1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <cmath>
20
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/kernel_shape_util.h"
23#include "tensorflow/core/framework/numeric_op.h"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/tensor_types.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/kernels/cast_op.h"
31#include "tensorflow/core/kernels/conv_grad_ops.h"
32#include "tensorflow/core/kernels/depthwise_conv_op.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/types.h"
36#include "tensorflow/core/util/determinism.h"
37#include "tensorflow/core/util/padding.h"
38#include "tensorflow/core/util/tensor_format.h"
39#include "tensorflow/core/util/use_cudnn.h"
40#include "tensorflow/core/util/work_sharder.h"
41
42#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43
44#if GOOGLE_CUDA
45#include "third_party/gpus/cudnn/cudnn.h"
46#endif
47
48#include "tensorflow/core/platform/stream_executor.h"
49#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50
51namespace tensorflow {
52
53// Gradient operations for depthwise convolution.
54
55typedef Eigen::ThreadPoolDevice CPUDevice;
56typedef Eigen::GpuDevice GPUDevice;
57
58// Common code between the two backward pass kernels: verifies that the
59// dimensions all match and extract the padded rows and columns.
60#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
61 const Tensor& out_backprop = context->input(2); \
62 OP_REQUIRES( \
63 context, input_shape.dims() == 4, \
64 errors::InvalidArgument(label, ": input must be 4-dimensional")); \
65 OP_REQUIRES( \
66 context, filter_shape.dims() == 4, \
67 errors::InvalidArgument(label, ": filter must be 4-dimensional")); \
68 OP_REQUIRES( \
69 context, out_backprop.dims() == 4, \
70 errors::InvalidArgument(label, ": out_backprop must be 4-dimensional")); \
71 const int64_t batch = input_shape.dim_size(0); \
72 OP_REQUIRES( \
73 context, batch == out_backprop.dim_size(0), \
74 errors::InvalidArgument( \
75 label, ": input and out_backprop must have the same batch size")); \
76 const int64_t input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); \
77 OP_REQUIRES( \
78 context, \
79 FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()), \
80 errors::InvalidArgument("Input rows too large")); \
81 const int32 input_rows = static_cast<int32>(input_rows_raw); \
82 const int64_t input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); \
83 OP_REQUIRES( \
84 context, \
85 FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()), \
86 errors::InvalidArgument("Input cols too large")); \
87 const int32 input_cols = static_cast<int32>(input_cols_raw); \
88 const int64_t filter_rows = filter_shape.dim_size(0); \
89 const int64_t filter_cols = filter_shape.dim_size(1); \
90 const int64_t output_rows_raw = \
91 GetTensorDim(out_backprop.shape(), data_format_, 'H'); \
92 OP_REQUIRES( \
93 context, \
94 FastBoundsCheck(output_rows_raw, std::numeric_limits<int32>::max()), \
95 errors::InvalidArgument("Output rows too large")); \
96 const int32 output_rows = static_cast<int32>(output_rows_raw); \
97 const int64_t output_cols_raw = \
98 GetTensorDim(out_backprop.shape(), data_format_, 'W'); \
99 OP_REQUIRES( \
100 context, \
101 FastBoundsCheck(output_cols_raw, std::numeric_limits<int32>::max()), \
102 errors::InvalidArgument("Output cols too large")); \
103 const int32 output_cols = static_cast<int32>(output_cols_raw); \
104 const int64_t in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
105 OP_REQUIRES(context, in_depth == filter_shape.dim_size(2), \
106 errors::InvalidArgument( \
107 label, ": input and filter must have the same in_depth")); \
108 const int64_t depth_multiplier = filter_shape.dim_size(3); \
109 const int64_t out_depth_raw = \
110 GetTensorDim(out_backprop.shape(), data_format_, 'C'); \
111 OP_REQUIRES( \
112 context, \
113 FastBoundsCheck(out_depth_raw, std::numeric_limits<int32>::max()), \
114 errors::InvalidArgument("Output depth too large")); \
115 const int32 out_depth = static_cast<int32>(out_depth_raw); \
116 OP_REQUIRES( \
117 context, (depth_multiplier * in_depth) == out_depth, \
118 errors::InvalidArgument( \
119 label, ": depth_multiplier * in_depth not equal to out_depth")); \
120 const auto stride = stride_; \
121 int64_t out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0, \
122 pad_left = 0, pad_right = 0; \
123 if (padding_ == Padding::EXPLICIT) { \
124 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top, \
125 &pad_bottom); \
126 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, \
127 &pad_right); \
128 } \
129 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \
130 input_rows, filter_rows, stride_, padding_, \
131 &out_rows, &pad_top, &pad_bottom)); \
132 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( \
133 input_cols, filter_cols, stride_, padding_, \
134 &out_cols, &pad_left, &pad_right)); \
135 OP_REQUIRES( \
136 context, output_rows == out_rows, \
137 errors::InvalidArgument( \
138 label, ": Number of rows of out_backprop doesn't match computed: ", \
139 "actual = ", output_rows, ", computed = ", out_rows)); \
140 OP_REQUIRES( \
141 context, output_cols == out_cols, \
142 errors::InvalidArgument( \
143 label, ": Number of cols of out_backprop doesn't match computed: ", \
144 "actual = ", output_cols, ", computed = ", out_cols)); \
145 DepthwiseArgs args; \
146 args.batch = batch; \
147 args.in_rows = input_rows; \
148 args.in_cols = input_cols; \
149 args.in_depth = in_depth; \
150 args.filter_rows = filter_rows; \
151 args.filter_cols = filter_cols; \
152 args.depth_multiplier = depth_multiplier; \
153 args.stride = stride; \
154 args.pad_rows = pad_top; \
155 args.pad_cols = pad_left; \
156 args.out_rows = out_rows; \
157 args.out_cols = out_cols; \
158 args.out_depth = out_depth; \
159 VLOG(2) << "DepthwiseConv2d: " << label << " Input: [" << batch << ", " \
160 << input_rows << ", " << input_cols << ", " << in_depth \
161 << "]; Filter: [" << filter_rows << ", " << filter_cols << ", " \
162 << in_depth << ", " << depth_multiplier << "]; stride = " << stride \
163 << ", pad_rows = " << pad_top << ", pad_cols = " << pad_left \
164 << ", output: [" << batch << ", " << out_rows << ", " << out_cols \
165 << ", " << out_depth << "]";
166
167// Copies data from local region in 'out_backprop' into 'buffer'.
168// The local region coordinates are calculated as the set of output points which
169// used the input point ('in_r', 'in_'c') as input during the forward pass.
170// Rather than spatially reversing the filter, the input is reversed during
171// the copy. The copied data is padded to vector register-width boundaries so
172// that it is aligned for efficient traversal and vector multiply-add by the
173// depthwise input kernel.
174//
175// EX:
176// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
177//
178// 'out_backprop': [batch, out_rows, out_cols, out_depth]
179//
180// [a00, a01, a10, a11] [a20, a21, b00, b01]
181// [b10, b11, b20, b21] [...]
182// [e00, e01, e10, e11] [e20, e21, f00, f01]
183// [f10, f11, f20, f21] [...]
184//
185// 'buffer' (register boundaries shown):
186//
187// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0
188// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1
189// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0
190// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1
191//
192template <typename T>
193static void CopyOutputBackpropRegion(const DepthwiseArgs& args,
194 const int64_t padded_filter_inner_dim_size,
195 const int64_t in_r, const int64_t in_c,
196 const T* out_backprop, T* buffer) {
197 typedef typename Eigen::internal::packet_traits<T>::type Packet;
198 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
199
200 const int64_t stride = args.stride;
201 const int64_t filter_rows = args.filter_rows;
202 const int64_t filter_cols = args.filter_cols;
203 const int64_t pad_rows = args.pad_rows;
204 const int64_t pad_cols = args.pad_cols;
205 const int64_t out_rows = args.out_rows;
206 const int64_t out_cols = args.out_cols;
207
208 // Calculate the output spatial region which used point (in_r, in_c) as input.
209 const int64_t out_r_start =
210 std::max(static_cast<int64_t>(0),
211 (in_r - filter_rows + pad_rows + stride) / stride);
212 const int64_t out_r_end = std::min(out_rows - 1, (in_r + pad_rows) / stride);
213 const int64_t out_c_start =
214 std::max(static_cast<int64_t>(0),
215 (in_c - filter_cols + pad_cols + stride) / stride);
216 const int64_t out_c_end = std::min(out_cols - 1, (in_c + pad_cols) / stride);
217
218 // Zero-pad 'buffer' if output region is smaller than filter spatial size.
219 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
220 if ((out_r_end - out_r_start + 1) < args.filter_rows ||
221 (out_c_end - out_c_start + 1) < args.filter_cols) {
222 memset(buffer, 0,
223 filter_spatial_size * padded_filter_inner_dim_size * sizeof(T));
224 }
225
226 // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
227 const int64_t vectorized_size = (args.out_depth / kPacketSize) * kPacketSize;
228 const int64_t scalar_size = args.out_depth % kPacketSize;
229 const int64_t pad_size = scalar_size > 0 ? kPacketSize - scalar_size : 0;
230
231 for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
232 const int64_t f_r = in_r + pad_rows - out_r * stride;
233 for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
234 const int64_t f_c = in_c + pad_cols - out_c * stride;
235 const int64_t buf_base =
236 (f_r * filter_cols + f_c) * padded_filter_inner_dim_size;
237 // Calculate index into 'out_backprop' for coordinate (out_r, out_c).
238 auto* out_bprop =
239 out_backprop + (out_r * args.out_cols + out_c) * args.out_depth;
240
241 // Copy vectorized portion of inner dimension into 'buffer'.
242 for (int64_t d = 0; d < vectorized_size; d += kPacketSize) {
243 auto v = Eigen::internal::ploadu<Packet>(out_bprop + d);
244 Eigen::internal::pstoreu<T>(buffer + buf_base + d, v);
245 }
246 // Copy scalar portion of out_bprop to 'buffer'
247 for (int64_t d = 0; d < scalar_size; ++d) {
248 buffer[buf_base + vectorized_size + d] = out_bprop[vectorized_size + d];
249 }
250 // Pad to vector-register width (if needed).
251 for (int64_t d = 0; d < pad_size; ++d) {
252 buffer[buf_base + vectorized_size + scalar_size + d] =
253 static_cast<T>(0);
254 }
255 }
256 }
257}
258
259// Computes the vectorized product of 'buffer' and 'filter' and stores
260// result in 'output' at location computed from 'in_r' and 'in_c'.
261// If depth_multiplier is > 1, the intermediate output is reduced along
262// the depth_multiplier dimension.
263//
264// EX:
265// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
266// Both 'input_buffer' and 'filter' are padded to register-width boundaries.
267//
268// 'buffer' [rows, cols, in_depth, depth_multiplier]
269//
270// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0
271// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1
272// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0
273// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1
274//
275// filter [rows, cols, in_depth, depth_multiplier]
276// [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0]
277// [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0]
278//
279// First output register [in_depth, depth_multiplier]
280// [q00, q01, q10, q11] = ([f00, f01, f10, f11] x [u0, v0, w0, x0]) +
281// ([e00, e01, e10, e11] x [u1, v1, w1, x1]) +
282// ([b00, b01, b10, b11] x [u2, v2, w2, x2]) +
283// ([a00, a01, a10, a11] x [u3, v3, w3, x3])
284//
285// Reduction step along depth-multiplier dimension:
286//
287// [q00, q01, q10, q11] [q20, q21, 0, 0] -> [r0, r1, r2, 0]
288//
289
290template <typename T>
291static void ComputeBackpropInput(const DepthwiseArgs& args,
292 const int64_t padded_filter_inner_dim_size,
293 const int64_t in_r, const int64_t in_c,
294 const T* filter, const T* buffer,
295 T* out_buffer, T* output) {
296 typedef typename Eigen::internal::packet_traits<T>::type Packet;
297 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
298
299 const int64_t in_depth = args.in_depth;
300 const int64_t depth_multiplier = args.depth_multiplier;
301 const int64_t out_depth = args.out_depth;
302 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
303
304 // Calculate vectorized and scalar lengths of 'out_depth'.
305 const int64_t output_vectorized_size =
306 (out_depth / kPacketSize) * kPacketSize;
307 const int64_t output_scalar_size = out_depth % kPacketSize;
308
309 // Calculate base index at which to begin writing output.
310 const int64_t base_output_index = (in_r * args.in_cols + in_c) * in_depth;
311
312 // Calculate vectorized and scalar lengths for 'depth_multiplier'. This is
313 // used to efficiently reduce output when 'depth_multiplier' > kPacketSize.
314 const int64_t dm_vectorized_size =
315 (depth_multiplier / kPacketSize) * kPacketSize;
316 const int64_t dm_scalar_size = depth_multiplier % kPacketSize;
317
318 for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
319 // Reset accumulator.
320 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
321 for (int j = 0; j < filter_spatial_size; ++j) {
322 // Calculate index.
323 const int64_t index = i + j * padded_filter_inner_dim_size;
324 // Load filter.
325 const auto filter_block = Eigen::internal::ploadu<Packet>(filter + index);
326 // Load input.
327 const auto data_block = Eigen::internal::ploadu<Packet>(buffer + index);
328 // Vector multiply-add.
329 vaccum = Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
330 }
331 if (depth_multiplier == 1) {
332 // Write directly to the output.
333 Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum);
334 } else {
335 // Buffer output for subsequent reduction step.
336 Eigen::internal::pstoreu<T>(out_buffer + i, vaccum);
337 }
338 }
339
340 if (output_scalar_size > 0) {
341 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
342 for (int j = 0; j < filter_spatial_size; ++j) {
343 const int64_t index =
344 output_vectorized_size + j * padded_filter_inner_dim_size;
345 const auto filter_block = Eigen::internal::ploadu<Packet>(filter + index);
346 const auto data_block = Eigen::internal::ploadu<Packet>(buffer + index);
347 vaccum = Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
348 }
349 // Load accumulator into an array and loop through output.
350 T out_buf[kPacketSize];
351 Eigen::internal::pstoreu<T>(out_buf, vaccum);
352 if (depth_multiplier == 1) {
353 // Write directly to the output.
354 for (int j = 0; j < output_scalar_size; ++j) {
355 output[base_output_index + output_vectorized_size + j] = out_buf[j];
356 }
357 } else {
358 // Buffer output for subsequent reduction step.
359 for (int j = 0; j < output_scalar_size; ++j) {
360 out_buffer[output_vectorized_size + j] = out_buf[j];
361 }
362 }
363 }
364
365 // Iterate over 'in_depth', reduce over 'depth_multiplier', write 'output'.
366 if (depth_multiplier > 1) {
367 for (int64_t d = 0; d < in_depth; ++d) {
368 const int64_t index = d * args.depth_multiplier;
369 T accum = static_cast<T>(0);
370 for (int64_t dm = 0; dm < dm_vectorized_size; dm += kPacketSize) {
371 const auto v = Eigen::internal::ploadu<Packet>(out_buffer + index + dm);
372 accum += Eigen::internal::predux(v);
373 }
374 // Copy scalar portion of replicated output.
375 for (int64_t dm = 0; dm < dm_scalar_size; ++dm) {
376 accum += out_buffer[index + dm_vectorized_size + dm];
377 }
378 // Copy to output.
379 output[base_output_index + d] = accum;
380 }
381 }
382}
383
384// Computes the depthwise conv2d backprop input of 'out_backprop' by
385// 'depthwise_filter' and stores the result in 'in_backprop'.
386template <typename T>
387struct LaunchDepthwiseConvBackpropInputOp<CPUDevice, T> {
388 typedef typename Eigen::internal::packet_traits<T>::type Packet;
389
390 void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
391 const T* out_backprop, const T* depthwise_filter,
392 T* in_backprop, TensorFormat data_format) {
393 OP_REQUIRES(
394 ctx, data_format == FORMAT_NHWC,
395 errors::Unimplemented(
396 "Depthwise convolution on CPU is only supported for NHWC format"));
397
398 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
399
400 // Pad 'depthwise_filter' to vector register width (if needed).
401 const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true;
402 Tensor padded_filter;
403 if (pad_filter) {
404 // Allocate space for padded filter.
405 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
406 const int64_t padded_filter_inner_dim_size =
407 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
408 OP_REQUIRES_OK(
409 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
410 TensorShape({filter_spatial_size,
411 padded_filter_inner_dim_size}),
412 &padded_filter));
413 // Write out padded filter.
414 functor::DepthwiseFilterPadOp<T>()(
415 args, depthwise_filter, padded_filter.template flat<T>().data());
416 }
417 const T* filter_data =
418 pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter;
419
420 // Computes one shard of depthwise conv2d backprop input.
421 auto shard = [&ctx, &args, &out_backprop, &filter_data, &in_backprop](
422 int64_t start, int64_t limit) {
423 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
424
425 const int64_t input_image_size =
426 args.in_rows * args.in_cols * args.in_depth;
427 const int64_t output_image_size =
428 args.out_rows * args.out_cols * args.out_depth;
429 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
430 const int64_t padded_filter_inner_dim_size =
431 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
432
433 // Allocate buffer to copy regions from 'out_backprop'.
434 Tensor out_bprop_buffer;
435 OP_REQUIRES_OK(
436 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
437 TensorShape({filter_spatial_size,
438 padded_filter_inner_dim_size}),
439 &out_bprop_buffer));
440 T* out_bprop_buf = out_bprop_buffer.template flat<T>().data();
441
442 // Allocate buffer for intermediate results.
443 Tensor in_bprop_buffer;
444 OP_REQUIRES_OK(
445 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
446 TensorShape({padded_filter_inner_dim_size}),
447 &in_bprop_buffer));
448 T* in_bprop_buf = in_bprop_buffer.template flat<T>().data();
449
450 for (int64_t b = start; b < limit; ++b) {
451 for (int64_t in_r = 0; in_r < args.in_rows; ++in_r) {
452 for (int64_t in_c = 0; in_c < args.in_cols; ++in_c) {
453 // Populate 'out_bprop_buf' from local 'out_backprop' region.
454 CopyOutputBackpropRegion<T>(
455 args, padded_filter_inner_dim_size, in_r, in_c,
456 out_backprop + b * output_image_size, out_bprop_buf);
457
458 // Compute depthwise backprop input.
459 ComputeBackpropInput<T>(args, padded_filter_inner_dim_size, in_r,
460 in_c, filter_data, out_bprop_buf,
461 in_bprop_buf,
462 in_backprop + b * input_image_size);
463 }
464 }
465 }
466 };
467
468 const int64_t shard_cost = args.in_rows * args.in_cols * args.out_depth;
469 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
470 Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
471 shard_cost, shard);
472 }
473};
474
475template <typename T>
476static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
477 const T* out_backprop,
478 const T* filter,
479 T* in_backprop) {
480 // Naive for loop as a reference point without concerns about performance.
481 for (int b = 0; b < args.batch; ++b) {
482 for (int in_r = 0; in_r < args.in_rows; ++in_r) {
483 for (int in_c = 0; in_c < args.in_cols; ++in_c) {
484 for (int in_d = 0; in_d < args.in_depth; ++in_d) {
485 T sum = 0;
486 const int stride = args.stride;
487 const int out_d_start = in_d * args.depth_multiplier;
488 const int out_d_end = out_d_start + args.depth_multiplier;
489
490 for (int out_d = out_d_start; out_d < out_d_end; ++out_d) {
491 const int out_r_start = std::max(
492 0, (in_r - args.filter_rows + args.pad_rows + stride) / stride);
493 const int out_r_end =
494 std::min(args.out_rows - 1, (in_r + args.pad_rows) / stride);
495
496 for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) {
497 const int out_c_start = std::max(
498 0,
499 (in_c - args.filter_cols + args.pad_cols + stride) / stride);
500 const int out_c_end =
501 std::min(args.out_cols - 1, (in_c + args.pad_cols) / stride);
502
503 for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) {
504 int f_r = in_r + args.pad_rows - out_r * stride;
505 int f_c = in_c + args.pad_cols - out_c * stride;
506 int filter_dm = out_d - out_d_start;
507 int out_backprop_offset =
508 out_d +
509 args.out_depth *
510 (out_c + args.out_cols * (out_r + args.out_rows * b));
511 int filter_offset =
512 filter_dm +
513 args.depth_multiplier *
514 (in_d + args.in_depth * (f_c + args.filter_cols * f_r));
515 sum +=
516 out_backprop[out_backprop_offset] * filter[filter_offset];
517 }
518 }
519 }
520
521 int in_backprop_offset =
522 in_d +
523 args.in_depth * (in_c + args.in_cols * (in_r + args.in_rows * b));
524 in_backprop[in_backprop_offset] = sum;
525 }
526 }
527 }
528 }
529}
530
531// Extern template instantiated in conv_grad_input_ops.cc.
532extern template struct LaunchConv2DBackpropInputOp<CPUDevice, bfloat16>;
533extern template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
534extern template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
535extern template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
536
537#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
538
539// Extern template instantiated in conv_grad_input_ops.cc.
540extern template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
541extern template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
542extern template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
543
544// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
545extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice,
546 Eigen::half>;
547extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, float>;
548extern template struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, double>;
549
550#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
551
552// Kernel to compute the input backprop for depthwise convolution.
553template <typename Device, class T>
554class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
555 public:
556 explicit DepthwiseConv2dNativeBackpropInputOp(OpKernelConstruction* context)
557 : OpKernel(context) {
558 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
559 OP_REQUIRES(context, strides_.size() == 4,
560 errors::InvalidArgument("Sliding window strides field must "
561 "specify 4 dimensions"));
562
563 string data_format;
564 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
565 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
566 errors::InvalidArgument("Invalid data format"));
567
568 stride_ = GetTensorDim(strides_, data_format_, 'H');
569 const int64_t stride_w = GetTensorDim(strides_, data_format_, 'W');
570 const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
571 const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
572
573 OP_REQUIRES(context, stride_ == stride_w,
574 errors::InvalidArgument(
575 "Current implementation only supports equal length "
576 "strides in the row and column dimensions."));
577 OP_REQUIRES(
578 context, (stride_n == 1 && stride_c == 1),
579 errors::InvalidArgument("Current implementation does not yet support "
580 "strides in the batch and depth dimensions."));
581 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
582 OP_REQUIRES_OK(context,
583 context->GetAttr("explicit_paddings", &explicit_paddings_));
584 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
585 /*num_dims=*/4, data_format_));
586
587 cudnn_use_autotune_ = CudnnUseAutotune();
588 dtype_ = DataTypeToEnum<T>::value;
589#if CUDNN_VERSION >= 8000
590 // From the cuDNN release note 8.0: We’ve extended the fprop and dgrad
591 // NHWC depthwise kernels to support more combinations (filter
592 // sizes/strides) such as 5x5/1x1, 5x5/2x2, 7x7/1x1, 7x7/2x2 (in addition
593 // to what we already have, 1x1/1x1, 3x3/1x1, 3x3/2x2), which provides
594 // good performance. (https://docs.nvidia.com/deeplearning/sdk/cudnn-
595 // release-notes/rel_8.html#rel_8)
596 use_cudnn_grouped_conv_ =
597 dtype_ == DT_HALF &&
598 ((data_format_ == FORMAT_NCHW && stride_ == 1 && stride_w == 1) ||
599 (data_format_ == FORMAT_NHWC && stride_ == stride_w &&
600 (stride_ == 1 || stride_ == 2)));
601#elif CUDNN_VERSION >= 7603
602 // Use CuDNN grouped conv (input gradient) when stride = 1, input/output is
603 // NCHW and float16(half). See cudnn release note 7.6.3 (https://docs.nvidi
604 // a.com/deeplearning/sdk/cudnn-release-notes/rel_763.html#rel_763).
605 use_cudnn_grouped_conv_ = dtype_ == DT_HALF &&
606 data_format_ == FORMAT_NCHW && stride_ == 1 &&
607 stride_w == 1;
608#else
609 use_cudnn_grouped_conv_ = false;
610#endif
611 }
612
613 void Compute(OpKernelContext* context) override {
614 const Tensor& input_sizes = context->input(0);
615 const Tensor& filter = context->input(1);
616 OP_REQUIRES(
617 context, TensorShapeUtils::IsVector(input_sizes.shape()),
618 errors::InvalidArgument(
619 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
620 input_sizes.dims()));
621 TensorShape input_shape;
622 const int32* in_sizes_data = input_sizes.template flat<int32>().data();
623
624 for (int i = 0; i < input_sizes.NumElements(); ++i) {
625 OP_REQUIRES(context, in_sizes_data[i] >= 0,
626 errors::InvalidArgument("Dimension ", i,
627 " of input_sizes must be >= 0"));
628 OP_REQUIRES_OK(context, input_shape.AddDimWithStatus(in_sizes_data[i]));
629 }
630 const TensorShape& filter_shape = filter.shape();
631 EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
632
633 Tensor* in_backprop = nullptr;
634 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
635 {0}, 0, input_shape, &in_backprop));
636
637 // If there is nothing to compute, return.
638 if (input_shape.num_elements() == 0) {
639 return;
640 }
641
642 // If in_depth==1, this operation is just a standard convolution.
643 // Depthwise convolution is a special case of cuDNN's grouped convolution.
644 bool use_cudnn =
645 std::is_same<Device, GPUDevice>::value &&
646 (in_depth == 1 || (use_cudnn_grouped_conv_ &&
647 ShouldCudnnGroupedConvolutionBeUsed(
648 filter_rows, filter_cols, in_depth, out_depth)));
649
650 VLOG(2) << "DepthwiseConv2dNativeBackpropInput: "
651 << " Input: [" << batch << ", " << input_rows << ", " << input_cols
652 << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
653 << filter_cols << ", " << in_depth << ", " << depth_multiplier
654 << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
655 << ", " << out_depth << "], stride = " << stride_
656 << ", pad_rows = " << pad_top << ", pad_cols = " << pad_left
657 << ", Use cuDNN: " << use_cudnn;
658
659 if (use_cudnn) {
660 // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
661 //
662 // | TensorFlow | cuDNN
663 // --------------------------------------------------------------------
664 // filter_out_depth | depth_multiplier | depth_multiplier * group_count
665 // filter_in_depth | in_depth | in_depth / group_count
666 //
667 // For depthwise convolution, we have group_count == in_depth.
668 int32_t filter_in_depth = 1;
669 TensorShape shape =
670 TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
671 Tensor reshaped_filter(/*type=*/dtype_);
672 OP_REQUIRES(
673 context, reshaped_filter.CopyFrom(filter, shape),
674 errors::Internal(
675 "Failed to reshape filter tensor for grouped convolution."));
676 // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
677 // conv is supported.
678 launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, out_backprop,
679 reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
680 stride_, stride_, padding_, explicit_paddings_, in_backprop,
681 data_format_);
682 return;
683 }
684
685 auto out_backprop_ptr = out_backprop.template flat<T>().data();
686 auto filter_ptr = filter.template flat<T>().data();
687 auto in_backprop_ptr = in_backprop->template flat<T>().data();
688 LaunchDepthwiseConvBackpropInputOp<Device, T>()(
689 context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
690 data_format_);
691 }
692
693 protected:
694 bool use_cudnn_grouped_conv_;
695
696 private:
697 std::vector<int32> strides_;
698 Padding padding_;
699 std::vector<int64_t> explicit_paddings_;
700 TensorFormat data_format_;
701 int64_t stride_;
702
703 // For in_depth == 1 and grouped convolutions.
704 LaunchConv2DBackpropInputOp<Device, T> launcher_;
705 bool cudnn_use_autotune_;
706 DataType dtype_;
707
708 TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
709};
710
711#define REGISTER_CPU_KERNEL(T) \
712 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
713 .Device(DEVICE_CPU) \
714 .TypeConstraint<T>("T"), \
715 DepthwiseConv2dNativeBackpropInputOp<CPUDevice, T>);
716
717TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
718TF_CALL_half(REGISTER_CPU_KERNEL);
719TF_CALL_float(REGISTER_CPU_KERNEL);
720#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
721TF_CALL_double(REGISTER_CPU_KERNEL);
722#endif
723#undef REGISTER_CPU_KERNEL
724
725#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
726
727#define REGISTER_GPU_KERNEL(T) \
728 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
729 .Device(DEVICE_GPU) \
730 .TypeConstraint<T>("T") \
731 .HostMemory("input_sizes"), \
732 DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>)
733
734TF_CALL_half(REGISTER_GPU_KERNEL);
735TF_CALL_float(REGISTER_GPU_KERNEL);
736TF_CALL_double(REGISTER_GPU_KERNEL);
737#undef REGISTER_GPU_KERNEL
738
739#if CUDNN_VERSION >= 7000
740template <typename T>
741class DepthwiseConv2dGroupedConvBackpropInputOp
742 : public DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T> {
743 public:
744 DepthwiseConv2dGroupedConvBackpropInputOp(OpKernelConstruction* context)
745 : DepthwiseConv2dNativeBackpropInputOp<GPUDevice, T>(context) {
746 this->use_cudnn_grouped_conv_ = true;
747 }
748};
749
750#define REGISTER_GROUPED_CONV_KERNEL(T) \
751 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropInput") \
752 .Device(DEVICE_GPU) \
753 .TypeConstraint<T>("T") \
754 .HostMemory("input_sizes") \
755 .Label("cudnn_grouped_convolution"), \
756 DepthwiseConv2dGroupedConvBackpropInputOp<T>)
757
758TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
759TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
760TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
761#undef REGISTER_GROUPED_CONV_KERNEL
762#endif // CUDNN_VERSION
763#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
764
765// Kernels to compute the gradients of the filters for depthwise convolution.
766
767// Computes filter backprop using 'out_backprop' and 'input_buffer', storing the
768// result in 'output_buffer' at an index computed from 'out_r' and 'out_c'.
769//
770// EX:
771// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
772// Both 'input_buffer' and 'filter' are padded to register-width boundaries.
773//
774// 'input_buffer' [rows, cols, in_depth, depth_multiplier]
775//
776// [f00, f01, f10, f11] [f20, f21, 0, 0] in_row = 0, in_col = 0
777// [e00, e01, e10, e11] [e20, e21, 0, 0] in_row = 0, in_col = 1
778// [b00, b01, b10, b11] [b20, b21, 0, 0] in_row = 1, in_col = 0
779// [a00, a01, a10, a11] [a20, a21, 0, 0] in_row = 1, in_col = 1
780//
781// 'out_backprop' [out_rows, out_cols, in_depth, depth_multiplier]
782//
783// [q00, q01, q10, q11] [q20, q21, r00, r01]
784// [r10, r11, r20, r21] [s00, s01, s10, s11]
785// [s20, s21, t00, t01] [t10, t11, t20, a21]
786//
787// First output register of 'filter_backprop'
788// [u0, v0, w0, x0] += ([f00, f01, f10, f11] x [q00, q01, q10, q11])
789//
790template <typename T>
791static void ComputeBackpropFilter(const DepthwiseArgs& args,
792 const int64_t padded_out_depth_size,
793 const int64_t out_r, const int64_t out_c,
794 const T* out_backprop, const T* input_buffer,
795 T* output_buffer) {
796 typedef typename Eigen::internal::packet_traits<T>::type Packet;
797 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
798 // Calculate vectorized size of 'padded_out_depth_size'.
799 const int64_t out_depth = args.out_depth;
800 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
801 const int64_t output_vectorized_size =
802 (padded_out_depth_size / kPacketSize) * kPacketSize;
803 const int64_t base_output_index = (out_r * args.out_cols + out_c) * out_depth;
804 // Determine whether we can execute fast or slow code path.
805 const int64_t output_image_size =
806 args.out_rows * args.out_cols * args.out_depth;
807 const int64_t output_last_vector_index =
808 output_image_size - (filter_spatial_size * padded_out_depth_size);
809 const bool fast_path = base_output_index <= output_last_vector_index;
810
811 if (fast_path) {
812 // TODO(andydavis) Process multiple inputs in 'input_buffer' so we can
813 // amortize the cost of 'output_buffer' load store in the loop below.
814 for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
815 // Load vector register from 'out_backprop'.
816 const auto out_bprop_block =
817 Eigen::internal::ploadu<Packet>(out_backprop + base_output_index + i);
818 for (int j = 0; j < filter_spatial_size; ++j) {
819 const int64_t index = i + j * padded_out_depth_size;
820 // Load vector register from 'input_buffer'.
821 const auto input_block =
822 Eigen::internal::ploadu<Packet>(input_buffer + index);
823 // Load output block into vector register.
824 auto out_block_data = output_buffer + index;
825 auto out_block = Eigen::internal::ploadu<Packet>(out_block_data);
826 // Vector multiply-add.
827 out_block = Eigen::internal::pmadd<Packet>(out_bprop_block, input_block,
828 out_block);
829 // Store 'out_block' back to memory.
830 Eigen::internal::pstoreu<T>(out_block_data, out_block);
831 }
832 }
833 } else {
834 // Slow path (cant do vector reads from non-padded 'out_backprop'.
835 for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
836 // Calculate safe read size from 'out_backprop'.
837 const int64_t out_bprop_index = base_output_index + i;
838 const int64_t out_bprop_limit =
839 std::min(output_image_size, out_bprop_index + kPacketSize);
840 T out_buf[kPacketSize];
841 memset(&out_buf, 0, kPacketSize * sizeof(T));
842 const int64_t scalar_size = out_bprop_limit - out_bprop_index;
843 for (int64_t j = 0; j < scalar_size; ++j) {
844 out_buf[j] = out_backprop[out_bprop_index + j];
845 }
846 // Load vector register from 'out_buf'.
847 const auto out_bprop_block = Eigen::internal::ploadu<Packet>(out_buf);
848 for (int j = 0; j < filter_spatial_size; ++j) {
849 const int64_t index = i + j * padded_out_depth_size;
850 // Load vector register from 'input_buffer'.
851 const auto input_block =
852 Eigen::internal::ploadu<Packet>(input_buffer + index);
853 // Load output block into vector register.
854 auto out_block_data = output_buffer + index;
855 auto out_block = Eigen::internal::ploadu<Packet>(out_block_data);
856 // Vector multiply-add.
857 out_block = Eigen::internal::pmadd<Packet>(out_bprop_block, input_block,
858 out_block);
859 // Store 'out_block' back to memory.
860 Eigen::internal::pstoreu<T>(out_block_data, out_block);
861 }
862 }
863 }
864}
865
866template <typename Device, typename T>
867struct LaunchDepthwiseConvBackpropFilterOp;
868
869template <typename T>
870struct LaunchDepthwiseConvBackpropFilterOp<CPUDevice, T> {
871 typedef typename Eigen::internal::packet_traits<T>::type Packet;
872
873 void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
874 const T* out_backprop, const T* input, T* filter_backprop,
875 TensorFormat data_format) {
876 OP_REQUIRES(
877 ctx, data_format == FORMAT_NHWC,
878 errors::Unimplemented(
879 "Depthwise convolution on CPU is only supported for NHWC format"));
880
881 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
882
883 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
884 const int64_t padded_out_depth_size =
885 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
886
887 // Allocate output buffers for each image in 'batch' (padded to vector
888 // register boundaries).
889 Tensor output_buffer;
890 OP_REQUIRES_OK(
891 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
892 TensorShape({args.batch, filter_spatial_size,
893 padded_out_depth_size}),
894 &output_buffer));
895 T* output_buffer_data = output_buffer.template flat<T>().data();
896
897 // Computes one shard of depthwise conv2d backprop filter.
898 auto shard = [&ctx, &args, &out_backprop, &input, &output_buffer_data](
899 int64_t start, int64_t limit) {
900 static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
901 const int64_t filter_spatial_size = args.filter_rows * args.filter_cols;
902 const int64_t padded_out_depth_size =
903 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
904
905 // Allocate buffer for local input regions.
906 Tensor input_buffer;
907 OP_REQUIRES_OK(
908 ctx, ctx->allocate_temp(
909 DataTypeToEnum<T>::value,
910 TensorShape({filter_spatial_size, padded_out_depth_size}),
911 &input_buffer));
912 T* input_buffer_data = input_buffer.template flat<T>().data();
913
914 const int64_t input_image_size =
915 args.in_rows * args.in_cols * args.in_depth;
916 const int64_t output_image_size =
917 args.out_rows * args.out_cols * args.out_depth;
918 const int64_t padded_filter_size =
919 filter_spatial_size * padded_out_depth_size;
920
921 for (int b = start; b < limit; ++b) {
922 // Initialize 'output_buffer' for 'b'.
923 auto* output_buffer = output_buffer_data + b * padded_filter_size;
924 memset(output_buffer, 0, padded_filter_size * sizeof(T));
925
926 for (int out_r = 0; out_r < args.out_rows; ++out_r) {
927 for (int out_c = 0; out_c < args.out_cols; ++out_c) {
928 // Populate 'input_buffer_data' with data from local input region.
929 functor::DepthwiseInputCopyOp<T>()(
930 args, padded_out_depth_size, out_r, out_c,
931 input + b * input_image_size, input_buffer_data);
932 // Compute depthwise backprop filter.
933 ComputeBackpropFilter(args, padded_out_depth_size, out_r, out_c,
934 out_backprop + b * output_image_size,
935 input_buffer_data, output_buffer);
936 }
937 }
938 }
939 };
940 const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth;
941 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
942 Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
943 shard_cost, shard);
944
945 // Accumulate 'output_buffer' from each shard into 'output'.
946 const int64_t out_depth = args.out_depth;
947 const int64_t vectorized_size = (out_depth / kPacketSize) * kPacketSize;
948 const int64_t scalar_size = out_depth - vectorized_size;
949 const int64_t padded_filter_size =
950 filter_spatial_size * padded_out_depth_size;
951 memset(filter_backprop, 0, filter_spatial_size * out_depth * sizeof(T));
952
953 for (int64_t i = 0; i < filter_spatial_size; ++i) {
954 const int64_t buffer_base = i * padded_out_depth_size;
955 const int64_t output_base = i * out_depth;
956 // Write vectorized length of filter's inner dimension to output.
957 for (int64_t j = 0; j < vectorized_size; j += kPacketSize) {
958 // Load data from 'filter_backprop' into vector register.
959 auto out_block_data = filter_backprop + output_base + j;
960 auto out_block = Eigen::internal::ploadu<Packet>(out_block_data);
961 for (int b = 0; b < args.batch; ++b) {
962 // Load data from 'output_buffer' for 'b'.
963 const auto* output_buffer =
964 output_buffer_data + b * padded_filter_size;
965 const auto v =
966 Eigen::internal::ploadu<Packet>(output_buffer + buffer_base + j);
967 // Add 'v' to 'out_block'.
968 out_block = Eigen::internal::padd<Packet>(out_block, v);
969 }
970 // Store 'out_block' back to memory.
971 Eigen::internal::pstoreu<T>(out_block_data, out_block);
972 }
973 // Write scalar length of filter's inner dimension to output.
974 for (int64_t j = 0; j < scalar_size; ++j) {
975 for (int b = 0; b < args.batch; ++b) {
976 const auto* output_buffer =
977 output_buffer_data + b * padded_filter_size;
978 filter_backprop[output_base + vectorized_size + j] +=
979 output_buffer[buffer_base + vectorized_size + j];
980 }
981 }
982 }
983 }
984};
985
986template <typename T>
987static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
988 const T* out_backprop,
989 const T* input,
990 T* filter_backprop) {
991 int num_filter_backprop = args.filter_rows * args.filter_cols *
992 args.in_depth * args.depth_multiplier;
993 memset(filter_backprop, 0, num_filter_backprop * sizeof(T));
994 // Naive for loop as a reference point without concerns about performance.
995 for (int b = 0; b < args.batch; ++b) {
996 for (int out_r = 0; out_r < args.out_rows; ++out_r) {
997 for (int out_c = 0; out_c < args.out_cols; ++out_c) {
998 for (int out_d = 0; out_d < args.out_depth; ++out_d) {
999 const int in_d = out_d / args.depth_multiplier;
1000 const int dm = out_d % args.depth_multiplier;
1001 const int in_r_start = out_r * args.stride - args.pad_rows;
1002 const int in_c_start = out_c * args.stride - args.pad_cols;
1003
1004 for (int f_r = 0; f_r < args.filter_rows; ++f_r) {
1005 for (int f_c = 0; f_c < args.filter_cols; ++f_c) {
1006 const int in_r = in_r_start + f_r;
1007 const int in_c = in_c_start + f_c;
1008
1009 if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 &&
1010 in_c < args.in_cols) {
1011 int out_backprop_offset =
1012 out_d +
1013 args.out_depth *
1014 (out_c + args.out_cols * (out_r + args.out_rows * b));
1015 int input_offset =
1016 in_d +
1017 args.in_depth *
1018 (in_c + args.in_cols * (in_r + args.in_rows * b));
1019 int filter_backprop_offset =
1020 dm +
1021 args.depth_multiplier *
1022 (in_d + args.in_depth * (f_c + args.filter_cols * f_r));
1023 filter_backprop[filter_backprop_offset] +=
1024 input[input_offset] * out_backprop[out_backprop_offset];
1025 }
1026 }
1027 }
1028 }
1029 }
1030 }
1031 }
1032}
1033
1034// Extern template instantiated in conv_grad_filter_ops.cc.
1035extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, bfloat16>;
1036extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, Eigen::half>;
1037extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
1038extern template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;
1039
1040#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1041
1042// Extern template instantiated in conv_grad_filter_ops.cc.
1043extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
1044extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
1045extern template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;
1046
1047// Extern template instantiated in depthwise_conv_op_gpu.cu.cc.
1048extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice,
1049 Eigen::half>;
1050extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, float>;
1051extern template struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, double>;
1052
1053#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1054
1055// Kernel to compute the filter backprop for depthwise convolution.
1056template <typename Device, class T>
1057class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
1058 public:
1059 explicit DepthwiseConv2dNativeBackpropFilterOp(OpKernelConstruction* context)
1060 : OpKernel(context) {
1061 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
1062 OP_REQUIRES(context, strides_.size() == 4,
1063 errors::InvalidArgument("Sliding window strides field must "
1064 "specify 4 dimensions"));
1065
1066 string data_format;
1067 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
1068 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
1069 errors::InvalidArgument("Invalid data format"));
1070
1071 stride_ = GetTensorDim(strides_, data_format_, 'H');
1072 const int64_t stride_w = GetTensorDim(strides_, data_format_, 'W');
1073 const int64_t stride_n = GetTensorDim(strides_, data_format_, 'N');
1074 const int64_t stride_c = GetTensorDim(strides_, data_format_, 'C');
1075
1076 OP_REQUIRES(context, stride_ == stride_w,
1077 errors::InvalidArgument(
1078 "Current implementation only supports equal length "
1079 "strides in the row and column dimensions."));
1080 OP_REQUIRES(
1081 context, (stride_n == 1 && stride_c == 1),
1082 errors::InvalidArgument("Current implementation does not yet support "
1083 "strides in the batch and depth dimensions."));
1084 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
1085 OP_REQUIRES_OK(context,
1086 context->GetAttr("explicit_paddings", &explicit_paddings_));
1087 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
1088 /*num_dims=*/4, data_format_));
1089
1090 cudnn_use_autotune_ = CudnnUseAutotune();
1091
1092 if (std::is_same<T, bfloat16>::value) {
1093 dtype_ = DT_BFLOAT16;
1094 } else if (std::is_same<T, Eigen::half>::value) {
1095 dtype_ = DT_HALF;
1096 } else if (std::is_same<T, float>::value) {
1097 dtype_ = DT_FLOAT;
1098 } else if (std::is_same<T, double>::value) {
1099 dtype_ = DT_DOUBLE;
1100 } else {
1101 LOG(ERROR) << "Only bfloat16, half, float, and double are supported.";
1102 }
1103#if CUDNN_VERSION >= 7603
1104 // Use CuDNN grouped conv (filter gradients) when input/output is
1105 // float16(half). See cudnn release note 7.6.3. (https://docs.nvidia.com/dee
1106 // plearning/sdk/cudnn-release-notes/rel_763.html#rel_763)
1107 //
1108 // Grouped convolution was added to cuDNN in version 7.0.1 but
1109 // TensorFlow op-determinism has been added only for cuDNN versions 7.6.3
1110 // and later intentionally. This is to avoid potential issues with earlier
1111 // versions of cuDNN.
1112 use_cudnn_grouped_conv_ = OpDeterminismRequired() || dtype_ == DT_HALF;
1113#else
1114 use_cudnn_grouped_conv_ = false;
1115#endif
1116 }
1117
1118 void Compute(OpKernelContext* context) override {
1119 const Tensor& input = context->input(0);
1120 const Tensor& filter_sizes = context->input(1);
1121 OP_REQUIRES(
1122 context, TensorShapeUtils::IsVector(filter_sizes.shape()),
1123 errors::InvalidArgument(
1124 "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
1125 filter_sizes.dims()));
1126 TensorShape filter_shape;
1127 const int32* filter_sizes_data = filter_sizes.template flat<int32>().data();
1128 for (int i = 0; i < filter_sizes.NumElements(); ++i) {
1129 OP_REQUIRES(context, filter_sizes_data[i] >= 0,
1130 errors::InvalidArgument("Dimension ", i,
1131 " of filter_sizes must be >= 0"));
1132 OP_REQUIRES_OK(context,
1133 filter_shape.AddDimWithStatus(filter_sizes_data[i]));
1134 }
1135 const TensorShape& input_shape = input.shape();
1136
1137 EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropFilter");
1138 Tensor* filter_backprop = nullptr;
1139 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1140 {1}, 0, filter_shape, &filter_backprop));
1141
1142 // If there is nothing to compute, return.
1143 if (out_backprop.shape().num_elements() == 0) {
1144 return;
1145 }
1146
1147 // If in_depth==1, this operation is just a standard convolution.
1148 // Depthwise convolution is a special case of cuDNN's grouped convolution.
1149 bool use_cudnn = std::is_same<Device, GPUDevice>::value &&
1150 (in_depth == 1 ||
1151 (use_cudnn_grouped_conv_ &&
1152 (ShouldCudnnGroupedConvolutionBeUsed(
1153 filter_rows, filter_cols, in_depth, out_depth) ||
1154 OpDeterminismRequired())));
1155
1156 VLOG(2) << "DepthwiseConv2dNativeBackpropFilter: "
1157 << " Input: [" << batch << ", " << input_rows << ", " << input_cols
1158 << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
1159 << filter_cols << ", " << in_depth << ", " << depth_multiplier
1160 << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
1161 << ", " << out_depth << "], stride = " << stride_
1162 << ", pad_rows = " << pad_top << ", pad_cols = " << pad_left
1163 << ", Use cuDNN: " << use_cudnn;
1164
1165 if (use_cudnn) {
1166 // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
1167 //
1168 // | TensorFlow | cuDNN
1169 // --------------------------------------------------------------------
1170 // filter_out_depth | depth_multiplier | depth_multiplier * group_count
1171 // filter_in_depth | in_depth | in_depth / group_count
1172 //
1173 // For depthwise convolution, we have group_count == in_depth.
1174 int32_t filter_in_depth = 1;
1175 TensorShape shape =
1176 TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
1177 Tensor reshaped_filter(/*type=*/dtype_);
1178 OP_REQUIRES(
1179 context, reshaped_filter.CopyFrom(*filter_backprop, shape),
1180 errors::Internal(
1181 "Failed to reshape filter tensor for grouped convolution."));
1182
1183 // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
1184 // conv is supported.
1185 launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, out_backprop,
1186 input,
1187 /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
1188 padding_, explicit_paddings_, &reshaped_filter, data_format_);
1189 return;
1190 }
1191
1192 // For GPU inputs with type half, we cast inputs to float and outputs back
1193 // to half, as half implementation is slow and does not use full precision
1194 // accumulation in some cases.
1195 constexpr bool cast_to_float = std::is_same<T, Eigen::half>::value &&
1196 std::is_same<Device, GPUDevice>::value;
1197 using U = typename std::conditional<cast_to_float, float, T>::type;
1198 Tensor casted_out_backprop = out_backprop;
1199 Tensor casted_input = input;
1200 Tensor casted_filter_backprop = *filter_backprop;
1201 const Device& device = context->template eigen_device<Device>();
1202 if (cast_to_float) {
1203 functor::CastFunctor<Device, float, Eigen::half> cast;
1204 OP_REQUIRES_OK(context,
1205 context->allocate_temp(DT_FLOAT, out_backprop.shape(),
1206 &casted_out_backprop));
1207 cast(device, casted_out_backprop.template flat<float>(),
1208 out_backprop.template flat<Eigen::half>());
1209 OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, input.shape(),
1210 &casted_input));
1211 cast(device, casted_input.template flat<float>(),
1212 input.template flat<Eigen::half>());
1213 OP_REQUIRES_OK(context,
1214 context->allocate_temp(DT_FLOAT, filter_backprop->shape(),
1215 &casted_filter_backprop));
1216 }
1217
1218 auto out_backprop_ptr = casted_out_backprop.template flat<U>().data();
1219 auto input_ptr = casted_input.template flat<U>().data();
1220 auto filter_backprop_ptr = casted_filter_backprop.template flat<U>().data();
1221 LaunchDepthwiseConvBackpropFilterOp<Device, U>()(
1222 context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
1223 data_format_);
1224
1225 if (cast_to_float) {
1226 functor::CastFunctor<Device, Eigen::half, float> cast;
1227 const Tensor& casted_filter_backprop_const = casted_filter_backprop;
1228 cast(device, filter_backprop->template flat<Eigen::half>(),
1229 casted_filter_backprop_const.template flat<float>());
1230 }
1231 }
1232
1233 protected:
1234 bool use_cudnn_grouped_conv_;
1235
1236 private:
1237 std::vector<int32> strides_;
1238 Padding padding_;
1239 std::vector<int64_t> explicit_paddings_;
1240 TensorFormat data_format_;
1241 int64_t stride_;
1242
1243 // For in_depth == 1 and grouped convolutions.
1244 LaunchConv2DBackpropFilterOp<Device, T> launcher_;
1245 bool cudnn_use_autotune_;
1246 DataType dtype_;
1247
1248 TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
1249};
1250
1251#define REGISTER_CPU_KERNEL(T) \
1252 REGISTER_KERNEL_BUILDER( \
1253 Name("DepthwiseConv2dNativeBackpropFilter") \
1254 .Device(DEVICE_CPU) \
1255 .TypeConstraint<T>("T"), \
1256 DepthwiseConv2dNativeBackpropFilterOp<CPUDevice, T>);
1257TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
1258TF_CALL_half(REGISTER_CPU_KERNEL);
1259TF_CALL_float(REGISTER_CPU_KERNEL);
1260#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
1261TF_CALL_double(REGISTER_CPU_KERNEL);
1262#endif
1263#undef REGISTER_CPU_KERNEL
1264
1265#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1266#define REGISTER_GPU_KERNEL(T) \
1267 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
1268 .Device(DEVICE_GPU) \
1269 .TypeConstraint<T>("T") \
1270 .HostMemory("filter_sizes"), \
1271 DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>)
1272
1273TF_CALL_half(REGISTER_GPU_KERNEL);
1274TF_CALL_float(REGISTER_GPU_KERNEL);
1275TF_CALL_double(REGISTER_GPU_KERNEL);
1276#undef REGISTER_GPU_KERNEL
1277
1278#if CUDNN_VERSION >= 7000
1279template <typename T>
1280class DepthwiseConv2dGroupedConvBackpropFilterOp
1281 : public DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T> {
1282 public:
1283 DepthwiseConv2dGroupedConvBackpropFilterOp(OpKernelConstruction* context)
1284 : DepthwiseConv2dNativeBackpropFilterOp<GPUDevice, T>(context) {
1285 this->use_cudnn_grouped_conv_ = true;
1286 }
1287};
1288
1289#define REGISTER_GROUPED_CONV_KERNEL(T) \
1290 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNativeBackpropFilter") \
1291 .Device(DEVICE_GPU) \
1292 .TypeConstraint<T>("T") \
1293 .HostMemory("filter_sizes") \
1294 .Label("cudnn_grouped_convolution"), \
1295 DepthwiseConv2dGroupedConvBackpropFilterOp<T>)
1296
1297TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
1298TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
1299TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
1300#undef REGISTER_GROUPED_CONV_KERNEL
1301#endif // CUDNN_VERSION
1302#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1303
1304} // namespace tensorflow
1305