1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if defined(__ARM_NEON__) || defined(__ARM_NEON)
19#define USE_NEON
20#include <arm_neon.h>
21#endif
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/numeric_op.h"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/kernels/quantization_utils.h"
30
31#ifdef USE_NEON
32namespace {
33
34// Single pass mean and variance.
35// Shape of `input` is [rows x cols], shape of both `mean` and `variance`
36// is [cols].
37// Note, `mean` and `variance` are of 'i' (not scaled).
38// The following is a straightforward implementation of the parallel algorithm
39// described in
40// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
41void ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
42 const uint32_t cols, float* mean, float* variance) {
43 // The implementation operates on for 16 columns at a time.
44 // Assumes cols % 16 == 0
45 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
46 // Vector registers to track the running sum across the rows. Since there
47 // are 16 columns, we have 4 32x4 registers.
48 uint32x4_t sum[4] = {0};
49
50 float nA = 0.0f;
51 // Running average and the second moment.
52 float32x4_t xA[4] = {0.0f};
53 float32x4_t M2A[4] = {0.0f};
54
55 const uint8_t* inp_ptr = input + col_offset;
56 // Go over the rows in chunks of 256. This is so that we can use 16 bit adds
57 // to do the accumulation.
58 for (uint32_t row = 0; row < rows; row += 256) {
59 // Running sum and sum of squares for the 256 rows.
60 uint32x4_t sub_sum[4] = {0};
61 uint32x4_t sub_sq_sum[4] = {0};
62 const uint32_t limit = std::min(rows, row + 256);
63 const float nB = limit - row;
64 for (uint32_t subrow = row; subrow < limit; ++subrow) {
65 const uint8x16_t v = vld1q_u8(inp_ptr);
66 inp_ptr += cols;
67
68 const uint8x8_t v_high = vget_high_u8(v);
69 const uint8x8_t v_low = vget_low_u8(v);
70
71 const uint16x8_t v_high_u16 = vmovl_u8(v_high);
72 const uint16x8_t v_low_u16 = vmovl_u8(v_low);
73
74 const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
75 const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
76 const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
77 const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
78
79 sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
80 sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
81 sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
82 sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
83
84 sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
85 sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
86 sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
87 sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
88 }
89
90 // Update the full running sum and moment from the ones for 256 rows.
91 for (int i = 0; i < 4; ++i) {
92 sum[i] = vaddq_u32(sum[i], sub_sum[i]);
93 const float nX = nA + nB;
94 // xB is the average of up to 256 elements.
95 const float32x4_t xB =
96 vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
97
98 // delta = xB - xA
99 const float32x4_t delta = vsubq_f32(xB, xA[i]);
100 // xA = (nA * xA + nB * xB) / (nA + nB)
101 xA[i] = vmulq_n_f32(
102 vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
103
104 const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
105 const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
106
107 // M2B = sum(xB^2) - sum(xB)^2/nB
108 const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
109 vmulq_n_f32(sub_sum_sq, 1.0f / nB));
110 const float32x4_t last_term =
111 vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
112 // M2A = oldM2A + M2B + delta^2 * nA*nB/nX
113 M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
114 }
115 nA += limit;
116 }
117
118 // Write the final mean and variance for the 16 columns.
119 const float inv_rows = 1.0f / static_cast<float>(rows);
120 vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
121 vst1q_f32(mean + col_offset + 4,
122 vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
123 vst1q_f32(mean + col_offset + 8,
124 vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
125 vst1q_f32(mean + col_offset + 12,
126 vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
127
128 vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
129 vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
130 vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
131 vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
132 }
133}
134
135// Compute min and max of (input - mean) / sqrt(variance + epsilon).
136// This is done in a separate pass so that the normalized value can be
137// temporarily computed in floating point precision and not stored anywhere.
138void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
139 const float* mean_ptr, const float* variance_ptr,
140 float variance_epsilon, float* minimum, float* maximum) {
141 float v_maximum = std::numeric_limits<float>::min();
142 float v_minimum = std::numeric_limits<float>::max();
143 const float32x4_t eps = vdupq_n_f32(variance_epsilon);
144
145 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
146 const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
147 vld1q_f32(mean_ptr + col_offset + 4),
148 vld1q_f32(mean_ptr + col_offset + 8),
149 vld1q_f32(mean_ptr + col_offset + 12)};
150 const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
151 vld1q_f32(variance_ptr + col_offset + 4),
152 vld1q_f32(variance_ptr + col_offset + 8),
153 vld1q_f32(variance_ptr + col_offset + 12)};
154 const float32x4_t inv_stddev[4] = {
155 vrsqrteq_f32(vaddq_f32(variance[0], eps)),
156 vrsqrteq_f32(vaddq_f32(variance[1], eps)),
157 vrsqrteq_f32(vaddq_f32(variance[2], eps)),
158 vrsqrteq_f32(vaddq_f32(variance[3], eps))};
159
160 const uint8_t* inp_ptr = input + col_offset;
161 for (uint32_t row = 0; row < rows; ++row) {
162 const uint8x16_t v = vld1q_u8(inp_ptr);
163 inp_ptr += cols;
164
165 const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
166 const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
167
168 const float32x4_t v_float[4] = {
169 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
170 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
171 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
172 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
173
174 for (int i = 0; i < 4; ++i) {
175 const float32x4_t normed =
176 vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
177 const float32x2_t high = vget_high_f32(normed);
178 const float32x2_t low = vget_low_f32(normed);
179 float32x2_t tmp_max = vpmax_f32(low, high);
180 tmp_max = vpmax_f32(tmp_max, tmp_max);
181 v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
182 float32x2_t tmp_min = vpmin_f32(low, high);
183 tmp_min = vpmin_f32(tmp_min, tmp_min);
184 v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
185 }
186 }
187 }
188 *minimum = v_minimum;
189 *maximum = v_maximum;
190}
191
192// Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
193// it in the range (minimum, maximum) and store the result as quint8.
194void InstanceNorm(const uint8_t* input, const uint32_t rows,
195 const uint32_t cols, const float* mean_ptr,
196 const float* variance_ptr, float variance_epsilon,
197 float minimum, float maximum, uint8_t* output) {
198 const float32x4_t eps = vdupq_n_f32(variance_epsilon);
199 const float32x4_t out_min = vdupq_n_f32(minimum);
200 const float out_scale = 255.0f / (maximum - minimum);
201
202 for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
203 const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
204 vld1q_f32(mean_ptr + col_offset + 8),
205 vld1q_f32(mean_ptr + col_offset + 4),
206 vld1q_f32(mean_ptr + col_offset)};
207 const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
208 vld1q_f32(variance_ptr + col_offset + 8),
209 vld1q_f32(variance_ptr + col_offset + 4),
210 vld1q_f32(variance_ptr + col_offset)};
211 const float32x4_t inv_stddev[4] = {
212 vrsqrteq_f32(vaddq_f32(variance[0], eps)),
213 vrsqrteq_f32(vaddq_f32(variance[1], eps)),
214 vrsqrteq_f32(vaddq_f32(variance[2], eps)),
215 vrsqrteq_f32(vaddq_f32(variance[3], eps))};
216 const uint8_t* inp_ptr = input + col_offset;
217 uint8_t* out_ptr = output + col_offset;
218 for (uint32_t row = 0; row < rows; ++row) {
219 const uint8x16_t v = vld1q_u8(inp_ptr);
220 inp_ptr += cols;
221 const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
222 const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
223
224 const float32x4_t v_float[4] = {
225 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
226 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
227 vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
228 vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
229
230 uint16x4_t normed_uint16[4];
231 for (int i = 0; i < 4; ++i) {
232 const float32x4_t normed =
233 vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
234 const int32x4_t normed_int32 =
235 vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
236 normed_uint16[i] = vqmovun_s32(normed_int32);
237 }
238 vst1_u8(out_ptr,
239 vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
240 vst1_u8(out_ptr + 8,
241 vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
242 out_ptr += cols;
243 }
244 }
245}
246
247} // end namespace
248#endif // USE_NEON
249
250namespace tensorflow {
251
252typedef Eigen::ThreadPoolDevice CPUDevice;
253
254class QuantizedInstanceNorm : public OpKernel {
255 public:
256 explicit QuantizedInstanceNorm(OpKernelConstruction* context)
257 : OpKernel(context) {
258 OP_REQUIRES_OK(context,
259 context->GetAttr("variance_epsilon", &variance_epsilon_));
260 OP_REQUIRES_OK(context,
261 context->GetAttr("min_separation", &min_separation_));
262 OP_REQUIRES_OK(
263 context, context->GetAttr("output_range_given", &output_range_given_));
264 if (output_range_given_) {
265 OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
266 OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
267 OP_REQUIRES(context, given_y_min_ < given_y_max_,
268 errors::InvalidArgument(
269 "given_y_min must be less than given_y_max : ",
270 given_y_min_, " >= ", given_y_max_));
271 }
272 }
273
274 void Compute(OpKernelContext* context) override {
275 const Tensor& input = context->input(0);
276
277 const Tensor& x_min = context->input(1);
278 const Tensor& x_max = context->input(2);
279 OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_min.shape()),
280 errors::InvalidArgument("`x_min` must be rank 0 but is rank ",
281 x_min.dims()));
282 OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_max.shape()),
283 errors::InvalidArgument("`x_max` must be rank 0 but is rank ",
284 x_max.dims()));
285 float input_min = x_min.scalar<float>()();
286 float input_max = x_max.scalar<float>()();
287 float input_scale = (input_max - input_min) / 255.0f;
288
289 OP_REQUIRES(context, input_min < input_max,
290 errors::InvalidArgument(
291 "input_min must be less than input_max : ", input_min,
292 " >= ", input_max));
293
294 auto input_tensor = input.tensor<quint8, 4>();
295 auto N = input_tensor.dimension(0);
296 auto H = input_tensor.dimension(1);
297 auto W = input_tensor.dimension(2);
298 auto C = input_tensor.dimension(3);
299
300 Tensor* output = nullptr;
301 OP_REQUIRES_OK(context,
302 context->allocate_output(0, input.shape(), &output));
303
304 Tensor* output_min = nullptr;
305 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
306 Tensor* output_max = nullptr;
307 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
308
309 typedef TTypes<float>::Tensor::Index Index;
310
311 const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
312 reduction_indices;
313 Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
314 broadcast_spec;
315 broadcast_spec.set(1, H);
316 broadcast_spec.set(2, W);
317 Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
318 expand_spec;
319 expand_spec.set(0, N);
320 expand_spec.set(3, C);
321
322 Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
323 Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
324
325#ifdef USE_NEON
326 if (N == 1 && (C % 16 == 0)) {
327 VLOG(2) << "Calling optimized";
328 ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
329 H * W, C, float_mean.data(), float_variance.data());
330
331 float minimum = given_y_min_, maximum = given_y_max_;
332 if (!output_range_given_) {
333 MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
334 C, float_mean.data(), float_variance.data(),
335 variance_epsilon_, &minimum, &maximum);
336 }
337
338 if (maximum - minimum < min_separation_) {
339 maximum = minimum + min_separation_;
340 }
341
342 InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
343 C, float_mean.data(), float_variance.data(),
344 variance_epsilon_, minimum, maximum,
345 reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
346 output_min->scalar<float>()() = minimum;
347 output_max->scalar<float>()() = maximum;
348 } else // NOLINT(readability/braces)
349#endif
350 {
351 VLOG(2) << "Calling unoptimized";
352 float_mean = input_tensor.cast<float>().reduce(
353 reduction_indices, Eigen::internal::MeanReducer<float>());
354
355 float_variance =
356 (input_scale *
357 ((input_tensor.cast<float>() -
358 float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
359 .square()
360 .reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
361
362 Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
363 input_scale *
364 (input_tensor.cast<float>() -
365 float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
366 (float_variance + variance_epsilon_)
367 .rsqrt()
368 .reshape(expand_spec)
369 .broadcast(broadcast_spec);
370
371 Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
372 Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
373
374 if (!output_range_given_) {
375 normed_min = instance_normed.minimum();
376 normed_max = instance_normed.maximum();
377 } else {
378 normed_min() = given_y_min_;
379 normed_max() = given_y_max_;
380 }
381
382 if (normed_max() - normed_min() < min_separation_) {
383 normed_max() = normed_min() + min_separation_;
384 }
385
386 FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
387 auto instance_normed_quantized =
388 QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
389
390 output->tensor<quint8, 4>().device(
391 context->template eigen_device<CPUDevice>()) =
392 instance_normed_quantized;
393 output_min->flat<float>()(0) = normed_min();
394 output_max->flat<float>()(0) = normed_max();
395 }
396 }
397
398 private:
399 float variance_epsilon_;
400 float min_separation_;
401 bool output_range_given_;
402 float given_y_min_;
403 float given_y_max_;
404};
405
406REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
407 .Device(DEVICE_CPU)
408 .TypeConstraint<quint8>("T"),
409 QuantizedInstanceNorm);
410
411} // namespace tensorflow
412