1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/kernels/internal/reference/conv3d.h"
17
18#include <cstddef>
19#include <cstdint>
20#include <vector>
21
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/cpu_backend_context.h"
25#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/internal/types.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29#include "tensorflow/lite/kernels/padding.h"
30#include "tensorflow/lite/util.h"
31
32namespace tflite {
33namespace ops {
34namespace builtin {
35namespace conv3d {
36
37enum KernelType {
38 kReference,
39 kGenericOptimized,
40};
41
42// Struct to carry data from Prepare to Eval.
43const int kTensorNotAllocated = -1;
44static constexpr size_t kMaxIm2colBufferSizeMobile = 1024 * 1024 * 1024; // 1GB
45
46struct OpData {
47 Padding3DValues padding;
48 int im2col_tensor_id = kTensorNotAllocated;
49 int transposed_filter_tensor_id = kTensorNotAllocated;
50
51 bool need_im2col = false;
52 bool need_transposed_filter = false;
53
54 // Disable im2col if the temporary im2col tensor requires too much memory
55 // (i.e. >= kMaxIm2colBufferSizeMobile).
56 bool im2col_oversized = false;
57
58 int32_t im2col_index;
59 int32_t transposed_filter_index;
60};
61
62void* Init(TfLiteContext* context, const char* buffer, size_t length) {
63 auto* opdata = new OpData;
64 return opdata;
65}
66
67void Free(TfLiteContext* context, void* buffer) {
68 delete static_cast<OpData*>(buffer);
69}
70
71TfLiteStatus AllocateTemporaryTensorsIfRequired(
72 KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
73 OpData* opdata, TfLiteConv3DParams* params, const TfLiteTensor* filter,
74 size_t im2col_bytes) {
75 int temporaries_count = 0;
76 const bool need_dilated_im2col = params->dilation_width_factor != 1 ||
77 params->dilation_height_factor != 1 ||
78 params->dilation_depth_factor != 1;
79 const bool need_non_dilated_im2col =
80 params->stride_depth != 1 || params->stride_width != 1 ||
81 params->stride_height != 1 || filter->dims->data[2] != 1 ||
82 filter->dims->data[1] != 1 || filter->dims->data[0] != 1;
83
84 opdata->need_im2col = (kernel_type == kGenericOptimized) &&
85 (need_dilated_im2col || need_non_dilated_im2col);
86 // TODO(b/183455632): Add transposing logic in converter so constant folding
87 // might work on constant filter tensor.
88 opdata->need_transposed_filter = (kernel_type == kGenericOptimized);
89
90 // On mobile platforms, the generic optimized kernel will not be used if the
91 // temporary im2col tensor requires too much memory.
92 if (IsMobilePlatform() && opdata->need_im2col &&
93 im2col_bytes >= kMaxIm2colBufferSizeMobile) {
94 opdata->need_im2col = false;
95 opdata->need_transposed_filter = false;
96 opdata->im2col_oversized = true;
97 }
98
99 if (opdata->need_im2col) {
100 if (opdata->im2col_tensor_id == kTensorNotAllocated) {
101 TF_LITE_ENSURE_OK(
102 context, context->AddTensors(context, 1, &opdata->im2col_tensor_id));
103 }
104 opdata->im2col_index = temporaries_count++;
105 }
106
107 if (opdata->need_transposed_filter) {
108 if (opdata->transposed_filter_tensor_id == kTensorNotAllocated) {
109 TF_LITE_ENSURE_OK(
110 context, context->AddTensors(context, 1,
111 &opdata->transposed_filter_tensor_id));
112 }
113 opdata->transposed_filter_index = temporaries_count++;
114 }
115
116 TfLiteIntArrayFree(node->temporaries);
117 node->temporaries = TfLiteIntArrayCreate(temporaries_count);
118 return kTfLiteOk;
119}
120
121TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
122 TfLiteNode* node) {
123 auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data);
124 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
125
126 // Check number of inputs/outputs.
127 TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
128 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
129 TfLiteTensor* output;
130 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
131 const TfLiteTensor* input;
132 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
133 const TfLiteTensor* filter;
134 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
135
136 // Check dimensionality of input, filter.
137 TF_LITE_ENSURE_EQ(context, input->dims->size, 5);
138 TF_LITE_ENSURE_EQ(context, filter->dims->size, 5);
139
140 // Check input channels matching filter.
141 TF_LITE_ENSURE_EQ(context, input->dims->data[4], filter->dims->data[3]);
142
143 // Check types.
144 TfLiteType input_type = input->type;
145 TF_LITE_ENSURE_TYPES_EQ(context, input_type, kTfLiteFloat32);
146 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
147 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
148
149 // Check bias.
150 const TfLiteTensor* bias = GetInput(context, node, 2);
151 if (bias) {
152 TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
153 TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
154 }
155
156 // Filter has shape of [filter_depth, filter_height, filter_width,
157 // in_channels, out_channels].
158 int batches = input->dims->data[0];
159 int channels_out = filter->dims->data[4];
160 int depth = input->dims->data[1];
161 int height = input->dims->data[2];
162 int width = input->dims->data[3];
163 int filter_depth = filter->dims->data[0];
164 int filter_height = filter->dims->data[1];
165 int filter_width = filter->dims->data[2];
166 int input_channel = filter->dims->data[3];
167
168 // Matching GetWindowedOutputSize in TensorFlow.
169 int out_width, out_height, out_depth;
170 opdata->padding = ComputePadding3DValues(
171 params->stride_height, params->stride_width, params->stride_depth,
172 params->dilation_height_factor, params->dilation_width_factor,
173 params->dilation_depth_factor, height, width, depth, filter_height,
174 filter_width, filter_depth, params->padding, &out_height, &out_width,
175 &out_depth);
176
177 TfLiteIntArray* output_size = TfLiteIntArrayCreate(5);
178 output_size->data[0] = batches;
179 output_size->data[1] = out_depth;
180 output_size->data[2] = out_height;
181 output_size->data[3] = out_width;
182 output_size->data[4] = channels_out;
183 TF_LITE_ENSURE_OK(context,
184 context->ResizeTensor(context, output, output_size));
185
186 // Allocate temporary tensors.
187 size_t input_type_size;
188 TF_LITE_ENSURE_STATUS(GetSizeOfType(context, input->type, &input_type_size));
189 const size_t im2col_bytes = batches * out_depth * out_height * out_width *
190 input_channel * filter_depth * filter_height *
191 filter_width * input_type_size;
192 TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired(
193 kernel_type, context, node, opdata, params,
194 filter, im2col_bytes));
195
196 if (opdata->need_im2col) {
197 TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(5);
198 im2col_size->data[0] = output_size->data[0];
199 im2col_size->data[1] = output_size->data[1];
200 im2col_size->data[2] = output_size->data[2];
201 im2col_size->data[3] = output_size->data[3];
202 im2col_size->data[4] =
203 input_channel * filter_depth * filter_height * filter_width;
204
205 TfLiteTensor* im2col;
206 node->temporaries->data[opdata->im2col_index] = opdata->im2col_tensor_id;
207 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
208 opdata->im2col_index, &im2col));
209 im2col->type = input->type;
210 im2col->allocation_type = kTfLiteArenaRw;
211 TF_LITE_ENSURE_OK(context,
212 context->ResizeTensor(context, im2col, im2col_size));
213 }
214
215 if (opdata->need_transposed_filter) {
216 TfLiteIntArray* transposed_filter_size = TfLiteIntArrayCreate(5);
217 transposed_filter_size->data[0] = filter->dims->data[4];
218 transposed_filter_size->data[1] = filter->dims->data[0];
219 transposed_filter_size->data[2] = filter->dims->data[1];
220 transposed_filter_size->data[3] = filter->dims->data[2];
221 transposed_filter_size->data[4] = filter->dims->data[3];
222
223 TfLiteTensor* transposed_filter;
224 node->temporaries->data[opdata->transposed_filter_index] =
225 opdata->transposed_filter_tensor_id;
226 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
227 opdata->transposed_filter_index,
228 &transposed_filter));
229 transposed_filter->type = filter->type;
230 transposed_filter->allocation_type = kTfLiteArenaRw;
231 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, transposed_filter,
232 transposed_filter_size));
233 }
234 return kTfLiteOk;
235}
236
237template <KernelType kernel_type>
238TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
239 return Prepare(kernel_type, context, node);
240}
241
242void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
243 TfLiteConv3DParams* params, OpData* opdata,
244 const TfLiteTensor* input, const TfLiteTensor* filter,
245 const TfLiteTensor* bias, TfLiteTensor* im2col,
246 TfLiteTensor* tranposed_filter, TfLiteTensor* output) {
247 float output_activation_min, output_activation_max;
248 CalculateActivationRange(params->activation, &output_activation_min,
249 &output_activation_max);
250
251 Conv3DParams runtime_params;
252 runtime_params.padding_values = opdata->padding;
253 runtime_params.stride_depth = params->stride_depth;
254 runtime_params.stride_height = params->stride_height;
255 runtime_params.stride_width = params->stride_width;
256 runtime_params.dilation_depth = params->dilation_depth_factor;
257 runtime_params.dilation_height = params->dilation_height_factor;
258 runtime_params.dilation_width = params->dilation_width_factor;
259 runtime_params.float_activation_min = output_activation_min;
260 runtime_params.float_activation_max = output_activation_max;
261 switch (kernel_type) {
262 case kReference: {
263 reference_ops::Conv3D(runtime_params, GetTensorShape(input),
264 GetTensorData<float>(input), GetTensorShape(filter),
265 GetTensorData<float>(filter), GetTensorShape(bias),
266 GetTensorData<float>(bias), GetTensorShape(output),
267 GetTensorData<float>(output));
268 break;
269 }
270 case kGenericOptimized: {
271 optimized_ops::Conv3D(
272 runtime_params, GetTensorShape(input), GetTensorData<float>(input),
273 GetTensorShape(filter), GetTensorData<float>(filter),
274 GetTensorShape(bias), GetTensorData<float>(bias),
275 GetTensorShape(output), GetTensorData<float>(output),
276 GetTensorShape(im2col), GetTensorData<float>(im2col),
277 GetTensorShape(tranposed_filter),
278 GetTensorData<float>(tranposed_filter),
279 CpuBackendContext::GetFromContext(context));
280 } break;
281 }
282}
283
284TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context,
285 TfLiteNode* node) {
286 auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data);
287 OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
288
289 TfLiteTensor* output;
290 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
291 const TfLiteTensor* input;
292 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
293 const TfLiteTensor* filter;
294 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
295 const TfLiteTensor* bias = GetInput(context, node, 2);
296
297 TfLiteTensor* im2col = opdata->need_im2col
298 ? &context->tensors[opdata->im2col_tensor_id]
299 : nullptr;
300 TfLiteTensor* transposed_filter =
301 opdata->need_transposed_filter
302 ? &context->tensors[opdata->transposed_filter_tensor_id]
303 : nullptr;
304
305 // Fallback to reference execution path when im2col is needed but disabled.
306 if (opdata->im2col_oversized) {
307 kernel_type = kReference;
308 }
309
310 switch (input->type) {
311 case kTfLiteFloat32:
312 EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias,
313 im2col, transposed_filter, output);
314 break;
315 default:
316 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
317 TfLiteTypeGetName(input->type));
318 return kTfLiteError;
319 }
320 return kTfLiteOk;
321}
322
323template <KernelType kernel_type>
324TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
325 return Eval(kernel_type, context, node);
326}
327
328} // namespace conv3d
329
330TfLiteRegistration* Register_CONV_3D_REF() {
331 static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
332 conv3d::Prepare<conv3d::kReference>,
333 conv3d::Eval<conv3d::kReference>};
334 return &r;
335}
336
337TfLiteRegistration* Register_CONV_3D_GENERIC_OPT() {
338 static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
339 conv3d::Prepare<conv3d::kGenericOptimized>,
340 conv3d::Eval<conv3d::kGenericOptimized>};
341 return &r;
342}
343
344TfLiteRegistration* Register_CONV_3D() {
345 return Register_CONV_3D_GENERIC_OPT();
346}
347
348} // namespace builtin
349} // namespace ops
350} // namespace tflite
351