1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include <algorithm> |
18 | #include <tuple> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
25 | #include "tensorflow/lite/kernels/kernel_util.h" |
26 | #include "tensorflow/lite/string_util.h" |
27 | |
28 | namespace tflite { |
29 | namespace ops { |
30 | namespace builtin { |
31 | namespace tile { |
32 | |
33 | constexpr int kInputTensor = 0; |
34 | constexpr int kInputMultipliers = 1; |
35 | constexpr int kOutputTensor = 0; |
36 | |
37 | namespace { |
38 | template <typename T> |
39 | TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, |
40 | const TfLiteTensor* multipliers, |
41 | int num_dimensions) { |
42 | const T* multipliers_v = GetTensorData<T>(multipliers); |
43 | |
44 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); |
45 | for (int i = 0; i < num_dimensions; ++i) { |
46 | output_shape->data[i] = shape.data[i] * multipliers_v[i]; |
47 | } |
48 | return output_shape; |
49 | } |
50 | |
51 | TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { |
52 | const TfLiteTensor* input; |
53 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
54 | TfLiteTensor* output; |
55 | TF_LITE_ENSURE_OK(context, |
56 | GetOutputSafe(context, node, kOutputTensor, &output)); |
57 | const TfLiteTensor* multipliers; |
58 | TF_LITE_ENSURE_OK( |
59 | context, GetInputSafe(context, node, kInputMultipliers, &multipliers)); |
60 | |
61 | const int num_dimensions = NumDimensions(input); |
62 | const int num_multipliers = NumElements(multipliers); |
63 | TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); |
64 | switch (multipliers->type) { |
65 | case kTfLiteInt32: |
66 | return context->ResizeTensor( |
67 | context, output, |
68 | MultiplyShapeDims<int32_t>(*input->dims, multipliers, |
69 | num_dimensions)); |
70 | case kTfLiteInt64: |
71 | return context->ResizeTensor( |
72 | context, output, |
73 | MultiplyShapeDims<int64_t>(*input->dims, multipliers, |
74 | num_dimensions)); |
75 | default: |
76 | TF_LITE_KERNEL_LOG(context, |
77 | "Multipliers of type '%s' are not supported by tile." , |
78 | TfLiteTypeGetName(multipliers->type)); |
79 | return kTfLiteError; |
80 | } |
81 | } |
82 | |
83 | template <typename T, typename M> |
84 | void CopyMultipleTimes(const T* in_data, int32_t in_size, M multiplier, |
85 | T* out_data) { |
86 | for (M i = 0; i < multiplier; ++i) { |
87 | const T* in_end = in_data + in_size; |
88 | T* new_out_data = std::copy(in_data, in_end, out_data); |
89 | in_data = out_data; |
90 | out_data = new_out_data; |
91 | } |
92 | } |
93 | |
94 | template <typename M> |
95 | void CopyStringMultipleTimes(const TfLiteTensor* in_data, int in_data_index, |
96 | const int dimension_size, M multiplier, |
97 | DynamicBuffer* buffer) { |
98 | for (M i = 0; i < multiplier; ++i) { |
99 | for (int j = 0; j < dimension_size; ++j) { |
100 | const auto string_ref = GetString(in_data, in_data_index + j); |
101 | buffer->AddString(string_ref.str, string_ref.len); |
102 | } |
103 | } |
104 | } |
105 | |
106 | template <typename T, typename M> |
107 | std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions, |
108 | const T* in_data, const M* multipliers, |
109 | T* out_data, int dimension) { |
110 | if (in_dimensions.size == 0) { |
111 | // If input tensor is a scalar, then just copy it to output (no need to |
112 | // multiply). |
113 | *out_data = *in_data; |
114 | return std::make_pair(0, 0); |
115 | } |
116 | |
117 | const int dimension_size = in_dimensions.data[dimension]; |
118 | if (dimension == in_dimensions.size - 1) { |
119 | CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], |
120 | out_data); |
121 | return std::make_pair( |
122 | dimension_size, |
123 | dimension_size * static_cast<int>(multipliers[dimension])); |
124 | } |
125 | int total_stride_size = 0, total_tiled_stride_size = 0; |
126 | const T* copy_from_data = in_data; |
127 | T* copy_to_data = out_data; |
128 | for (int i = 0; i < dimension_size; ++i) { |
129 | int stride_size = 0, tiled_stride_size = 0; |
130 | std::tie(stride_size, tiled_stride_size) = |
131 | TileOneDimension(in_dimensions, copy_from_data, multipliers, |
132 | copy_to_data, dimension + 1); |
133 | copy_from_data += stride_size; |
134 | copy_to_data += tiled_stride_size; |
135 | total_stride_size += stride_size; |
136 | total_tiled_stride_size += tiled_stride_size; |
137 | } |
138 | CopyMultipleTimes(out_data, total_tiled_stride_size, |
139 | multipliers[dimension] - 1, |
140 | out_data + total_tiled_stride_size); |
141 | return std::make_pair( |
142 | total_stride_size, |
143 | static_cast<int>(total_tiled_stride_size * multipliers[dimension])); |
144 | } |
145 | |
146 | template <typename M> |
147 | std::pair<int, int> TileStringOneDimension( |
148 | const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, |
149 | int in_data_index, const M* multipliers, DynamicBuffer* buffer, |
150 | int buffer_index, int dimension, TfLiteTensor* out_data) { |
151 | const int dimension_size = in_dimensions.data[dimension]; |
152 | if (dimension == in_dimensions.size - 1) { |
153 | CopyStringMultipleTimes(in_data, in_data_index, dimension_size, |
154 | multipliers[dimension], buffer); |
155 | return {dimension_size, |
156 | dimension_size * static_cast<int>(multipliers[dimension])}; |
157 | } |
158 | |
159 | int total_stride_size = 0, total_tiled_stride_size = 0; |
160 | for (int i = 0; i < dimension_size; ++i) { |
161 | int stride_size, tiled_stride_size; |
162 | std::tie(stride_size, tiled_stride_size) = TileStringOneDimension( |
163 | in_dimensions, in_data, in_data_index + total_stride_size, multipliers, |
164 | buffer, buffer_index + total_tiled_stride_size, dimension + 1, |
165 | out_data); |
166 | total_stride_size += stride_size; |
167 | total_tiled_stride_size += tiled_stride_size; |
168 | } |
169 | |
170 | buffer->WriteToTensor(out_data, /*new_shape=*/nullptr); |
171 | CopyStringMultipleTimes(out_data, buffer_index, total_tiled_stride_size, |
172 | multipliers[dimension] - 1, buffer); |
173 | |
174 | return {total_stride_size, |
175 | total_tiled_stride_size * static_cast<int>(multipliers[dimension])}; |
176 | } |
177 | |
178 | template <typename T> |
179 | void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, |
180 | const TfLiteTensor* multipliers, TfLiteTensor* out_data) { |
181 | // Doing recursively tiling from top to down dimension. |
182 | switch (multipliers->type) { |
183 | case kTfLiteInt32: |
184 | TileOneDimension(in_dimensions, GetTensorData<T>(in_data), |
185 | GetTensorData<int32_t>(multipliers), |
186 | GetTensorData<T>(out_data), 0); |
187 | break; |
188 | case kTfLiteInt64: |
189 | TileOneDimension(in_dimensions, GetTensorData<T>(in_data), |
190 | GetTensorData<int64_t>(multipliers), |
191 | GetTensorData<T>(out_data), 0); |
192 | break; |
193 | default: |
194 | break; |
195 | } |
196 | } |
197 | |
198 | void TileString(const TfLiteIntArray& in_dimensions, |
199 | const TfLiteTensor* in_data, const TfLiteTensor* multipliers, |
200 | DynamicBuffer* buffer, TfLiteTensor* out_data) { |
201 | // Doing recursively tiling from top to down dimension. |
202 | switch (multipliers->type) { |
203 | case kTfLiteInt32: |
204 | TileStringOneDimension(in_dimensions, in_data, 0, |
205 | GetTensorData<int32_t>(multipliers), buffer, 0, 0, |
206 | out_data); |
207 | break; |
208 | case kTfLiteInt64: |
209 | TileStringOneDimension(in_dimensions, in_data, 0, |
210 | GetTensorData<int64_t>(multipliers), buffer, 0, 0, |
211 | out_data); |
212 | break; |
213 | default: |
214 | break; |
215 | } |
216 | } |
217 | } // namespace |
218 | |
219 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
220 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
221 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
222 | |
223 | const TfLiteTensor* input; |
224 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
225 | |
226 | TfLiteTensor* output; |
227 | TF_LITE_ENSURE_OK(context, |
228 | GetOutputSafe(context, node, kOutputTensor, &output)); |
229 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
230 | |
231 | const TfLiteTensor* multipliers; |
232 | TF_LITE_ENSURE_OK( |
233 | context, GetInputSafe(context, node, kInputMultipliers, &multipliers)); |
234 | // Only int32 and int64 multipliers type is supported. |
235 | if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) { |
236 | TF_LITE_KERNEL_LOG(context, |
237 | "Multipliers of type '%s' are not supported by tile." , |
238 | TfLiteTypeGetName(multipliers->type)); |
239 | return kTfLiteError; |
240 | } |
241 | |
242 | if (IsConstantTensor(multipliers)) { |
243 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
244 | } else { |
245 | SetTensorToDynamic(output); |
246 | } |
247 | return kTfLiteOk; |
248 | } |
249 | |
250 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
251 | const TfLiteTensor* input; |
252 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
253 | TfLiteTensor* output; |
254 | TF_LITE_ENSURE_OK(context, |
255 | GetOutputSafe(context, node, kOutputTensor, &output)); |
256 | const TfLiteTensor* multipliers; |
257 | TF_LITE_ENSURE_OK( |
258 | context, GetInputSafe(context, node, kInputMultipliers, &multipliers)); |
259 | |
260 | if (IsDynamicTensor(output)) { |
261 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
262 | } |
263 | if (GetTensorShape(output).FlatSize() == 0) { |
264 | return kTfLiteOk; |
265 | } |
266 | |
267 | switch (output->type) { |
268 | case kTfLiteFloat32: |
269 | Tile<float>(*(input->dims), input, multipliers, output); |
270 | break; |
271 | case kTfLiteInt8: |
272 | Tile<int8_t>(*(input->dims), input, multipliers, output); |
273 | break; |
274 | case kTfLiteUInt8: |
275 | Tile<uint8_t>(*(input->dims), input, multipliers, output); |
276 | break; |
277 | case kTfLiteInt32: |
278 | Tile<int32_t>(*(input->dims), input, multipliers, output); |
279 | break; |
280 | case kTfLiteInt64: |
281 | Tile<int64_t>(*(input->dims), input, multipliers, output); |
282 | break; |
283 | case kTfLiteString: { |
284 | DynamicBuffer buffer; |
285 | TileString(*(input->dims), input, multipliers, &buffer, output); |
286 | buffer.WriteToTensor(output, /*new_shape=*/nullptr); |
287 | break; |
288 | } |
289 | case kTfLiteBool: |
290 | Tile<bool>(*(input->dims), input, multipliers, output); |
291 | break; |
292 | default: |
293 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by tile." , |
294 | TfLiteTypeGetName(output->type)); |
295 | return kTfLiteError; |
296 | } |
297 | return kTfLiteOk; |
298 | } |
299 | |
300 | } // namespace tile |
301 | TfLiteRegistration* Register_TILE() { |
302 | static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; |
303 | return &r; |
304 | } |
305 | } // namespace builtin |
306 | } // namespace ops |
307 | } // namespace tflite |
308 | |