1 | // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | // output.h: processing the 32-bit accumulators output by the unpack |
16 | // stage, obtaining the final result matrix entries and storing them into |
17 | // the destination matrix. |
18 | |
19 | #ifndef GEMMLOWP_INTERNAL_OUTPUT_H_ |
20 | #define GEMMLOWP_INTERNAL_OUTPUT_H_ |
21 | |
22 | #include <cmath> |
23 | #include <tuple> |
24 | #include <type_traits> |
25 | #include <typeinfo> |
26 | |
27 | #include "../fixedpoint/fixedpoint.h" |
28 | #include "../public/output_stages.h" |
29 | #include "simd_wrappers.h" |
30 | |
31 | namespace gemmlowp { |
32 | |
33 | template <typename OutputStage, typename InputBufferType> |
34 | struct OutputStageEvalBufferImpl { |
35 | // This generic template body should never be hit. |
36 | static_assert( |
37 | std::is_same<InputBufferType, void>::value, |
38 | "Unimplemented: missing implementation of this output pipeline stage " |
39 | "for this data type. This would happen if some architecture-specific " |
40 | "SIMD back-end (output_$arch.h) were incomplete." ); |
41 | }; |
42 | |
43 | template <typename OutputStage, typename InputType> |
44 | struct OutputStageEvalImpl { |
45 | static constexpr int kRows = InputType::kRows; |
46 | static constexpr int kCols = InputType::kCols; |
47 | using InputBufferType = typename InputType::BufferType; |
48 | using BufferEvalImplType = |
49 | OutputStageEvalBufferImpl<OutputStage, InputBufferType>; |
50 | using OutputBufferType = typename BufferEvalImplType::OutputType; |
51 | using OutputScalarType = typename OutputBufferType::ScalarType; |
52 | using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>; |
53 | |
54 | OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {} |
55 | |
56 | OutputType Eval(InputType input, int, int) const { |
57 | OutputType output; |
58 | output.buf = buffer_eval_impl.Eval(input.buf); |
59 | return output; |
60 | } |
61 | |
62 | const BufferEvalImplType buffer_eval_impl; |
63 | }; |
64 | |
65 | template <int Size> |
66 | struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale, |
67 | RegisterBuffer<std::int32_t, Size>> { |
68 | using InputType = RegisterBuffer<std::int32_t, Size>; |
69 | using OutputType = RegisterBuffer<std::int32_t, Size>; |
70 | |
71 | typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage; |
72 | |
73 | OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} |
74 | |
75 | OutputType Eval(InputType input) const { |
76 | const int result_shift = output_stage.result_shift; |
77 | const std::int32_t result_mult_int = output_stage.result_mult_int; |
78 | using RegisterType = typename InputType::RegisterType; |
79 | const RegisterType result_offset = |
80 | Dup<RegisterType>(output_stage.result_offset); |
81 | OutputType output; |
82 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
83 | output.reg[i] = RoundingDivideByPOT( |
84 | Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift); |
85 | } |
86 | return output; |
87 | } |
88 | |
89 | const OutputStage& output_stage; |
90 | }; |
91 | |
92 | template <int Rows, int Cols, VectorShape Shape> |
93 | struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>, |
94 | RegisterBlock<std::int32_t, Rows, Cols>> { |
95 | typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; |
96 | typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; |
97 | typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage; |
98 | |
99 | OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} |
100 | |
101 | OutputType Eval(InputType input, int row, int col) const { |
102 | OutputType output; |
103 | const int result_shift = output_stage.result_shift; |
104 | const int pos = Shape == VectorShape::Col ? row : col; |
105 | const auto result_mult_int = |
106 | LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos); |
107 | const auto result_offset = |
108 | LoadForBroadcasting<InputType>(output_stage.result_offset, pos); |
109 | const auto dividend = BroadcastMul<InputType>( |
110 | BroadcastAdd<InputType>(input, result_offset), result_mult_int); |
111 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
112 | output.buf.reg[i] = |
113 | RoundingDivideByPOT(dividend.buf.reg[i], result_shift); |
114 | } |
115 | return output; |
116 | } |
117 | |
118 | const OutputStage& output_stage; |
119 | }; |
120 | |
121 | template <int Size> |
122 | struct OutputStageEvalBufferImpl< |
123 | OutputStageQuantizeDownInt32ByFixedPoint, |
124 | RegisterBuffer<std::int32_t, Size>> { |
125 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
126 | typedef RegisterBuffer<std::int32_t, Size> OutputType; |
127 | |
128 | typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage; |
129 | |
130 | OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} |
131 | |
132 | OutputType Eval(InputType input) const { |
133 | OutputType output; |
134 | using RegisterType = typename InputType::RegisterType; |
135 | const RegisterType result_offset_after_shift = |
136 | Dup<RegisterType>(output_stage.result_offset_after_shift); |
137 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
138 | const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( |
139 | input.reg[i], output_stage.result_fixedpoint_multiplier); |
140 | output.reg[i] = |
141 | Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift), |
142 | result_offset_after_shift); |
143 | } |
144 | return output; |
145 | } |
146 | |
147 | const OutputStage& output_stage; |
148 | }; |
149 | |
150 | template <int Size> |
151 | struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent, |
152 | RegisterBuffer<std::int32_t, Size>> { |
153 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
154 | typedef RegisterBuffer<std::int32_t, Size> OutputType; |
155 | |
156 | typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage; |
157 | |
158 | OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { |
159 | left_shift = std::max(0, output_stage.result_exponent); |
160 | right_shift = std::max(0, -output_stage.result_exponent); |
161 | } |
162 | |
163 | OutputType Eval(InputType input) const { |
164 | OutputType output; |
165 | using RegisterType = typename InputType::RegisterType; |
166 | const RegisterType result_offset_after_shift = |
167 | Dup<RegisterType>(output_stage.result_offset_after_shift); |
168 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
169 | const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( |
170 | ShiftLeft(input.reg[i], left_shift), |
171 | output_stage.result_fixedpoint_multiplier); |
172 | output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift), |
173 | result_offset_after_shift); |
174 | } |
175 | return output; |
176 | } |
177 | |
178 | const OutputStage& output_stage; |
179 | int left_shift; |
180 | int right_shift; |
181 | }; |
182 | |
183 | template <int Rows, int Cols, VectorShape Shape> |
184 | struct OutputStageEvalImpl< |
185 | OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>, |
186 | RegisterBlock<std::int32_t, Rows, Cols>> { |
187 | typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; |
188 | typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; |
189 | |
190 | typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage; |
191 | |
192 | OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} |
193 | |
194 | OutputType Eval(InputType input, int row, int col) const { |
195 | OutputType output; |
196 | const int pos = Shape == VectorShape::Row ? col : row; |
197 | using RegisterType = typename InputType::RegisterType; |
198 | const RegisterType result_offset_after_shift = |
199 | Dup<RegisterType>(output_stage.result_offset_after_shift); |
200 | auto left_shift = |
201 | LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); |
202 | auto right_shift = |
203 | LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); |
204 | const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>( |
205 | output_stage.result_fixedpoint_multiplier, pos); |
206 | for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) { |
207 | left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0); |
208 | right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0); |
209 | } |
210 | const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul( |
211 | BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier); |
212 | const auto rdpot_val = |
213 | BroadcastRoundingDivideByPOT(mulhigh_val, right_shift); |
214 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
215 | output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift); |
216 | } |
217 | return output; |
218 | } |
219 | |
220 | const OutputStage& output_stage; |
221 | }; |
222 | |
223 | // Implementation of OutputStageSaturatingCastToUint8 for scalar data. |
224 | template <int Size> |
225 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, |
226 | RegisterBuffer<std::int32_t, Size>> { |
227 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
228 | typedef RegisterBuffer<std::uint8_t, Size> OutputType; |
229 | static_assert(InputType::kRegisterLanes == 1, |
230 | "This path is only for scalar values" ); |
231 | |
232 | typedef OutputStageSaturatingCastToUint8 OutputStage; |
233 | |
234 | OutputStageEvalBufferImpl(const OutputStage&) {} |
235 | |
236 | OutputType Eval(InputType input) const { |
237 | OutputType output; |
238 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
239 | std::int32_t data = input.reg[i]; |
240 | output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data; |
241 | } |
242 | return output; |
243 | } |
244 | }; |
245 | |
246 | // Implementation of OutputStageSaturatingCastToInt8 for scalar data. |
247 | template <int Size> |
248 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, |
249 | RegisterBuffer<std::int32_t, Size>> { |
250 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
251 | typedef RegisterBuffer<std::int8_t, Size> OutputType; |
252 | static_assert(InputType::kRegisterLanes == 1, |
253 | "This path is only for scalar values" ); |
254 | |
255 | typedef OutputStageSaturatingCastToInt8 OutputStage; |
256 | |
257 | OutputStageEvalBufferImpl(const OutputStage&) {} |
258 | |
259 | OutputType Eval(InputType input) const { |
260 | OutputType output; |
261 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
262 | std::int32_t data = input.reg[i]; |
263 | output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data; |
264 | } |
265 | return output; |
266 | } |
267 | }; |
268 | |
269 | // Implementation of OutputStageSaturatingCastToInt16 for scalar data. |
270 | template <int Size> |
271 | struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, |
272 | RegisterBuffer<std::int32_t, Size>> { |
273 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
274 | typedef RegisterBuffer<std::int16_t, Size> OutputType; |
275 | static_assert(InputType::kRegisterLanes == 1, |
276 | "This path is only for scalar values" ); |
277 | |
278 | typedef OutputStageSaturatingCastToInt16 OutputStage; |
279 | |
280 | OutputStageEvalBufferImpl(const OutputStage&) {} |
281 | |
282 | OutputType Eval(InputType input) const { |
283 | OutputType output; |
284 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
285 | std::int32_t data = input.reg[i]; |
286 | output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data; |
287 | } |
288 | return output; |
289 | } |
290 | }; |
291 | |
292 | // Implementation of OutputStageTruncatingCastToUint8 for scalar data |
293 | template <int Size> |
294 | struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, |
295 | RegisterBuffer<std::int32_t, Size>> { |
296 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
297 | typedef RegisterBuffer<std::uint8_t, Size> OutputType; |
298 | static_assert(InputType::kRegisterLanes == 1, |
299 | "This path is only for scalar values" ); |
300 | |
301 | typedef OutputStageTruncatingCastToUint8 OutputStage; |
302 | |
303 | OutputStageEvalBufferImpl(const OutputStage&) {} |
304 | |
305 | OutputType Eval(InputType input) const { |
306 | OutputType output; |
307 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
308 | output.reg[i] = input.reg[i]; |
309 | } |
310 | return output; |
311 | } |
312 | }; |
313 | |
314 | template <int Rows, int Cols, typename VectorType> |
315 | struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>, |
316 | RegisterBlock<std::int32_t, Rows, Cols>> { |
317 | typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; |
318 | typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; |
319 | typedef OutputStageBiasAddition<VectorType> OutputStage; |
320 | |
321 | OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} |
322 | |
323 | OutputType Eval(InputType input, int row, int col) const { |
324 | const int pos = VectorType::kShape == VectorShape::Row ? col : row; |
325 | return BroadcastAdd<InputType>( |
326 | input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos)); |
327 | } |
328 | |
329 | const OutputStage& output_stage; |
330 | }; |
331 | |
332 | template <int Size> |
333 | struct OutputStageEvalBufferImpl<OutputStageClamp, |
334 | RegisterBuffer<std::int32_t, Size>> { |
335 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
336 | typedef RegisterBuffer<std::int32_t, Size> OutputType; |
337 | |
338 | typedef OutputStageClamp OutputStage; |
339 | |
340 | OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} |
341 | |
342 | OutputType Eval(InputType input) const { |
343 | using RegisterType = typename InputType::RegisterType; |
344 | const RegisterType min = Dup<RegisterType>(output_stage.min); |
345 | const RegisterType max = Dup<RegisterType>(output_stage.max); |
346 | OutputType output; |
347 | for (int i = 0; i < InputType::kRegisterCount; i++) { |
348 | output.reg[i] = Min(Max(input.reg[i], min), max); |
349 | } |
350 | return output; |
351 | } |
352 | |
353 | const OutputStage& output_stage; |
354 | }; |
355 | |
356 | template <int Size> |
357 | struct OutputStageEvalBufferImpl<OutputStageTanh, |
358 | RegisterBuffer<std::int32_t, Size>> { |
359 | typedef RegisterBuffer<std::int32_t, Size> InputType; |
360 | typedef RegisterBuffer<std::int32_t, Size> OutputType; |
361 | using RegisterType = typename InputType::RegisterType; |
362 | typedef RegisterType DataType; |
363 | typedef OutputStageTanh OutputStage; |
364 | |
365 | OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { |
366 | const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; |
367 | const std::int32_t real_amplitude_as_int32 = |
368 | output_stage.real_amplitude_as_int32; |
369 | |
370 | input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32; |
371 | input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32; |
372 | output_min = real_zero_as_int32 - real_amplitude_as_int32; |
373 | output_max = real_zero_as_int32 + real_amplitude_as_int32; |
374 | |
375 | double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32; |
376 | inverse_amplitude_neg_exponent = 0; |
377 | while (inverse_amplitude_normalized_double < 0.5) { |
378 | inverse_amplitude_normalized_double *= 2; |
379 | inverse_amplitude_neg_exponent++; |
380 | } |
381 | inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble( |
382 | inverse_amplitude_normalized_double); |
383 | |
384 | double amplitude_normalized_double = real_amplitude_as_int32; |
385 | amplitude_exponent = 0; |
386 | while (amplitude_normalized_double >= 1.0) { |
387 | amplitude_normalized_double *= 0.5; |
388 | amplitude_exponent++; |
389 | } |
390 | amplitude_normalized = |
391 | FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double); |
392 | } |
393 | |
394 | OutputType Eval(InputType input) const { |
395 | const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; |
396 | |
397 | typedef FixedPoint<DataType, 3> F3; |
398 | typedef FixedPoint<DataType, 0> F0; |
399 | |
400 | OutputType output; |
401 | |
402 | for (int i = 0; i < OutputType::kRegisterCount; i++) { |
403 | // fixed-point affine transformation |
404 | DataType input_centered = |
405 | Sub(input.reg[i], Dup<DataType>(real_zero_as_int32)); |
406 | F3 fixedpoint_input = |
407 | F3::FromRaw(input_centered) * inverse_amplitude_normalized; |
408 | // left shift |
409 | fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(), |
410 | 28 - inverse_amplitude_neg_exponent); |
411 | // fixed-point tanh and multiplication |
412 | F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized; |
413 | // right shift |
414 | DataType int32_output = |
415 | Add(Dup<DataType>(real_zero_as_int32), |
416 | ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent)); |
417 | |
418 | DataType mask_if_below_cutoff_min = |
419 | MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min)); |
420 | DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual( |
421 | input.reg[i], Dup<DataType>(input_cutoff_max)); |
422 | |
423 | output.reg[i] = SelectUsingMask( |
424 | mask_if_below_cutoff_min, Dup<DataType>(output_min), |
425 | SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max), |
426 | int32_output)); |
427 | } |
428 | return output; |
429 | } |
430 | |
431 | const OutputStage& output_stage; |
432 | std::int32_t input_cutoff_min, input_cutoff_max; |
433 | std::int32_t output_min, output_max; |
434 | FixedPoint<DataType, 0> inverse_amplitude_normalized; |
435 | int inverse_amplitude_neg_exponent; |
436 | FixedPoint<DataType, 0> amplitude_normalized; |
437 | int amplitude_exponent; |
438 | }; |
439 | |
440 | // OutputPipelineOutputType is a helper to determine the output data type of a |
441 | // pipeline, for a |
442 | // given input data type. It is a recursive template; see the explanation on |
443 | // OutputPipelineEvalImpl below. |
444 | template <typename OutputPipelineType, int FirstStage, typename InputType, |
445 | bool StopRecursion = |
446 | FirstStage == std::tuple_size<OutputPipelineType>::value> |
447 | struct OutputPipelineOutputType { |
448 | typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type |
449 | FirstStageType; |
450 | typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType |
451 | FirstStageOutputType; |
452 | typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1, |
453 | FirstStageOutputType>::Type Type; |
454 | }; |
455 | |
456 | template <typename OutputPipelineType, int FirstStage, typename InputType> |
457 | struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType, |
458 | true> { |
459 | typedef InputType Type; |
460 | }; |
461 | |
462 | // OutputPipelineEvalImpl is a helper to implement the evaluation of |
463 | // the whole pipeline. It is a recursive template to implement compile-time |
464 | // unrolling of the loop over all pipeline stages. The 'FirstStage' parameter |
465 | // is how we implement recursion: each specialization implements only |
466 | // evaluation starting at 'FirstStage'. The StopRecursion parameter is just a |
467 | // helper to implement the termination of the recursion as a partial |
468 | // specialization below. |
469 | template <typename OutputPipelineType, int FirstStage, typename InputType, |
470 | bool StopRecursion = |
471 | FirstStage == std::tuple_size<OutputPipelineType>::value> |
472 | struct OutputPipelineEvalImpl { |
473 | typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type |
474 | FirstStageType; |
475 | typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType |
476 | FirstStageOutputType; |
477 | typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage, |
478 | InputType>::Type OutputType; |
479 | |
480 | OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline) |
481 | : head_impl(std::get<FirstStage>(output_pipeline)), |
482 | tail_impl(output_pipeline) {} |
483 | |
484 | OutputType Eval(InputType input, int row, int col) const { |
485 | // Evaluate the first stage. |
486 | FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col); |
487 | // Recurse into the remaining stages. |
488 | return tail_impl.Eval(first_stage_output, row, col); |
489 | } |
490 | |
491 | const OutputStageEvalImpl<FirstStageType, InputType> head_impl; |
492 | const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1, |
493 | FirstStageOutputType> |
494 | tail_impl; |
495 | }; |
496 | |
497 | // Specialization on 'StopRecursion' for terminating the recursion. |
498 | template <typename OutputPipelineType, int FirstStage, typename InputType> |
499 | struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> { |
500 | OutputPipelineEvalImpl(const OutputPipelineType&) {} |
501 | |
502 | InputType Eval(InputType input, int, int) const { |
503 | // Terminating the recursion. |
504 | return input; |
505 | } |
506 | }; |
507 | |
508 | template <typename RegisterBlockType, typename DstType> |
509 | struct StoreFinalOutputImpl { |
510 | static_assert(std::is_same<RegisterBlockType, void>::value, |
511 | "This generic impl should never be hit" ); |
512 | }; |
513 | |
514 | template <typename ScalarType, int Rows, int Cols, typename DstType> |
515 | struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> { |
516 | using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; |
517 | static void Run(const RegisterBlockType& src, DstType* dst, int row, |
518 | int col) { |
519 | for (int r = 0; r < Rows; r++) { |
520 | for (int c = 0; c < Cols; c++) { |
521 | *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows]; |
522 | } |
523 | } |
524 | } |
525 | }; |
526 | |
527 | // StoreFinalOutput takes the final value at the end of the output pipeline and |
528 | // stores it into the destination matrix. It can be specialized for different |
529 | // data types; the generic implementation here is typically used only for plain |
530 | // old scalar (not SIMD) types. |
531 | template <typename RegisterBlockType, typename DstType> |
532 | void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) { |
533 | StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col); |
534 | } |
535 | |
536 | template <typename OutputPipelineType, typename InputType> |
537 | struct OutputPipelineExecutor { |
538 | OutputPipelineExecutor(const OutputPipelineType& output_pipeline) |
539 | : output_pipeline_eval_impl_(output_pipeline) {} |
540 | |
541 | // Execute is the entry point into the output pipeline evaluation |
542 | // code. It should be the only thing that unpack code calls. It takes the |
543 | // result |
544 | // of the unpack stage and stores it into the destination matrix. |
545 | template <typename DstType> |
546 | void Execute(InputType input, DstType* dst, int src_global_row, |
547 | int src_global_col, int dst_row, int dst_col) const { |
548 | // Statically assert that the output pipeline matches the given destination |
549 | // matrix's scalar type. |
550 | typedef typename OutputPipelineOutputType< |
551 | OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType |
552 | |
553 | ScalarOutputType; |
554 | typedef typename DstType::Scalar ScalarDstType; |
555 | static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value, |
556 | "mismatched destination scalar type and output pipeline" ); |
557 | |
558 | // Evaluate the output pipeline. |
559 | auto output = |
560 | output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col); |
561 | // Store the result into the destination matrix. |
562 | StoreFinalOutput(output, dst, dst_row, dst_col); |
563 | } |
564 | |
565 | const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType> |
566 | output_pipeline_eval_impl_; |
567 | }; |
568 | |
569 | } // namespace gemmlowp |
570 | |
571 | #ifdef GEMMLOWP_NEON |
572 | #include "output_neon.h" |
573 | #elif defined(GEMMLOWP_SSE4) |
574 | #include "output_sse.h" |
575 | #elif defined(GEMMLOWP_MSA) |
576 | #include "output_msa.h" |
577 | #endif |
578 | |
579 | #endif // GEMMLOWP_INTERNAL_OUTPUT_H_ |
580 | |