1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ |
19 | #define TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ |
20 | |
21 | #define USE_EIGEN_TENSOR |
22 | #define EIGEN_USE_THREADS |
23 | |
24 | #include <algorithm> |
25 | #include <limits> |
26 | #include <vector> |
27 | |
28 | #include "absl/base/dynamic_annotations.h" |
29 | #include "tensorflow/core/framework/bounds_check.h" |
30 | #include "tensorflow/core/framework/kernel_shape_util.h" |
31 | #include "tensorflow/core/framework/numeric_op.h" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/register_types.h" |
34 | #include "tensorflow/core/framework/tensor.h" |
35 | #include "tensorflow/core/framework/tensor_shape.h" |
36 | #include "tensorflow/core/framework/tensor_slice.h" |
37 | #include "tensorflow/core/kernels/conv_2d.h" |
38 | #include "tensorflow/core/kernels/conv_grad_ops.h" |
39 | #include "tensorflow/core/kernels/conv_grad_shape_utils.h" |
40 | #include "tensorflow/core/kernels/fill_functor.h" |
41 | #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS |
42 | #include "tensorflow/core/kernels/xsmm_conv2d.h" |
43 | #endif |
44 | #include "tensorflow/core/lib/core/errors.h" |
45 | #include "tensorflow/core/lib/gtl/array_slice.h" |
46 | #include "tensorflow/core/platform/logging.h" |
47 | #include "tensorflow/core/platform/macros.h" |
48 | #include "tensorflow/core/util/padding.h" |
49 | #include "tensorflow/core/util/tensor_format.h" |
50 | #include "tensorflow/core/util/use_cudnn.h" |
51 | #include "tensorflow/core/util/work_sharder.h" |
52 | |
53 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
54 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
55 | #endif |
56 | |
57 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
58 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
59 | #include "tensorflow/core/platform/stream_executor.h" |
60 | #include "tensorflow/core/util/proto/proto_utils.h" |
61 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
62 | #if GOOGLE_CUDA |
63 | #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" |
64 | #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" |
65 | #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" |
66 | #endif // GOOGLE_CUDA |
67 | |
68 | namespace tensorflow { |
69 | |
70 | typedef Eigen::ThreadPoolDevice CPUDevice; |
71 | typedef Eigen::GpuDevice GPUDevice; |
72 | |
73 | // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage |
74 | // order (height, width, depth), constructed from patches in 'col_data', which |
75 | // is required to be in storage order (out_height * out_width, filter_height, |
76 | // filter_width, in_depth). Implementation by Yangqing Jia (jiayq). |
77 | template <typename T> |
78 | void Col2im(const T* col_data, const int depth, const int height, |
79 | const int width, const int filter_h, const int filter_w, |
80 | const int pad_t, const int pad_l, const int pad_b, const int pad_r, |
81 | const int stride_h, const int stride_w, T* __restrict im_data) { |
82 | int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; |
83 | int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; |
84 | int h_pad = -pad_t; |
85 | for (int h = 0; h < height_col; ++h) { |
86 | int w_pad = -pad_l; |
87 | for (int w = 0; w < width_col; ++w) { |
88 | T* im_patch_data = im_data + (h_pad * width + w_pad) * depth; |
89 | for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { |
90 | for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { |
91 | if (ih >= 0 && ih < height && iw >= 0 && iw < width) { |
92 | for (int i = 0; i < depth; ++i) { |
93 | im_patch_data[i] += col_data[i]; |
94 | } |
95 | } |
96 | im_patch_data += depth; |
97 | col_data += depth; |
98 | } |
99 | // Jump over remaining number of depth. |
100 | im_patch_data += depth * (width - filter_w); |
101 | } |
102 | w_pad += stride_w; |
103 | } |
104 | h_pad += stride_h; |
105 | } |
106 | } |
107 | |
108 | // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU |
109 | // and GPU (for int32 only). |
110 | template <typename Device, typename T> |
111 | struct LaunchConv2DBackpropInputOpImpl { |
112 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
113 | const Tensor& out_backprop, const Tensor& filter, |
114 | int row_dilation, int col_dilation, int row_stride, |
115 | int col_stride, const Padding& padding, |
116 | const std::vector<int64_t>& explicit_paddings, |
117 | Tensor* in_backprop, TensorFormat data_format) { |
118 | std::vector<int32> strides(4, 1); |
119 | std::vector<int32> dilations(4, 1); |
120 | |
121 | auto input_h = GetTensorDimIndex(data_format, 'H'); |
122 | auto input_w = GetTensorDimIndex(data_format, 'W'); |
123 | strides[input_h] = row_stride; |
124 | strides[input_w] = col_stride; |
125 | dilations[input_h] = row_dilation; |
126 | dilations[input_w] = col_dilation; |
127 | |
128 | const TensorShape& input_shape = in_backprop->shape(); |
129 | const TensorShape& filter_shape = filter.shape(); |
130 | |
131 | ConvBackpropDimensions dims; |
132 | OP_REQUIRES_OK( |
133 | ctx, ConvBackpropComputeDimensionsV2( |
134 | "Conv2DBackpropInput" , /*num_spatial_dims=*/2, input_shape, |
135 | filter_shape, out_backprop.shape(), dilations, strides, |
136 | padding, explicit_paddings, data_format, &dims)); |
137 | |
138 | int64_t padding_top = -1, padding_bottom = -1; |
139 | int64_t padding_left = -1, padding_right = -1; |
140 | if (padding == EXPLICIT) { |
141 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', |
142 | &padding_top, &padding_bottom); |
143 | GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', |
144 | &padding_left, &padding_right); |
145 | } |
146 | |
147 | int64_t expected_out_rows, expected_out_cols; |
148 | // The function is guaranteed to succeed because we checked the output and |
149 | // padding was valid earlier. |
150 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
151 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
152 | row_dilation, row_stride, padding, &expected_out_rows, &padding_top, |
153 | &padding_bottom)); |
154 | DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); |
155 | |
156 | TF_CHECK_OK(GetWindowedOutputSizeVerboseV2( |
157 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
158 | col_dilation, col_stride, padding, &expected_out_cols, &padding_left, |
159 | &padding_right)); |
160 | DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols); |
161 | |
162 | if (std::is_same<Device, GPUDevice>::value) { |
163 | int64_t size = 1; |
164 | #define REQUIRES_32BIT(x) \ |
165 | size *= x; \ |
166 | OP_REQUIRES(ctx, \ |
167 | FastBoundsCheck(x, std::numeric_limits<int32>::max()) && \ |
168 | FastBoundsCheck(size, std::numeric_limits<int32>::max()), \ |
169 | errors::InvalidArgument("Tensor too large")) |
170 | |
171 | REQUIRES_32BIT(in_backprop->dim_size(0)); |
172 | REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom); |
173 | REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right); |
174 | REQUIRES_32BIT(in_backprop->dim_size(3)); |
175 | #undef REQUIRES_32BIT |
176 | } |
177 | |
178 | auto in_backprop_t = in_backprop->tensor<T, 4>(); |
179 | auto out_backprop_t = out_backprop.tensor<T, 4>(); |
180 | auto filter_t = filter.tensor<T, 4>(); |
181 | |
182 | // WARNING: Need to swap row/col, padding_top/padding_left, and |
183 | // padding_bottom/padding_right when calling Eigen. Eigen expects tensors |
184 | // in NWHC format, but Tensorflow uses NHWC. |
185 | |
186 | if (padding != EXPLICIT) { |
187 | // If padding was not explicitly defined, Eigen spatial convolution |
188 | // backward input will infer correct forward paddings from input tensors. |
189 | functor::SpatialConvolutionBackwardInputFunc<Device, T>()( |
190 | ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t, |
191 | col_stride, row_stride, col_dilation, row_dilation); |
192 | } else { |
193 | functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Device, |
194 | T>()( |
195 | ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t, |
196 | in_backprop_t.dimension(2) + (padding_left + padding_right), |
197 | in_backprop_t.dimension(1) + (padding_top + padding_bottom), |
198 | col_stride, row_stride, col_dilation, row_dilation, padding_top, |
199 | padding_left); |
200 | } |
201 | } |
202 | }; |
203 | |
204 | // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU. |
205 | template <typename T> |
206 | struct LaunchConv2DBackpropInputOp<CPUDevice, T> { |
207 | void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, |
208 | const Tensor& out_backprop, const Tensor& filter, |
209 | int row_dilation, int col_dilation, int row_stride, |
210 | int col_stride, const Padding& padding, |
211 | const std::vector<int64_t>& explicit_paddings, |
212 | Tensor* in_backprop, TensorFormat data_format) { |
213 | LaunchConv2DBackpropInputOpImpl<CPUDevice, T> launcher; |
214 | launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter, |
215 | row_dilation, col_dilation, row_stride, col_stride, padding, |
216 | explicit_paddings, in_backprop, data_format); |
217 | } |
218 | }; |
219 | |
220 | #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS |
221 | template <typename Device, class T> |
222 | struct LaunchXsmmBackwardInputConvolution { |
223 | bool operator()(OpKernelContext* context, const Device& d, |
224 | typename TTypes<T, 4>::Tensor input_backward, |
225 | typename TTypes<T, 4>::ConstTensor kernel, |
226 | typename TTypes<T, 4>::ConstTensor output_backward, |
227 | int input_rows, int input_cols, int row_stride, |
228 | int col_stride, int pad_h, int pad_w, |
229 | TensorFormat data_format) const { |
230 | return false; |
231 | } |
232 | }; |
233 | |
234 | template <> |
235 | struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> { |
236 | bool operator()(OpKernelContext* context, const CPUDevice& d, |
237 | typename TTypes<float, 4>::Tensor input_backward, |
238 | typename TTypes<float, 4>::ConstTensor kernel, |
239 | typename TTypes<float, 4>::ConstTensor output_backward, |
240 | int input_rows, int input_cols, int row_stride, |
241 | int col_stride, int pad_h, int pad_w, |
242 | TensorFormat data_format) const { |
243 | auto batch = input_backward.dimension(0); |
244 | auto in_depth = input_backward.dimension(3); |
245 | auto out_depth = output_backward.dimension(3); |
246 | auto filter_rows = kernel.dimension(0); |
247 | auto filter_cols = kernel.dimension(1); |
248 | auto num_threads = |
249 | context->device()->tensorflow_cpu_worker_threads()->num_threads; |
250 | // See libxsmm_dnn.h for this struct definition. |
251 | libxsmm_dnn_conv_desc desc; |
252 | desc.N = batch; |
253 | desc.C = in_depth; |
254 | desc.H = input_rows; |
255 | desc.W = input_cols; |
256 | desc.K = out_depth; |
257 | desc.R = filter_rows; |
258 | desc.S = filter_cols; |
259 | desc.u = row_stride; |
260 | desc.v = col_stride; |
261 | desc.pad_h = pad_h; |
262 | desc.pad_w = pad_w; |
263 | desc.pad_h_in = 0; |
264 | desc.pad_w_in = 0; |
265 | desc.pad_h_out = 0; |
266 | desc.pad_w_out = 0; |
267 | desc.threads = num_threads; |
268 | desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; |
269 | desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; |
270 | desc.filter_format = |
271 | LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; |
272 | desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; |
273 | desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE; |
274 | desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32; |
275 | desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32; |
276 | auto input_ptr = input_backward.data(); |
277 | auto filter_ptr = kernel.data(); |
278 | auto output_ptr = output_backward.data(); |
279 | |
280 | bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()( |
281 | context, desc, input_ptr, filter_ptr, output_ptr); |
282 | return success; |
283 | } |
284 | }; |
285 | #endif |
286 | |
287 | template <typename T> |
288 | struct Conv2DCustomBackpropInputMatMulFunctor { |
289 | using MatrixMap = Eigen::Map< |
290 | Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; |
291 | using ConstMatrixMap = Eigen::Map< |
292 | const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; |
293 | |
294 | void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data, |
295 | const int filter_total_size, const int output_image_size, |
296 | const int dims_out_depth, T* im2col_buf) { |
297 | // Compute gradient into 'im2col_buf'. |
298 | MatrixMap C(im2col_buf, output_image_size, filter_total_size); |
299 | |
300 | ConstMatrixMap A(out_data, output_image_size, dims_out_depth); |
301 | ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth); |
302 | |
303 | C.noalias() = A * B.transpose(); |
304 | } |
305 | }; |
306 | |
307 | #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) |
308 | template <> |
309 | struct Conv2DCustomBackpropInputMatMulFunctor<float> { |
310 | using T = float; |
311 | |
312 | void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data, |
313 | const int filter_total_size, const int output_image_size, |
314 | const int dims_out_depth, T* im2col_buf) { |
315 | // Inputs are in RowMajor order. |
316 | // im2col = out_data * filter_data^T |
317 | // [ois x fts] = [ois x dod] * [fts x dod]^T |
318 | // |
319 | // Dimension names: |
320 | // out_image_size -> ois |
321 | // filter_total_size -> fts |
322 | // dims_out_depth -> dod |
323 | |
324 | const int m = output_image_size; |
325 | const int n = filter_total_size; |
326 | const int k = dims_out_depth; // contraction dim |
327 | |
328 | const char transposeA = 'N'; // sgemm(A) == filter_data |
329 | const char transposeB = 'T'; // sgemm(B) == out_data |
330 | |
331 | const int ldA = dims_out_depth; |
332 | const int ldB = dims_out_depth; |
333 | const int ldC = filter_total_size; |
334 | |
335 | const float alpha = 1.0; |
336 | const float beta = 0.0; |
337 | |
338 | // dnnl_sgemm code can't be instrumented with msan. |
339 | ANNOTATE_MEMORY_IS_INITIALIZED( |
340 | im2col_buf, filter_total_size * output_image_size * sizeof(T)); |
341 | |
342 | dnnl_status_t st = |
343 | dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA, |
344 | filter_data, ldB, beta, im2col_buf, ldC); |
345 | |
346 | OP_REQUIRES( |
347 | ctx, st == 0, |
348 | errors::Internal("Failed to call dnnl_sgemm. Error code: " , st)); |
349 | } |
350 | }; |
351 | #endif |
352 | |
353 | template <typename Device, class T> |
354 | class Conv2DBackpropInputOp : public OpKernel { |
355 | public: |
356 | explicit Conv2DBackpropInputOp(OpKernelConstruction* context) |
357 | : OpKernel(context) { |
358 | string data_format; |
359 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
360 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
361 | errors::InvalidArgument("Invalid data format" )); |
362 | |
363 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
364 | OP_REQUIRES(context, strides_.size() == 4, |
365 | errors::InvalidArgument("Sliding window strides field must " |
366 | "specify 4 dimensions" )); |
367 | int stride_n = GetTensorDim(strides_, data_format_, 'N'); |
368 | int stride_c = GetTensorDim(strides_, data_format_, 'C'); |
369 | int stride_h = GetTensorDim(strides_, data_format_, 'H'); |
370 | int stride_w = GetTensorDim(strides_, data_format_, 'W'); |
371 | OP_REQUIRES( |
372 | context, (stride_n == 1 && stride_c == 1), |
373 | errors::Unimplemented("Current implementation does not yet support " |
374 | "strides in the batch and depth dimensions." )); |
375 | OP_REQUIRES(context, stride_h > 0 && stride_w > 0, |
376 | errors::InvalidArgument( |
377 | "Row and column strides should be larger than 0." )); |
378 | |
379 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilations_)); |
380 | OP_REQUIRES(context, dilations_.size() == 4, |
381 | errors::InvalidArgument("Sliding window dilations field must " |
382 | "specify 4 dimensions" )); |
383 | int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); |
384 | int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); |
385 | int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); |
386 | int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); |
387 | OP_REQUIRES( |
388 | context, (dilation_n == 1 && dilation_c == 1), |
389 | errors::Unimplemented("Current implementation does not yet support " |
390 | "dilations in the batch and depth dimensions." )); |
391 | OP_REQUIRES( |
392 | context, dilation_h > 0 && dilation_w > 0, |
393 | errors::InvalidArgument("Dilated rates should be larger than 0." )); |
394 | |
395 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
396 | OP_REQUIRES_OK(context, |
397 | context->GetAttr("explicit_paddings" , &explicit_paddings_)); |
398 | OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, |
399 | /*num_dims=*/4, data_format_)); |
400 | |
401 | OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu" , &use_cudnn_)); |
402 | cudnn_use_autotune_ = CudnnUseAutotune(); |
403 | |
404 | if (std::is_same<Device, CPUDevice>::value || |
405 | std::is_same<T, int32>::value) { |
406 | OP_REQUIRES( |
407 | context, data_format_ == FORMAT_NHWC, |
408 | errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] " |
409 | "only supports NHWC data format." )); |
410 | |
411 | // TODO(yangzihao): Add a CPU implementation for dilated convolution. |
412 | OP_REQUIRES( |
413 | context, (dilation_h == 1 && dilation_w == 1), |
414 | errors::InvalidArgument( |
415 | "Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support " |
416 | "dilation rates larger than 1." )); |
417 | } |
418 | } |
419 | |
420 | void Compute(OpKernelContext* context) override { |
421 | const Tensor& input_sizes = context->input(0); |
422 | const Tensor& filter = context->input(1); |
423 | const Tensor& out_backprop = context->input(2); |
424 | |
425 | OP_REQUIRES( |
426 | context, out_backprop.dims() == 4, |
427 | errors::InvalidArgument("input_sizes must be 4-dimensional, got: " , |
428 | out_backprop.dims())); |
429 | |
430 | TensorShape input_shape; |
431 | OP_REQUIRES_OK(context, |
432 | Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), |
433 | out_backprop.shape(), |
434 | data_format_, &input_shape)); |
435 | |
436 | Tensor* in_backprop = nullptr; |
437 | OP_REQUIRES_OK(context, |
438 | context->allocate_output(0, input_shape, &in_backprop)); |
439 | |
440 | // If there is nothing to compute, return. |
441 | if (input_shape.num_elements() == 0) { |
442 | return; |
443 | } |
444 | |
445 | // If shapes are valid but `out_backprop` is empty, in_backprop should be |
446 | // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input. |
447 | if (out_backprop.NumElements() == 0) { |
448 | functor::SetZeroFunctor<Device, T> set_zero; |
449 | set_zero(context->eigen_device<Device>(), |
450 | in_backprop->template flat<T>()); |
451 | return; |
452 | } |
453 | |
454 | // For now we take the stride from the second and third dimensions only (we |
455 | // do not support striding on the batch or depth dimension). |
456 | const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); |
457 | const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); |
458 | const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); |
459 | const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); |
460 | |
461 | VLOG(2) << "Conv2DBackpropInput:" |
462 | << " input: " << input_shape.DebugString() |
463 | << " filter:" << filter.shape().DebugString() |
464 | << " out_backprop: " << out_backprop.shape().DebugString() |
465 | << " strides: [" << stride_rows << ", " << stride_cols << "]" |
466 | << " dilations: [" << dilation_rows << ", " << dilation_cols << "]" ; |
467 | |
468 | LaunchConv2DBackpropInputOp<Device, T> launch; |
469 | launch(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, |
470 | dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, |
471 | explicit_paddings_, in_backprop, data_format_); |
472 | } |
473 | |
474 | private: |
475 | std::vector<int32> dilations_; |
476 | std::vector<int32> strides_; |
477 | TensorFormat data_format_; |
478 | Padding padding_; |
479 | std::vector<int64_t> explicit_paddings_; |
480 | |
481 | bool use_cudnn_ = false; |
482 | bool cudnn_use_autotune_ = false; |
483 | |
484 | TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp); |
485 | }; |
486 | |
487 | // Based on implementation written by Yangqing Jia (jiayq). |
488 | template <typename Device, class T> |
489 | class Conv2DCustomBackpropInputOp : public OpKernel { |
490 | public: |
491 | explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context) |
492 | : OpKernel(context) { |
493 | string data_format; |
494 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format)); |
495 | OP_REQUIRES(context, FormatFromString(data_format, &data_format_), |
496 | errors::InvalidArgument("Invalid data format" )); |
497 | OP_REQUIRES(context, data_format_ == FORMAT_NHWC, |
498 | errors::InvalidArgument( |
499 | "Conv2DCustomBackpropInputOp only supports NHWC." )); |
500 | OP_REQUIRES_OK(context, context->GetAttr("strides" , &strides_)); |
501 | OP_REQUIRES(context, strides_.size() == 4, |
502 | errors::InvalidArgument("Sliding window strides field must " |
503 | "specify 4 dimensions" )); |
504 | OP_REQUIRES( |
505 | context, (strides_[0] == 1 && strides_[3] == 1), |
506 | errors::Unimplemented("Current implementation does not yet support " |
507 | "strides in the batch and depth dimensions." )); |
508 | OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, |
509 | errors::InvalidArgument( |
510 | "Row and column strides should be larger than 0." )); |
511 | OP_REQUIRES_OK(context, context->GetAttr("padding" , &padding_)); |
512 | OP_REQUIRES_OK(context, context->GetAttr("dilations" , &dilations_)); |
513 | OP_REQUIRES(context, dilations_.size() == 4, |
514 | errors::InvalidArgument("Sliding window dilations field must " |
515 | "specify 4 dimensions" )); |
516 | OP_REQUIRES( |
517 | context, (dilations_[0] == 1 && dilations_[3] == 1), |
518 | errors::Unimplemented("Current implementation does not yet support " |
519 | "dilations in the batch and depth dimensions." )); |
520 | // TODO(yangzihao): Add a CPU implementation for dilated convolution. |
521 | OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), |
522 | errors::InvalidArgument( |
523 | "Current libxsmm and customized CPU implementations do " |
524 | "not yet support dilation rates larger than 1." )); |
525 | OP_REQUIRES_OK(context, |
526 | context->GetAttr("explicit_paddings" , &explicit_paddings_)); |
527 | OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, |
528 | /*num_dims=*/4, data_format_)); |
529 | } |
530 | |
531 | void Compute(OpKernelContext* context) override { |
532 | const Tensor& input_sizes = context->input(0); |
533 | const Tensor& filter = context->input(1); |
534 | const Tensor& out_backprop = context->input(2); |
535 | OP_REQUIRES( |
536 | context, out_backprop.dims() == 4, |
537 | errors::InvalidArgument("input_sizes must be 4-dimensional, got: " , |
538 | out_backprop.dims())); |
539 | |
540 | TensorShape input_shape; |
541 | OP_REQUIRES_OK(context, |
542 | Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), |
543 | out_backprop.shape(), |
544 | data_format_, &input_shape)); |
545 | |
546 | ConvBackpropDimensions dims; |
547 | OP_REQUIRES_OK(context, |
548 | ConvBackpropComputeDimensionsV2( |
549 | "Conv2DCustomBackpropInput" , /*num_spatial_dims=*/2, |
550 | input_shape, filter.shape(), out_backprop.shape(), |
551 | /*dilations=*/{1, 1, 1, 1}, strides_, padding_, |
552 | explicit_paddings_, data_format_, &dims)); |
553 | |
554 | OP_REQUIRES(context, dims.in_depth == filter.shape().dim_size(2), |
555 | errors::InvalidArgument( |
556 | "Gradients for grouped convolutions are not " |
557 | "supported on CPU. Please file a feature request if you " |
558 | "run into this issue. Computed input depth " , |
559 | dims.in_depth, " doesn't match filter input depth " , |
560 | filter.shape().dim_size(2))); |
561 | OP_REQUIRES( |
562 | context, dims.out_depth == filter.shape().dim_size(3), |
563 | errors::InvalidArgument("Computed output depth " , dims.out_depth, |
564 | " doesn't match filter output depth " , |
565 | filter.shape().dim_size(3))); |
566 | |
567 | Tensor* in_backprop = nullptr; |
568 | OP_REQUIRES_OK(context, |
569 | context->allocate_output(0, input_shape, &in_backprop)); |
570 | |
571 | // If there is nothing to compute, return. |
572 | if (input_shape.num_elements() == 0) { |
573 | return; |
574 | } |
575 | |
576 | // If shapes are valid but `out_backprop` is empty, in_backprop should be |
577 | // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input. |
578 | if (out_backprop.NumElements() == 0) { |
579 | functor::SetZeroFunctor<Device, T> set_zero; |
580 | set_zero(context->eigen_device<Device>(), |
581 | in_backprop->template flat<T>()); |
582 | return; |
583 | } |
584 | |
585 | // TODO(ezhulenev): Remove custom kernel and move XSMM support to |
586 | // LaunchConv2DBackpropInputOp functor. |
587 | #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ |
588 | defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS |
589 | int64 pad_top, pad_bottom; |
590 | int64 pad_left, pad_right; |
591 | OP_REQUIRES_OK( |
592 | context, |
593 | GetWindowedOutputSizeVerbose( |
594 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
595 | dims.spatial_dims[0].stride, padding_, |
596 | &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); |
597 | OP_REQUIRES_OK( |
598 | context, |
599 | GetWindowedOutputSizeVerbose( |
600 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
601 | dims.spatial_dims[1].stride, padding_, |
602 | &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); |
603 | |
604 | if (pad_left == pad_right && pad_top == pad_bottom) { |
605 | if (LaunchXsmmBackwardInputConvolution<Device, T>()( |
606 | context, context->eigen_device<Device>(), |
607 | in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(), |
608 | out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size, |
609 | dims.spatial_dims[1].input_size, |
610 | static_cast<int>(dims.spatial_dims[0].stride), |
611 | static_cast<int>(dims.spatial_dims[1].stride), |
612 | static_cast<int>(pad_top), static_cast<int>(pad_left), |
613 | data_format_)) { |
614 | return; |
615 | } |
616 | } |
617 | #else |
618 | int64_t pad_top, pad_bottom; |
619 | int64_t pad_left, pad_right; |
620 | #endif |
621 | if (padding_ == Padding::EXPLICIT) { |
622 | pad_top = explicit_paddings_[2]; |
623 | pad_bottom = explicit_paddings_[3]; |
624 | pad_left = explicit_paddings_[4]; |
625 | pad_right = explicit_paddings_[5]; |
626 | } |
627 | OP_REQUIRES_OK( |
628 | context, |
629 | GetWindowedOutputSizeVerbose( |
630 | dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, |
631 | dims.spatial_dims[0].stride, padding_, |
632 | &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); |
633 | OP_REQUIRES_OK( |
634 | context, |
635 | GetWindowedOutputSizeVerbose( |
636 | dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, |
637 | dims.spatial_dims[1].stride, padding_, |
638 | &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); |
639 | |
640 | // The total dimension size of each kernel. |
641 | const int filter_total_size = dims.spatial_dims[0].filter_size * |
642 | dims.spatial_dims[1].filter_size * |
643 | dims.in_depth; |
644 | // The output image size is the spatial size of the output. |
645 | const int output_image_size = |
646 | dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size; |
647 | |
648 | // TODO(andydavis) Get L2/L3 cache sizes from device. |
649 | const size_t l2_cache_size = 256LL << 10; |
650 | const size_t l3_cache_size = 30LL << 20; |
651 | |
652 | // Use L3 cache size as target working set size. |
653 | const size_t target_working_set_size = l3_cache_size / sizeof(T); |
654 | |
655 | // Calculate size of matrices involved in MatMul: C = A x B. |
656 | const size_t size_A = output_image_size * dims.out_depth; |
657 | |
658 | const size_t size_B = filter_total_size * dims.out_depth; |
659 | |
660 | const size_t size_C = output_image_size * filter_total_size; |
661 | |
662 | const size_t work_unit_size = size_A + size_B + size_C; |
663 | |
664 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
665 | |
666 | // Calculate per-thread work unit size. |
667 | const size_t thread_work_unit_size = |
668 | work_unit_size / worker_threads.num_threads; |
669 | |
670 | // Set minimum per-thread work unit size to size of L2 cache. |
671 | const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T); |
672 | |
673 | // Use parallel tensor contractions if there is no batching, or if the |
674 | // minimum per-thread work unit size threshold has been exceeded. |
675 | // Otherwise, revert to multiple single-threaded matmul ops running in |
676 | // parallel to keep all threads busy. |
677 | // TODO(andydavis) Explore alternatives to branching the code in this way |
678 | // (i.e. run multiple, parallel tensor contractions in another thread pool). |
679 | const bool use_parallel_contraction = |
680 | dims.batch_size == 1 || |
681 | thread_work_unit_size >= min_thread_work_unit_size; |
682 | |
683 | OP_REQUIRES( |
684 | context, work_unit_size > 0, |
685 | errors::InvalidArgument("input, filter_sizes and out_backprop tensors " |
686 | "must all have at least 1 element" )); |
687 | |
688 | const size_t shard_size = |
689 | use_parallel_contraction |
690 | ? 1 |
691 | : (target_working_set_size + work_unit_size - 1) / work_unit_size; |
692 | |
693 | Tensor col_buffer; |
694 | OP_REQUIRES_OK(context, |
695 | context->allocate_temp( |
696 | DataTypeToEnum<T>::value, |
697 | TensorShape({static_cast<int64_t>(shard_size), |
698 | static_cast<int64_t>(output_image_size), |
699 | static_cast<int64_t>(filter_total_size)}), |
700 | &col_buffer)); |
701 | |
702 | // The input offset corresponding to a single input image. |
703 | const int input_offset = dims.spatial_dims[0].input_size * |
704 | dims.spatial_dims[1].input_size * dims.in_depth; |
705 | // The output offset corresponding to a single output image. |
706 | const int output_offset = dims.spatial_dims[0].output_size * |
707 | dims.spatial_dims[1].output_size * dims.out_depth; |
708 | |
709 | const T* filter_data = filter.template flat<T>().data(); |
710 | T* col_buffer_data = col_buffer.template flat<T>().data(); |
711 | const T* out_backprop_data = out_backprop.template flat<T>().data(); |
712 | |
713 | auto in_backprop_flat = in_backprop->template flat<T>(); |
714 | T* input_backprop_data = in_backprop_flat.data(); |
715 | in_backprop_flat.device(context->eigen_device<Device>()) = |
716 | in_backprop_flat.constant(T(0)); |
717 | |
718 | if (use_parallel_contraction) { |
719 | typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, |
720 | Eigen::Unaligned> |
721 | TensorMap; |
722 | typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, |
723 | Eigen::Unaligned> |
724 | ConstTensorMap; |
725 | |
726 | // Initialize contraction dims (we need to transpose 'B' below). |
727 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; |
728 | contract_dims[0].first = 1; |
729 | contract_dims[0].second = 1; |
730 | |
731 | for (int image_id = 0; image_id < dims.batch_size; ++image_id) { |
732 | // Compute gradient into col_buffer. |
733 | TensorMap C(col_buffer_data, output_image_size, filter_total_size); |
734 | |
735 | ConstTensorMap A(out_backprop_data + output_offset * image_id, |
736 | output_image_size, dims.out_depth); |
737 | ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); |
738 | |
739 | C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); |
740 | |
741 | Col2im<T>( |
742 | col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size, |
743 | dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size, |
744 | dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom, |
745 | pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, |
746 | input_backprop_data); |
747 | |
748 | input_backprop_data += input_offset; |
749 | } |
750 | } else { |
751 | for (int image_id = 0; image_id < dims.batch_size; |
752 | image_id += shard_size) { |
753 | const int shard_limit = |
754 | std::min(static_cast<int>(shard_size), |
755 | static_cast<int>(dims.batch_size) - image_id); |
756 | |
757 | auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom, |
758 | &pad_right, &output_image_size, &filter_total_size, |
759 | &input_backprop_data, &col_buffer_data, |
760 | &out_backprop_data, &filter_data, &input_offset, |
761 | &output_offset, &size_C](int64_t start, int64_t limit) { |
762 | for (int shard_id = start; shard_id < limit; ++shard_id) { |
763 | T* im2col_buf = col_buffer_data + shard_id * size_C; |
764 | T* input_data = input_backprop_data + shard_id * input_offset; |
765 | const T* out_data = out_backprop_data + shard_id * output_offset; |
766 | |
767 | Conv2DCustomBackpropInputMatMulFunctor<T>()( |
768 | context, out_data, filter_data, filter_total_size, |
769 | output_image_size, dims.out_depth, im2col_buf); |
770 | |
771 | Col2im<T>(im2col_buf, dims.in_depth, |
772 | dims.spatial_dims[0].input_size, |
773 | dims.spatial_dims[1].input_size, |
774 | dims.spatial_dims[0].filter_size, |
775 | dims.spatial_dims[1].filter_size, pad_top, pad_left, |
776 | pad_bottom, pad_right, dims.spatial_dims[0].stride, |
777 | dims.spatial_dims[1].stride, input_data); |
778 | } |
779 | }; |
780 | Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, |
781 | work_unit_size, shard); |
782 | |
783 | input_backprop_data += input_offset * shard_limit; |
784 | out_backprop_data += output_offset * shard_limit; |
785 | } |
786 | } |
787 | } |
788 | |
789 | private: |
790 | std::vector<int32> dilations_; |
791 | std::vector<int32> strides_; |
792 | Padding padding_; |
793 | std::vector<int64_t> explicit_paddings_; |
794 | TensorFormat data_format_; |
795 | |
796 | TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp); |
797 | }; |
798 | |
799 | // TODO(ezhulenev): Add a cost model to switch between custom/Eigen ops. |
800 | #define DEFAULT_CONV_2D_BACKPROP_CPU_OP Conv2DCustomBackpropInputOp |
801 | |
802 | #define REGISTER_CONV_2D_BACKPROP_CPU_KERNELS(T) \ |
803 | REGISTER_KERNEL_BUILDER( \ |
804 | Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
805 | DEFAULT_CONV_2D_BACKPROP_CPU_OP<CPUDevice, T>); \ |
806 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ |
807 | .Device(DEVICE_CPU) \ |
808 | .Label("custom") \ |
809 | .TypeConstraint<T>("T"), \ |
810 | Conv2DCustomBackpropInputOp<CPUDevice, T>); \ |
811 | REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ |
812 | .Device(DEVICE_CPU) \ |
813 | .Label("eigen_tensor") \ |
814 | .TypeConstraint<T>("T"), \ |
815 | Conv2DBackpropInputOp<CPUDevice, T>); |
816 | |
817 | } // namespace tensorflow |
818 | |
819 | #endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ |
820 | |