1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include <vector>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24#include "tensorflow/lite/kernels/op_macros.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace sparse_to_dense {
30
31constexpr int kIndicesTensor = 0;
32constexpr int kOutputShapeTensor = 1;
33constexpr int kValueInputTensor = 2;
34constexpr int kDefaultValueTensor = 3;
35constexpr int kOutputTensor = 0;
36
37constexpr int kMaxDimensions = 4;
38
39template <typename T>
40TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape,
41 TfLiteTensor* output) {
42 const int output_dimensions = NumElements(output_shape);
43 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions);
44 for (int i = 0; i < output_dimensions; ++i) {
45 output_shape_array->data[i] = GetTensorData<T>(output_shape)[i];
46 }
47
48 return context->ResizeTensor(context, output, output_shape_array);
49}
50
51TfLiteStatus CheckDimensionsMatch(TfLiteContext* context,
52 const TfLiteTensor* indices,
53 const TfLiteTensor* output_shape,
54 const TfLiteTensor* values) {
55 switch (NumDimensions(indices)) {
56 case 0:
57 case 1: {
58 if (NumDimensions(values) == 0) {
59 TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values));
60 }
61 TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1);
62 break;
63 }
64 case 2: {
65 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1),
66 NumElements(output_shape));
67 if (NumDimensions(values) == 0)
68 TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
69 NumElements(values));
70 break;
71 }
72 default:
73 TF_LITE_KERNEL_LOG(context,
74 "Wrong indices dimensions %d, should be less than 3.",
75 NumDimensions(indices));
76 return kTfLiteError;
77 }
78 return kTfLiteOk;
79}
80
81// Convert indices into a vector of 4-d vectors.
82// TODO(renjieliu): Revisit here to improve the performance, since multiple
83// allocations of std::vectors will be quite slow on phones.
84template <typename T>
85TfLiteStatus GetIndicesVector(TfLiteContext* context,
86 const TfLiteTensor* indices,
87 const int num_indices,
88 std::vector<std::vector<T>>* indices_vector) {
89 // Note because TfLite will reverse the dimensions, so pad zeros upfront.
90 switch (NumDimensions(indices)) {
91 case 0:
92 case 1: {
93 const auto indices_data = GetTensorData<T>(indices);
94 for (int i = 0; i < num_indices; ++i) {
95 std::vector<T> index({0, 0, 0, indices_data[i]});
96 indices_vector->push_back(index);
97 }
98 break;
99 }
100 case 2: {
101 const int true_dimensions = SizeOfDimension(indices, 1);
102 TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions);
103 for (int i = 0; i < num_indices; ++i) {
104 std::vector<T> index;
105 index.reserve(kMaxDimensions);
106 // Fill the index with 1 up to kMaxDimensions - true_dimensions to
107 // satisfy the needs for 4-dimension index.
108 for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) {
109 index.push_back(0);
110 }
111 for (int j = 0; j < true_dimensions; ++j) {
112 index.push_back(GetTensorData<T>(indices)[i * true_dimensions + j]);
113 }
114
115 indices_vector->push_back(index);
116 }
117 break;
118 }
119 default:
120 TF_LITE_KERNEL_LOG(context,
121 "Indices dimensions problem, got %d dimensions",
122 NumDimensions(indices));
123 return kTfLiteError;
124 }
125 return kTfLiteOk;
126}
127
128TfLiteStatus ResizeOutputShape(TfLiteContext* context,
129 const TfLiteTensor* output_shape,
130 TfLiteTensor* output) {
131 if (output_shape->type == kTfLiteInt32) {
132 return Resize<int32_t>(context, output_shape, output);
133 } else if (output_shape->type == kTfLiteInt64) {
134 return Resize<int64_t>(context, output_shape, output);
135 } else {
136 TF_LITE_KERNEL_LOG(context, "Dense shape type %d not supported.",
137 output_shape->type);
138 return kTfLiteError;
139 }
140}
141
142TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
143 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
144 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
145
146 const TfLiteTensor* indices;
147 TF_LITE_ENSURE_OK(context,
148 GetInputSafe(context, node, kIndicesTensor, &indices));
149 const TfLiteTensor* output_shape;
150 TF_LITE_ENSURE_OK(
151 context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
152 const TfLiteTensor* values;
153 TF_LITE_ENSURE_OK(context,
154 GetInputSafe(context, node, kValueInputTensor, &values));
155 const TfLiteTensor* default_value;
156 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
157 &default_value));
158
159 // TODO(renjieliu): Handle validate_indices.
160
161 // Indices can be 0-D, 1-D or 2-D.
162 TF_LITE_ASSERT(NumDimensions(indices) >= 0);
163 TF_LITE_ENSURE(context, NumDimensions(indices) < 3);
164 TF_LITE_ASSERT(NumDimensions(output_shape) >= 0);
165 TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
166 // Values can be 0-D or 1-D.
167 TF_LITE_ASSERT(NumDimensions(values) >= 0);
168 TF_LITE_ENSURE(context, NumDimensions(values) < 2);
169
170 TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1);
171
172 TF_LITE_ENSURE(
173 context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64);
174 TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 ||
175 output_shape->type == kTfLiteInt64);
176 TF_LITE_ENSURE(context, values->type == kTfLiteInt32 ||
177 values->type == kTfLiteInt64 ||
178 values->type == kTfLiteInt8 ||
179 values->type == kTfLiteUInt8 ||
180 values->type == kTfLiteFloat32);
181 TF_LITE_ENSURE_TYPES_EQ(context, values->type, default_value->type);
182
183 // Ensure dimensions match.
184 TF_LITE_ENSURE_OK(
185 context, CheckDimensionsMatch(context, indices, output_shape, values));
186
187 TfLiteTensor* output;
188 TF_LITE_ENSURE_OK(context,
189 GetOutputSafe(context, node, kOutputTensor, &output));
190 output->type = values->type;
191 TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
192
193 if (!IsConstantOrPersistentTensor(output_shape)) {
194 SetTensorToDynamic(output);
195 return kTfLiteOk;
196 }
197 return ResizeOutputShape(context, output_shape, output);
198}
199
200template <typename T, typename TI>
201TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
202 const TfLiteTensor* indices;
203 TF_LITE_ENSURE_OK(context,
204 GetInputSafe(context, node, kIndicesTensor, &indices));
205 const TfLiteTensor* output_shape;
206 TF_LITE_ENSURE_OK(
207 context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape));
208 const TfLiteTensor* values;
209 TF_LITE_ENSURE_OK(context,
210 GetInputSafe(context, node, kValueInputTensor, &values));
211 const TfLiteTensor* default_value;
212 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor,
213 &default_value));
214 TfLiteTensor* output;
215 TF_LITE_ENSURE_OK(context,
216 GetOutputSafe(context, node, kOutputTensor, &output));
217
218 if (IsDynamicTensor(output)) {
219 TF_LITE_ENSURE_OK(context,
220 ResizeOutputShape(context, output_shape, output));
221 }
222
223 const int num_indices = SizeOfDimension(indices, 0);
224 const bool value_is_scalar = NumDimensions(values) == 0;
225 std::vector<std::vector<TI>> indices_vector;
226 indices_vector.reserve(num_indices);
227 TF_LITE_ENSURE_OK(context, GetIndicesVector<TI>(context, indices, num_indices,
228 &indices_vector));
229 reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
230 *GetTensorData<T>(default_value),
231 value_is_scalar, GetTensorShape(output),
232 GetTensorData<T>(output));
233
234 return kTfLiteOk;
235}
236
237template <typename T>
238TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
239 const TfLiteTensor* indices) {
240 switch (indices->type) {
241 case kTfLiteInt32: {
242 return SparseToDenseImpl<T, int32_t>(context, node);
243 }
244 case kTfLiteInt64: {
245 return SparseToDenseImpl<T, int64_t>(context, node);
246 }
247 default:
248 TF_LITE_KERNEL_LOG(
249 context,
250 "Indice type %s is currently not supported by sparse to dense.",
251 TfLiteTypeGetName(indices->type));
252 return kTfLiteError;
253 }
254}
255
256TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
257 const TfLiteTensor* indices;
258 TF_LITE_ENSURE_OK(context,
259 GetInputSafe(context, node, kIndicesTensor, &indices));
260 const TfLiteTensor* values;
261 TF_LITE_ENSURE_OK(context,
262 GetInputSafe(context, node, kValueInputTensor, &values));
263
264 switch (values->type) {
265 case kTfLiteFloat32:
266 return EvalForIndexType<float>(context, node, indices);
267 case kTfLiteInt32:
268 return EvalForIndexType<int32_t>(context, node, indices);
269 case kTfLiteInt64:
270 return EvalForIndexType<int64_t>(context, node, indices);
271 case kTfLiteInt8:
272 return EvalForIndexType<int8_t>(context, node, indices);
273 case kTfLiteUInt8:
274 return EvalForIndexType<uint8_t>(context, node, indices);
275 default:
276 TF_LITE_KERNEL_LOG(
277 context,
278 "Value type %s is currently not supported by sparse to dense.",
279 TfLiteTypeGetName(values->type));
280 return kTfLiteError;
281 }
282}
283
284} // namespace sparse_to_dense
285
286TfLiteRegistration* Register_SPARSE_TO_DENSE() {
287 static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare,
288 sparse_to_dense::Eval};
289 return &r;
290}
291
292} // namespace builtin
293} // namespace ops
294} // namespace tflite
295