1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | #include <cstring> |
18 | #include <memory> |
19 | |
20 | #include "tensorflow/lite/c/builtin_op_data.h" |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor.h" |
23 | #include "tensorflow/lite/kernels/kernel_util.h" |
24 | |
25 | namespace tflite { |
26 | namespace ops { |
27 | namespace builtin { |
28 | namespace reshape { |
29 | |
30 | constexpr int kInputTensor = 0; |
31 | constexpr int kShapeTensor = 1; |
32 | constexpr int kOutputTensor = 0; |
33 | |
34 | TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*); |
35 | |
36 | TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { |
37 | TfLiteIntArray* output_shape = GetOutputShape(context, node); |
38 | std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> |
39 | scoped_output_shape(output_shape, TfLiteIntArrayFree); |
40 | |
41 | const TfLiteTensor* input; |
42 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
43 | TfLiteTensor* output; |
44 | TF_LITE_ENSURE_OK(context, |
45 | GetOutputSafe(context, node, kOutputTensor, &output)); |
46 | |
47 | // Tensorflow's Reshape allows one of the shape components to have the |
48 | // special -1 value, meaning it will be calculated automatically based on the |
49 | // input. Here we calculate what that dimension should be so that the number |
50 | // of output elements is the same as the number of input elements. |
51 | int64_t non_zero_num_input_elements = 1, num_input_elements = 1; |
52 | const RuntimeShape& input_shape = GetTensorShape(input); |
53 | for (int i = 0; i < input_shape.DimensionsCount(); ++i) { |
54 | const int value = input_shape.Dims(i); |
55 | num_input_elements *= value; |
56 | if (value != 0) { |
57 | non_zero_num_input_elements *= value; |
58 | } |
59 | } |
60 | |
61 | int64_t non_zero_num_output_elements = 1, num_output_elements = 1; |
62 | int stretch_dim = -1; |
63 | for (int i = 0; i < output_shape->size; ++i) { |
64 | const int value = output_shape->data[i]; |
65 | if (value == -1) { |
66 | TF_LITE_ENSURE_EQ(context, stretch_dim, -1); |
67 | stretch_dim = i; |
68 | continue; |
69 | } else if (value != 0) { |
70 | non_zero_num_output_elements *= value; |
71 | } |
72 | num_output_elements *= value; |
73 | } |
74 | |
75 | if (stretch_dim != -1) { |
76 | if (num_input_elements == 0 && num_output_elements != 0) { |
77 | output_shape->data[stretch_dim] = 0; |
78 | } else { |
79 | output_shape->data[stretch_dim] = |
80 | non_zero_num_input_elements / non_zero_num_output_elements; |
81 | } |
82 | num_output_elements *= output_shape->data[stretch_dim]; |
83 | } |
84 | |
85 | TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); |
86 | return context->ResizeTensor(context, output, scoped_output_shape.release()); |
87 | } |
88 | |
89 | inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context, |
90 | TfLiteNode* node) { |
91 | const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); |
92 | if (shape == nullptr) return nullptr; |
93 | |
94 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); |
95 | for (int i = 0; i < output_shape->size; ++i) { |
96 | output_shape->data[i] = shape->data.i32[i]; |
97 | } |
98 | |
99 | return output_shape; |
100 | } |
101 | |
102 | inline TfLiteIntArray* GetOutputShapeFromParam(TfLiteContext* context, |
103 | TfLiteNode* node) { |
104 | auto* params = reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data); |
105 | |
106 | // The function is returned above this line if the shape tensor is usable. |
107 | // Now fallback to the shape parameter in `TfLiteReshapeParams`. |
108 | int num_dimensions = params->num_dimensions; |
109 | if (num_dimensions == 1 && params->shape[0] == 0) { |
110 | // Legacy tflite models use a shape parameter of [0] to indicate scalars, |
111 | // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during |
112 | // toco conversion. |
113 | num_dimensions = 0; |
114 | } |
115 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); |
116 | for (int i = 0; i < num_dimensions; ++i) { |
117 | output_shape->data[i] = params->shape[i]; |
118 | } |
119 | |
120 | return output_shape; |
121 | } |
122 | |
123 | // Check if the shape tensor is valid. Shapes should be int32 vectors. |
124 | inline bool ShapeIsVector(TfLiteContext* context, TfLiteNode* node) { |
125 | const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); |
126 | return (shape != nullptr && shape->dims->size == 1 && |
127 | shape->type == kTfLiteInt32); |
128 | } |
129 | |
130 | TfLiteIntArray* GetOutputShape(TfLiteContext* context, TfLiteNode* node) { |
131 | if (NumInputs(node) == 2 && ShapeIsVector(context, node)) { |
132 | return GetOutputShapeFromTensor(context, node); |
133 | } else { |
134 | return GetOutputShapeFromParam(context, node); |
135 | } |
136 | } |
137 | |
138 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
139 | TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); |
140 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
141 | |
142 | // Always postpone sizing string tensors, even if we could in principle |
143 | // calculate their shapes now. String tensors don't benefit from having their |
144 | // shapes precalculated because the actual memory can only be allocated after |
145 | // we know all the content. |
146 | TfLiteTensor* output; |
147 | TF_LITE_ENSURE_OK(context, |
148 | GetOutputSafe(context, node, kOutputTensor, &output)); |
149 | if (output->type != kTfLiteString) { |
150 | if (NumInputs(node) == 1 || |
151 | IsConstantTensor(GetInput(context, node, kShapeTensor))) { |
152 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
153 | } else { |
154 | SetTensorToDynamic(output); |
155 | } |
156 | } |
157 | return kTfLiteOk; |
158 | } |
159 | |
160 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
161 | const TfLiteTensor* input; |
162 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
163 | TfLiteTensor* output; |
164 | TF_LITE_ENSURE_OK(context, |
165 | GetOutputSafe(context, node, kOutputTensor, &output)); |
166 | |
167 | // There are two ways in which the 'output' can be made dynamic: it could be |
168 | // a string tensor, or its shape cannot be calculated during Prepare(). In |
169 | // either case, we now have all the information to calculate its shape. |
170 | if (IsDynamicTensor(output)) { |
171 | TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); |
172 | } |
173 | |
174 | // Note that string tensors are always "dynamic" in the sense that their size |
175 | // is not known until we have all the content. This applies even when their |
176 | // shape is known ahead of time. As a result, a string tensor is never given |
177 | // any memory by ResizeOutput(), and we need to do it manually here. Since |
178 | // reshape doesn't change the data, the output tensor needs exactly as many |
179 | // bytes as the input tensor. |
180 | if (output->type == kTfLiteString) { |
181 | auto bytes_required = input->bytes; |
182 | TfLiteTensorRealloc(bytes_required, output); |
183 | output->bytes = bytes_required; |
184 | } |
185 | |
186 | memcpy(output->data.raw, input->data.raw, input->bytes); |
187 | |
188 | return kTfLiteOk; |
189 | } |
190 | |
191 | } // namespace reshape |
192 | |
193 | TfLiteRegistration* Register_RESHAPE() { |
194 | static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, |
195 | reshape::Eval}; |
196 | return &r; |
197 | } |
198 | |
199 | } // namespace builtin |
200 | } // namespace ops |
201 | } // namespace tflite |
202 | |