1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/conv3d_transpose.h" |
16 | |
17 | #include <cstddef> |
18 | #include <cstdint> |
19 | |
20 | #include "tensorflow/lite/c/builtin_op_data.h" |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
23 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
25 | #include "tensorflow/lite/kernels/internal/types.h" |
26 | #include "tensorflow/lite/kernels/kernel_util.h" |
27 | #include "tensorflow/lite/kernels/padding.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace conv3d_transpose { |
33 | |
34 | enum KernelType { |
35 | kReference, |
36 | kGenericOptimized, |
37 | }; |
38 | |
39 | const int kTensorNotAllocated = -1; |
40 | |
41 | struct OpData { |
42 | Padding3DValues padding; |
43 | |
44 | // The id of the temporary col2im tensor. |
45 | int col2im_id = kTensorNotAllocated; |
46 | |
47 | // The index of col2im tensor in the temporaries list. |
48 | int col2im_index; |
49 | |
50 | bool need_col2im = false; |
51 | }; |
52 | |
53 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
54 | auto* opdata = new OpData; |
55 | return opdata; |
56 | } |
57 | |
58 | void Free(TfLiteContext* context, void* buffer) { |
59 | delete static_cast<OpData*>(buffer); |
60 | } |
61 | |
62 | static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, |
63 | TfLiteNode* node, |
64 | KernelType kernel_type) { |
65 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
66 | int temporaries_count = 0; |
67 | |
68 | // Allocate col2im tensor for the optimized kernel. |
69 | if (kernel_type == kGenericOptimized) { |
70 | if (data->col2im_id == kTensorNotAllocated) { |
71 | context->AddTensors(context, 1, &data->col2im_id); |
72 | } |
73 | data->col2im_index = temporaries_count++; |
74 | data->need_col2im = true; |
75 | } |
76 | |
77 | TfLiteIntArrayFree(node->temporaries); |
78 | node->temporaries = TfLiteIntArrayCreate(temporaries_count); |
79 | |
80 | return kTfLiteOk; |
81 | } |
82 | |
83 | TfLiteStatus ResizeOutputAndTemporaryTensors( |
84 | TfLiteContext* context, OpData* opdata, TfLiteConv3DTransposeParams* params, |
85 | const TfLiteTensor* shape_tensor, const TfLiteTensor* filter, |
86 | const TfLiteTensor* input, TfLiteTensor* col2im, TfLiteTensor* output) { |
87 | auto shape_data = GetTensorData<int32_t>(shape_tensor); |
88 | // Output and input tensor must have the same batch size. |
89 | TF_LITE_ENSURE_EQ(context, shape_data[0], SizeOfDimension(input, 0)); |
90 | // The number of channels of output must be divisible by that of filter. |
91 | TF_LITE_ENSURE_EQ(context, shape_data[4] % SizeOfDimension(filter, 3), 0); |
92 | |
93 | // Compute padding. |
94 | const RuntimeShape& filter_shape = GetTensorShape(filter); |
95 | const int depth = shape_data[1]; |
96 | const int height = shape_data[2]; |
97 | const int width = shape_data[3]; |
98 | const int filter_depth = filter_shape.Dims(0); |
99 | const int filter_height = filter_shape.Dims(1); |
100 | const int filter_width = filter_shape.Dims(2); |
101 | int unused_out_width, unused_out_height, unused_out_depth; |
102 | opdata->padding = ComputePadding3DValues( |
103 | params->stride_height, params->stride_width, params->stride_depth, |
104 | params->dilation_height_factor, params->dilation_width_factor, |
105 | params->dilation_depth_factor, height, width, depth, filter_height, |
106 | filter_width, filter_depth, params->padding, &unused_out_height, |
107 | &unused_out_width, &unused_out_depth); |
108 | // Computed shape must match the shape of the input tensor. |
109 | TF_LITE_ENSURE_EQ(context, unused_out_depth, SizeOfDimension(input, 1)); |
110 | TF_LITE_ENSURE_EQ(context, unused_out_height, SizeOfDimension(input, 2)); |
111 | TF_LITE_ENSURE_EQ(context, unused_out_width, SizeOfDimension(input, 3)); |
112 | |
113 | TfLiteIntArray* output_shape = |
114 | TfLiteIntArrayCreate(NumElements(shape_tensor)); |
115 | for (int i = 0; i < output_shape->size; ++i) { |
116 | output_shape->data[i] = GetTensorData<int32_t>(shape_tensor)[i]; |
117 | } |
118 | |
119 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); |
120 | |
121 | // Resize col2im tensor. |
122 | if (opdata->need_col2im) { |
123 | TfLiteIntArray* col2im_shape_array = TfLiteIntArrayCreate(2); |
124 | const RuntimeShape& input_shape = GetTensorShape(input); |
125 | col2im_shape_array->data[0] = |
126 | input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3); |
127 | col2im_shape_array->data[1] = |
128 | filter_depth * filter_height * filter_width * filter_shape.Dims(3); |
129 | |
130 | col2im->type = kTfLiteFloat32; |
131 | col2im->allocation_type = kTfLiteDynamic; |
132 | return context->ResizeTensor(context, col2im, col2im_shape_array); |
133 | } |
134 | return kTfLiteOk; |
135 | } |
136 | |
137 | TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context, |
138 | TfLiteNode* node) { |
139 | auto* params = |
140 | reinterpret_cast<TfLiteConv3DTransposeParams*>(node->builtin_data); |
141 | OpData* opdata = reinterpret_cast<OpData*>(node->user_data); |
142 | // Check number of inputs/outputs. |
143 | TF_LITE_ENSURE(context, node->inputs->size == 3 || node->inputs->size == 4); |
144 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
145 | TfLiteTensor* output; |
146 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
147 | const TfLiteTensor* output_shape; |
148 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &output_shape)); |
149 | const TfLiteTensor* filter; |
150 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter)); |
151 | const TfLiteTensor* input; |
152 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input)); |
153 | |
154 | // Check dimensionality of inputs/outputs. |
155 | TF_LITE_ENSURE_EQ(context, output_shape->dims->size, 1); |
156 | TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 5); |
157 | TF_LITE_ENSURE_EQ(context, input->dims->size, 5); |
158 | TF_LITE_ENSURE_EQ(context, filter->dims->size, 5); |
159 | |
160 | // Input and filter must have the same number of channels. |
161 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 4), |
162 | SizeOfDimension(filter, 4)); |
163 | |
164 | // Check types. |
165 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); |
166 | TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32); |
167 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type); |
168 | TF_LITE_ENSURE_TYPES_EQ(context, output_shape->type, kTfLiteInt32); |
169 | |
170 | // Check bias. |
171 | const TfLiteTensor* bias = GetInput(context, node, 3); |
172 | if (bias) { |
173 | TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input->type); |
174 | TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 3)); |
175 | } |
176 | |
177 | // GenericOptimized kernel currently doesn't support dilation. |
178 | if (params->dilation_depth_factor > 1 || params->dilation_height_factor > 1 || |
179 | params->dilation_width_factor > 1) { |
180 | kernel_type = kReference; |
181 | } |
182 | |
183 | // Allocate temporary tensors. |
184 | TF_LITE_ENSURE_STATUS( |
185 | AllocateTemporaryTensorsIfRequired(context, node, kernel_type)); |
186 | |
187 | // Check temporary tensors. |
188 | TfLiteTensor* col2im = nullptr; |
189 | if (opdata->need_col2im) { |
190 | node->temporaries->data[opdata->col2im_index] = opdata->col2im_id; |
191 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, |
192 | opdata->col2im_index, &col2im)); |
193 | } |
194 | |
195 | // Resize the output tensor. |
196 | if (!IsConstantTensor(output_shape)) { |
197 | SetTensorToDynamic(output); |
198 | if (opdata->need_col2im) { |
199 | SetTensorToDynamic(col2im); |
200 | } |
201 | } else { |
202 | TF_LITE_ENSURE_STATUS(ResizeOutputAndTemporaryTensors( |
203 | context, opdata, params, output_shape, filter, input, col2im, output)); |
204 | } |
205 | return kTfLiteOk; |
206 | } |
207 | |
208 | template <KernelType kernel_type> |
209 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
210 | return Prepare(kernel_type, context, node); |
211 | } |
212 | |
213 | void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node, |
214 | TfLiteConv3DTransposeParams* params, OpData* opdata, |
215 | const TfLiteTensor* input, const TfLiteTensor* filter, |
216 | const TfLiteTensor* bias, TfLiteTensor* col2im, |
217 | TfLiteTensor* output) { |
218 | float output_activation_min, output_activation_max; |
219 | CalculateActivationRange(params->activation, &output_activation_min, |
220 | &output_activation_max); |
221 | |
222 | Conv3DTransposeParams runtime_params; |
223 | runtime_params.padding_values = opdata->padding; |
224 | runtime_params.stride_depth = params->stride_depth; |
225 | runtime_params.stride_height = params->stride_height; |
226 | runtime_params.stride_width = params->stride_width; |
227 | runtime_params.dilation_depth = params->dilation_depth_factor; |
228 | runtime_params.dilation_height = params->dilation_height_factor; |
229 | runtime_params.dilation_width = params->dilation_width_factor; |
230 | runtime_params.float_activation_min = output_activation_min; |
231 | runtime_params.float_activation_max = output_activation_max; |
232 | |
233 | switch (kernel_type) { |
234 | case kReference: { |
235 | reference_ops::Conv3DTranspose( |
236 | runtime_params, GetTensorShape(input), GetTensorData<float>(input), |
237 | GetTensorShape(filter), GetTensorData<float>(filter), |
238 | GetTensorShape(bias), GetTensorData<float>(bias), |
239 | GetTensorShape(output), GetTensorData<float>(output)); |
240 | break; |
241 | } |
242 | case kGenericOptimized: { |
243 | optimized_ops::Conv3DTranspose( |
244 | runtime_params, GetTensorShape(input), GetTensorData<float>(input), |
245 | GetTensorShape(filter), GetTensorData<float>(filter), |
246 | GetTensorShape(bias), GetTensorData<float>(bias), |
247 | GetTensorShape(output), GetTensorData<float>(output), |
248 | GetTensorShape(col2im), GetTensorData<float>(col2im), |
249 | CpuBackendContext::GetFromContext(context)); |
250 | } break; |
251 | } |
252 | } |
253 | |
254 | TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context, |
255 | TfLiteNode* node) { |
256 | auto* params = |
257 | reinterpret_cast<TfLiteConv3DTransposeParams*>(node->builtin_data); |
258 | OpData* opdata = reinterpret_cast<OpData*>(node->user_data); |
259 | |
260 | TfLiteTensor* output; |
261 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
262 | const TfLiteTensor* output_shape; |
263 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &output_shape)); |
264 | const TfLiteTensor* filter; |
265 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter)); |
266 | const TfLiteTensor* input; |
267 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input)); |
268 | const TfLiteTensor* bias = GetInput(context, node, 3); |
269 | TfLiteTensor* col2im = opdata->need_col2im |
270 | ? GetTemporary(context, node, opdata->col2im_index) |
271 | : nullptr; |
272 | |
273 | if (IsDynamicTensor(output)) { |
274 | TF_LITE_ENSURE_OK(context, ResizeOutputAndTemporaryTensors( |
275 | context, opdata, params, output_shape, |
276 | filter, input, col2im, output)); |
277 | } |
278 | |
279 | // GenericOptimized kernel currently doesn't support dilation. |
280 | if (params->dilation_depth_factor > 1 || params->dilation_height_factor > 1 || |
281 | params->dilation_width_factor > 1) { |
282 | kernel_type = kReference; |
283 | } |
284 | |
285 | switch (input->type) { |
286 | case kTfLiteFloat32: |
287 | EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias, |
288 | col2im, output); |
289 | break; |
290 | default: |
291 | TF_LITE_KERNEL_LOG(context, "Type %s currently not supported." , |
292 | TfLiteTypeGetName(input->type)); |
293 | return kTfLiteError; |
294 | } |
295 | return kTfLiteOk; |
296 | } |
297 | |
298 | template <KernelType kernel_type> |
299 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
300 | return Eval(kernel_type, context, node); |
301 | } |
302 | |
303 | } // namespace conv3d_transpose |
304 | |
305 | TfLiteRegistration* Register_CONV_3D_TRANSPOSE_REF() { |
306 | static TfLiteRegistration r = { |
307 | conv3d_transpose::Init, conv3d_transpose::Free, |
308 | conv3d_transpose::Prepare<conv3d_transpose::kReference>, |
309 | conv3d_transpose::Eval<conv3d_transpose::kReference>}; |
310 | return &r; |
311 | } |
312 | |
313 | TfLiteRegistration* Register_CONV_3D_TRANSPOSE_GENERIC_OPT() { |
314 | static TfLiteRegistration r = { |
315 | conv3d_transpose::Init, conv3d_transpose::Free, |
316 | conv3d_transpose::Prepare<conv3d_transpose::kGenericOptimized>, |
317 | conv3d_transpose::Eval<conv3d_transpose::kGenericOptimized>}; |
318 | return &r; |
319 | } |
320 | |
321 | TfLiteRegistration* Register_CONV_3D_TRANSPOSE() { |
322 | return Register_CONV_3D_TRANSPOSE_GENERIC_OPT(); |
323 | } |
324 | |
325 | } // namespace builtin |
326 | } // namespace ops |
327 | } // namespace tflite |
328 | |