1 | /* Copyright 2020 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Provides a reference (portable, non-optimized) ApplyMultiplier function. |
17 | // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! |
18 | // Warning: this code is not meant to be bit-exact-normative. |
19 | // Please refer to the class comment of ruy::MulParams, in mul_params.h. |
20 | // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! |
21 | |
22 | #ifndef RUY_RUY_APPLY_MULTIPLIER_H_ |
23 | #define RUY_RUY_APPLY_MULTIPLIER_H_ |
24 | |
25 | #include <cstdint> |
26 | #include <type_traits> |
27 | |
28 | #include "ruy/check_macros.h" |
29 | #include "ruy/mul_params.h" |
30 | |
31 | namespace ruy { |
32 | |
33 | // Applies the quantized multiplier to the `*accum` accumulator value, if |
34 | // applicable, that is, if AccumScalar==int32 and DstScalar!=int32. Otherwise, |
35 | // does nothing. |
36 | // |
37 | // This is slow, portable, 'reference' code. It should only be used in |
38 | // ReferenceMul and in Path::kStandardCpp. There isn't a point in optimizing it, |
39 | // either. Fast paths have that multiplier work done as part of the kernel, |
40 | // typically written in assembly anyway. |
41 | template <typename AccumScalar, typename DstScalar> |
42 | void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, |
43 | int channel, AccumScalar* accum); |
44 | |
45 | namespace detail { |
46 | |
47 | // Copied from TF Lite code. |
48 | std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x, |
49 | std::int32_t quantized_multiplier, |
50 | int shift); |
51 | |
52 | // Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar |
53 | // is int32 (i.e. in all cases except floating-point) and if the destination is |
54 | // not int32 (i.e. unless the user wants to get raw accumulators). |
55 | template <typename AccumScalar, typename DstScalar, |
56 | bool IsApplicable = std::is_same<AccumScalar, std::int32_t>::value && |
57 | !std::is_same<DstScalar, std::int32_t>::value> |
58 | struct ApplyMultiplierImpl {}; |
59 | |
60 | // Specialization in non-applicable case: do nothing. |
61 | template <typename AccumScalar, typename DstScalar> |
62 | struct ApplyMultiplierImpl<AccumScalar, DstScalar, false> { |
63 | static void Run(const MulParams<AccumScalar, DstScalar>&, int, AccumScalar*) { |
64 | } |
65 | }; |
66 | |
67 | template <typename AccumScalar, typename DstScalar> |
68 | struct ApplyMultiplierImpl<AccumScalar, DstScalar, true> { |
69 | static void Run(const MulParams<AccumScalar, DstScalar>& mul_params, |
70 | int channel, AccumScalar* accum) { |
71 | AccumScalar m = mul_params.multiplier_fixedpoint_perchannel() |
72 | ? mul_params.multiplier_fixedpoint_perchannel()[channel] |
73 | : mul_params.multiplier_fixedpoint(); |
74 | int e = mul_params.multiplier_exponent_perchannel() |
75 | ? mul_params.multiplier_exponent_perchannel()[channel] |
76 | : mul_params.multiplier_exponent(); |
77 | *accum = MultiplyByQuantizedMultiplier(*accum, m, e); |
78 | } |
79 | }; |
80 | |
81 | } // namespace detail |
82 | |
83 | template <typename AccumScalar, typename DstScalar> |
84 | void ApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, |
85 | int channel, AccumScalar* accum) { |
86 | detail::ApplyMultiplierImpl<AccumScalar, DstScalar>::Run(mul_params, channel, |
87 | accum); |
88 | } |
89 | |
90 | } // namespace ruy |
91 | |
92 | #endif // RUY_RUY_APPLY_MULTIPLIER_H_ |
93 | |