1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/kernel_util.h" |
24 | #include "tensorflow/lite/kernels/op_macros.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace sparse_to_dense { |
30 | |
31 | constexpr int kIndicesTensor = 0; |
32 | constexpr int kOutputShapeTensor = 1; |
33 | constexpr int kValueInputTensor = 2; |
34 | constexpr int kDefaultValueTensor = 3; |
35 | constexpr int kOutputTensor = 0; |
36 | |
37 | constexpr int kMaxDimensions = 4; |
38 | |
39 | template <typename T> |
40 | TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, |
41 | TfLiteTensor* output) { |
42 | const int output_dimensions = NumElements(output_shape); |
43 | TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); |
44 | for (int i = 0; i < output_dimensions; ++i) { |
45 | output_shape_array->data[i] = GetTensorData<T>(output_shape)[i]; |
46 | } |
47 | |
48 | return context->ResizeTensor(context, output, output_shape_array); |
49 | } |
50 | |
51 | TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, |
52 | const TfLiteTensor* indices, |
53 | const TfLiteTensor* output_shape, |
54 | const TfLiteTensor* values) { |
55 | switch (NumDimensions(indices)) { |
56 | case 0: |
57 | case 1: { |
58 | if (NumDimensions(values) == 0) { |
59 | TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); |
60 | } |
61 | TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); |
62 | break; |
63 | } |
64 | case 2: { |
65 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), |
66 | NumElements(output_shape)); |
67 | if (NumDimensions(values) == 0) |
68 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), |
69 | NumElements(values)); |
70 | break; |
71 | } |
72 | default: |
73 | TF_LITE_KERNEL_LOG(context, |
74 | "Wrong indices dimensions %d, should be less than 3." , |
75 | NumDimensions(indices)); |
76 | return kTfLiteError; |
77 | } |
78 | return kTfLiteOk; |
79 | } |
80 | |
81 | // Convert indices into a vector of 4-d vectors. |
82 | // TODO(renjieliu): Revisit here to improve the performance, since multiple |
83 | // allocations of std::vectors will be quite slow on phones. |
84 | template <typename T> |
85 | TfLiteStatus GetIndicesVector(TfLiteContext* context, |
86 | const TfLiteTensor* indices, |
87 | const int num_indices, |
88 | std::vector<std::vector<T>>* indices_vector) { |
89 | // Note because TfLite will reverse the dimensions, so pad zeros upfront. |
90 | switch (NumDimensions(indices)) { |
91 | case 0: |
92 | case 1: { |
93 | const auto indices_data = GetTensorData<T>(indices); |
94 | for (int i = 0; i < num_indices; ++i) { |
95 | std::vector<T> index({0, 0, 0, indices_data[i]}); |
96 | indices_vector->push_back(index); |
97 | } |
98 | break; |
99 | } |
100 | case 2: { |
101 | const int true_dimensions = SizeOfDimension(indices, 1); |
102 | TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); |
103 | for (int i = 0; i < num_indices; ++i) { |
104 | std::vector<T> index; |
105 | index.reserve(kMaxDimensions); |
106 | // Fill the index with 1 up to kMaxDimensions - true_dimensions to |
107 | // satisfy the needs for 4-dimension index. |
108 | for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { |
109 | index.push_back(0); |
110 | } |
111 | for (int j = 0; j < true_dimensions; ++j) { |
112 | index.push_back(GetTensorData<T>(indices)[i * true_dimensions + j]); |
113 | } |
114 | |
115 | indices_vector->push_back(index); |
116 | } |
117 | break; |
118 | } |
119 | default: |
120 | TF_LITE_KERNEL_LOG(context, |
121 | "Indices dimensions problem, got %d dimensions" , |
122 | NumDimensions(indices)); |
123 | return kTfLiteError; |
124 | } |
125 | return kTfLiteOk; |
126 | } |
127 | |
128 | TfLiteStatus ResizeOutputShape(TfLiteContext* context, |
129 | const TfLiteTensor* output_shape, |
130 | TfLiteTensor* output) { |
131 | if (output_shape->type == kTfLiteInt32) { |
132 | return Resize<int32_t>(context, output_shape, output); |
133 | } else if (output_shape->type == kTfLiteInt64) { |
134 | return Resize<int64_t>(context, output_shape, output); |
135 | } else { |
136 | TF_LITE_KERNEL_LOG(context, "Dense shape type %d not supported." , |
137 | output_shape->type); |
138 | return kTfLiteError; |
139 | } |
140 | } |
141 | |
142 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
143 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); |
144 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
145 | |
146 | const TfLiteTensor* indices; |
147 | TF_LITE_ENSURE_OK(context, |
148 | GetInputSafe(context, node, kIndicesTensor, &indices)); |
149 | const TfLiteTensor* output_shape; |
150 | TF_LITE_ENSURE_OK( |
151 | context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape)); |
152 | const TfLiteTensor* values; |
153 | TF_LITE_ENSURE_OK(context, |
154 | GetInputSafe(context, node, kValueInputTensor, &values)); |
155 | const TfLiteTensor* default_value; |
156 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor, |
157 | &default_value)); |
158 | |
159 | // TODO(renjieliu): Handle validate_indices. |
160 | |
161 | // Indices can be 0-D, 1-D or 2-D. |
162 | TF_LITE_ASSERT(NumDimensions(indices) >= 0); |
163 | TF_LITE_ENSURE(context, NumDimensions(indices) < 3); |
164 | TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); |
165 | TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); |
166 | // Values can be 0-D or 1-D. |
167 | TF_LITE_ASSERT(NumDimensions(values) >= 0); |
168 | TF_LITE_ENSURE(context, NumDimensions(values) < 2); |
169 | |
170 | TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); |
171 | |
172 | TF_LITE_ENSURE( |
173 | context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); |
174 | TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || |
175 | output_shape->type == kTfLiteInt64); |
176 | TF_LITE_ENSURE(context, values->type == kTfLiteInt32 || |
177 | values->type == kTfLiteInt64 || |
178 | values->type == kTfLiteInt8 || |
179 | values->type == kTfLiteUInt8 || |
180 | values->type == kTfLiteFloat32); |
181 | TF_LITE_ENSURE_TYPES_EQ(context, values->type, default_value->type); |
182 | |
183 | // Ensure dimensions match. |
184 | TF_LITE_ENSURE_OK( |
185 | context, CheckDimensionsMatch(context, indices, output_shape, values)); |
186 | |
187 | TfLiteTensor* output; |
188 | TF_LITE_ENSURE_OK(context, |
189 | GetOutputSafe(context, node, kOutputTensor, &output)); |
190 | output->type = values->type; |
191 | TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); |
192 | |
193 | if (!IsConstantOrPersistentTensor(output_shape)) { |
194 | SetTensorToDynamic(output); |
195 | return kTfLiteOk; |
196 | } |
197 | return ResizeOutputShape(context, output_shape, output); |
198 | } |
199 | |
200 | template <typename T, typename TI> |
201 | TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { |
202 | const TfLiteTensor* indices; |
203 | TF_LITE_ENSURE_OK(context, |
204 | GetInputSafe(context, node, kIndicesTensor, &indices)); |
205 | const TfLiteTensor* output_shape; |
206 | TF_LITE_ENSURE_OK( |
207 | context, GetInputSafe(context, node, kOutputShapeTensor, &output_shape)); |
208 | const TfLiteTensor* values; |
209 | TF_LITE_ENSURE_OK(context, |
210 | GetInputSafe(context, node, kValueInputTensor, &values)); |
211 | const TfLiteTensor* default_value; |
212 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDefaultValueTensor, |
213 | &default_value)); |
214 | TfLiteTensor* output; |
215 | TF_LITE_ENSURE_OK(context, |
216 | GetOutputSafe(context, node, kOutputTensor, &output)); |
217 | |
218 | if (IsDynamicTensor(output)) { |
219 | TF_LITE_ENSURE_OK(context, |
220 | ResizeOutputShape(context, output_shape, output)); |
221 | } |
222 | |
223 | const int num_indices = SizeOfDimension(indices, 0); |
224 | const bool value_is_scalar = NumDimensions(values) == 0; |
225 | std::vector<std::vector<TI>> indices_vector; |
226 | indices_vector.reserve(num_indices); |
227 | TF_LITE_ENSURE_OK(context, GetIndicesVector<TI>(context, indices, num_indices, |
228 | &indices_vector)); |
229 | reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values), |
230 | *GetTensorData<T>(default_value), |
231 | value_is_scalar, GetTensorShape(output), |
232 | GetTensorData<T>(output)); |
233 | |
234 | return kTfLiteOk; |
235 | } |
236 | |
237 | template <typename T> |
238 | TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node, |
239 | const TfLiteTensor* indices) { |
240 | switch (indices->type) { |
241 | case kTfLiteInt32: { |
242 | return SparseToDenseImpl<T, int32_t>(context, node); |
243 | } |
244 | case kTfLiteInt64: { |
245 | return SparseToDenseImpl<T, int64_t>(context, node); |
246 | } |
247 | default: |
248 | TF_LITE_KERNEL_LOG( |
249 | context, |
250 | "Indice type %s is currently not supported by sparse to dense." , |
251 | TfLiteTypeGetName(indices->type)); |
252 | return kTfLiteError; |
253 | } |
254 | } |
255 | |
256 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
257 | const TfLiteTensor* indices; |
258 | TF_LITE_ENSURE_OK(context, |
259 | GetInputSafe(context, node, kIndicesTensor, &indices)); |
260 | const TfLiteTensor* values; |
261 | TF_LITE_ENSURE_OK(context, |
262 | GetInputSafe(context, node, kValueInputTensor, &values)); |
263 | |
264 | switch (values->type) { |
265 | case kTfLiteFloat32: |
266 | return EvalForIndexType<float>(context, node, indices); |
267 | case kTfLiteInt32: |
268 | return EvalForIndexType<int32_t>(context, node, indices); |
269 | case kTfLiteInt64: |
270 | return EvalForIndexType<int64_t>(context, node, indices); |
271 | case kTfLiteInt8: |
272 | return EvalForIndexType<int8_t>(context, node, indices); |
273 | case kTfLiteUInt8: |
274 | return EvalForIndexType<uint8_t>(context, node, indices); |
275 | default: |
276 | TF_LITE_KERNEL_LOG( |
277 | context, |
278 | "Value type %s is currently not supported by sparse to dense." , |
279 | TfLiteTypeGetName(values->type)); |
280 | return kTfLiteError; |
281 | } |
282 | } |
283 | |
284 | } // namespace sparse_to_dense |
285 | |
286 | TfLiteRegistration* Register_SPARSE_TO_DENSE() { |
287 | static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, |
288 | sparse_to_dense::Eval}; |
289 | return &r; |
290 | } |
291 | |
292 | } // namespace builtin |
293 | } // namespace ops |
294 | } // namespace tflite |
295 | |