1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cmath> |
18 | #include <cstdint> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/lite/c/c_api_types.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
26 | #include "tensorflow/lite/kernels/internal/types.h" |
27 | #include "tensorflow/lite/kernels/kernel_util.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace dynamic_update_slice { |
33 | |
34 | constexpr int kOperandTensor = 0; |
35 | constexpr int kUpdateTensor = 1; |
36 | constexpr int kStartIndicesTensor = 2; |
37 | constexpr int kOutputTensor = 0; |
38 | |
39 | // TFLite DynamicUpdateSlice op follows the semantics of XLA DynamicUpdateSlice |
40 | // op. See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice |
41 | // for details. |
42 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
43 | const TfLiteTensor* operand; |
44 | TF_LITE_ENSURE_OK(context, |
45 | GetInputSafe(context, node, kOperandTensor, &operand)); |
46 | const TfLiteTensor* update; |
47 | TF_LITE_ENSURE_OK(context, |
48 | GetInputSafe(context, node, kUpdateTensor, &update)); |
49 | const TfLiteTensor* start_indices; |
50 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartIndicesTensor, |
51 | &start_indices)); |
52 | TfLiteTensor* output; |
53 | TF_LITE_ENSURE_OK(context, |
54 | GetOutputSafe(context, node, kOutputTensor, &output)); |
55 | |
56 | // The shape of start_indices must be rank == 1, with dimension size equal to |
57 | // the rank of operand. |
58 | TF_LITE_ENSURE(context, NumDimensions(start_indices) == 1); |
59 | TF_LITE_ENSURE(context, |
60 | SizeOfDimension(start_indices, 0) == NumDimensions(operand)); |
61 | |
62 | // Update must be less than or equal to the operand size for each dimension to |
63 | // avoid generating out-of-bounds update indices. |
64 | TF_LITE_ENSURE(context, NumDimensions(update) == NumDimensions(operand)); |
65 | for (int i = 0; i < NumDimensions(operand); i++) { |
66 | TF_LITE_ENSURE(context, |
67 | SizeOfDimension(update, i) <= SizeOfDimension(operand, i)); |
68 | } |
69 | |
70 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
71 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
72 | TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type); |
73 | TF_LITE_ENSURE_TYPES_EQ(context, start_indices->type, kTfLiteInt32); |
74 | |
75 | output->type = operand->type; |
76 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(operand->dims); |
77 | return context->ResizeTensor(context, output, output_size); |
78 | } |
79 | |
80 | // A helper function that converts a tensor index into a flat array index. |
81 | // Takes `start_indices` as an offset if not null. |
82 | int TensorIndexToFlat(const int* index, const int dims, |
83 | const RuntimeShape& shape, |
84 | const int* start_indices = nullptr) { |
85 | int flat_index = index[0] + (start_indices ? start_indices[0] : 0); |
86 | for (int i = 1; i < dims; i++) { |
87 | flat_index = flat_index * shape.Dims(i) + index[i] + |
88 | (start_indices ? start_indices[i] : 0); |
89 | } |
90 | return flat_index; |
91 | } |
92 | |
93 | // A helper function to compute the clamped start indices to ensure they are |
94 | // not out of bounds. |
95 | std::vector<int> ClampStartIndices(int input_dims, const int32_t* indices_data, |
96 | const RuntimeShape& input_shape, |
97 | const RuntimeShape& update_shape) { |
98 | std::vector<int> clamped_start_indices(input_dims, 0); |
99 | for (int i = 0; i < input_dims; i++) { |
100 | clamped_start_indices[i] = |
101 | std::min(std::max(0, indices_data[i]), |
102 | input_shape.Dims(i) - update_shape.Dims(i)); |
103 | } |
104 | return clamped_start_indices; |
105 | } |
106 | |
107 | template <typename T> |
108 | void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, |
109 | const TfLiteTensor* indice, TfLiteTensor* output) { |
110 | const auto& input_shape = GetTensorShape(input); |
111 | const auto& update_shape = GetTensorShape(update); |
112 | const T* update_data = GetTensorData<T>(update); |
113 | const int32_t* indices_data = GetTensorData<int32_t>(indice); |
114 | T* output_data = GetTensorData<T>(output); |
115 | |
116 | const int input_dims = input_shape.DimensionsCount(); |
117 | // Computes the effective slice indices. |
118 | // The clamped indices are gauranteed to >= 0 since update is less than or |
119 | // equal to the operand size for each dimension. |
120 | std::vector<int> clamped_start_indices = |
121 | ClampStartIndices(input_dims, indices_data, input_shape, update_shape); |
122 | |
123 | // Copies input to output first. |
124 | memcpy(output->data.raw, input->data.raw, input->bytes); |
125 | |
126 | std::vector<int> current_dim(input_dims, 0); |
127 | // Overwrites update to output. |
128 | do { |
129 | int flat_update_index = |
130 | TensorIndexToFlat(current_dim.data(), input_dims, update_shape); |
131 | int flat_input_index = |
132 | TensorIndexToFlat(current_dim.data(), input_dims, input_shape, |
133 | clamped_start_indices.data()); |
134 | output_data[flat_input_index] = update_data[flat_update_index]; |
135 | } while (NextIndex(input_dims, |
136 | reinterpret_cast<const int*>(update_shape.DimsData()), |
137 | current_dim.data())); |
138 | } |
139 | |
140 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
141 | const TfLiteTensor* operand; |
142 | TF_LITE_ENSURE_OK(context, |
143 | GetInputSafe(context, node, kOperandTensor, &operand)); |
144 | const TfLiteTensor* update; |
145 | TF_LITE_ENSURE_OK(context, |
146 | GetInputSafe(context, node, kUpdateTensor, &update)); |
147 | const TfLiteTensor* indice; |
148 | TF_LITE_ENSURE_OK(context, |
149 | GetInputSafe(context, node, kStartIndicesTensor, &indice)); |
150 | TfLiteTensor* output; |
151 | TF_LITE_ENSURE_OK(context, |
152 | GetOutputSafe(context, node, kOutputTensor, &output)); |
153 | |
154 | switch (operand->type) { |
155 | case kTfLiteFloat32: |
156 | DynamicUpdateSlice<float>(operand, update, indice, output); |
157 | break; |
158 | case kTfLiteBool: |
159 | DynamicUpdateSlice<bool>(operand, update, indice, output); |
160 | break; |
161 | case kTfLiteInt8: |
162 | DynamicUpdateSlice<int8_t>(operand, update, indice, output); |
163 | break; |
164 | case kTfLiteInt32: |
165 | DynamicUpdateSlice<int32_t>(operand, update, indice, output); |
166 | break; |
167 | case kTfLiteInt64: |
168 | DynamicUpdateSlice<int64_t>(operand, update, indice, output); |
169 | break; |
170 | default: |
171 | TF_LITE_KERNEL_LOG(context, |
172 | "DynamicUpdateSlice only currently supports " |
173 | "1-bit/8-bit/32-bit/64-bit integer or " |
174 | "float type, got %d." , |
175 | operand->type); |
176 | return kTfLiteError; |
177 | } |
178 | |
179 | return kTfLiteOk; |
180 | } |
181 | } // namespace dynamic_update_slice |
182 | |
183 | TfLiteRegistration* Register_DYNAMIC_UPDATE_SLICE() { |
184 | static TfLiteRegistration r = {/*init=*/nullptr, |
185 | /*free=*/nullptr, |
186 | dynamic_update_slice::Prepare, |
187 | dynamic_update_slice::Eval}; |
188 | return &r; |
189 | } |
190 | |
191 | } // namespace builtin |
192 | } // namespace ops |
193 | } // namespace tflite |
194 | |