1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/batch_norm_op.h"
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26
27namespace tensorflow {
28
29typedef Eigen::ThreadPoolDevice CPUDevice;
30typedef Eigen::GpuDevice GPUDevice;
31
32template <typename Device, typename T>
33class BatchNormOp : public OpKernel {
34 public:
35 explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
36 float variance_epsilon;
37 OP_REQUIRES_OK(context,
38 context->GetAttr("variance_epsilon", &variance_epsilon));
39 variance_epsilon_ = T(variance_epsilon);
40 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
41 &scale_after_normalization_));
42 }
43
44 void Compute(OpKernelContext* context) override {
45 const Tensor& input = context->input(0);
46 const Tensor& mean = context->input(1);
47 const Tensor& var = context->input(2);
48 const Tensor& beta = context->input(3);
49 const Tensor& gamma = context->input(4);
50
51 OP_REQUIRES(context, input.dims() == 4,
52 errors::InvalidArgument("input must be 4-dimensional",
53 input.shape().DebugString()));
54 OP_REQUIRES(context, mean.dims() == 1,
55 errors::InvalidArgument("mean must be 1-dimensional",
56 mean.shape().DebugString()));
57 OP_REQUIRES(context, var.dims() == 1,
58 errors::InvalidArgument("var must be 1-dimensional",
59 var.shape().DebugString()));
60 OP_REQUIRES(context, beta.dims() == 1,
61 errors::InvalidArgument("beta must be 1-dimensional",
62 beta.shape().DebugString()));
63 OP_REQUIRES(context, gamma.dims() == 1,
64 errors::InvalidArgument("gamma must be 1-dimensional",
65 gamma.shape().DebugString()));
66
67 Tensor* output = nullptr;
68 OP_REQUIRES_OK(context,
69 context->allocate_output(0, input.shape(), &output));
70
71 functor::BatchNorm<Device, T>()(
72 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
73 var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_,
74 scale_after_normalization_, output->tensor<T, 4>());
75 }
76
77 private:
78 T variance_epsilon_;
79 bool scale_after_normalization_;
80};
81
82template <typename Device, typename T>
83class BatchNormGradOp : public OpKernel {
84 public:
85 explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
86 float variance_epsilon;
87 OP_REQUIRES_OK(context,
88 context->GetAttr("variance_epsilon", &variance_epsilon));
89 variance_epsilon_ = T(variance_epsilon);
90 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
91 &scale_after_normalization_));
92 }
93
94 void Compute(OpKernelContext* context) override {
95 const Tensor& input = context->input(0);
96 const Tensor& mean = context->input(1);
97 const Tensor& var = context->input(2);
98 const Tensor& gamma = context->input(3);
99 const Tensor& out_backprop = context->input(4);
100
101 OP_REQUIRES(context, input.dims() == 4,
102 errors::InvalidArgument("input must be 4-dimensional",
103 input.shape().DebugString()));
104 OP_REQUIRES(context, mean.dims() == 1,
105 errors::InvalidArgument("mean must be 1-dimensional",
106 mean.shape().DebugString()));
107 OP_REQUIRES(context, var.dims() == 1,
108 errors::InvalidArgument("var must be 1-dimensional",
109 var.shape().DebugString()));
110 OP_REQUIRES(context, gamma.dims() == 1,
111 errors::InvalidArgument("gamma must be 1-dimensional",
112 gamma.shape().DebugString()));
113 OP_REQUIRES(context, out_backprop.dims() == 4,
114 errors::InvalidArgument("out_backprop must be 4-dimensional",
115 out_backprop.shape().DebugString()));
116
117 Tensor* dx = nullptr;
118 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
119 {0, 4}, 0, input.shape(), &dx));
120 Tensor* dm = nullptr;
121 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
122 {1}, 1, mean.shape(), &dm));
123 Tensor* dv = nullptr;
124 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
125 {2}, 2, var.shape(), &dv));
126 Tensor* db = nullptr;
127 if (scale_after_normalization_) {
128 OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
129 } else {
130 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
131 {3}, 3, mean.shape(), &db));
132 }
133 Tensor* dg = nullptr;
134 OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
135
136 // Scratch buffer of [depth] dimension, aka the 4th dimension of input,
137 // which is dim_size(3), for calculating various combinations of
138 // (var + epsilon).
139 Tensor scratch1;
140 OP_REQUIRES_OK(context, context->allocate_temp(
141 DataTypeToEnum<T>::value,
142 TensorShape({input.dim_size(3)}), &scratch1));
143
144 // Scratch buffer of [depth] dimension for saving intermediate calculation
145 // values.
146 Tensor scratch2;
147 OP_REQUIRES_OK(context, context->allocate_temp(
148 DataTypeToEnum<T>::value,
149 TensorShape({input.dim_size(3)}), &scratch2));
150
151 functor::BatchNormGrad<Device, T>()(
152 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
153 var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(),
154 variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(),
155 dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(),
156 scratch1.vec<T>(), scratch2.vec<T>());
157 }
158
159 private:
160 T variance_epsilon_;
161 bool scale_after_normalization_;
162};
163
164#define REGISTER_KERNEL(T) \
165 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
166 .Device(DEVICE_CPU) \
167 .TypeConstraint<T>("T"), \
168 BatchNormOp<CPUDevice, T>);
169
170TF_CALL_half(REGISTER_KERNEL);
171TF_CALL_float(REGISTER_KERNEL);
172TF_CALL_double(REGISTER_KERNEL);
173#undef REGISTER_KERNEL
174
175#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176// Forward declarations of the functor specializations for GPU.
177namespace functor {
178#define DECLARE_GPU_SPEC(T) \
179 template <> \
180 void BatchNorm<GPUDevice, T>::operator()( \
181 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
182 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
183 typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \
184 T variance_epsilon, bool scale_after_normalization, \
185 typename TTypes<T, 4>::Tensor output); \
186 extern template struct BatchNorm<GPUDevice, T>;
187
188#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
189
190TF_CALL_half(DECLARE_GPU_SPECS);
191TF_CALL_float(DECLARE_GPU_SPECS);
192#undef DECLARE_GPU_SPEC
193} // namespace functor
194
195// Registration of the GPU implementations.
196#define REGISTER_GPU_KERNEL(T) \
197 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
198 .Device(DEVICE_GPU) \
199 .TypeConstraint<T>("T"), \
200 BatchNormOp<GPUDevice, T>);
201
202TF_CALL_half(REGISTER_GPU_KERNEL);
203TF_CALL_float(REGISTER_GPU_KERNEL);
204#undef REGISTER_GPU_KERNEL
205
206#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
207
208
209#define REGISTER_KERNEL(T) \
210 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
211 .Device(DEVICE_CPU) \
212 .TypeConstraint<T>("T"), \
213 BatchNormGradOp<CPUDevice, T>);
214
215TF_CALL_half(REGISTER_KERNEL);
216TF_CALL_float(REGISTER_KERNEL);
217TF_CALL_double(REGISTER_KERNEL);
218#undef REGISTER_KERNEL
219
220#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
221// Forward declarations of the functor specializations for GPU.
222namespace functor {
223#define DECLARE_GPU_SPEC(T) \
224 template <> \
225 void BatchNormGrad<GPUDevice, T>::operator()( \
226 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
227 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
228 typename TTypes<T>::ConstVec gamma, \
229 typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \
230 bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx, \
231 typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv, \
232 typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg, \
233 typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \
234 extern template struct BatchNormGrad<GPUDevice, T>;
235
236#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
237
238TF_CALL_half(DECLARE_GPU_SPECS);
239TF_CALL_float(DECLARE_GPU_SPECS);
240#undef DECLARE_GPU_SPEC
241} // namespace functor
242
243// Registration of the GPU implementations.
244#define REGISTER_GPU_KERNEL(T) \
245 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
246 .Device(DEVICE_GPU) \
247 .TypeConstraint<T>("T"), \
248 BatchNormGradOp<GPUDevice, T>);
249
250TF_CALL_half(REGISTER_GPU_KERNEL);
251TF_CALL_float(REGISTER_GPU_KERNEL);
252#undef REGISTER_GPU_KERNEL
253
254#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255
256
257} // namespace tensorflow
258