1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/// @example inference_int8_matmul.cpp
18/// > Annotated version: @ref inference_int8_matmul_cpp
19///
20/// @page inference_int8_matmul_cpp_short
21/// C++ API example demonstrating how one can use
22/// [MatMul](@ref dev_guide_matmul) fused with ReLU in INT8 inference.
23///
24/// Concepts:
25/// - Asymmetric quantization
26/// - Scales: dnnl::primitive_attr::set_scales_mask()
27/// - Zero points: dnnl::primitive_attr::set_zero_points_mask()
28/// - [Operation fusion](@ref dev_guide_attributes_post_ops)
29/// - Create primitive once, use multiple times
30/// - Run-time tensor shapes: #DNNL_RUNTIME_DIM_VAL
31/// - Weights pre-packing: use #dnnl::memory::format_tag::any
32///
33/// @page inference_int8_matmul_cpp MatMul Tutorial: INT8 Inference
34/// @copydetails inference_int8_matmul_cpp_short
35///
36/// Assumptions:
37/// 1. The shape of the weights (matrix \f$B(K, N)\f$) is known in advance, the
38/// data type is `int8_t` and centered around 0 (i.e. the zero point is 0).
39/// 2. The shapes of the source matrix \f$A\f$ and destination matrix \f$C\f$
40/// are partially unknown. Both matrices use `uint8_t` data type and might
41/// have arbitrary zero points (specified at execution time only).
42/// 3. Scaling (re-quantization) factor specified at run-time only.
43///
44/// Since the shape of weights is known in advance, the MatMul weights can be
45/// created with format tag #dnnl::memory::format_tag::any to enable the library
46/// to choose the most appropriate layout for best performance.
47///
48/// @warning
49/// The format tag #dnnl::memory::format_tag::any doesn't work for memory
50/// descriptors that have one or more unknown dimensions and/or strides.
51///
52/// @include inference_int8_matmul.cpp
53
54#include <cassert>
55#include <cctype>
56#include <cmath>
57#include <cstdio>
58#include <iostream>
59#include <random>
60#include <stdexcept>
61#include <vector>
62
63#include "oneapi/dnnl/dnnl.hpp"
64
65#include "example_utils.hpp"
66
67using namespace dnnl;
68
69namespace {
70
71void init_vector(std::vector<float> &v) {
72 std::mt19937 gen;
73 std::uniform_real_distribution<float> u(0, 1);
74 for (auto &e : v)
75 e = u(gen);
76}
77
78void init_vector(std::vector<uint8_t> &v) {
79 std::mt19937 gen;
80 std::uniform_int_distribution<unsigned int> u(0, 255);
81 for (auto &e : v)
82 e = static_cast<uint8_t>(u(gen));
83}
84
85} // namespace
86
87int number_of_runs = 1;
88
89// Create a MatMul primitive descriptor for the following op:
90// C_u8 = ReLU(sc_A * sc_B[:] * (A_u8 - zp_A) * B_s8) / sc_C + zp_C
91//
92// Here:
93// - Matrices A and C are known to be non-transposed but their M dimension is
94// not known. They can be activation matrices in an MLP topology and the M
95// dimension can be the mini-batch dimension.
96// - zp_A and zp_C are zero points for matrices A and C which are stored as
97// uint8_t. These are run-time parameters that are not known at the primitive
98// creation time.
99// - The B matrix is stored as int8_t, its zero point is 0, and all its
100// dimensions are known. This matrix can be a matrix of weights in an MLP
101// topology.
102// - The scaling values are not known at the primitive creation time.
103matmul::primitive_desc matmul_pd_create(
104 int64_t K, int64_t N, const engine &eng) {
105 const int64_t M = DNNL_RUNTIME_DIM_VAL;
106
107 memory::desc a_md({M, K}, memory::data_type::u8, {K, 1}); // M x K layout
108 memory::desc b_md({K, N}, memory::data_type::s8, memory::format_tag::any);
109 memory::desc c_md({M, N}, memory::data_type::u8, {N, 1}); // M x N layout
110
111 // Create attributes and indicate that the alpha and zero points are
112 // runtime parameters
113 primitive_attr attr;
114 attr.set_scales_mask(DNNL_ARG_SRC, /* mask */ 0);
115 attr.set_scales_mask(DNNL_ARG_WEIGHTS, /* mask */ 1 << 1);
116 attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 0);
117 attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0);
118 attr.set_zero_points_mask(DNNL_ARG_DST, /* mask */ 0);
119 post_ops po;
120 po.append_eltwise(algorithm::eltwise_relu, 0.f, 0.f);
121 attr.set_post_ops(po);
122
123 // Create a MatMul primitive descriptor
124 return matmul::primitive_desc(eng, a_md, b_md, c_md, attr);
125}
126
127void prepare_input(memory &A_u8_mem, memory &sc_A_mem, memory &sc_B_mem,
128 memory &sc_C_mem, memory &zp_A_mem, memory &zp_C_mem) {
129 int64_t M = A_u8_mem.get_desc().get_dims()[0];
130 int64_t N = sc_B_mem.get_desc().get_dims()[0];
131 int64_t K = A_u8_mem.get_desc().get_dims()[1];
132
133 std::vector<uint8_t> A_u8(M * K);
134 init_vector(A_u8);
135
136 std::vector<float> sc_B(N);
137 init_vector(sc_B);
138
139 float sc_A = 0.5f;
140 float sc_C = 0.25f;
141 int32_t zp_A = 128, zp_C = 40;
142
143 write_to_dnnl_memory(A_u8.data(), A_u8_mem);
144 write_to_dnnl_memory(&zp_A, zp_A_mem);
145 write_to_dnnl_memory(&zp_C, zp_C_mem);
146 write_to_dnnl_memory(&sc_A, sc_A_mem);
147 write_to_dnnl_memory(sc_B.data(), sc_B_mem);
148 write_to_dnnl_memory(&sc_C, sc_C_mem);
149}
150
151void sanity_check(memory &C_u8_mem, memory &zp_C_mem) {
152 int64_t M = C_u8_mem.get_desc().get_dims()[0];
153 int64_t N = C_u8_mem.get_desc().get_dims()[1];
154 int32_t zp_C = 0;
155 std::vector<uint8_t> C_u8(M * N);
156
157 read_from_dnnl_memory(C_u8.data(), C_u8_mem);
158 read_from_dnnl_memory(&zp_C, zp_C_mem);
159
160 // simple check: C_u8 >= zp_C
161 for (int64_t i = 0; i < M * N; ++i)
162 if (C_u8[i] < zp_C)
163 throw std::logic_error(
164 "Smoke check failed."
165 "\n\tQuantized value is smaller than the zero point,"
166 "\n\twhich should not happen since ReLU was applied.");
167}
168
169void infer(const matmul &matmul_p, int64_t M, int64_t N, int64_t K,
170 const memory &B_s8_mem, const engine &eng) {
171 // inputs of the current layer / operation
172 memory A_u8_mem({{M, K}, memory::data_type::u8, {K, 1}}, eng);
173 memory zp_A_mem({{1}, memory::data_type::s32, {1}}, eng);
174 memory zp_C_mem({{1}, memory::data_type::s32, {1}}, eng);
175 memory sc_A_mem({{1}, memory::data_type::f32, {1}}, eng);
176 memory sc_B_mem({{N}, memory::data_type::f32, {1}}, eng);
177 memory sc_C_mem({{1}, memory::data_type::f32, {1}}, eng);
178
179 // the function below fills dnnl::memory with some values
180 // these memories, typically, come from the previous layers / operations
181 // with meaningful data inside
182 prepare_input(A_u8_mem, sc_A_mem, sc_B_mem, sc_C_mem, zp_A_mem, zp_C_mem);
183
184 // output - no initialization required
185 memory C_u8_mem({{M, N}, memory::data_type::u8, {N, 1}}, eng);
186
187 stream s(eng);
188 for (int run = 0; run < number_of_runs; ++run)
189 matmul_p.execute(s,
190 {{DNNL_ARG_SRC, A_u8_mem}, {DNNL_ARG_WEIGHTS, B_s8_mem},
191 {DNNL_ARG_DST, C_u8_mem},
192 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, sc_A_mem},
193 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem},
194 {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, sc_C_mem},
195 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_mem},
196 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_C_mem}});
197 s.wait();
198
199 // a sanity check for the correctness of the output
200 sanity_check(C_u8_mem, zp_C_mem);
201}
202
203void inference_int8_matmul(engine::kind engine_kind) {
204 engine eng(engine_kind, 0);
205
206 const int64_t K = 96;
207 const int64_t N = 1000;
208 auto matmul_pd = matmul_pd_create(K, N, eng);
209
210 // Original weights stored as float in a known format
211 std::vector<float> B_f32(K * N);
212 init_vector(B_f32);
213
214 // Pre-packed weights stored as int8_t
215 memory B_s8_mem(matmul_pd.weights_desc(), eng);
216 {
217 stream s(eng);
218 memory B_f32_mem(
219 {{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng);
220 write_to_dnnl_memory(B_f32.data(), B_f32_mem);
221 reorder(B_f32_mem, B_s8_mem).execute(s, B_f32_mem, B_s8_mem);
222 s.wait();
223 }
224
225 matmul matmul_p(matmul_pd);
226
227 for (int64_t M : {1, 100})
228 infer(matmul_p, M, N, K, B_s8_mem, eng);
229}
230
231int main(int argc, char **argv) {
232 engine::kind engine_kind = parse_engine_kind(argc, argv);
233 return handle_example_errors(inference_int8_matmul, engine_kind);
234}
235