1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cmath>
18#include <cstdint>
19#include <vector>
20
21#include "tensorflow/lite/c/c_api_types.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24#include "tensorflow/lite/kernels/internal/tensor.h"
25#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26#include "tensorflow/lite/kernels/internal/types.h"
27#include "tensorflow/lite/kernels/kernel_util.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace dynamic_update_slice {
33
34constexpr int kOperandTensor = 0;
35constexpr int kUpdateTensor = 1;
36constexpr int kStartIndicesTensor = 2;
37constexpr int kOutputTensor = 0;
38
39// TFLite DynamicUpdateSlice op follows the semantics of XLA DynamicUpdateSlice
40// op. See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
41// for details.
42TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
43 const TfLiteTensor* operand;
44 TF_LITE_ENSURE_OK(context,
45 GetInputSafe(context, node, kOperandTensor, &operand));
46 const TfLiteTensor* update;
47 TF_LITE_ENSURE_OK(context,
48 GetInputSafe(context, node, kUpdateTensor, &update));
49 const TfLiteTensor* start_indices;
50 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor,
51 &start_indices));
52 TfLiteTensor* output;
53 TF_LITE_ENSURE_OK(context,
54 GetOutputSafe(context, node, kOutputTensor, &output));
55
56 // The shape of start_indices must be rank == 1, with dimension size equal to
57 // the rank of operand.
58 TF_LITE_ENSURE(context, NumDimensions(start_indices) == 1);
59 TF_LITE_ENSURE(context,
60 SizeOfDimension(start_indices, 0) == NumDimensions(operand));
61
62 // Update must be less than or equal to the operand size for each dimension to
63 // avoid generating out-of-bounds update indices.
64 TF_LITE_ENSURE(context, NumDimensions(update) == NumDimensions(operand));
65 for (int i = 0; i < NumDimensions(operand); i++) {
66 TF_LITE_ENSURE(context,
67 SizeOfDimension(update, i) <= SizeOfDimension(operand, i));
68 }
69
70 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
71 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
72 TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type);
73 TF_LITE_ENSURE_TYPES_EQ(context, start_indices->type, kTfLiteInt32);
74
75 output->type = operand->type;
76 TfLiteIntArray* output_size = TfLiteIntArrayCopy(operand->dims);
77 return context->ResizeTensor(context, output, output_size);
78}
79
80// A helper function that converts a tensor index into a flat array index.
81// Takes `start_indices` as an offset if not null.
82int TensorIndexToFlat(const int* index, const int dims,
83 const RuntimeShape& shape,
84 const int* start_indices = nullptr) {
85 int flat_index = index[0] + (start_indices ? start_indices[0] : 0);
86 for (int i = 1; i < dims; i++) {
87 flat_index = flat_index * shape.Dims(i) + index[i] +
88 (start_indices ? start_indices[i] : 0);
89 }
90 return flat_index;
91}
92
93// A helper function to compute the clamped start indices to ensure they are
94// not out of bounds.
95std::vector<int> ClampStartIndices(int input_dims, const int32_t* indices_data,
96 const RuntimeShape& input_shape,
97 const RuntimeShape& update_shape) {
98 std::vector<int> clamped_start_indices(input_dims, 0);
99 for (int i = 0; i < input_dims; i++) {
100 clamped_start_indices[i] =
101 std::min(std::max(0, indices_data[i]),
102 input_shape.Dims(i) - update_shape.Dims(i));
103 }
104 return clamped_start_indices;
105}
106
107template <typename T>
108void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update,
109 const TfLiteTensor* indice, TfLiteTensor* output) {
110 const auto& input_shape = GetTensorShape(input);
111 const auto& update_shape = GetTensorShape(update);
112 const T* update_data = GetTensorData<T>(update);
113 const int32_t* indices_data = GetTensorData<int32_t>(indice);
114 T* output_data = GetTensorData<T>(output);
115
116 const int input_dims = input_shape.DimensionsCount();
117 // Computes the effective slice indices.
118 // The clamped indices are gauranteed to >= 0 since update is less than or
119 // equal to the operand size for each dimension.
120 std::vector<int> clamped_start_indices =
121 ClampStartIndices(input_dims, indices_data, input_shape, update_shape);
122
123 // Copies input to output first.
124 memcpy(output->data.raw, input->data.raw, input->bytes);
125
126 std::vector<int> current_dim(input_dims, 0);
127 // Overwrites update to output.
128 do {
129 int flat_update_index =
130 TensorIndexToFlat(current_dim.data(), input_dims, update_shape);
131 int flat_input_index =
132 TensorIndexToFlat(current_dim.data(), input_dims, input_shape,
133 clamped_start_indices.data());
134 output_data[flat_input_index] = update_data[flat_update_index];
135 } while (NextIndex(input_dims,
136 reinterpret_cast<const int*>(update_shape.DimsData()),
137 current_dim.data()));
138}
139
140TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
141 const TfLiteTensor* operand;
142 TF_LITE_ENSURE_OK(context,
143 GetInputSafe(context, node, kOperandTensor, &operand));
144 const TfLiteTensor* update;
145 TF_LITE_ENSURE_OK(context,
146 GetInputSafe(context, node, kUpdateTensor, &update));
147 const TfLiteTensor* indice;
148 TF_LITE_ENSURE_OK(context,
149 GetInputSafe(context, node, kStartIndicesTensor, &indice));
150 TfLiteTensor* output;
151 TF_LITE_ENSURE_OK(context,
152 GetOutputSafe(context, node, kOutputTensor, &output));
153
154 switch (operand->type) {
155 case kTfLiteFloat32:
156 DynamicUpdateSlice<float>(operand, update, indice, output);
157 break;
158 case kTfLiteBool:
159 DynamicUpdateSlice<bool>(operand, update, indice, output);
160 break;
161 case kTfLiteInt8:
162 DynamicUpdateSlice<int8_t>(operand, update, indice, output);
163 break;
164 case kTfLiteInt32:
165 DynamicUpdateSlice<int32_t>(operand, update, indice, output);
166 break;
167 case kTfLiteInt64:
168 DynamicUpdateSlice<int64_t>(operand, update, indice, output);
169 break;
170 default:
171 TF_LITE_KERNEL_LOG(context,
172 "DynamicUpdateSlice only currently supports "
173 "1-bit/8-bit/32-bit/64-bit integer or "
174 "float type, got %d.",
175 operand->type);
176 return kTfLiteError;
177 }
178
179 return kTfLiteOk;
180}
181} // namespace dynamic_update_slice
182
183TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE() {
184 static TfLiteRegistration r = {/*init=*/nullptr,
185 /*free=*/nullptr,
186 dynamic_update_slice::Prepare,
187 dynamic_update_slice::Eval};
188 return &r;
189}
190
191} // namespace builtin
192} // namespace ops
193} // namespace tflite
194