1/*******************************************************************************
2* Copyright 2017-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_SIMPLE_Q10N_HPP
18#define CPU_SIMPLE_Q10N_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/math_utils.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32template <typename data_t, typename acc_t>
33inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
34 typename utils::remove_reference<acc_t>::type>::type
35saturate(const acc_t &x) {
36 acc_t v = x;
37 return v;
38}
39
40template <typename data_t, typename acc_t>
41inline typename utils::enable_if<nstl::is_integral<data_t>::value,
42 typename utils::remove_reference<acc_t>::type>::type
43saturate(const acc_t &x) {
44 acc_t v = x;
45 acc_t lbound = (acc_t)nstl::numeric_limits<data_t>::lowest();
46 // Pick up a modified version of max value when do f32 -> s32.
47 acc_t ubound = types::max_value<acc_t>(data_traits<data_t>::data_type);
48 if (v < lbound) v = lbound;
49 if (v > ubound) v = ubound;
50 return v;
51}
52
53template <>
54inline uint8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
55 return x <= 127u ? x : 127;
56}
57
58template <>
59inline int8_t saturate<uint8_t, int8_t>(const int8_t &x) {
60 return x >= 0 ? x : 0;
61}
62
63template <typename out_t>
64inline typename utils::enable_if<nstl::is_integral<out_t>::value,
65 typename utils::remove_reference<out_t>::type>::type
66out_round(float v) {
67 return (out_t)math::mxcsr_cvt(v);
68}
69
70template <typename out_t>
71inline typename utils::enable_if<!nstl::is_integral<out_t>::value,
72 typename utils::remove_reference<out_t>::type>::type
73out_round(float v) {
74 return v;
75}
76
77template <typename out_t, typename acc_t = float>
78inline out_t saturate_and_round(acc_t f) {
79 return out_round<out_t>(saturate<out_t, acc_t>(f));
80}
81
82/* Quantization with alpha == 1 and beta == 0 */
83template <typename in_t, typename out_t, typename enabled = void>
84struct qz_a1b0 {
85 out_t operator()(in_t in) { return saturate_and_round<out_t>((float)in); }
86};
87
88template <typename in_t, typename out_t>
89struct qz_a1b0<in_t, out_t,
90 typename utils::enable_if<true && nstl::is_integral<in_t>::value
91 && !is_subset<in_t, out_t>::value>::type> {
92 out_t operator()(in_t in) { return saturate<out_t>(in); }
93};
94
95template <typename in_t, typename out_t>
96struct qz_a1b0<in_t, out_t,
97 typename utils::enable_if<is_subset<in_t, out_t>::value>::type> {
98 out_t operator()(in_t in) { return (out_t)in; }
99};
100
101/* Quantization with alpha == 1 */
102template <typename in_t, typename out_t>
103struct qz_a1 {
104 out_t operator()(in_t in, out_t out, float beta) {
105 return saturate_and_round<out_t>((float)in + beta * out);
106 }
107};
108
109template <typename in_t>
110struct qz_a1<in_t, float> {
111 float operator()(in_t in, float out, float beta) {
112 return (float)in + beta * out;
113 }
114};
115
116/* Quantization with beta == 0 */
117template <typename in_t, typename out_t>
118struct qz_b0 {
119 out_t operator()(in_t in, float alpha) {
120 return saturate_and_round<out_t>(alpha * in);
121 }
122};
123
124template <typename in_t>
125struct qz_b0<in_t, float> {
126 float operator()(in_t in, float alpha) { return alpha * in; }
127};
128
129/* Quantization */
130template <typename in_t, typename out_t>
131struct qz {
132 out_t operator()(in_t in, out_t out, float alpha, float beta) {
133 return saturate_and_round<out_t>(alpha * in + (beta ? beta * out : 0));
134 }
135};
136
137template <typename in_t>
138struct qz<in_t, float> {
139 float operator()(in_t in, float out, float alpha, float beta) {
140 return alpha * in + (beta ? beta * out : 0);
141 }
142};
143
144template <>
145struct qz<bfloat16_t, bfloat16_t> {
146 float operator()(bfloat16_t in, bfloat16_t out, float alpha, float beta) {
147 return (bfloat16_t)(alpha * (float)in + (beta ? beta * (float)out : 0));
148 }
149};
150
151template <>
152struct qz<float, bfloat16_t> {
153 float operator()(float in, bfloat16_t out, float alpha, float beta) {
154 return (bfloat16_t)(alpha * in + (beta ? beta * out : 0));
155 }
156};
157
158template <>
159struct qz<float16_t, float16_t> {
160 float operator()(float16_t in, float16_t out, float alpha, float beta) {
161 return (float16_t)(alpha * (float)in + (beta ? beta * (float)out : 0));
162 }
163};
164
165template <>
166struct qz<float, float16_t> {
167 float operator()(float in, float16_t out, float alpha, float beta) {
168 return (float16_t)(alpha * in + (beta ? beta * out : 0));
169 }
170};
171
172} // namespace cpu
173} // namespace impl
174} // namespace dnnl
175
176#endif
177