1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17#include <cstring>
18#include <memory>
19
20#include "tensorflow/lite/c/builtin_op_data.h"
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/tensor.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace reshape {
29
30constexpr int kInputTensor = 0;
31constexpr int kShapeTensor = 1;
32constexpr int kOutputTensor = 0;
33
34TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*);
35
36TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
37 TfLiteIntArray* output_shape = GetOutputShape(context, node);
38 std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)>
39 scoped_output_shape(output_shape, TfLiteIntArrayFree);
40
41 const TfLiteTensor* input;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43 TfLiteTensor* output;
44 TF_LITE_ENSURE_OK(context,
45 GetOutputSafe(context, node, kOutputTensor, &output));
46
47 // Tensorflow's Reshape allows one of the shape components to have the
48 // special -1 value, meaning it will be calculated automatically based on the
49 // input. Here we calculate what that dimension should be so that the number
50 // of output elements is the same as the number of input elements.
51 int64_t non_zero_num_input_elements = 1, num_input_elements = 1;
52 const RuntimeShape& input_shape = GetTensorShape(input);
53 for (int i = 0; i < input_shape.DimensionsCount(); ++i) {
54 const int value = input_shape.Dims(i);
55 num_input_elements *= value;
56 if (value != 0) {
57 non_zero_num_input_elements *= value;
58 }
59 }
60
61 int64_t non_zero_num_output_elements = 1, num_output_elements = 1;
62 int stretch_dim = -1;
63 for (int i = 0; i < output_shape->size; ++i) {
64 const int value = output_shape->data[i];
65 if (value == -1) {
66 TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
67 stretch_dim = i;
68 continue;
69 } else if (value != 0) {
70 non_zero_num_output_elements *= value;
71 }
72 num_output_elements *= value;
73 }
74
75 if (stretch_dim != -1) {
76 if (num_input_elements == 0 && num_output_elements != 0) {
77 output_shape->data[stretch_dim] = 0;
78 } else {
79 output_shape->data[stretch_dim] =
80 non_zero_num_input_elements / non_zero_num_output_elements;
81 }
82 num_output_elements *= output_shape->data[stretch_dim];
83 }
84
85 TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
86 return context->ResizeTensor(context, output, scoped_output_shape.release());
87}
88
89inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context,
90 TfLiteNode* node) {
91 const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
92 if (shape == nullptr) return nullptr;
93
94 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]);
95 for (int i = 0; i < output_shape->size; ++i) {
96 output_shape->data[i] = shape->data.i32[i];
97 }
98
99 return output_shape;
100}
101
102inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context,
103 TfLiteNode* node) {
104 auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
105
106 // The function is returned above this line if the shape tensor is usable.
107 // Now fallback to the shape parameter in `TfLiteReshapeParams`.
108 int num_dimensions = params->num_dimensions;
109 if (num_dimensions == 1 && params->shape[0] == 0) {
110 // Legacy tflite models use a shape parameter of [0] to indicate scalars,
111 // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
112 // toco conversion.
113 num_dimensions = 0;
114 }
115 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
116 for (int i = 0; i < num_dimensions; ++i) {
117 output_shape->data[i] = params->shape[i];
118 }
119
120 return output_shape;
121}
122
123// Check if the shape tensor is valid. Shapes should be int32 vectors.
124inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) {
125 const TfLiteTensor* shape = GetInput(context, node, kShapeTensor);
126 return (shape != nullptr && shape->dims->size == 1 &&
127 shape->type == kTfLiteInt32);
128}
129
130TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) {
131 if (NumInputs(node) == 2 && ShapeIsVector(context, node)) {
132 return GetOutputShapeFromTensor(context, node);
133 } else {
134 return GetOutputShapeFromParam(context, node);
135 }
136}
137
138TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
139 TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
140 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
141
142 // Always postpone sizing string tensors, even if we could in principle
143 // calculate their shapes now. String tensors don't benefit from having their
144 // shapes precalculated because the actual memory can only be allocated after
145 // we know all the content.
146 TfLiteTensor* output;
147 TF_LITE_ENSURE_OK(context,
148 GetOutputSafe(context, node, kOutputTensor, &output));
149 if (output->type != kTfLiteString) {
150 if (NumInputs(node) == 1 ||
151 IsConstantTensor(GetInput(context, node, kShapeTensor))) {
152 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
153 } else {
154 SetTensorToDynamic(output);
155 }
156 }
157 return kTfLiteOk;
158}
159
160TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
161 const TfLiteTensor* input;
162 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
163 TfLiteTensor* output;
164 TF_LITE_ENSURE_OK(context,
165 GetOutputSafe(context, node, kOutputTensor, &output));
166
167 // There are two ways in which the 'output' can be made dynamic: it could be
168 // a string tensor, or its shape cannot be calculated during Prepare(). In
169 // either case, we now have all the information to calculate its shape.
170 if (IsDynamicTensor(output)) {
171 TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
172 }
173
174 // Note that string tensors are always "dynamic" in the sense that their size
175 // is not known until we have all the content. This applies even when their
176 // shape is known ahead of time. As a result, a string tensor is never given
177 // any memory by ResizeOutput(), and we need to do it manually here. Since
178 // reshape doesn't change the data, the output tensor needs exactly as many
179 // bytes as the input tensor.
180 if (output->type == kTfLiteString) {
181 auto bytes_required = input->bytes;
182 TfLiteTensorRealloc(bytes_required, output);
183 output->bytes = bytes_required;
184 }
185
186 memcpy(output->data.raw, input->data.raw, input->bytes);
187
188 return kTfLiteOk;
189}
190
191} // namespace reshape
192
193TfLiteRegistration* Register_RESHAPE() {
194 static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
195 reshape::Eval};
196 return &r;
197}
198
199} // namespace builtin
200} // namespace ops
201} // namespace tflite
202