1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19#include "tensorflow/core/framework/numeric_op.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/register_types.h"
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/kernels/quantization_utils.h"
24
25namespace tensorflow {
26
27namespace {
28
29// A slow but straightforward implementation of batch normalization.
30template <typename T1, typename T2>
31void ReferenceBatchNorm(const Tensor& input, const float input_min,
32 const float input_max, const Tensor& mean,
33 float mean_min, float mean_max, const Tensor& var,
34 float var_min, float var_max, const Tensor& beta,
35 float beta_min, float beta_max, const Tensor& gamma,
36 float gamma_min, float gamma_max,
37 float variance_epsilon, bool scale_after_normalization,
38 Tensor* output, float* output_min, float* output_max) {
39 auto input_flat = input.flat<T1>();
40 auto mean_flat = mean.flat<T1>();
41 auto var_flat = var.flat<T1>();
42 auto beta_flat = beta.flat<T1>();
43 auto gamma_flat = gamma.flat<T1>();
44 auto output_flat = output->flat<T2>();
45
46 const int depth = mean.dim_size(0);
47 const int row_count = input_flat.size() / depth;
48
49 *output_min = std::numeric_limits<float>::max();
50 *output_max = std::numeric_limits<float>::lowest();
51 for (int pass = 0; pass < 2; ++pass) {
52 const bool is_range_pass = (pass == 0);
53 for (int row_index = 0; row_index < row_count; ++row_index) {
54 for (int channel = 0; channel < depth; ++channel) {
55 const int input_index = (row_index * depth) + channel;
56 const float input_value =
57 QuantizedToFloat(input_flat(input_index), input_min, input_max);
58 const float mean_value =
59 QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
60 const float var_value =
61 QuantizedToFloat(var_flat(channel), var_min, var_max);
62 const float beta_value =
63 QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
64 const float gamma_value =
65 QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
66 float output_value;
67 if (scale_after_normalization) {
68 output_value = (((input_value - mean_value) /
69 sqrtf(var_value + variance_epsilon)) *
70 gamma_value) +
71 beta_value;
72 } else {
73 output_value = ((input_value - mean_value) /
74 sqrtf(var_value + variance_epsilon)) +
75 beta_value;
76 }
77 if (is_range_pass) {
78 *output_min = std::min(output_value, *output_min);
79 *output_max = std::max(output_value, *output_max);
80 } else {
81 output_flat(input_index) =
82 FloatToQuantized<T2>(output_value, *output_min, *output_max);
83 }
84 }
85 }
86 }
87}
88
89// An implementation of batch normalization that does the main calculations
90// using only fixed-point arithmetic. There's a prologue with some floating
91// calculations, but assuming the weights are constant these could be hoisted to
92// an offline process, or baked into the weights.
93template <typename T1, typename T2>
94void FixedPointBatchNorm(const Tensor& input, const float input_min,
95 const float input_max, const Tensor& mean,
96 float mean_min, float mean_max, const Tensor& var,
97 float var_min, float var_max, const Tensor& beta,
98 float beta_min, float beta_max, const Tensor& gamma,
99 float gamma_min, float gamma_max,
100 float variance_epsilon, bool scale_after_normalization,
101 Tensor* output, float* output_min, float* output_max) {
102 auto input_flat = input.flat<T1>();
103 auto mean_flat = mean.flat<T1>();
104 auto var_flat = var.flat<T1>();
105 auto beta_flat = beta.flat<T1>();
106 auto gamma_flat = gamma.flat<T1>();
107 auto output_flat = output->flat<T2>();
108
109 const int depth = mean.dim_size(0);
110 const int row_count = input_flat.size() / depth;
111
112 // The range here is chosen so that typical input values fit in without any
113 // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed
114 // point precision.
115 *output_min = -(1 << 20);
116 *output_max = (1 << 20);
117
118 Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth});
119 auto scale_flat = scale_tensor.flat<T2>();
120 Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth});
121 auto offset_flat = offset_tensor.flat<T2>();
122 for (int channel = 0; channel < depth; ++channel) {
123 const float mean_value =
124 QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
125 const float var_value =
126 QuantizedToFloat(var_flat(channel), var_min, var_max);
127 const float beta_value =
128 QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
129 const float gamma_value =
130 QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
131 float scale_value;
132 if (scale_after_normalization) {
133 scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value;
134 } else {
135 scale_value = (1.0f / sqrtf(var_value + variance_epsilon));
136 }
137 const float offset_value = (-mean_value * scale_value) + beta_value;
138 scale_flat(channel) =
139 FloatToQuantized<T2>(scale_value, *output_min, *output_max);
140 offset_flat(channel) =
141 FloatToQuantized<T2>(offset_value, *output_min, *output_max);
142 }
143
144 const T2 one_in_output_space =
145 FloatToQuantized<T2>(1.0f, *output_min, *output_max);
146 for (int row_index = 0; row_index < row_count; ++row_index) {
147 for (int channel = 0; channel < depth; ++channel) {
148 const int input_index = (row_index * depth) + channel;
149 const T2 input_value =
150 RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min,
151 input_max, *output_min, *output_max);
152 const T2 scale_value = scale_flat(channel);
153 const T2 offset_value = offset_flat(channel);
154 const T2 output_value =
155 ((input_value * scale_value) / one_in_output_space) + offset_value;
156 output_flat(input_index) = output_value;
157 }
158 }
159}
160
161} // namespace
162
163template <typename T1, typename T2>
164class QuantizedBatchNormOp : public OpKernel {
165 public:
166 explicit QuantizedBatchNormOp(OpKernelConstruction* context)
167 : OpKernel(context) {
168 OP_REQUIRES_OK(context,
169 context->GetAttr("variance_epsilon", &variance_epsilon_));
170 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
171 &scale_after_normalization_));
172 }
173
174 void Compute(OpKernelContext* context) override {
175 const Tensor& input = context->input(0);
176 const auto& input_min_tensor = context->input(1);
177 OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
178 errors::InvalidArgument("input_min must have 1 element"));
179 const float input_min = input_min_tensor.flat<float>()(0);
180 const auto& input_max_tensor = context->input(2);
181 OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
182 errors::InvalidArgument("input_max must have 1 element"));
183 const float input_max = input_max_tensor.flat<float>()(0);
184 const Tensor& mean = context->input(3);
185 const auto& mean_min_tensor = context->input(4);
186 OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
187 errors::InvalidArgument("mean_min must have 1 element"));
188 const float mean_min = mean_min_tensor.flat<float>()(0);
189 const auto& mean_max_tensor = context->input(5);
190 OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
191 errors::InvalidArgument("mean_max must have 1 element"));
192 const float mean_max = mean_max_tensor.flat<float>()(0);
193 const Tensor& var = context->input(6);
194 const auto& var_min_tensor = context->input(7);
195 OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
196 errors::InvalidArgument("var_min must have 1 element"));
197 const float var_min = var_min_tensor.flat<float>()(0);
198 const auto& var_max_tensor = context->input(8);
199 OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
200 errors::InvalidArgument("var_max must have 1 element"));
201 const float var_max = var_max_tensor.flat<float>()(0);
202 const Tensor& beta = context->input(9);
203 const auto& beta_min_tensor = context->input(10);
204 OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
205 errors::InvalidArgument("beta_min must have 1 element"));
206 const float beta_min = beta_min_tensor.flat<float>()(0);
207 const auto& beta_max_tensor = context->input(11);
208 OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
209 errors::InvalidArgument("beta_max must have 1 element"));
210 const float beta_max = beta_max_tensor.flat<float>()(0);
211 const Tensor& gamma = context->input(12);
212 const auto& gamma_min_tensor = context->input(13);
213 OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
214 errors::InvalidArgument("gamma_min must have 1 element"));
215 const float gamma_min = gamma_min_tensor.flat<float>()(0);
216 const auto& gamma_max_tensor = context->input(14);
217 OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
218 errors::InvalidArgument("gamma_max must have 1 element"));
219 const float gamma_max = gamma_max_tensor.flat<float>()(0);
220
221 OP_REQUIRES(context, input.dims() == 4,
222 errors::InvalidArgument("input must be 4-dimensional",
223 input.shape().DebugString()));
224 OP_REQUIRES(context, mean.dims() == 1,
225 errors::InvalidArgument("mean must be 1-dimensional",
226 mean.shape().DebugString()));
227 OP_REQUIRES(context, var.dims() == 1,
228 errors::InvalidArgument("var must be 1-dimensional",
229 var.shape().DebugString()));
230 OP_REQUIRES(context, beta.dims() == 1,
231 errors::InvalidArgument("beta must be 1-dimensional",
232 beta.shape().DebugString()));
233 OP_REQUIRES(context, gamma.dims() == 1,
234 errors::InvalidArgument("gamma must be 1-dimensional",
235 gamma.shape().DebugString()));
236 OP_REQUIRES(context, mean.NumElements() > 1,
237 errors::InvalidArgument("Must have at least a mean value",
238 gamma.shape().DebugString()));
239 OP_REQUIRES(context, mean.NumElements() > 1,
240 errors::InvalidArgument("Must have at least a mean value"));
241 const auto last_dim = input.shape().dims() - 1;
242 OP_REQUIRES(context,
243 mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
244 errors::InvalidArgument("Must provide as many means as the "
245 "last dimension of the input tensor: ",
246 mean.shape().DebugString(), " vs. ",
247 input.shape().DebugString()));
248 OP_REQUIRES(
249 context, mean.shape().dim_size(0) == var.shape().dim_size(0),
250 errors::InvalidArgument(
251 "Mean and variance tensors must have the same shape: ",
252 mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
253 OP_REQUIRES(
254 context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
255 errors::InvalidArgument(
256 "Mean and beta tensors must have the same shape: ",
257 mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
258 OP_REQUIRES(
259 context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
260 errors::InvalidArgument(
261 "Mean and gamma tensors must have the same shape: ",
262 mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));
263
264 Tensor* output = nullptr;
265 OP_REQUIRES_OK(context,
266 context->allocate_output(0, input.shape(), &output));
267 float output_min;
268 float output_max;
269 FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min,
270 mean_max, var, var_min, var_max, beta, beta_min,
271 beta_max, gamma, gamma_min, gamma_max,
272 variance_epsilon_, scale_after_normalization_,
273 output, &output_min, &output_max);
274
275 Tensor* output_min_tensor = nullptr;
276 OP_REQUIRES_OK(context,
277 context->allocate_output(1, {}, &output_min_tensor));
278 output_min_tensor->flat<float>()(0) = output_min;
279
280 Tensor* output_max_tensor = nullptr;
281 OP_REQUIRES_OK(context,
282 context->allocate_output(2, {}, &output_max_tensor));
283 output_max_tensor->flat<float>()(0) = output_max;
284 }
285
286 private:
287 float variance_epsilon_;
288 bool scale_after_normalization_;
289};
290
291REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization")
292 .Device(DEVICE_CPU)
293 .TypeConstraint<quint8>("Tinput")
294 .TypeConstraint<qint32>("out_type"),
295 QuantizedBatchNormOp<quint8, qint32>);
296
297} // namespace tensorflow
298