1 | /******************************************************************************* |
2 | * Copyright 2017-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_SIMPLE_Q10N_HPP |
18 | #define CPU_SIMPLE_Q10N_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | template <typename data_t, typename acc_t> |
33 | inline typename utils::enable_if<!nstl::is_integral<data_t>::value, |
34 | typename utils::remove_reference<acc_t>::type>::type |
35 | saturate(const acc_t &x) { |
36 | acc_t v = x; |
37 | return v; |
38 | } |
39 | |
40 | template <typename data_t, typename acc_t> |
41 | inline typename utils::enable_if<nstl::is_integral<data_t>::value, |
42 | typename utils::remove_reference<acc_t>::type>::type |
43 | saturate(const acc_t &x) { |
44 | acc_t v = x; |
45 | acc_t lbound = (acc_t)nstl::numeric_limits<data_t>::lowest(); |
46 | // Pick up a modified version of max value when do f32 -> s32. |
47 | acc_t ubound = types::max_value<acc_t>(data_traits<data_t>::data_type); |
48 | if (v < lbound) v = lbound; |
49 | if (v > ubound) v = ubound; |
50 | return v; |
51 | } |
52 | |
53 | template <> |
54 | inline uint8_t saturate<int8_t, uint8_t>(const uint8_t &x) { |
55 | return x <= 127u ? x : 127; |
56 | } |
57 | |
58 | template <> |
59 | inline int8_t saturate<uint8_t, int8_t>(const int8_t &x) { |
60 | return x >= 0 ? x : 0; |
61 | } |
62 | |
63 | template <typename out_t> |
64 | inline typename utils::enable_if<nstl::is_integral<out_t>::value, |
65 | typename utils::remove_reference<out_t>::type>::type |
66 | out_round(float v) { |
67 | return (out_t)math::mxcsr_cvt(v); |
68 | } |
69 | |
70 | template <typename out_t> |
71 | inline typename utils::enable_if<!nstl::is_integral<out_t>::value, |
72 | typename utils::remove_reference<out_t>::type>::type |
73 | out_round(float v) { |
74 | return v; |
75 | } |
76 | |
77 | template <typename out_t, typename acc_t = float> |
78 | inline out_t saturate_and_round(acc_t f) { |
79 | return out_round<out_t>(saturate<out_t, acc_t>(f)); |
80 | } |
81 | |
82 | /* Quantization with alpha == 1 and beta == 0 */ |
83 | template <typename in_t, typename out_t, typename enabled = void> |
84 | struct qz_a1b0 { |
85 | out_t operator()(in_t in) { return saturate_and_round<out_t>((float)in); } |
86 | }; |
87 | |
88 | template <typename in_t, typename out_t> |
89 | struct qz_a1b0<in_t, out_t, |
90 | typename utils::enable_if<true && nstl::is_integral<in_t>::value |
91 | && !is_subset<in_t, out_t>::value>::type> { |
92 | out_t operator()(in_t in) { return saturate<out_t>(in); } |
93 | }; |
94 | |
95 | template <typename in_t, typename out_t> |
96 | struct qz_a1b0<in_t, out_t, |
97 | typename utils::enable_if<is_subset<in_t, out_t>::value>::type> { |
98 | out_t operator()(in_t in) { return (out_t)in; } |
99 | }; |
100 | |
101 | /* Quantization with alpha == 1 */ |
102 | template <typename in_t, typename out_t> |
103 | struct qz_a1 { |
104 | out_t operator()(in_t in, out_t out, float beta) { |
105 | return saturate_and_round<out_t>((float)in + beta * out); |
106 | } |
107 | }; |
108 | |
109 | template <typename in_t> |
110 | struct qz_a1<in_t, float> { |
111 | float operator()(in_t in, float out, float beta) { |
112 | return (float)in + beta * out; |
113 | } |
114 | }; |
115 | |
116 | /* Quantization with beta == 0 */ |
117 | template <typename in_t, typename out_t> |
118 | struct qz_b0 { |
119 | out_t operator()(in_t in, float alpha) { |
120 | return saturate_and_round<out_t>(alpha * in); |
121 | } |
122 | }; |
123 | |
124 | template <typename in_t> |
125 | struct qz_b0<in_t, float> { |
126 | float operator()(in_t in, float alpha) { return alpha * in; } |
127 | }; |
128 | |
129 | /* Quantization */ |
130 | template <typename in_t, typename out_t> |
131 | struct qz { |
132 | out_t operator()(in_t in, out_t out, float alpha, float beta) { |
133 | return saturate_and_round<out_t>(alpha * in + (beta ? beta * out : 0)); |
134 | } |
135 | }; |
136 | |
137 | template <typename in_t> |
138 | struct qz<in_t, float> { |
139 | float operator()(in_t in, float out, float alpha, float beta) { |
140 | return alpha * in + (beta ? beta * out : 0); |
141 | } |
142 | }; |
143 | |
144 | template <> |
145 | struct qz<bfloat16_t, bfloat16_t> { |
146 | float operator()(bfloat16_t in, bfloat16_t out, float alpha, float beta) { |
147 | return (bfloat16_t)(alpha * (float)in + (beta ? beta * (float)out : 0)); |
148 | } |
149 | }; |
150 | |
151 | template <> |
152 | struct qz<float, bfloat16_t> { |
153 | float operator()(float in, bfloat16_t out, float alpha, float beta) { |
154 | return (bfloat16_t)(alpha * in + (beta ? beta * out : 0)); |
155 | } |
156 | }; |
157 | |
158 | template <> |
159 | struct qz<float16_t, float16_t> { |
160 | float operator()(float16_t in, float16_t out, float alpha, float beta) { |
161 | return (float16_t)(alpha * (float)in + (beta ? beta * (float)out : 0)); |
162 | } |
163 | }; |
164 | |
165 | template <> |
166 | struct qz<float, float16_t> { |
167 | float operator()(float in, float16_t out, float alpha, float beta) { |
168 | return (float16_t)(alpha * in + (beta ? beta * out : 0)); |
169 | } |
170 | }; |
171 | |
172 | } // namespace cpu |
173 | } // namespace impl |
174 | } // namespace dnnl |
175 | |
176 | #endif |
177 | |