1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define USE_EIGEN_TENSOR |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/core/framework/kernel_shape_util.h" |
22 | #include "tensorflow/core/framework/numeric_op.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/tensor_slice.h" |
28 | #include "tensorflow/core/framework/tensor_util.h" |
29 | #include "tensorflow/core/kernels/conv_2d.h" |
30 | #include "tensorflow/core/kernels/conv_3d.h" |
31 | #include "tensorflow/core/kernels/conv_grad_ops.h" |
32 | #include "tensorflow/core/kernels/conv_grad_shape_utils.h" |
33 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
34 | #include "tensorflow/core/lib/core/errors.h" |
35 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
36 | #include "tensorflow/core/profiler/lib/scoped_annotation.h" |
37 | #include "tensorflow/core/util/padding.h" |
38 | #include "tensorflow/core/util/tensor_format.h" |
39 | #include "tensorflow/core/util/use_cudnn.h" |
40 | #include "tensorflow/core/util/work_sharder.h" |
41 | |
42 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
43 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
44 | #endif |
45 | |
46 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
47 | #include "tensorflow/core/platform/stream_executor.h" |
48 | using stream_executor::dnn::DimIndex; |
49 | #include "tensorflow/core/protobuf/autotuning.pb.h" |
50 | #include "tensorflow/core/util/autotune_maps/conv_parameters.h" |
51 | #include "tensorflow/core/util/proto/proto_utils.h" |
52 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
53 | #if GOOGLE_CUDA |
54 | #include "third_party/gpus/cudnn/cudnn.h" |
55 | #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" |
56 | #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" |
57 | #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" |
58 | #endif // GOOGLE_CUDA |
59 | |
60 | namespace { |
61 | |
62 | // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and |
63 | // conv_grad_input_ops_3d.cc. |
64 | |
65 | // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels. |
66 | |
67 | // "Depth" is already used for the channel dimension, so for the third spatial |
68 | // dimension in this file we use "plane", although in NDHWC layout it's |
69 | // indicated with a "D". |
70 | |
71 | // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage |
72 | // order (planes, height, width, depth), constructed from patches in 'col_data', |
73 | // which is required to be in storage order (out_planes * out_height * |
74 | // out_width, filter_planes, filter_height, filter_width, in_depth). |
75 | // |
76 | // Based on 2-dimensional implementation written by Yangqing Jia (jiayq). |
77 | template <typename T> |
78 | void Col2im(const T* col_data, const int depth, const int planes, |
79 | const int height, const int width, const int filter_p, |
80 | const int filter_h, const int filter_w, const int pad_pt, |
81 | const int pad_t, const int pad_l, const int pad_pb, const int pad_b, |
82 | const int pad_r, const int stride_p, const int stride_h, |
83 | const int stride_w, T* im_data) { |
84 | const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; |
85 | const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; |
86 | const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; |
87 | int p_pad = -pad_pt; |
88 | for (int p = 0; p < planes_col; ++p) { |
89 | int h_pad = -pad_t; |
90 | for (int h = 0; h < height_col; ++h) { |
91 | int w_pad = -pad_l; |
92 | for (int w = 0; w < width_col; ++w) { |
93 | T* im_patch_data = |
94 | im_data + (p_pad * height * width + h_pad * width + w_pad) * depth; |
95 | for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { |
96 | for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { |
97 | for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { |
98 | if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && |
99 | iw < width) { |
100 | for (int i = 0; i < depth; ++i) { |
101 | im_patch_data[i] += col_data[i]; |
102 | } |
103 | } |
104 | im_patch_data += depth; |
105 | col_data += depth; |
106 | } |
107 | // Jump over remaining number of depth. |
108 | im_patch_data += depth * (width - filter_w); |
109 | } |
110 | // Jump over remaining number of (depth * width). |
111 | im_patch_data += (depth * width) * (height - filter_h); |
112 | } |
113 | w_pad += stride_w; |
114 | } |
115 | h_pad += stride_h; |
116 | } |
117 | p_pad += stride_p; |
118 | } |
119 | } |
120 | |
121 | // Returns in 'col_data', image patches in storage order (planes, height, width, |
122 | // depth) extracted from image at 'input_data', which is required to be in |
123 | // storage order (batch, planes, height, width, depth). |
124 | // |
125 | // Based on 2-dimensional implementation written by Yangqing Jia (jiayq). |
126 | template <typename T> |
127 | void Im2col(const T* input_data, const int depth, const int planes, |
128 | const int height, const int width, const int filter_p, |
129 | const int filter_h, const int filter_w, const int pad_pt, |
130 | const int pad_t, const int pad_l, const int pad_pb, const int pad_b, |
131 | const int pad_r, const int stride_p, const int stride_h, |
132 | const int stride_w, T* col_data) { |
133 | const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; |
134 | const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; |
135 | const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; |
136 | |
137 | int p_pad = -pad_pt; |
138 | for (int p = 0; p < planes_col; ++p) { |
139 | int h_pad = -pad_t; |
140 | for (int h = 0; h < height_col; ++h) { |
141 | int w_pad = -pad_l; |
142 | for (int w = 0; w < width_col; ++w) { |
143 | for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { |
144 | for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { |
145 | for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { |
146 | if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && |
147 | iw < width) { |
148 | memcpy(col_data, |
149 | input_data + |
150 | (ip * height * width + ih * width + iw) * depth, |
151 | sizeof(T) * depth); |
152 | } else { |
153 | // This should be simply padded with zero. |
154 | memset(col_data, 0, sizeof(T) * depth); |
155 | } |
156 | col_data += depth; |
157 | } |
158 | } |
159 | } |
160 | w_pad += stride_w; |
161 | } |
162 | h_pad += stride_h; |
163 | } |
164 | p_pad += stride_p; |
165 | } |
166 | } |
167 | |
168 | } // namespace |
169 | |
170 | namespace tensorflow { |
171 | |
172 | typedef Eigen::ThreadPoolDevice CPUDevice; |
173 | typedef Eigen::GpuDevice GPUDevice; |
174 | |
175 | // Backprop for input that offloads computation to |
176 | // Eigen::CuboidConvolutionBackwardInput. |
177 | template <typename Device, class T> |
178 | class Conv3DBackpropInputOp : public OpKernel { |
179 | public: |
180 | explicit Conv3DBackpropInputOp(OpKernelConstruction* context) |
181 | : OpKernel(context), |
182 | data_format_(FORMAT_NHWC), |
183 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
184 | // data_format is only available in V2. |
185 | if (takes_shape_) { |
186 | string data_format; |
187 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
188 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
189 | errors::InvalidArgument("Invalid data format" )); |
190 | OP_REQUIRES( |
191 | context, data_format_ == FORMAT_NHWC, |
192 | errors::InvalidArgument( |
193 | "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU." )); |
194 | } |
195 | |
196 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
197 | OP_REQUIRES(context, dilation_.size() == 5, |
198 | errors::InvalidArgument("Dilation rates field must " |
199 | "specify 5 dimensions" )); |
200 | OP_REQUIRES(context, |
201 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
202 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
203 | errors::InvalidArgument( |
204 | "Current implementation does not yet support " |
205 | "dilation rates in the batch and depth dimensions." )); |
206 | |
207 | // TODO(yangzihao): Add CPU version of dilated conv 3D. |
208 | OP_REQUIRES(context, |
209 | (GetTensorDim(dilation_, data_format_, '0') == 1 && |
210 | GetTensorDim(dilation_, data_format_, '1') == 1 && |
211 | GetTensorDim(dilation_, data_format_, '2') == 1), |
212 | errors::InvalidArgument( |
213 | "Current CPU implementation does not yet support " |
214 | "dilation rates larger than 1." )); |
215 | |
216 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
217 | OP_REQUIRES(context, stride_.size() == 5, |
218 | errors::InvalidArgument("Sliding window strides field must " |
219 | "specify 5 dimensions" )); |
220 | OP_REQUIRES( |
221 | context, |
222 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
223 | GetTensorDim(stride_, data_format_, 'N') == 1), |
224 | errors::InvalidArgument("Current implementation does not yet support " |
225 | "strides in the batch and depth dimensions." )); |
226 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
227 | } |
228 | |
229 | void Compute(OpKernelContext* context) override { |
230 | const Tensor& filter = context->input(1); |
231 | const TensorShape& filter_shape = filter.shape(); |
232 | |
233 | const Tensor& out_backprop = context->input(2); |
234 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
235 | |
236 | TensorShape input_shape; |
237 | if (takes_shape_) { |
238 | const Tensor& input_sizes = context->input(0); |
239 | // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for |
240 | // input_sizes. |
241 | OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); |
242 | } else { |
243 | input_shape = context->input(0).shape(); |
244 | } |
245 | |
246 | OP_REQUIRES(context, input_shape.dims() == 5, |
247 | errors::InvalidArgument("input tensor must have 5 dimensions" )); |
248 | OP_REQUIRES( |
249 | context, filter_shape.dims() == 5, |
250 | errors::InvalidArgument("filter_sizes tensor must have 5 dimensions" )); |
251 | OP_REQUIRES( |
252 | context, out_backprop_shape.dims() == 5, |
253 | errors::InvalidArgument("out_backprop tensor must have 5 dimensions" )); |
254 | OP_REQUIRES( |
255 | context, input_shape.dim_size(4) == filter_shape.dim_size(3), |
256 | errors::InvalidArgument("input and filter_sizes must have the same " |
257 | "number of channels. Got " , |
258 | input_shape.dim_size(4), " for input and " , |
259 | filter_shape.dim_size(3), " for filter_sizes" )); |
260 | OP_REQUIRES( |
261 | context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4), |
262 | errors::InvalidArgument("out_backprop and filter_sizes must have the " |
263 | "same number of channels. Got " , |
264 | out_backprop_shape.dim_size(4), |
265 | " for out_backprop and " , |
266 | filter_shape.dim_size(4), " for filter_sizes" )); |
267 | |
268 | ConvBackpropDimensions dims; |
269 | OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( |
270 | "Conv3DBackpropInputOp" , /*num_spatial_dims=*/3, |
271 | input_shape, filter_shape, out_backprop_shape, |
272 | stride_, padding_, data_format_, &dims)); |
273 | |
274 | Tensor* in_backprop; |
275 | OP_REQUIRES_OK(context, |
276 | context->allocate_output(0, input_shape, &in_backprop)); |
277 | |
278 | functor::CuboidConvolutionBackwardInput<Device, T>()( |
279 | context->eigen_device<Device>(), |
280 | in_backprop->tensor<T, 5>(), // input_backward |
281 | filter.tensor<T, 5>(), // filter |
282 | out_backprop.tensor<T, 5>(), // output_backward |
283 | static_cast<int>(dims.spatial_dims[0].stride), // stride_planes |
284 | static_cast<int>(dims.spatial_dims[1].stride), // stride_rows |
285 | static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols |
286 | } |
287 | |
288 | private: |
289 | std::vector<int32> dilation_; |
290 | std::vector<int32> stride_; |
291 | Padding padding_; |
292 | TensorFormat data_format_; |
293 | bool takes_shape_; |
294 | |
295 | TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp); |
296 | }; |
297 | |
298 | // Custom backprop for input that explicitly does the work sharding and calls |
299 | // Eigen only to multiply matrices. |
300 | template <typename Device, class T> |
301 | class Conv3DCustomBackpropInputOp : public OpKernel { |
302 | // Limit the maximum size of allocated temporary buffer to |
303 | // kMaxTempAllocationOverhead times the size of the input tensors (input, |
304 | // filter, out_backprop). If the size of the temporary buffer exceeds this |
305 | // limit, fallback on Eigen implementation. |
306 | static constexpr int kMaxTempAllocationOverhead = 25; |
307 | |
308 | public: |
309 | explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context) |
310 | : OpKernel(context), |
311 | data_format_(FORMAT_NHWC), |
312 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
313 | // data_format is only available in V2. |
314 | if (takes_shape_) { |
315 | string data_format; |
316 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
317 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
318 | errors::InvalidArgument("Invalid data format" )); |
319 | OP_REQUIRES( |
320 | context, data_format_ == FORMAT_NHWC, |
321 | errors::InvalidArgument( |
322 | "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU." )); |
323 | } |
324 | |
325 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
326 | OP_REQUIRES(context, dilation_.size() == 5, |
327 | errors::InvalidArgument("Dilation rates field must " |
328 | "specify 5 dimensions" )); |
329 | OP_REQUIRES(context, |
330 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
331 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
332 | errors::InvalidArgument( |
333 | "Current implementation does not yet support " |
334 | "dilation rates in the batch and depth dimensions." )); |
335 | |
336 | // TODO(yangzihao): Add CPU version of dilated conv 3D. |
337 | OP_REQUIRES(context, |
338 | (GetTensorDim(dilation_, data_format_, '0') == 1 && |
339 | GetTensorDim(dilation_, data_format_, '1') == 1 && |
340 | GetTensorDim(dilation_, data_format_, '2') == 1), |
341 | errors::InvalidArgument( |
342 | "Current CPU implementation does not yet support " |
343 | "dilation rates larger than 1." )); |
344 | |
345 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
346 | OP_REQUIRES(context, stride_.size() == 5, |
347 | errors::InvalidArgument("Sliding window strides field must " |
348 | "specify 5 dimensions" )); |
349 | OP_REQUIRES( |
350 | context, |
351 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
352 | GetTensorDim(stride_, data_format_, 'N') == 1), |
353 | errors::InvalidArgument("Current implementation does not yet support " |
354 | "strides in the batch and depth dimensions." )); |
355 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
356 | } |
357 | |
358 | void Compute(OpKernelContext* context) override { |
359 | const Tensor& filter = context->input(1); |
360 | const TensorShape& filter_shape = filter.shape(); |
361 | |
362 | const Tensor& out_backprop = context->input(2); |
363 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
364 | |
365 | TensorShape input_shape; |
366 | if (takes_shape_) { |
367 | const Tensor& input_sizes = context->input(0); |
368 | // tensor::MakeShape is able to handle both DT_INT32 and DT_INT64 for |
369 | // input_sizes. |
370 | OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); |
371 | } else { |
372 | input_shape = context->input(0).shape(); |
373 | } |
374 | |
375 | OP_REQUIRES(context, input_shape.dims() == 5, |
376 | errors::InvalidArgument("input tensor must have 5 dimensions" )); |
377 | OP_REQUIRES( |
378 | context, filter_shape.dims() == 5, |
379 | errors::InvalidArgument("filter_sizes tensor must have 5 dimensions" )); |
380 | OP_REQUIRES( |
381 | context, out_backprop_shape.dims() == 5, |
382 | errors::InvalidArgument("out_backprop tensor must have 5 dimensions" )); |
383 | OP_REQUIRES( |
384 | context, input_shape.dim_size(4) == filter_shape.dim_size(3), |
385 | errors::InvalidArgument("input and filter_sizes must have the same " |
386 | "number of channels. Got " , |
387 | input_shape.dim_size(4), " for input and " , |
388 | filter_shape.dim_size(3), " for filter_sizes" )); |
389 | OP_REQUIRES( |
390 | context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4), |
391 | errors::InvalidArgument("out_backprop and filter_sizes must have the " |
392 | "same number of channels. Got " , |
393 | out_backprop_shape.dim_size(4), |
394 | " for out_backprop and " , |
395 | filter_shape.dim_size(4), " for filter_sizes" )); |
396 | |
397 | ConvBackpropDimensions dims; |
398 | OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( |
399 | "Conv3DBackpropInputOp" , /*num_spatial_dims=*/3, |
400 | input_shape, filter_shape, out_backprop_shape, |
401 | stride_, padding_, data_format_, &dims)); |
402 | |
403 | Tensor* in_backprop; |
404 | OP_REQUIRES_OK(context, |
405 | context->allocate_output(0, input_shape, &in_backprop)); |
406 | |
407 | int64_t top_pad_planes, bottom_pad_planes; |
408 | int64_t top_pad_rows, bottom_pad_rows; |
409 | int64_t left_pad_cols, right_pad_cols; |
410 | |
411 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
412 | dims.spatial_dims[0].input_size, |
413 | dims.spatial_dims[0].filter_size, |
414 | dims.spatial_dims[0].stride, padding_, |
415 | &dims.spatial_dims[0].output_size, |
416 | &top_pad_planes, &bottom_pad_planes)); |
417 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
418 | dims.spatial_dims[1].input_size, |
419 | dims.spatial_dims[1].filter_size, |
420 | dims.spatial_dims[1].stride, padding_, |
421 | &dims.spatial_dims[1].output_size, |
422 | &top_pad_rows, &bottom_pad_rows)); |
423 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
424 | dims.spatial_dims[2].input_size, |
425 | dims.spatial_dims[2].filter_size, |
426 | dims.spatial_dims[2].stride, padding_, |
427 | &dims.spatial_dims[2].output_size, |
428 | &left_pad_cols, &right_pad_cols)); |
429 | |
430 | // TODO(ezhulenev): Extract work size and shard estimation to shared |
431 | // functions in conv_grad_ops, and update 2d convolution backprop. |
432 | |
433 | // The total dimension size of each kernel. |
434 | const int64_t filter_total_size = |
435 | dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * |
436 | dims.spatial_dims[2].filter_size * dims.in_depth; |
437 | |
438 | // The output image size is the spatial size of the output. |
439 | const int64_t output_image_size = dims.spatial_dims[0].output_size * |
440 | dims.spatial_dims[1].output_size * |
441 | dims.spatial_dims[2].output_size; |
442 | |
443 | const auto cache_sizes = Eigen::internal::CacheSizes(); |
444 | const ptrdiff_t l3_cache_size = cache_sizes.m_l3; |
445 | |
446 | // Use L3 cache size as target working set size. |
447 | const size_t target_working_set_size = l3_cache_size / sizeof(T); |
448 | |
449 | // Calculate size of matrices involved in MatMul: C = A x B. |
450 | const int64_t size_A = output_image_size * dims.out_depth; |
451 | |
452 | const int64_t size_B = filter_total_size * dims.out_depth; |
453 | |
454 | const int64_t size_C = output_image_size * filter_total_size; |
455 | |
456 | const int64_t work_unit_size = size_A + size_B + size_C; |
457 | |
458 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
459 | |
460 | // Use parallel tensor contractions if there is no batching. |
461 | // |
462 | // Compared to Conv2D code, this version is missing work size estimation. In |
463 | // benchmarks I didn't find a case when it's beneficial to run parallel |
464 | // contraction compared to sharding and matmuls. |
465 | const bool use_parallel_contraction = dims.batch_size == 1; |
466 | |
467 | OP_REQUIRES( |
468 | context, work_unit_size > 0, |
469 | errors::InvalidArgument("input, filter_sizes and out_backprop tensors " |
470 | "must all have at least 1 element" )); |
471 | |
472 | const size_t shard_size = |
473 | use_parallel_contraction |
474 | ? 1 |
475 | : (target_working_set_size + work_unit_size - 1) / work_unit_size; |
476 | |
477 | // Total number of elements in all the tensors used by this kernel. |
478 | int64_t total_tensor_elements = input_shape.num_elements() + |
479 | filter_shape.num_elements() + |
480 | out_backprop_shape.num_elements(); |
481 | |
482 | // Shape of the temporary workspace buffer. |
483 | TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size), |
484 | static_cast<int64_t>(output_image_size), |
485 | static_cast<int64_t>(filter_total_size)}; |
486 | int64_t col_buffer_elements = col_buffer_shape.num_elements(); |
487 | |
488 | // If the temporary allocation overhead is too large, fallback on Eigen |
489 | // implementation which requires much less memory. |
490 | int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements; |
491 | if (col_buffer_overhead > kMaxTempAllocationOverhead) { |
492 | VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: " |
493 | "col_buffer_overhead=" |
494 | << col_buffer_overhead; |
495 | |
496 | functor::CuboidConvolutionBackwardInput<Device, T>()( |
497 | context->eigen_device<Device>(), |
498 | in_backprop->tensor<T, 5>(), // input_backward |
499 | filter.tensor<T, 5>(), // filter |
500 | out_backprop.tensor<T, 5>(), // output_backward |
501 | static_cast<int>(dims.spatial_dims[0].stride), // stride_planes |
502 | static_cast<int>(dims.spatial_dims[1].stride), // stride_rows |
503 | static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols |
504 | |
505 | return; |
506 | } |
507 | |
508 | Tensor col_buffer; |
509 | OP_REQUIRES_OK(context, |
510 | context->allocate_temp(DataTypeToEnum<T>::value, |
511 | col_buffer_shape, &col_buffer)); |
512 | |
513 | // The input offset corresponding to a single input image. |
514 | const int64_t input_offset = |
515 | dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size * |
516 | dims.spatial_dims[2].input_size * dims.in_depth; |
517 | |
518 | // The output offset corresponding to a single output image. |
519 | const int64_t output_offset = |
520 | dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * |
521 | dims.spatial_dims[2].output_size * dims.out_depth; |
522 | |
523 | const T* filter_data = filter.template flat<T>().data(); |
524 | T* col_buffer_data = col_buffer.template flat<T>().data(); |
525 | const T* out_backprop_data = out_backprop.template flat<T>().data(); |
526 | |
527 | auto in_backprop_flat = in_backprop->template flat<T>(); |
528 | T* input_backprop_data = in_backprop_flat.data(); |
529 | in_backprop_flat.device(context->eigen_device<Device>()) = |
530 | in_backprop_flat.constant(T(0)); |
531 | |
532 | if (use_parallel_contraction) { |
533 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, |
534 | Eigen::Unaligned> |
535 | TensorMap; |
536 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, |
537 | Eigen::Unaligned> |
538 | ConstTensorMap; |
539 | |
540 | // Initialize contraction dims (we need to transpose 'B' below). |
541 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; |
542 | contract_dims[0].first = 1; |
543 | contract_dims[0].second = 1; |
544 | |
545 | for (int image_id = 0; image_id < dims.batch_size; ++image_id) { |
546 | // Compute gradient into col_buffer. |
547 | TensorMap C(col_buffer_data, output_image_size, filter_total_size); |
548 | |
549 | ConstTensorMap A(out_backprop_data + output_offset * image_id, |
550 | output_image_size, dims.out_depth); |
551 | ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); |
552 | |
553 | C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); |
554 | |
555 | Col2im<T>(col_buffer_data, dims.in_depth, |
556 | // Input spatial dimensions. |
557 | dims.spatial_dims[0].input_size, // input planes |
558 | dims.spatial_dims[1].input_size, // input rows |
559 | dims.spatial_dims[2].input_size, // input cols |
560 | // Filter spatial dimensions. |
561 | dims.spatial_dims[0].filter_size, // filter planes |
562 | dims.spatial_dims[1].filter_size, // filter rows |
563 | dims.spatial_dims[2].filter_size, // filter cols |
564 | // Spatial padding. |
565 | top_pad_planes, top_pad_rows, left_pad_cols, |
566 | bottom_pad_planes, bottom_pad_rows, right_pad_cols, |
567 | // Spatial striding. |
568 | dims.spatial_dims[0].stride, // stride planes |
569 | dims.spatial_dims[1].stride, // stride rows |
570 | dims.spatial_dims[2].stride, // stride cols |
571 | input_backprop_data); |
572 | |
573 | input_backprop_data += input_offset; |
574 | } |
575 | } else { |
576 | typedef Eigen::Map< |
577 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
578 | MatrixMap; |
579 | typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, |
580 | Eigen::RowMajor>> |
581 | ConstMatrixMap; |
582 | |
583 | for (int image_id = 0; image_id < dims.batch_size; |
584 | image_id += shard_size) { |
585 | const int shard_limit = |
586 | std::min(static_cast<int>(shard_size), |
587 | static_cast<int>(dims.batch_size) - image_id); |
588 | |
589 | auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols, |
590 | &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols, |
591 | &output_image_size, &filter_total_size, |
592 | &input_backprop_data, &col_buffer_data, |
593 | &out_backprop_data, &filter_data, &input_offset, |
594 | &output_offset, &size_C](int64_t start, int64_t limit) { |
595 | for (int shard_id = start; shard_id < limit; ++shard_id) { |
596 | T* im2col_buf = col_buffer_data + shard_id * size_C; |
597 | T* input_data = input_backprop_data + shard_id * input_offset; |
598 | const T* out_data = out_backprop_data + shard_id * output_offset; |
599 | |
600 | // Compute gradient into 'im2col_buf'. |
601 | MatrixMap C(im2col_buf, output_image_size, filter_total_size); |
602 | |
603 | ConstMatrixMap A(out_data, output_image_size, dims.out_depth); |
604 | ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); |
605 | |
606 | C.noalias() = A * B.transpose(); |
607 | |
608 | Col2im<T>(im2col_buf, dims.in_depth, |
609 | // Input spatial dimensions. |
610 | dims.spatial_dims[0].input_size, // input planes |
611 | dims.spatial_dims[1].input_size, // input rows |
612 | dims.spatial_dims[2].input_size, // input cols |
613 | // Filter spatial dimensions. |
614 | dims.spatial_dims[0].filter_size, // filter planes |
615 | dims.spatial_dims[1].filter_size, // filter rows |
616 | dims.spatial_dims[2].filter_size, // filter cols |
617 | // Spatial padding. |
618 | top_pad_planes, top_pad_rows, left_pad_cols, |
619 | bottom_pad_planes, bottom_pad_rows, right_pad_cols, |
620 | // Spatial striding. |
621 | dims.spatial_dims[0].stride, // stride planes |
622 | dims.spatial_dims[1].stride, // stride rows |
623 | dims.spatial_dims[2].stride, // stride cols |
624 | input_data); |
625 | } |
626 | }; |
627 | Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, |
628 | work_unit_size, shard); |
629 | |
630 | input_backprop_data += input_offset * shard_limit; |
631 | out_backprop_data += output_offset * shard_limit; |
632 | } |
633 | } |
634 | } |
635 | |
636 | private: |
637 | std::vector<int32> dilation_; |
638 | std::vector<int32> stride_; |
639 | Padding padding_; |
640 | TensorFormat data_format_; |
641 | bool takes_shape_; |
642 | |
643 | TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp); |
644 | }; |
645 | |
646 | // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than |
647 | // default Eigen implementation (at the cost of ~2x-8x peak memory usage). |
648 | |
649 | #define REGISTER_CPU_KERNEL(T) \ |
650 | REGISTER_KERNEL_BUILDER( \ |
651 | Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
652 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
653 | REGISTER_KERNEL_BUILDER( \ |
654 | Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
655 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
656 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ |
657 | .Device(DEVICE_CPU) \ |
658 | .Label("custom") \ |
659 | .TypeConstraint<T>("T"), \ |
660 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
661 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ |
662 | .Device(DEVICE_CPU) \ |
663 | .Label("custom") \ |
664 | .TypeConstraint<T>("T"), \ |
665 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
666 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ |
667 | .Device(DEVICE_CPU) \ |
668 | .Label("eigen_tensor") \ |
669 | .TypeConstraint<T>("T"), \ |
670 | Conv3DBackpropInputOp<CPUDevice, T>); \ |
671 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ |
672 | .Device(DEVICE_CPU) \ |
673 | .Label("eigen_tensor") \ |
674 | .TypeConstraint<T>("T"), \ |
675 | Conv3DBackpropInputOp<CPUDevice, T>); |
676 | |
677 | TF_CALL_half(REGISTER_CPU_KERNEL); |
678 | TF_CALL_float(REGISTER_CPU_KERNEL); |
679 | TF_CALL_double(REGISTER_CPU_KERNEL); |
680 | #undef REGISTER_CPU_KERNEL |
681 | |
682 | #define REGISTER_CPU_KERNEL(T) \ |
683 | REGISTER_KERNEL_BUILDER( \ |
684 | Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
685 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
686 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ |
687 | .Device(DEVICE_CPU) \ |
688 | .Label("custom") \ |
689 | .TypeConstraint<T>("T"), \ |
690 | Conv3DCustomBackpropInputOp<CPUDevice, T>); \ |
691 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ |
692 | .Device(DEVICE_CPU) \ |
693 | .Label("eigen_tensor") \ |
694 | .TypeConstraint<T>("T"), \ |
695 | Conv3DBackpropInputOp<CPUDevice, T>); |
696 | |
697 | TF_CALL_bfloat16(REGISTER_CPU_KERNEL); |
698 | #undef REGISTER_CPU_KERNEL |
699 | |
700 | // Backprop for filter that offloads computation to |
701 | // Eigen::CuboidConvolutionBackwardFilter. |
702 | template <typename Device, class T> |
703 | class Conv3DBackpropFilterOp : public OpKernel { |
704 | public: |
705 | explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) |
706 | : OpKernel(context), |
707 | data_format_(FORMAT_NHWC), |
708 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
709 | // data_format is only available in V2. |
710 | if (takes_shape_) { |
711 | string data_format; |
712 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
713 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
714 | errors::InvalidArgument("Invalid data format" )); |
715 | OP_REQUIRES( |
716 | context, data_format_ == FORMAT_NHWC, |
717 | errors::InvalidArgument( |
718 | "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU." )); |
719 | } |
720 | |
721 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
722 | OP_REQUIRES(context, dilation_.size() == 5, |
723 | errors::InvalidArgument("Dilation rates field must " |
724 | "specify 5 dimensions" )); |
725 | OP_REQUIRES(context, |
726 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
727 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
728 | errors::InvalidArgument( |
729 | "Current implementation does not yet support " |
730 | "dilation rates in the batch and depth dimensions." )); |
731 | |
732 | // TODO(yangzihao): Add CPU version of dilated conv 3D. |
733 | OP_REQUIRES(context, |
734 | (GetTensorDim(dilation_, data_format_, '0') == 1 && |
735 | GetTensorDim(dilation_, data_format_, '1') == 1 && |
736 | GetTensorDim(dilation_, data_format_, '2') == 1), |
737 | errors::InvalidArgument( |
738 | "Current CPU implementation does not yet support " |
739 | "dilation rates larger than 1." )); |
740 | |
741 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
742 | OP_REQUIRES(context, stride_.size() == 5, |
743 | errors::InvalidArgument("Sliding window strides field must " |
744 | "specify 5 dimensions" )); |
745 | OP_REQUIRES( |
746 | context, |
747 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
748 | GetTensorDim(stride_, data_format_, 'N') == 1), |
749 | errors::InvalidArgument("Current implementation does not yet support " |
750 | "strides in the batch and depth dimensions." )); |
751 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
752 | } |
753 | |
754 | void Compute(OpKernelContext* context) override { |
755 | const Tensor& input = context->input(0); |
756 | const TensorShape& input_shape = input.shape(); |
757 | |
758 | const Tensor& out_backprop = context->input(2); |
759 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
760 | |
761 | TensorShape filter_shape; |
762 | if (takes_shape_) { |
763 | const Tensor& filter_sizes = context->input(1); |
764 | OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()), |
765 | errors::InvalidArgument( |
766 | "filter_sizes shape must be rank 1 but is rank " , |
767 | filter_sizes.shape().dims())); |
768 | OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( |
769 | filter_sizes.vec<int32>(), &filter_shape)); |
770 | } else { |
771 | filter_shape = context->input(1).shape(); |
772 | } |
773 | |
774 | OP_REQUIRES(context, input_shape.dims() == 5, |
775 | errors::InvalidArgument("input tensor must have 5 dimensions" )); |
776 | OP_REQUIRES( |
777 | context, filter_shape.dims() == 5, |
778 | errors::InvalidArgument("filter_sizes tensor must have 5 dimensions" )); |
779 | OP_REQUIRES( |
780 | context, out_backprop_shape.dims() == 5, |
781 | errors::InvalidArgument("out_backprop tensor must have 5 dimensions" )); |
782 | OP_REQUIRES( |
783 | context, input_shape.dim_size(4) == filter_shape.dim_size(3), |
784 | errors::InvalidArgument("input and filter_sizes must have the same " |
785 | "number of channels. Got " , |
786 | input_shape.dim_size(4), " for input and " , |
787 | filter_shape.dim_size(3), " for filter_sizes" )); |
788 | OP_REQUIRES( |
789 | context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4), |
790 | errors::InvalidArgument("out_backprop and filter_sizes must have the " |
791 | "same number of channels. Got " , |
792 | out_backprop_shape.dim_size(4), |
793 | " for out_backprop and " , |
794 | filter_shape.dim_size(4), " for filter_sizes" )); |
795 | |
796 | ConvBackpropDimensions dims; |
797 | OP_REQUIRES_OK(context, |
798 | ConvBackpropComputeDimensions( |
799 | "Conv3DBackpropFilterOp" , /*num_spatial_dims=*/3, |
800 | input_shape, filter_shape, out_backprop_shape, stride_, |
801 | padding_, data_format_, &dims)); |
802 | |
803 | Tensor* filter_backprop; |
804 | OP_REQUIRES_OK(context, |
805 | context->allocate_output(0, filter_shape, &filter_backprop)); |
806 | |
807 | if (input_shape.num_elements() == 0) { |
808 | filter_backprop->template flat<T>().setZero(); |
809 | return; |
810 | } |
811 | |
812 | functor::CuboidConvolutionBackwardFilter<Device, T>()( |
813 | context->eigen_device<Device>(), |
814 | filter_backprop->tensor<T, 5>(), // filter_backward |
815 | input.tensor<T, 5>(), // input |
816 | out_backprop.tensor<T, 5>(), // output_backward |
817 | static_cast<int>(dims.spatial_dims[0].stride), // stride_planes |
818 | static_cast<int>(dims.spatial_dims[1].stride), // stride_rows |
819 | static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols |
820 | } |
821 | |
822 | private: |
823 | std::vector<int32> dilation_; |
824 | std::vector<int32> stride_; |
825 | Padding padding_; |
826 | TensorFormat data_format_; |
827 | bool takes_shape_; |
828 | |
829 | TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp); |
830 | }; |
831 | |
832 | // Custom backprop for filter that explicitly does the work sharding and calls |
833 | // Eigen only to multiply matrices. |
834 | template <typename Device, class T> |
835 | class Conv3DCustomBackpropFilterOp : public OpKernel { |
836 | // Limit the maximum size of allocated temporary buffer to |
837 | // kMaxTempAllocationOverhead times the size of the input tensors (input, |
838 | // filter, out_backprop). If the size of the temporary buffer exceeds this |
839 | // limit, fallback on Eigen implementation. |
840 | static constexpr int kMaxTempAllocationOverhead = 25; |
841 | |
842 | public: |
843 | explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context) |
844 | : OpKernel(context), |
845 | data_format_(FORMAT_NHWC), |
846 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
847 | // data_format is only available in V2. |
848 | if (takes_shape_) { |
849 | string data_format; |
850 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
851 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
852 | errors::InvalidArgument("Invalid data format" )); |
853 | OP_REQUIRES( |
854 | context, data_format_ == FORMAT_NHWC, |
855 | errors::InvalidArgument( |
856 | "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU." )); |
857 | } |
858 | |
859 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
860 | OP_REQUIRES(context, dilation_.size() == 5, |
861 | errors::InvalidArgument("Dilation rates field must " |
862 | "specify 5 dimensions" )); |
863 | OP_REQUIRES(context, |
864 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
865 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
866 | errors::InvalidArgument( |
867 | "Current implementation does not yet support " |
868 | "dilation rates in the batch and depth dimensions." )); |
869 | |
870 | // TODO(yangzihao): Add CPU version of dilated conv 3D. |
871 | OP_REQUIRES(context, |
872 | (GetTensorDim(dilation_, data_format_, '0') == 1 && |
873 | GetTensorDim(dilation_, data_format_, '1') == 1 && |
874 | GetTensorDim(dilation_, data_format_, '2') == 1), |
875 | errors::InvalidArgument( |
876 | "Current CPU implementation does not yet support " |
877 | "dilation rates larger than 1." )); |
878 | |
879 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
880 | OP_REQUIRES(context, stride_.size() == 5, |
881 | errors::InvalidArgument("Sliding window strides field must " |
882 | "specify 5 dimensions" )); |
883 | OP_REQUIRES( |
884 | context, |
885 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
886 | GetTensorDim(stride_, data_format_, 'N') == 1), |
887 | errors::InvalidArgument("Current implementation does not yet support " |
888 | "strides in the batch and depth dimensions." )); |
889 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
890 | } |
891 | |
892 | void Compute(OpKernelContext* context) override { |
893 | const Tensor& input = context->input(0); |
894 | const TensorShape& input_shape = input.shape(); |
895 | |
896 | const Tensor& out_backprop = context->input(2); |
897 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
898 | |
899 | TensorShape filter_shape; |
900 | if (takes_shape_) { |
901 | const Tensor& filter_sizes = context->input(1); |
902 | OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()), |
903 | errors::InvalidArgument( |
904 | "filter_sizes shape must be rank 1 but is rank " , |
905 | filter_sizes.shape().dims())); |
906 | OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( |
907 | filter_sizes.vec<int32>(), &filter_shape)); |
908 | } else { |
909 | filter_shape = context->input(1).shape(); |
910 | } |
911 | |
912 | OP_REQUIRES(context, input_shape.dims() == 5, |
913 | errors::InvalidArgument("input tensor must have 5 dimensions" )); |
914 | OP_REQUIRES( |
915 | context, filter_shape.dims() == 5, |
916 | errors::InvalidArgument("filter_sizes tensor must have 5 dimensions" )); |
917 | OP_REQUIRES( |
918 | context, out_backprop_shape.dims() == 5, |
919 | errors::InvalidArgument("out_backprop tensor must have 5 dimensions" )); |
920 | OP_REQUIRES( |
921 | context, input_shape.dim_size(4) == filter_shape.dim_size(3), |
922 | errors::InvalidArgument("input and filter_sizes must have the same " |
923 | "number of channels. Got " , |
924 | input_shape.dim_size(4), " for input and " , |
925 | filter_shape.dim_size(3), " for filter_sizes" )); |
926 | OP_REQUIRES( |
927 | context, out_backprop_shape.dim_size(4) == filter_shape.dim_size(4), |
928 | errors::InvalidArgument("out_backprop and filter_sizes must have the " |
929 | "same number of channels. Got " , |
930 | out_backprop_shape.dim_size(4), |
931 | " for out_backprop and " , |
932 | filter_shape.dim_size(4), " for filter_sizes" )); |
933 | |
934 | ConvBackpropDimensions dims; |
935 | OP_REQUIRES_OK(context, |
936 | ConvBackpropComputeDimensions( |
937 | "Conv3DBackpropFilterOp" , /*num_spatial_dims=*/3, |
938 | input_shape, filter_shape, out_backprop_shape, stride_, |
939 | padding_, data_format_, &dims)); |
940 | |
941 | Tensor* filter_backprop; |
942 | OP_REQUIRES_OK(context, |
943 | context->allocate_output(0, filter_shape, &filter_backprop)); |
944 | |
945 | if (input_shape.num_elements() == 0) { |
946 | filter_backprop->template flat<T>().setZero(); |
947 | return; |
948 | } |
949 | |
950 | int64_t top_pad_planes, bottom_pad_planes; |
951 | int64_t top_pad_rows, bottom_pad_rows; |
952 | int64_t left_pad_cols, right_pad_cols; |
953 | |
954 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
955 | dims.spatial_dims[0].input_size, |
956 | dims.spatial_dims[0].filter_size, |
957 | dims.spatial_dims[0].stride, padding_, |
958 | &dims.spatial_dims[0].output_size, |
959 | &top_pad_planes, &bottom_pad_planes)); |
960 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
961 | dims.spatial_dims[1].input_size, |
962 | dims.spatial_dims[1].filter_size, |
963 | dims.spatial_dims[1].stride, padding_, |
964 | &dims.spatial_dims[1].output_size, |
965 | &top_pad_rows, &bottom_pad_rows)); |
966 | OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( |
967 | dims.spatial_dims[2].input_size, |
968 | dims.spatial_dims[2].filter_size, |
969 | dims.spatial_dims[2].stride, padding_, |
970 | &dims.spatial_dims[2].output_size, |
971 | &left_pad_cols, &right_pad_cols)); |
972 | |
973 | // TODO(ezhulenev): Extract work size and shard estimation to shared |
974 | // functions in conv_grad_ops, and update 2d convolution backprop. |
975 | |
976 | // The total dimension size of each kernel. |
977 | const int64_t filter_total_size = |
978 | dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * |
979 | dims.spatial_dims[2].filter_size * dims.in_depth; |
980 | // The output image size is the spatial size of the output. |
981 | const int64_t output_image_size = dims.spatial_dims[0].output_size * |
982 | dims.spatial_dims[1].output_size * |
983 | dims.spatial_dims[2].output_size; |
984 | |
985 | // Shard 'batch' images (volumes) into 'shard_size' groups of images |
986 | // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by |
987 | // dividing the L3 cache size ('target_working_set_size') by the matmul size |
988 | // of an individual image ('work_unit_size'). |
989 | |
990 | const auto cache_sizes = Eigen::internal::CacheSizes(); |
991 | const ptrdiff_t l3_cache_size = cache_sizes.m_l3; |
992 | |
993 | // TODO(andydavis) |
994 | // *) Consider reducing 'target_working_set_size' if L3 is shared by |
995 | // other concurrently running tensorflow ops. |
996 | const size_t target_working_set_size = l3_cache_size / sizeof(T); |
997 | |
998 | const int64_t size_A = output_image_size * filter_total_size; |
999 | |
1000 | const int64_t size_B = output_image_size * dims.out_depth; |
1001 | |
1002 | const int64_t size_C = filter_total_size * dims.out_depth; |
1003 | |
1004 | const int64_t work_unit_size = size_A + size_B + size_C; |
1005 | |
1006 | OP_REQUIRES( |
1007 | context, work_unit_size > 0, |
1008 | errors::InvalidArgument("input, filter_sizes and out_backprop tensors " |
1009 | "must all have at least 1 element" )); |
1010 | |
1011 | const size_t shard_size = |
1012 | (target_working_set_size + work_unit_size - 1) / work_unit_size; |
1013 | |
1014 | // Total number of elements in all the tensors used by this kernel. |
1015 | int64_t total_tensor_elements = input_shape.num_elements() + |
1016 | filter_shape.num_elements() + |
1017 | out_backprop_shape.num_elements(); |
1018 | |
1019 | // Shape of the temporary workspace buffer. |
1020 | TensorShape col_buffer_shape = {static_cast<int64_t>(shard_size), |
1021 | static_cast<int64_t>(output_image_size), |
1022 | static_cast<int64_t>(filter_total_size)}; |
1023 | int64_t col_buffer_elements = col_buffer_shape.num_elements(); |
1024 | |
1025 | // If the temporary allocation overhead is too large, fallback on Eigen |
1026 | // implementation which requires much less memory. |
1027 | int64_t col_buffer_overhead = col_buffer_elements / total_tensor_elements; |
1028 | if (col_buffer_overhead > kMaxTempAllocationOverhead) { |
1029 | VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: " |
1030 | "col_buffer_overhead=" |
1031 | << col_buffer_overhead; |
1032 | |
1033 | functor::CuboidConvolutionBackwardFilter<Device, T>()( |
1034 | context->eigen_device<Device>(), |
1035 | filter_backprop->tensor<T, 5>(), // filter_backward |
1036 | input.tensor<T, 5>(), // input |
1037 | out_backprop.tensor<T, 5>(), // output_backward |
1038 | static_cast<int>(dims.spatial_dims[0].stride), // stride_planes |
1039 | static_cast<int>(dims.spatial_dims[1].stride), // stride_rows |
1040 | static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols |
1041 | |
1042 | return; |
1043 | } |
1044 | |
1045 | Tensor col_buffer; |
1046 | OP_REQUIRES_OK(context, |
1047 | context->allocate_temp(DataTypeToEnum<T>::value, |
1048 | col_buffer_shape, &col_buffer)); |
1049 | |
1050 | // The input offset corresponding to a single input image. |
1051 | const int64_t input_offset = |
1052 | dims.spatial_dims[0].input_size * dims.spatial_dims[1].input_size * |
1053 | dims.spatial_dims[2].input_size * dims.in_depth; |
1054 | // The output offset corresponding to a single output image. |
1055 | const int64_t output_offset = |
1056 | dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * |
1057 | dims.spatial_dims[2].output_size * dims.out_depth; |
1058 | |
1059 | const T* input_data = input.template flat<T>().data(); |
1060 | T* col_buffer_data = col_buffer.template flat<T>().data(); |
1061 | const T* out_backprop_data = out_backprop.template flat<T>().data(); |
1062 | T* filter_backprop_data = filter_backprop->template flat<T>().data(); |
1063 | |
1064 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, |
1065 | Eigen::Unaligned> |
1066 | TensorMap; |
1067 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, |
1068 | Eigen::Unaligned> |
1069 | ConstTensorMap; |
1070 | |
1071 | TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); |
1072 | C.setZero(); |
1073 | |
1074 | // Initialize contraction dims (we need to transpose 'A' below). |
1075 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; |
1076 | contract_dims[0].first = 0; |
1077 | contract_dims[0].second = 0; |
1078 | |
1079 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
1080 | |
1081 | for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { |
1082 | const int shard_limit = |
1083 | std::min(static_cast<int>(shard_size), |
1084 | static_cast<int>(dims.batch_size) - image_id); |
1085 | |
1086 | auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes, |
1087 | &top_pad_rows, &left_pad_cols, &bottom_pad_planes, |
1088 | &bottom_pad_rows, &right_pad_cols, &input_offset, |
1089 | &size_A](int64_t start, int64_t limit) { |
1090 | for (int shard_id = start; shard_id < limit; ++shard_id) { |
1091 | const T* input_data_shard = input_data + shard_id * input_offset; |
1092 | T* col_data_shard = col_buffer_data + shard_id * size_A; |
1093 | |
1094 | // When we compute the gradient with respect to the filters, we need |
1095 | // to do im2col to allow gemm-type computation. |
1096 | Im2col<T>(input_data_shard, dims.in_depth, |
1097 | // Input spatial dimensions. |
1098 | dims.spatial_dims[0].input_size, // input planes |
1099 | dims.spatial_dims[1].input_size, // input rows |
1100 | dims.spatial_dims[2].input_size, // input cols |
1101 | // Filter spatial dimensions. |
1102 | dims.spatial_dims[0].filter_size, // filter planes |
1103 | dims.spatial_dims[1].filter_size, // filter rows |
1104 | dims.spatial_dims[2].filter_size, // filter cols |
1105 | // Spatial padding. |
1106 | top_pad_planes, top_pad_rows, left_pad_cols, |
1107 | bottom_pad_planes, bottom_pad_rows, right_pad_cols, |
1108 | // Spatial striding. |
1109 | dims.spatial_dims[0].stride, // stride planes |
1110 | dims.spatial_dims[1].stride, // stride rows |
1111 | dims.spatial_dims[2].stride, // stride cols |
1112 | col_data_shard); |
1113 | } |
1114 | }; |
1115 | Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, |
1116 | size_A, shard); |
1117 | |
1118 | ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, |
1119 | filter_total_size); |
1120 | ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, |
1121 | dims.out_depth); |
1122 | |
1123 | // Gradient with respect to filter. |
1124 | C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); |
1125 | |
1126 | input_data += input_offset * shard_limit; |
1127 | out_backprop_data += output_offset * shard_limit; |
1128 | } |
1129 | } |
1130 | |
1131 | private: |
1132 | std::vector<int32> dilation_; |
1133 | std::vector<int32> stride_; |
1134 | Padding padding_; |
1135 | TensorFormat data_format_; |
1136 | bool takes_shape_; |
1137 | |
1138 | TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp); |
1139 | }; |
1140 | |
1141 | // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than |
1142 | // default Eigen implementation (at the cost of ~2x-8x peak memory usage). |
1143 | |
1144 | #define REGISTER_CPU_KERNEL(T) \ |
1145 | REGISTER_KERNEL_BUILDER( \ |
1146 | Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
1147 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1148 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1149 | .Device(DEVICE_CPU) \ |
1150 | .TypeConstraint<T>("T"), \ |
1151 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1152 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ |
1153 | .Device(DEVICE_CPU) \ |
1154 | .Label("custom") \ |
1155 | .TypeConstraint<T>("T"), \ |
1156 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1157 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1158 | .Device(DEVICE_CPU) \ |
1159 | .Label("custom") \ |
1160 | .TypeConstraint<T>("T"), \ |
1161 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1162 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ |
1163 | .Device(DEVICE_CPU) \ |
1164 | .Label("eigen_tensor") \ |
1165 | .TypeConstraint<T>("T"), \ |
1166 | Conv3DBackpropFilterOp<CPUDevice, T>); \ |
1167 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1168 | .Device(DEVICE_CPU) \ |
1169 | .Label("eigen_tensor") \ |
1170 | .TypeConstraint<T>("T"), \ |
1171 | Conv3DBackpropFilterOp<CPUDevice, T>); |
1172 | |
1173 | TF_CALL_float(REGISTER_CPU_KERNEL); |
1174 | TF_CALL_double(REGISTER_CPU_KERNEL); |
1175 | #undef REGISTER_CPU_KERNEL |
1176 | |
1177 | #define REGISTER_CPU_KERNEL(T) \ |
1178 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1179 | .Device(DEVICE_CPU) \ |
1180 | .TypeConstraint<T>("T"), \ |
1181 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1182 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1183 | .Device(DEVICE_CPU) \ |
1184 | .Label("custom") \ |
1185 | .TypeConstraint<T>("T"), \ |
1186 | Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ |
1187 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1188 | .Device(DEVICE_CPU) \ |
1189 | .Label("eigen_tensor") \ |
1190 | .TypeConstraint<T>("T"), \ |
1191 | Conv3DBackpropFilterOp<CPUDevice, T>); |
1192 | |
1193 | TF_CALL_bfloat16(REGISTER_CPU_KERNEL); |
1194 | #undef REGISTER_CPU_KERNEL |
1195 | |
1196 | // WARNING: Eigen::half is not trivially copyable and can't be used in |
1197 | // custom backprop filter kernel because of memcpy and memset in Im2col. |
1198 | #define REGISTER_CPU_KERNEL(T) \ |
1199 | REGISTER_KERNEL_BUILDER( \ |
1200 | Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
1201 | Conv3DBackpropFilterOp<CPUDevice, T>); \ |
1202 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
1203 | .Device(DEVICE_CPU) \ |
1204 | .TypeConstraint<T>("T"), \ |
1205 | Conv3DBackpropFilterOp<CPUDevice, T>); |
1206 | |
1207 | TF_CALL_half(REGISTER_CPU_KERNEL); |
1208 | #undef REGISTER_CPU_KERNEL |
1209 | |
1210 | // GPU definitions of both ops. |
1211 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1212 | // Forward declarations of the functor specializations for GPU. |
1213 | // This ensures that the custom implementation is used instead of the default |
1214 | // Eigen one (which is used for CPU). |
1215 | namespace functor { |
1216 | #define DECLARE_GPU_SPEC(T) \ |
1217 | template <> \ |
1218 | void TransformFilter<GPUDevice, T, int, 5>::operator()( \ |
1219 | const GPUDevice& d, FilterTensorFormat dst_filter_format, \ |
1220 | typename TTypes<T, 5, int>::ConstTensor in, \ |
1221 | typename TTypes<T, 5, int>::Tensor out); \ |
1222 | template <> \ |
1223 | void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ |
1224 | const GPUDevice& d, FilterTensorFormat src_filter_format, \ |
1225 | typename TTypes<T, 5>::ConstTensor in, \ |
1226 | typename TTypes<T, 5>::Tensor out); \ |
1227 | template <> \ |
1228 | void PadInput<GPUDevice, T, int, 5>::operator()( \ |
1229 | const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ |
1230 | const std::array<int, 3>& padding_left, \ |
1231 | const std::array<int, 3>& padding_right, \ |
1232 | typename TTypes<T, 5, int>::Tensor out, TensorFormat format, \ |
1233 | const T& padding_value); |
1234 | |
1235 | DECLARE_GPU_SPEC(Eigen::half); |
1236 | DECLARE_GPU_SPEC(float); |
1237 | DECLARE_GPU_SPEC(double); |
1238 | #undef DECLARE_GPU_SPEC |
1239 | } // namespace functor |
1240 | |
1241 | // A dummy type to group backward data autotune results together. |
1242 | struct Conv3dBackwardDataAutotuneGroup { |
1243 | static string name() { return "Conv3dBwdData" ; } |
1244 | }; |
1245 | |
1246 | typedef AutotuneSingleton<Conv3dBackwardDataAutotuneGroup, ConvParameters, |
1247 | AutotuneEntry<se::dnn::ConvOp>> |
1248 | |
1249 | AutotuneConv3dBwdData; |
1250 | |
1251 | template <typename T> |
1252 | class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { |
1253 | public: |
1254 | explicit Conv3DBackpropInputOp(OpKernelConstruction* context) |
1255 | : OpKernel(context), |
1256 | data_format_(FORMAT_NHWC), |
1257 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
1258 | // data_format is only available in V2. |
1259 | if (takes_shape_) { |
1260 | string data_format; |
1261 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
1262 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
1263 | errors::InvalidArgument("Invalid data format" )); |
1264 | } |
1265 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
1266 | OP_REQUIRES(context, dilation_.size() == 5, |
1267 | errors::InvalidArgument("Dilation rates field must " |
1268 | "specify 5 dimensions" )); |
1269 | OP_REQUIRES(context, |
1270 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
1271 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
1272 | errors::InvalidArgument( |
1273 | "Current implementation does not yet support " |
1274 | "dilation rates in the batch and depth dimensions." )); |
1275 | OP_REQUIRES( |
1276 | context, |
1277 | (GetTensorDim(dilation_, data_format_, '0') > 0 && |
1278 | GetTensorDim(dilation_, data_format_, '1') > 0 && |
1279 | GetTensorDim(dilation_, data_format_, '2') > 0), |
1280 | errors::InvalidArgument("Dilated rates should be larger than 0." )); |
1281 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
1282 | OP_REQUIRES(context, stride_.size() == 5, |
1283 | errors::InvalidArgument("Sliding window strides field must " |
1284 | "specify 5 dimensions" )); |
1285 | OP_REQUIRES( |
1286 | context, |
1287 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
1288 | GetTensorDim(stride_, data_format_, 'N') == 1), |
1289 | errors::InvalidArgument("Current implementation does not yet support " |
1290 | "strides in the batch and depth dimensions." )); |
1291 | OP_REQUIRES( |
1292 | context, |
1293 | (GetTensorDim(stride_, data_format_, '0') > 0 && |
1294 | GetTensorDim(stride_, data_format_, '1') > 0 && |
1295 | GetTensorDim(stride_, data_format_, '2') > 0), |
1296 | errors::InvalidArgument("Spatial strides should be larger than 0." )); |
1297 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
1298 | cudnn_use_autotune_ = CudnnUseAutotune(); |
1299 | } |
1300 | void Compute(OpKernelContext* context) override { |
1301 | const Tensor& filter = context->input(1); |
1302 | const TensorShape& filter_shape = filter.shape(); |
1303 | |
1304 | const Tensor& out_backprop = context->input(2); |
1305 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
1306 | |
1307 | TensorShape input_shape; |
1308 | if (takes_shape_) { |
1309 | const Tensor& input_sizes = context->input(0); |
1310 | OP_REQUIRES_OK(context, tensor::MakeShape(input_sizes, &input_shape)); |
1311 | } else { |
1312 | input_shape = context->input(0).shape(); |
1313 | } |
1314 | |
1315 | ConvBackpropDimensions dims; |
1316 | OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2( |
1317 | "Conv3DBackpropInputOp" , /*num_spatial_dims=*/3, |
1318 | input_shape, filter_shape, out_backprop_shape, |
1319 | dilation_, stride_, padding_, |
1320 | /*explicit_paddings=*/{}, data_format_, &dims)); |
1321 | |
1322 | Tensor* in_backprop; |
1323 | OP_REQUIRES_OK(context, |
1324 | context->allocate_output(0, input_shape, &in_backprop)); |
1325 | |
1326 | auto* stream = context->op_device_context()->stream(); |
1327 | OP_REQUIRES(context, stream, errors::Internal("No GPU stream available." )); |
1328 | |
1329 | bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; |
1330 | if (!is_grouped_convolution && dims.filter_size(0) == 1 && |
1331 | dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && |
1332 | dims.dilation(0) == 1 && dims.dilation(1) == 1 && |
1333 | dims.dilation(2) == 1 && dims.stride(0) == 1 && dims.stride(1) == 1 && |
1334 | dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) { |
1335 | const uint64 m = dims.batch_size * dims.input_size(0) * |
1336 | dims.input_size(1) * dims.input_size(2); |
1337 | const uint64 k = dims.out_depth; |
1338 | const uint64 n = dims.in_depth; |
1339 | |
1340 | auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
1341 | out_backprop.template flat<T>().size()); |
1342 | auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), |
1343 | filter.template flat<T>().size()); |
1344 | auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), |
1345 | in_backprop->template flat<T>().size()); |
1346 | |
1347 | auto transpose = se::blas::Transpose::kTranspose; |
1348 | auto no_transpose = se::blas::Transpose::kNoTranspose; |
1349 | |
1350 | OP_REQUIRES_OK( |
1351 | context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, |
1352 | k, a_ptr, k, &c_ptr, n, |
1353 | se::blas::kDefaultComputePrecision)); |
1354 | return; |
1355 | } else if (!is_grouped_convolution && |
1356 | dims.filter_size(0) == dims.input_size(0) && |
1357 | dims.filter_size(1) == dims.input_size(1) && |
1358 | dims.filter_size(2) == dims.input_size(2) && |
1359 | padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { |
1360 | const uint64 m = dims.batch_size; |
1361 | const uint64 k = dims.out_depth; |
1362 | const uint64 n = dims.input_size(0) * dims.input_size(1) * |
1363 | dims.input_size(2) * dims.in_depth; |
1364 | |
1365 | auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
1366 | out_backprop.template flat<T>().size()); |
1367 | auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), |
1368 | filter.template flat<T>().size()); |
1369 | auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), |
1370 | in_backprop->template flat<T>().size()); |
1371 | |
1372 | auto transpose = se::blas::Transpose::kTranspose; |
1373 | auto no_transpose = se::blas::Transpose::kNoTranspose; |
1374 | |
1375 | OP_REQUIRES_OK( |
1376 | context, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, |
1377 | k, a_ptr, k, &c_ptr, n, |
1378 | se::blas::kDefaultComputePrecision)); |
1379 | return; |
1380 | } |
1381 | |
1382 | int padding_planes = dims.SpatialPadding(padding_, 0); |
1383 | int padding_rows = dims.SpatialPadding(padding_, 1); |
1384 | int padding_cols = dims.SpatialPadding(padding_, 2); |
1385 | const bool planes_odd = (padding_planes % 2 != 0); |
1386 | const bool rows_odd = (padding_rows % 2 != 0); |
1387 | const bool cols_odd = (padding_cols % 2 != 0); |
1388 | |
1389 | TensorShape compatible_input_shape; |
1390 | if (rows_odd || cols_odd || planes_odd) { |
1391 | // cuDNN only supports the same amount of padding on both sides. |
1392 | compatible_input_shape = { |
1393 | dims.batch_size, |
1394 | dims.in_depth, |
1395 | dims.input_size(0) + planes_odd, |
1396 | dims.input_size(1) + rows_odd, |
1397 | dims.input_size(2) + cols_odd, |
1398 | }; |
1399 | } else { |
1400 | compatible_input_shape = {dims.batch_size, dims.in_depth, |
1401 | dims.input_size(0), dims.input_size(1), |
1402 | dims.input_size(2)}; |
1403 | } |
1404 | |
1405 | CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) |
1406 | << "Negative paddings: (" << padding_rows << ", " << padding_cols |
1407 | << ", " << padding_planes << ")" ; |
1408 | |
1409 | #if GOOGLE_CUDA |
1410 | const bool compute_in_nhwc = |
1411 | CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF; |
1412 | #else |
1413 | // fast NDHWC implementation is a CUDA only feature |
1414 | const bool compute_in_nhwc = false; |
1415 | #endif |
1416 | const TensorFormat compute_data_format = |
1417 | (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC |
1418 | : FORMAT_NCHW; |
1419 | |
1420 | VLOG(3) << "Compute Conv3DBackpropInput with cuDNN:" |
1421 | << " data_format=" << ToString(data_format_) |
1422 | << " compute_data_format=" << ToString(compute_data_format); |
1423 | |
1424 | constexpr auto kComputeInNHWC = |
1425 | std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, |
1426 | se::dnn::FilterLayout::kOutputYXInput); |
1427 | constexpr auto kComputeInNCHW = |
1428 | std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, |
1429 | se::dnn::FilterLayout::kOutputInputYX); |
1430 | |
1431 | se::dnn::DataLayout compute_data_layout; |
1432 | se::dnn::FilterLayout filter_layout; |
1433 | |
1434 | std::tie(compute_data_layout, filter_layout) = |
1435 | compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; |
1436 | |
1437 | se::dnn::BatchDescriptor input_desc(3); |
1438 | input_desc.set_count(dims.batch_size) |
1439 | .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) |
1440 | .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) |
1441 | .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) |
1442 | .set_feature_map_count(dims.in_depth) |
1443 | .set_layout(compute_data_layout); |
1444 | se::dnn::BatchDescriptor output_desc(3); |
1445 | output_desc.set_count(dims.batch_size) |
1446 | .set_spatial_dim(DimIndex::X, dims.output_size(2)) |
1447 | .set_spatial_dim(DimIndex::Y, dims.output_size(1)) |
1448 | .set_spatial_dim(DimIndex::Z, dims.output_size(0)) |
1449 | .set_feature_map_count(dims.out_depth) |
1450 | .set_layout(compute_data_layout); |
1451 | se::dnn::FilterDescriptor filter_desc(3); |
1452 | filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) |
1453 | .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) |
1454 | .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) |
1455 | .set_input_feature_map_count(filter_shape.dim_size(3)) |
1456 | .set_output_feature_map_count(filter_shape.dim_size(4)) |
1457 | .set_layout(filter_layout); |
1458 | se::dnn::ConvolutionDescriptor conv_desc(3); |
1459 | conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) |
1460 | .set_dilation_rate(DimIndex::Y, dims.dilation(1)) |
1461 | .set_dilation_rate(DimIndex::Z, dims.dilation(0)) |
1462 | .set_filter_stride(DimIndex::X, dims.stride(2)) |
1463 | .set_filter_stride(DimIndex::Y, dims.stride(1)) |
1464 | .set_filter_stride(DimIndex::Z, dims.stride(0)) |
1465 | .set_zero_padding(DimIndex::X, padding_cols / 2) |
1466 | .set_zero_padding(DimIndex::Y, padding_rows / 2) |
1467 | .set_zero_padding(DimIndex::Z, padding_planes / 2) |
1468 | .set_group_count(dims.in_depth / filter_shape.dim_size(3)); |
1469 | |
1470 | // Shape: out, in, z, y, x. |
1471 | Tensor transformed_filter; |
1472 | auto dst_format = |
1473 | compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; |
1474 | TensorShape dst_shape = |
1475 | dst_format == FORMAT_OIHW |
1476 | ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), |
1477 | dims.filter_size(0), dims.filter_size(1), |
1478 | dims.filter_size(2)}) |
1479 | : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), |
1480 | dims.filter_size(1), dims.filter_size(2), |
1481 | filter_shape.dim_size(3)}); |
1482 | OP_REQUIRES_OK(context, |
1483 | context->allocate_temp(DataTypeToEnum<T>::value, dst_shape, |
1484 | &transformed_filter)); |
1485 | |
1486 | functor::TransformFilter<GPUDevice, T, int, 5>()( |
1487 | context->eigen_device<GPUDevice>(), dst_format, |
1488 | To32Bit(filter.tensor<T, 5>()), |
1489 | To32Bit(transformed_filter.tensor<T, 5>())); |
1490 | |
1491 | // Shape: batch, filters, z, y, x. |
1492 | Tensor transformed_out_backprop; |
1493 | if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
1494 | TensorShape nchw_shape = {dims.batch_size, dims.out_depth, |
1495 | dims.output_size(0), dims.output_size(1), |
1496 | dims.output_size(2)}; |
1497 | if (dims.out_depth > 1) { |
1498 | OP_REQUIRES_OK(context, context->allocate_temp( |
1499 | DataTypeToEnum<T>::value, nchw_shape, |
1500 | &transformed_out_backprop)); |
1501 | functor::NHWCToNCHW<GPUDevice, T, 5>()( |
1502 | context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), |
1503 | transformed_out_backprop.tensor<T, 5>()); |
1504 | } else { |
1505 | CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); |
1506 | } |
1507 | } else { |
1508 | transformed_out_backprop = out_backprop; |
1509 | } |
1510 | // Shape: batch, filters, z, y, x. |
1511 | Tensor pre_transformed_in_backprop; |
1512 | OP_REQUIRES_OK(context, |
1513 | context->allocate_temp( |
1514 | DataTypeToEnum<T>::value, |
1515 | ShapeFromFormat(compute_data_format, |
1516 | compatible_input_shape.dim_size(0), |
1517 | {{compatible_input_shape.dim_size(2), |
1518 | compatible_input_shape.dim_size(3), |
1519 | compatible_input_shape.dim_size(4)}}, |
1520 | compatible_input_shape.dim_size(1)), |
1521 | &pre_transformed_in_backprop)); |
1522 | |
1523 | auto out_backprop_ptr = |
1524 | AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), |
1525 | transformed_out_backprop.template flat<T>().size()); |
1526 | auto filter_ptr = |
1527 | AsDeviceMemory(transformed_filter.template flat<T>().data(), |
1528 | transformed_filter.template flat<T>().size()); |
1529 | auto in_backprop_ptr = |
1530 | AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), |
1531 | pre_transformed_in_backprop.template flat<T>().size()); |
1532 | |
1533 | static int64_t ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit( |
1534 | "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" , 1LL << 33); // 8GB by default |
1535 | |
1536 | const int device_id = stream->parent()->device_ordinal(); |
1537 | // To make sure the Conv3DBackpropInputV2 get the correct dtype, we infer |
1538 | // the dtype from 2nd input, i.e., out_backprop. |
1539 | DataType dtype = context->input(2).dtype(); |
1540 | const ConvParameters conv_parameters = { |
1541 | dims.batch_size, |
1542 | dims.in_depth, |
1543 | {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, |
1544 | compute_data_format, |
1545 | dims.out_depth, |
1546 | {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, |
1547 | {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, |
1548 | {{dims.stride(0), dims.stride(1), dims.stride(2)}}, |
1549 | {{padding_planes, padding_rows, padding_cols}}, |
1550 | dtype, |
1551 | device_id, |
1552 | conv_desc.group_count()}; |
1553 | |
1554 | using se::dnn::AlgorithmConfig; |
1555 | using se::dnn::AlgorithmDesc; |
1556 | using se::dnn::ProfileResult; |
1557 | |
1558 | auto entry_or = AutotuneUnfusedConv( |
1559 | cudnn_use_autotune_, AutotuneConv3dBwdData::GetInstance(), |
1560 | conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_DATA, |
1561 | input_desc, in_backprop_ptr, filter_desc, filter_ptr, conv_desc, |
1562 | output_desc, out_backprop_ptr, ConvolveBackwardDataScratchSize); |
1563 | OP_REQUIRES_OK(context, entry_or.status()); |
1564 | auto autotune_entry = std::move(entry_or).value(); |
1565 | |
1566 | DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, |
1567 | context); |
1568 | Status cudnn_launch_status = LaunchAutotunedConv( |
1569 | autotune_entry, &scratch_allocator, |
1570 | se::dnn::ConvolutionKind::BACKWARD_DATA, stream, input_desc, |
1571 | in_backprop_ptr, filter_desc, filter_ptr, conv_desc, output_desc, |
1572 | out_backprop_ptr); |
1573 | if (!cudnn_launch_status.ok()) { |
1574 | context->SetStatus(cudnn_launch_status); |
1575 | return; |
1576 | } |
1577 | |
1578 | if (rows_odd || cols_odd || planes_odd) { |
1579 | Tensor in_backprop_remove_padding; |
1580 | OP_REQUIRES_OK( |
1581 | context, context->allocate_temp( |
1582 | DataTypeToEnum<T>::value, |
1583 | ShapeFromFormat(compute_data_format, dims.batch_size, |
1584 | {{dims.input_size(0), dims.input_size(1), |
1585 | dims.input_size(2)}}, |
1586 | dims.in_depth), |
1587 | &in_backprop_remove_padding)); |
1588 | |
1589 | // Remove the padding for odd spatial dimensions. |
1590 | functor::PadInput<GPUDevice, T, int, 5>()( |
1591 | context->eigen_device<GPUDevice>(), |
1592 | To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) |
1593 | .tensor<T, 5>()), |
1594 | {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}}, |
1595 | To32Bit(in_backprop_remove_padding.tensor<T, 5>()), |
1596 | compute_data_format, T{}); |
1597 | |
1598 | pre_transformed_in_backprop = in_backprop_remove_padding; |
1599 | } |
1600 | |
1601 | if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
1602 | auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; |
1603 | functor::NCHWToNHWC<GPUDevice, T, 5>()( |
1604 | context->eigen_device<GPUDevice>(), |
1605 | toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(), |
1606 | in_backprop->tensor<T, 5>()); |
1607 | } else { |
1608 | *in_backprop = pre_transformed_in_backprop; |
1609 | } |
1610 | } |
1611 | |
1612 | private: |
1613 | std::vector<int32> dilation_; |
1614 | std::vector<int32> stride_; |
1615 | Padding padding_; |
1616 | TensorFormat data_format_; |
1617 | bool takes_shape_; |
1618 | bool cudnn_use_autotune_; |
1619 | }; |
1620 | |
1621 | // A dummy type to group backward filter autotune results together. |
1622 | struct Conv3dBackwardFilterAutotuneGroup { |
1623 | static string name() { return "Conv3dBwdFilter" ; } |
1624 | }; |
1625 | |
1626 | typedef AutotuneSingleton<Conv3dBackwardFilterAutotuneGroup, ConvParameters, |
1627 | AutotuneEntry<se::dnn::ConvOp>> |
1628 | AutotuneConv3dBwdFilter; |
1629 | |
1630 | template <typename T> |
1631 | class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { |
1632 | public: |
1633 | explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) |
1634 | : OpKernel(context), |
1635 | data_format_(FORMAT_NHWC), |
1636 | takes_shape_(type_string().find("V2" ) != std::string::npos) { |
1637 | // data_format is only available in V2. |
1638 | if (takes_shape_) { |
1639 | string data_format; |
1640 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
1641 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
1642 | errors::InvalidArgument("Invalid data format" )); |
1643 | } |
1644 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilation_)); |
1645 | OP_REQUIRES(context, dilation_.size() == 5, |
1646 | errors::InvalidArgument("Dilation rates field must " |
1647 | "specify 5 dimensions" )); |
1648 | OP_REQUIRES(context, |
1649 | (GetTensorDim(dilation_, data_format_, 'C') == 1 && |
1650 | GetTensorDim(dilation_, data_format_, 'N') == 1), |
1651 | errors::InvalidArgument( |
1652 | "Current implementation does not yet support " |
1653 | "dilation rates in the batch and depth dimensions." )); |
1654 | OP_REQUIRES( |
1655 | context, |
1656 | (GetTensorDim(dilation_, data_format_, '0') > 0 && |
1657 | GetTensorDim(dilation_, data_format_, '1') > 0 && |
1658 | GetTensorDim(dilation_, data_format_, '2') > 0), |
1659 | errors::InvalidArgument("Dilated rates should be larger than 0." )); |
1660 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &stride_)); |
1661 | OP_REQUIRES(context, stride_.size() == 5, |
1662 | errors::InvalidArgument("Sliding window strides field must " |
1663 | "specify 5 dimensions" )); |
1664 | OP_REQUIRES( |
1665 | context, |
1666 | (GetTensorDim(stride_, data_format_, 'C') == 1 && |
1667 | GetTensorDim(stride_, data_format_, 'N') == 1), |
1668 | errors::InvalidArgument("Current implementation does not yet support " |
1669 | "strides in the batch and depth dimensions." )); |
1670 | OP_REQUIRES( |
1671 | context, |
1672 | (GetTensorDim(stride_, data_format_, '0') > 0 && |
1673 | GetTensorDim(stride_, data_format_, '1') > 0 && |
1674 | GetTensorDim(stride_, data_format_, '2') > 0), |
1675 | errors::InvalidArgument("Spatial strides should be larger than 0." )); |
1676 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
1677 | cudnn_use_autotune_ = CudnnUseAutotune(); |
1678 | } |
1679 | |
1680 | void Compute(OpKernelContext* context) override { |
1681 | const Tensor& input = context->input(0); |
1682 | const TensorShape& input_shape = input.shape(); |
1683 | |
1684 | const Tensor& out_backprop = context->input(2); |
1685 | const TensorShape& out_backprop_shape = out_backprop.shape(); |
1686 | |
1687 | TensorShape filter_shape; |
1688 | if (takes_shape_) { |
1689 | const Tensor& filter_sizes = context->input(1); |
1690 | OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()), |
1691 | errors::InvalidArgument( |
1692 | "filter_sizes shape must be rank 1 but is rank " , |
1693 | filter_sizes.shape().dims())); |
1694 | OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape)); |
1695 | } else { |
1696 | filter_shape = context->input(1).shape(); |
1697 | } |
1698 | |
1699 | ConvBackpropDimensions dims; |
1700 | OP_REQUIRES_OK( |
1701 | context, |
1702 | ConvBackpropComputeDimensionsV2( |
1703 | "Conv3DBackpropFilterOp" , /*num_spatial_dims=*/3, input_shape, |
1704 | filter_shape, out_backprop_shape, dilation_, stride_, padding_, |
1705 | /*explicit_paddings=*/{}, data_format_, &dims)); |
1706 | |
1707 | Tensor* filter_backprop; |
1708 | OP_REQUIRES_OK(context, |
1709 | context->allocate_output(0, filter_shape, &filter_backprop)); |
1710 | |
1711 | auto* stream = context->op_device_context()->stream(); |
1712 | OP_REQUIRES(context, stream, errors::Internal("No GPU stream available." )); |
1713 | |
1714 | bool is_grouped_convolution = filter_shape.dim_size(3) != dims.in_depth; |
1715 | if (!is_grouped_convolution && dims.filter_size(1) == 1 && |
1716 | dims.filter_size(2) == 1 && dims.filter_size(0) == 1 && |
1717 | dims.dilation(2) == 1 && dims.dilation(1) == 1 && |
1718 | dims.dilation(0) == 1 && dims.stride(2) == 1 && dims.stride(1) == 1 && |
1719 | dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) { |
1720 | const uint64 m = dims.in_depth; |
1721 | const uint64 k = dims.batch_size * dims.input_size(1) * |
1722 | dims.input_size(2) * dims.input_size(0); |
1723 | const uint64 n = dims.out_depth; |
1724 | |
1725 | // The shape of output backprop is |
1726 | // [batch, out_z, out_y, out_x, out_depth] |
1727 | // From cublas's perspective, it is: n x k |
1728 | auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
1729 | out_backprop.template flat<T>().size()); |
1730 | |
1731 | // The shape of input is: |
1732 | // [batch, in_z, in_y, in_x, in_depth], |
1733 | // From cublas's perspective, it is: m x k |
1734 | auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), |
1735 | input.template flat<T>().size()); |
1736 | |
1737 | // The shape of the filter backprop is: |
1738 | // [1, 1, 1, in_depth, out_depth] |
1739 | // From cublas's perspective, it is: n x m |
1740 | auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), |
1741 | filter_backprop->template flat<T>().size()); |
1742 | |
1743 | OP_REQUIRES_OK(context, |
1744 | stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, |
1745 | se::blas::Transpose::kTranspose, n, m, |
1746 | k, a_ptr, n, b_ptr, m, &c_ptr, n, |
1747 | se::blas::kDefaultComputePrecision)); |
1748 | return; |
1749 | } else if (!is_grouped_convolution && |
1750 | dims.filter_size(0) == dims.input_size(0) && |
1751 | dims.filter_size(1) == dims.input_size(1) && |
1752 | dims.filter_size(2) == dims.input_size(2) && |
1753 | padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { |
1754 | const uint64 m = dims.input_size(0) * dims.input_size(1) * |
1755 | dims.input_size(2) * dims.in_depth; |
1756 | const uint64 k = dims.batch_size; |
1757 | const uint64 n = dims.out_depth; |
1758 | |
1759 | auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), |
1760 | input.template flat<T>().size()); |
1761 | auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
1762 | out_backprop.template flat<T>().size()); |
1763 | auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), |
1764 | filter_backprop->template flat<T>().size()); |
1765 | |
1766 | OP_REQUIRES_OK(context, |
1767 | stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, |
1768 | se::blas::Transpose::kTranspose, n, m, |
1769 | k, b_ptr, n, a_ptr, m, &c_ptr, n, |
1770 | se::blas::kDefaultComputePrecision)); |
1771 | return; |
1772 | } |
1773 | |
1774 | int padding_planes = dims.SpatialPadding(padding_, 0); |
1775 | int padding_rows = dims.SpatialPadding(padding_, 1); |
1776 | int padding_cols = dims.SpatialPadding(padding_, 2); |
1777 | const bool planes_odd = (padding_planes % 2 != 0); |
1778 | const bool rows_odd = (padding_rows % 2 != 0); |
1779 | const bool cols_odd = (padding_cols % 2 != 0); |
1780 | |
1781 | Tensor compatible_input; |
1782 | if (rows_odd || cols_odd || planes_odd) { |
1783 | OP_REQUIRES_OK(context, |
1784 | context->allocate_temp( |
1785 | DataTypeToEnum<T>::value, |
1786 | ShapeFromFormat(data_format_, dims.batch_size, |
1787 | {{dims.input_size(0) + planes_odd, |
1788 | dims.input_size(1) + rows_odd, |
1789 | dims.input_size(2) + cols_odd}}, |
1790 | dims.in_depth), |
1791 | &compatible_input)); |
1792 | functor::PadInput<GPUDevice, T, int, 5>()( |
1793 | context->template eigen_device<GPUDevice>(), |
1794 | To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, |
1795 | {{planes_odd, rows_odd, cols_odd}}, |
1796 | To32Bit(compatible_input.tensor<T, 5>()), data_format_, T{}); |
1797 | } else { |
1798 | compatible_input = input; |
1799 | } |
1800 | |
1801 | CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) |
1802 | << "Negative paddings: (" << padding_rows << ", " << padding_cols |
1803 | << ", " << padding_planes << ")" ; |
1804 | |
1805 | #if GOOGLE_CUDA |
1806 | const bool compute_in_nhwc = |
1807 | CUDNN_VERSION >= 8000 && DataTypeToEnum<T>::value == DT_HALF; |
1808 | #else |
1809 | // fast NDHWC implementation is a CUDA only feature |
1810 | const bool compute_in_nhwc = false; |
1811 | #endif |
1812 | const TensorFormat compute_data_format = |
1813 | (compute_in_nhwc && data_format_ == FORMAT_NHWC) ? FORMAT_NHWC |
1814 | : FORMAT_NCHW; |
1815 | |
1816 | VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" |
1817 | << " data_format=" << ToString(data_format_) |
1818 | << " compute_data_format=" << ToString(compute_data_format); |
1819 | |
1820 | constexpr auto kComputeInNHWC = |
1821 | std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, |
1822 | se::dnn::FilterLayout::kOutputYXInput); |
1823 | constexpr auto kComputeInNCHW = |
1824 | std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, |
1825 | se::dnn::FilterLayout::kOutputInputYX); |
1826 | |
1827 | se::dnn::DataLayout compute_data_layout; |
1828 | se::dnn::FilterLayout filter_layout; |
1829 | |
1830 | std::tie(compute_data_layout, filter_layout) = |
1831 | compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; |
1832 | |
1833 | se::dnn::BatchDescriptor input_desc(3); |
1834 | input_desc.set_count(dims.batch_size) |
1835 | .set_spatial_dim(DimIndex::X, |
1836 | GetTensorDim(compatible_input, data_format_, '2')) |
1837 | .set_spatial_dim(DimIndex::Y, |
1838 | GetTensorDim(compatible_input, data_format_, '1')) |
1839 | .set_spatial_dim(DimIndex::Z, |
1840 | GetTensorDim(compatible_input, data_format_, '0')) |
1841 | .set_feature_map_count(dims.in_depth) |
1842 | .set_layout(compute_data_layout); |
1843 | se::dnn::BatchDescriptor output_desc(3); |
1844 | output_desc.set_count(dims.batch_size) |
1845 | .set_spatial_dim(DimIndex::X, dims.output_size(2)) |
1846 | .set_spatial_dim(DimIndex::Y, dims.output_size(1)) |
1847 | .set_spatial_dim(DimIndex::Z, dims.output_size(0)) |
1848 | .set_feature_map_count(dims.out_depth) |
1849 | .set_layout(compute_data_layout); |
1850 | se::dnn::FilterDescriptor filter_desc(3); |
1851 | filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) |
1852 | .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) |
1853 | .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) |
1854 | .set_input_feature_map_count(filter_shape.dim_size(3)) |
1855 | .set_output_feature_map_count(filter_shape.dim_size(4)) |
1856 | .set_layout(filter_layout); |
1857 | se::dnn::ConvolutionDescriptor conv_desc(3); |
1858 | conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) |
1859 | .set_dilation_rate(DimIndex::Y, dims.dilation(1)) |
1860 | .set_dilation_rate(DimIndex::Z, dims.dilation(0)) |
1861 | .set_filter_stride(DimIndex::X, dims.stride(2)) |
1862 | .set_filter_stride(DimIndex::Y, dims.stride(1)) |
1863 | .set_filter_stride(DimIndex::Z, dims.stride(0)) |
1864 | .set_zero_padding(DimIndex::X, padding_cols / 2) |
1865 | .set_zero_padding(DimIndex::Y, padding_rows / 2) |
1866 | .set_zero_padding(DimIndex::Z, padding_planes / 2) |
1867 | .set_group_count(dims.in_depth / filter_shape.dim_size(3)); |
1868 | |
1869 | Tensor pre_transformed_filter_backprop; |
1870 | auto dst_format = |
1871 | compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; |
1872 | TensorShape dst_shape = |
1873 | dst_format == FORMAT_OIHW |
1874 | ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), |
1875 | dims.filter_size(0), dims.filter_size(1), |
1876 | dims.filter_size(2)}) |
1877 | : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), |
1878 | dims.filter_size(1), dims.filter_size(2), |
1879 | filter_shape.dim_size(3)}); |
1880 | OP_REQUIRES_OK(context, |
1881 | context->allocate_temp(DataTypeToEnum<T>::value, dst_shape, |
1882 | &pre_transformed_filter_backprop)); |
1883 | |
1884 | Tensor transformed_out_backprop; |
1885 | if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
1886 | VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW." ; |
1887 | TensorShape nchw_shape = {dims.batch_size, dims.out_depth, |
1888 | dims.output_size(0), dims.output_size(1), |
1889 | dims.output_size(2)}; |
1890 | OP_REQUIRES_OK( |
1891 | context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, |
1892 | &transformed_out_backprop)); |
1893 | if (dims.out_depth > 1) { |
1894 | functor::NHWCToNCHW<GPUDevice, T, 5>()( |
1895 | context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), |
1896 | transformed_out_backprop.tensor<T, 5>()); |
1897 | } else { |
1898 | CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); |
1899 | } |
1900 | } else { |
1901 | transformed_out_backprop = out_backprop; |
1902 | } |
1903 | Tensor transformed_input; |
1904 | if (data_format_ == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
1905 | VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW." ; |
1906 | TensorShape nchw_shape = { |
1907 | dims.batch_size, dims.in_depth, compatible_input.dim_size(1), |
1908 | compatible_input.dim_size(2), compatible_input.dim_size(3)}; |
1909 | if (dims.in_depth > 1) { |
1910 | OP_REQUIRES_OK(context, |
1911 | context->allocate_temp(DataTypeToEnum<T>::value, |
1912 | nchw_shape, &transformed_input)); |
1913 | functor::NHWCToNCHW<GPUDevice, T, 5>()( |
1914 | context->eigen_device<GPUDevice>(), |
1915 | const_cast<const Tensor&>(compatible_input).tensor<T, 5>(), |
1916 | transformed_input.tensor<T, 5>()); |
1917 | } else { |
1918 | CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); |
1919 | } |
1920 | } else { |
1921 | transformed_input = compatible_input; |
1922 | } |
1923 | |
1924 | auto out_backprop_ptr = |
1925 | AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), |
1926 | transformed_out_backprop.template flat<T>().size()); |
1927 | auto filter_backprop_ptr = AsDeviceMemory( |
1928 | pre_transformed_filter_backprop.template flat<T>().data(), |
1929 | pre_transformed_filter_backprop.template flat<T>().size()); |
1930 | auto input_ptr = |
1931 | AsDeviceMemory(transformed_input.template flat<T>().data(), |
1932 | transformed_input.template flat<T>().size()); |
1933 | |
1934 | static int64_t ConvolveBackwardFilterScratchSize = |
1935 | GetDnnWorkspaceLimitOrDefault(); |
1936 | |
1937 | const int device_id = stream->parent()->device_ordinal(); |
1938 | DataType dtype = input.dtype(); |
1939 | const ConvParameters conv_parameters = { |
1940 | dims.batch_size, |
1941 | dims.in_depth, |
1942 | {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, |
1943 | compute_data_format, |
1944 | dims.out_depth, |
1945 | {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, |
1946 | {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, |
1947 | {{dims.stride(0), dims.stride(1), dims.stride(2)}}, |
1948 | {{padding_planes, padding_rows, padding_cols}}, |
1949 | dtype, |
1950 | device_id, |
1951 | conv_desc.group_count()}; |
1952 | |
1953 | using se::dnn::AlgorithmConfig; |
1954 | using se::dnn::AlgorithmDesc; |
1955 | using se::dnn::ProfileResult; |
1956 | |
1957 | auto entry_or = AutotuneUnfusedConv( |
1958 | cudnn_use_autotune_, AutotuneConv3dBwdFilter::GetInstance(), |
1959 | conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER, |
1960 | input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc, |
1961 | output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize); |
1962 | OP_REQUIRES_OK(context, entry_or.status()); |
1963 | auto autotune_entry = std::move(entry_or).value(); |
1964 | |
1965 | DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, |
1966 | context); |
1967 | Status cudnn_launch_status = LaunchAutotunedConv( |
1968 | autotune_entry, &scratch_allocator, |
1969 | se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, |
1970 | input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc, |
1971 | out_backprop_ptr); |
1972 | if (!cudnn_launch_status.ok()) { |
1973 | context->SetStatus(cudnn_launch_status); |
1974 | return; |
1975 | } |
1976 | |
1977 | auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; |
1978 | functor::ReverseTransformFilter<GPUDevice, T, 5>()( |
1979 | context->eigen_device<GPUDevice>(), /*src_filter_format=*/dst_format, |
1980 | toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(), |
1981 | filter_backprop->tensor<T, 5>()); |
1982 | } |
1983 | |
1984 | private: |
1985 | std::vector<int32> dilation_; |
1986 | std::vector<int32> stride_; |
1987 | Padding padding_; |
1988 | TensorFormat data_format_; |
1989 | bool takes_shape_; |
1990 | bool cudnn_use_autotune_; |
1991 | }; |
1992 | |
1993 | #define REGISTER_GPU_KERNEL(T) \ |
1994 | REGISTER_KERNEL_BUILDER( \ |
1995 | Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
1996 | Conv3DBackpropInputOp<GPUDevice, T>); \ |
1997 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ |
1998 | .Device(DEVICE_GPU) \ |
1999 | .TypeConstraint<T>("T") \ |
2000 | .HostMemory("input_sizes"), \ |
2001 | Conv3DBackpropInputOp<GPUDevice, T>); \ |
2002 | REGISTER_KERNEL_BUILDER( \ |
2003 | Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
2004 | Conv3DBackpropFilterOp<GPUDevice, T>); \ |
2005 | REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ |
2006 | .Device(DEVICE_GPU) \ |
2007 | .TypeConstraint<T>("T") \ |
2008 | .HostMemory("filter_sizes"), \ |
2009 | Conv3DBackpropFilterOp<GPUDevice, T>); |
2010 | TF_CALL_half(REGISTER_GPU_KERNEL); |
2011 | TF_CALL_float(REGISTER_GPU_KERNEL); |
2012 | TF_CALL_double(REGISTER_GPU_KERNEL); |
2013 | #undef REGISTER_GPU_KERNEL |
2014 | |
2015 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
2016 | |
2017 | } // namespace tensorflow |
2018 | |