1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <atomic>
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#if GOOGLE_CUDA
23#include "third_party/gpus/cudnn/cudnn.h"
24#endif // GOOGLE_CUDA
25
26#include "tensorflow/core/kernels/conv_2d.h"
27#include "tensorflow/core/platform/stream_executor.h"
28#include "tensorflow/core/util/stream_executor_util.h"
29#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
30
31#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/register_types.h"
34#include "tensorflow/core/framework/tensor.h"
35#include "tensorflow/core/framework/tensor_types.h"
36#include "tensorflow/core/kernels/fill_functor.h"
37#include "tensorflow/core/kernels/fused_batch_norm_op.h"
38#include "tensorflow/core/kernels/redux_functor.h"
39#include "tensorflow/core/kernels/transpose_functor.h"
40#include "tensorflow/core/platform/blocking_counter.h"
41#include "tensorflow/core/util/env_var.h"
42#include "tensorflow/core/util/tensor_format.h"
43
44namespace tensorflow {
45using CPUDevice = Eigen::ThreadPoolDevice;
46using GPUDevice = Eigen::GpuDevice;
47
48namespace functor {
49
50#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51using se::DeviceMemory;
52using se::ScratchAllocator;
53using se::Stream;
54using se::port::StatusOr;
55#endif
56
57string ToString(FusedBatchNormActivationMode activation_mode) {
58 switch (activation_mode) {
59 case FusedBatchNormActivationMode::kIdentity:
60 return "Identity";
61 case FusedBatchNormActivationMode::kRelu:
62 return "Relu";
63 }
64}
65
66Status ParseActivationMode(OpKernelConstruction* context,
67 FusedBatchNormActivationMode* activation_mode) {
68 string activation_mode_str;
69 TF_RETURN_IF_ERROR(context->GetAttr("activation_mode", &activation_mode_str));
70
71 if (activation_mode_str == "Identity") {
72 *activation_mode = FusedBatchNormActivationMode::kIdentity;
73 return OkStatus();
74 }
75 if (activation_mode_str == "Relu") {
76 *activation_mode = FusedBatchNormActivationMode::kRelu;
77 return OkStatus();
78 }
79 return errors::InvalidArgument("Unsupported activation mode: ",
80 activation_mode_str);
81}
82
83// Functor used by FusedBatchNormOp to do the computations.
84template <typename Device, typename T, typename U, bool is_training>
85struct FusedBatchNorm;
86// Functor used by FusedBatchNormGradOp to do the computations when
87// is_training=True.
88template <typename Device, typename T, typename U>
89struct FusedBatchNormGrad;
90
91template <typename T, typename U>
92struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
93 void operator()(OpKernelContext* context, const Tensor& x_input,
94 const Tensor& scale_input, const Tensor& offset_input,
95 const Tensor& running_mean_input,
96 const Tensor& running_variance_input,
97 const Tensor* side_input, U epsilon, U exponential_avg_factor,
98 FusedBatchNormActivationMode activation_mode,
99 Tensor* y_output, Tensor* running_mean_output,
100 Tensor* running_var_output, Tensor* saved_batch_mean_output,
101 Tensor* saved_batch_var_output, TensorFormat tensor_format,
102 bool use_reserved_space) {
103 OP_REQUIRES(context, side_input == nullptr,
104 errors::Internal(
105 "The CPU implementation of FusedBatchNorm does not support "
106 "side input."));
107 OP_REQUIRES(context,
108 activation_mode == FusedBatchNormActivationMode::kIdentity,
109 errors::Internal("The CPU implementation of FusedBatchNorm "
110 "does not support activations."));
111
112 if (use_reserved_space) {
113 Tensor* dummy_reserve_space = nullptr;
114 OP_REQUIRES_OK(context,
115 context->allocate_output(5, {}, &dummy_reserve_space));
116 // Initialize the memory, to avoid sanitizer alerts.
117 dummy_reserve_space->flat<U>()(0) = U();
118 }
119
120 // If input is empty, return NaN mean/variance
121 if (x_input.shape().num_elements() == 0) {
122 functor::SetNanFunctor<CPUDevice, U> f;
123 f(context->eigen_device<CPUDevice>(), running_mean_output->flat<U>());
124 f(context->eigen_device<CPUDevice>(), running_var_output->flat<U>());
125 return;
126 }
127
128 Tensor transformed_x;
129 Tensor transformed_y;
130 if (tensor_format == FORMAT_NCHW) {
131 const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
132 const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
133 const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
134 const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
135 OP_REQUIRES_OK(context, context->allocate_temp(
136 DataTypeToEnum<T>::value,
137 ShapeFromFormat(FORMAT_NHWC, in_batch,
138 in_rows, in_cols, in_depths),
139 &transformed_x));
140 OP_REQUIRES_OK(context, context->allocate_temp(
141 DataTypeToEnum<T>::value,
142 ShapeFromFormat(FORMAT_NHWC, in_batch,
143 in_rows, in_cols, in_depths),
144 &transformed_y));
145 // Perform NCHW to NHWC
146 std::vector<int32> perm = {0, 2, 3, 1};
147 OP_REQUIRES_OK(
148 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
149 x_input, perm, &transformed_x));
150 } else {
151 transformed_x = x_input;
152 transformed_y = *y_output;
153 }
154 typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
155 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
156 typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
157 typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
158 typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
159 typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
160 typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
161 typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
162 typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
163 typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());
164
165 const CPUDevice& d = context->eigen_device<CPUDevice>();
166
167 const int depth = x.dimension(3);
168 const int size = x.size();
169 const int rest_size = size / depth;
170 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
171
172 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
173 one_by_depth.set(1, depth);
174 Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
175 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
176 bcast_spec.set(0, rest_size);
177
178 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
179 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
180 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
181 // This adjustment is for Bessel's correction
182 U rest_size_adjust =
183 static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
184
185 Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
186 Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);
187
188 batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
189 auto x_centered = x_rest_by_depth -
190 batch_mean.reshape(one_by_depth).broadcast(bcast_spec);
191
192 batch_variance.device(d) =
193 x_centered.square().sum(reduce_dims) * rest_size_inv;
194 auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
195 .eval()
196 .reshape(one_by_depth)
197 .broadcast(bcast_spec);
198 auto x_scaled = x_centered * scaling_factor;
199 auto x_shifted =
200 (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
201 .template cast<T>();
202
203 y.reshape(rest_by_depth).device(d) = x_shifted;
204 if (exponential_avg_factor == U(1.0)) {
205 saved_batch_var.device(d) = batch_variance;
206 saved_batch_mean.device(d) = batch_mean;
207 new_variance.device(d) = batch_variance * rest_size_adjust;
208 new_mean.device(d) = batch_mean;
209 } else {
210 U one_minus_factor = U(1) - exponential_avg_factor;
211 saved_batch_var.device(d) = batch_variance;
212 saved_batch_mean.device(d) = batch_mean;
213 new_variance.device(d) =
214 one_minus_factor * old_variance +
215 (exponential_avg_factor * rest_size_adjust) * batch_variance;
216 new_mean.device(d) =
217 one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
218 }
219
220 if (tensor_format == FORMAT_NCHW) {
221 // Perform NHWC to NCHW
222 const std::vector<int32> perm = {0, 3, 1, 2};
223 const Status s = ::tensorflow::DoTranspose(
224 context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
225 if (!s.ok()) {
226 context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
227 }
228 }
229 }
230};
231
232template <typename T, typename U>
233struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
234 void operator()(OpKernelContext* context, const Tensor& x_input,
235 const Tensor& scale_input, const Tensor& offset_input,
236 const Tensor& estimated_mean_input,
237 const Tensor& estimated_variance_input,
238 const Tensor* side_input, U epsilon, U exponential_avg_factor,
239 FusedBatchNormActivationMode activation_mode,
240 Tensor* y_output, Tensor* batch_mean_output,
241 Tensor* batch_var_output, Tensor* saved_mean_output,
242 Tensor* saved_var_output, TensorFormat tensor_format,
243 bool use_reserved_space) {
244 OP_REQUIRES(context, side_input == nullptr,
245 errors::Internal(
246 "The CPU implementation of FusedBatchNorm does not support "
247 "side input."));
248 OP_REQUIRES(context,
249 activation_mode == FusedBatchNormActivationMode::kIdentity,
250 errors::Internal("The CPU implementation of FusedBatchNorm "
251 "does not support activations."));
252
253 if (use_reserved_space) {
254 Tensor* dummy_reserve_space = nullptr;
255 OP_REQUIRES_OK(context,
256 context->allocate_output(5, {}, &dummy_reserve_space));
257 // Initialize the memory, to avoid sanitizer alerts.
258 dummy_reserve_space->flat<U>()(0) = U();
259 }
260
261 // If input is empty, return NaN mean/variance
262 if (x_input.shape().num_elements() == 0) {
263 functor::SetNanFunctor<CPUDevice, U> f;
264 f(context->eigen_device<CPUDevice>(), batch_mean_output->flat<U>());
265 f(context->eigen_device<CPUDevice>(), batch_var_output->flat<U>());
266 return;
267 }
268
269 Tensor transformed_x;
270 Tensor transformed_y;
271 if (tensor_format == FORMAT_NCHW) {
272 const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
273 const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
274 const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
275 const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
276 OP_REQUIRES_OK(context, context->allocate_temp(
277 DataTypeToEnum<T>::value,
278 ShapeFromFormat(FORMAT_NHWC, in_batch,
279 in_rows, in_cols, in_depths),
280 &transformed_x));
281 OP_REQUIRES_OK(context, context->allocate_temp(
282 DataTypeToEnum<T>::value,
283 ShapeFromFormat(FORMAT_NHWC, in_batch,
284 in_rows, in_cols, in_depths),
285 &transformed_y));
286 // Perform NCHW to NHWC
287 std::vector<int32> perm = {0, 2, 3, 1};
288 OP_REQUIRES_OK(
289 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
290 x_input, perm, &transformed_x));
291 } else {
292 transformed_x = x_input;
293 transformed_y = *y_output;
294 }
295 typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
296 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
297 typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
298 typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
299 typename TTypes<U>::ConstVec estimated_variance(
300 estimated_variance_input.vec<U>());
301 typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
302 typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
303 typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
304
305 const CPUDevice& d = context->eigen_device<CPUDevice>();
306
307 const int depth = x.dimension(3);
308 OP_REQUIRES(
309 context, depth != 0,
310 errors::Internal("The 4th element in the input shape cannot be 0."));
311 const int size = x.size();
312 const int rest_size = size / depth;
313 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
314 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
315 one_by_depth.set(1, depth);
316 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
317 bcast_spec.set(0, rest_size);
318
319 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
320 auto x_centered =
321 x_rest_by_depth -
322 estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
323 auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
324 .eval()
325 .reshape(one_by_depth)
326 .broadcast(bcast_spec);
327 auto x_scaled = x_centered * scaling_factor;
328 auto x_shifted =
329 (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
330 .template cast<T>();
331
332 y.reshape(rest_by_depth).device(d) = x_shifted;
333 batch_mean.device(d) = estimated_mean;
334 batch_variance.device(d) = estimated_variance;
335
336 if (tensor_format == FORMAT_NCHW) {
337 // Perform NHWC to NCHW
338 const std::vector<int32> perm = {0, 3, 1, 2};
339 const Status s = ::tensorflow::DoTranspose(
340 context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
341 if (!s.ok()) {
342 context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
343 }
344 }
345 }
346};
347
348template <typename T, typename U>
349struct FusedBatchNormGrad<CPUDevice, T, U> {
350 void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
351 const Tensor& x_input, const Tensor& scale_input,
352 const Tensor* offset_input, const Tensor& mean_input,
353 const Tensor& variance_input, const Tensor* y_input,
354 U epsilon, FusedBatchNormActivationMode activation_mode,
355 Tensor* x_backprop_output, Tensor* scale_backprop_output,
356 Tensor* offset_backprop_output,
357 Tensor* side_input_backprop_output, bool use_reserved_space,
358 TensorFormat tensor_format) {
359 OP_REQUIRES(context,
360 y_input == nullptr &&
361 activation_mode == FusedBatchNormActivationMode::kIdentity,
362 errors::Internal(
363 "The CPU implementation of FusedBatchNormGrad does not "
364 "support activations."));
365 OP_REQUIRES(context, side_input_backprop_output == nullptr,
366 errors::Internal("The CPU implementation of FusedBatchNormGrad "
367 "does not support side input."));
368
369 Tensor transformed_y_backprop_input;
370 Tensor transformed_x_input;
371 Tensor transformed_x_backprop_output;
372 if (tensor_format == FORMAT_NCHW) {
373 const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N');
374 const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H');
375 const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W');
376 const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C');
377 OP_REQUIRES_OK(context, context->allocate_temp(
378 DataTypeToEnum<T>::value,
379 ShapeFromFormat(FORMAT_NHWC, in_batch,
380 in_rows, in_cols, in_depths),
381 &transformed_y_backprop_input));
382 OP_REQUIRES_OK(context, context->allocate_temp(
383 DataTypeToEnum<T>::value,
384 ShapeFromFormat(FORMAT_NHWC, in_batch,
385 in_rows, in_cols, in_depths),
386 &transformed_x_input));
387 OP_REQUIRES_OK(context, context->allocate_temp(
388 DataTypeToEnum<T>::value,
389 ShapeFromFormat(FORMAT_NHWC, in_batch,
390 in_rows, in_cols, in_depths),
391 &transformed_x_backprop_output));
392 // Perform NCHW to NHWC
393 std::vector<int32> perm = {0, 2, 3, 1};
394 OP_REQUIRES_OK(
395 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
396 y_backprop_input, perm,
397 &transformed_y_backprop_input));
398 OP_REQUIRES_OK(context, ::tensorflow::DoTranspose(
399 context->eigen_device<CPUDevice>(), x_input,
400 perm, &transformed_x_input));
401 } else {
402 transformed_y_backprop_input = y_backprop_input;
403 transformed_x_input = x_input;
404 transformed_x_backprop_output = *x_backprop_output;
405 }
406 typename TTypes<T, 4>::Tensor y_backprop(
407 transformed_y_backprop_input.tensor<T, 4>());
408 typename TTypes<T, 4>::Tensor x(transformed_x_input.tensor<T, 4>());
409 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
410 typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
411 typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
412 typename TTypes<T, 4>::Tensor x_backprop(
413 transformed_x_backprop_output.tensor<T, 4>());
414 typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
415
416 // Note: the following formulas are used to compute the gradients for
417 // back propagation.
418 // x_backprop = scale * rsqrt(variance + epsilon) *
419 // [y_backprop - mean(y_backprop) - (x - mean(x)) *
420 // mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
421 // scale_backprop = sum(y_backprop *
422 // (x - mean(x)) * rsqrt(variance + epsilon))
423 // offset_backprop = sum(y_backprop)
424
425 const CPUDevice& d = context->eigen_device<CPUDevice>();
426 const int depth = x.dimension(3);
427 const int size = x.size();
428 const int rest_size = size / depth;
429 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
430 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
431 one_by_depth.set(1, depth);
432 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
433 bcast_spec.set(0, rest_size);
434
435 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
436 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
437
438 // Eigen is notoriously bad at reducing outer dimension, so we materialize
439 // all temporary tensors that require reduction, and then use Eigen redux
440 // functor, that is optimized for this particular task.
441 //
442 // All reductions are of this type: [rest_size, depth] -> [depth].
443 using ScalarSum = Eigen::internal::scalar_sum_op<U>;
444 const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
445 const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
446
447 auto scratch_dtype = DataTypeToEnum<U>::value;
448
449 // Allocate a temporary workspace of [depth] shape.
450 Tensor scratch_one_by_depth;
451 OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth},
452 &scratch_one_by_depth));
453
454 // Maybe allocate a temporary workspace of [rest_size, depth] shape.
455 Tensor scratch_rest_by_depth;
456 if (std::is_same<T, U>::value) {
457 OP_REQUIRES(context,
458 scratch_rest_by_depth.CopyFrom(transformed_x_backprop_output,
459 {rest_size, depth}),
460 errors::Internal("Failed to copy a tensor"));
461 } else {
462 OP_REQUIRES_OK(context,
463 context->allocate_temp(scratch_dtype, {rest_size, depth},
464 &scratch_rest_by_depth));
465 }
466
467 typename TTypes<U, 2>::Tensor scratch_tensor(
468 scratch_rest_by_depth.tensor<U, 2>());
469 typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>());
470
471 auto x_mean_rest_by_depth =
472 mean.reshape(one_by_depth).broadcast(bcast_spec);
473 auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
474 auto coef0_one_by_depth =
475 (variance.reshape(one_by_depth) + epsilon).rsqrt();
476 auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec);
477 auto x_scaled = x_centered * coef0_rest_by_depth;
478
479 auto y_backprop_rest_by_depth =
480 y_backprop.reshape(rest_by_depth).template cast<U>();
481
482 // Compute `scale_backprop_output`:
483 // scale_backprop =
484 // (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims)
485 scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled;
486 redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output);
487
488 // Compute 'offset_backprop_output':
489 // offset_backprop =
490 // y_backprop_rest_by_depth.sum(reduce_dims)
491 redux_sum_t(d, rest_by_depth, transformed_y_backprop_input,
492 offset_backprop_output);
493 auto y_backprop_sum = offset_backprop;
494
495 auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth);
496 auto y_backprop_mean_one_by_depth =
497 y_backprop_sum_one_by_depth * rest_size_inv;
498 auto y_backprop_mean_rest_by_depth =
499 y_backprop_mean_one_by_depth.broadcast(bcast_spec);
500 auto y_backprop_centered =
501 y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
502
503 // Compute expression:
504 // y_backprop_centered_mean =
505 // (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
506 scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
507 redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
508 auto y_backprop_centered_mean =
509 scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size);
510
511 auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth)
512 .broadcast(bcast_spec);
513 auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean)
514 .broadcast(bcast_spec);
515
516 x_backprop.reshape(rest_by_depth).device(d) =
517 (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
518
519 if (tensor_format == FORMAT_NCHW) {
520 // Perform NHWC to NCHW
521 std::vector<int32> perm = {0, 3, 1, 2};
522 OP_REQUIRES_OK(
523 context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
524 transformed_x_backprop_output,
525 perm, x_backprop_output));
526 }
527 }
528};
529
530template <typename T, typename U>
531struct FusedBatchNormFreezeGrad<CPUDevice, T, U> {
532 void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
533 const Tensor& x_input, const Tensor& scale_input,
534 const Tensor& pop_mean_input,
535 const Tensor& pop_variance_input, U epsilon,
536 Tensor* x_backprop_output, Tensor* scale_backprop_output,
537 Tensor* offset_backprop_output) {
538 typename TTypes<T, 4>::ConstTensor y_backprop(
539 y_backprop_input.tensor<T, 4>());
540 typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
541 typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
542 typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
543 typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
544 typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
545 typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
546
547 const int depth = pop_mean.dimension(0);
548 const int rest_size = input.size() / depth;
549
550 const CPUDevice& d = context->eigen_device<CPUDevice>();
551
552 // Allocate two temporary workspaces of [depth] shape.
553 Tensor scratch1_vec, scratch2_vec;
554 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
555 {depth}, &scratch1_vec));
556 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
557 {depth}, &scratch2_vec));
558
559 // Maybe allocate a temporary workspace of [rest_size, depth] shape.
560 Tensor scratch3_tensor;
561 if (std::is_same<T, U>::value) {
562 OP_REQUIRES(
563 context,
564 scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}),
565 errors::Internal("Failed to copy a tensor"));
566 } else {
567 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
568 {rest_size, depth},
569 &scratch3_tensor));
570 }
571
572 typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
573 typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
574 typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>());
575
576 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
577 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
578 one_by_depth.set(1, depth);
579 Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one;
580 rest_by_one.set(0, rest_size);
581
582 // Sum reduction along the 0th dimension using custom CPU functor.
583 using ScalarSum = Eigen::internal::scalar_sum_op<U>;
584 const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
585 const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
586
587 // offset_backprop = sum(y_backprop)
588 // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
589 // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
590
591 // NOTE: DEFAULT DEVICE comment is added to expression assignments that
592 // we don't want to be executed in a thread pool.
593
594 auto y_backprop_rest_by_depth =
595 y_backprop.reshape(rest_by_depth).template cast<U>();
596 auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
597
598 // offset_backprop = sum(y_backprop)
599 redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
600
601 // scratch1 = rsqrt(pop_var + epsilon)
602 scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt(); // DEFAULT DEVICE
603
604 // scratch2 = sum(y_backprop * (x - mean))
605 scratch3.device(d) =
606 y_backprop_rest_by_depth *
607 (input_rest_by_depth -
608 pop_mean.reshape(one_by_depth).broadcast(rest_by_one));
609 redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec);
610
611 x_backprop.reshape(rest_by_depth).device(d) =
612 (y_backprop_rest_by_depth *
613 ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth))
614 .broadcast(rest_by_one)))
615 .template cast<T>();
616 scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE
617 }
618};
619
620#if !GOOGLE_CUDA
621namespace {
622// See implementation under GOOGLE_CUDA #ifdef below.
623// This is a CUDA specific feature, do not enable it for non-CUDA builds
624bool BatchnormSpatialPersistentEnabled() { return false; }
625} // namespace
626#endif
627
628#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
629
630namespace {
631
632se::dnn::ActivationMode AsDnnActivationMode(
633 const FusedBatchNormActivationMode activation_mode) {
634 switch (activation_mode) {
635 case FusedBatchNormActivationMode::kIdentity:
636 return se::dnn::ActivationMode::kNone;
637 case FusedBatchNormActivationMode::kRelu:
638 return se::dnn::ActivationMode::kRelu;
639 }
640}
641
642#if GOOGLE_CUDA
643// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
644// `cuda_dnn.cc` for details.
645bool BatchnormSpatialPersistentEnabled() {
646#if CUDNN_VERSION >= 7402
647 static bool is_enabled = [] {
648 bool is_enabled = false;
649 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
650 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
651 /*default_val=*/false, &is_enabled));
652 return is_enabled;
653 }();
654 return is_enabled;
655#else
656 return false;
657#endif
658}
659#endif
660
661} // namespace
662
663template <typename U, typename T>
664DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
665 return DeviceMemory<U>::MakeFromByteSize(
666 tensor->template flat<T>().data(),
667 tensor->template flat<T>().size() * sizeof(T));
668}
669
670// A helper to allocate temporary scratch memory for Cudnn BatchNormEx ops. It
671// takes the ownership of the underlying memory. The expectation is that the
672// memory should be alive for the span of the Cudnn BatchNormEx itself.
673template <typename T>
674class CudnnBatchNormAllocatorInTemp : public ScratchAllocator {
675 public:
676 ~CudnnBatchNormAllocatorInTemp() override = default;
677
678 explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context)
679 : context_(context) {}
680
681 int64_t GetMemoryLimitInBytes() override {
682 return std::numeric_limits<int64_t>::max();
683 }
684
685 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
686 Tensor temporary_memory;
687 const DataType tf_data_type = DataTypeToEnum<T>::v();
688 int64_t allocate_count =
689 Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
690 Status allocation_status(context_->allocate_temp(
691 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
692 if (!allocation_status.ok()) {
693 return allocation_status;
694 }
695 // Hold the reference of the allocated tensors until the end of the
696 // allocator.
697 allocated_tensors_.push_back(temporary_memory);
698 total_byte_size_ += byte_size;
699 return DeviceMemory<uint8>::MakeFromByteSize(
700 temporary_memory.template flat<T>().data(),
701 temporary_memory.template flat<T>().size() * sizeof(T));
702 }
703
704 int64_t TotalByteSize() const { return total_byte_size_; }
705
706 Tensor get_allocated_tensor(int index) const {
707 return allocated_tensors_[index];
708 }
709
710 private:
711 int64_t total_byte_size_ = 0;
712 OpKernelContext* context_; // not owned
713 std::vector<Tensor> allocated_tensors_;
714};
715
716// A helper to allocate memory for Cudnn BatchNormEx as a kernel output. It is
717// used by forward pass kernel to feed the output to the backward pass.
718// The memory is expected to live long enough after the backward pass is
719// finished.
720template <typename T>
721class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
722 public:
723 ~CudnnBatchNormAllocatorInOutput() override {
724 if (!output_allocated) {
725 Tensor* dummy_reserve_space = nullptr;
726 OP_REQUIRES_OK(context_, context_->allocate_output(output_index_, {},
727 &dummy_reserve_space));
728 }
729 }
730
731 CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index)
732 : context_(context), output_index_(output_index) {}
733
734 int64_t GetMemoryLimitInBytes() override {
735 return std::numeric_limits<int64_t>::max();
736 }
737
738 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
739 output_allocated = true;
740 DCHECK(total_byte_size_ == 0)
741 << "Reserve space allocator can only be called once";
742 int64_t allocate_count =
743 Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
744
745 Tensor* temporary_memory = nullptr;
746 Status allocation_status(context_->allocate_output(
747 output_index_, TensorShape({allocate_count}), &temporary_memory));
748 if (!allocation_status.ok()) {
749 return allocation_status;
750 }
751 total_byte_size_ += byte_size;
752 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
753 temporary_memory->template flat<T>().data(),
754 temporary_memory->template flat<T>().size() * sizeof(T));
755 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
756 }
757
758 int64_t TotalByteSize() { return total_byte_size_; }
759
760 private:
761 int64_t total_byte_size_ = 0;
762 OpKernelContext* context_; // not owned
763 int output_index_;
764 bool output_allocated = false;
765};
766
767template <typename T, typename U, bool is_training>
768struct FusedBatchNorm<GPUDevice, T, U, is_training> {
769 void operator()(OpKernelContext* context, const Tensor& x,
770 const Tensor& scale, const Tensor& offset,
771 const Tensor& estimated_mean,
772 const Tensor& estimated_variance, const Tensor* side_input,
773 U epsilon, U exponential_avg_factor,
774 FusedBatchNormActivationMode activation_mode, Tensor* y,
775 Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
776 Tensor* saved_inv_var, TensorFormat tensor_format,
777 bool use_reserved_space) {
778 auto* stream = context->op_device_context()->stream();
779 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
780
781 const int64_t batch_size = GetTensorDim(x, tensor_format, 'N');
782 const int64_t channels = GetTensorDim(x, tensor_format, 'C');
783 const int64_t height = GetTensorDim(x, tensor_format, 'H');
784 const int64_t width = GetTensorDim(x, tensor_format, 'W');
785
786 // If use_reserved_space we have reserve_space_3 output (only in
787 // FusedBatchNormV3 op).
788
789#if GOOGLE_CUDA
790 // Check if cuDNN batch normalization has a fast NHWC implementation:
791 // (1) In inference mode it's always fast.
792 // (2) Tensorflow enabled batchnorm spatial persistence, we are called
793 // from
794 // FusedBatchNormV3, i.e. use_reserved_space is true.
795 const bool fast_nhwc_batch_norm =
796 !is_training ||
797 (BatchnormSpatialPersistentEnabled() &&
798 DataTypeToEnum<T>::value == DT_HALF && use_reserved_space);
799#else
800 // fast NHWC implementation is a CUDA only feature
801 const bool fast_nhwc_batch_norm = false;
802#endif
803
804 // If input tensor is in NHWC format, and we have a fast cuDNN
805 // implementation, there is no need to do data format conversion.
806 TensorFormat compute_format =
807 fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
808 : FORMAT_NCHW;
809
810 VLOG(2) << "FusedBatchNorm:"
811 << " batch_size: " << batch_size << " channels: " << channels
812 << " height: " << height << " width:" << width
813 << " x shape: " << x.shape().DebugString()
814 << " scale shape: " << scale.shape().DebugString()
815 << " offset shape: " << offset.shape().DebugString()
816 << " activation mode: " << ToString(activation_mode)
817 << " tensor format: " << ToString(tensor_format)
818 << " compute format: " << ToString(compute_format);
819
820 auto maybe_make_dummy_output = [context, use_reserved_space]() -> Status {
821 if (use_reserved_space) {
822 Tensor* dummy_reserve_space = nullptr;
823 return context->allocate_output(5, {}, &dummy_reserve_space);
824 }
825 return OkStatus();
826 };
827
828 // If input is empty, return NaN mean/variance
829 if (x.shape().num_elements() == 0) {
830 OP_REQUIRES_OK(context, maybe_make_dummy_output());
831 functor::SetNanFunctor<GPUDevice, U> f;
832 f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
833 f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
834 return;
835 }
836
837 // In inference mode we use custom CUDA kernel, because cuDNN does not
838 // support side input and activations for inference.
839 const bool has_side_input = side_input != nullptr;
840 const bool has_activation =
841 activation_mode != FusedBatchNormActivationMode::kIdentity;
842
843 if (!is_training && (has_side_input || has_activation)) {
844 OP_REQUIRES_OK(context, maybe_make_dummy_output());
845 FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
846
847 if (has_side_input) {
848 inference_functor(context, tensor_format, x.tensor<T, 4>(),
849 scale.vec<U>(), offset.vec<U>(),
850 estimated_mean.vec<U>(), estimated_variance.vec<U>(),
851 side_input->tensor<T, 4>(), epsilon, activation_mode,
852 y->tensor<T, 4>());
853 } else {
854 typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
855 inference_functor(context, tensor_format, x.tensor<T, 4>(),
856 scale.vec<U>(), offset.vec<U>(),
857 estimated_mean.vec<U>(), estimated_variance.vec<U>(),
858 empty_tensor, epsilon, activation_mode,
859 y->tensor<T, 4>());
860 }
861 return;
862 }
863
864 Tensor x_maybe_transformed = x;
865 Tensor x_transformed;
866 Tensor y_transformed;
867 se::DeviceMemory<T> y_ptr;
868
869 if (tensor_format == compute_format) {
870 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
871 } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
872 OP_REQUIRES_OK(context, context->allocate_temp(
873 DataTypeToEnum<T>::value,
874 ShapeFromFormat(compute_format, batch_size,
875 height, width, channels),
876 &x_transformed));
877 functor::NHWCToNCHW<GPUDevice, T, 4>()(
878 context->eigen_device<GPUDevice>(),
879 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
880 x_transformed.tensor<T, 4>());
881 x_maybe_transformed = x_transformed;
882
883 OP_REQUIRES_OK(context, context->allocate_temp(
884 DataTypeToEnum<T>::value,
885 ShapeFromFormat(compute_format, batch_size,
886 height, width, channels),
887 &y_transformed));
888 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
889 } else {
890 context->SetStatus(errors::Internal(
891 "Unsupported tensor format: ", ToString(tensor_format),
892 " and compute format: ", ToString(compute_format)));
893 return;
894 }
895
896 const se::dnn::DataLayout data_layout =
897 compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
898 : se::dnn::DataLayout::kBatchDepthYX;
899
900 se::dnn::BatchDescriptor x_desc;
901 x_desc.set_count(batch_size)
902 .set_feature_map_count(channels)
903 .set_height(height)
904 .set_width(width)
905 .set_layout(data_layout);
906
907 se::dnn::BatchDescriptor scale_offset_desc;
908 scale_offset_desc.set_count(1)
909 .set_feature_map_count(channels)
910 .set_height(1)
911 .set_width(1)
912 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
913
914 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
915 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
916 auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
917 auto estimated_mean_ptr =
918 StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
919 auto estimated_variance_ptr =
920 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
921 auto side_input_ptr =
922 side_input != nullptr
923 ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input)
924 : se::DeviceMemory<T>();
925 auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
926
927 auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
928 auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
929 auto saved_inv_var_ptr =
930 StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
931
932 std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>>
933 reserve_space_allocator;
934 std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
935 workspace_allocator;
936 if (use_reserved_space) {
937 reserve_space_allocator.reset(
938 new functor::CudnnBatchNormAllocatorInOutput<U>(context, 5));
939 workspace_allocator.reset(
940 new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
941 }
942 if (!batch_mean->SharesBufferWith(estimated_mean) &&
943 exponential_avg_factor != 1.0f) {
944 OP_REQUIRES(
945 context,
946 stream
947 ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr,
948 estimated_mean.NumElements() * sizeof(U))
949 .ok(),
950 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
951 "from device"));
952 }
953 if (!batch_var->SharesBufferWith(estimated_variance) &&
954 exponential_avg_factor != 1.0f) {
955 OP_REQUIRES(
956 context,
957 stream
958 ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr,
959 estimated_variance.NumElements() * sizeof(U))
960 .ok(),
961 errors::Internal("MatrixTriangularSolveOp: failed to copy rhs "
962 "from device"));
963 }
964 bool cudnn_launch_status =
965 stream
966 ->ThenBatchNormalizationForward(
967 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
968 estimated_variance_ptr, side_input_ptr, x_desc,
969 scale_offset_desc, static_cast<double>(epsilon),
970 static_cast<double>(exponential_avg_factor),
971 AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
972 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
973 is_training, reserve_space_allocator.get(),
974 workspace_allocator.get())
975 .ok();
976
977 if (!cudnn_launch_status) {
978 context->SetStatus(
979 errors::Internal("cuDNN launch failure : input shape (",
980 x.shape().DebugString(), ")"));
981 return;
982 }
983
984 if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
985 functor::NCHWToNHWC<GPUDevice, T, 4>()(
986 context->eigen_device<GPUDevice>(),
987 const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
988 y->tensor<T, 4>());
989 }
990 }
991};
992
993template <typename T, typename U>
994struct FusedBatchNormGrad<GPUDevice, T, U> {
995 void operator()(OpKernelContext* context, const Tensor& y_backprop,
996 const Tensor& x, const Tensor& scale, const Tensor* offset,
997 const Tensor& mean, const Tensor& inv_variance,
998 const Tensor* y, U epsilon,
999 FusedBatchNormActivationMode activation_mode,
1000 Tensor* x_backprop, Tensor* scale_backprop,
1001 Tensor* offset_backprop, Tensor* side_input_backprop,
1002 bool use_reserved_space, TensorFormat tensor_format) {
1003 auto* stream = context->op_device_context()->stream();
1004 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
1005
1006 const int64_t batch_size = GetTensorDim(x, tensor_format, 'N');
1007 const int64_t channels = GetTensorDim(x, tensor_format, 'C');
1008 const int64_t height = GetTensorDim(x, tensor_format, 'H');
1009 const int64_t width = GetTensorDim(x, tensor_format, 'W');
1010
1011#if GOOGLE_CUDA
1012 // Check if cuDNN batch normalization has a fast NHWC implementation:
1013 // (1) Tensorflow enabled batchnorm spatial persistence, and
1014 // FusedBatchNormGradV3 passed non-null reserve space and allocator.
1015 const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() &&
1016 DataTypeToEnum<T>::value == DT_HALF &&
1017 use_reserved_space;
1018#else
1019 // fast NHWC implementation is a CUDA only feature
1020 const bool fast_nhwc_batch_norm = false;
1021#endif
1022
1023 // If input tensor is in NHWC format, and we have a fast cuDNN
1024 // implementation, there is no need to do data format conversion.
1025 TensorFormat compute_format =
1026 fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
1027 : FORMAT_NCHW;
1028
1029 VLOG(2) << "FusedBatchNormGrad:"
1030 << " batch_size: " << batch_size << " channels: " << channels
1031 << " height: " << height << " width: " << width
1032 << " y_backprop shape: " << y_backprop.shape().DebugString()
1033 << " x shape: " << x.shape().DebugString()
1034 << " scale shape: " << scale.shape().DebugString()
1035 << " activation mode: " << ToString(activation_mode)
1036 << " tensor format: " << ToString(tensor_format)
1037 << " compute format: " << ToString(compute_format);
1038
1039 // Inputs
1040 Tensor y_backprop_maybe_transformed = y_backprop;
1041 Tensor x_maybe_transformed = x;
1042 Tensor y_backprop_transformed;
1043 Tensor x_transformed;
1044
1045 // Outputs
1046 Tensor x_backprop_transformed;
1047 se::DeviceMemory<T> x_backprop_ptr;
1048
1049 if (tensor_format == compute_format) {
1050 x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
1051 } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1052 // Transform inputs from 'NHWC' to 'NCHW'
1053 OP_REQUIRES_OK(context, context->allocate_temp(
1054 DataTypeToEnum<T>::value,
1055 ShapeFromFormat(FORMAT_NCHW, batch_size,
1056 height, width, channels),
1057 &y_backprop_transformed));
1058 functor::NHWCToNCHW<GPUDevice, T, 4>()(
1059 context->eigen_device<GPUDevice>(),
1060 const_cast<const Tensor&>(y_backprop_maybe_transformed)
1061 .tensor<T, 4>(),
1062 y_backprop_transformed.tensor<T, 4>());
1063 y_backprop_maybe_transformed = y_backprop_transformed;
1064
1065 OP_REQUIRES_OK(context, context->allocate_temp(
1066 DataTypeToEnum<T>::value,
1067 ShapeFromFormat(FORMAT_NCHW, batch_size,
1068 height, width, channels),
1069 &x_transformed));
1070 functor::NHWCToNCHW<GPUDevice, T, 4>()(
1071 context->eigen_device<GPUDevice>(),
1072 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
1073 x_transformed.tensor<T, 4>());
1074 x_maybe_transformed = x_transformed;
1075
1076 // Allocate memory for transformed outputs in 'NCHW'
1077 OP_REQUIRES_OK(context, context->allocate_temp(
1078 DataTypeToEnum<T>::value,
1079 ShapeFromFormat(FORMAT_NCHW, batch_size,
1080 height, width, channels),
1081 &x_backprop_transformed));
1082 x_backprop_ptr =
1083 StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
1084 } else {
1085 context->SetStatus(errors::Internal(
1086 "Unsupported tensor format: ", ToString(tensor_format),
1087 " and compute format: ", ToString(compute_format)));
1088 return;
1089 }
1090
1091 const se::dnn::DataLayout data_layout =
1092 compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
1093 : se::dnn::DataLayout::kBatchDepthYX;
1094
1095 se::dnn::BatchDescriptor x_desc;
1096 x_desc.set_count(batch_size)
1097 .set_feature_map_count(channels)
1098 .set_height(height)
1099 .set_width(width)
1100 .set_layout(data_layout);
1101
1102 se::dnn::BatchDescriptor scale_offset_desc;
1103 scale_offset_desc.set_count(1)
1104 .set_feature_map_count(channels)
1105 .set_height(1)
1106 .set_width(1)
1107 .set_layout(se::dnn::DataLayout::kBatchDepthYX);
1108
1109 auto y_backprop_ptr =
1110 StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
1111 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
1112 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
1113 auto offset_ptr = offset != nullptr
1114 ? StreamExecutorUtil::AsDeviceMemory<U>(*offset)
1115 : se::DeviceMemory<U>();
1116 auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
1117 auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
1118 auto y_ptr = y != nullptr ? StreamExecutorUtil::AsDeviceMemory<T>(*y)
1119 : se::DeviceMemory<T>();
1120 auto scale_backprop_ptr =
1121 StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
1122 auto offset_backprop_ptr =
1123 StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
1124 auto side_input_backprop_ptr =
1125 side_input_backprop != nullptr
1126 ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input_backprop)
1127 : se::DeviceMemory<T>();
1128
1129 std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
1130 workspace_allocator;
1131 DeviceMemory<uint8>* reserve_space_data_ptr = nullptr;
1132 DeviceMemory<uint8> reserve_space_data;
1133#if CUDNN_VERSION >= 7402
1134 if (use_reserved_space) {
1135 const Tensor& reserve_space = context->input(5);
1136 workspace_allocator.reset(
1137 new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
1138
1139 // the cudnn kernel outputs inverse variance in forward and reuse it in
1140 // backward
1141 if (reserve_space.dims() != 0) {
1142 reserve_space_data = functor::CastDeviceMemory<uint8, U>(
1143 const_cast<Tensor*>(&reserve_space));
1144 reserve_space_data_ptr = &reserve_space_data;
1145 }
1146 }
1147#endif // CUDNN_VERSION >= 7402
1148
1149 bool cudnn_launch_status =
1150 stream
1151 ->ThenBatchNormalizationBackward(
1152 y_backprop_ptr, x_ptr, scale_ptr, offset_ptr, mean_ptr,
1153 inv_variance_ptr, y_ptr, x_desc, scale_offset_desc,
1154 static_cast<double>(epsilon),
1155 AsDnnActivationMode(activation_mode), &x_backprop_ptr,
1156 &scale_backprop_ptr, &offset_backprop_ptr,
1157 &side_input_backprop_ptr, reserve_space_data_ptr,
1158 workspace_allocator.get())
1159 .ok();
1160
1161 if (!cudnn_launch_status) {
1162 context->SetStatus(
1163 errors::Internal("cuDNN launch failure : input shape (",
1164 x.shape().DebugString(), ")"));
1165 }
1166 if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
1167 functor::NCHWToNHWC<GPUDevice, T, 4>()(
1168 context->eigen_device<GPUDevice>(),
1169 const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
1170 x_backprop->tensor<T, 4>());
1171 }
1172 }
1173};
1174
1175// Forward declarations of the functor specializations for GPU.
1176#define DECLARE_GPU_SPEC(T, U) \
1177 template <> \
1178 void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \
1179 OpKernelContext* context, const Tensor& y_backprop_input, \
1180 const Tensor& x_input, const Tensor& scale_input, \
1181 const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
1182 Tensor* x_backprop_output, Tensor* scale_backprop_output, \
1183 Tensor* offset_backprop_output); \
1184 extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; \
1185 template <> \
1186 void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()( \
1187 OpKernelContext* context, TensorFormat tensor_format, \
1188 typename TTypes<T, 4>::ConstTensor in, \
1189 typename TTypes<U>::ConstVec scale, typename TTypes<U>::ConstVec offset, \
1190 typename TTypes<U>::ConstVec estimated_mean, \
1191 typename TTypes<U>::ConstVec estimated_variance, \
1192 typename TTypes<T, 4>::ConstTensor side_input, U epsilon, \
1193 FusedBatchNormActivationMode activation_mode, \
1194 typename TTypes<T, 4>::Tensor out); \
1195 extern template struct FusedBatchNormInferenceFunctor<GPUDevice, T, U>;
1196
1197DECLARE_GPU_SPEC(float, float);
1198DECLARE_GPU_SPEC(Eigen::half, float);
1199
1200#undef DECLARE_GPU_SPEC
1201
1202#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1203} // namespace functor
1204
1205template <typename Device, typename T, typename U>
1206class FusedBatchNormOpBase : public OpKernel {
1207 using FbnActivationMode = functor::FusedBatchNormActivationMode;
1208
1209 protected:
1210 explicit FusedBatchNormOpBase(OpKernelConstruction* context,
1211 bool is_batch_norm_ex = false)
1212 : OpKernel(context) {
1213 float epsilon;
1214 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1215 epsilon_ = U(epsilon);
1216 float exponential_avg_factor;
1217 OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
1218 &exponential_avg_factor));
1219 exponential_avg_factor_ = U(exponential_avg_factor);
1220 string tensor_format;
1221 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1222 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1223 errors::InvalidArgument("Invalid data format"));
1224 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1225
1226 if (!is_batch_norm_ex) {
1227 has_side_input_ = false;
1228 activation_mode_ = FbnActivationMode::kIdentity;
1229 } else {
1230 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1231
1232 int num_side_inputs;
1233 OP_REQUIRES_OK(context,
1234 context->GetAttr("num_side_inputs", &num_side_inputs));
1235 OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1236 errors::InvalidArgument(
1237 "FusedBatchNorm accepts at most one side input."));
1238 has_side_input_ = (num_side_inputs == 1);
1239 if (has_side_input_ && is_training_) {
1240 OP_REQUIRES(
1241 context, activation_mode_ != FbnActivationMode::kIdentity,
1242 errors::InvalidArgument("Identity activation is not supported with "
1243 "non-empty side input"));
1244 }
1245 }
1246
1247 if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1248 // NOTE(ezhulenev): Following requirements are coming from implementation
1249 // details of cudnnBatchNormalizationForwardTrainingEx used in training
1250 // mode. In inference mode we call custom CUDA kernel that supports all
1251 // data formats and data types.
1252 OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1253 errors::InvalidArgument("FusedBatchNorm with activation "
1254 "supports only DT_HALF data type."));
1255 OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1256 errors::InvalidArgument("FusedBatchNorm with activation "
1257 "supports only NHWC tensor format."));
1258 OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1259 errors::InvalidArgument(
1260 "FusedBatchNorm with activation must run with cuDNN "
1261 "spatial persistence mode enabled."));
1262 }
1263 }
1264
1265 // If use_reserved_space is true, we need to handle the 5th output (a reserved
1266 // space) and a new cudnn batch norm will be called if the version > 7.4.2.
1267 // If use_reserved_space is false, we don't have 5th output.
1268 virtual void ComputeWithReservedSpace(OpKernelContext* context,
1269 bool use_reserved_space) {
1270 Tensor x = context->input(0);
1271 const Tensor& scale = context->input(1);
1272 const Tensor& offset = context->input(2);
1273 const Tensor& estimated_mean = context->input(3);
1274 const Tensor& estimated_variance = context->input(4);
1275 const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr;
1276
1277 OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1278 errors::InvalidArgument("input must be 4 or 5-dimensional",
1279 x.shape().DebugString()));
1280 OP_REQUIRES(context, scale.dims() == 1,
1281 errors::InvalidArgument("scale must be 1-dimensional",
1282 scale.shape().DebugString()));
1283 OP_REQUIRES(context, offset.dims() == 1,
1284 errors::InvalidArgument("offset must be 1-dimensional",
1285 offset.shape().DebugString()));
1286 OP_REQUIRES(context, estimated_mean.dims() == 1,
1287 errors::InvalidArgument("estimated_mean must be 1-dimensional",
1288 estimated_mean.shape().DebugString()));
1289 OP_REQUIRES(
1290 context, estimated_variance.dims() == 1,
1291 errors::InvalidArgument("estimated_variance must be 1-dimensional",
1292 estimated_variance.shape().DebugString()));
1293 bool use_reshape = (x.dims() == 5);
1294 auto x_shape = x.shape();
1295 TensorShape dest_shape;
1296 if (use_reshape) {
1297 const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N');
1298 int64_t in_planes = GetTensorDim(x, tensor_format_, '0');
1299 int64_t in_rows = GetTensorDim(x, tensor_format_, '1');
1300 int64_t in_cols = GetTensorDim(x, tensor_format_, '2');
1301 const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C');
1302 dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1303 {{in_planes, in_rows * in_cols}}, in_depth);
1304 OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1305 errors::InvalidArgument("Error during tensor copy."));
1306 }
1307
1308 const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
1309 OP_REQUIRES(
1310 context, scale.NumElements() == num_channels,
1311 errors::InvalidArgument("scale must have the same number of elements "
1312 "as the channels of x, got ",
1313 scale.NumElements(), " and ", num_channels));
1314 OP_REQUIRES(
1315 context, offset.NumElements() == num_channels,
1316 errors::InvalidArgument("offset must have the same number of elements "
1317 "as the channels of x, got ",
1318 offset.NumElements(), " and ", num_channels));
1319 if (!is_training_ || exponential_avg_factor_ != 1.) {
1320 std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1"
1321 : "When is_training=false";
1322 OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
1323 errors::InvalidArgument(
1324 prefix_msg,
1325 ", mean must have the same number "
1326 "of elements as the channels of x, got ",
1327 estimated_mean.NumElements(), " and ", num_channels));
1328 OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
1329 errors::InvalidArgument(
1330 prefix_msg,
1331 ", variance must have the same "
1332 "number of elements as the channels of x, got ",
1333 estimated_variance.NumElements(), " and ", num_channels));
1334 }
1335
1336 if (has_side_input_) {
1337 OP_REQUIRES(context, side_input->shape() == x.shape(),
1338 errors::InvalidArgument(
1339 "side_input shape must be equal to input shape: ",
1340 side_input->shape().DebugString(),
1341 " != ", x.shape().DebugString()));
1342 }
1343
1344 if (activation_mode_ != FbnActivationMode::kIdentity) {
1345 // NOTE(ezhulenev): This requirement is coming from implementation
1346 // details of cudnnBatchNormalizationForwardTrainingEx.
1347 OP_REQUIRES(
1348 context, !is_training_ || num_channels % 4 == 0,
1349 errors::InvalidArgument("FusedBatchNorm with activation requires "
1350 "channel dimension to be a multiple of 4."));
1351 }
1352
1353 Tensor* y = nullptr;
1354 auto alloc_shape = use_reshape ? dest_shape : x_shape;
1355 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1356 {0}, 0, alloc_shape, &y));
1357
1358 Tensor* batch_mean = nullptr;
1359 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1360 {3}, 1, scale.shape(), &batch_mean));
1361 Tensor* batch_var = nullptr;
1362 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
1363 {4}, 2, scale.shape(), &batch_var));
1364 Tensor* saved_mean = nullptr;
1365 OP_REQUIRES_OK(context,
1366 context->allocate_output(3, scale.shape(), &saved_mean));
1367 Tensor* saved_maybe_inv_var = nullptr;
1368 OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
1369 &saved_maybe_inv_var));
1370
1371 if (is_training_) {
1372 functor::FusedBatchNorm<Device, T, U, true>()(
1373 context, x, scale, offset, estimated_mean, estimated_variance,
1374 side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1375 batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1376 tensor_format_, use_reserved_space);
1377 } else {
1378 functor::FusedBatchNorm<Device, T, U, false>()(
1379 context, x, scale, offset, estimated_mean, estimated_variance,
1380 side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
1381 batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
1382 tensor_format_, use_reserved_space);
1383 }
1384 if (use_reshape) {
1385 OP_REQUIRES(context, y->CopyFrom(*y, x_shape),
1386 errors::InvalidArgument("Error during tensor copy."));
1387 }
1388 }
1389
1390 private:
1391 U epsilon_;
1392 U exponential_avg_factor_;
1393 TensorFormat tensor_format_;
1394 bool is_training_;
1395 bool has_side_input_;
1396 FbnActivationMode activation_mode_;
1397};
1398
1399template <typename Device, typename T, typename U>
1400class FusedBatchNormOp : public FusedBatchNormOpBase<Device, T, U> {
1401 public:
1402 explicit FusedBatchNormOp(OpKernelConstruction* context)
1403 : FusedBatchNormOpBase<Device, T, U>(context) {}
1404
1405 void Compute(OpKernelContext* context) override {
1406 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1407 false);
1408 }
1409};
1410
1411template <typename Device, typename T, typename U>
1412class FusedBatchNormOpV3 : public FusedBatchNormOpBase<Device, T, U> {
1413 public:
1414 explicit FusedBatchNormOpV3(OpKernelConstruction* context)
1415 : FusedBatchNormOpBase<Device, T, U>(context) {}
1416
1417 void Compute(OpKernelContext* context) override {
1418 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1419 }
1420};
1421
1422template <typename Device, typename T, typename U>
1423class FusedBatchNormOpEx : public FusedBatchNormOpBase<Device, T, U> {
1424 static constexpr bool kWithSideInputAndActivation = true;
1425
1426 public:
1427 explicit FusedBatchNormOpEx(OpKernelConstruction* context)
1428 : FusedBatchNormOpBase<Device, T, U>(context,
1429 kWithSideInputAndActivation) {}
1430
1431 void Compute(OpKernelContext* context) override {
1432 FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true);
1433 }
1434};
1435
1436template <typename Device, typename T, typename U>
1437class FusedBatchNormGradOpBase : public OpKernel {
1438 using FbnActivationMode = functor::FusedBatchNormActivationMode;
1439
1440 protected:
1441 explicit FusedBatchNormGradOpBase(OpKernelConstruction* context,
1442 bool is_batch_norm_grad_ex = false)
1443 : OpKernel(context) {
1444 float epsilon;
1445 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
1446 epsilon_ = U(epsilon);
1447 string tensor_format;
1448 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
1449 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
1450 errors::InvalidArgument("Invalid data format"));
1451 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1452 if (!is_batch_norm_grad_ex) {
1453 has_side_input_ = false;
1454 activation_mode_ = FbnActivationMode::kIdentity;
1455 } else {
1456 OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_));
1457
1458 int num_side_inputs;
1459 OP_REQUIRES_OK(context,
1460 context->GetAttr("num_side_inputs", &num_side_inputs));
1461 OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1,
1462 errors::InvalidArgument(
1463 "FusedBatchNormGrad accepts at most one side input."));
1464 has_side_input_ = (num_side_inputs == 1);
1465 if (has_side_input_ && is_training_) {
1466 OP_REQUIRES(
1467 context, activation_mode_ != FbnActivationMode::kIdentity,
1468 errors::InvalidArgument("Identity activation is not supported with "
1469 "non-empty side input"));
1470 }
1471 }
1472
1473 if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) {
1474 // NOTE(kaixih@nvidia): Following requirements are coming from
1475 // implementation details of cudnnBatchNormalizationBackwardEx used in
1476 // training mode.
1477 OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
1478 errors::InvalidArgument("FusedBatchNormGrad with activation "
1479 "supports only DT_HALF data type."));
1480 OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1481 errors::InvalidArgument("FusedBatchNormGrad with activation "
1482 "supports only NHWC tensor format."));
1483 OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(),
1484 errors::InvalidArgument(
1485 "FusedBatchNormGrad with activation must run with cuDNN "
1486 "spatial persistence mode enabled."));
1487 }
1488 }
1489
1490 virtual void ComputeWithReservedSpace(OpKernelContext* context,
1491 bool use_reserved_space) {
1492 Tensor y_backprop = context->input(0);
1493 Tensor x = context->input(1);
1494 const Tensor& scale = context->input(2);
1495 // When is_training=True, batch mean and variance/inverted variance are
1496 // saved in the forward pass to be reused here. When is_training=False,
1497 // population mean and variance need to be forwarded here to compute the
1498 // gradients.
1499 const Tensor& saved_mean_or_pop_mean = context->input(3);
1500 // The Eigen implementation saves variance in the forward pass, while cuDNN
1501 // saves inverted variance.
1502 const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
1503 bool use_activation = activation_mode_ != FbnActivationMode::kIdentity;
1504 const Tensor* offset = use_activation ? &context->input(6) : nullptr;
1505 const Tensor* y = use_activation ? &context->input(7) : nullptr;
1506
1507 OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5,
1508 errors::InvalidArgument("input must be 4 or 5-dimensional",
1509 y_backprop.shape().DebugString()));
1510 OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5,
1511 errors::InvalidArgument("input must be 4 or 5-dimensional",
1512 x.shape().DebugString()));
1513 OP_REQUIRES(context, scale.dims() == 1,
1514 errors::InvalidArgument("scale must be 1-dimensional",
1515 scale.shape().DebugString()));
1516 OP_REQUIRES(
1517 context, saved_mean_or_pop_mean.dims() == 1,
1518 errors::InvalidArgument("saved mean must be 1-dimensional",
1519 saved_mean_or_pop_mean.shape().DebugString()));
1520 OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
1521 errors::InvalidArgument(
1522 "saved variance must be 1-dimensional",
1523 saved_maybe_inv_var_or_pop_var.shape().DebugString()));
1524 OP_REQUIRES(
1525 context, x.shape() == y_backprop.shape(),
1526 errors::InvalidArgument(
1527 "x and y_backprop must have same shape, but x has shape ",
1528 x.shape(), " and y_backprop has shape ", y_backprop.shape()));
1529 if (use_activation) {
1530 OP_REQUIRES(
1531 context, x.dim_size(3) % 4 == 0,
1532 errors::InvalidArgument("FusedBatchNormGrad with activation requires "
1533 "channel dimension to be a multiple of 4."));
1534 OP_REQUIRES(context, offset->dims() == 1,
1535 errors::InvalidArgument("offset must be 1-dimensional",
1536 offset->shape().DebugString()));
1537 }
1538 bool use_reshape = (x.dims() == 5);
1539 auto x_shape = x.shape();
1540 TensorShape dest_shape;
1541 if (use_reshape) {
1542 const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N');
1543 int64_t in_planes = GetTensorDim(x, tensor_format_, '0');
1544 int64_t in_rows = GetTensorDim(x, tensor_format_, '1');
1545 int64_t in_cols = GetTensorDim(x, tensor_format_, '2');
1546 const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C');
1547 dest_shape = ShapeFromFormat(tensor_format_, in_batch,
1548 {{in_planes, in_rows * in_cols}}, in_depth);
1549 OP_REQUIRES(context, x.CopyFrom(x, dest_shape),
1550 errors::InvalidArgument("Error during tensor copy."));
1551 OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape),
1552 errors::InvalidArgument("Error during tensor copy."));
1553 }
1554
1555 const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
1556 OP_REQUIRES(
1557 context, scale.NumElements() == num_channels,
1558 errors::InvalidArgument("scale must have the same number of elements "
1559 "as the channels of x, got ",
1560 scale.NumElements(), " and ", num_channels));
1561 OP_REQUIRES(
1562 context, saved_mean_or_pop_mean.NumElements() == num_channels,
1563 errors::InvalidArgument("reserve_space_1 must have the same number of "
1564 "elements as the channels of x, got ",
1565 saved_mean_or_pop_mean.NumElements(), " and ",
1566 num_channels));
1567 OP_REQUIRES(
1568 context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels,
1569 errors::InvalidArgument("reserve_space_2 must have the same number of "
1570 "elements as the channels of x, got ",
1571 saved_maybe_inv_var_or_pop_var.NumElements(),
1572 " and ", num_channels));
1573
1574 Tensor* x_backprop = nullptr;
1575 auto alloc_shape = use_reshape ? dest_shape : x_shape;
1576 OP_REQUIRES_OK(context,
1577 context->allocate_output(0, alloc_shape, &x_backprop));
1578
1579 const TensorShape& scale_offset_shape = scale.shape();
1580 Tensor* scale_backprop = nullptr;
1581 OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
1582 &scale_backprop));
1583 Tensor* offset_backprop = nullptr;
1584 OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
1585 &offset_backprop));
1586 // Two placeholders for estimated_mean and estimated_variance, which are
1587 // used for inference and thus not needed here for gradient computation.
1588 // They are filled with zeros so as to avoid NaN outputs.
1589 Tensor* placeholder_1 = nullptr;
1590 OP_REQUIRES_OK(
1591 context, context->allocate_output(3, TensorShape({0}), &placeholder_1));
1592 Tensor* placeholder_2 = nullptr;
1593 OP_REQUIRES_OK(
1594 context, context->allocate_output(4, TensorShape({0}), &placeholder_2));
1595
1596 Tensor* side_input_backprop = nullptr;
1597 if (has_side_input_) {
1598 OP_REQUIRES_OK(context, context->allocate_output(5, alloc_shape,
1599 &side_input_backprop));
1600 }
1601
1602 // If input is empty, set gradients w.r.t scale/offset to zero.
1603 if (x.shape().num_elements() == 0) {
1604 functor::SetZeroFunctor<Device, U> f;
1605 f(context->eigen_device<Device>(), scale_backprop->flat<U>());
1606 f(context->eigen_device<Device>(), offset_backprop->flat<U>());
1607 return;
1608 }
1609
1610 if (is_training_) {
1611 functor::FusedBatchNormGrad<Device, T, U>()(
1612 context, y_backprop, x, scale, offset, saved_mean_or_pop_mean,
1613 saved_maybe_inv_var_or_pop_var, y, epsilon_, activation_mode_,
1614 x_backprop, scale_backprop, offset_backprop, side_input_backprop,
1615 use_reserved_space, tensor_format_);
1616 } else {
1617 OP_REQUIRES(
1618 context,
1619 activation_mode_ == FbnActivationMode::kIdentity && !has_side_input_,
1620 errors::InvalidArgument(
1621 "FusedBatchNormGrad with activation is only supported "
1622 "when is_training=True."));
1623 // Necessary layout conversion is currently done in python.
1624 OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC,
1625 errors::InvalidArgument(
1626 "The implementation of "
1627 "FusedBatchNormGrad with is_training=False only support "
1628 "NHWC tensor format for now."));
1629 functor::FusedBatchNormFreezeGrad<Device, T, U>()(
1630 context, y_backprop, x, scale, saved_mean_or_pop_mean,
1631 saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
1632 offset_backprop);
1633 }
1634 if (use_reshape) {
1635 OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape),
1636 errors::InvalidArgument("Error during tensor copy."));
1637 }
1638 }
1639
1640 private:
1641 U epsilon_;
1642 TensorFormat tensor_format_;
1643 bool is_training_;
1644 bool has_side_input_;
1645 FbnActivationMode activation_mode_;
1646};
1647
1648template <typename Device, typename T, typename U>
1649class FusedBatchNormGradOp : public FusedBatchNormGradOpBase<Device, T, U> {
1650 public:
1651 explicit FusedBatchNormGradOp(OpKernelConstruction* context)
1652 : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1653
1654 void Compute(OpKernelContext* context) override {
1655 FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1656 false);
1657 }
1658};
1659
1660template <typename Device, typename T, typename U>
1661class FusedBatchNormGradOpV3 : public FusedBatchNormGradOpBase<Device, T, U> {
1662 public:
1663 explicit FusedBatchNormGradOpV3(OpKernelConstruction* context)
1664 : FusedBatchNormGradOpBase<Device, T, U>(context) {}
1665
1666 void Compute(OpKernelContext* context) override {
1667 FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1668 true);
1669 }
1670};
1671
1672template <typename Device, typename T, typename U>
1673class FusedBatchNormGradOpEx : public FusedBatchNormGradOpBase<Device, T, U> {
1674 static constexpr bool kWithSideInputAndActivation = true;
1675
1676 public:
1677 explicit FusedBatchNormGradOpEx(OpKernelConstruction* context)
1678 : FusedBatchNormGradOpBase<Device, T, U>(context,
1679 kWithSideInputAndActivation) {}
1680
1681 void Compute(OpKernelContext* context) override {
1682 FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context,
1683 true);
1684 }
1685};
1686
1687REGISTER_KERNEL_BUILDER(
1688 Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1689 FusedBatchNormOp<CPUDevice, float, float>);
1690
1691REGISTER_KERNEL_BUILDER(
1692 Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
1693 FusedBatchNormGradOp<CPUDevice, float, float>);
1694
1695REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1696 .Device(DEVICE_CPU)
1697 .TypeConstraint<float>("T")
1698 .TypeConstraint<float>("U"),
1699 FusedBatchNormOp<CPUDevice, float, float>);
1700
1701REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1702 .Device(DEVICE_CPU)
1703 .TypeConstraint<float>("T")
1704 .TypeConstraint<float>("U"),
1705 FusedBatchNormGradOp<CPUDevice, float, float>);
1706
1707REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1708 .Device(DEVICE_CPU)
1709 .TypeConstraint<Eigen::half>("T")
1710 .TypeConstraint<float>("U"),
1711 FusedBatchNormOp<CPUDevice, Eigen::half, float>);
1712
1713REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1714 .Device(DEVICE_CPU)
1715 .TypeConstraint<Eigen::half>("T")
1716 .TypeConstraint<float>("U"),
1717 FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
1718
1719REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1720 .Device(DEVICE_CPU)
1721 .TypeConstraint<float>("T")
1722 .TypeConstraint<float>("U"),
1723 FusedBatchNormOpV3<CPUDevice, float, float>);
1724
1725REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1726 .Device(DEVICE_CPU)
1727 .TypeConstraint<float>("T")
1728 .TypeConstraint<float>("U"),
1729 FusedBatchNormGradOpV3<CPUDevice, float, float>);
1730
1731REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1732 .Device(DEVICE_CPU)
1733 .TypeConstraint<Eigen::half>("T")
1734 .TypeConstraint<float>("U"),
1735 FusedBatchNormOpV3<CPUDevice, Eigen::half, float>);
1736
1737REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1738 .Device(DEVICE_CPU)
1739 .TypeConstraint<Eigen::half>("T")
1740 .TypeConstraint<float>("U"),
1741 FusedBatchNormGradOpV3<CPUDevice, Eigen::half, float>);
1742
1743#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1744
1745REGISTER_KERNEL_BUILDER(
1746 Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1747 FusedBatchNormOp<GPUDevice, float, float>);
1748
1749REGISTER_KERNEL_BUILDER(
1750 Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1751 FusedBatchNormGradOp<GPUDevice, float, float>);
1752
1753REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1754 .Device(DEVICE_GPU)
1755 .TypeConstraint<float>("T")
1756 .TypeConstraint<float>("U"),
1757 FusedBatchNormOp<GPUDevice, float, float>);
1758
1759REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1760 .Device(DEVICE_GPU)
1761 .TypeConstraint<float>("T")
1762 .TypeConstraint<float>("U"),
1763 FusedBatchNormGradOp<GPUDevice, float, float>);
1764
1765REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
1766 .Device(DEVICE_GPU)
1767 .TypeConstraint<Eigen::half>("T")
1768 .TypeConstraint<float>("U"),
1769 FusedBatchNormOp<GPUDevice, Eigen::half, float>);
1770
1771REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
1772 .Device(DEVICE_GPU)
1773 .TypeConstraint<Eigen::half>("T")
1774 .TypeConstraint<float>("U"),
1775 FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
1776
1777REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1778 .Device(DEVICE_GPU)
1779 .TypeConstraint<float>("T")
1780 .TypeConstraint<float>("U"),
1781 FusedBatchNormOpV3<GPUDevice, float, float>);
1782
1783REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1784 .Device(DEVICE_GPU)
1785 .TypeConstraint<float>("T")
1786 .TypeConstraint<float>("U"),
1787 FusedBatchNormOpEx<GPUDevice, float, float>);
1788
1789REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1790 .Device(DEVICE_GPU)
1791 .TypeConstraint<float>("T")
1792 .TypeConstraint<float>("U"),
1793 FusedBatchNormGradOpV3<GPUDevice, float, float>);
1794
1795REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx")
1796 .Device(DEVICE_GPU)
1797 .TypeConstraint<float>("T")
1798 .TypeConstraint<float>("U"),
1799 FusedBatchNormGradOpEx<GPUDevice, float, float>);
1800
1801REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3")
1802 .Device(DEVICE_GPU)
1803 .TypeConstraint<Eigen::half>("T")
1804 .TypeConstraint<float>("U"),
1805 FusedBatchNormOpV3<GPUDevice, Eigen::half, float>);
1806
1807REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx")
1808 .Device(DEVICE_GPU)
1809 .TypeConstraint<Eigen::half>("T")
1810 .TypeConstraint<float>("U"),
1811 FusedBatchNormOpEx<GPUDevice, Eigen::half, float>);
1812
1813REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3")
1814 .Device(DEVICE_GPU)
1815 .TypeConstraint<Eigen::half>("T")
1816 .TypeConstraint<float>("U"),
1817 FusedBatchNormGradOpV3<GPUDevice, Eigen::half, float>);
1818
1819REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx")
1820 .Device(DEVICE_GPU)
1821 .TypeConstraint<Eigen::half>("T")
1822 .TypeConstraint<float>("U"),
1823 FusedBatchNormGradOpEx<GPUDevice, Eigen::half, float>);
1824
1825#endif
1826
1827} // namespace tensorflow
1828