1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include <algorithm>
18#include <tuple>
19#include <utility>
20
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23#include "tensorflow/lite/kernels/internal/tensor.h"
24#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25#include "tensorflow/lite/kernels/kernel_util.h"
26#include "tensorflow/lite/string_util.h"
27
28namespace tflite {
29namespace ops {
30namespace builtin {
31namespace tile {
32
33constexpr int kInputTensor = 0;
34constexpr int kInputMultipliers = 1;
35constexpr int kOutputTensor = 0;
36
37namespace {
38template <typename T>
39TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
40 const TfLiteTensor* multipliers,
41 int num_dimensions) {
42 const T* multipliers_v = GetTensorData<T>(multipliers);
43
44 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
45 for (int i = 0; i < num_dimensions; ++i) {
46 output_shape->data[i] = shape.data[i] * multipliers_v[i];
47 }
48 return output_shape;
49}
50
51TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
52 const TfLiteTensor* input;
53 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
54 TfLiteTensor* output;
55 TF_LITE_ENSURE_OK(context,
56 GetOutputSafe(context, node, kOutputTensor, &output));
57 const TfLiteTensor* multipliers;
58 TF_LITE_ENSURE_OK(
59 context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
60
61 const int num_dimensions = NumDimensions(input);
62 const int num_multipliers = NumElements(multipliers);
63 TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers);
64 switch (multipliers->type) {
65 case kTfLiteInt32:
66 return context->ResizeTensor(
67 context, output,
68 MultiplyShapeDims<int32_t>(*input->dims, multipliers,
69 num_dimensions));
70 case kTfLiteInt64:
71 return context->ResizeTensor(
72 context, output,
73 MultiplyShapeDims<int64_t>(*input->dims, multipliers,
74 num_dimensions));
75 default:
76 TF_LITE_KERNEL_LOG(context,
77 "Multipliers of type '%s' are not supported by tile.",
78 TfLiteTypeGetName(multipliers->type));
79 return kTfLiteError;
80 }
81}
82
83template <typename T, typename M>
84void CopyMultipleTimes(const T* in_data, int32_t in_size, M multiplier,
85 T* out_data) {
86 for (M i = 0; i < multiplier; ++i) {
87 const T* in_end = in_data + in_size;
88 T* new_out_data = std::copy(in_data, in_end, out_data);
89 in_data = out_data;
90 out_data = new_out_data;
91 }
92}
93
94template <typename M>
95void CopyStringMultipleTimes(const TfLiteTensor* in_data, int in_data_index,
96 const int dimension_size, M multiplier,
97 DynamicBuffer* buffer) {
98 for (M i = 0; i < multiplier; ++i) {
99 for (int j = 0; j < dimension_size; ++j) {
100 const auto string_ref = GetString(in_data, in_data_index + j);
101 buffer->AddString(string_ref.str, string_ref.len);
102 }
103 }
104}
105
106template <typename T, typename M>
107std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
108 const T* in_data, const M* multipliers,
109 T* out_data, int dimension) {
110 if (in_dimensions.size == 0) {
111 // If input tensor is a scalar, then just copy it to output (no need to
112 // multiply).
113 *out_data = *in_data;
114 return std::make_pair(0, 0);
115 }
116
117 const int dimension_size = in_dimensions.data[dimension];
118 if (dimension == in_dimensions.size - 1) {
119 CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
120 out_data);
121 return std::make_pair(
122 dimension_size,
123 dimension_size * static_cast<int>(multipliers[dimension]));
124 }
125 int total_stride_size = 0, total_tiled_stride_size = 0;
126 const T* copy_from_data = in_data;
127 T* copy_to_data = out_data;
128 for (int i = 0; i < dimension_size; ++i) {
129 int stride_size = 0, tiled_stride_size = 0;
130 std::tie(stride_size, tiled_stride_size) =
131 TileOneDimension(in_dimensions, copy_from_data, multipliers,
132 copy_to_data, dimension + 1);
133 copy_from_data += stride_size;
134 copy_to_data += tiled_stride_size;
135 total_stride_size += stride_size;
136 total_tiled_stride_size += tiled_stride_size;
137 }
138 CopyMultipleTimes(out_data, total_tiled_stride_size,
139 multipliers[dimension] - 1,
140 out_data + total_tiled_stride_size);
141 return std::make_pair(
142 total_stride_size,
143 static_cast<int>(total_tiled_stride_size * multipliers[dimension]));
144}
145
146template <typename M>
147std::pair<int, int> TileStringOneDimension(
148 const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
149 int in_data_index, const M* multipliers, DynamicBuffer* buffer,
150 int buffer_index, int dimension, TfLiteTensor* out_data) {
151 const int dimension_size = in_dimensions.data[dimension];
152 if (dimension == in_dimensions.size - 1) {
153 CopyStringMultipleTimes(in_data, in_data_index, dimension_size,
154 multipliers[dimension], buffer);
155 return {dimension_size,
156 dimension_size * static_cast<int>(multipliers[dimension])};
157 }
158
159 int total_stride_size = 0, total_tiled_stride_size = 0;
160 for (int i = 0; i < dimension_size; ++i) {
161 int stride_size, tiled_stride_size;
162 std::tie(stride_size, tiled_stride_size) = TileStringOneDimension(
163 in_dimensions, in_data, in_data_index + total_stride_size, multipliers,
164 buffer, buffer_index + total_tiled_stride_size, dimension + 1,
165 out_data);
166 total_stride_size += stride_size;
167 total_tiled_stride_size += tiled_stride_size;
168 }
169
170 buffer->WriteToTensor(out_data, /*new_shape=*/nullptr);
171 CopyStringMultipleTimes(out_data, buffer_index, total_tiled_stride_size,
172 multipliers[dimension] - 1, buffer);
173
174 return {total_stride_size,
175 total_tiled_stride_size * static_cast<int>(multipliers[dimension])};
176}
177
178template <typename T>
179void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
180 const TfLiteTensor* multipliers, TfLiteTensor* out_data) {
181 // Doing recursively tiling from top to down dimension.
182 switch (multipliers->type) {
183 case kTfLiteInt32:
184 TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
185 GetTensorData<int32_t>(multipliers),
186 GetTensorData<T>(out_data), 0);
187 break;
188 case kTfLiteInt64:
189 TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
190 GetTensorData<int64_t>(multipliers),
191 GetTensorData<T>(out_data), 0);
192 break;
193 default:
194 break;
195 }
196}
197
198void TileString(const TfLiteIntArray& in_dimensions,
199 const TfLiteTensor* in_data, const TfLiteTensor* multipliers,
200 DynamicBuffer* buffer, TfLiteTensor* out_data) {
201 // Doing recursively tiling from top to down dimension.
202 switch (multipliers->type) {
203 case kTfLiteInt32:
204 TileStringOneDimension(in_dimensions, in_data, 0,
205 GetTensorData<int32_t>(multipliers), buffer, 0, 0,
206 out_data);
207 break;
208 case kTfLiteInt64:
209 TileStringOneDimension(in_dimensions, in_data, 0,
210 GetTensorData<int64_t>(multipliers), buffer, 0, 0,
211 out_data);
212 break;
213 default:
214 break;
215 }
216}
217} // namespace
218
219TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
220 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
221 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
222
223 const TfLiteTensor* input;
224 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
225
226 TfLiteTensor* output;
227 TF_LITE_ENSURE_OK(context,
228 GetOutputSafe(context, node, kOutputTensor, &output));
229 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
230
231 const TfLiteTensor* multipliers;
232 TF_LITE_ENSURE_OK(
233 context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
234 // Only int32 and int64 multipliers type is supported.
235 if (multipliers->type != kTfLiteInt32 && multipliers->type != kTfLiteInt64) {
236 TF_LITE_KERNEL_LOG(context,
237 "Multipliers of type '%s' are not supported by tile.",
238 TfLiteTypeGetName(multipliers->type));
239 return kTfLiteError;
240 }
241
242 if (IsConstantTensor(multipliers)) {
243 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
244 } else {
245 SetTensorToDynamic(output);
246 }
247 return kTfLiteOk;
248}
249
250TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
251 const TfLiteTensor* input;
252 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
253 TfLiteTensor* output;
254 TF_LITE_ENSURE_OK(context,
255 GetOutputSafe(context, node, kOutputTensor, &output));
256 const TfLiteTensor* multipliers;
257 TF_LITE_ENSURE_OK(
258 context, GetInputSafe(context, node, kInputMultipliers, &multipliers));
259
260 if (IsDynamicTensor(output)) {
261 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
262 }
263 if (GetTensorShape(output).FlatSize() == 0) {
264 return kTfLiteOk;
265 }
266
267 switch (output->type) {
268 case kTfLiteFloat32:
269 Tile<float>(*(input->dims), input, multipliers, output);
270 break;
271 case kTfLiteInt8:
272 Tile<int8_t>(*(input->dims), input, multipliers, output);
273 break;
274 case kTfLiteUInt8:
275 Tile<uint8_t>(*(input->dims), input, multipliers, output);
276 break;
277 case kTfLiteInt32:
278 Tile<int32_t>(*(input->dims), input, multipliers, output);
279 break;
280 case kTfLiteInt64:
281 Tile<int64_t>(*(input->dims), input, multipliers, output);
282 break;
283 case kTfLiteString: {
284 DynamicBuffer buffer;
285 TileString(*(input->dims), input, multipliers, &buffer, output);
286 buffer.WriteToTensor(output, /*new_shape=*/nullptr);
287 break;
288 }
289 case kTfLiteBool:
290 Tile<bool>(*(input->dims), input, multipliers, output);
291 break;
292 default:
293 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by tile.",
294 TfLiteTypeGetName(output->type));
295 return kTfLiteError;
296 }
297 return kTfLiteOk;
298}
299
300} // namespace tile
301TfLiteRegistration* Register_TILE() {
302 static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval};
303 return &r;
304}
305} // namespace builtin
306} // namespace ops
307} // namespace tflite
308