1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
19 | #include "tensorflow/core/framework/numeric_op.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/register_types.h" |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/kernels/quantization_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | namespace { |
28 | |
29 | // A slow but straightforward implementation of batch normalization. |
30 | template <typename T1, typename T2> |
31 | void ReferenceBatchNorm(const Tensor& input, const float input_min, |
32 | const float input_max, const Tensor& mean, |
33 | float mean_min, float mean_max, const Tensor& var, |
34 | float var_min, float var_max, const Tensor& beta, |
35 | float beta_min, float beta_max, const Tensor& gamma, |
36 | float gamma_min, float gamma_max, |
37 | float variance_epsilon, bool scale_after_normalization, |
38 | Tensor* output, float* output_min, float* output_max) { |
39 | auto input_flat = input.flat<T1>(); |
40 | auto mean_flat = mean.flat<T1>(); |
41 | auto var_flat = var.flat<T1>(); |
42 | auto beta_flat = beta.flat<T1>(); |
43 | auto gamma_flat = gamma.flat<T1>(); |
44 | auto output_flat = output->flat<T2>(); |
45 | |
46 | const int depth = mean.dim_size(0); |
47 | const int row_count = input_flat.size() / depth; |
48 | |
49 | *output_min = std::numeric_limits<float>::max(); |
50 | *output_max = std::numeric_limits<float>::lowest(); |
51 | for (int pass = 0; pass < 2; ++pass) { |
52 | const bool is_range_pass = (pass == 0); |
53 | for (int row_index = 0; row_index < row_count; ++row_index) { |
54 | for (int channel = 0; channel < depth; ++channel) { |
55 | const int input_index = (row_index * depth) + channel; |
56 | const float input_value = |
57 | QuantizedToFloat(input_flat(input_index), input_min, input_max); |
58 | const float mean_value = |
59 | QuantizedToFloat(mean_flat(channel), mean_min, mean_max); |
60 | const float var_value = |
61 | QuantizedToFloat(var_flat(channel), var_min, var_max); |
62 | const float beta_value = |
63 | QuantizedToFloat(beta_flat(channel), beta_min, beta_max); |
64 | const float gamma_value = |
65 | QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max); |
66 | float output_value; |
67 | if (scale_after_normalization) { |
68 | output_value = (((input_value - mean_value) / |
69 | sqrtf(var_value + variance_epsilon)) * |
70 | gamma_value) + |
71 | beta_value; |
72 | } else { |
73 | output_value = ((input_value - mean_value) / |
74 | sqrtf(var_value + variance_epsilon)) + |
75 | beta_value; |
76 | } |
77 | if (is_range_pass) { |
78 | *output_min = std::min(output_value, *output_min); |
79 | *output_max = std::max(output_value, *output_max); |
80 | } else { |
81 | output_flat(input_index) = |
82 | FloatToQuantized<T2>(output_value, *output_min, *output_max); |
83 | } |
84 | } |
85 | } |
86 | } |
87 | } |
88 | |
89 | // An implementation of batch normalization that does the main calculations |
90 | // using only fixed-point arithmetic. There's a prologue with some floating |
91 | // calculations, but assuming the weights are constant these could be hoisted to |
92 | // an offline process, or baked into the weights. |
93 | template <typename T1, typename T2> |
94 | void FixedPointBatchNorm(const Tensor& input, const float input_min, |
95 | const float input_max, const Tensor& mean, |
96 | float mean_min, float mean_max, const Tensor& var, |
97 | float var_min, float var_max, const Tensor& beta, |
98 | float beta_min, float beta_max, const Tensor& gamma, |
99 | float gamma_min, float gamma_max, |
100 | float variance_epsilon, bool scale_after_normalization, |
101 | Tensor* output, float* output_min, float* output_max) { |
102 | auto input_flat = input.flat<T1>(); |
103 | auto mean_flat = mean.flat<T1>(); |
104 | auto var_flat = var.flat<T1>(); |
105 | auto beta_flat = beta.flat<T1>(); |
106 | auto gamma_flat = gamma.flat<T1>(); |
107 | auto output_flat = output->flat<T2>(); |
108 | |
109 | const int depth = mean.dim_size(0); |
110 | const int row_count = input_flat.size() / depth; |
111 | |
112 | // The range here is chosen so that typical input values fit in without any |
113 | // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed |
114 | // point precision. |
115 | *output_min = -(1 << 20); |
116 | *output_max = (1 << 20); |
117 | |
118 | Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth}); |
119 | auto scale_flat = scale_tensor.flat<T2>(); |
120 | Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth}); |
121 | auto offset_flat = offset_tensor.flat<T2>(); |
122 | for (int channel = 0; channel < depth; ++channel) { |
123 | const float mean_value = |
124 | QuantizedToFloat(mean_flat(channel), mean_min, mean_max); |
125 | const float var_value = |
126 | QuantizedToFloat(var_flat(channel), var_min, var_max); |
127 | const float beta_value = |
128 | QuantizedToFloat(beta_flat(channel), beta_min, beta_max); |
129 | const float gamma_value = |
130 | QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max); |
131 | float scale_value; |
132 | if (scale_after_normalization) { |
133 | scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value; |
134 | } else { |
135 | scale_value = (1.0f / sqrtf(var_value + variance_epsilon)); |
136 | } |
137 | const float offset_value = (-mean_value * scale_value) + beta_value; |
138 | scale_flat(channel) = |
139 | FloatToQuantized<T2>(scale_value, *output_min, *output_max); |
140 | offset_flat(channel) = |
141 | FloatToQuantized<T2>(offset_value, *output_min, *output_max); |
142 | } |
143 | |
144 | const T2 one_in_output_space = |
145 | FloatToQuantized<T2>(1.0f, *output_min, *output_max); |
146 | for (int row_index = 0; row_index < row_count; ++row_index) { |
147 | for (int channel = 0; channel < depth; ++channel) { |
148 | const int input_index = (row_index * depth) + channel; |
149 | const T2 input_value = |
150 | RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min, |
151 | input_max, *output_min, *output_max); |
152 | const T2 scale_value = scale_flat(channel); |
153 | const T2 offset_value = offset_flat(channel); |
154 | const T2 output_value = |
155 | ((input_value * scale_value) / one_in_output_space) + offset_value; |
156 | output_flat(input_index) = output_value; |
157 | } |
158 | } |
159 | } |
160 | |
161 | } // namespace |
162 | |
163 | template <typename T1, typename T2> |
164 | class QuantizedBatchNormOp : public OpKernel { |
165 | public: |
166 | explicit QuantizedBatchNormOp(OpKernelConstruction* context) |
167 | : OpKernel(context) { |
168 | OP_REQUIRES_OK(context, |
169 | context->GetAttr("variance_epsilon" , &variance_epsilon_)); |
170 | OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization" , |
171 | &scale_after_normalization_)); |
172 | } |
173 | |
174 | void Compute(OpKernelContext* context) override { |
175 | const Tensor& input = context->input(0); |
176 | const auto& input_min_tensor = context->input(1); |
177 | OP_REQUIRES(context, input_min_tensor.NumElements() == 1, |
178 | errors::InvalidArgument("input_min must have 1 element" )); |
179 | const float input_min = input_min_tensor.flat<float>()(0); |
180 | const auto& input_max_tensor = context->input(2); |
181 | OP_REQUIRES(context, input_max_tensor.NumElements() == 1, |
182 | errors::InvalidArgument("input_max must have 1 element" )); |
183 | const float input_max = input_max_tensor.flat<float>()(0); |
184 | const Tensor& mean = context->input(3); |
185 | const auto& mean_min_tensor = context->input(4); |
186 | OP_REQUIRES(context, mean_min_tensor.NumElements() == 1, |
187 | errors::InvalidArgument("mean_min must have 1 element" )); |
188 | const float mean_min = mean_min_tensor.flat<float>()(0); |
189 | const auto& mean_max_tensor = context->input(5); |
190 | OP_REQUIRES(context, mean_max_tensor.NumElements() == 1, |
191 | errors::InvalidArgument("mean_max must have 1 element" )); |
192 | const float mean_max = mean_max_tensor.flat<float>()(0); |
193 | const Tensor& var = context->input(6); |
194 | const auto& var_min_tensor = context->input(7); |
195 | OP_REQUIRES(context, var_min_tensor.NumElements() == 1, |
196 | errors::InvalidArgument("var_min must have 1 element" )); |
197 | const float var_min = var_min_tensor.flat<float>()(0); |
198 | const auto& var_max_tensor = context->input(8); |
199 | OP_REQUIRES(context, var_max_tensor.NumElements() == 1, |
200 | errors::InvalidArgument("var_max must have 1 element" )); |
201 | const float var_max = var_max_tensor.flat<float>()(0); |
202 | const Tensor& beta = context->input(9); |
203 | const auto& beta_min_tensor = context->input(10); |
204 | OP_REQUIRES(context, beta_min_tensor.NumElements() == 1, |
205 | errors::InvalidArgument("beta_min must have 1 element" )); |
206 | const float beta_min = beta_min_tensor.flat<float>()(0); |
207 | const auto& beta_max_tensor = context->input(11); |
208 | OP_REQUIRES(context, beta_max_tensor.NumElements() == 1, |
209 | errors::InvalidArgument("beta_max must have 1 element" )); |
210 | const float beta_max = beta_max_tensor.flat<float>()(0); |
211 | const Tensor& gamma = context->input(12); |
212 | const auto& gamma_min_tensor = context->input(13); |
213 | OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1, |
214 | errors::InvalidArgument("gamma_min must have 1 element" )); |
215 | const float gamma_min = gamma_min_tensor.flat<float>()(0); |
216 | const auto& gamma_max_tensor = context->input(14); |
217 | OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1, |
218 | errors::InvalidArgument("gamma_max must have 1 element" )); |
219 | const float gamma_max = gamma_max_tensor.flat<float>()(0); |
220 | |
221 | OP_REQUIRES(context, input.dims() == 4, |
222 | errors::InvalidArgument("input must be 4-dimensional" , |
223 | input.shape().DebugString())); |
224 | OP_REQUIRES(context, mean.dims() == 1, |
225 | errors::InvalidArgument("mean must be 1-dimensional" , |
226 | mean.shape().DebugString())); |
227 | OP_REQUIRES(context, var.dims() == 1, |
228 | errors::InvalidArgument("var must be 1-dimensional" , |
229 | var.shape().DebugString())); |
230 | OP_REQUIRES(context, beta.dims() == 1, |
231 | errors::InvalidArgument("beta must be 1-dimensional" , |
232 | beta.shape().DebugString())); |
233 | OP_REQUIRES(context, gamma.dims() == 1, |
234 | errors::InvalidArgument("gamma must be 1-dimensional" , |
235 | gamma.shape().DebugString())); |
236 | OP_REQUIRES(context, mean.NumElements() > 1, |
237 | errors::InvalidArgument("Must have at least a mean value" , |
238 | gamma.shape().DebugString())); |
239 | OP_REQUIRES(context, mean.NumElements() > 1, |
240 | errors::InvalidArgument("Must have at least a mean value" )); |
241 | const auto last_dim = input.shape().dims() - 1; |
242 | OP_REQUIRES(context, |
243 | mean.shape().dim_size(0) == input.shape().dim_size(last_dim), |
244 | errors::InvalidArgument("Must provide as many means as the " |
245 | "last dimension of the input tensor: " , |
246 | mean.shape().DebugString(), " vs. " , |
247 | input.shape().DebugString())); |
248 | OP_REQUIRES( |
249 | context, mean.shape().dim_size(0) == var.shape().dim_size(0), |
250 | errors::InvalidArgument( |
251 | "Mean and variance tensors must have the same shape: " , |
252 | mean.shape().DebugString(), " vs. " , var.shape().DebugString())); |
253 | OP_REQUIRES( |
254 | context, mean.shape().dim_size(0) == beta.shape().dim_size(0), |
255 | errors::InvalidArgument( |
256 | "Mean and beta tensors must have the same shape: " , |
257 | mean.shape().DebugString(), " vs. " , beta.shape().DebugString())); |
258 | OP_REQUIRES( |
259 | context, mean.shape().dim_size(0) == gamma.shape().dim_size(0), |
260 | errors::InvalidArgument( |
261 | "Mean and gamma tensors must have the same shape: " , |
262 | mean.shape().DebugString(), " vs. " , gamma.shape().DebugString())); |
263 | |
264 | Tensor* output = nullptr; |
265 | OP_REQUIRES_OK(context, |
266 | context->allocate_output(0, input.shape(), &output)); |
267 | float output_min; |
268 | float output_max; |
269 | FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min, |
270 | mean_max, var, var_min, var_max, beta, beta_min, |
271 | beta_max, gamma, gamma_min, gamma_max, |
272 | variance_epsilon_, scale_after_normalization_, |
273 | output, &output_min, &output_max); |
274 | |
275 | Tensor* output_min_tensor = nullptr; |
276 | OP_REQUIRES_OK(context, |
277 | context->allocate_output(1, {}, &output_min_tensor)); |
278 | output_min_tensor->flat<float>()(0) = output_min; |
279 | |
280 | Tensor* output_max_tensor = nullptr; |
281 | OP_REQUIRES_OK(context, |
282 | context->allocate_output(2, {}, &output_max_tensor)); |
283 | output_max_tensor->flat<float>()(0) = output_max; |
284 | } |
285 | |
286 | private: |
287 | float variance_epsilon_; |
288 | bool scale_after_normalization_; |
289 | }; |
290 | |
291 | REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization" ) |
292 | .Device(DEVICE_CPU) |
293 | .TypeConstraint<quint8>("Tinput" ) |
294 | .TypeConstraint<qint32>("out_type" ), |
295 | QuantizedBatchNormOp<quint8, qint32>); |
296 | |
297 | } // namespace tensorflow |
298 | |