1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/conv3d_transpose.h"
16
17#include <cstddef>
18#include <cstdint>
19
20#include "tensorflow/lite/c/builtin_op_data.h"
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/cpu_backend_context.h"
23#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25#include "tensorflow/lite/kernels/internal/types.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27#include "tensorflow/lite/kernels/padding.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace conv3d_transpose {
33
34enum KernelType {
35 kReference,
36 kGenericOptimized,
37};
38
39const int kTensorNotAllocated = -1;
40
41struct OpData {
42 Padding3DValues padding;
43
44 // The id of the temporary col2im tensor.
45 int col2im_id = kTensorNotAllocated;
46
47 // The index of col2im tensor in the temporaries list.
48 int col2im_index;
49
50 bool need_col2im = false;
51};
52
53void* Init(TfLiteContext* context, const char* buffer, size_t length) {
54 auto* opdata = new OpData;
55 return opdata;
56}
57
58void Free(TfLiteContext* context, void* buffer) {
59 delete static_cast<OpData*>(buffer);
60}
61
62static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
63 TfLiteNode* node,
64 KernelType kernel_type) {
65 OpData* data = reinterpret_cast<OpData*>(node->user_data);
66 int temporaries_count = 0;
67
68 // Allocate col2im tensor for the optimized kernel.
69 if (kernel_type == kGenericOptimized) {
70 if (data->col2im_id == kTensorNotAllocated) {
71 context->AddTensors(context, 1, &data->col2im_id);
72 }
73 data->col2im_index = temporaries_count++;
74 data->need_col2im = true;
75 }
76
77 TfLiteIntArrayFree(node->temporaries);
78 node->temporaries = TfLiteIntArrayCreate(temporaries_count);
79
80 return kTfLiteOk;
81}
82
83TfLiteStatus ResizeOutputAndTemporaryTensors(
84 TfLiteContext* context, OpData* opdata, TfLiteConv3DTransposeParams* params,
85 const TfLiteTensor* shape_tensor, const TfLiteTensor* filter,
86 const TfLiteTensor* input, TfLiteTensor* col2im, TfLiteTensor* output) {
87 auto shape_data = GetTensorData<int32_t>(shape_tensor);
88 // Output and input tensor must have the same batch size.
89 TF_LITE_ENSURE_EQ(context, shape_data[0], SizeOfDimension(input, 0));
90 // The number of channels of output must be divisible by that of filter.
91 TF_LITE_ENSURE_EQ(context, shape_data[4] % SizeOfDimension(filter, 3), 0);
92
93 // Compute padding.
94 const RuntimeShape& filter_shape = GetTensorShape(filter);
95 const int depth = shape_data[1];
96 const int height = shape_data[2];
97 const int width = shape_data[3];
98 const int filter_depth = filter_shape.Dims(0);
99 const int filter_height = filter_shape.Dims(1);
100 const int filter_width = filter_shape.Dims(2);
101 int unused_out_width, unused_out_height, unused_out_depth;
102 opdata->padding = ComputePadding3DValues(
103 params->stride_height, params->stride_width, params->stride_depth,
104 params->dilation_height_factor, params->dilation_width_factor,
105 params->dilation_depth_factor, height, width, depth, filter_height,
106 filter_width, filter_depth, params->padding, &unused_out_height,
107 &unused_out_width, &unused_out_depth);
108 // Computed shape must match the shape of the input tensor.
109 TF_LITE_ENSURE_EQ(context, unused_out_depth, SizeOfDimension(input, 1));
110 TF_LITE_ENSURE_EQ(context, unused_out_height, SizeOfDimension(input, 2));
111 TF_LITE_ENSURE_EQ(context, unused_out_width, SizeOfDimension(input, 3));
112
113 TfLiteIntArray* output_shape =
114 TfLiteIntArrayCreate(NumElements(shape_tensor));
115 for (int i = 0; i < output_shape->size; ++i) {
116 output_shape->data[i] = GetTensorData<int32_t>(shape_tensor)[i];
117 }
118
119 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
120
121 // Resize col2im tensor.
122 if (opdata->need_col2im) {
123 TfLiteIntArray* col2im_shape_array = TfLiteIntArrayCreate(2);
124 const RuntimeShape& input_shape = GetTensorShape(input);
125 col2im_shape_array->data[0] =
126 input_shape.Dims(1) * input_shape.Dims(2) * input_shape.Dims(3);
127 col2im_shape_array->data[1] =
128 filter_depth * filter_height * filter_width * filter_shape.Dims(3);
129
130 col2im->type = kTfLiteFloat32;
131 col2im->allocation_type = kTfLiteDynamic;
132 return context->ResizeTensor(context, col2im, col2im_shape_array);
133 }
134 return kTfLiteOk;
135}
136
137TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
138 TfLiteNode* node) {
139 auto* params =
140 reinterpret_cast<TfLiteConv3DTransposeParams*>(node->builtin_data);
141 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
142 // Check number of inputs/outputs.
143 TF_LITE_ENSURE(context, node->inputs->size == 3 || node->inputs->size == 4);
144 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
145 TfLiteTensor* output;
146 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
147 const TfLiteTensor* output_shape;
148 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &output_shape));
149 const TfLiteTensor* filter;
150 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
151 const TfLiteTensor* input;
152 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input));
153
154 // Check dimensionality of inputs/outputs.
155 TF_LITE_ENSURE_EQ(context, output_shape->dims->size, 1);
156 TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 5);
157 TF_LITE_ENSURE_EQ(context, input->dims->size, 5);
158 TF_LITE_ENSURE_EQ(context, filter->dims->size, 5);
159
160 // Input and filter must have the same number of channels.
161 TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 4),
162 SizeOfDimension(filter, 4));
163
164 // Check types.
165 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
166 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
167 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
168 TF_LITE_ENSURE_TYPES_EQ(context, output_shape->type, kTfLiteInt32);
169
170 // Check bias.
171 const TfLiteTensor* bias = GetInput(context, node, 3);
172 if (bias) {
173 TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input->type);
174 TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 3));
175 }
176
177 // GenericOptimized kernel currently doesn't support dilation.
178 if (params->dilation_depth_factor > 1 || params->dilation_height_factor > 1 ||
179 params->dilation_width_factor > 1) {
180 kernel_type = kReference;
181 }
182
183 // Allocate temporary tensors.
184 TF_LITE_ENSURE_STATUS(
185 AllocateTemporaryTensorsIfRequired(context, node, kernel_type));
186
187 // Check temporary tensors.
188 TfLiteTensor* col2im = nullptr;
189 if (opdata->need_col2im) {
190 node->temporaries->data[opdata->col2im_index] = opdata->col2im_id;
191 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
192 opdata->col2im_index, &col2im));
193 }
194
195 // Resize the output tensor.
196 if (!IsConstantTensor(output_shape)) {
197 SetTensorToDynamic(output);
198 if (opdata->need_col2im) {
199 SetTensorToDynamic(col2im);
200 }
201 } else {
202 TF_LITE_ENSURE_STATUS(ResizeOutputAndTemporaryTensors(
203 context, opdata, params, output_shape, filter, input, col2im, output));
204 }
205 return kTfLiteOk;
206}
207
208template <KernelType kernel_type>
209TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
210 return Prepare(kernel_type, context, node);
211}
212
213void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
214 TfLiteConv3DTransposeParams* params, OpData* opdata,
215 const TfLiteTensor* input, const TfLiteTensor* filter,
216 const TfLiteTensor* bias, TfLiteTensor* col2im,
217 TfLiteTensor* output) {
218 float output_activation_min, output_activation_max;
219 CalculateActivationRange(params->activation, &output_activation_min,
220 &output_activation_max);
221
222 Conv3DTransposeParams runtime_params;
223 runtime_params.padding_values = opdata->padding;
224 runtime_params.stride_depth = params->stride_depth;
225 runtime_params.stride_height = params->stride_height;
226 runtime_params.stride_width = params->stride_width;
227 runtime_params.dilation_depth = params->dilation_depth_factor;
228 runtime_params.dilation_height = params->dilation_height_factor;
229 runtime_params.dilation_width = params->dilation_width_factor;
230 runtime_params.float_activation_min = output_activation_min;
231 runtime_params.float_activation_max = output_activation_max;
232
233 switch (kernel_type) {
234 case kReference: {
235 reference_ops::Conv3DTranspose(
236 runtime_params, GetTensorShape(input), GetTensorData<float>(input),
237 GetTensorShape(filter), GetTensorData<float>(filter),
238 GetTensorShape(bias), GetTensorData<float>(bias),
239 GetTensorShape(output), GetTensorData<float>(output));
240 break;
241 }
242 case kGenericOptimized: {
243 optimized_ops::Conv3DTranspose(
244 runtime_params, GetTensorShape(input), GetTensorData<float>(input),
245 GetTensorShape(filter), GetTensorData<float>(filter),
246 GetTensorShape(bias), GetTensorData<float>(bias),
247 GetTensorShape(output), GetTensorData<float>(output),
248 GetTensorShape(col2im), GetTensorData<float>(col2im),
249 CpuBackendContext::GetFromContext(context));
250 } break;
251 }
252}
253
254TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context,
255 TfLiteNode* node) {
256 auto* params =
257 reinterpret_cast<TfLiteConv3DTransposeParams*>(node->builtin_data);
258 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
259
260 TfLiteTensor* output;
261 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
262 const TfLiteTensor* output_shape;
263 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &output_shape));
264 const TfLiteTensor* filter;
265 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
266 const TfLiteTensor* input;
267 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 2, &input));
268 const TfLiteTensor* bias = GetInput(context, node, 3);
269 TfLiteTensor* col2im = opdata->need_col2im
270 ? GetTemporary(context, node, opdata->col2im_index)
271 : nullptr;
272
273 if (IsDynamicTensor(output)) {
274 TF_LITE_ENSURE_OK(context, ResizeOutputAndTemporaryTensors(
275 context, opdata, params, output_shape,
276 filter, input, col2im, output));
277 }
278
279 // GenericOptimized kernel currently doesn't support dilation.
280 if (params->dilation_depth_factor > 1 || params->dilation_height_factor > 1 ||
281 params->dilation_width_factor > 1) {
282 kernel_type = kReference;
283 }
284
285 switch (input->type) {
286 case kTfLiteFloat32:
287 EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias,
288 col2im, output);
289 break;
290 default:
291 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
292 TfLiteTypeGetName(input->type));
293 return kTfLiteError;
294 }
295 return kTfLiteOk;
296}
297
298template <KernelType kernel_type>
299TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
300 return Eval(kernel_type, context, node);
301}
302
303} // namespace conv3d_transpose
304
305TfLiteRegistration* Register_CONV_3D_TRANSPOSE_REF() {
306 static TfLiteRegistration r = {
307 conv3d_transpose::Init, conv3d_transpose::Free,
308 conv3d_transpose::Prepare<conv3d_transpose::kReference>,
309 conv3d_transpose::Eval<conv3d_transpose::kReference>};
310 return &r;
311}
312
313TfLiteRegistration* Register_CONV_3D_TRANSPOSE_GENERIC_OPT() {
314 static TfLiteRegistration r = {
315 conv3d_transpose::Init, conv3d_transpose::Free,
316 conv3d_transpose::Prepare<conv3d_transpose::kGenericOptimized>,
317 conv3d_transpose::Eval<conv3d_transpose::kGenericOptimized>};
318 return &r;
319}
320
321TfLiteRegistration* Register_CONV_3D_TRANSPOSE() {
322 return Register_CONV_3D_TRANSPOSE_GENERIC_OPT();
323}
324
325} // namespace builtin
326} // namespace ops
327} // namespace tflite
328