1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/binary_injector_utils.hpp"
26#include "cpu/cpu_primitive.hpp"
27#include "cpu/gemm/gemm.hpp"
28#include "cpu/gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
29#include "cpu/gemm_x8s8s32x_convolution.hpp"
30#include "cpu/ref_io_helper.hpp"
31#include "cpu/scale_utils.hpp"
32#include "cpu/simple_q10n.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38using namespace dnnl::impl::utils;
39using namespace dnnl::impl::memory_tracking::names;
40
41const int32_t *mul_zp_src_comp_from_wei_by_zp_src(const int zp_comp_size,
42 int32_t *zp_src_comp_scratch_dst,
43 const int32_t *const zp_src_comp_from_wei, const int32_t zp_src) {
44 static constexpr int cache_line_size
45 = platform::get_cache_line_size() / sizeof(int);
46 const auto res = std::div(zp_comp_size, cache_line_size);
47
48 if (res.quot) {
49 parallel_nd(res.quot, [&](size_t shift_factor) {
50 const auto shift = shift_factor * cache_line_size;
51 const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
52 int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
53
54 PRAGMA_OMP_SIMD()
55 for (int i = 0; i < cache_line_size; ++i) {
56 dst[i] = src[i] * zp_src;
57 }
58 });
59 }
60
61 if (res.rem) {
62 const auto shift = res.quot * cache_line_size;
63 const int32_t *__restrict const src = zp_src_comp_from_wei + shift;
64 int32_t *__restrict dst = zp_src_comp_scratch_dst + shift;
65
66 PRAGMA_OMP_SIMD()
67 for (int i = 0; i < res.rem; ++i) {
68 dst[i] = src[i] * zp_src;
69 }
70 }
71
72 return zp_src_comp_scratch_dst;
73}
74
75static zero_point_call_params_t prepare_zp_params(const conv_gemm_conf_t &jcp,
76 const memory_tracking::grantor_t &scratchpad, const int8_t *weights,
77 const memory_desc_wrapper &weights_md, bool with_groups,
78 const int32_t *zp_src, const int32_t *zp_dst) {
79
80 int32_t *zp_src_comp_pad = nullptr;
81 const int32_t *zp_src_comp = nullptr;
82
83 if (jcp.zp.src_exists) {
84 const int32_t *zp_src_comp_from_wei = get_src_zp_comp_from_wei(
85 weights, weights_md, jcp.signed_input, jcp.ngroups, jcp.oc);
86 int32_t *zp_src_comp_scratch
87 = scratchpad.get<int32_t>(key_conv_gemm_zp_src_comp);
88 static constexpr auto cache_line_size
89 = platform::get_cache_line_size() / sizeof(int);
90 const auto zp_comp_size = jcp.oc * jcp.ngroups;
91
92 if (jcp.zp.src_is_common) {
93 zp_src_comp = mul_zp_src_comp_from_wei_by_zp_src(zp_comp_size,
94 zp_src_comp_scratch, zp_src_comp_from_wei, *zp_src);
95 } else
96 zp_src_comp = zp_src_comp_from_wei;
97
98 if (jit_gemm_convolution_utils::padding_exists(jcp)) {
99 const auto shift = jcp.zp.src_is_common
100 ? utils::rnd_up(zp_comp_size, cache_line_size)
101 : 0;
102 zp_src_comp_pad = zp_src_comp_scratch + shift;
103 compute_zp_src_comp_pad(jcp, zp_src_comp_pad, zp_src, weights,
104 weights_md, with_groups);
105 }
106 }
107
108 return {zp_src, zp_dst, zp_src_comp, zp_src_comp_pad};
109}
110
111status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward(
112 const exec_ctx_t &ctx) const {
113 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
114 auto src_base = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
115 auto wei_base = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
116 auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
117 auto dst_base = CTX_OUT_MEM(void *, DNNL_ARG_DST);
118 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
119 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
120 const auto post_ops_binary_rhs_arg_vec
121 = binary_injector_utils::prepare_binary_args(
122 this->pd()->attr()->post_ops_, ctx);
123
124 auto scratchpad = ctx.get_scratchpad_grantor();
125
126 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
127
128 const zero_point_call_params_t zp = prepare_zp_params(jcp, scratchpad,
129 wei_base, memory_desc_wrapper(pd()->weights_md(0)),
130 this->pd()->with_groups(), zp_src, zp_dst);
131
132 std::atomic<status_t> st(status::success);
133
134 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
135 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
136 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
137
138 const float *scales = precompute_scales(
139 scratchpad, src_scales, wei_scales, pd()->OC(), pd()->attr());
140
141 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
142 status_t st_thr = execute_forward_thr(ithr, nthr, src_base, wei_base,
143 bia_base, dst_base, scales, dst_scales, zp, scratchpad,
144 post_ops_binary_rhs_arg_vec.data(), ctx);
145
146 if (st_thr != status::success) st = st_thr;
147 });
148
149 return st;
150}
151
152static const int32_t *get_wei_comp(
153 const int8_t *weights, const memory_desc_wrapper &weights_md) {
154 const size_t comp_off
155 = weights_md.size() - weights_md.additional_buffer_size();
156 return reinterpret_cast<const int32_t *>(&weights[comp_off]);
157}
158
159status_t gemm_x8s8s32x_convolution_fwd_t::execute_forward_thr(const int ithr,
160 const int nthr, const char *src_base, const int8_t *wei_base,
161 const char *bia_base, void *dst_base, const float *scales,
162 const float *dst_scales, const zero_point_call_params_t &zp,
163 const memory_tracking::grantor_t &scratchpad,
164 const void *post_ops_binary_rhs_arg_vec, const exec_ctx_t &ctx) const {
165
166 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
167
168 const auto src_md = memory_desc_wrapper(pd()->src_md());
169 const size_t src_mb_stride = src_md.blk_off(1);
170 const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
171
172 const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
173 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
174
175 const auto dst_md = memory_desc_wrapper(pd()->dst_md());
176 const size_t dst_mb_stride = dst_md.blk_off(1);
177 const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
178
179 const auto &post_ops = pd()->attr()->post_ops_;
180 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
181 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
182
183 uint8_t *__restrict col = scratchpad.get<uint8_t>(key_conv_gemm_col)
184 + (ptrdiff_t)ithr * jcp.im2col_sz;
185 char *__restrict imtr = scratchpad.get<char>(key_conv_gemm_imtr)
186 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
187 int *__restrict acc = scratchpad.get<int>(key_conv_int_dat_in_acc_dt)
188 + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
189
190 const int32_t *_wei_comp
191 = jcp.signed_input ? get_wei_comp(wei_base, wei_md) : nullptr;
192
193 const bool should_apply_zp_src_comp_pad = jcp.zp.src_exists
194 && jit_gemm_convolution_utils::padding_exists(jcp);
195 const bool should_apply_zp_src_comp_pad_jit_pp
196 = should_apply_zp_src_comp_pad
197 && gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(
198 dst_md.data_type());
199 const bool should_apply_zp_src_comp_outside_pp
200 = should_apply_zp_src_comp_pad
201 && !gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(
202 dst_md.data_type());
203
204 dim_t g {0}, n {0}, ohb {0}, owb {0};
205 dim_t start = 0, end = 0;
206
207 const bool is_problem_3d = pd()->ndims() == 5;
208 assert(IMPLICATION(is_problem_3d,
209 jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
210 && jcp.ic_block == jcp.ic));
211
212 const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block);
213 const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block);
214 const dim_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow;
215 balance211(work_amount, nthr, ithr, start, end);
216 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
217 const uint8_t shift = jcp.signed_input ? 128 : 0;
218 parallel_nd(jcp.im2col_sz, [&](ptrdiff_t i) { col[i] = shift; });
219
220 status_t st = status::success;
221
222 for (dim_t iwork = start; iwork < end; ++iwork) {
223 const int oh = ohb * jcp.oh_block;
224 const int ow = owb * jcp.ow_block;
225 const char *__restrict src
226 = src_base + n * src_mb_stride + g * src_g_stride;
227 const int8_t *__restrict wei = wei_base + g * wei_g_stride;
228 const int32_t *__restrict wei_comp
229 = _wei_comp ? _wei_comp + g * jcp.oc : nullptr;
230 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
231 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
232 if (jcp.im2col_sz && is_problem_3d)
233 jit_gemm_convolution_utils::transpose_dt<char>(jcp, src, imtr);
234
235 for (int od = 0; od < jcp.od; od++) {
236 const auto dst_off = n * dst_mb_stride + g * dst_g_stride
237 + ((od * jcp.oh + oh) * jcp.ow + ow) * jcp.dst_os_stride;
238 char *__restrict dst = (char *)dst_base
239 + types::data_type_size(dst_md.data_type()) * dst_off;
240 if (jcp.im2col_sz) {
241 switch (src_md.data_type()) {
242 case data_type::s8: {
243 if (is_problem_3d)
244 jit_gemm_convolution_utils::im2col_dt_3d<int8_t,
245 uint8_t>(jcp, imtr, col, od);
246 else
247 jit_gemm_convolution_utils::im2col_dt<int8_t,
248 uint8_t>(jcp, src, imtr, col, oh, h_step,
249 ow, w_step);
250 } break;
251 case data_type::u8: {
252 if (is_problem_3d)
253 jit_gemm_convolution_utils::im2col_dt_3d<uint8_t,
254 uint8_t>(jcp, imtr, col, od);
255 else
256 jit_gemm_convolution_utils::im2col_dt<uint8_t,
257 uint8_t>(jcp, src, imtr, col, oh, h_step,
258 ow, w_step);
259 } break;
260 default: assert(!"unsupported data type"); break;
261 }
262 }
263
264 const dim_t M = jcp.oc;
265 const dim_t K = jcp.ks * jcp.ic;
266 const dim_t N = h_step * w_step;
267 const dim_t LDA = M * jcp.ngroups;
268 const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
269 const char *BT = jcp.im2col_sz ? "T" : "N";
270 const int8_t off_a = 0;
271 const uint8_t off_b = 0;
272 const int32_t off_c = 0;
273 const float onef = 1.f, zerof = 0.f;
274 const char *__restrict src_od
275 = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
276 st = gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F", &M, &N, &K,
277 &onef, wei, &LDA, &off_a,
278 jcp.im2col_sz ? col : (uint8_t *)src_od, &LDB, &off_b,
279 &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
280
281 if (st != status::success) return st;
282
283 const auto wei_adj_scale
284 = (wei_md.extra().flags & memory_extra_flags::scale_adjust)
285 ? wei_md.extra().scale_adjust
286 : 1.f;
287
288 if (should_apply_zp_src_comp_outside_pp)
289 apply_zp_src_comp_pad(jcp, g, od, oh, ow, h_step, w_step, acc,
290 zp.src_pad_comp);
291
292 const single_gemm_conv_chunk_desc_t chunk_desc
293 = should_apply_zp_src_comp_pad_jit_pp
294 ? single_gemm_conv_chunk_desc_t {od, 1, oh, h_step, ow,
295 w_step}
296 : single_gemm_conv_chunk_desc_t {};
297
298 parallel(0, [&](int ithr, int nthr) {
299 dim_t _start {}, _end {};
300 balance211(N * jcp.oc, nthr, ithr, _start, _end);
301
302 (*pp_ker_)(dst, acc, bia_base, scales, dst_scales[0], sum_scale,
303 1.f / wei_adj_scale, g, _start, _end, zp,
304 post_ops_binary_rhs_arg_vec, dst_base, ctx,
305 *pd()->dst_md(), chunk_desc);
306 });
307 }
308 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
309 }
310
311 return st;
312}
313
314status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data(
315 const exec_ctx_t &ctx) const {
316 auto diff_dst_base = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST);
317 auto wei_base = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
318 auto bia_base = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
319 auto diff_src_base = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
320
321 auto scratchpad = ctx.get_scratchpad_grantor();
322
323 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
324
325 std::atomic<status_t> st(status::success);
326
327 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
328 status_t st_thr = execute_backward_data_thr(ithr, nthr, diff_dst_base,
329 wei_base, bia_base, diff_src_base, scratchpad, ctx);
330
331 if (st_thr != status::success) st = st_thr;
332 });
333
334 return st;
335}
336
337status_t gemm_x8s8s32x_convolution_bwd_data_t::execute_backward_data_thr(
338 const int ithr, const int nthr, const char *diff_dst_base,
339 const int8_t *wei_base, const char *bia_base, char *diff_src_base,
340 const memory_tracking::grantor_t &scratchpad,
341 const exec_ctx_t &ctx) const {
342 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
343
344 const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
345 const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
346 const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
347
348 const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
349 const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
350
351 const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
352 const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
353 const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
354 const size_t diff_src_os_stride
355 = diff_src_md.blocking_desc().strides[pd()->ndims() - 1];
356 const auto diff_src_dt_size
357 = types::data_type_size(diff_src_md.data_type());
358
359 const int scale_idx_mult = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_
360 == (1 << static_cast<int>(pd()->with_groups()));
361 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
362 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
363 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
364
365 const float *scales = precompute_scales(
366 scratchpad, src_scales, wei_scales, pd()->OC(), pd()->attr());
367
368 const dim_t work_amount = jcp.ngroups * jcp.mb;
369
370 int *__restrict col = scratchpad.get<int>(key_conv_gemm_col)
371 + (ptrdiff_t)ithr * jcp.im2col_sz;
372 int *__restrict acc = scratchpad.get<int>(key_conv_int_dat_in_acc_dt)
373 + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
374
375 dim_t n = 0, g = 0;
376 dim_t start = 0, end = 0;
377
378 balance211(work_amount, nthr, ithr, start, end);
379 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
380
381 for (dim_t iwork = start; iwork < end; ++iwork) {
382 const int8_t *__restrict wei = wei_base + g * wei_g_stride;
383 char *__restrict diff_src = diff_src_base
384 + diff_src_dt_size
385 * (n * diff_src_mb_stride + g * diff_src_g_stride);
386
387 const dim_t M = jcp.ks * jcp.ic;
388 const dim_t N = jcp.os * jcp.od;
389 const dim_t K = jcp.oc;
390 const int8_t off_a = 0;
391 const int32_t off_c = 0;
392 const float onef = 1.0, zerof = 0.0;
393 const dim_t LD = K * jcp.ngroups;
394
395 status_t st = status::runtime_error;
396 switch (diff_dst_md.data_type()) {
397 case data_type::s8: {
398 const int8_t *__restrict diff_dst
399 = reinterpret_cast<const int8_t *>(diff_dst_base)
400 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
401 const int8_t off_b = 0;
402 st = gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, wei, &LD,
403 &off_a, diff_dst, &LD, &off_b, &zerof,
404 jcp.im2col_sz ? col : acc, &M, &off_c);
405 } break;
406 case data_type::u8: {
407 const uint8_t *__restrict diff_dst
408 = reinterpret_cast<const uint8_t *>(diff_dst_base)
409 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
410 const uint8_t off_b = 0;
411 st = gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef, wei, &LD,
412 &off_a, diff_dst, &LD, &off_b, &zerof,
413 jcp.im2col_sz ? col : acc, &M, &off_c);
414 } break;
415 default: assert(!"unsupported data type"); break;
416 }
417
418 if (st != status::success) return st;
419
420 if (jcp.im2col_sz)
421 jit_gemm_convolution_utils::col2im_dt<int32_t>(jcp, col, acc);
422
423 parallel_nd(jcp.is * jcp.id, [&](dim_t is) {
424 char *__restrict diff_src_loc
425 = diff_src + diff_src_dt_size * is * diff_src_os_stride;
426 const int *__restrict acc_loc = acc + is * jcp.ic;
427 const float *__restrict scales_loc
428 = scales + g * jcp.ic * scale_idx_mult;
429 for (int ic = 0; ic < jcp.ic; ic++) {
430 float d = static_cast<float>(acc_loc[ic]);
431 d *= scales_loc[ic * scale_idx_mult];
432 if (jcp.with_bias) {
433 const float b = io::load_float_value(
434 pd()->desc()->bias_desc.data_type, bia_base,
435 g * jcp.ic + ic);
436 d += b;
437 }
438 if (jcp.with_dst_scale) d *= dst_scales[0];
439 io::store_float_value(
440 diff_src_md.data_type(), d, diff_src_loc, ic);
441 }
442 });
443 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
444 }
445
446 return status::success;
447}
448
449} // namespace cpu
450} // namespace impl
451} // namespace dnnl
452