1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example cpu_matmul_quantization.cpp
18/// > Annotated version: @ref cpu_matmul_quantization_cpp
19///
20/// @page cpu_matmul_quantization_cpp_short
21/// C++ API example demonstrating how one can perform reduced precision
22/// matrix-matrix multiplication using [MatMul](@ref dev_guide_matmul) and the
23/// accuracy of the result compared to the floating point computations.
24///
25/// Concepts:
26/// - **Static** and **dynamic** quantization
27/// - Asymmetric quantization
28/// - Scales: dnnl::primitive_attr::set_scales_mask()
29/// - Zero points: dnnl::primitive_attr::set_zero_points_mask()
30///
31/// @page cpu_matmul_quantization_cpp MatMul Tutorial: Quantization
32/// @copydetails cpu_matmul_quantization_cpp_short
33///
34/// The example is focused around the following computation:
35/// \f[
36/// C = A \times B
37/// \f]
38///
39/// First, we produce the reference result, having the original matrices
40/// \f$A\f$ and \f$B\f$ be in #dnnl::memory::data_type::f32 data type.
41///
42/// For reduced precision computations, the matrices \f$A\f$ and \f$C\f$ will
43/// use #dnnl::memory::data_type::u8 data type and would have the appropriate
44/// zero points. For the matrix \f$B\f$, we will use the
45/// #dnnl::memory::data_type::s8 data type, assuming that the data is centered
46/// around zero (hence, the zero point would be simply 0).
47///
48/// The quantization formula is:
49/// \f[
50/// X_{f32}(:) := scale\_X \cdot (X_{int8}(:) - zp\_X),
51/// \f]
52///
53/// where:
54/// - \f$X_{f32}(:)\f$ -- original matrix;
55///
56/// - \f$X_{int8}(:)\f$ -- quantized matrix, where `int8` is either `u8`
57/// (`uint8_t`) for the matrices \f$A\f$ and \f$C\f$, or
58/// `s8` (`int8_t`) for the matrix \f$B\f$;
59///
60/// - \f$scale\_X\f$ -- `f32` scaling factor. For simplicity we will use a
61/// single scale factor for each matrix, though for
62/// better accuracy it might be a good idea to use
63/// per-N-dimension scaling factor for the matrix B.
64///
65/// - \f$zp\_X\f$ -- integer quantization parameter "zero point"
66/// (essentially, the representation of the real 0 in
67/// the quantized data type).
68///
69/// For a given matrix \f$X_{f32}\f$ and `int8` data type (`u8` or `s8`), the
70/// process of finding the proper \f$scale\_X\f$ and \f$zp\_X\f$ is a research
71/// problem and can be different depending on the domain. For example purposes,
72/// we will use the simplest approach by mapping the maximum (minimum)
73/// \f$X_{f32}\f$ elements to the maximum (minimum) number in the corresponding
74/// integer data type, using the following formulas:
75///
76/// 1. Since:
77/// - \f$max(X_{f32}(:)) = scale\_X \cdot (max_{int8} - zp\_X)\f$
78/// - \f$min(X_{f32}(:)) = scale\_X \cdot (min_{int8} - zp\_X)\f$
79///
80/// 2. Hence:
81/// - \f$scale\_X =
82/// \frac{max(X_{f32}(:)) - min(X_{f32}(:))}{max_{int8} - min_{int8}}\f$
83/// - \f$zp\_X = max_{int8} - \frac{max(X_{f32}(:))}{scale\_X}\f$
84///
85/// It is worth noting that quantization parameters are not always computed at
86/// actual run-time. For example, if we perform MatMul operation for _similar_
87/// matrices (in a sense that data distribution is similar between the runs) we
88/// can simply _guess_ the proper quantization parameters by collecting some
89/// statistics during the early runs. This approach is called **static**
90/// quantization. It gives good performance (since no cycles are spent on
91/// computing those parameters) and is typically used in reduced precision
92/// CNN inference. However, the **static** quantization has an obvious
93/// disadvantage -- the _guessed_ parameters might not work well for some
94/// particular matrices. For example, that would most likely be the case if we
95/// could not guarantee the similarity of the input matrices. In this case, the
96/// **dynamic** quantization would be used, i.e. the parameters (re-)computed at
97/// runtime. This gives slightly worse performance, but that might be inevitable
98/// due to accuracy considerations.
99///
100/// Only dynamic approaches is demonstrated in this example.
101///
102/// Other details:
103/// - For simplicity all matrices will be stored in Row-Major format.
104/// - The shapes of the matrices are assumed to be known at creation time.
105/// However, for dynamic quantization we would consider q10n parameters
106/// (\f$scale\_X\f$ and \f$zp\_X\f$) to be known at run-time only. On the
107/// contrary, for the static quantization these parameters are known at
108/// creation time as well.
109///
110/// @include cpu_matmul_quantization.cpp
111
112#include <cassert>
113#include <cctype>
114#include <cmath>
115#include <cstdio>
116#include <iostream>
117#include <random>
118#include <stdexcept>
119#include <vector>
120#include <type_traits>
121
122#include "oneapi/dnnl/dnnl.hpp"
123
124#include "example_utils.hpp"
125
126using namespace dnnl;
127
128namespace {
129
130void init_vector(std::vector<float> &v, float min_value, float max_value) {
131 std::mt19937 gen;
132 std::uniform_real_distribution<float> u(min_value, max_value);
133
134 for (auto &e : v)
135 e = u(gen);
136}
137
138template <typename T>
139void find_min_max(const std::vector<T> &v, float &min_value, float &max_value) {
140 min_value = max_value = v[0];
141 for (auto &e : v) {
142 min_value = std::min<float>(min_value, e);
143 max_value = std::max<float>(max_value, e);
144 }
145}
146
147template <typename T>
148void compute_q10n_params(const char *message, const std::vector<float> &v,
149 float &scale, int32_t &zp) {
150 // Find property of T integer type
151 // Simple trick to improve accuracy: shrink the range a little bit
152 float max_int = (float)std::numeric_limits<T>::max() - 1;
153 float min_int = (float)std::numeric_limits<T>::lowest() + 1;
154
155#ifndef OMIT_WORKAROUND_FOR_SKX
156 // Read more in CPU / Section 1 here:
157 // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html
158 if (std::is_same<T, uint8_t>::value) max_int /= 2;
159#endif
160
161 // Find min and max value in array
162 float min_val = v[0], max_val = v[0];
163 find_min_max(v, min_val, max_val);
164
165 // Compute appropriate scale
166 scale = (max_val - min_val) / (max_int - min_int);
167
168 // Compute appropriate offset
169 if (std::is_same<T, int8_t>::value)
170 zp = 0;
171 else
172 zp = (int32_t)(max_int - max_val / scale);
173 printf("\tComputing q10n params for %s\n"
174 "\t\tData type: %s\n"
175 "\t\tScale:%.3g (inverse scale:%.3g)\n"
176 "\t\tZero point:%d\n\n",
177 message, std::is_same<T, int8_t>::value ? "int8_t" : "uint8_t",
178 scale, 1 / scale, zp);
179}
180
181int compare_vectors(const std::vector<float> &v1,
182 const std::vector<uint8_t> &v2, float scale_v2, int32_t zp_v2,
183 float threshold) {
184 double v1_l2 = 0, diff_l2 = 0;
185 for (size_t n = 0; n < v1.size(); ++n) {
186 float v2_n = scale_v2 * (v2[n] - zp_v2); // deq10n v2
187 float diff = v1[n] - v2_n;
188 v1_l2 += v1[n] * v1[n];
189 diff_l2 += diff * diff;
190 }
191
192 v1_l2 = std::sqrt(v1_l2);
193 diff_l2 = std::sqrt(diff_l2);
194 bool ok = diff_l2 <= threshold * v1_l2;
195
196 printf("\tComparison (using l2-norms)\n"
197 "\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative error:%g\n"
198 "\nAccuracy check: %s\n\n",
199 v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED");
200
201 return ok ? 0 : 1;
202}
203
204} // namespace
205
206engine eng(engine::kind::cpu, 0); // We create a global engine for simplicity
207
208// Quantize float data into X_int_m oneDNN memory using the q10n parameters
209//
210// Inputs:
211// - X_f32 -- source f32 matrix
212// - scale_X, zp_X -- quantization parameters
213// - q10n_scheme -- dynamic or static, to mimic real-world applications wrt to
214// how the q10n parameters are passed to reorders
215// Outputs:
216// - X_int_m -- prepared oneDNN memory that would hold quantized values
217void quantize(const std::vector<float> &X_f32, float scale_X, int32_t zp_X,
218 memory &X_int_m) {
219 using dt = memory::data_type;
220
221 stream s(eng);
222
223 memory::desc x_int_md = X_int_m.get_desc();
224 const auto &dims = x_int_md.get_dims();
225
226 memory::desc x_f32_md({dims[0], dims[1]}, dt::f32, {dims[1], 1});
227 memory X_f32_m(x_f32_md, eng, (void *)X_f32.data());
228
229 primitive_attr q10n_attr;
230 q10n_attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 0);
231 q10n_attr.set_zero_points_mask(DNNL_ARG_DST, /* mask */ 0);
232
233 reorder::primitive_desc q10n_pd(eng, x_f32_md, eng, x_int_md, q10n_attr);
234 memory dst_scale_X_m({{1}, dt::f32, {1}}, eng, &scale_X);
235 memory zp_X_m({{1}, dt::s32, {1}}, eng, &zp_X);
236 reorder(q10n_pd).execute(s,
237 {{DNNL_ARG_SRC, X_f32_m}, {DNNL_ARG_DST, X_int_m},
238 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale_X_m},
239 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_X_m}});
240
241 s.wait();
242}
243
244// Floating point MatMul
245// Inputs:
246// - Shape: M, N, K
247// - Matrices A and B
248// Outputs:
249// - Matrix C
250void f32_matmul_compute(int64_t M, int64_t N, int64_t K,
251 const std::vector<float> &A_f32, const std::vector<float> &B_f32,
252 std::vector<float> &C_f32) {
253 // Initialize memory descriptors that describes matrices in Row-Major format
254 memory::desc a_md({M, K}, memory::data_type::f32, {K, 1});
255 memory::desc b_md({K, N}, memory::data_type::f32, {N, 1});
256 memory::desc c_md({M, N}, memory::data_type::f32, {N, 1});
257
258 // Wrap raw pointers into oneDNN memory objects
259 memory A_f32_m(a_md, eng, (void *)A_f32.data());
260 memory B_f32_m(b_md, eng, (void *)B_f32.data());
261 memory C_f32_m(c_md, eng, (void *)C_f32.data());
262
263 // Create a MatMul primitive
264 matmul::primitive_desc matmul_pd(eng, a_md, b_md, c_md);
265 matmul matmul_p(matmul_pd);
266
267 stream s(eng);
268 matmul_p.execute(s,
269 {{DNNL_ARG_SRC, A_f32_m}, {DNNL_ARG_WEIGHTS, B_f32_m},
270 {DNNL_ARG_DST, C_f32_m}});
271 s.wait();
272}
273
274// Reduced precision MatMul with **dynamic** quantization
275// Inputs:
276// - Shape: M, N, K
277// - Matrices A and B in float (would be quantized inside the function)
278// Outputs:
279// - Matrix C in uint8_t
280// - Quantization parameters: scale_C and zp_C
281void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K,
282 const std::vector<float> &A_f32, const std::vector<float> &B_f32,
283 std::vector<uint8_t> &C_u8, float &scale_C, int32_t &zp_C) {
284 stream s(eng);
285
286 float scale_A, scale_B;
287 int32_t zp_A, zp_B;
288
289 // We compute q10n parameters here, but in the real world applications for
290 // inputs these parameters are transferred from the previous layers
291 compute_q10n_params<uint8_t>("A", A_f32, scale_A, zp_A);
292 compute_q10n_params<int8_t>("B", B_f32, scale_B, zp_B);
293 assert(zp_B == 0 && "for int8 q10n we assume zero point = 0");
294
295 // Quantize matrix A_u8 using reorder primitive
296 std::vector<uint8_t> A_u8(M * K, 0);
297 memory::desc a_u8_md({M, K}, memory::data_type::u8, {K, 1});
298 memory A_u8_m(a_u8_md, eng, (void *)A_u8.data());
299 quantize(A_f32, scale_A, zp_A, A_u8_m);
300
301 // Quantize matrix B_s8 using reorder primitive
302 std::vector<uint8_t> B_s8(K * N, 0);
303 memory::desc b_s8_md({K, N}, memory::data_type::s8, {N, 1});
304 memory B_s8_m(b_s8_md, eng, (void *)B_s8.data());
305 quantize(B_f32, scale_B, 0, B_s8_m);
306
307 // Compute C_f32. We cannot directly compute C_u8 since we don't know the
308 // appropriate quantization parameters.
309 //
310 // Note: typically the computed data type in this case is int32_t and not
311 // float. But for brevity we are going to embed the scale_A and
312 // scale_B directly in this quantized MatMul, and hence will get the
313 // intermediate computation in floating point anyways, so there is
314 // no sense to convert the result to int32_t.
315 // In theory, we could postpone using the scale_A and scale_B, compute
316 // the exact C_s32 := (A_u8 - zp_A) * B_s8, and then find the
317 // appropriate quantization parameters for matrix C.
318 // Let it be an exercise :)
319
320 std::vector<float> C_f32(M * N, 0);
321 memory::desc c_f32_md({M, N}, memory::data_type::f32, {N, 1});
322 memory C_f32_m(c_f32_md, eng, (void *)C_f32.data());
323
324 // Create and compute a reduced precision MatMul primitive
325 {
326 primitive_attr matmul_attr;
327 matmul_attr.set_scales_mask(DNNL_ARG_SRC, /* mask */ 0);
328 matmul_attr.set_scales_mask(DNNL_ARG_WEIGHTS, /* mask */ 0);
329 matmul_attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0);
330
331 matmul::primitive_desc matmul_pd(
332 eng, a_u8_md, b_s8_md, c_f32_md, matmul_attr);
333 matmul matmul_p(matmul_pd);
334
335 memory scales_A_m({{1}, memory::data_type::f32, {1}}, eng, &scale_A);
336 memory scales_B_m({{1}, memory::data_type::f32, {1}}, eng, &scale_B);
337 memory zp_A_m({{1}, memory::data_type::s32, {1}}, eng, &zp_A);
338
339 matmul_p.execute(s,
340 {{DNNL_ARG_SRC, A_u8_m}, {DNNL_ARG_WEIGHTS, B_s8_m},
341 {DNNL_ARG_DST, C_f32_m},
342 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_A_m},
343 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scales_B_m},
344 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_m}});
345 }
346
347 // Find quantization parameters for matrix C
348 compute_q10n_params<uint8_t>("C", C_f32, scale_C, zp_C);
349
350 // Finally quantize the matrix C
351 memory::desc c_u8_md({M, N}, memory::data_type::u8, {N, 1});
352 memory C_u8_m(c_u8_md, eng, (void *)C_u8.data());
353 quantize(C_f32, scale_C, zp_C, C_u8_m);
354}
355
356void compare_f32_and_quantized_matmuls() {
357 // MatMul parameters
358 const int64_t M = 10, N = 20, K = 30;
359
360 // Data distribution for matrices A and B
361 const float param_A_min_val = -2.f;
362 const float param_A_max_val = 1.4f;
363
364 const float param_B_min_val = -1.f;
365 const float param_B_max_val = -param_B_min_val; // B is centered around 0
366
367 // Thresholds
368 //
369 const float threshold_dynamic_q10n = 3 * 1e-2f;
370
371 // Prepare matrices
372 std::vector<float> A_f32(M * K), B_f32(K * N), C_f32(M * N, 0);
373 init_vector(A_f32, param_A_min_val, param_A_max_val);
374 init_vector(B_f32, param_B_min_val, param_B_max_val);
375
376 // Compute _true_ f32 result
377 f32_matmul_compute(M, N, K, A_f32, B_f32, C_f32);
378
379 std::vector<uint8_t> C_u8_dynamic_q10n(M * N, 0);
380
381 float scale_C_dynamic_q10n; // Q10n parameters we don't know yet
382 int zp_C_dynamic_q10n;
383
384 dynamic_q10n_matmul(M, N, K, A_f32, B_f32, C_u8_dynamic_q10n,
385 scale_C_dynamic_q10n, zp_C_dynamic_q10n);
386
387 // Compare _true_ f32 result with dynamic q10n
388 int rc = compare_vectors(C_f32, C_u8_dynamic_q10n, scale_C_dynamic_q10n,
389 zp_C_dynamic_q10n, threshold_dynamic_q10n);
390 if (rc) throw std::logic_error("Dynamic quantization accuracy failed.");
391}
392
393int main(int argc, char **argv) {
394 return handle_example_errors(
395 {engine::kind::cpu}, compare_f32_and_quantized_matmuls);
396}
397