1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ |
18 | // Functor definition for BatchNormOp, must be compilable by nvcc. |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/tensor_types.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace functor { |
24 | |
25 | // Functor used by BatchNormOp to do the computations. |
26 | template <typename Device, typename T> |
27 | struct BatchNorm { |
28 | void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, |
29 | typename TTypes<T>::ConstVec mean, |
30 | typename TTypes<T>::ConstVec var, |
31 | typename TTypes<T>::ConstVec beta, |
32 | typename TTypes<T>::ConstVec gamma, T variance_epsilon, |
33 | bool scale_after_normalization, |
34 | typename TTypes<T, 4>::Tensor output) { |
35 | const int depth = mean.dimension(0); |
36 | const int rest_size = input.size() / depth; |
37 | |
38 | Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth); |
39 | Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one; |
40 | rest_by_one.set(0, rest_size); |
41 | Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth; |
42 | one_by_depth.set(1, depth); |
43 | Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one; |
44 | depth_by_one.set(0, depth); |
45 | if (scale_after_normalization) { |
46 | output.reshape(rest_by_depth).device(d) = |
47 | (input.reshape(rest_by_depth) - |
48 | mean.reshape(one_by_depth).broadcast(rest_by_one)) * |
49 | ((var + var.constant(variance_epsilon)).rsqrt() * gamma) |
50 | .eval() |
51 | .reshape(one_by_depth) |
52 | .broadcast(rest_by_one) + |
53 | beta.reshape(one_by_depth).broadcast(rest_by_one); |
54 | } else { |
55 | output.reshape(rest_by_depth).device(d) = |
56 | (input.reshape(rest_by_depth) - |
57 | mean.reshape(one_by_depth).broadcast(rest_by_one)) * |
58 | ((var + var.constant(variance_epsilon)).rsqrt()) |
59 | .eval() |
60 | .reshape(one_by_depth) |
61 | .broadcast(rest_by_one) + |
62 | beta.reshape(one_by_depth).broadcast(rest_by_one); |
63 | } |
64 | } |
65 | }; |
66 | |
67 | template <typename Device, typename T> |
68 | struct BatchNormGrad { |
69 | void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, |
70 | typename TTypes<T>::ConstVec mean, |
71 | typename TTypes<T>::ConstVec var, |
72 | typename TTypes<T>::ConstVec gamma, |
73 | typename TTypes<T, 4>::ConstTensor out_backprop, |
74 | T variance_epsilon, bool scale_after_normalization, |
75 | typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm, |
76 | typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db, |
77 | typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1, |
78 | typename TTypes<T>::Vec scratch2) { |
79 | const int depth = mean.dimension(0); |
80 | const int rest_size = input.size() / depth; |
81 | |
82 | typedef typename TTypes<T>::ConstVec::Index Index; |
83 | |
84 | Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth); |
85 | Eigen::IndexList<Index, Eigen::type2index<1> > rest_by_one; |
86 | rest_by_one.set(0, rest_size); |
87 | Eigen::IndexList<Eigen::type2index<1>, Index> one_by_depth; |
88 | one_by_depth.set(1, depth); |
89 | Eigen::IndexList<Eigen::type2index<0> > reduction_axis; |
90 | |
91 | // db = out_backprop |
92 | // |
93 | // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) |
94 | // |
95 | // dv = sum_over_rest(out_backprop * gamma * (x - m)) * |
96 | // (-1/2) * (v + epsilon) ^ (-3/2) |
97 | // |
98 | // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) |
99 | // |
100 | // dx = out_backprop * (gamma * rsqrt(v + epsilon)) |
101 | db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); |
102 | |
103 | // scratch1 = rsqrt(v + epsilon) |
104 | scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); |
105 | |
106 | // scratch2 = sum_over_rest(out_backprop * (x - m)) |
107 | scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * |
108 | (input.reshape(rest_by_depth) - |
109 | mean.reshape(one_by_depth).broadcast(rest_by_one))) |
110 | .sum(reduction_axis); |
111 | |
112 | if (scale_after_normalization) { |
113 | dx.reshape(rest_by_depth).device(d) = |
114 | out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) |
115 | .eval() |
116 | .reshape(one_by_depth) |
117 | .broadcast(rest_by_one)); |
118 | dm.device(d) = -db * (scratch1 * gamma).eval(); |
119 | dg.device(d) = scratch2 * scratch1; |
120 | } else { |
121 | dx.reshape(rest_by_depth).device(d) = |
122 | out_backprop.reshape(rest_by_depth) * |
123 | scratch1.reshape(one_by_depth).broadcast(rest_by_one); |
124 | dm.device(d) = -db * scratch1; |
125 | dg.device(d) = dg.constant(static_cast<T>(0.0)); // Gamma is not learned. |
126 | } |
127 | |
128 | // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) |
129 | scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) / |
130 | (var + var.constant(variance_epsilon)); |
131 | |
132 | if (scale_after_normalization) { |
133 | dv.device(d) = scratch2 * (scratch1 * gamma).eval(); |
134 | } else { |
135 | dv.device(d) = scratch2 * scratch1; |
136 | } |
137 | } |
138 | }; |
139 | |
140 | } // namespace functor |
141 | } // namespace tensorflow |
142 | |
143 | #endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ |
144 | |