1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements a quantized eight-bit version of the matmul operation. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if defined(__ARM_NEON__) || defined(__ARM_NEON) |
21 | #define USE_NEON |
22 | #include <arm_neon.h> |
23 | #endif |
24 | |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/kernels/meta_support.h" |
28 | #include "tensorflow/core/kernels/quantization_utils.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/util/bcast.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace { |
34 | |
35 | template <class T, class Toutput> |
36 | void ScalarMultiply(OpKernelContext* context, const T* full_input, |
37 | int32_t full_input_offset, int64_t num_elements, |
38 | T scalar_input, int32_t scalar_input_offset, |
39 | Toutput* output) { |
40 | const int32_t scalar_minus_offset = |
41 | static_cast<int32>(scalar_input) - scalar_input_offset; |
42 | for (int i = 0; i < num_elements; ++i) { |
43 | output[i] = (static_cast<int32>(full_input[i]) - full_input_offset) * |
44 | scalar_minus_offset; |
45 | } |
46 | } |
47 | |
48 | #ifdef USE_NEON |
49 | |
50 | template <> |
51 | void ScalarMultiply<quint8, qint32>(OpKernelContext* context, |
52 | const quint8* full_input, |
53 | int32 full_input_offset, int64 num_elements, |
54 | quint8 scalar_input, |
55 | int32 scalar_input_offset, qint32* output) { |
56 | const int16 scalar_minus_offset = |
57 | static_cast<int16>(scalar_input) - scalar_input_offset; |
58 | const int16x4_t scalar_minus_offset_16x4 = vmov_n_s16(scalar_minus_offset); |
59 | const uint8x8_t full_input_offset_8x8 = vmov_n_u8(full_input_offset); |
60 | // Go through the results in 16-element chunks for NEON acceleration. |
61 | int i; |
62 | for (i = 0; i < (num_elements - 15); i += 16) { |
63 | // Load the tensor inputs. |
64 | const uint8* full_input_ptr = &(full_input->value) + i; |
65 | const uint8x16_t full_input_8x16 = vld1q_u8(full_input_ptr); |
66 | |
67 | // Break into two sets of vectors so we can do further calculations |
68 | // easily. |
69 | const uint8x8_t full_input_high_8x8 = vget_high_u8(full_input_8x16); |
70 | const uint8x8_t full_input_low_8x8 = vget_low_u8(full_input_8x16); |
71 | |
72 | // Subtract off the offset value to get 16-bit results. |
73 | const int16x8_t full_input_minus_offset_high_16x8 = vreinterpretq_s16_u16( |
74 | vsubl_u8(full_input_high_8x8, full_input_offset_8x8)); |
75 | const int16x8_t full_input_minus_offset_low_16x8 = vreinterpretq_s16_u16( |
76 | vsubl_u8(full_input_low_8x8, full_input_offset_8x8)); |
77 | |
78 | // We have to work with 4-wide vectors, so extract them. |
79 | const int16x4_t x_high_high_16x4 = |
80 | vget_high_s16(full_input_minus_offset_high_16x8); |
81 | const int16x4_t x_high_low_16x4 = |
82 | vget_low_s16(full_input_minus_offset_high_16x8); |
83 | const int16x4_t x_low_high_16x4 = |
84 | vget_high_s16(full_input_minus_offset_low_16x8); |
85 | const int16x4_t x_low_low_16x4 = |
86 | vget_low_s16(full_input_minus_offset_low_16x8); |
87 | |
88 | // Perform the multiplication. |
89 | const int32x4_t z_high_high_32x4 = |
90 | vmull_s16(x_high_high_16x4, scalar_minus_offset_16x4); |
91 | const int32x4_t z_high_low_32x4 = |
92 | vmull_s16(x_high_low_16x4, scalar_minus_offset_16x4); |
93 | const int32x4_t z_low_high_32x4 = |
94 | vmull_s16(x_low_high_16x4, scalar_minus_offset_16x4); |
95 | const int32x4_t z_low_low_32x4 = |
96 | vmull_s16(x_low_low_16x4, scalar_minus_offset_16x4); |
97 | |
98 | // Write out the results. |
99 | int32* output_ptr = &(output->value) + i; |
100 | vst1q_s32(output_ptr + 0, z_low_low_32x4); |
101 | vst1q_s32(output_ptr + 4, z_low_high_32x4); |
102 | vst1q_s32(output_ptr + 8, z_high_low_32x4); |
103 | vst1q_s32(output_ptr + 12, z_high_high_32x4); |
104 | } |
105 | // Finish up any remaining elements that weren't a multiple of 16. |
106 | for (; i < num_elements; ++i) { |
107 | output[i] = (static_cast<int32>(full_input[i]) - full_input_offset) * |
108 | scalar_minus_offset; |
109 | } |
110 | } |
111 | #endif // USE_NEON |
112 | |
113 | template <class T, class Toutput> |
114 | void VectorMultiply(OpKernelContext* context, const T* x_data, int32_t offset_x, |
115 | const T* y_data, int32_t offset_y, int64_t num_elements, |
116 | Toutput* output) { |
117 | for (int i = 0; i < num_elements; ++i) { |
118 | output[i] = (static_cast<int32>(x_data[i]) - offset_x) * |
119 | (static_cast<int32>(y_data[i]) - offset_y); |
120 | } |
121 | } |
122 | |
123 | #ifdef USE_NEON |
124 | template <> |
125 | void VectorMultiply<quint8, qint32>(OpKernelContext* context, |
126 | const quint8* x_data, int32 offset_x, |
127 | const quint8* y_data, int32 offset_y, |
128 | int64 num_elements, qint32* output) { |
129 | const uint8x8_t offset_x_8x8 = vmov_n_u8(offset_x); |
130 | const uint8x8_t offset_y_8x8 = vmov_n_u8(offset_y); |
131 | int i; |
132 | // Go through the results in 16-element chunks for NEON acceleration. |
133 | for (i = 0; i < (num_elements - 15); i += 16) { |
134 | // Load the vector inputs. |
135 | const uint8* x_data_ptr = &(x_data->value) + i; |
136 | const uint8x16_t x_8x16 = vld1q_u8(x_data_ptr); |
137 | const uint8* y_data_ptr = &(y_data->value) + i; |
138 | const uint8x16_t y_8x16 = vld1q_u8(y_data_ptr); |
139 | |
140 | // Break into two sets of vectors so we can do further calculations easily. |
141 | const uint8x8_t x_high_8x8 = vget_high_u8(x_8x16); |
142 | const uint8x8_t x_low_8x8 = vget_low_u8(x_8x16); |
143 | const uint8x8_t y_high_8x8 = vget_high_u8(y_8x16); |
144 | const uint8x8_t y_low_8x8 = vget_low_u8(y_8x16); |
145 | |
146 | // Subtract off the offset values to get 16-bit results. |
147 | const int16x8_t x_minus_offset_high_16x8 = |
148 | vreinterpretq_s16_u16(vsubl_u8(x_high_8x8, offset_x_8x8)); |
149 | const int16x8_t x_minus_offset_low_16x8 = |
150 | vreinterpretq_s16_u16(vsubl_u8(x_low_8x8, offset_x_8x8)); |
151 | const int16x8_t y_minus_offset_high_16x8 = |
152 | vreinterpretq_s16_u16(vsubl_u8(y_high_8x8, offset_y_8x8)); |
153 | const int16x8_t y_minus_offset_low_16x8 = |
154 | vreinterpretq_s16_u16(vsubl_u8(y_low_8x8, offset_y_8x8)); |
155 | |
156 | // We have to work with 4-wide vectors, so extract them. |
157 | const int16x4_t x_high_high_16x4 = vget_high_s16(x_minus_offset_high_16x8); |
158 | const int16x4_t x_high_low_16x4 = vget_low_s16(x_minus_offset_high_16x8); |
159 | const int16x4_t x_low_high_16x4 = vget_high_s16(x_minus_offset_low_16x8); |
160 | const int16x4_t x_low_low_16x4 = vget_low_s16(x_minus_offset_low_16x8); |
161 | const int16x4_t y_high_high_16x4 = vget_high_s16(y_minus_offset_high_16x8); |
162 | const int16x4_t y_high_low_16x4 = vget_low_s16(y_minus_offset_high_16x8); |
163 | const int16x4_t y_low_high_16x4 = vget_high_s16(y_minus_offset_low_16x8); |
164 | const int16x4_t y_low_low_16x4 = vget_low_s16(y_minus_offset_low_16x8); |
165 | |
166 | // Perform the multiplication. |
167 | const int32x4_t z_high_high_32x4 = |
168 | vmull_s16(x_high_high_16x4, y_high_high_16x4); |
169 | const int32x4_t z_high_low_32x4 = |
170 | vmull_s16(x_high_low_16x4, y_high_low_16x4); |
171 | const int32x4_t z_low_high_32x4 = |
172 | vmull_s16(x_low_high_16x4, y_low_high_16x4); |
173 | const int32x4_t z_low_low_32x4 = vmull_s16(x_low_low_16x4, y_low_low_16x4); |
174 | |
175 | // Write out the results. |
176 | int32* output_ptr = &(output->value) + i; |
177 | vst1q_s32(output_ptr + 0, z_low_low_32x4); |
178 | vst1q_s32(output_ptr + 4, z_low_high_32x4); |
179 | vst1q_s32(output_ptr + 8, z_high_low_32x4); |
180 | vst1q_s32(output_ptr + 12, z_high_high_32x4); |
181 | } |
182 | for (; i < num_elements; ++i) { |
183 | output[i] = (static_cast<int32>(x_data[i]) - offset_x) * |
184 | (static_cast<int32>(y_data[i]) - offset_y); |
185 | } |
186 | } |
187 | #endif // USE_NEON |
188 | |
189 | template <class T, class Toutput> |
190 | void VectorTensorMultiply(const T* vector_data, int32_t vector_offset, |
191 | int64_t vector_num_elements, const T* tensor_data, |
192 | int32_t tensor_offset, int64_t tensor_num_elements, |
193 | Toutput* output) { |
194 | for (int i = 0; i < tensor_num_elements; ++i) { |
195 | const int64_t vector_i = i % vector_num_elements; |
196 | output[i] = (static_cast<int32>(vector_data[vector_i]) - vector_offset) * |
197 | (static_cast<int32>(tensor_data[i]) - tensor_offset); |
198 | } |
199 | } |
200 | |
201 | #ifdef USE_NEON |
202 | template <> |
203 | void VectorTensorMultiply<quint8, qint32>( |
204 | const quint8* vector_data, int32 vector_offset, int64 vector_num_elements, |
205 | const quint8* tensor_data, int32 tensor_offset, int64 tensor_num_elements, |
206 | qint32* output) { |
207 | const uint8x8_t offset_x_8x8 = vmov_n_u8(vector_offset); |
208 | const uint8x8_t offset_y_8x8 = vmov_n_u8(tensor_offset); |
209 | CHECK_EQ(0, tensor_num_elements % vector_num_elements); |
210 | for (int base_i = 0; base_i < tensor_num_elements; |
211 | base_i += vector_num_elements) { |
212 | int i = base_i; |
213 | const int end_i = base_i + vector_num_elements; |
214 | // Go through the results in 16-element chunks for NEON acceleration. |
215 | int vector_i; |
216 | for (vector_i = 0; vector_i < (vector_num_elements - 15); |
217 | vector_i += 16, i += 16) { |
218 | // Load the vector inputs. |
219 | const uint8* x_data_ptr = &(vector_data->value) + vector_i; |
220 | const uint8x16_t x_8x16 = vld1q_u8(x_data_ptr); |
221 | const uint8* y_data_ptr = &(tensor_data->value) + i; |
222 | const uint8x16_t y_8x16 = vld1q_u8(y_data_ptr); |
223 | |
224 | // Break into two sets of vectors so we can do further calculations |
225 | // easily. |
226 | const uint8x8_t x_high_8x8 = vget_high_u8(x_8x16); |
227 | const uint8x8_t x_low_8x8 = vget_low_u8(x_8x16); |
228 | const uint8x8_t y_high_8x8 = vget_high_u8(y_8x16); |
229 | const uint8x8_t y_low_8x8 = vget_low_u8(y_8x16); |
230 | |
231 | // Subtract off the offset values to get 16-bit results. |
232 | const int16x8_t x_minus_offset_high_16x8 = |
233 | vreinterpretq_s16_u16(vsubl_u8(x_high_8x8, offset_x_8x8)); |
234 | const int16x8_t x_minus_offset_low_16x8 = |
235 | vreinterpretq_s16_u16(vsubl_u8(x_low_8x8, offset_x_8x8)); |
236 | const int16x8_t y_minus_offset_high_16x8 = |
237 | vreinterpretq_s16_u16(vsubl_u8(y_high_8x8, offset_y_8x8)); |
238 | const int16x8_t y_minus_offset_low_16x8 = |
239 | vreinterpretq_s16_u16(vsubl_u8(y_low_8x8, offset_y_8x8)); |
240 | |
241 | // We have to work with 4-wide vectors, so extract them. |
242 | const int16x4_t x_high_high_16x4 = |
243 | vget_high_s16(x_minus_offset_high_16x8); |
244 | const int16x4_t x_high_low_16x4 = vget_low_s16(x_minus_offset_high_16x8); |
245 | const int16x4_t x_low_high_16x4 = vget_high_s16(x_minus_offset_low_16x8); |
246 | const int16x4_t x_low_low_16x4 = vget_low_s16(x_minus_offset_low_16x8); |
247 | const int16x4_t y_high_high_16x4 = |
248 | vget_high_s16(y_minus_offset_high_16x8); |
249 | const int16x4_t y_high_low_16x4 = vget_low_s16(y_minus_offset_high_16x8); |
250 | const int16x4_t y_low_high_16x4 = vget_high_s16(y_minus_offset_low_16x8); |
251 | const int16x4_t y_low_low_16x4 = vget_low_s16(y_minus_offset_low_16x8); |
252 | |
253 | // Perform the multiplication. |
254 | const int32x4_t z_high_high_32x4 = |
255 | vmull_s16(x_high_high_16x4, y_high_high_16x4); |
256 | const int32x4_t z_high_low_32x4 = |
257 | vmull_s16(x_high_low_16x4, y_high_low_16x4); |
258 | const int32x4_t z_low_high_32x4 = |
259 | vmull_s16(x_low_high_16x4, y_low_high_16x4); |
260 | const int32x4_t z_low_low_32x4 = |
261 | vmull_s16(x_low_low_16x4, y_low_low_16x4); |
262 | |
263 | // Write out the results. |
264 | int32* output_ptr = &(output->value) + i; |
265 | vst1q_s32(output_ptr + 0, z_low_low_32x4); |
266 | vst1q_s32(output_ptr + 4, z_low_high_32x4); |
267 | vst1q_s32(output_ptr + 8, z_high_low_32x4); |
268 | vst1q_s32(output_ptr + 12, z_high_high_32x4); |
269 | } |
270 | for (; i < end_i; ++i, ++vector_i) { |
271 | output[i] = (static_cast<int32>(vector_data[vector_i]) - vector_offset) * |
272 | (static_cast<int32>(tensor_data[i]) - tensor_offset); |
273 | } |
274 | } |
275 | } |
276 | #endif // USE_NEON |
277 | |
278 | } // namespace |
279 | |
280 | template <class T, class Toutput> |
281 | class QuantizedMulOp : public OpKernel { |
282 | public: |
283 | explicit QuantizedMulOp(OpKernelConstruction* context) : OpKernel(context) {} |
284 | |
285 | void Compute(OpKernelContext* context) override { |
286 | const Tensor& x = context->input(0); |
287 | const Tensor& y = context->input(1); |
288 | auto& min_x_tensor = context->input(2); |
289 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_x_tensor.shape()), |
290 | errors::InvalidArgument("min_x must be a scalar" )); |
291 | const float min_x = min_x_tensor.flat<float>()(0); |
292 | auto& max_x_tensor = context->input(3); |
293 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_x_tensor.shape()), |
294 | errors::InvalidArgument("max_x must be a scalar" )); |
295 | const float max_x = max_x_tensor.flat<float>()(0); |
296 | auto& min_y_tensor = context->input(4); |
297 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_y_tensor.shape()), |
298 | errors::InvalidArgument("min_y must be a scalar" )); |
299 | const float min_y = min_y_tensor.flat<float>()(0); |
300 | auto& max_y_tensor = context->input(5); |
301 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_y_tensor.shape()), |
302 | errors::InvalidArgument("max_y must be a scalar" )); |
303 | const float max_y = max_y_tensor.flat<float>()(0); |
304 | |
305 | BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape())); |
306 | if (!bcast.IsValid()) { |
307 | context->SetStatus(errors::InvalidArgument( |
308 | "Incompatible shapes: " , x.shape().DebugString(), " vs. " , |
309 | y.shape().DebugString())); |
310 | return; |
311 | } |
312 | Tensor* z; |
313 | OP_REQUIRES_OK(context, context->allocate_output( |
314 | 0, BCast::ToShape(bcast.output_shape()), &z)); |
315 | |
316 | // Make sure that we have valid quantization ranges for the input buffers. |
317 | // If the difference between the min and max is negative or zero, it makes |
318 | // it hard to do meaningful intermediate operations on the values. |
319 | OP_REQUIRES(context, (max_x > min_x), |
320 | errors::InvalidArgument("max_x must be larger than min_a." )); |
321 | OP_REQUIRES(context, (max_y > min_y), |
322 | errors::InvalidArgument("max_x must be larger than min_b." )); |
323 | const int32_t offset_x = FloatToQuantizedUnclamped<T>(0.0f, min_x, max_x); |
324 | const int32_t offset_y = FloatToQuantizedUnclamped<T>(0.0f, min_y, max_y); |
325 | const T* x_data = x.flat<T>().data(); |
326 | const T* y_data = y.flat<T>().data(); |
327 | Toutput* z_data = z->flat<Toutput>().data(); |
328 | |
329 | const int ndims = bcast.x_reshape().size(); |
330 | if (ndims <= 1) { |
331 | if (x.NumElements() == 1) { |
332 | ScalarMultiply<T, Toutput>(context, y_data, offset_y, y.NumElements(), |
333 | x_data[0], offset_x, z_data); |
334 | } else if (y.NumElements() == 1) { |
335 | ScalarMultiply<T, Toutput>(context, x_data, offset_x, x.NumElements(), |
336 | y_data[0], offset_y, z_data); |
337 | } else { |
338 | VectorMultiply<T, Toutput>(context, x_data, offset_x, y_data, offset_y, |
339 | x.NumElements(), z_data); |
340 | } |
341 | } else if (ndims == 2) { |
342 | const T* vector_data; |
343 | int64_t vector_num_elements; |
344 | int32_t vector_offset; |
345 | const T* tensor_data; |
346 | int64_t tensor_num_elements; |
347 | int32_t tensor_offset; |
348 | if (x.NumElements() < y.NumElements()) { |
349 | vector_data = x_data; |
350 | vector_num_elements = x.NumElements(); |
351 | vector_offset = offset_x; |
352 | tensor_data = y_data; |
353 | tensor_num_elements = y.NumElements(); |
354 | tensor_offset = offset_y; |
355 | } else { |
356 | vector_data = y_data; |
357 | vector_num_elements = y.NumElements(); |
358 | vector_offset = offset_y; |
359 | tensor_data = x_data; |
360 | tensor_num_elements = x.NumElements(); |
361 | tensor_offset = offset_x; |
362 | } |
363 | if (vector_num_elements == 0) { |
364 | context->SetStatus( |
365 | errors::InvalidArgument("vector must have at least 1 element" )); |
366 | return; |
367 | } |
368 | VectorTensorMultiply<T, Toutput>( |
369 | vector_data, vector_offset, vector_num_elements, tensor_data, |
370 | tensor_offset, tensor_num_elements, z_data); |
371 | } else { |
372 | LOG(INFO) << "ndims=" << ndims; |
373 | LOG(INFO) << "bcast.x_reshape()=" |
374 | << TensorShape(bcast.x_reshape()).DebugString(); |
375 | LOG(INFO) << "bcast.y_reshape()=" |
376 | << TensorShape(bcast.y_reshape()).DebugString(); |
377 | LOG(INFO) << "bcast.x_bcast()=" |
378 | << TensorShape(bcast.x_bcast()).DebugString(); |
379 | LOG(INFO) << "bcast.y_bcast()=" |
380 | << TensorShape(bcast.y_bcast()).DebugString(); |
381 | |
382 | context->SetStatus(errors::Unimplemented( |
383 | "Broadcast between " , context->input(0).shape().DebugString(), |
384 | " and " , context->input(1).shape().DebugString(), |
385 | " is not supported yet." )); |
386 | return; |
387 | } |
388 | |
389 | float min_z_value; |
390 | float max_z_value; |
391 | QuantizationRangeForMultiplication<T, T, Toutput>( |
392 | min_x, max_x, min_y, max_y, &min_z_value, &max_z_value); |
393 | Tensor* z_min = nullptr; |
394 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min)); |
395 | z_min->flat<float>()(0) = min_z_value; |
396 | |
397 | Tensor* z_max = nullptr; |
398 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max)); |
399 | z_max->flat<float>()(0) = max_z_value; |
400 | } |
401 | }; |
402 | |
403 | REGISTER_KERNEL_BUILDER(Name("QuantizedMul" ) |
404 | .Device(DEVICE_CPU) |
405 | .TypeConstraint<quint8>("T1" ) |
406 | .TypeConstraint<quint8>("T2" ) |
407 | .TypeConstraint<qint32>("Toutput" ), |
408 | QuantizedMulOp<quint8, qint32>); |
409 | |
410 | } // namespace tensorflow |
411 | |