1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
17#define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
18// Functor definition for BatchNormOp, must be compilable by nvcc.
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/tensor_types.h"
21
22namespace tensorflow {
23namespace functor {
24
25// Functor used by BatchNormOp to do the computations.
26template <typename Device, typename T>
27struct BatchNorm {
28 void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
29 typename TTypes<T>::ConstVec mean,
30 typename TTypes<T>::ConstVec var,
31 typename TTypes<T>::ConstVec beta,
32 typename TTypes<T>::ConstVec gamma, T variance_epsilon,
33 bool scale_after_normalization,
34 typename TTypes<T, 4>::Tensor output) {
35 const int depth = mean.dimension(0);
36 const int rest_size = input.size() / depth;
37
38 Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth);
39 Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
40 rest_by_one.set(0, rest_size);
41 Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth;
42 one_by_depth.set(1, depth);
43 Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one;
44 depth_by_one.set(0, depth);
45 if (scale_after_normalization) {
46 output.reshape(rest_by_depth).device(d) =
47 (input.reshape(rest_by_depth) -
48 mean.reshape(one_by_depth).broadcast(rest_by_one)) *
49 ((var + var.constant(variance_epsilon)).rsqrt() * gamma)
50 .eval()
51 .reshape(one_by_depth)
52 .broadcast(rest_by_one) +
53 beta.reshape(one_by_depth).broadcast(rest_by_one);
54 } else {
55 output.reshape(rest_by_depth).device(d) =
56 (input.reshape(rest_by_depth) -
57 mean.reshape(one_by_depth).broadcast(rest_by_one)) *
58 ((var + var.constant(variance_epsilon)).rsqrt())
59 .eval()
60 .reshape(one_by_depth)
61 .broadcast(rest_by_one) +
62 beta.reshape(one_by_depth).broadcast(rest_by_one);
63 }
64 }
65};
66
67template <typename Device, typename T>
68struct BatchNormGrad {
69 void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
70 typename TTypes<T>::ConstVec mean,
71 typename TTypes<T>::ConstVec var,
72 typename TTypes<T>::ConstVec gamma,
73 typename TTypes<T, 4>::ConstTensor out_backprop,
74 T variance_epsilon, bool scale_after_normalization,
75 typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
76 typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
77 typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,
78 typename TTypes<T>::Vec scratch2) {
79 const int depth = mean.dimension(0);
80 const int rest_size = input.size() / depth;
81
82 typedef typename TTypes<T>::ConstVec::Index Index;
83
84 Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth);
85 Eigen::IndexList<Index, Eigen::type2index<1> > rest_by_one;
86 rest_by_one.set(0, rest_size);
87 Eigen::IndexList<Eigen::type2index<1>, Index> one_by_depth;
88 one_by_depth.set(1, depth);
89 Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
90
91 // db = out_backprop
92 //
93 // dg = out_backprop * ((x - m) * rsqrt(v + epsilon))
94 //
95 // dv = sum_over_rest(out_backprop * gamma * (x - m)) *
96 // (-1/2) * (v + epsilon) ^ (-3/2)
97 //
98 // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon))
99 //
100 // dx = out_backprop * (gamma * rsqrt(v + epsilon))
101 db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis);
102
103 // scratch1 = rsqrt(v + epsilon)
104 scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt();
105
106 // scratch2 = sum_over_rest(out_backprop * (x - m))
107 scratch2.device(d) = (out_backprop.reshape(rest_by_depth) *
108 (input.reshape(rest_by_depth) -
109 mean.reshape(one_by_depth).broadcast(rest_by_one)))
110 .sum(reduction_axis);
111
112 if (scale_after_normalization) {
113 dx.reshape(rest_by_depth).device(d) =
114 out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma)
115 .eval()
116 .reshape(one_by_depth)
117 .broadcast(rest_by_one));
118 dm.device(d) = -db * (scratch1 * gamma).eval();
119 dg.device(d) = scratch2 * scratch1;
120 } else {
121 dx.reshape(rest_by_depth).device(d) =
122 out_backprop.reshape(rest_by_depth) *
123 scratch1.reshape(one_by_depth).broadcast(rest_by_one);
124 dm.device(d) = -db * scratch1;
125 dg.device(d) = dg.constant(static_cast<T>(0.0)); // Gamma is not learned.
126 }
127
128 // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2)
129 scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) /
130 (var + var.constant(variance_epsilon));
131
132 if (scale_after_normalization) {
133 dv.device(d) = scratch2 * (scratch1 * gamma).eval();
134 } else {
135 dv.device(d) = scratch2 * scratch1;
136 }
137 }
138};
139
140} // namespace functor
141} // namespace tensorflow
142
143#endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
144