1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define USE_EIGEN_TENSOR |
19 | #define EIGEN_USE_THREADS |
20 | |
21 | #include <algorithm> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/core/framework/kernel_shape_util.h" |
26 | #include "tensorflow/core/framework/numeric_op.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/register_types.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/framework/tensor_slice.h" |
32 | #include "tensorflow/core/kernels/conv_2d.h" |
33 | #include "tensorflow/core/kernels/conv_grad_ops.h" |
34 | #include "tensorflow/core/kernels/conv_grad_shape_utils.h" |
35 | #include "tensorflow/core/kernels/fill_functor.h" |
36 | #include "tensorflow/core/profiler/lib/scoped_annotation.h" |
37 | |
38 | #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS |
39 | #include "tensorflow/core/kernels/xsmm_conv2d.h" |
40 | #endif |
41 | #include "tensorflow/core/lib/core/errors.h" |
42 | #include "tensorflow/core/lib/gtl/array_slice.h" |
43 | #include "tensorflow/core/platform/logging.h" |
44 | #include "tensorflow/core/platform/macros.h" |
45 | #include "tensorflow/core/util/padding.h" |
46 | #include "tensorflow/core/util/tensor_format.h" |
47 | #include "tensorflow/core/util/use_cudnn.h" |
48 | #include "tensorflow/core/util/work_sharder.h" |
49 | |
50 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
51 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
52 | #endif |
53 | |
54 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
55 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
56 | #include "tensorflow/core/platform/stream_executor.h" |
57 | #include "tensorflow/core/protobuf/autotuning.pb.h" |
58 | #include "tensorflow/core/util/autotune_maps/conv_parameters.h" |
59 | #include "tensorflow/core/util/proto/proto_utils.h" |
60 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
61 | #if GOOGLE_CUDA |
62 | #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" |
63 | #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" |
64 | #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" |
65 | #endif // GOOGLE_CUDA |
66 | |
67 | namespace { |
68 | |
69 | // Returns in 'col_data', image patches in storage order (height, width, depth) |
70 | // extracted from image at 'input_data', which is required to be in storage |
71 | // order (batch, height, width, depth). |
72 | // Implementation written by Yangqing Jia (jiayq). |
73 | template <typename T> |
74 | void Im2col(const T* input_data, const int depth, const int height, |
75 | const int width, const int filter_h, const int filter_w, |
76 | const int pad_t, const int pad_l, const int pad_b, const int pad_r, |
77 | const int stride_h, const int stride_w, T* col_data) { |
78 | int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; |
79 | int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; |
80 | |
81 | int h_pad = -pad_t; |
82 | for (int h = 0; h < height_col; ++h) { |
83 | int w_pad = -pad_l; |
84 | for (int w = 0; w < width_col; ++w) { |
85 | for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { |
86 | for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { |
87 | if (ih >= 0 && ih < height && iw >= 0 && iw < width) { |
88 | memcpy(col_data, input_data + (ih * width + iw) * depth, |
89 | sizeof(T) * depth); |
90 | } else { |
91 | // This should be simply padded with zero. |
92 | memset(col_data, 0, sizeof(T) * depth); |
93 | } |
94 | col_data += depth; |
95 | } |
96 | } |
97 | w_pad += stride_w; |
98 | } |
99 | h_pad += stride_h; |
100 | } |
101 | } |
102 | |
103 | } // namespace |
104 | |
105 | namespace tensorflow { |
106 | |
107 | typedef Eigen::ThreadPoolDevice CPUDevice; |
108 | typedef Eigen::GpuDevice GPUDevice; |
109 | |
110 | template <typename T> |
111 | struct LaunchConv2DBackpropFilterOp<CPUDevice, T> { |
112 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
113 | const Tensor& out_backprop, const Tensor& input, |
114 | int row_dilation, int col_dilation, int row_stride, |
115 | int col_stride, const Padding& padding, |
116 | const std::vector<int64_t>& explicit_paddings, |
117 | Tensor* filter_backprop, TensorFormat data_format) { |
118 | std::vector<int32> dilations(4, 1); |
119 | dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation; |
120 | dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation; |
121 | |
122 | std::vector<int32> strides(4, 1); |
123 | strides[GetTensorDimIndex(data_format, 'H')] = row_stride; |
124 | strides[GetTensorDimIndex(data_format, 'W')] = col_stride; |
125 | TensorShape filter_shape = filter_backprop->shape(); |
126 | |
127 | ConvBackpropDimensions dims; |
128 | OP_REQUIRES_OK( |
129 | ctx, ConvBackpropComputeDimensionsV2( |
130 | "Conv2DBackpropFilter" , /*num_spatial_dims=*/2, input.shape(), |
131 | filter_shape, out_backprop.shape(), dilations, strides, |
132 | padding, explicit_paddings, data_format, &dims)); |
133 | |
134 | int64_t padding_top = -1, padding_bottom = -1; |
135 | int64_t padding_left = -1, padding_right = -1; |
136 | if (padding == EXPLICIT) { |
137 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', |
138 | &padding_top, &padding_bottom); |
139 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', |
140 | &padding_left, &padding_right); |
141 | } |
142 | int64_t expected_out_rows, expected_out_cols; |
143 | // The function is guaranteed to succeed because we checked the output and |
144 | // padding was valid earlier. |
145 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
146 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
147 | row_dilation, row_stride, padding, &expected_out_rows, &padding_top, |
148 | &padding_bottom)); |
149 | DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); |
150 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
151 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
152 | col_dilation, col_stride, padding, &expected_out_cols, &padding_left, |
153 | &padding_right)); |
154 | DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols); |
155 | |
156 | const CPUDevice& d = ctx->eigen_device<CPUDevice>(); |
157 | |
158 | // WARNING: Need to swap row/col, padding_top/padding_left, and |
159 | // padding_bottom/padding_right when calling Eigen. Eigen expects tensors |
160 | // in NWHC format, but Tensorflow uses NHWC. |
161 | |
162 | auto filter_backprop_t = filter_backprop->tensor<T, 4>(); |
163 | auto input_t = input.tensor<T, 4>(); |
164 | auto out_backprop_t = out_backprop.tensor<T, 4>(); |
165 | |
166 | if (padding != EXPLICIT) { |
167 | // If padding was not explicitly defined, Eigen spatial convolution |
168 | // backward filter will infer correct forward paddings from input tensors. |
169 | filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel( |
170 | input_t, out_backprop_t, filter_backprop_t.dimension(1), |
171 | filter_backprop_t.dimension(0), col_stride, row_stride, col_dilation, |
172 | row_dilation); |
173 | |
174 | } else { |
175 | // Otherwise we have to explicitly pad the input, before passing it to |
176 | // spatial convolution backward filter. |
177 | Eigen::array<std::pair<int, int>, 4> paddings; |
178 | paddings[0] = {0, 0}; |
179 | paddings[1] = {padding_top, padding_bottom}; |
180 | paddings[2] = {padding_left, padding_right}; |
181 | paddings[3] = {0, 0}; |
182 | |
183 | auto padded_t = input_t.pad(paddings, T(0)); |
184 | |
185 | // TODO(ezhulenev): Pass explicit paddings to Eigen spatial backward |
186 | // convolution and do not rely on tensor padding expression. |
187 | filter_backprop_t.device(d) = Eigen::SpatialConvolutionBackwardKernel( |
188 | padded_t, out_backprop_t, filter_backprop_t.dimension(1), |
189 | filter_backprop_t.dimension(0), col_stride, row_stride, col_dilation, |
190 | row_dilation); |
191 | } |
192 | } |
193 | }; |
194 | |
195 | #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS |
196 | template <typename Device, class T> |
197 | struct LaunchXsmmBackwardFilter { |
198 | bool operator()(OpKernelContext* context, const Device& d, |
199 | typename TTypes<T, 4>::ConstTensor input_backward, |
200 | typename TTypes<T, 4>::Tensor kernel, |
201 | typename TTypes<T, 4>::ConstTensor output_backward, |
202 | int input_rows, int input_cols, int row_stride, |
203 | int col_stride, int pad_h, int pad_w, |
204 | TensorFormat data_format) const { |
205 | return false; |
206 | } |
207 | }; |
208 | |
209 | template <> |
210 | struct LaunchXsmmBackwardFilter<CPUDevice, float> { |
211 | bool operator()(OpKernelContext* context, const CPUDevice& d, |
212 | typename TTypes<float, 4>::ConstTensor input, |
213 | typename TTypes<float, 4>::Tensor filter, |
214 | typename TTypes<float, 4>::ConstTensor output, int input_rows, |
215 | int input_cols, int row_stride, int col_stride, int pad_h, |
216 | int pad_w, TensorFormat data_format) const { |
217 | auto batch = input.dimension(0); |
218 | auto in_depth = input.dimension(3); |
219 | auto out_depth = output.dimension(3); |
220 | auto filter_rows = filter.dimension(0); |
221 | auto filter_cols = filter.dimension(1); |
222 | |
223 | auto num_threads = |
224 | context->device()->tensorflow_cpu_worker_threads()->num_threads; |
225 | // See libxsmm_dnn.h for this struct definition. |
226 | libxsmm_dnn_conv_desc desc; |
227 | desc.N = batch; |
228 | desc.C = in_depth; |
229 | desc.H = input_rows; |
230 | desc.W = input_cols; |
231 | desc.K = out_depth; |
232 | desc.R = filter_rows; |
233 | desc.S = filter_cols; |
234 | desc.u = row_stride; |
235 | desc.v = col_stride; |
236 | desc.pad_h = pad_h; |
237 | desc.pad_w = pad_w; |
238 | desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now. |
239 | desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now. |
240 | desc.pad_h_out = 0; |
241 | desc.pad_w_out = 0; |
242 | desc.threads = num_threads; |
243 | desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; |
244 | desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; |
245 | desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; |
246 | desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; |
247 | desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; |
248 | desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; |
249 | desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; |
250 | if (!CanUseXsmmConv2D(desc, data_format)) { |
251 | return false; |
252 | } |
253 | |
254 | auto input_ptr = input.data(); |
255 | auto filter_ptr = filter.data(); |
256 | auto output_ptr = output.data(); |
257 | bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()( |
258 | context, desc, input_ptr, filter_ptr, output_ptr); |
259 | return success; |
260 | } |
261 | }; |
262 | #endif |
263 | |
264 | template <typename Device, class T> |
265 | class Conv2DBackpropFilterOp : public OpKernel { |
266 | public: |
267 | explicit Conv2DBackpropFilterOp(OpKernelConstruction* context) |
268 | : OpKernel(context) { |
269 | string data_format; |
270 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
271 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
272 | errors::InvalidArgument("Invalid data format" )); |
273 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
274 | int stride_n = GetTensorDim(strides_, data_format_, 'N'); |
275 | int stride_c = GetTensorDim(strides_, data_format_, 'C'); |
276 | int stride_h = GetTensorDim(strides_, data_format_, 'H'); |
277 | int stride_w = GetTensorDim(strides_, data_format_, 'W'); |
278 | OP_REQUIRES( |
279 | context, (stride_n == 1 && stride_c == 1), |
280 | errors::InvalidArgument("Current implementation does not yet support " |
281 | "strides in the batch and depth dimensions." )); |
282 | OP_REQUIRES(context, stride_h > 0 && stride_w > 0, |
283 | errors::InvalidArgument( |
284 | "Row and column strides should be larger than 0." )); |
285 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilations_)); |
286 | OP_REQUIRES(context, dilations_.size() == 4, |
287 | errors::InvalidArgument("Sliding window dilations field must " |
288 | "specify 4 dimensions" )); |
289 | int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); |
290 | int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); |
291 | int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); |
292 | int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); |
293 | OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, |
294 | errors::InvalidArgument( |
295 | "Current implementation does not yet support " |
296 | "dilations in the batch and depth dimensions." )); |
297 | OP_REQUIRES( |
298 | context, dilation_h > 0 && dilation_w > 0, |
299 | errors::InvalidArgument("Dilated rates should be larger than 0." )); |
300 | |
301 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
302 | OP_REQUIRES_OK(context, |
303 | context->GetAttr("explicit_paddings" , &explicit_paddings_)); |
304 | OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, |
305 | /*num_dims=*/4, data_format_)); |
306 | |
307 | OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu" , &use_cudnn_)); |
308 | cudnn_use_autotune_ = CudnnUseAutotune(); |
309 | |
310 | if (std::is_same<Device, CPUDevice>::value) { |
311 | OP_REQUIRES(context, data_format_ == FORMAT_NHWC, |
312 | errors::InvalidArgument("Conv2DBackpropFilterOp [CPU] " |
313 | "only supports NHWC data format." )); |
314 | |
315 | // TODO(yangzihao): Add a CPU implementation for dilated convolution. |
316 | OP_REQUIRES( |
317 | context, (dilation_h == 1 && dilation_w == 1), |
318 | errors::InvalidArgument("Conv2DBackpropFilterOp [CPU] not yet " |
319 | "support dilation rates larger than 1." )); |
320 | } |
321 | } |
322 | |
323 | void Compute(OpKernelContext* context) override { |
324 | const Tensor& input = context->input(0); |
325 | const Tensor& filter_sizes = context->input(1); |
326 | const Tensor& out_backprop = context->input(2); |
327 | OP_REQUIRES( |
328 | context, TensorShapeUtils::IsVector(filter_sizes.shape()), |
329 | errors::InvalidArgument( |
330 | "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not " , |
331 | filter_sizes.dims())); |
332 | TensorShape filter_shape; |
333 | OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( |
334 | filter_sizes.vec<int32>(), &filter_shape)); |
335 | |
336 | Tensor* filter_backprop = nullptr; |
337 | OP_REQUIRES_OK(context, |
338 | context->allocate_output(0, filter_shape, &filter_backprop)); |
339 | |
340 | // If there is nothing to compute, return. |
341 | if (filter_shape.num_elements() == 0) { |
342 | return; |
343 | } |
344 | // If input is empty, set gradients to zero. |
345 | if (input.shape().num_elements() == 0) { |
346 | functor::SetZeroFunctor<Device, T> f; |
347 | f(context->eigen_device<Device>(), filter_backprop->flat<T>()); |
348 | return; |
349 | } |
350 | |
351 | // For now we take the stride from the second and third dimensions only (we |
352 | // do not support striding on the batch or depth dimension). |
353 | const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); |
354 | const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); |
355 | const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); |
356 | const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); |
357 | |
358 | VLOG(2) << "Conv2DBackpropFilter:" |
359 | << " input: " << input.shape().DebugString() |
360 | << " filter:" << filter_shape.DebugString() |
361 | << " out_backprop: " << out_backprop.shape().DebugString() |
362 | << " strides: [" << stride_rows << ", " << stride_cols << "]" |
363 | << " dilations: [" << dilation_rows << ", " << dilation_cols << "]" ; |
364 | |
365 | launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, |
366 | dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, |
367 | explicit_paddings_, filter_backprop, data_format_); |
368 | } |
369 | |
370 | private: |
371 | std::vector<int32> dilations_; |
372 | std::vector<int32> strides_; |
373 | Padding padding_; |
374 | std::vector<int64_t> explicit_paddings_; |
375 | bool use_cudnn_; |
376 | TensorFormat data_format_; |
377 | LaunchConv2DBackpropFilterOp<Device, T> launcher_; |
378 | bool cudnn_use_autotune_; |
379 | |
380 | TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropFilterOp); |
381 | }; |
382 | |
383 | // Based on implementation written by Yangqing Jia (jiayq). |
384 | template <typename Device, class T> |
385 | class Conv2DCustomBackpropFilterOp : public OpKernel { |
386 | public: |
387 | explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context) |
388 | : OpKernel(context) { |
389 | string data_format; |
390 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
391 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
392 | errors::InvalidArgument("Invalid data format" )); |
393 | OP_REQUIRES(context, data_format_ == FORMAT_NHWC, |
394 | errors::InvalidArgument( |
395 | "Conv2DCustomBackpropFilterOp only supports NHWC." )); |
396 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
397 | OP_REQUIRES(context, strides_.size() == 4, |
398 | errors::InvalidArgument("Sliding window strides field must " |
399 | "specify 4 dimensions" )); |
400 | OP_REQUIRES( |
401 | context, (strides_[0] == 1 && strides_[3] == 1), |
402 | errors::InvalidArgument("Current implementation does not yet support " |
403 | "strides in the batch and depth dimensions." )); |
404 | OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, |
405 | errors::InvalidArgument( |
406 | "Row and column strides should be larger than 0." )); |
407 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
408 | OP_REQUIRES_OK(context, |
409 | context->GetAttr("explicit_paddings" , &explicit_paddings_)); |
410 | OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, |
411 | /*num_dims=*/4, data_format_)); |
412 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilations_)); |
413 | OP_REQUIRES(context, dilations_.size() == 4, |
414 | errors::InvalidArgument("Sliding window dilations field must " |
415 | "specify 4 dimensions" )); |
416 | OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), |
417 | errors::InvalidArgument( |
418 | "Current implementation does not yet support " |
419 | "dilations in the batch and depth dimensions." )); |
420 | if (std::is_same<Device, CPUDevice>::value || |
421 | std::is_same<Device, GPUDevice>::value) { |
422 | // TODO(yangzihao): Add a CPU implementation for dilated convolution. |
423 | OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), |
424 | errors::InvalidArgument( |
425 | "Current libxsmm and customized CPU implementations do " |
426 | "not yet support dilation rates larger than 1." )); |
427 | dilations_ = {1, 1, 1, 1}; |
428 | } |
429 | } |
430 | |
431 | void Compute(OpKernelContext* context) override { |
432 | const Tensor& input = context->input(0); |
433 | const Tensor& filter_sizes = context->input(1); |
434 | const Tensor& out_backprop = context->input(2); |
435 | OP_REQUIRES( |
436 | context, TensorShapeUtils::IsVector(filter_sizes.shape()), |
437 | errors::InvalidArgument( |
438 | "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " |
439 | "not " , |
440 | filter_sizes.dims())); |
441 | TensorShape filter_shape; |
442 | OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( |
443 | filter_sizes.vec<int32>(), &filter_shape)); |
444 | |
445 | ConvBackpropDimensions dims; |
446 | OP_REQUIRES_OK( |
447 | context, |
448 | ConvBackpropComputeDimensionsV2( |
449 | "Conv2DCustomBackpropFilter" , /*num_spatial_dims=*/2, input.shape(), |
450 | filter_shape, out_backprop.shape(), dilations_, strides_, padding_, |
451 | explicit_paddings_, data_format_, &dims)); |
452 | |
453 | Tensor* filter_backprop; |
454 | OP_REQUIRES_OK(context, |
455 | context->allocate_output(0, filter_shape, &filter_backprop)); |
456 | |
457 | // If there is nothing to compute, return. |
458 | if (filter_shape.num_elements() == 0) { |
459 | return; |
460 | } |
461 | |
462 | int64_t pad_top, pad_bottom; |
463 | int64_t pad_left, pad_right; |
464 | if (padding_ == Padding::EXPLICIT) { |
465 | pad_top = explicit_paddings_[2]; |
466 | pad_bottom = explicit_paddings_[3]; |
467 | pad_left = explicit_paddings_[4]; |
468 | pad_right = explicit_paddings_[5]; |
469 | } |
470 | OP_REQUIRES_OK( |
471 | context, |
472 | GetWindowedOutputSizeVerbose( |
473 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
474 | dims.spatial_dims[0].stride, padding_, |
475 | &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); |
476 | OP_REQUIRES_OK( |
477 | context, |
478 | GetWindowedOutputSizeVerbose( |
479 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
480 | dims.spatial_dims[1].stride, padding_, |
481 | &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); |
482 | #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ |
483 | defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS |
484 | if (pad_left == pad_right && pad_top == pad_bottom) { |
485 | if (LaunchXsmmBackwardFilter<Device, T>()( |
486 | context, context->eigen_device<Device>(), input.tensor<T, 4>(), |
487 | filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(), |
488 | dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, |
489 | static_cast<int>(dims.spatial_dims[0].stride), |
490 | static_cast<int>(dims.spatial_dims[1].stride), |
491 | static_cast<int>(pad_top), static_cast<int>(pad_left), |
492 | data_format_)) { |
493 | return; |
494 | } |
495 | } |
496 | #endif |
497 | |
498 | // The total dimension size of each kernel. |
499 | const int filter_total_size = dims.spatial_dims[0].filter_size * |
500 | dims.spatial_dims[1].filter_size * |
501 | dims.in_depth; |
502 | OP_REQUIRES( |
503 | context, |
504 | filter_total_size * dims.out_depth == filter_backprop->NumElements(), |
505 | errors::InvalidArgument( |
506 | "filter_size does not have enough elements, requested " , |
507 | filter_total_size * dims.out_depth, ", got " , |
508 | filter_backprop->NumElements())); |
509 | |
510 | // The output image size is the spatial size of the output. |
511 | const int output_image_size = |
512 | dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size; |
513 | |
514 | // Shard 'batch' images into 'shard_size' groups of images to be fed |
515 | // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache |
516 | // size ('target_working_set_size') by the matmul size of an individual |
517 | // image ('work_unit_size'). |
518 | |
519 | // TODO(andydavis) |
520 | // *) Get L3 cache size from device at runtime (30MB is from ivybridge). |
521 | // *) Consider reducing 'target_working_set_size' if L3 is shared by |
522 | // other concurrently running tensorflow ops. |
523 | const size_t target_working_set_size = (30LL << 20) / sizeof(T); |
524 | |
525 | const size_t size_A = output_image_size * filter_total_size; |
526 | |
527 | const size_t size_B = output_image_size * dims.out_depth; |
528 | |
529 | const size_t size_C = filter_total_size * dims.out_depth; |
530 | |
531 | const size_t work_unit_size = size_A + size_B + size_C; |
532 | |
533 | OP_REQUIRES( |
534 | context, work_unit_size != 0, |
535 | errors::InvalidArgument( |
536 | "Work size for convolution would be 0, which is not acceptable" )); |
537 | |
538 | const size_t shard_size = |
539 | (target_working_set_size + work_unit_size - 1) / work_unit_size; |
540 | |
541 | Tensor col_buffer; |
542 | OP_REQUIRES_OK(context, |
543 | context->allocate_temp( |
544 | DataTypeToEnum<T>::value, |
545 | TensorShape({static_cast<int64_t>(shard_size), |
546 | static_cast<int64_t>(output_image_size), |
547 | static_cast<int64_t>(filter_total_size)}), |
548 | &col_buffer)); |
549 | |
550 | // The input offset corresponding to a single input image. |
551 | const int input_offset = dims.spatial_dims[0].input_size * |
552 | dims.spatial_dims[1].input_size * dims.in_depth; |
553 | // The output offset corresponding to a single output image. |
554 | const int output_offset = dims.spatial_dims[0].output_size * |
555 | dims.spatial_dims[1].output_size * dims.out_depth; |
556 | |
557 | const T* input_data = input.template flat<T>().data(); |
558 | T* col_buffer_data = col_buffer.template flat<T>().data(); |
559 | const T* out_backprop_data = out_backprop.template flat<T>().data(); |
560 | T* filter_backprop_data = filter_backprop->template flat<T>().data(); |
561 | |
562 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, |
563 | Eigen::Unaligned> |
564 | TensorMap; |
565 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, |
566 | Eigen::Unaligned> |
567 | ConstTensorMap; |
568 | |
569 | TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); |
570 | C.setZero(); |
571 | |
572 | // Initialize contraction dims (we need to transpose 'A' below). |
573 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; |
574 | contract_dims[0].first = 0; |
575 | contract_dims[0].second = 0; |
576 | |
577 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
578 | |
579 | for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { |
580 | const int shard_limit = |
581 | std::min(static_cast<int>(shard_size), |
582 | static_cast<int>(dims.batch_size) - image_id); |
583 | |
584 | auto shard = [&input_data, &col_buffer_data, &dims, &pad_top, &pad_left, |
585 | &pad_bottom, &pad_right, &input_offset, |
586 | &size_A](int64_t start, int64_t limit) { |
587 | for (int shard_id = start; shard_id < limit; ++shard_id) { |
588 | const T* input_data_shard = input_data + shard_id * input_offset; |
589 | T* col_data_shard = col_buffer_data + shard_id * size_A; |
590 | |
591 | // When we compute the gradient with respect to the filters, we need |
592 | // to do im2col to allow gemm-type computation. |
593 | Im2col<T>( |
594 | input_data_shard, dims.in_depth, dims.spatial_dims[0].input_size, |
595 | dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size, |
596 | dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom, |
597 | pad_right, dims.spatial_dims[0].stride, |
598 | dims.spatial_dims[1].stride, col_data_shard); |
599 | } |
600 | }; |
601 | Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, |
602 | size_A, shard); |
603 | |
604 | ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, |
605 | filter_total_size); |
606 | ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, |
607 | dims.out_depth); |
608 | |
609 | // Gradient with respect to filter. |
610 | C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); |
611 | |
612 | input_data += input_offset * shard_limit; |
613 | out_backprop_data += output_offset * shard_limit; |
614 | } |
615 | } |
616 | |
617 | private: |
618 | std::vector<int32> dilations_; |
619 | std::vector<int32> strides_; |
620 | Padding padding_; |
621 | std::vector<int64_t> explicit_paddings_; |
622 | TensorFormat data_format_; |
623 | |
624 | TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp); |
625 | }; |
626 | |
627 | #define REGISTER_CPU_KERNELS(T) \ |
628 | REGISTER_KERNEL_BUILDER( \ |
629 | Name("Conv2DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
630 | Conv2DCustomBackpropFilterOp<CPUDevice, T>); \ |
631 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \ |
632 | .Device(DEVICE_CPU) \ |
633 | .Label("custom") \ |
634 | .TypeConstraint<T>("T") \ |
635 | .AttrConstraint("data_format", "NHWC"), \ |
636 | Conv2DCustomBackpropFilterOp<CPUDevice, T>); \ |
637 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \ |
638 | .Device(DEVICE_CPU) \ |
639 | .Label("eigen_tensor") \ |
640 | .TypeConstraint<T>("T") \ |
641 | .AttrConstraint("data_format", "NHWC"), \ |
642 | Conv2DBackpropFilterOp<CPUDevice, T>); |
643 | |
644 | TF_CALL_bfloat16(REGISTER_CPU_KERNELS); |
645 | TF_CALL_half(REGISTER_CPU_KERNELS); |
646 | TF_CALL_float(REGISTER_CPU_KERNELS); |
647 | TF_CALL_double(REGISTER_CPU_KERNELS); |
648 | #undef REGISTER_CPU_KERNELS |
649 | |
650 | // To be used inside depthwise_conv_grad_op.cc. |
651 | template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::bfloat16>; |
652 | template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>; |
653 | template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>; |
654 | template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>; |
655 | |
656 | // GPU definitions. |
657 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
658 | // The slow version (but compiles for GPU) |
659 | |
660 | // A dummy type to group forward backward filter autotune results together. |
661 | struct ConvBackwardFilterAutotuneGroup { |
662 | static string name() { return "ConvBwdFilter" ; } |
663 | }; |
664 | |
665 | typedef AutotuneSingleton<ConvBackwardFilterAutotuneGroup, ConvParameters, |
666 | AutotuneEntry<se::dnn::ConvOp>> |
667 | AutotuneConvBwdFilter; |
668 | |
669 | template <typename T> |
670 | void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( |
671 | OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
672 | const Tensor& out_backprop, const Tensor& input, int row_dilation, |
673 | int col_dilation, int row_stride, int col_stride, const Padding& padding, |
674 | const std::vector<int64_t>& explicit_paddings, Tensor* filter_backprop, |
675 | TensorFormat data_format) { |
676 | using se::dnn::AlgorithmConfig; |
677 | using se::dnn::AlgorithmDesc; |
678 | using se::dnn::ProfileResult; |
679 | |
680 | std::vector<int32> dilations(4, 1); |
681 | dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation; |
682 | dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation; |
683 | |
684 | std::vector<int32> strides(4, 1); |
685 | strides[GetTensorDimIndex(data_format, 'H')] = row_stride; |
686 | strides[GetTensorDimIndex(data_format, 'W')] = col_stride; |
687 | TensorShape filter_shape = filter_backprop->shape(); |
688 | |
689 | ConvBackpropDimensions dims; |
690 | OP_REQUIRES_OK( |
691 | ctx, ConvBackpropComputeDimensionsV2( |
692 | "Conv2DBackpropFilter" , /*num_spatial_dims=*/2, input.shape(), |
693 | filter_shape, out_backprop.shape(), dilations, strides, padding, |
694 | explicit_paddings, data_format, &dims)); |
695 | |
696 | int64_t padding_top = -1, padding_bottom = -1; |
697 | int64_t padding_left = -1, padding_right = -1; |
698 | if (padding == EXPLICIT) { |
699 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top, |
700 | &padding_bottom); |
701 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left, |
702 | &padding_right); |
703 | } |
704 | int64_t expected_out_rows, expected_out_cols; |
705 | // The function is guaranteed to succeed because we checked the output and |
706 | // padding was valid earlier. |
707 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
708 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
709 | row_dilation, row_stride, padding, &expected_out_rows, &padding_top, |
710 | &padding_bottom)); |
711 | DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); |
712 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
713 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
714 | col_dilation, col_stride, padding, &expected_out_cols, &padding_left, |
715 | &padding_right)); |
716 | DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols); |
717 | |
718 | auto* stream = ctx->op_device_context()->stream(); |
719 | OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available." )); |
720 | |
721 | if (!use_cudnn) { |
722 | ctx->SetStatus(errors::Unimplemented( |
723 | "Conv2DBackprop for GPU is not currently supported " |
724 | "without cudnn" )); |
725 | return; |
726 | } |
727 | |
728 | // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the |
729 | // input depth, it's a depthwise convolution. More generally, if the filter |
730 | // in-depth divides but is smaller than the input depth, it is a grouped |
731 | // convolution. |
732 | bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth; |
733 | bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); |
734 | if (!cudnn_disable_conv_1x1_optimization_ && |
735 | dims.spatial_dims[0].filter_size == 1 && |
736 | dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution && |
737 | dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && |
738 | data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) { |
739 | const uint64 m = dims.in_depth; |
740 | const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * |
741 | dims.spatial_dims[1].input_size; |
742 | const uint64 n = dims.out_depth; |
743 | |
744 | // The shape of output backprop is |
745 | // [batch, out_rows, out_cols, out_depth] |
746 | // From cublas's perspective, it is: n x k |
747 | auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
748 | out_backprop.template flat<T>().size()); |
749 | |
750 | // The shape of input is |
751 | // [batch, in_rows, in_cols, in_depth], |
752 | // From cublas's perspective, it is: m x k |
753 | auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), |
754 | input.template flat<T>().size()); |
755 | |
756 | // the shape of the filter backprop from the conv_2d should be |
757 | // [1, 1, in_depth, out_depth] |
758 | // From cublas's perspective, it is: n x m |
759 | auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), |
760 | filter_backprop->template flat<T>().size()); |
761 | |
762 | OP_REQUIRES_OK( |
763 | ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, |
764 | se::blas::Transpose::kTranspose, n, m, k, |
765 | a_ptr, n, b_ptr, m, &c_ptr, n, |
766 | se::blas::kDefaultComputePrecision)); |
767 | return; |
768 | } else if (dims.spatial_dims[0].filter_size == |
769 | dims.spatial_dims[0].input_size && |
770 | dims.spatial_dims[1].filter_size == |
771 | dims.spatial_dims[1].input_size && |
772 | !is_grouped_convolution && padding == VALID && |
773 | data_format == FORMAT_NHWC) { |
774 | // The input data and filter have the same height/width, and we are not |
775 | // using grouped convolution, so call cublas directly. |
776 | const uint64 m = dims.spatial_dims[0].input_size * |
777 | dims.spatial_dims[1].input_size * dims.in_depth; |
778 | const uint64 k = dims.batch_size; |
779 | const uint64 n = dims.out_depth; |
780 | |
781 | auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), |
782 | input.template flat<T>().size()); |
783 | auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), |
784 | out_backprop.template flat<T>().size()); |
785 | auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), |
786 | filter_backprop->template flat<T>().size()); |
787 | |
788 | OP_REQUIRES_OK( |
789 | ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, |
790 | se::blas::Transpose::kTranspose, n, m, k, |
791 | b_ptr, n, a_ptr, m, &c_ptr, n, |
792 | se::blas::kDefaultComputePrecision)); |
793 | return; |
794 | } |
795 | |
796 | const int64_t common_padding_rows = std::min(padding_top, padding_bottom); |
797 | const int64_t common_padding_cols = std::min(padding_left, padding_right); |
798 | Tensor compatible_input; |
799 | if (padding_top != padding_bottom || padding_left != padding_right) { |
800 | // Pad the input in the same way we did during the forward pass, so that |
801 | // cuDNN or MIOpen receives the same input during the backward pass function |
802 | // as it did during the forward pass function. |
803 | const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top); |
804 | const int64_t padding_cols_diff = std::abs(padding_right - padding_left); |
805 | const int64_t new_in_rows = |
806 | dims.spatial_dims[0].input_size + padding_rows_diff; |
807 | const int64_t new_in_cols = |
808 | dims.spatial_dims[1].input_size + padding_cols_diff; |
809 | const int64_t input_pad_top = padding_top - common_padding_rows; |
810 | const int64_t input_pad_bottom = padding_bottom - common_padding_rows; |
811 | const int64_t input_pad_left = padding_left - common_padding_cols; |
812 | const int64_t input_pad_right = padding_right - common_padding_cols; |
813 | OP_REQUIRES_OK( |
814 | ctx, ctx->allocate_temp( |
815 | DataTypeToEnum<T>::value, |
816 | ShapeFromFormat(data_format, dims.batch_size, new_in_rows, |
817 | new_in_cols, dims.in_depth), |
818 | &compatible_input)); |
819 | |
820 | functor::PadInput<GPUDevice, T, int, 4>()( |
821 | ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()), |
822 | {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}}, |
823 | {{static_cast<int>(input_pad_bottom), |
824 | static_cast<int>(input_pad_right)}}, |
825 | To32Bit(compatible_input.tensor<T, 4>()), data_format, T{}); |
826 | } else { |
827 | compatible_input = input; |
828 | } |
829 | |
830 | CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK |
831 | << "Negative row or col paddings: (" << common_padding_rows << ", " |
832 | << common_padding_cols << ")" ; |
833 | |
834 | // The Tensor Core in NVIDIA Volta+ GPUs supports efficient convolution with |
835 | // fp16 in NHWC data layout. In all other configurations it's more efficient |
836 | // to run computation in NCHW data format. |
837 | const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF && |
838 | stream->GetCudaComputeCapability().IsAtLeast( |
839 | se::CudaComputeCapability::VOLTA); |
840 | |
841 | // We only do one directional conversion: NHWC->NCHW. We never convert in the |
842 | // other direction. Grappler layout optimizer selects the preferred layout and |
843 | // adds necessary annotations to the graph. |
844 | const TensorFormat compute_data_format = |
845 | (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC |
846 | : FORMAT_NCHW; |
847 | |
848 | VLOG(3) << "Compute Conv2DBackpropFilter with cuDNN:" |
849 | << " data_format=" << ToString(data_format) |
850 | << " compute_data_format=" << ToString(compute_data_format); |
851 | |
852 | constexpr auto kComputeInNHWC = |
853 | std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, |
854 | se::dnn::FilterLayout::kOutputYXInput); |
855 | constexpr auto kComputeInNCHW = |
856 | std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, |
857 | se::dnn::FilterLayout::kOutputInputYX); |
858 | |
859 | se::dnn::DataLayout compute_data_layout; |
860 | se::dnn::FilterLayout filter_layout; |
861 | |
862 | std::tie(compute_data_layout, filter_layout) = |
863 | compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; |
864 | |
865 | se::dnn::BatchDescriptor input_desc; |
866 | input_desc.set_count(dims.batch_size) |
867 | .set_height(GetTensorDim(compatible_input, data_format, 'H')) |
868 | .set_width(GetTensorDim(compatible_input, data_format, 'W')) |
869 | .set_feature_map_count(dims.in_depth) |
870 | .set_layout(compute_data_layout); |
871 | se::dnn::BatchDescriptor output_desc; |
872 | output_desc.set_count(dims.batch_size) |
873 | .set_height(dims.spatial_dims[0].output_size) |
874 | .set_width(dims.spatial_dims[1].output_size) |
875 | .set_feature_map_count(dims.out_depth) |
876 | .set_layout(compute_data_layout); |
877 | se::dnn::FilterDescriptor filter_desc; |
878 | filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) |
879 | .set_input_filter_width(dims.spatial_dims[1].filter_size) |
880 | .set_input_feature_map_count(filter_shape.dim_size(2)) |
881 | .set_output_feature_map_count(filter_shape.dim_size(3)) |
882 | .set_layout(filter_layout); |
883 | se::dnn::ConvolutionDescriptor conv_desc; |
884 | conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) |
885 | .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) |
886 | .set_vertical_filter_stride(dims.spatial_dims[0].stride) |
887 | .set_horizontal_filter_stride(dims.spatial_dims[1].stride) |
888 | .set_zero_padding_height(common_padding_rows) |
889 | .set_zero_padding_width(common_padding_cols) |
890 | .set_group_count(dims.in_depth / filter_shape.dim_size(2)); |
891 | |
892 | // Tensorflow filter format: HWIO |
893 | // cuDNN filter formats: (data format) -> (filter format) |
894 | // (1) NCHW -> OIHW |
895 | // (2) NHWC -> OHWI |
896 | // |
897 | // We compute filter backprop into temporary tensor, and then convert it to |
898 | // the HWIO data format at the end. |
899 | |
900 | Tensor pre_transformed_filter_backprop; |
901 | OP_REQUIRES_OK( |
902 | ctx, |
903 | ctx->allocate_temp( |
904 | DataTypeToEnum<T>::value, |
905 | TensorShape({filter_shape.dim_size(3), filter_shape.dim_size(2), |
906 | filter_shape.dim_size(0), filter_shape.dim_size(1)}), |
907 | &pre_transformed_filter_backprop)); |
908 | |
909 | Tensor transformed_out_backprop; |
910 | if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
911 | VLOG(4) << "Convert the `out_backprop` tensor from NHWC to NCHW." ; |
912 | TensorShape compute_shape = ShapeFromFormat( |
913 | compute_data_format, dims.batch_size, dims.spatial_dims[0].output_size, |
914 | dims.spatial_dims[1].output_size, dims.out_depth); |
915 | if (dims.out_depth > 1) { |
916 | OP_REQUIRES_OK(ctx, |
917 | ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape, |
918 | &transformed_out_backprop)); |
919 | functor::NHWCToNCHW<GPUDevice, T, 4>()( |
920 | ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(), |
921 | transformed_out_backprop.tensor<T, 4>()); |
922 | } else { |
923 | // If depth <= 1, just reshape. |
924 | CHECK(transformed_out_backprop.CopyFrom(out_backprop, compute_shape)); |
925 | } |
926 | } else { |
927 | transformed_out_backprop = out_backprop; |
928 | } |
929 | |
930 | Tensor transformed_input; |
931 | if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { |
932 | VLOG(4) << "Convert the `input` tensor from NHWC to NCHW." ; |
933 | TensorShape compute_shape = ShapeFromFormat( |
934 | compute_data_format, GetTensorDim(compatible_input, data_format, 'N'), |
935 | GetTensorDim(compatible_input, data_format, 'H'), |
936 | GetTensorDim(compatible_input, data_format, 'W'), |
937 | GetTensorDim(compatible_input, data_format, 'C')); |
938 | if (compute_shape.dim_size(1) > 1) { |
939 | OP_REQUIRES_OK(ctx, |
940 | ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape, |
941 | &transformed_input)); |
942 | functor::NHWCToNCHW<GPUDevice, T, 4>()( |
943 | ctx->eigen_device<GPUDevice>(), |
944 | const_cast<const Tensor&>(compatible_input).tensor<T, 4>(), |
945 | transformed_input.tensor<T, 4>()); |
946 | } else { |
947 | // If depth <= 1, just reshape. |
948 | CHECK(transformed_input.CopyFrom(compatible_input, compute_shape)); |
949 | } |
950 | } else { |
951 | transformed_input = compatible_input; |
952 | } |
953 | |
954 | se::DeviceMemory<T> out_backprop_ptr = |
955 | AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), |
956 | transformed_out_backprop.template flat<T>().size()); |
957 | se::DeviceMemory<T> filter_backprop_ptr = |
958 | AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(), |
959 | pre_transformed_filter_backprop.template flat<T>().size()); |
960 | auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(), |
961 | transformed_input.template flat<T>().size()); |
962 | |
963 | static int64_t ConvolveBackwardFilterScratchSize = |
964 | GetDnnWorkspaceLimitOrDefault(); |
965 | int device_id = stream->parent()->device_ordinal(); |
966 | DataType dtype = input.dtype(); |
967 | ConvParameters conv_parameters = { |
968 | dims.batch_size, // batch |
969 | dims.in_depth, // in_depths |
970 | {{input_desc.height(), // in_rows |
971 | input_desc.width()}}, // in_cols |
972 | compute_data_format, // compute_data_format |
973 | dims.out_depth, // out_depths |
974 | {{dims.spatial_dims[0].filter_size, // filter_rows |
975 | dims.spatial_dims[1].filter_size, // filter_cols |
976 | filter_shape.dim_size(2)}}, // filter_depth |
977 | {{dims.spatial_dims[0].dilation, // dilation_rows |
978 | dims.spatial_dims[1].dilation}}, // dilation_cols |
979 | {{dims.spatial_dims[0].stride, // stride_rows |
980 | dims.spatial_dims[1].stride}}, // stride_cols |
981 | {{common_padding_rows, // padding_rows |
982 | common_padding_cols}}, // padding_cols |
983 | dtype, // tensor datatype |
984 | device_id, // device_id |
985 | conv_desc.group_count() // group_count |
986 | }; |
987 | |
988 | auto entry_or = AutotuneUnfusedConv( |
989 | cudnn_use_autotune, AutotuneConvBwdFilter::GetInstance(), conv_parameters, |
990 | ctx, se::dnn::ConvolutionKind::BACKWARD_FILTER, input_desc, input_ptr, |
991 | filter_desc, filter_backprop_ptr, conv_desc, output_desc, |
992 | out_backprop_ptr, ConvolveBackwardFilterScratchSize); |
993 | OP_REQUIRES_OK(ctx, entry_or.status()); |
994 | auto autotune_entry = std::move(entry_or).value(); |
995 | |
996 | DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, ctx); |
997 | Status cudnn_launch_status = LaunchAutotunedConv( |
998 | autotune_entry, &scratch_allocator, |
999 | se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, input_ptr, |
1000 | filter_desc, filter_backprop_ptr, conv_desc, output_desc, |
1001 | out_backprop_ptr); |
1002 | if (!cudnn_launch_status.ok()) { |
1003 | ctx->SetStatus(cudnn_launch_status); |
1004 | return; |
1005 | } |
1006 | |
1007 | FilterTensorFormat src_filter_format = |
1008 | compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; |
1009 | |
1010 | auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; |
1011 | functor::ReverseTransformFilter<GPUDevice, T, 4>()( |
1012 | ctx->eigen_device<GPUDevice>(), src_filter_format, |
1013 | toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(), |
1014 | filter_backprop->tensor<T, 4>()); |
1015 | } |
1016 | |
1017 | // Forward declarations of the functor specializations for GPU. |
1018 | namespace functor { |
1019 | #define DECLARE_GPU_SPEC(T) \ |
1020 | template <> \ |
1021 | void TransformFilter<GPUDevice, T, int, 4>::operator()( \ |
1022 | const GPUDevice& d, FilterTensorFormat dst_filter_format, \ |
1023 | typename TTypes<T, 4, int>::ConstTensor in, \ |
1024 | typename TTypes<T, 4, int>::Tensor out); \ |
1025 | extern template struct TransformFilter<GPUDevice, T, int, 4>; \ |
1026 | template <> \ |
1027 | void PadInput<GPUDevice, T, int, 4>::operator()( \ |
1028 | const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ |
1029 | const std::array<int, 2>& padding_left, \ |
1030 | const std::array<int, 2>& padding_right, \ |
1031 | typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \ |
1032 | const T& padding_value); \ |
1033 | extern template struct PadInput<GPUDevice, T, int, 4>; |
1034 | |
1035 | DECLARE_GPU_SPEC(float); |
1036 | DECLARE_GPU_SPEC(Eigen::half); |
1037 | DECLARE_GPU_SPEC(double); |
1038 | #undef DECLARE_GPU_SPEC |
1039 | } // namespace functor |
1040 | |
1041 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter" ) |
1042 | .Device(DEVICE_GPU) |
1043 | .TypeConstraint<double>("T" ) |
1044 | .HostMemory("filter_sizes" ), |
1045 | Conv2DBackpropFilterOp<GPUDevice, double>); |
1046 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter" ) |
1047 | .Device(DEVICE_GPU) |
1048 | .TypeConstraint<float>("T" ) |
1049 | .HostMemory("filter_sizes" ), |
1050 | Conv2DBackpropFilterOp<GPUDevice, float>); |
1051 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter" ) |
1052 | .Device(DEVICE_GPU) |
1053 | .TypeConstraint<Eigen::half>("T" ) |
1054 | .HostMemory("filter_sizes" ), |
1055 | Conv2DBackpropFilterOp<GPUDevice, Eigen::half>); |
1056 | |
1057 | // To be used inside depthwise_conv_grad_op.cc. |
1058 | // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc. |
1059 | template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>; |
1060 | template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>; |
1061 | template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>; |
1062 | |
1063 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1064 | |
1065 | } // namespace tensorflow |
1066 | |