1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace scatter_nd {
30constexpr int kIndices = 0;
31constexpr int kUpdates = 1;
32constexpr int kShape = 2;
33constexpr int kOutputTensor = 0;
34
35template <typename IndicesT>
36TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
37 const TfLiteTensor* shape,
38 TfLiteTensor* output) {
39 const int shape_rank = SizeOfDimension(shape, 0);
40 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape_rank);
41 const auto* shape_data = GetTensorData<IndicesT>(shape);
42
43 for (int i = 0; i < shape_rank; i++) {
44 output_shape->data[i] = shape_data[i];
45 }
46 return context->ResizeTensor(context, output, output_shape);
47}
48
49template <typename IndicesT>
50TfLiteStatus CheckShapes(TfLiteContext* context, const RuntimeShape& indices,
51 const RuntimeShape& updates,
52 const RuntimeShape& shape_shape,
53 const IndicesT* shape_data) {
54 TF_LITE_ENSURE(context, (indices.DimensionsCount() >= 1) &&
55 (updates.DimensionsCount() >= 1) &&
56 (shape_shape.DimensionsCount() == 1));
57
58 const int outer_dims = indices.DimensionsCount() - 1;
59 for (int i = 0; i < outer_dims; ++i) {
60 TF_LITE_ENSURE_EQ(context, indices.Dims(i), updates.Dims(i));
61 }
62
63 const int ix = indices.Dims(outer_dims);
64 TF_LITE_ENSURE_EQ(context, updates.DimensionsCount() - outer_dims,
65 shape_shape.Dims(0) - ix);
66 for (int i = 0; i + outer_dims < updates.DimensionsCount(); ++i) {
67 TF_LITE_ENSURE_EQ(context, updates.Dims(i + outer_dims),
68 shape_data[ix + i]);
69 }
70 return kTfLiteOk;
71}
72
73TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
74 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
75 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
76
77 const TfLiteTensor* indices;
78 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
79 const TfLiteTensor* updates;
80 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
81 const TfLiteTensor* shape;
82 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
83
84 switch (updates->type) {
85 case kTfLiteFloat32:
86 case kTfLiteUInt8:
87 case kTfLiteBool:
88 case kTfLiteInt8:
89 case kTfLiteInt64:
90 case kTfLiteInt32:
91 break;
92 default:
93 TF_LITE_KERNEL_LOG(
94 context, "Updates of type '%s' are not supported by scatter_nd.",
95 TfLiteTypeGetName(updates->type));
96 return kTfLiteError;
97 }
98 if (indices->type != shape->type) {
99 TF_LITE_KERNEL_LOG(context, "Indices and shape must have the same type.");
100 return kTfLiteError;
101 }
102
103 TfLiteTensor* output;
104 TF_LITE_ENSURE_OK(context,
105 GetOutputSafe(context, node, kOutputTensor, &output));
106 output->type = updates->type;
107
108 if (IsConstantTensor(shape)) {
109 switch (indices->type) {
110 case kTfLiteInt32:
111 TF_LITE_ENSURE_OK(
112 context,
113 CheckShapes<int32_t>(context, GetTensorShape(indices),
114 GetTensorShape(updates), GetTensorShape(shape),
115 GetTensorData<int32_t>(shape)));
116 return ResizeOutputTensor<int32_t>(context, shape, output);
117 default:
118 TF_LITE_KERNEL_LOG(
119 context, "Indices of type '%s' are not supported by scatter_nd.",
120 TfLiteTypeGetName(indices->type));
121 return kTfLiteError;
122 }
123 } else {
124 SetTensorToDynamic(output);
125 return kTfLiteOk;
126 }
127}
128
129template <typename IndicesT, typename UpdatesT>
130TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
131 TfLiteTensor* output) {
132 return reference_ops::ScatterNd(
133 GetTensorShape(indices), GetTensorData<IndicesT>(indices),
134 GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
135 GetTensorShape(output), GetTensorData<UpdatesT>(output));
136}
137
138template <typename IndicesT>
139TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
140 const TfLiteTensor* updates,
141 const TfLiteTensor* shape, TfLiteTensor* output) {
142 if (IsDynamicTensor(output)) {
143 TF_LITE_ENSURE_OK(
144 context, CheckShapes<IndicesT>(
145 context, GetTensorShape(indices), GetTensorShape(updates),
146 GetTensorShape(shape), GetTensorData<IndicesT>(shape)));
147 TF_LITE_ENSURE_OK(context,
148 ResizeOutputTensor<IndicesT>(context, shape, output));
149 }
150
151 TfLiteStatus status = kTfLiteError;
152 switch (updates->type) {
153 case kTfLiteFloat32:
154 status = ScatterNd<IndicesT, float>(indices, updates, output);
155 break;
156 case kTfLiteUInt8:
157 status = ScatterNd<IndicesT, uint8_t>(indices, updates, output);
158 break;
159 case kTfLiteBool:
160 status = ScatterNd<IndicesT, bool>(indices, updates, output);
161 break;
162 case kTfLiteInt8:
163 status = ScatterNd<IndicesT, int8_t>(indices, updates, output);
164 break;
165 case kTfLiteInt32:
166 status = ScatterNd<IndicesT, int32_t>(indices, updates, output);
167 break;
168 case kTfLiteInt64:
169 status = ScatterNd<IndicesT, int64_t>(indices, updates, output);
170 break;
171 default:
172 TF_LITE_KERNEL_LOG(
173 context, "Updates of type '%s' are not supported by scatter_nd.",
174 TfLiteTypeGetName(updates->type));
175 return kTfLiteError;
176 }
177 if (status != kTfLiteOk) {
178 TF_LITE_KERNEL_LOG(context, "scatter_nd index out of bounds");
179 }
180 return status;
181}
182
183TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
184 const TfLiteTensor* indices;
185 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
186 const TfLiteTensor* updates;
187 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
188 const TfLiteTensor* shape;
189 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
190 TfLiteTensor* output;
191 TF_LITE_ENSURE_OK(context,
192 GetOutputSafe(context, node, kOutputTensor, &output));
193
194 switch (indices->type) {
195 case kTfLiteInt32:
196 return EvalScatterNd<int32_t>(context, indices, updates, shape, output);
197 default:
198 TF_LITE_KERNEL_LOG(
199 context, "Indices of type '%s' are not supported by scatter_nd.",
200 TfLiteTypeGetName(indices->type));
201 return kTfLiteError;
202 }
203}
204
205} // namespace scatter_nd
206
207TfLiteRegistration* Register_SCATTER_ND() {
208 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
209 scatter_nd::Prepare, scatter_nd::Eval};
210 return &r;
211}
212} // namespace builtin
213} // namespace ops
214} // namespace tflite
215