1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <atomic> |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | #define EIGEN_USE_GPU |
22 | #if GOOGLE_CUDA |
23 | #include "third_party/gpus/cudnn/cudnn.h" |
24 | #endif // GOOGLE_CUDA |
25 | |
26 | #include "tensorflow/core/kernels/conv_2d.h" |
27 | #include "tensorflow/core/platform/stream_executor.h" |
28 | #include "tensorflow/core/util/stream_executor_util.h" |
29 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
30 | |
31 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/register_types.h" |
34 | #include "tensorflow/core/framework/tensor.h" |
35 | #include "tensorflow/core/framework/tensor_types.h" |
36 | #include "tensorflow/core/kernels/fill_functor.h" |
37 | #include "tensorflow/core/kernels/fused_batch_norm_op.h" |
38 | #include "tensorflow/core/kernels/redux_functor.h" |
39 | #include "tensorflow/core/kernels/transpose_functor.h" |
40 | #include "tensorflow/core/platform/blocking_counter.h" |
41 | #include "tensorflow/core/util/env_var.h" |
42 | #include "tensorflow/core/util/tensor_format.h" |
43 | |
44 | namespace tensorflow { |
45 | using CPUDevice = Eigen::ThreadPoolDevice; |
46 | using GPUDevice = Eigen::GpuDevice; |
47 | |
48 | namespace functor { |
49 | |
50 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
51 | using se::DeviceMemory; |
52 | using se::ScratchAllocator; |
53 | using se::Stream; |
54 | using se::port::StatusOr; |
55 | #endif |
56 | |
57 | string ToString(FusedBatchNormActivationMode activation_mode) { |
58 | switch (activation_mode) { |
59 | case FusedBatchNormActivationMode::kIdentity: |
60 | return "Identity" ; |
61 | case FusedBatchNormActivationMode::kRelu: |
62 | return "Relu" ; |
63 | } |
64 | } |
65 | |
66 | Status ParseActivationMode(OpKernelConstruction* context, |
67 | FusedBatchNormActivationMode* activation_mode) { |
68 | string activation_mode_str; |
69 | TF_RETURN_IF_ERROR(context->GetAttr("activation_mode" , &activation_mode_str)); |
70 | |
71 | if (activation_mode_str == "Identity" ) { |
72 | *activation_mode = FusedBatchNormActivationMode::kIdentity; |
73 | return OkStatus(); |
74 | } |
75 | if (activation_mode_str == "Relu" ) { |
76 | *activation_mode = FusedBatchNormActivationMode::kRelu; |
77 | return OkStatus(); |
78 | } |
79 | return errors::InvalidArgument("Unsupported activation mode: " , |
80 | activation_mode_str); |
81 | } |
82 | |
83 | // Functor used by FusedBatchNormOp to do the computations. |
84 | template <typename Device, typename T, typename U, bool is_training> |
85 | struct FusedBatchNorm; |
86 | // Functor used by FusedBatchNormGradOp to do the computations when |
87 | // is_training=True. |
88 | template <typename Device, typename T, typename U> |
89 | struct FusedBatchNormGrad; |
90 | |
91 | template <typename T, typename U> |
92 | struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> { |
93 | void operator()(OpKernelContext* context, const Tensor& x_input, |
94 | const Tensor& scale_input, const Tensor& offset_input, |
95 | const Tensor& running_mean_input, |
96 | const Tensor& running_variance_input, |
97 | const Tensor* side_input, U epsilon, U exponential_avg_factor, |
98 | FusedBatchNormActivationMode activation_mode, |
99 | Tensor* y_output, Tensor* running_mean_output, |
100 | Tensor* running_var_output, Tensor* saved_batch_mean_output, |
101 | Tensor* saved_batch_var_output, TensorFormat tensor_format, |
102 | bool use_reserved_space) { |
103 | OP_REQUIRES(context, side_input == nullptr, |
104 | errors::Internal( |
105 | "The CPU implementation of FusedBatchNorm does not support " |
106 | "side input." )); |
107 | OP_REQUIRES(context, |
108 | activation_mode == FusedBatchNormActivationMode::kIdentity, |
109 | errors::Internal("The CPU implementation of FusedBatchNorm " |
110 | "does not support activations." )); |
111 | |
112 | if (use_reserved_space) { |
113 | Tensor* dummy_reserve_space = nullptr; |
114 | OP_REQUIRES_OK(context, |
115 | context->allocate_output(5, {}, &dummy_reserve_space)); |
116 | // Initialize the memory, to avoid sanitizer alerts. |
117 | dummy_reserve_space->flat<U>()(0) = U(); |
118 | } |
119 | |
120 | // If input is empty, return NaN mean/variance |
121 | if (x_input.shape().num_elements() == 0) { |
122 | functor::SetNanFunctor<CPUDevice, U> f; |
123 | f(context->eigen_device<CPUDevice>(), running_mean_output->flat<U>()); |
124 | f(context->eigen_device<CPUDevice>(), running_var_output->flat<U>()); |
125 | return; |
126 | } |
127 | |
128 | Tensor transformed_x; |
129 | Tensor transformed_y; |
130 | if (tensor_format == FORMAT_NCHW) { |
131 | const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N'); |
132 | const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H'); |
133 | const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W'); |
134 | const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C'); |
135 | OP_REQUIRES_OK(context, context->allocate_temp( |
136 | DataTypeToEnum<T>::value, |
137 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
138 | in_rows, in_cols, in_depths), |
139 | &transformed_x)); |
140 | OP_REQUIRES_OK(context, context->allocate_temp( |
141 | DataTypeToEnum<T>::value, |
142 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
143 | in_rows, in_cols, in_depths), |
144 | &transformed_y)); |
145 | // Perform NCHW to NHWC |
146 | std::vector<int32> perm = {0, 2, 3, 1}; |
147 | OP_REQUIRES_OK( |
148 | context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(), |
149 | x_input, perm, &transformed_x)); |
150 | } else { |
151 | transformed_x = x_input; |
152 | transformed_y = *y_output; |
153 | } |
154 | typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>()); |
155 | typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); |
156 | typename TTypes<U>::ConstVec offset(offset_input.vec<U>()); |
157 | typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>()); |
158 | typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>()); |
159 | typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>()); |
160 | typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>()); |
161 | typename TTypes<U>::Vec new_variance(running_var_output->vec<U>()); |
162 | typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>()); |
163 | typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>()); |
164 | |
165 | const CPUDevice& d = context->eigen_device<CPUDevice>(); |
166 | |
167 | const int depth = x.dimension(3); |
168 | const int size = x.size(); |
169 | const int rest_size = size / depth; |
170 | Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); |
171 | |
172 | Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; |
173 | one_by_depth.set(1, depth); |
174 | Eigen::IndexList<Eigen::type2index<0>> reduce_dims; |
175 | Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec; |
176 | bcast_spec.set(0, rest_size); |
177 | |
178 | auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); |
179 | const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; |
180 | U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); |
181 | // This adjustment is for Bessel's correction |
182 | U rest_size_adjust = |
183 | static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one); |
184 | |
185 | Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth); |
186 | Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth); |
187 | |
188 | batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); |
189 | auto x_centered = x_rest_by_depth - |
190 | batch_mean.reshape(one_by_depth).broadcast(bcast_spec); |
191 | |
192 | batch_variance.device(d) = |
193 | x_centered.square().sum(reduce_dims) * rest_size_inv; |
194 | auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale) |
195 | .eval() |
196 | .reshape(one_by_depth) |
197 | .broadcast(bcast_spec); |
198 | auto x_scaled = x_centered * scaling_factor; |
199 | auto x_shifted = |
200 | (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) |
201 | .template cast<T>(); |
202 | |
203 | y.reshape(rest_by_depth).device(d) = x_shifted; |
204 | if (exponential_avg_factor == U(1.0)) { |
205 | saved_batch_var.device(d) = batch_variance; |
206 | saved_batch_mean.device(d) = batch_mean; |
207 | new_variance.device(d) = batch_variance * rest_size_adjust; |
208 | new_mean.device(d) = batch_mean; |
209 | } else { |
210 | U one_minus_factor = U(1) - exponential_avg_factor; |
211 | saved_batch_var.device(d) = batch_variance; |
212 | saved_batch_mean.device(d) = batch_mean; |
213 | new_variance.device(d) = |
214 | one_minus_factor * old_variance + |
215 | (exponential_avg_factor * rest_size_adjust) * batch_variance; |
216 | new_mean.device(d) = |
217 | one_minus_factor * old_mean + exponential_avg_factor * batch_mean; |
218 | } |
219 | |
220 | if (tensor_format == FORMAT_NCHW) { |
221 | // Perform NHWC to NCHW |
222 | const std::vector<int32> perm = {0, 3, 1, 2}; |
223 | const Status s = ::tensorflow::DoTranspose( |
224 | context->eigen_device<CPUDevice>(), transformed_y, perm, y_output); |
225 | if (!s.ok()) { |
226 | context->SetStatus(errors::InvalidArgument("Transpose failed: " , s)); |
227 | } |
228 | } |
229 | } |
230 | }; |
231 | |
232 | template <typename T, typename U> |
233 | struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> { |
234 | void operator()(OpKernelContext* context, const Tensor& x_input, |
235 | const Tensor& scale_input, const Tensor& offset_input, |
236 | const Tensor& estimated_mean_input, |
237 | const Tensor& estimated_variance_input, |
238 | const Tensor* side_input, U epsilon, U exponential_avg_factor, |
239 | FusedBatchNormActivationMode activation_mode, |
240 | Tensor* y_output, Tensor* batch_mean_output, |
241 | Tensor* batch_var_output, Tensor* saved_mean_output, |
242 | Tensor* saved_var_output, TensorFormat tensor_format, |
243 | bool use_reserved_space) { |
244 | OP_REQUIRES(context, side_input == nullptr, |
245 | errors::Internal( |
246 | "The CPU implementation of FusedBatchNorm does not support " |
247 | "side input." )); |
248 | OP_REQUIRES(context, |
249 | activation_mode == FusedBatchNormActivationMode::kIdentity, |
250 | errors::Internal("The CPU implementation of FusedBatchNorm " |
251 | "does not support activations." )); |
252 | |
253 | if (use_reserved_space) { |
254 | Tensor* dummy_reserve_space = nullptr; |
255 | OP_REQUIRES_OK(context, |
256 | context->allocate_output(5, {}, &dummy_reserve_space)); |
257 | // Initialize the memory, to avoid sanitizer alerts. |
258 | dummy_reserve_space->flat<U>()(0) = U(); |
259 | } |
260 | |
261 | // If input is empty, return NaN mean/variance |
262 | if (x_input.shape().num_elements() == 0) { |
263 | functor::SetNanFunctor<CPUDevice, U> f; |
264 | f(context->eigen_device<CPUDevice>(), batch_mean_output->flat<U>()); |
265 | f(context->eigen_device<CPUDevice>(), batch_var_output->flat<U>()); |
266 | return; |
267 | } |
268 | |
269 | Tensor transformed_x; |
270 | Tensor transformed_y; |
271 | if (tensor_format == FORMAT_NCHW) { |
272 | const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N'); |
273 | const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H'); |
274 | const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W'); |
275 | const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C'); |
276 | OP_REQUIRES_OK(context, context->allocate_temp( |
277 | DataTypeToEnum<T>::value, |
278 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
279 | in_rows, in_cols, in_depths), |
280 | &transformed_x)); |
281 | OP_REQUIRES_OK(context, context->allocate_temp( |
282 | DataTypeToEnum<T>::value, |
283 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
284 | in_rows, in_cols, in_depths), |
285 | &transformed_y)); |
286 | // Perform NCHW to NHWC |
287 | std::vector<int32> perm = {0, 2, 3, 1}; |
288 | OP_REQUIRES_OK( |
289 | context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(), |
290 | x_input, perm, &transformed_x)); |
291 | } else { |
292 | transformed_x = x_input; |
293 | transformed_y = *y_output; |
294 | } |
295 | typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>()); |
296 | typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); |
297 | typename TTypes<U>::ConstVec offset(offset_input.vec<U>()); |
298 | typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>()); |
299 | typename TTypes<U>::ConstVec estimated_variance( |
300 | estimated_variance_input.vec<U>()); |
301 | typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>()); |
302 | typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>()); |
303 | typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>()); |
304 | |
305 | const CPUDevice& d = context->eigen_device<CPUDevice>(); |
306 | |
307 | const int depth = x.dimension(3); |
308 | OP_REQUIRES( |
309 | context, depth != 0, |
310 | errors::Internal("The 4th element in the input shape cannot be 0." )); |
311 | const int size = x.size(); |
312 | const int rest_size = size / depth; |
313 | Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); |
314 | Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; |
315 | one_by_depth.set(1, depth); |
316 | Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec; |
317 | bcast_spec.set(0, rest_size); |
318 | |
319 | auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); |
320 | auto x_centered = |
321 | x_rest_by_depth - |
322 | estimated_mean.reshape(one_by_depth).broadcast(bcast_spec); |
323 | auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale) |
324 | .eval() |
325 | .reshape(one_by_depth) |
326 | .broadcast(bcast_spec); |
327 | auto x_scaled = x_centered * scaling_factor; |
328 | auto x_shifted = |
329 | (x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec)) |
330 | .template cast<T>(); |
331 | |
332 | y.reshape(rest_by_depth).device(d) = x_shifted; |
333 | batch_mean.device(d) = estimated_mean; |
334 | batch_variance.device(d) = estimated_variance; |
335 | |
336 | if (tensor_format == FORMAT_NCHW) { |
337 | // Perform NHWC to NCHW |
338 | const std::vector<int32> perm = {0, 3, 1, 2}; |
339 | const Status s = ::tensorflow::DoTranspose( |
340 | context->eigen_device<CPUDevice>(), transformed_y, perm, y_output); |
341 | if (!s.ok()) { |
342 | context->SetStatus(errors::InvalidArgument("Transpose failed: " , s)); |
343 | } |
344 | } |
345 | } |
346 | }; |
347 | |
348 | template <typename T, typename U> |
349 | struct FusedBatchNormGrad<CPUDevice, T, U> { |
350 | void operator()(OpKernelContext* context, const Tensor& y_backprop_input, |
351 | const Tensor& x_input, const Tensor& scale_input, |
352 | const Tensor* offset_input, const Tensor& mean_input, |
353 | const Tensor& variance_input, const Tensor* y_input, |
354 | U epsilon, FusedBatchNormActivationMode activation_mode, |
355 | Tensor* x_backprop_output, Tensor* scale_backprop_output, |
356 | Tensor* offset_backprop_output, |
357 | Tensor* side_input_backprop_output, bool use_reserved_space, |
358 | TensorFormat tensor_format) { |
359 | OP_REQUIRES(context, |
360 | y_input == nullptr && |
361 | activation_mode == FusedBatchNormActivationMode::kIdentity, |
362 | errors::Internal( |
363 | "The CPU implementation of FusedBatchNormGrad does not " |
364 | "support activations." )); |
365 | OP_REQUIRES(context, side_input_backprop_output == nullptr, |
366 | errors::Internal("The CPU implementation of FusedBatchNormGrad " |
367 | "does not support side input." )); |
368 | |
369 | Tensor transformed_y_backprop_input; |
370 | Tensor transformed_x_input; |
371 | Tensor transformed_x_backprop_output; |
372 | if (tensor_format == FORMAT_NCHW) { |
373 | const int64_t in_batch = GetTensorDim(x_input, tensor_format, 'N'); |
374 | const int64_t in_rows = GetTensorDim(x_input, tensor_format, 'H'); |
375 | const int64_t in_cols = GetTensorDim(x_input, tensor_format, 'W'); |
376 | const int64_t in_depths = GetTensorDim(x_input, tensor_format, 'C'); |
377 | OP_REQUIRES_OK(context, context->allocate_temp( |
378 | DataTypeToEnum<T>::value, |
379 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
380 | in_rows, in_cols, in_depths), |
381 | &transformed_y_backprop_input)); |
382 | OP_REQUIRES_OK(context, context->allocate_temp( |
383 | DataTypeToEnum<T>::value, |
384 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
385 | in_rows, in_cols, in_depths), |
386 | &transformed_x_input)); |
387 | OP_REQUIRES_OK(context, context->allocate_temp( |
388 | DataTypeToEnum<T>::value, |
389 | ShapeFromFormat(FORMAT_NHWC, in_batch, |
390 | in_rows, in_cols, in_depths), |
391 | &transformed_x_backprop_output)); |
392 | // Perform NCHW to NHWC |
393 | std::vector<int32> perm = {0, 2, 3, 1}; |
394 | OP_REQUIRES_OK( |
395 | context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(), |
396 | y_backprop_input, perm, |
397 | &transformed_y_backprop_input)); |
398 | OP_REQUIRES_OK(context, ::tensorflow::DoTranspose( |
399 | context->eigen_device<CPUDevice>(), x_input, |
400 | perm, &transformed_x_input)); |
401 | } else { |
402 | transformed_y_backprop_input = y_backprop_input; |
403 | transformed_x_input = x_input; |
404 | transformed_x_backprop_output = *x_backprop_output; |
405 | } |
406 | typename TTypes<T, 4>::Tensor y_backprop( |
407 | transformed_y_backprop_input.tensor<T, 4>()); |
408 | typename TTypes<T, 4>::Tensor x(transformed_x_input.tensor<T, 4>()); |
409 | typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); |
410 | typename TTypes<U>::ConstVec mean(mean_input.vec<U>()); |
411 | typename TTypes<U>::ConstVec variance(variance_input.vec<U>()); |
412 | typename TTypes<T, 4>::Tensor x_backprop( |
413 | transformed_x_backprop_output.tensor<T, 4>()); |
414 | typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>()); |
415 | |
416 | // Note: the following formulas are used to compute the gradients for |
417 | // back propagation. |
418 | // x_backprop = scale * rsqrt(variance + epsilon) * |
419 | // [y_backprop - mean(y_backprop) - (x - mean(x)) * |
420 | // mean(y_backprop * (x - mean(x))) / (variance + epsilon)] |
421 | // scale_backprop = sum(y_backprop * |
422 | // (x - mean(x)) * rsqrt(variance + epsilon)) |
423 | // offset_backprop = sum(y_backprop) |
424 | |
425 | const CPUDevice& d = context->eigen_device<CPUDevice>(); |
426 | const int depth = x.dimension(3); |
427 | const int size = x.size(); |
428 | const int rest_size = size / depth; |
429 | Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); |
430 | Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; |
431 | one_by_depth.set(1, depth); |
432 | Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec; |
433 | bcast_spec.set(0, rest_size); |
434 | |
435 | auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); |
436 | U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); |
437 | |
438 | // Eigen is notoriously bad at reducing outer dimension, so we materialize |
439 | // all temporary tensors that require reduction, and then use Eigen redux |
440 | // functor, that is optimized for this particular task. |
441 | // |
442 | // All reductions are of this type: [rest_size, depth] -> [depth]. |
443 | using ScalarSum = Eigen::internal::scalar_sum_op<U>; |
444 | const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t; |
445 | const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u; |
446 | |
447 | auto scratch_dtype = DataTypeToEnum<U>::value; |
448 | |
449 | // Allocate a temporary workspace of [depth] shape. |
450 | Tensor scratch_one_by_depth; |
451 | OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth}, |
452 | &scratch_one_by_depth)); |
453 | |
454 | // Maybe allocate a temporary workspace of [rest_size, depth] shape. |
455 | Tensor scratch_rest_by_depth; |
456 | if (std::is_same<T, U>::value) { |
457 | OP_REQUIRES(context, |
458 | scratch_rest_by_depth.CopyFrom(transformed_x_backprop_output, |
459 | {rest_size, depth}), |
460 | errors::Internal("Failed to copy a tensor" )); |
461 | } else { |
462 | OP_REQUIRES_OK(context, |
463 | context->allocate_temp(scratch_dtype, {rest_size, depth}, |
464 | &scratch_rest_by_depth)); |
465 | } |
466 | |
467 | typename TTypes<U, 2>::Tensor scratch_tensor( |
468 | scratch_rest_by_depth.tensor<U, 2>()); |
469 | typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>()); |
470 | |
471 | auto x_mean_rest_by_depth = |
472 | mean.reshape(one_by_depth).broadcast(bcast_spec); |
473 | auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth); |
474 | auto coef0_one_by_depth = |
475 | (variance.reshape(one_by_depth) + epsilon).rsqrt(); |
476 | auto coef0_rest_by_depth = coef0_one_by_depth.broadcast(bcast_spec); |
477 | auto x_scaled = x_centered * coef0_rest_by_depth; |
478 | |
479 | auto y_backprop_rest_by_depth = |
480 | y_backprop.reshape(rest_by_depth).template cast<U>(); |
481 | |
482 | // Compute `scale_backprop_output`: |
483 | // scale_backprop = |
484 | // (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims) |
485 | scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled; |
486 | redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output); |
487 | |
488 | // Compute 'offset_backprop_output': |
489 | // offset_backprop = |
490 | // y_backprop_rest_by_depth.sum(reduce_dims) |
491 | redux_sum_t(d, rest_by_depth, transformed_y_backprop_input, |
492 | offset_backprop_output); |
493 | auto y_backprop_sum = offset_backprop; |
494 | |
495 | auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth); |
496 | auto y_backprop_mean_one_by_depth = |
497 | y_backprop_sum_one_by_depth * rest_size_inv; |
498 | auto y_backprop_mean_rest_by_depth = |
499 | y_backprop_mean_one_by_depth.broadcast(bcast_spec); |
500 | auto y_backprop_centered = |
501 | y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth; |
502 | |
503 | // Compute expression: |
504 | // y_backprop_centered_mean = |
505 | // (y_backprop_rest_by_depth * x_centered).mean(reduce_dims) |
506 | scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered; |
507 | redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth); |
508 | auto y_backprop_centered_mean = |
509 | scratch_vector.reshape(one_by_depth) / static_cast<U>(rest_size); |
510 | |
511 | auto coef1 = (scale.reshape(one_by_depth) * coef0_one_by_depth) |
512 | .broadcast(bcast_spec); |
513 | auto coef2 = (coef0_one_by_depth.square() * y_backprop_centered_mean) |
514 | .broadcast(bcast_spec); |
515 | |
516 | x_backprop.reshape(rest_by_depth).device(d) = |
517 | (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>(); |
518 | |
519 | if (tensor_format == FORMAT_NCHW) { |
520 | // Perform NHWC to NCHW |
521 | std::vector<int32> perm = {0, 3, 1, 2}; |
522 | OP_REQUIRES_OK( |
523 | context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(), |
524 | transformed_x_backprop_output, |
525 | perm, x_backprop_output)); |
526 | } |
527 | } |
528 | }; |
529 | |
530 | template <typename T, typename U> |
531 | struct FusedBatchNormFreezeGrad<CPUDevice, T, U> { |
532 | void operator()(OpKernelContext* context, const Tensor& y_backprop_input, |
533 | const Tensor& x_input, const Tensor& scale_input, |
534 | const Tensor& pop_mean_input, |
535 | const Tensor& pop_variance_input, U epsilon, |
536 | Tensor* x_backprop_output, Tensor* scale_backprop_output, |
537 | Tensor* offset_backprop_output) { |
538 | typename TTypes<T, 4>::ConstTensor y_backprop( |
539 | y_backprop_input.tensor<T, 4>()); |
540 | typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>()); |
541 | typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); |
542 | typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>()); |
543 | typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>()); |
544 | typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>()); |
545 | typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>()); |
546 | |
547 | const int depth = pop_mean.dimension(0); |
548 | const int rest_size = input.size() / depth; |
549 | |
550 | const CPUDevice& d = context->eigen_device<CPUDevice>(); |
551 | |
552 | // Allocate two temporary workspaces of [depth] shape. |
553 | Tensor scratch1_vec, scratch2_vec; |
554 | OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value, |
555 | {depth}, &scratch1_vec)); |
556 | OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value, |
557 | {depth}, &scratch2_vec)); |
558 | |
559 | // Maybe allocate a temporary workspace of [rest_size, depth] shape. |
560 | Tensor scratch3_tensor; |
561 | if (std::is_same<T, U>::value) { |
562 | OP_REQUIRES( |
563 | context, |
564 | scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}), |
565 | errors::Internal("Failed to copy a tensor" )); |
566 | } else { |
567 | OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value, |
568 | {rest_size, depth}, |
569 | &scratch3_tensor)); |
570 | } |
571 | |
572 | typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>()); |
573 | typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>()); |
574 | typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>()); |
575 | |
576 | Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); |
577 | Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; |
578 | one_by_depth.set(1, depth); |
579 | Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one; |
580 | rest_by_one.set(0, rest_size); |
581 | |
582 | // Sum reduction along the 0th dimension using custom CPU functor. |
583 | using ScalarSum = Eigen::internal::scalar_sum_op<U>; |
584 | const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t; |
585 | const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u; |
586 | |
587 | // offset_backprop = sum(y_backprop) |
588 | // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) |
589 | // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) |
590 | |
591 | // NOTE: DEFAULT DEVICE comment is added to expression assignments that |
592 | // we don't want to be executed in a thread pool. |
593 | |
594 | auto y_backprop_rest_by_depth = |
595 | y_backprop.reshape(rest_by_depth).template cast<U>(); |
596 | auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>(); |
597 | |
598 | // offset_backprop = sum(y_backprop) |
599 | redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output); |
600 | |
601 | // scratch1 = rsqrt(pop_var + epsilon) |
602 | scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt(); // DEFAULT DEVICE |
603 | |
604 | // scratch2 = sum(y_backprop * (x - mean)) |
605 | scratch3.device(d) = |
606 | y_backprop_rest_by_depth * |
607 | (input_rest_by_depth - |
608 | pop_mean.reshape(one_by_depth).broadcast(rest_by_one)); |
609 | redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec); |
610 | |
611 | x_backprop.reshape(rest_by_depth).device(d) = |
612 | (y_backprop_rest_by_depth * |
613 | ((scratch1.reshape(one_by_depth) * scale.reshape(one_by_depth)) |
614 | .broadcast(rest_by_one))) |
615 | .template cast<T>(); |
616 | scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE |
617 | } |
618 | }; |
619 | |
620 | #if !GOOGLE_CUDA |
621 | namespace { |
622 | // See implementation under GOOGLE_CUDA #ifdef below. |
623 | // This is a CUDA specific feature, do not enable it for non-CUDA builds |
624 | bool BatchnormSpatialPersistentEnabled() { return false; } |
625 | } // namespace |
626 | #endif |
627 | |
628 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
629 | |
630 | namespace { |
631 | |
632 | se::dnn::ActivationMode AsDnnActivationMode( |
633 | const FusedBatchNormActivationMode activation_mode) { |
634 | switch (activation_mode) { |
635 | case FusedBatchNormActivationMode::kIdentity: |
636 | return se::dnn::ActivationMode::kNone; |
637 | case FusedBatchNormActivationMode::kRelu: |
638 | return se::dnn::ActivationMode::kRelu; |
639 | } |
640 | } |
641 | |
642 | #if GOOGLE_CUDA |
643 | // NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the |
644 | // `cuda_dnn.cc` for details. |
645 | bool BatchnormSpatialPersistentEnabled() { |
646 | #if CUDNN_VERSION >= 7402 |
647 | static bool is_enabled = [] { |
648 | bool is_enabled = false; |
649 | TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( |
650 | "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT" , |
651 | /*default_val=*/false, &is_enabled)); |
652 | return is_enabled; |
653 | }(); |
654 | return is_enabled; |
655 | #else |
656 | return false; |
657 | #endif |
658 | } |
659 | #endif |
660 | |
661 | } // namespace |
662 | |
663 | template <typename U, typename T> |
664 | DeviceMemory<U> CastDeviceMemory(Tensor* tensor) { |
665 | return DeviceMemory<U>::MakeFromByteSize( |
666 | tensor->template flat<T>().data(), |
667 | tensor->template flat<T>().size() * sizeof(T)); |
668 | } |
669 | |
670 | // A helper to allocate temporary scratch memory for Cudnn BatchNormEx ops. It |
671 | // takes the ownership of the underlying memory. The expectation is that the |
672 | // memory should be alive for the span of the Cudnn BatchNormEx itself. |
673 | template <typename T> |
674 | class CudnnBatchNormAllocatorInTemp : public ScratchAllocator { |
675 | public: |
676 | ~CudnnBatchNormAllocatorInTemp() override = default; |
677 | |
678 | explicit CudnnBatchNormAllocatorInTemp(OpKernelContext* context) |
679 | : context_(context) {} |
680 | |
681 | int64_t GetMemoryLimitInBytes() override { |
682 | return std::numeric_limits<int64_t>::max(); |
683 | } |
684 | |
685 | StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override { |
686 | Tensor temporary_memory; |
687 | const DataType tf_data_type = DataTypeToEnum<T>::v(); |
688 | int64_t allocate_count = |
689 | Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T))); |
690 | Status allocation_status(context_->allocate_temp( |
691 | tf_data_type, TensorShape({allocate_count}), &temporary_memory)); |
692 | if (!allocation_status.ok()) { |
693 | return allocation_status; |
694 | } |
695 | // Hold the reference of the allocated tensors until the end of the |
696 | // allocator. |
697 | allocated_tensors_.push_back(temporary_memory); |
698 | total_byte_size_ += byte_size; |
699 | return DeviceMemory<uint8>::MakeFromByteSize( |
700 | temporary_memory.template flat<T>().data(), |
701 | temporary_memory.template flat<T>().size() * sizeof(T)); |
702 | } |
703 | |
704 | int64_t TotalByteSize() const { return total_byte_size_; } |
705 | |
706 | Tensor get_allocated_tensor(int index) const { |
707 | return allocated_tensors_[index]; |
708 | } |
709 | |
710 | private: |
711 | int64_t total_byte_size_ = 0; |
712 | OpKernelContext* context_; // not owned |
713 | std::vector<Tensor> allocated_tensors_; |
714 | }; |
715 | |
716 | // A helper to allocate memory for Cudnn BatchNormEx as a kernel output. It is |
717 | // used by forward pass kernel to feed the output to the backward pass. |
718 | // The memory is expected to live long enough after the backward pass is |
719 | // finished. |
720 | template <typename T> |
721 | class CudnnBatchNormAllocatorInOutput : public ScratchAllocator { |
722 | public: |
723 | ~CudnnBatchNormAllocatorInOutput() override { |
724 | if (!output_allocated) { |
725 | Tensor* dummy_reserve_space = nullptr; |
726 | OP_REQUIRES_OK(context_, context_->allocate_output(output_index_, {}, |
727 | &dummy_reserve_space)); |
728 | } |
729 | } |
730 | |
731 | CudnnBatchNormAllocatorInOutput(OpKernelContext* context, int output_index) |
732 | : context_(context), output_index_(output_index) {} |
733 | |
734 | int64_t GetMemoryLimitInBytes() override { |
735 | return std::numeric_limits<int64_t>::max(); |
736 | } |
737 | |
738 | StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override { |
739 | output_allocated = true; |
740 | DCHECK(total_byte_size_ == 0) |
741 | << "Reserve space allocator can only be called once" ; |
742 | int64_t allocate_count = |
743 | Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T))); |
744 | |
745 | Tensor* temporary_memory = nullptr; |
746 | Status allocation_status(context_->allocate_output( |
747 | output_index_, TensorShape({allocate_count}), &temporary_memory)); |
748 | if (!allocation_status.ok()) { |
749 | return allocation_status; |
750 | } |
751 | total_byte_size_ += byte_size; |
752 | auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize( |
753 | temporary_memory->template flat<T>().data(), |
754 | temporary_memory->template flat<T>().size() * sizeof(T)); |
755 | return StatusOr<DeviceMemory<uint8>>(memory_uint8); |
756 | } |
757 | |
758 | int64_t TotalByteSize() { return total_byte_size_; } |
759 | |
760 | private: |
761 | int64_t total_byte_size_ = 0; |
762 | OpKernelContext* context_; // not owned |
763 | int output_index_; |
764 | bool output_allocated = false; |
765 | }; |
766 | |
767 | template <typename T, typename U, bool is_training> |
768 | struct FusedBatchNorm<GPUDevice, T, U, is_training> { |
769 | void operator()(OpKernelContext* context, const Tensor& x, |
770 | const Tensor& scale, const Tensor& offset, |
771 | const Tensor& estimated_mean, |
772 | const Tensor& estimated_variance, const Tensor* side_input, |
773 | U epsilon, U exponential_avg_factor, |
774 | FusedBatchNormActivationMode activation_mode, Tensor* y, |
775 | Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, |
776 | Tensor* saved_inv_var, TensorFormat tensor_format, |
777 | bool use_reserved_space) { |
778 | auto* stream = context->op_device_context()->stream(); |
779 | OP_REQUIRES(context, stream, errors::Internal("No GPU stream available" )); |
780 | |
781 | const int64_t batch_size = GetTensorDim(x, tensor_format, 'N'); |
782 | const int64_t channels = GetTensorDim(x, tensor_format, 'C'); |
783 | const int64_t height = GetTensorDim(x, tensor_format, 'H'); |
784 | const int64_t width = GetTensorDim(x, tensor_format, 'W'); |
785 | |
786 | // If use_reserved_space we have reserve_space_3 output (only in |
787 | // FusedBatchNormV3 op). |
788 | |
789 | #if GOOGLE_CUDA |
790 | // Check if cuDNN batch normalization has a fast NHWC implementation: |
791 | // (1) In inference mode it's always fast. |
792 | // (2) Tensorflow enabled batchnorm spatial persistence, we are called |
793 | // from |
794 | // FusedBatchNormV3, i.e. use_reserved_space is true. |
795 | const bool fast_nhwc_batch_norm = |
796 | !is_training || |
797 | (BatchnormSpatialPersistentEnabled() && |
798 | DataTypeToEnum<T>::value == DT_HALF && use_reserved_space); |
799 | #else |
800 | // fast NHWC implementation is a CUDA only feature |
801 | const bool fast_nhwc_batch_norm = false; |
802 | #endif |
803 | |
804 | // If input tensor is in NHWC format, and we have a fast cuDNN |
805 | // implementation, there is no need to do data format conversion. |
806 | TensorFormat compute_format = |
807 | fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC |
808 | : FORMAT_NCHW; |
809 | |
810 | VLOG(2) << "FusedBatchNorm:" |
811 | << " batch_size: " << batch_size << " channels: " << channels |
812 | << " height: " << height << " width:" << width |
813 | << " x shape: " << x.shape().DebugString() |
814 | << " scale shape: " << scale.shape().DebugString() |
815 | << " offset shape: " << offset.shape().DebugString() |
816 | << " activation mode: " << ToString(activation_mode) |
817 | << " tensor format: " << ToString(tensor_format) |
818 | << " compute format: " << ToString(compute_format); |
819 | |
820 | auto maybe_make_dummy_output = [context, use_reserved_space]() -> Status { |
821 | if (use_reserved_space) { |
822 | Tensor* dummy_reserve_space = nullptr; |
823 | return context->allocate_output(5, {}, &dummy_reserve_space); |
824 | } |
825 | return OkStatus(); |
826 | }; |
827 | |
828 | // If input is empty, return NaN mean/variance |
829 | if (x.shape().num_elements() == 0) { |
830 | OP_REQUIRES_OK(context, maybe_make_dummy_output()); |
831 | functor::SetNanFunctor<GPUDevice, U> f; |
832 | f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>()); |
833 | f(context->eigen_device<GPUDevice>(), batch_var->flat<U>()); |
834 | return; |
835 | } |
836 | |
837 | // In inference mode we use custom CUDA kernel, because cuDNN does not |
838 | // support side input and activations for inference. |
839 | const bool has_side_input = side_input != nullptr; |
840 | const bool has_activation = |
841 | activation_mode != FusedBatchNormActivationMode::kIdentity; |
842 | |
843 | if (!is_training && (has_side_input || has_activation)) { |
844 | OP_REQUIRES_OK(context, maybe_make_dummy_output()); |
845 | FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor; |
846 | |
847 | if (has_side_input) { |
848 | inference_functor(context, tensor_format, x.tensor<T, 4>(), |
849 | scale.vec<U>(), offset.vec<U>(), |
850 | estimated_mean.vec<U>(), estimated_variance.vec<U>(), |
851 | side_input->tensor<T, 4>(), epsilon, activation_mode, |
852 | y->tensor<T, 4>()); |
853 | } else { |
854 | typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0); |
855 | inference_functor(context, tensor_format, x.tensor<T, 4>(), |
856 | scale.vec<U>(), offset.vec<U>(), |
857 | estimated_mean.vec<U>(), estimated_variance.vec<U>(), |
858 | empty_tensor, epsilon, activation_mode, |
859 | y->tensor<T, 4>()); |
860 | } |
861 | return; |
862 | } |
863 | |
864 | Tensor x_maybe_transformed = x; |
865 | Tensor x_transformed; |
866 | Tensor y_transformed; |
867 | se::DeviceMemory<T> y_ptr; |
868 | |
869 | if (tensor_format == compute_format) { |
870 | y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y); |
871 | } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { |
872 | OP_REQUIRES_OK(context, context->allocate_temp( |
873 | DataTypeToEnum<T>::value, |
874 | ShapeFromFormat(compute_format, batch_size, |
875 | height, width, channels), |
876 | &x_transformed)); |
877 | functor::NHWCToNCHW<GPUDevice, T, 4>()( |
878 | context->eigen_device<GPUDevice>(), |
879 | const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), |
880 | x_transformed.tensor<T, 4>()); |
881 | x_maybe_transformed = x_transformed; |
882 | |
883 | OP_REQUIRES_OK(context, context->allocate_temp( |
884 | DataTypeToEnum<T>::value, |
885 | ShapeFromFormat(compute_format, batch_size, |
886 | height, width, channels), |
887 | &y_transformed)); |
888 | y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed); |
889 | } else { |
890 | context->SetStatus(errors::Internal( |
891 | "Unsupported tensor format: " , ToString(tensor_format), |
892 | " and compute format: " , ToString(compute_format))); |
893 | return; |
894 | } |
895 | |
896 | const se::dnn::DataLayout data_layout = |
897 | compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth |
898 | : se::dnn::DataLayout::kBatchDepthYX; |
899 | |
900 | se::dnn::BatchDescriptor x_desc; |
901 | x_desc.set_count(batch_size) |
902 | .set_feature_map_count(channels) |
903 | .set_height(height) |
904 | .set_width(width) |
905 | .set_layout(data_layout); |
906 | |
907 | se::dnn::BatchDescriptor scale_offset_desc; |
908 | scale_offset_desc.set_count(1) |
909 | .set_feature_map_count(channels) |
910 | .set_height(1) |
911 | .set_width(1) |
912 | .set_layout(se::dnn::DataLayout::kBatchDepthYX); |
913 | |
914 | auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); |
915 | auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); |
916 | auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset); |
917 | auto estimated_mean_ptr = |
918 | StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean); |
919 | auto estimated_variance_ptr = |
920 | StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); |
921 | auto side_input_ptr = |
922 | side_input != nullptr |
923 | ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input) |
924 | : se::DeviceMemory<T>(); |
925 | auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean); |
926 | |
927 | auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var); |
928 | auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean); |
929 | auto saved_inv_var_ptr = |
930 | StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var); |
931 | |
932 | std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>> |
933 | reserve_space_allocator; |
934 | std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>> |
935 | workspace_allocator; |
936 | if (use_reserved_space) { |
937 | reserve_space_allocator.reset( |
938 | new functor::CudnnBatchNormAllocatorInOutput<U>(context, 5)); |
939 | workspace_allocator.reset( |
940 | new functor::CudnnBatchNormAllocatorInTemp<uint8>(context)); |
941 | } |
942 | if (!batch_mean->SharesBufferWith(estimated_mean) && |
943 | exponential_avg_factor != 1.0f) { |
944 | OP_REQUIRES( |
945 | context, |
946 | stream |
947 | ->ThenMemcpyD2D(&batch_mean_ptr, estimated_mean_ptr, |
948 | estimated_mean.NumElements() * sizeof(U)) |
949 | .ok(), |
950 | errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " |
951 | "from device" )); |
952 | } |
953 | if (!batch_var->SharesBufferWith(estimated_variance) && |
954 | exponential_avg_factor != 1.0f) { |
955 | OP_REQUIRES( |
956 | context, |
957 | stream |
958 | ->ThenMemcpyD2D(&batch_var_ptr, estimated_variance_ptr, |
959 | estimated_variance.NumElements() * sizeof(U)) |
960 | .ok(), |
961 | errors::Internal("MatrixTriangularSolveOp: failed to copy rhs " |
962 | "from device" )); |
963 | } |
964 | bool cudnn_launch_status = |
965 | stream |
966 | ->ThenBatchNormalizationForward( |
967 | x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, |
968 | estimated_variance_ptr, side_input_ptr, x_desc, |
969 | scale_offset_desc, static_cast<double>(epsilon), |
970 | static_cast<double>(exponential_avg_factor), |
971 | AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr, |
972 | &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, |
973 | is_training, reserve_space_allocator.get(), |
974 | workspace_allocator.get()) |
975 | .ok(); |
976 | |
977 | if (!cudnn_launch_status) { |
978 | context->SetStatus( |
979 | errors::Internal("cuDNN launch failure : input shape (" , |
980 | x.shape().DebugString(), ")" )); |
981 | return; |
982 | } |
983 | |
984 | if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { |
985 | functor::NCHWToNHWC<GPUDevice, T, 4>()( |
986 | context->eigen_device<GPUDevice>(), |
987 | const_cast<const Tensor&>(y_transformed).tensor<T, 4>(), |
988 | y->tensor<T, 4>()); |
989 | } |
990 | } |
991 | }; |
992 | |
993 | template <typename T, typename U> |
994 | struct FusedBatchNormGrad<GPUDevice, T, U> { |
995 | void operator()(OpKernelContext* context, const Tensor& y_backprop, |
996 | const Tensor& x, const Tensor& scale, const Tensor* offset, |
997 | const Tensor& mean, const Tensor& inv_variance, |
998 | const Tensor* y, U epsilon, |
999 | FusedBatchNormActivationMode activation_mode, |
1000 | Tensor* x_backprop, Tensor* scale_backprop, |
1001 | Tensor* offset_backprop, Tensor* side_input_backprop, |
1002 | bool use_reserved_space, TensorFormat tensor_format) { |
1003 | auto* stream = context->op_device_context()->stream(); |
1004 | OP_REQUIRES(context, stream, errors::Internal("No GPU stream available" )); |
1005 | |
1006 | const int64_t batch_size = GetTensorDim(x, tensor_format, 'N'); |
1007 | const int64_t channels = GetTensorDim(x, tensor_format, 'C'); |
1008 | const int64_t height = GetTensorDim(x, tensor_format, 'H'); |
1009 | const int64_t width = GetTensorDim(x, tensor_format, 'W'); |
1010 | |
1011 | #if GOOGLE_CUDA |
1012 | // Check if cuDNN batch normalization has a fast NHWC implementation: |
1013 | // (1) Tensorflow enabled batchnorm spatial persistence, and |
1014 | // FusedBatchNormGradV3 passed non-null reserve space and allocator. |
1015 | const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() && |
1016 | DataTypeToEnum<T>::value == DT_HALF && |
1017 | use_reserved_space; |
1018 | #else |
1019 | // fast NHWC implementation is a CUDA only feature |
1020 | const bool fast_nhwc_batch_norm = false; |
1021 | #endif |
1022 | |
1023 | // If input tensor is in NHWC format, and we have a fast cuDNN |
1024 | // implementation, there is no need to do data format conversion. |
1025 | TensorFormat compute_format = |
1026 | fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC |
1027 | : FORMAT_NCHW; |
1028 | |
1029 | VLOG(2) << "FusedBatchNormGrad:" |
1030 | << " batch_size: " << batch_size << " channels: " << channels |
1031 | << " height: " << height << " width: " << width |
1032 | << " y_backprop shape: " << y_backprop.shape().DebugString() |
1033 | << " x shape: " << x.shape().DebugString() |
1034 | << " scale shape: " << scale.shape().DebugString() |
1035 | << " activation mode: " << ToString(activation_mode) |
1036 | << " tensor format: " << ToString(tensor_format) |
1037 | << " compute format: " << ToString(compute_format); |
1038 | |
1039 | // Inputs |
1040 | Tensor y_backprop_maybe_transformed = y_backprop; |
1041 | Tensor x_maybe_transformed = x; |
1042 | Tensor y_backprop_transformed; |
1043 | Tensor x_transformed; |
1044 | |
1045 | // Outputs |
1046 | Tensor x_backprop_transformed; |
1047 | se::DeviceMemory<T> x_backprop_ptr; |
1048 | |
1049 | if (tensor_format == compute_format) { |
1050 | x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop); |
1051 | } else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { |
1052 | // Transform inputs from 'NHWC' to 'NCHW' |
1053 | OP_REQUIRES_OK(context, context->allocate_temp( |
1054 | DataTypeToEnum<T>::value, |
1055 | ShapeFromFormat(FORMAT_NCHW, batch_size, |
1056 | height, width, channels), |
1057 | &y_backprop_transformed)); |
1058 | functor::NHWCToNCHW<GPUDevice, T, 4>()( |
1059 | context->eigen_device<GPUDevice>(), |
1060 | const_cast<const Tensor&>(y_backprop_maybe_transformed) |
1061 | .tensor<T, 4>(), |
1062 | y_backprop_transformed.tensor<T, 4>()); |
1063 | y_backprop_maybe_transformed = y_backprop_transformed; |
1064 | |
1065 | OP_REQUIRES_OK(context, context->allocate_temp( |
1066 | DataTypeToEnum<T>::value, |
1067 | ShapeFromFormat(FORMAT_NCHW, batch_size, |
1068 | height, width, channels), |
1069 | &x_transformed)); |
1070 | functor::NHWCToNCHW<GPUDevice, T, 4>()( |
1071 | context->eigen_device<GPUDevice>(), |
1072 | const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), |
1073 | x_transformed.tensor<T, 4>()); |
1074 | x_maybe_transformed = x_transformed; |
1075 | |
1076 | // Allocate memory for transformed outputs in 'NCHW' |
1077 | OP_REQUIRES_OK(context, context->allocate_temp( |
1078 | DataTypeToEnum<T>::value, |
1079 | ShapeFromFormat(FORMAT_NCHW, batch_size, |
1080 | height, width, channels), |
1081 | &x_backprop_transformed)); |
1082 | x_backprop_ptr = |
1083 | StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed); |
1084 | } else { |
1085 | context->SetStatus(errors::Internal( |
1086 | "Unsupported tensor format: " , ToString(tensor_format), |
1087 | " and compute format: " , ToString(compute_format))); |
1088 | return; |
1089 | } |
1090 | |
1091 | const se::dnn::DataLayout data_layout = |
1092 | compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth |
1093 | : se::dnn::DataLayout::kBatchDepthYX; |
1094 | |
1095 | se::dnn::BatchDescriptor x_desc; |
1096 | x_desc.set_count(batch_size) |
1097 | .set_feature_map_count(channels) |
1098 | .set_height(height) |
1099 | .set_width(width) |
1100 | .set_layout(data_layout); |
1101 | |
1102 | se::dnn::BatchDescriptor scale_offset_desc; |
1103 | scale_offset_desc.set_count(1) |
1104 | .set_feature_map_count(channels) |
1105 | .set_height(1) |
1106 | .set_width(1) |
1107 | .set_layout(se::dnn::DataLayout::kBatchDepthYX); |
1108 | |
1109 | auto y_backprop_ptr = |
1110 | StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed); |
1111 | auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); |
1112 | auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); |
1113 | auto offset_ptr = offset != nullptr |
1114 | ? StreamExecutorUtil::AsDeviceMemory<U>(*offset) |
1115 | : se::DeviceMemory<U>(); |
1116 | auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean); |
1117 | auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance); |
1118 | auto y_ptr = y != nullptr ? StreamExecutorUtil::AsDeviceMemory<T>(*y) |
1119 | : se::DeviceMemory<T>(); |
1120 | auto scale_backprop_ptr = |
1121 | StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop); |
1122 | auto offset_backprop_ptr = |
1123 | StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop); |
1124 | auto side_input_backprop_ptr = |
1125 | side_input_backprop != nullptr |
1126 | ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input_backprop) |
1127 | : se::DeviceMemory<T>(); |
1128 | |
1129 | std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>> |
1130 | workspace_allocator; |
1131 | DeviceMemory<uint8>* reserve_space_data_ptr = nullptr; |
1132 | DeviceMemory<uint8> reserve_space_data; |
1133 | #if CUDNN_VERSION >= 7402 |
1134 | if (use_reserved_space) { |
1135 | const Tensor& reserve_space = context->input(5); |
1136 | workspace_allocator.reset( |
1137 | new functor::CudnnBatchNormAllocatorInTemp<uint8>(context)); |
1138 | |
1139 | // the cudnn kernel outputs inverse variance in forward and reuse it in |
1140 | // backward |
1141 | if (reserve_space.dims() != 0) { |
1142 | reserve_space_data = functor::CastDeviceMemory<uint8, U>( |
1143 | const_cast<Tensor*>(&reserve_space)); |
1144 | reserve_space_data_ptr = &reserve_space_data; |
1145 | } |
1146 | } |
1147 | #endif // CUDNN_VERSION >= 7402 |
1148 | |
1149 | bool cudnn_launch_status = |
1150 | stream |
1151 | ->ThenBatchNormalizationBackward( |
1152 | y_backprop_ptr, x_ptr, scale_ptr, offset_ptr, mean_ptr, |
1153 | inv_variance_ptr, y_ptr, x_desc, scale_offset_desc, |
1154 | static_cast<double>(epsilon), |
1155 | AsDnnActivationMode(activation_mode), &x_backprop_ptr, |
1156 | &scale_backprop_ptr, &offset_backprop_ptr, |
1157 | &side_input_backprop_ptr, reserve_space_data_ptr, |
1158 | workspace_allocator.get()) |
1159 | .ok(); |
1160 | |
1161 | if (!cudnn_launch_status) { |
1162 | context->SetStatus( |
1163 | errors::Internal("cuDNN launch failure : input shape (" , |
1164 | x.shape().DebugString(), ")" )); |
1165 | } |
1166 | if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) { |
1167 | functor::NCHWToNHWC<GPUDevice, T, 4>()( |
1168 | context->eigen_device<GPUDevice>(), |
1169 | const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(), |
1170 | x_backprop->tensor<T, 4>()); |
1171 | } |
1172 | } |
1173 | }; |
1174 | |
1175 | // Forward declarations of the functor specializations for GPU. |
1176 | #define DECLARE_GPU_SPEC(T, U) \ |
1177 | template <> \ |
1178 | void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \ |
1179 | OpKernelContext* context, const Tensor& y_backprop_input, \ |
1180 | const Tensor& x_input, const Tensor& scale_input, \ |
1181 | const Tensor& mean_input, const Tensor& variance_input, U epsilon, \ |
1182 | Tensor* x_backprop_output, Tensor* scale_backprop_output, \ |
1183 | Tensor* offset_backprop_output); \ |
1184 | extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; \ |
1185 | template <> \ |
1186 | void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()( \ |
1187 | OpKernelContext* context, TensorFormat tensor_format, \ |
1188 | typename TTypes<T, 4>::ConstTensor in, \ |
1189 | typename TTypes<U>::ConstVec scale, typename TTypes<U>::ConstVec offset, \ |
1190 | typename TTypes<U>::ConstVec estimated_mean, \ |
1191 | typename TTypes<U>::ConstVec estimated_variance, \ |
1192 | typename TTypes<T, 4>::ConstTensor side_input, U epsilon, \ |
1193 | FusedBatchNormActivationMode activation_mode, \ |
1194 | typename TTypes<T, 4>::Tensor out); \ |
1195 | extern template struct FusedBatchNormInferenceFunctor<GPUDevice, T, U>; |
1196 | |
1197 | DECLARE_GPU_SPEC(float, float); |
1198 | DECLARE_GPU_SPEC(Eigen::half, float); |
1199 | |
1200 | #undef DECLARE_GPU_SPEC |
1201 | |
1202 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1203 | } // namespace functor |
1204 | |
1205 | template <typename Device, typename T, typename U> |
1206 | class FusedBatchNormOpBase : public OpKernel { |
1207 | using FbnActivationMode = functor::FusedBatchNormActivationMode; |
1208 | |
1209 | protected: |
1210 | explicit FusedBatchNormOpBase(OpKernelConstruction* context, |
1211 | bool is_batch_norm_ex = false) |
1212 | : OpKernel(context) { |
1213 | float epsilon; |
1214 | OP_REQUIRES_OK(context, context->GetAttr("epsilon" , &epsilon)); |
1215 | epsilon_ = U(epsilon); |
1216 | float exponential_avg_factor; |
1217 | OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor" , |
1218 | &exponential_avg_factor)); |
1219 | exponential_avg_factor_ = U(exponential_avg_factor); |
1220 | string tensor_format; |
1221 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &tensor_format)); |
1222 | OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), |
1223 | errors::InvalidArgument("Invalid data format" )); |
1224 | OP_REQUIRES_OK(context, context->GetAttr("is_training" , &is_training_)); |
1225 | |
1226 | if (!is_batch_norm_ex) { |
1227 | has_side_input_ = false; |
1228 | activation_mode_ = FbnActivationMode::kIdentity; |
1229 | } else { |
1230 | OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); |
1231 | |
1232 | int num_side_inputs; |
1233 | OP_REQUIRES_OK(context, |
1234 | context->GetAttr("num_side_inputs" , &num_side_inputs)); |
1235 | OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1, |
1236 | errors::InvalidArgument( |
1237 | "FusedBatchNorm accepts at most one side input." )); |
1238 | has_side_input_ = (num_side_inputs == 1); |
1239 | if (has_side_input_ && is_training_) { |
1240 | OP_REQUIRES( |
1241 | context, activation_mode_ != FbnActivationMode::kIdentity, |
1242 | errors::InvalidArgument("Identity activation is not supported with " |
1243 | "non-empty side input" )); |
1244 | } |
1245 | } |
1246 | |
1247 | if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) { |
1248 | // NOTE(ezhulenev): Following requirements are coming from implementation |
1249 | // details of cudnnBatchNormalizationForwardTrainingEx used in training |
1250 | // mode. In inference mode we call custom CUDA kernel that supports all |
1251 | // data formats and data types. |
1252 | OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF, |
1253 | errors::InvalidArgument("FusedBatchNorm with activation " |
1254 | "supports only DT_HALF data type." )); |
1255 | OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC, |
1256 | errors::InvalidArgument("FusedBatchNorm with activation " |
1257 | "supports only NHWC tensor format." )); |
1258 | OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(), |
1259 | errors::InvalidArgument( |
1260 | "FusedBatchNorm with activation must run with cuDNN " |
1261 | "spatial persistence mode enabled." )); |
1262 | } |
1263 | } |
1264 | |
1265 | // If use_reserved_space is true, we need to handle the 5th output (a reserved |
1266 | // space) and a new cudnn batch norm will be called if the version > 7.4.2. |
1267 | // If use_reserved_space is false, we don't have 5th output. |
1268 | virtual void ComputeWithReservedSpace(OpKernelContext* context, |
1269 | bool use_reserved_space) { |
1270 | Tensor x = context->input(0); |
1271 | const Tensor& scale = context->input(1); |
1272 | const Tensor& offset = context->input(2); |
1273 | const Tensor& estimated_mean = context->input(3); |
1274 | const Tensor& estimated_variance = context->input(4); |
1275 | const Tensor* side_input = has_side_input_ ? &context->input(5) : nullptr; |
1276 | |
1277 | OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5, |
1278 | errors::InvalidArgument("input must be 4 or 5-dimensional" , |
1279 | x.shape().DebugString())); |
1280 | OP_REQUIRES(context, scale.dims() == 1, |
1281 | errors::InvalidArgument("scale must be 1-dimensional" , |
1282 | scale.shape().DebugString())); |
1283 | OP_REQUIRES(context, offset.dims() == 1, |
1284 | errors::InvalidArgument("offset must be 1-dimensional" , |
1285 | offset.shape().DebugString())); |
1286 | OP_REQUIRES(context, estimated_mean.dims() == 1, |
1287 | errors::InvalidArgument("estimated_mean must be 1-dimensional" , |
1288 | estimated_mean.shape().DebugString())); |
1289 | OP_REQUIRES( |
1290 | context, estimated_variance.dims() == 1, |
1291 | errors::InvalidArgument("estimated_variance must be 1-dimensional" , |
1292 | estimated_variance.shape().DebugString())); |
1293 | bool use_reshape = (x.dims() == 5); |
1294 | auto x_shape = x.shape(); |
1295 | TensorShape dest_shape; |
1296 | if (use_reshape) { |
1297 | const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N'); |
1298 | int64_t in_planes = GetTensorDim(x, tensor_format_, '0'); |
1299 | int64_t in_rows = GetTensorDim(x, tensor_format_, '1'); |
1300 | int64_t in_cols = GetTensorDim(x, tensor_format_, '2'); |
1301 | const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C'); |
1302 | dest_shape = ShapeFromFormat(tensor_format_, in_batch, |
1303 | {{in_planes, in_rows * in_cols}}, in_depth); |
1304 | OP_REQUIRES(context, x.CopyFrom(x, dest_shape), |
1305 | errors::InvalidArgument("Error during tensor copy." )); |
1306 | } |
1307 | |
1308 | const auto num_channels = GetTensorDim(x, tensor_format_, 'C'); |
1309 | OP_REQUIRES( |
1310 | context, scale.NumElements() == num_channels, |
1311 | errors::InvalidArgument("scale must have the same number of elements " |
1312 | "as the channels of x, got " , |
1313 | scale.NumElements(), " and " , num_channels)); |
1314 | OP_REQUIRES( |
1315 | context, offset.NumElements() == num_channels, |
1316 | errors::InvalidArgument("offset must have the same number of elements " |
1317 | "as the channels of x, got " , |
1318 | offset.NumElements(), " and " , num_channels)); |
1319 | if (!is_training_ || exponential_avg_factor_ != 1.) { |
1320 | std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1" |
1321 | : "When is_training=false" ; |
1322 | OP_REQUIRES(context, estimated_mean.NumElements() == num_channels, |
1323 | errors::InvalidArgument( |
1324 | prefix_msg, |
1325 | ", mean must have the same number " |
1326 | "of elements as the channels of x, got " , |
1327 | estimated_mean.NumElements(), " and " , num_channels)); |
1328 | OP_REQUIRES(context, estimated_variance.NumElements() == num_channels, |
1329 | errors::InvalidArgument( |
1330 | prefix_msg, |
1331 | ", variance must have the same " |
1332 | "number of elements as the channels of x, got " , |
1333 | estimated_variance.NumElements(), " and " , num_channels)); |
1334 | } |
1335 | |
1336 | if (has_side_input_) { |
1337 | OP_REQUIRES(context, side_input->shape() == x.shape(), |
1338 | errors::InvalidArgument( |
1339 | "side_input shape must be equal to input shape: " , |
1340 | side_input->shape().DebugString(), |
1341 | " != " , x.shape().DebugString())); |
1342 | } |
1343 | |
1344 | if (activation_mode_ != FbnActivationMode::kIdentity) { |
1345 | // NOTE(ezhulenev): This requirement is coming from implementation |
1346 | // details of cudnnBatchNormalizationForwardTrainingEx. |
1347 | OP_REQUIRES( |
1348 | context, !is_training_ || num_channels % 4 == 0, |
1349 | errors::InvalidArgument("FusedBatchNorm with activation requires " |
1350 | "channel dimension to be a multiple of 4." )); |
1351 | } |
1352 | |
1353 | Tensor* y = nullptr; |
1354 | auto alloc_shape = use_reshape ? dest_shape : x_shape; |
1355 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
1356 | {0}, 0, alloc_shape, &y)); |
1357 | |
1358 | Tensor* batch_mean = nullptr; |
1359 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
1360 | {3}, 1, scale.shape(), &batch_mean)); |
1361 | Tensor* batch_var = nullptr; |
1362 | OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( |
1363 | {4}, 2, scale.shape(), &batch_var)); |
1364 | Tensor* saved_mean = nullptr; |
1365 | OP_REQUIRES_OK(context, |
1366 | context->allocate_output(3, scale.shape(), &saved_mean)); |
1367 | Tensor* saved_maybe_inv_var = nullptr; |
1368 | OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), |
1369 | &saved_maybe_inv_var)); |
1370 | |
1371 | if (is_training_) { |
1372 | functor::FusedBatchNorm<Device, T, U, true>()( |
1373 | context, x, scale, offset, estimated_mean, estimated_variance, |
1374 | side_input, epsilon_, exponential_avg_factor_, activation_mode_, y, |
1375 | batch_mean, batch_var, saved_mean, saved_maybe_inv_var, |
1376 | tensor_format_, use_reserved_space); |
1377 | } else { |
1378 | functor::FusedBatchNorm<Device, T, U, false>()( |
1379 | context, x, scale, offset, estimated_mean, estimated_variance, |
1380 | side_input, epsilon_, exponential_avg_factor_, activation_mode_, y, |
1381 | batch_mean, batch_var, saved_mean, saved_maybe_inv_var, |
1382 | tensor_format_, use_reserved_space); |
1383 | } |
1384 | if (use_reshape) { |
1385 | OP_REQUIRES(context, y->CopyFrom(*y, x_shape), |
1386 | errors::InvalidArgument("Error during tensor copy." )); |
1387 | } |
1388 | } |
1389 | |
1390 | private: |
1391 | U epsilon_; |
1392 | U exponential_avg_factor_; |
1393 | TensorFormat tensor_format_; |
1394 | bool is_training_; |
1395 | bool has_side_input_; |
1396 | FbnActivationMode activation_mode_; |
1397 | }; |
1398 | |
1399 | template <typename Device, typename T, typename U> |
1400 | class FusedBatchNormOp : public FusedBatchNormOpBase<Device, T, U> { |
1401 | public: |
1402 | explicit FusedBatchNormOp(OpKernelConstruction* context) |
1403 | : FusedBatchNormOpBase<Device, T, U>(context) {} |
1404 | |
1405 | void Compute(OpKernelContext* context) override { |
1406 | FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, |
1407 | false); |
1408 | } |
1409 | }; |
1410 | |
1411 | template <typename Device, typename T, typename U> |
1412 | class FusedBatchNormOpV3 : public FusedBatchNormOpBase<Device, T, U> { |
1413 | public: |
1414 | explicit FusedBatchNormOpV3(OpKernelConstruction* context) |
1415 | : FusedBatchNormOpBase<Device, T, U>(context) {} |
1416 | |
1417 | void Compute(OpKernelContext* context) override { |
1418 | FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true); |
1419 | } |
1420 | }; |
1421 | |
1422 | template <typename Device, typename T, typename U> |
1423 | class FusedBatchNormOpEx : public FusedBatchNormOpBase<Device, T, U> { |
1424 | static constexpr bool kWithSideInputAndActivation = true; |
1425 | |
1426 | public: |
1427 | explicit FusedBatchNormOpEx(OpKernelConstruction* context) |
1428 | : FusedBatchNormOpBase<Device, T, U>(context, |
1429 | kWithSideInputAndActivation) {} |
1430 | |
1431 | void Compute(OpKernelContext* context) override { |
1432 | FusedBatchNormOpBase<Device, T, U>::ComputeWithReservedSpace(context, true); |
1433 | } |
1434 | }; |
1435 | |
1436 | template <typename Device, typename T, typename U> |
1437 | class FusedBatchNormGradOpBase : public OpKernel { |
1438 | using FbnActivationMode = functor::FusedBatchNormActivationMode; |
1439 | |
1440 | protected: |
1441 | explicit FusedBatchNormGradOpBase(OpKernelConstruction* context, |
1442 | bool is_batch_norm_grad_ex = false) |
1443 | : OpKernel(context) { |
1444 | float epsilon; |
1445 | OP_REQUIRES_OK(context, context->GetAttr("epsilon" , &epsilon)); |
1446 | epsilon_ = U(epsilon); |
1447 | string tensor_format; |
1448 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &tensor_format)); |
1449 | OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), |
1450 | errors::InvalidArgument("Invalid data format" )); |
1451 | OP_REQUIRES_OK(context, context->GetAttr("is_training" , &is_training_)); |
1452 | if (!is_batch_norm_grad_ex) { |
1453 | has_side_input_ = false; |
1454 | activation_mode_ = FbnActivationMode::kIdentity; |
1455 | } else { |
1456 | OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); |
1457 | |
1458 | int num_side_inputs; |
1459 | OP_REQUIRES_OK(context, |
1460 | context->GetAttr("num_side_inputs" , &num_side_inputs)); |
1461 | OP_REQUIRES(context, num_side_inputs >= 0 && num_side_inputs <= 1, |
1462 | errors::InvalidArgument( |
1463 | "FusedBatchNormGrad accepts at most one side input." )); |
1464 | has_side_input_ = (num_side_inputs == 1); |
1465 | if (has_side_input_ && is_training_) { |
1466 | OP_REQUIRES( |
1467 | context, activation_mode_ != FbnActivationMode::kIdentity, |
1468 | errors::InvalidArgument("Identity activation is not supported with " |
1469 | "non-empty side input" )); |
1470 | } |
1471 | } |
1472 | |
1473 | if (activation_mode_ != FbnActivationMode::kIdentity && is_training_) { |
1474 | // NOTE(kaixih@nvidia): Following requirements are coming from |
1475 | // implementation details of cudnnBatchNormalizationBackwardEx used in |
1476 | // training mode. |
1477 | OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF, |
1478 | errors::InvalidArgument("FusedBatchNormGrad with activation " |
1479 | "supports only DT_HALF data type." )); |
1480 | OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC, |
1481 | errors::InvalidArgument("FusedBatchNormGrad with activation " |
1482 | "supports only NHWC tensor format." )); |
1483 | OP_REQUIRES(context, functor::BatchnormSpatialPersistentEnabled(), |
1484 | errors::InvalidArgument( |
1485 | "FusedBatchNormGrad with activation must run with cuDNN " |
1486 | "spatial persistence mode enabled." )); |
1487 | } |
1488 | } |
1489 | |
1490 | virtual void ComputeWithReservedSpace(OpKernelContext* context, |
1491 | bool use_reserved_space) { |
1492 | Tensor y_backprop = context->input(0); |
1493 | Tensor x = context->input(1); |
1494 | const Tensor& scale = context->input(2); |
1495 | // When is_training=True, batch mean and variance/inverted variance are |
1496 | // saved in the forward pass to be reused here. When is_training=False, |
1497 | // population mean and variance need to be forwarded here to compute the |
1498 | // gradients. |
1499 | const Tensor& saved_mean_or_pop_mean = context->input(3); |
1500 | // The Eigen implementation saves variance in the forward pass, while cuDNN |
1501 | // saves inverted variance. |
1502 | const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); |
1503 | bool use_activation = activation_mode_ != FbnActivationMode::kIdentity; |
1504 | const Tensor* offset = use_activation ? &context->input(6) : nullptr; |
1505 | const Tensor* y = use_activation ? &context->input(7) : nullptr; |
1506 | |
1507 | OP_REQUIRES(context, y_backprop.dims() == 4 || y_backprop.dims() == 5, |
1508 | errors::InvalidArgument("input must be 4 or 5-dimensional" , |
1509 | y_backprop.shape().DebugString())); |
1510 | OP_REQUIRES(context, x.dims() == 4 || x.dims() == 5, |
1511 | errors::InvalidArgument("input must be 4 or 5-dimensional" , |
1512 | x.shape().DebugString())); |
1513 | OP_REQUIRES(context, scale.dims() == 1, |
1514 | errors::InvalidArgument("scale must be 1-dimensional" , |
1515 | scale.shape().DebugString())); |
1516 | OP_REQUIRES( |
1517 | context, saved_mean_or_pop_mean.dims() == 1, |
1518 | errors::InvalidArgument("saved mean must be 1-dimensional" , |
1519 | saved_mean_or_pop_mean.shape().DebugString())); |
1520 | OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1, |
1521 | errors::InvalidArgument( |
1522 | "saved variance must be 1-dimensional" , |
1523 | saved_maybe_inv_var_or_pop_var.shape().DebugString())); |
1524 | OP_REQUIRES( |
1525 | context, x.shape() == y_backprop.shape(), |
1526 | errors::InvalidArgument( |
1527 | "x and y_backprop must have same shape, but x has shape " , |
1528 | x.shape(), " and y_backprop has shape " , y_backprop.shape())); |
1529 | if (use_activation) { |
1530 | OP_REQUIRES( |
1531 | context, x.dim_size(3) % 4 == 0, |
1532 | errors::InvalidArgument("FusedBatchNormGrad with activation requires " |
1533 | "channel dimension to be a multiple of 4." )); |
1534 | OP_REQUIRES(context, offset->dims() == 1, |
1535 | errors::InvalidArgument("offset must be 1-dimensional" , |
1536 | offset->shape().DebugString())); |
1537 | } |
1538 | bool use_reshape = (x.dims() == 5); |
1539 | auto x_shape = x.shape(); |
1540 | TensorShape dest_shape; |
1541 | if (use_reshape) { |
1542 | const int64_t in_batch = GetTensorDim(x, tensor_format_, 'N'); |
1543 | int64_t in_planes = GetTensorDim(x, tensor_format_, '0'); |
1544 | int64_t in_rows = GetTensorDim(x, tensor_format_, '1'); |
1545 | int64_t in_cols = GetTensorDim(x, tensor_format_, '2'); |
1546 | const int64_t in_depth = GetTensorDim(x, tensor_format_, 'C'); |
1547 | dest_shape = ShapeFromFormat(tensor_format_, in_batch, |
1548 | {{in_planes, in_rows * in_cols}}, in_depth); |
1549 | OP_REQUIRES(context, x.CopyFrom(x, dest_shape), |
1550 | errors::InvalidArgument("Error during tensor copy." )); |
1551 | OP_REQUIRES(context, y_backprop.CopyFrom(y_backprop, dest_shape), |
1552 | errors::InvalidArgument("Error during tensor copy." )); |
1553 | } |
1554 | |
1555 | const auto num_channels = GetTensorDim(x, tensor_format_, 'C'); |
1556 | OP_REQUIRES( |
1557 | context, scale.NumElements() == num_channels, |
1558 | errors::InvalidArgument("scale must have the same number of elements " |
1559 | "as the channels of x, got " , |
1560 | scale.NumElements(), " and " , num_channels)); |
1561 | OP_REQUIRES( |
1562 | context, saved_mean_or_pop_mean.NumElements() == num_channels, |
1563 | errors::InvalidArgument("reserve_space_1 must have the same number of " |
1564 | "elements as the channels of x, got " , |
1565 | saved_mean_or_pop_mean.NumElements(), " and " , |
1566 | num_channels)); |
1567 | OP_REQUIRES( |
1568 | context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels, |
1569 | errors::InvalidArgument("reserve_space_2 must have the same number of " |
1570 | "elements as the channels of x, got " , |
1571 | saved_maybe_inv_var_or_pop_var.NumElements(), |
1572 | " and " , num_channels)); |
1573 | |
1574 | Tensor* x_backprop = nullptr; |
1575 | auto alloc_shape = use_reshape ? dest_shape : x_shape; |
1576 | OP_REQUIRES_OK(context, |
1577 | context->allocate_output(0, alloc_shape, &x_backprop)); |
1578 | |
1579 | const TensorShape& scale_offset_shape = scale.shape(); |
1580 | Tensor* scale_backprop = nullptr; |
1581 | OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape, |
1582 | &scale_backprop)); |
1583 | Tensor* offset_backprop = nullptr; |
1584 | OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape, |
1585 | &offset_backprop)); |
1586 | // Two placeholders for estimated_mean and estimated_variance, which are |
1587 | // used for inference and thus not needed here for gradient computation. |
1588 | // They are filled with zeros so as to avoid NaN outputs. |
1589 | Tensor* placeholder_1 = nullptr; |
1590 | OP_REQUIRES_OK( |
1591 | context, context->allocate_output(3, TensorShape({0}), &placeholder_1)); |
1592 | Tensor* placeholder_2 = nullptr; |
1593 | OP_REQUIRES_OK( |
1594 | context, context->allocate_output(4, TensorShape({0}), &placeholder_2)); |
1595 | |
1596 | Tensor* side_input_backprop = nullptr; |
1597 | if (has_side_input_) { |
1598 | OP_REQUIRES_OK(context, context->allocate_output(5, alloc_shape, |
1599 | &side_input_backprop)); |
1600 | } |
1601 | |
1602 | // If input is empty, set gradients w.r.t scale/offset to zero. |
1603 | if (x.shape().num_elements() == 0) { |
1604 | functor::SetZeroFunctor<Device, U> f; |
1605 | f(context->eigen_device<Device>(), scale_backprop->flat<U>()); |
1606 | f(context->eigen_device<Device>(), offset_backprop->flat<U>()); |
1607 | return; |
1608 | } |
1609 | |
1610 | if (is_training_) { |
1611 | functor::FusedBatchNormGrad<Device, T, U>()( |
1612 | context, y_backprop, x, scale, offset, saved_mean_or_pop_mean, |
1613 | saved_maybe_inv_var_or_pop_var, y, epsilon_, activation_mode_, |
1614 | x_backprop, scale_backprop, offset_backprop, side_input_backprop, |
1615 | use_reserved_space, tensor_format_); |
1616 | } else { |
1617 | OP_REQUIRES( |
1618 | context, |
1619 | activation_mode_ == FbnActivationMode::kIdentity && !has_side_input_, |
1620 | errors::InvalidArgument( |
1621 | "FusedBatchNormGrad with activation is only supported " |
1622 | "when is_training=True." )); |
1623 | // Necessary layout conversion is currently done in python. |
1624 | OP_REQUIRES(context, tensor_format_ == FORMAT_NHWC, |
1625 | errors::InvalidArgument( |
1626 | "The implementation of " |
1627 | "FusedBatchNormGrad with is_training=False only support " |
1628 | "NHWC tensor format for now." )); |
1629 | functor::FusedBatchNormFreezeGrad<Device, T, U>()( |
1630 | context, y_backprop, x, scale, saved_mean_or_pop_mean, |
1631 | saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, |
1632 | offset_backprop); |
1633 | } |
1634 | if (use_reshape) { |
1635 | OP_REQUIRES(context, x_backprop->CopyFrom(*x_backprop, x_shape), |
1636 | errors::InvalidArgument("Error during tensor copy." )); |
1637 | } |
1638 | } |
1639 | |
1640 | private: |
1641 | U epsilon_; |
1642 | TensorFormat tensor_format_; |
1643 | bool is_training_; |
1644 | bool has_side_input_; |
1645 | FbnActivationMode activation_mode_; |
1646 | }; |
1647 | |
1648 | template <typename Device, typename T, typename U> |
1649 | class FusedBatchNormGradOp : public FusedBatchNormGradOpBase<Device, T, U> { |
1650 | public: |
1651 | explicit FusedBatchNormGradOp(OpKernelConstruction* context) |
1652 | : FusedBatchNormGradOpBase<Device, T, U>(context) {} |
1653 | |
1654 | void Compute(OpKernelContext* context) override { |
1655 | FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context, |
1656 | false); |
1657 | } |
1658 | }; |
1659 | |
1660 | template <typename Device, typename T, typename U> |
1661 | class FusedBatchNormGradOpV3 : public FusedBatchNormGradOpBase<Device, T, U> { |
1662 | public: |
1663 | explicit FusedBatchNormGradOpV3(OpKernelConstruction* context) |
1664 | : FusedBatchNormGradOpBase<Device, T, U>(context) {} |
1665 | |
1666 | void Compute(OpKernelContext* context) override { |
1667 | FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context, |
1668 | true); |
1669 | } |
1670 | }; |
1671 | |
1672 | template <typename Device, typename T, typename U> |
1673 | class FusedBatchNormGradOpEx : public FusedBatchNormGradOpBase<Device, T, U> { |
1674 | static constexpr bool kWithSideInputAndActivation = true; |
1675 | |
1676 | public: |
1677 | explicit FusedBatchNormGradOpEx(OpKernelConstruction* context) |
1678 | : FusedBatchNormGradOpBase<Device, T, U>(context, |
1679 | kWithSideInputAndActivation) {} |
1680 | |
1681 | void Compute(OpKernelContext* context) override { |
1682 | FusedBatchNormGradOpBase<Device, T, U>::ComputeWithReservedSpace(context, |
1683 | true); |
1684 | } |
1685 | }; |
1686 | |
1687 | REGISTER_KERNEL_BUILDER( |
1688 | Name("FusedBatchNorm" ).Device(DEVICE_CPU).TypeConstraint<float>("T" ), |
1689 | FusedBatchNormOp<CPUDevice, float, float>); |
1690 | |
1691 | REGISTER_KERNEL_BUILDER( |
1692 | Name("FusedBatchNormGrad" ).Device(DEVICE_CPU).TypeConstraint<float>("T" ), |
1693 | FusedBatchNormGradOp<CPUDevice, float, float>); |
1694 | |
1695 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2" ) |
1696 | .Device(DEVICE_CPU) |
1697 | .TypeConstraint<float>("T" ) |
1698 | .TypeConstraint<float>("U" ), |
1699 | FusedBatchNormOp<CPUDevice, float, float>); |
1700 | |
1701 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2" ) |
1702 | .Device(DEVICE_CPU) |
1703 | .TypeConstraint<float>("T" ) |
1704 | .TypeConstraint<float>("U" ), |
1705 | FusedBatchNormGradOp<CPUDevice, float, float>); |
1706 | |
1707 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2" ) |
1708 | .Device(DEVICE_CPU) |
1709 | .TypeConstraint<Eigen::half>("T" ) |
1710 | .TypeConstraint<float>("U" ), |
1711 | FusedBatchNormOp<CPUDevice, Eigen::half, float>); |
1712 | |
1713 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2" ) |
1714 | .Device(DEVICE_CPU) |
1715 | .TypeConstraint<Eigen::half>("T" ) |
1716 | .TypeConstraint<float>("U" ), |
1717 | FusedBatchNormGradOp<CPUDevice, Eigen::half, float>); |
1718 | |
1719 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3" ) |
1720 | .Device(DEVICE_CPU) |
1721 | .TypeConstraint<float>("T" ) |
1722 | .TypeConstraint<float>("U" ), |
1723 | FusedBatchNormOpV3<CPUDevice, float, float>); |
1724 | |
1725 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3" ) |
1726 | .Device(DEVICE_CPU) |
1727 | .TypeConstraint<float>("T" ) |
1728 | .TypeConstraint<float>("U" ), |
1729 | FusedBatchNormGradOpV3<CPUDevice, float, float>); |
1730 | |
1731 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3" ) |
1732 | .Device(DEVICE_CPU) |
1733 | .TypeConstraint<Eigen::half>("T" ) |
1734 | .TypeConstraint<float>("U" ), |
1735 | FusedBatchNormOpV3<CPUDevice, Eigen::half, float>); |
1736 | |
1737 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3" ) |
1738 | .Device(DEVICE_CPU) |
1739 | .TypeConstraint<Eigen::half>("T" ) |
1740 | .TypeConstraint<float>("U" ), |
1741 | FusedBatchNormGradOpV3<CPUDevice, Eigen::half, float>); |
1742 | |
1743 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1744 | |
1745 | REGISTER_KERNEL_BUILDER( |
1746 | Name("FusedBatchNorm" ).Device(DEVICE_GPU).TypeConstraint<float>("T" ), |
1747 | FusedBatchNormOp<GPUDevice, float, float>); |
1748 | |
1749 | REGISTER_KERNEL_BUILDER( |
1750 | Name("FusedBatchNormGrad" ).Device(DEVICE_GPU).TypeConstraint<float>("T" ), |
1751 | FusedBatchNormGradOp<GPUDevice, float, float>); |
1752 | |
1753 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2" ) |
1754 | .Device(DEVICE_GPU) |
1755 | .TypeConstraint<float>("T" ) |
1756 | .TypeConstraint<float>("U" ), |
1757 | FusedBatchNormOp<GPUDevice, float, float>); |
1758 | |
1759 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2" ) |
1760 | .Device(DEVICE_GPU) |
1761 | .TypeConstraint<float>("T" ) |
1762 | .TypeConstraint<float>("U" ), |
1763 | FusedBatchNormGradOp<GPUDevice, float, float>); |
1764 | |
1765 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2" ) |
1766 | .Device(DEVICE_GPU) |
1767 | .TypeConstraint<Eigen::half>("T" ) |
1768 | .TypeConstraint<float>("U" ), |
1769 | FusedBatchNormOp<GPUDevice, Eigen::half, float>); |
1770 | |
1771 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2" ) |
1772 | .Device(DEVICE_GPU) |
1773 | .TypeConstraint<Eigen::half>("T" ) |
1774 | .TypeConstraint<float>("U" ), |
1775 | FusedBatchNormGradOp<GPUDevice, Eigen::half, float>); |
1776 | |
1777 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3" ) |
1778 | .Device(DEVICE_GPU) |
1779 | .TypeConstraint<float>("T" ) |
1780 | .TypeConstraint<float>("U" ), |
1781 | FusedBatchNormOpV3<GPUDevice, float, float>); |
1782 | |
1783 | REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx" ) |
1784 | .Device(DEVICE_GPU) |
1785 | .TypeConstraint<float>("T" ) |
1786 | .TypeConstraint<float>("U" ), |
1787 | FusedBatchNormOpEx<GPUDevice, float, float>); |
1788 | |
1789 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3" ) |
1790 | .Device(DEVICE_GPU) |
1791 | .TypeConstraint<float>("T" ) |
1792 | .TypeConstraint<float>("U" ), |
1793 | FusedBatchNormGradOpV3<GPUDevice, float, float>); |
1794 | |
1795 | REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx" ) |
1796 | .Device(DEVICE_GPU) |
1797 | .TypeConstraint<float>("T" ) |
1798 | .TypeConstraint<float>("U" ), |
1799 | FusedBatchNormGradOpEx<GPUDevice, float, float>); |
1800 | |
1801 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV3" ) |
1802 | .Device(DEVICE_GPU) |
1803 | .TypeConstraint<Eigen::half>("T" ) |
1804 | .TypeConstraint<float>("U" ), |
1805 | FusedBatchNormOpV3<GPUDevice, Eigen::half, float>); |
1806 | |
1807 | REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx" ) |
1808 | .Device(DEVICE_GPU) |
1809 | .TypeConstraint<Eigen::half>("T" ) |
1810 | .TypeConstraint<float>("U" ), |
1811 | FusedBatchNormOpEx<GPUDevice, Eigen::half, float>); |
1812 | |
1813 | REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV3" ) |
1814 | .Device(DEVICE_GPU) |
1815 | .TypeConstraint<Eigen::half>("T" ) |
1816 | .TypeConstraint<float>("U" ), |
1817 | FusedBatchNormGradOpV3<GPUDevice, Eigen::half, float>); |
1818 | |
1819 | REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormGradEx" ) |
1820 | .Device(DEVICE_GPU) |
1821 | .TypeConstraint<Eigen::half>("T" ) |
1822 | .TypeConstraint<float>("U" ), |
1823 | FusedBatchNormGradOpEx<GPUDevice, Eigen::half, float>); |
1824 | |
1825 | #endif |
1826 | |
1827 | } // namespace tensorflow |
1828 | |