1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #if defined(__ARM_NEON__) || defined(__ARM_NEON) |
19 | #define USE_NEON |
20 | #include <arm_neon.h> |
21 | #endif |
22 | |
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
24 | #include "tensorflow/core/framework/numeric_op.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_shape.h" |
29 | #include "tensorflow/core/kernels/quantization_utils.h" |
30 | |
31 | #ifdef USE_NEON |
32 | namespace { |
33 | |
34 | // Single pass mean and variance. |
35 | // Shape of `input` is [rows x cols], shape of both `mean` and `variance` |
36 | // is [cols]. |
37 | // Note, `mean` and `variance` are of 'i' (not scaled). |
38 | // The following is a straightforward implementation of the parallel algorithm |
39 | // described in |
40 | // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm |
41 | void ColMeanAndVariance(const uint8_t* input, const uint32_t rows, |
42 | const uint32_t cols, float* mean, float* variance) { |
43 | // The implementation operates on for 16 columns at a time. |
44 | // Assumes cols % 16 == 0 |
45 | for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { |
46 | // Vector registers to track the running sum across the rows. Since there |
47 | // are 16 columns, we have 4 32x4 registers. |
48 | uint32x4_t sum[4] = {0}; |
49 | |
50 | float nA = 0.0f; |
51 | // Running average and the second moment. |
52 | float32x4_t xA[4] = {0.0f}; |
53 | float32x4_t M2A[4] = {0.0f}; |
54 | |
55 | const uint8_t* inp_ptr = input + col_offset; |
56 | // Go over the rows in chunks of 256. This is so that we can use 16 bit adds |
57 | // to do the accumulation. |
58 | for (uint32_t row = 0; row < rows; row += 256) { |
59 | // Running sum and sum of squares for the 256 rows. |
60 | uint32x4_t sub_sum[4] = {0}; |
61 | uint32x4_t sub_sq_sum[4] = {0}; |
62 | const uint32_t limit = std::min(rows, row + 256); |
63 | const float nB = limit - row; |
64 | for (uint32_t subrow = row; subrow < limit; ++subrow) { |
65 | const uint8x16_t v = vld1q_u8(inp_ptr); |
66 | inp_ptr += cols; |
67 | |
68 | const uint8x8_t v_high = vget_high_u8(v); |
69 | const uint8x8_t v_low = vget_low_u8(v); |
70 | |
71 | const uint16x8_t v_high_u16 = vmovl_u8(v_high); |
72 | const uint16x8_t v_low_u16 = vmovl_u8(v_low); |
73 | |
74 | const uint16x4_t v_high_high = vget_high_u16(v_high_u16); |
75 | const uint16x4_t v_high_low = vget_low_u16(v_high_u16); |
76 | const uint16x4_t v_low_high = vget_high_u16(v_low_u16); |
77 | const uint16x4_t v_low_low = vget_low_u16(v_low_u16); |
78 | |
79 | sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high); |
80 | sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low); |
81 | sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high); |
82 | sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low); |
83 | |
84 | sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high); |
85 | sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low); |
86 | sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high); |
87 | sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low); |
88 | } |
89 | |
90 | // Update the full running sum and moment from the ones for 256 rows. |
91 | for (int i = 0; i < 4; ++i) { |
92 | sum[i] = vaddq_u32(sum[i], sub_sum[i]); |
93 | const float nX = nA + nB; |
94 | // xB is the average of up to 256 elements. |
95 | const float32x4_t xB = |
96 | vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB); |
97 | |
98 | // delta = xB - xA |
99 | const float32x4_t delta = vsubq_f32(xB, xA[i]); |
100 | // xA = (nA * xA + nB * xB) / (nA + nB) |
101 | xA[i] = vmulq_n_f32( |
102 | vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX); |
103 | |
104 | const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]); |
105 | const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32); |
106 | |
107 | // M2B = sum(xB^2) - sum(xB)^2/nB |
108 | const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]), |
109 | vmulq_n_f32(sub_sum_sq, 1.0f / nB)); |
110 | const float32x4_t last_term = |
111 | vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX); |
112 | // M2A = oldM2A + M2B + delta^2 * nA*nB/nX |
113 | M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term); |
114 | } |
115 | nA += limit; |
116 | } |
117 | |
118 | // Write the final mean and variance for the 16 columns. |
119 | const float inv_rows = 1.0f / static_cast<float>(rows); |
120 | vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows)); |
121 | vst1q_f32(mean + col_offset + 4, |
122 | vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows)); |
123 | vst1q_f32(mean + col_offset + 8, |
124 | vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows)); |
125 | vst1q_f32(mean + col_offset + 12, |
126 | vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows)); |
127 | |
128 | vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows)); |
129 | vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows)); |
130 | vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows)); |
131 | vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows)); |
132 | } |
133 | } |
134 | |
135 | // Compute min and max of (input - mean) / sqrt(variance + epsilon). |
136 | // This is done in a separate pass so that the normalized value can be |
137 | // temporarily computed in floating point precision and not stored anywhere. |
138 | void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols, |
139 | const float* mean_ptr, const float* variance_ptr, |
140 | float variance_epsilon, float* minimum, float* maximum) { |
141 | float v_maximum = std::numeric_limits<float>::min(); |
142 | float v_minimum = std::numeric_limits<float>::max(); |
143 | const float32x4_t eps = vdupq_n_f32(variance_epsilon); |
144 | |
145 | for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { |
146 | const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset), |
147 | vld1q_f32(mean_ptr + col_offset + 4), |
148 | vld1q_f32(mean_ptr + col_offset + 8), |
149 | vld1q_f32(mean_ptr + col_offset + 12)}; |
150 | const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset), |
151 | vld1q_f32(variance_ptr + col_offset + 4), |
152 | vld1q_f32(variance_ptr + col_offset + 8), |
153 | vld1q_f32(variance_ptr + col_offset + 12)}; |
154 | const float32x4_t inv_stddev[4] = { |
155 | vrsqrteq_f32(vaddq_f32(variance[0], eps)), |
156 | vrsqrteq_f32(vaddq_f32(variance[1], eps)), |
157 | vrsqrteq_f32(vaddq_f32(variance[2], eps)), |
158 | vrsqrteq_f32(vaddq_f32(variance[3], eps))}; |
159 | |
160 | const uint8_t* inp_ptr = input + col_offset; |
161 | for (uint32_t row = 0; row < rows; ++row) { |
162 | const uint8x16_t v = vld1q_u8(inp_ptr); |
163 | inp_ptr += cols; |
164 | |
165 | const uint16x8_t v_high = vmovl_u8(vget_high_u8(v)); |
166 | const uint16x8_t v_low = vmovl_u8(vget_low_u8(v)); |
167 | |
168 | const float32x4_t v_float[4] = { |
169 | vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))), |
170 | vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))), |
171 | vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))), |
172 | vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))}; |
173 | |
174 | for (int i = 0; i < 4; ++i) { |
175 | const float32x4_t normed = |
176 | vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]); |
177 | const float32x2_t high = vget_high_f32(normed); |
178 | const float32x2_t low = vget_low_f32(normed); |
179 | float32x2_t tmp_max = vpmax_f32(low, high); |
180 | tmp_max = vpmax_f32(tmp_max, tmp_max); |
181 | v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0)); |
182 | float32x2_t tmp_min = vpmin_f32(low, high); |
183 | tmp_min = vpmin_f32(tmp_min, tmp_min); |
184 | v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0)); |
185 | } |
186 | } |
187 | } |
188 | *minimum = v_minimum; |
189 | *maximum = v_maximum; |
190 | } |
191 | |
192 | // Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize |
193 | // it in the range (minimum, maximum) and store the result as quint8. |
194 | void InstanceNorm(const uint8_t* input, const uint32_t rows, |
195 | const uint32_t cols, const float* mean_ptr, |
196 | const float* variance_ptr, float variance_epsilon, |
197 | float minimum, float maximum, uint8_t* output) { |
198 | const float32x4_t eps = vdupq_n_f32(variance_epsilon); |
199 | const float32x4_t out_min = vdupq_n_f32(minimum); |
200 | const float out_scale = 255.0f / (maximum - minimum); |
201 | |
202 | for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) { |
203 | const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12), |
204 | vld1q_f32(mean_ptr + col_offset + 8), |
205 | vld1q_f32(mean_ptr + col_offset + 4), |
206 | vld1q_f32(mean_ptr + col_offset)}; |
207 | const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12), |
208 | vld1q_f32(variance_ptr + col_offset + 8), |
209 | vld1q_f32(variance_ptr + col_offset + 4), |
210 | vld1q_f32(variance_ptr + col_offset)}; |
211 | const float32x4_t inv_stddev[4] = { |
212 | vrsqrteq_f32(vaddq_f32(variance[0], eps)), |
213 | vrsqrteq_f32(vaddq_f32(variance[1], eps)), |
214 | vrsqrteq_f32(vaddq_f32(variance[2], eps)), |
215 | vrsqrteq_f32(vaddq_f32(variance[3], eps))}; |
216 | const uint8_t* inp_ptr = input + col_offset; |
217 | uint8_t* out_ptr = output + col_offset; |
218 | for (uint32_t row = 0; row < rows; ++row) { |
219 | const uint8x16_t v = vld1q_u8(inp_ptr); |
220 | inp_ptr += cols; |
221 | const uint16x8_t v_high = vmovl_u8(vget_high_u8(v)); |
222 | const uint16x8_t v_low = vmovl_u8(vget_low_u8(v)); |
223 | |
224 | const float32x4_t v_float[4] = { |
225 | vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))), |
226 | vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))), |
227 | vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))), |
228 | vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))}; |
229 | |
230 | uint16x4_t normed_uint16[4]; |
231 | for (int i = 0; i < 4; ++i) { |
232 | const float32x4_t normed = |
233 | vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]); |
234 | const int32x4_t normed_int32 = |
235 | vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale)); |
236 | normed_uint16[i] = vqmovun_s32(normed_int32); |
237 | } |
238 | vst1_u8(out_ptr, |
239 | vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2]))); |
240 | vst1_u8(out_ptr + 8, |
241 | vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0]))); |
242 | out_ptr += cols; |
243 | } |
244 | } |
245 | } |
246 | |
247 | } // end namespace |
248 | #endif // USE_NEON |
249 | |
250 | namespace tensorflow { |
251 | |
252 | typedef Eigen::ThreadPoolDevice CPUDevice; |
253 | |
254 | class QuantizedInstanceNorm : public OpKernel { |
255 | public: |
256 | explicit QuantizedInstanceNorm(OpKernelConstruction* context) |
257 | : OpKernel(context) { |
258 | OP_REQUIRES_OK(context, |
259 | context->GetAttr("variance_epsilon" , &variance_epsilon_)); |
260 | OP_REQUIRES_OK(context, |
261 | context->GetAttr("min_separation" , &min_separation_)); |
262 | OP_REQUIRES_OK( |
263 | context, context->GetAttr("output_range_given" , &output_range_given_)); |
264 | if (output_range_given_) { |
265 | OP_REQUIRES_OK(context, context->GetAttr("given_y_min" , &given_y_min_)); |
266 | OP_REQUIRES_OK(context, context->GetAttr("given_y_max" , &given_y_max_)); |
267 | OP_REQUIRES(context, given_y_min_ < given_y_max_, |
268 | errors::InvalidArgument( |
269 | "given_y_min must be less than given_y_max : " , |
270 | given_y_min_, " >= " , given_y_max_)); |
271 | } |
272 | } |
273 | |
274 | void Compute(OpKernelContext* context) override { |
275 | const Tensor& input = context->input(0); |
276 | |
277 | const Tensor& x_min = context->input(1); |
278 | const Tensor& x_max = context->input(2); |
279 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_min.shape()), |
280 | errors::InvalidArgument("`x_min` must be rank 0 but is rank " , |
281 | x_min.dims())); |
282 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(x_max.shape()), |
283 | errors::InvalidArgument("`x_max` must be rank 0 but is rank " , |
284 | x_max.dims())); |
285 | float input_min = x_min.scalar<float>()(); |
286 | float input_max = x_max.scalar<float>()(); |
287 | float input_scale = (input_max - input_min) / 255.0f; |
288 | |
289 | OP_REQUIRES(context, input_min < input_max, |
290 | errors::InvalidArgument( |
291 | "input_min must be less than input_max : " , input_min, |
292 | " >= " , input_max)); |
293 | |
294 | auto input_tensor = input.tensor<quint8, 4>(); |
295 | auto N = input_tensor.dimension(0); |
296 | auto H = input_tensor.dimension(1); |
297 | auto W = input_tensor.dimension(2); |
298 | auto C = input_tensor.dimension(3); |
299 | |
300 | Tensor* output = nullptr; |
301 | OP_REQUIRES_OK(context, |
302 | context->allocate_output(0, input.shape(), &output)); |
303 | |
304 | Tensor* output_min = nullptr; |
305 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); |
306 | Tensor* output_max = nullptr; |
307 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); |
308 | |
309 | typedef TTypes<float>::Tensor::Index Index; |
310 | |
311 | const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>> |
312 | reduction_indices; |
313 | Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>> |
314 | broadcast_spec; |
315 | broadcast_spec.set(1, H); |
316 | broadcast_spec.set(2, W); |
317 | Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index> |
318 | expand_spec; |
319 | expand_spec.set(0, N); |
320 | expand_spec.set(3, C); |
321 | |
322 | Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C); |
323 | Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C); |
324 | |
325 | #ifdef USE_NEON |
326 | if (N == 1 && (C % 16 == 0)) { |
327 | VLOG(2) << "Calling optimized" ; |
328 | ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()), |
329 | H * W, C, float_mean.data(), float_variance.data()); |
330 | |
331 | float minimum = given_y_min_, maximum = given_y_max_; |
332 | if (!output_range_given_) { |
333 | MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W, |
334 | C, float_mean.data(), float_variance.data(), |
335 | variance_epsilon_, &minimum, &maximum); |
336 | } |
337 | |
338 | if (maximum - minimum < min_separation_) { |
339 | maximum = minimum + min_separation_; |
340 | } |
341 | |
342 | InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W, |
343 | C, float_mean.data(), float_variance.data(), |
344 | variance_epsilon_, minimum, maximum, |
345 | reinterpret_cast<uint8_t*>(output->flat<quint8>().data())); |
346 | output_min->scalar<float>()() = minimum; |
347 | output_max->scalar<float>()() = maximum; |
348 | } else // NOLINT(readability/braces) |
349 | #endif |
350 | { |
351 | VLOG(2) << "Calling unoptimized" ; |
352 | float_mean = input_tensor.cast<float>().reduce( |
353 | reduction_indices, Eigen::internal::MeanReducer<float>()); |
354 | |
355 | float_variance = |
356 | (input_scale * |
357 | ((input_tensor.cast<float>() - |
358 | float_mean.reshape(expand_spec).broadcast(broadcast_spec)))) |
359 | .square() |
360 | .reduce(reduction_indices, Eigen::internal::MeanReducer<float>()); |
361 | |
362 | Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed = |
363 | input_scale * |
364 | (input_tensor.cast<float>() - |
365 | float_mean.reshape(expand_spec).broadcast(broadcast_spec)) * |
366 | (float_variance + variance_epsilon_) |
367 | .rsqrt() |
368 | .reshape(expand_spec) |
369 | .broadcast(broadcast_spec); |
370 | |
371 | Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min; |
372 | Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max; |
373 | |
374 | if (!output_range_given_) { |
375 | normed_min = instance_normed.minimum(); |
376 | normed_max = instance_normed.maximum(); |
377 | } else { |
378 | normed_min() = given_y_min_; |
379 | normed_max() = given_y_max_; |
380 | } |
381 | |
382 | if (normed_max() - normed_min() < min_separation_) { |
383 | normed_max() = normed_min() + min_separation_; |
384 | } |
385 | |
386 | FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max()); |
387 | auto instance_normed_quantized = |
388 | QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8); |
389 | |
390 | output->tensor<quint8, 4>().device( |
391 | context->template eigen_device<CPUDevice>()) = |
392 | instance_normed_quantized; |
393 | output_min->flat<float>()(0) = normed_min(); |
394 | output_max->flat<float>()(0) = normed_max(); |
395 | } |
396 | } |
397 | |
398 | private: |
399 | float variance_epsilon_; |
400 | float min_separation_; |
401 | bool output_range_given_; |
402 | float given_y_min_; |
403 | float given_y_max_; |
404 | }; |
405 | |
406 | REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm" ) |
407 | .Device(DEVICE_CPU) |
408 | .TypeConstraint<quint8>("T" ), |
409 | QuantizedInstanceNorm); |
410 | |
411 | } // namespace tensorflow |
412 | |