1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/kernels/internal/reference/conv3d.h" |
17 | |
18 | #include <cstddef> |
19 | #include <cstdint> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/lite/c/builtin_op_data.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
25 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
27 | #include "tensorflow/lite/kernels/internal/types.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | #include "tensorflow/lite/kernels/padding.h" |
30 | #include "tensorflow/lite/util.h" |
31 | |
32 | namespace tflite { |
33 | namespace ops { |
34 | namespace builtin { |
35 | namespace conv3d { |
36 | |
37 | enum KernelType { |
38 | kReference, |
39 | kGenericOptimized, |
40 | }; |
41 | |
42 | // Struct to carry data from Prepare to Eval. |
43 | const int kTensorNotAllocated = -1; |
44 | static constexpr size_t kMaxIm2colBufferSizeMobile = 1024 * 1024 * 1024; // 1GB |
45 | |
46 | struct OpData { |
47 | Padding3DValues padding; |
48 | int im2col_tensor_id = kTensorNotAllocated; |
49 | int transposed_filter_tensor_id = kTensorNotAllocated; |
50 | |
51 | bool need_im2col = false; |
52 | bool need_transposed_filter = false; |
53 | |
54 | // Disable im2col if the temporary im2col tensor requires too much memory |
55 | // (i.e. >= kMaxIm2colBufferSizeMobile). |
56 | bool im2col_oversized = false; |
57 | |
58 | int32_t im2col_index; |
59 | int32_t transposed_filter_index; |
60 | }; |
61 | |
62 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
63 | auto* opdata = new OpData; |
64 | return opdata; |
65 | } |
66 | |
67 | void Free(TfLiteContext* context, void* buffer) { |
68 | delete static_cast<OpData*>(buffer); |
69 | } |
70 | |
71 | TfLiteStatus AllocateTemporaryTensorsIfRequired( |
72 | KernelType kernel_type, TfLiteContext* context, TfLiteNode* node, |
73 | OpData* opdata, TfLiteConv3DParams* params, const TfLiteTensor* filter, |
74 | size_t im2col_bytes) { |
75 | int temporaries_count = 0; |
76 | const bool need_dilated_im2col = params->dilation_width_factor != 1 || |
77 | params->dilation_height_factor != 1 || |
78 | params->dilation_depth_factor != 1; |
79 | const bool need_non_dilated_im2col = |
80 | params->stride_depth != 1 || params->stride_width != 1 || |
81 | params->stride_height != 1 || filter->dims->data[2] != 1 || |
82 | filter->dims->data[1] != 1 || filter->dims->data[0] != 1; |
83 | |
84 | opdata->need_im2col = (kernel_type == kGenericOptimized) && |
85 | (need_dilated_im2col || need_non_dilated_im2col); |
86 | // TODO(b/183455632): Add transposing logic in converter so constant folding |
87 | // might work on constant filter tensor. |
88 | opdata->need_transposed_filter = (kernel_type == kGenericOptimized); |
89 | |
90 | // On mobile platforms, the generic optimized kernel will not be used if the |
91 | // temporary im2col tensor requires too much memory. |
92 | if (IsMobilePlatform() && opdata->need_im2col && |
93 | im2col_bytes >= kMaxIm2colBufferSizeMobile) { |
94 | opdata->need_im2col = false; |
95 | opdata->need_transposed_filter = false; |
96 | opdata->im2col_oversized = true; |
97 | } |
98 | |
99 | if (opdata->need_im2col) { |
100 | if (opdata->im2col_tensor_id == kTensorNotAllocated) { |
101 | TF_LITE_ENSURE_OK( |
102 | context, context->AddTensors(context, 1, &opdata->im2col_tensor_id)); |
103 | } |
104 | opdata->im2col_index = temporaries_count++; |
105 | } |
106 | |
107 | if (opdata->need_transposed_filter) { |
108 | if (opdata->transposed_filter_tensor_id == kTensorNotAllocated) { |
109 | TF_LITE_ENSURE_OK( |
110 | context, context->AddTensors(context, 1, |
111 | &opdata->transposed_filter_tensor_id)); |
112 | } |
113 | opdata->transposed_filter_index = temporaries_count++; |
114 | } |
115 | |
116 | TfLiteIntArrayFree(node->temporaries); |
117 | node->temporaries = TfLiteIntArrayCreate(temporaries_count); |
118 | return kTfLiteOk; |
119 | } |
120 | |
121 | TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context, |
122 | TfLiteNode* node) { |
123 | auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data); |
124 | OpData* opdata = reinterpret_cast<OpData*>(node->user_data); |
125 | |
126 | // Check number of inputs/outputs. |
127 | TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3); |
128 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
129 | TfLiteTensor* output; |
130 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
131 | const TfLiteTensor* input; |
132 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
133 | const TfLiteTensor* filter; |
134 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter)); |
135 | |
136 | // Check dimensionality of input, filter. |
137 | TF_LITE_ENSURE_EQ(context, input->dims->size, 5); |
138 | TF_LITE_ENSURE_EQ(context, filter->dims->size, 5); |
139 | |
140 | // Check input channels matching filter. |
141 | TF_LITE_ENSURE_EQ(context, input->dims->data[4], filter->dims->data[3]); |
142 | |
143 | // Check types. |
144 | TfLiteType input_type = input->type; |
145 | TF_LITE_ENSURE_TYPES_EQ(context, input_type, kTfLiteFloat32); |
146 | TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32); |
147 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type); |
148 | |
149 | // Check bias. |
150 | const TfLiteTensor* bias = GetInput(context, node, 2); |
151 | if (bias) { |
152 | TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type); |
153 | TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4)); |
154 | } |
155 | |
156 | // Filter has shape of [filter_depth, filter_height, filter_width, |
157 | // in_channels, out_channels]. |
158 | int batches = input->dims->data[0]; |
159 | int channels_out = filter->dims->data[4]; |
160 | int depth = input->dims->data[1]; |
161 | int height = input->dims->data[2]; |
162 | int width = input->dims->data[3]; |
163 | int filter_depth = filter->dims->data[0]; |
164 | int filter_height = filter->dims->data[1]; |
165 | int filter_width = filter->dims->data[2]; |
166 | int input_channel = filter->dims->data[3]; |
167 | |
168 | // Matching GetWindowedOutputSize in TensorFlow. |
169 | int out_width, out_height, out_depth; |
170 | opdata->padding = ComputePadding3DValues( |
171 | params->stride_height, params->stride_width, params->stride_depth, |
172 | params->dilation_height_factor, params->dilation_width_factor, |
173 | params->dilation_depth_factor, height, width, depth, filter_height, |
174 | filter_width, filter_depth, params->padding, &out_height, &out_width, |
175 | &out_depth); |
176 | |
177 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(5); |
178 | output_size->data[0] = batches; |
179 | output_size->data[1] = out_depth; |
180 | output_size->data[2] = out_height; |
181 | output_size->data[3] = out_width; |
182 | output_size->data[4] = channels_out; |
183 | TF_LITE_ENSURE_OK(context, |
184 | context->ResizeTensor(context, output, output_size)); |
185 | |
186 | // Allocate temporary tensors. |
187 | size_t input_type_size; |
188 | TF_LITE_ENSURE_STATUS(GetSizeOfType(context, input->type, &input_type_size)); |
189 | const size_t im2col_bytes = batches * out_depth * out_height * out_width * |
190 | input_channel * filter_depth * filter_height * |
191 | filter_width * input_type_size; |
192 | TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired( |
193 | kernel_type, context, node, opdata, params, |
194 | filter, im2col_bytes)); |
195 | |
196 | if (opdata->need_im2col) { |
197 | TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(5); |
198 | im2col_size->data[0] = output_size->data[0]; |
199 | im2col_size->data[1] = output_size->data[1]; |
200 | im2col_size->data[2] = output_size->data[2]; |
201 | im2col_size->data[3] = output_size->data[3]; |
202 | im2col_size->data[4] = |
203 | input_channel * filter_depth * filter_height * filter_width; |
204 | |
205 | TfLiteTensor* im2col; |
206 | node->temporaries->data[opdata->im2col_index] = opdata->im2col_tensor_id; |
207 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, |
208 | opdata->im2col_index, &im2col)); |
209 | im2col->type = input->type; |
210 | im2col->allocation_type = kTfLiteArenaRw; |
211 | TF_LITE_ENSURE_OK(context, |
212 | context->ResizeTensor(context, im2col, im2col_size)); |
213 | } |
214 | |
215 | if (opdata->need_transposed_filter) { |
216 | TfLiteIntArray* transposed_filter_size = TfLiteIntArrayCreate(5); |
217 | transposed_filter_size->data[0] = filter->dims->data[4]; |
218 | transposed_filter_size->data[1] = filter->dims->data[0]; |
219 | transposed_filter_size->data[2] = filter->dims->data[1]; |
220 | transposed_filter_size->data[3] = filter->dims->data[2]; |
221 | transposed_filter_size->data[4] = filter->dims->data[3]; |
222 | |
223 | TfLiteTensor* transposed_filter; |
224 | node->temporaries->data[opdata->transposed_filter_index] = |
225 | opdata->transposed_filter_tensor_id; |
226 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, |
227 | opdata->transposed_filter_index, |
228 | &transposed_filter)); |
229 | transposed_filter->type = filter->type; |
230 | transposed_filter->allocation_type = kTfLiteArenaRw; |
231 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, transposed_filter, |
232 | transposed_filter_size)); |
233 | } |
234 | return kTfLiteOk; |
235 | } |
236 | |
237 | template <KernelType kernel_type> |
238 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
239 | return Prepare(kernel_type, context, node); |
240 | } |
241 | |
242 | void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node, |
243 | TfLiteConv3DParams* params, OpData* opdata, |
244 | const TfLiteTensor* input, const TfLiteTensor* filter, |
245 | const TfLiteTensor* bias, TfLiteTensor* im2col, |
246 | TfLiteTensor* tranposed_filter, TfLiteTensor* output) { |
247 | float output_activation_min, output_activation_max; |
248 | CalculateActivationRange(params->activation, &output_activation_min, |
249 | &output_activation_max); |
250 | |
251 | Conv3DParams runtime_params; |
252 | runtime_params.padding_values = opdata->padding; |
253 | runtime_params.stride_depth = params->stride_depth; |
254 | runtime_params.stride_height = params->stride_height; |
255 | runtime_params.stride_width = params->stride_width; |
256 | runtime_params.dilation_depth = params->dilation_depth_factor; |
257 | runtime_params.dilation_height = params->dilation_height_factor; |
258 | runtime_params.dilation_width = params->dilation_width_factor; |
259 | runtime_params.float_activation_min = output_activation_min; |
260 | runtime_params.float_activation_max = output_activation_max; |
261 | switch (kernel_type) { |
262 | case kReference: { |
263 | reference_ops::Conv3D(runtime_params, GetTensorShape(input), |
264 | GetTensorData<float>(input), GetTensorShape(filter), |
265 | GetTensorData<float>(filter), GetTensorShape(bias), |
266 | GetTensorData<float>(bias), GetTensorShape(output), |
267 | GetTensorData<float>(output)); |
268 | break; |
269 | } |
270 | case kGenericOptimized: { |
271 | optimized_ops::Conv3D( |
272 | runtime_params, GetTensorShape(input), GetTensorData<float>(input), |
273 | GetTensorShape(filter), GetTensorData<float>(filter), |
274 | GetTensorShape(bias), GetTensorData<float>(bias), |
275 | GetTensorShape(output), GetTensorData<float>(output), |
276 | GetTensorShape(im2col), GetTensorData<float>(im2col), |
277 | GetTensorShape(tranposed_filter), |
278 | GetTensorData<float>(tranposed_filter), |
279 | CpuBackendContext::GetFromContext(context)); |
280 | } break; |
281 | } |
282 | } |
283 | |
284 | TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context, |
285 | TfLiteNode* node) { |
286 | auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data); |
287 | OpData* opdata = reinterpret_cast<OpData*>(node->user_data); |
288 | |
289 | TfLiteTensor* output; |
290 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
291 | const TfLiteTensor* input; |
292 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
293 | const TfLiteTensor* filter; |
294 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter)); |
295 | const TfLiteTensor* bias = GetInput(context, node, 2); |
296 | |
297 | TfLiteTensor* im2col = opdata->need_im2col |
298 | ? &context->tensors[opdata->im2col_tensor_id] |
299 | : nullptr; |
300 | TfLiteTensor* transposed_filter = |
301 | opdata->need_transposed_filter |
302 | ? &context->tensors[opdata->transposed_filter_tensor_id] |
303 | : nullptr; |
304 | |
305 | // Fallback to reference execution path when im2col is needed but disabled. |
306 | if (opdata->im2col_oversized) { |
307 | kernel_type = kReference; |
308 | } |
309 | |
310 | switch (input->type) { |
311 | case kTfLiteFloat32: |
312 | EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias, |
313 | im2col, transposed_filter, output); |
314 | break; |
315 | default: |
316 | TF_LITE_KERNEL_LOG(context, "Type %s currently not supported." , |
317 | TfLiteTypeGetName(input->type)); |
318 | return kTfLiteError; |
319 | } |
320 | return kTfLiteOk; |
321 | } |
322 | |
323 | template <KernelType kernel_type> |
324 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
325 | return Eval(kernel_type, context, node); |
326 | } |
327 | |
328 | } // namespace conv3d |
329 | |
330 | TfLiteRegistration* Register_CONV_3D_REF() { |
331 | static TfLiteRegistration r = {conv3d::Init, conv3d::Free, |
332 | conv3d::Prepare<conv3d::kReference>, |
333 | conv3d::Eval<conv3d::kReference>}; |
334 | return &r; |
335 | } |
336 | |
337 | TfLiteRegistration* Register_CONV_3D_GENERIC_OPT() { |
338 | static TfLiteRegistration r = {conv3d::Init, conv3d::Free, |
339 | conv3d::Prepare<conv3d::kGenericOptimized>, |
340 | conv3d::Eval<conv3d::kGenericOptimized>}; |
341 | return &r; |
342 | } |
343 | |
344 | TfLiteRegistration* Register_CONV_3D() { |
345 | return Register_CONV_3D_GENERIC_OPT(); |
346 | } |
347 | |
348 | } // namespace builtin |
349 | } // namespace ops |
350 | } // namespace tflite |
351 | |