1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23
24#include "cpu/x64/jit_generator.hpp"
25
26#include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace dnnl::impl::status;
34using namespace dnnl::impl::memory_tracking::names;
35using namespace dnnl::impl::utils;
36
37/* convolution forward */
38status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward(
39 const exec_ctx_t &ctx) const {
40 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
41 const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
42 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
43 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
44 auto weights_dw = CTX_IN_MEM(
45 const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
46 auto bias_dw = CTX_IN_MEM(
47 const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
48 const auto post_ops_binary_rhs_arg_vec
49 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
50 const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_
51 ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
52 pd()->jcp_.post_ops.entry_.size() + 1)
53 : std::vector<const void *> {};
54
55 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
56 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
57 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
58
59 DEFINE_ARG_SCALES_BUFFER(
60 dw_wei_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
61 DEFINE_ARG_SCALES_BUFFER(
62 dw_dst_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST);
63
64 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
65 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
66
67 auto scratchpad = ctx.get_scratchpad_grantor();
68
69 auto local_scales
70 = scratchpad.template get<float>(key_conv_adjusted_scales);
71 // Src scale is always a single value
72 float src_scale = src_scales[0];
73 int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
74 float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
75 ? 1.f / pd()->jcp_.wei_adj_scale
76 : 1.f;
77 switch (wei_mask) {
78 case 0:
79 utils::array_set(local_scales, src_scale * wei_scales[0] * factor,
80 pd()->jcp_.ic_block);
81 break;
82 default:
83 for (dim_t c = 0; c < pd()->OC(); c++)
84 local_scales[c] = src_scale * wei_scales[c] * factor;
85 }
86
87 const float *dw_oscales = nullptr;
88 if (pd()->jcp_.with_dw_conv) {
89 auto jcp_dw = pd()->jcp_dw_;
90 memory_tracking::grantor_t dw_scratchpad(
91 scratchpad, memory_tracking::names::prefix_fusion);
92 auto dw_local_scales
93 = dw_scratchpad.template get<float>(key_conv_adjusted_scales);
94 auto attr_dw = pd()->dw_conv_pd_->attr();
95 int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_;
96 dim_t count = wei_mask == 0 ? 1 : pd()->dw_conv_pd_->OC();
97 float factor = 1.f / jcp_dw->wei_adj_scale;
98 if (count == 1) {
99 utils::array_set(dw_local_scales,
100 dw_wei_scales[0] / dst_scales[0] * factor,
101 pd()->jcp_.ic_block);
102 } else {
103 for (dim_t c = 0; c < count; c++)
104 dw_local_scales[c] = dw_wei_scales[c] / dst_scales[0] * factor;
105 }
106 dw_oscales = dw_local_scales;
107 }
108 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) {
109 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
110 dst, local_scales, dst_scales, dw_oscales, dw_dst_scales,
111 src_zero_point, dst_zero_point, scratchpad,
112 post_ops_binary_rhs_arg_vec.data(),
113 post_ops_binary_rhs_arg_vec_dw.data());
114 });
115 return status::success;
116}
117
118void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::execute_forward_thr(
119 const int ithr, const int nthr, const char *src, const char *weights,
120 const char *bias, const char *weights_dw, const char *bias_dw,
121 char *dst, const float *oscales, const float *dst_scales,
122 const float *dw_oscales, const float *dw_dst_scales,
123 const int32_t *src_zero_point, const int32_t *dst_zero_point,
124 const memory_tracking::grantor_t &scratchpad,
125 const void *post_ops_binary_rhs_arg_vec,
126 const void *post_ops_binary_rhs_arg_vec_dw) const {
127 const memory_desc_wrapper src_d(pd()->src_md());
128 const memory_desc_wrapper dst_d(pd()->dst_md());
129 const memory_desc_wrapper weights_d(pd()->weights_md(0));
130 const memory_desc_wrapper dw_weights_d(
131 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
132
133 const auto &jcp = pd()->jcp_;
134
135 const size_t src_dt_size = types::data_type_size(src_d.data_type());
136 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
137 const size_t bia_dt_size = pd()->with_bias()
138 ? types::data_type_size(pd()->desc()->bias_desc.data_type)
139 : 0;
140
141 auto rtus_space = pd()->rtus_.reduce_src_
142 ? scratchpad.get<char>(key_conv_rtus_space)
143 : nullptr;
144
145 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
146
147 const bool is_2d = pd()->ndims() == 4;
148 const bool is_3d = pd()->ndims() == 5;
149
150 const int stride_d = pd()->KSD();
151 const int stride_h = pd()->KSH();
152 const int stride_w = pd()->KSW();
153
154 auto offset = weights_d.size() - weights_d.additional_buffer_size();
155 char *w = const_cast<char *>(weights);
156 const int32_t *compensation = (jcp.signed_input)
157 ? reinterpret_cast<int32_t *>(w + offset)
158 : nullptr;
159 const int32_t *zp_compensation = jcp.src_zero_point
160 ? reinterpret_cast<int32_t *>(w + offset)
161 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
162 : nullptr;
163
164 auto p = jit_1x1_conv_call_s();
165
166 auto rp = rtus_driver_t<avx512_core>::call_params_t();
167 const int nb_oc = jcp.nb_load;
168 const int nb_ic = jcp.nb_reduce;
169 // override some constants for fused dw_conv
170 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
171 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
172 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
173 const int nb_bcast_blocking_max
174 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
175 const int nb_load_blocking = jcp.nb_load_blocking;
176 const int nb_load_blocking_max = jcp.with_dw_conv
177 ? jcp.nb_load_blocking
178 : jcp.nb_load_blocking_max;
179
180 // Begin: declare Variables needed for dw conv.
181 const auto jcp_dw = pd()->jcp_dw_;
182 const auto &dw_pd = pd()->dw_conv_pd_;
183 memory_tracking::grantor_t dw_scratchpad(
184 scratchpad, memory_tracking::names::prefix_fusion);
185
186 size_t dw_bia_dt_size = 0;
187 int32_t *compensation_dw {nullptr};
188 if (jcp.with_dw_conv && jcp_dw) {
189 if (jcp_dw->with_bias)
190 dw_bia_dt_size
191 = types::data_type_size(dw_pd->desc()->bias_desc.data_type);
192
193 offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size();
194 w = const_cast<char *>(weights_dw);
195 compensation_dw = (jcp_dw->signed_input)
196 ? reinterpret_cast<int32_t *>(w + offset)
197 : nullptr;
198 dw_oscales = dw_scratchpad.get<float>(key_conv_adjusted_scales);
199 }
200
201 char *pbuf {nullptr};
202 size_t row_offset {};
203 const int nb_buffer = jcp.nb_load_blocking;
204 std::vector<char *> addrs;
205 // End
206
207 auto step = [](int default_step, int remaining, int tail_step) {
208 assert(default_step <= tail_step);
209 return remaining < tail_step ? remaining : default_step;
210 };
211
212 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
213 int &bcast_step, int &od, int &oh, int &ow,
214 int &id, int &ih, int &iw) {
215 int osb {0};
216 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
217 bcast_step = step(
218 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
219 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
220
221 const int os = osb * os_block;
222 const int depth_orthogonal_area = jcp.ow * jcp.oh;
223 od = os / depth_orthogonal_area;
224 oh = (os % depth_orthogonal_area) / jcp.ow;
225 ow = (os % depth_orthogonal_area) % jcp.ow;
226
227 id = od * stride_d;
228 ih = oh * stride_h;
229 iw = ow * stride_w;
230 rp.iw_start = iw;
231
232 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
233 rp.os = p.bcast_dim;
234 };
235
236 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
237 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
238 p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block,
239 load_step * jcp.oc_block);
240
241 if (ocb + load_step >= nb_oc)
242 p.first_last_flag |= FLAG_OC_LAST;
243 else
244 p.first_last_flag &= ~FLAG_OC_LAST;
245 };
246
247 auto init_reduce = [&]() {
248 p.reduce_dim = this_block_size(
249 0, jcp.ic_without_padding, jcp.ic_without_padding);
250 rp.icb = p.reduce_dim;
251 };
252
253 auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh,
254 int ow, int id, int ih, int iw) {
255 const int icb = 0; // Start from the first IC block
256 const int _ocb = g * nb_oc + ocb;
257 const int _icb = g * nb_ic + icb;
258
259 const size_t dst_off = is_3d
260 ? dst_d.blk_off(n, _ocb * jcp.oc_block, od, oh, ow)
261 : is_2d ? dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow)
262 : dst_d.blk_off(n, _ocb * jcp.oc_block, ow);
263
264 p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset
265 : dst + dst_dt_size * dst_off;
266 const auto wei_offset = pd()->with_groups()
267 ? weights_d.blk_off(g, ocb, icb)
268 : weights_d.blk_off(ocb, icb);
269 p.load_data = weights + wei_offset;
270 p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
271 p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block]
272 : nullptr;
273 p.zp_compensation = jcp.src_zero_point
274 ? zp_compensation + _ocb * jcp.oc_block
275 : nullptr;
276 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
277 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
278 p.scales = &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block];
279 p.dst_scale = dst_scales;
280 const size_t src_off = is_3d
281 ? src_d.blk_off(n, _icb * jcp.ic_block, id, ih, iw)
282 : is_2d ? src_d.blk_off(n, _icb * jcp.ic_block, ih, iw)
283 : src_d.blk_off(n, _icb * jcp.ic_block, iw);
284 if (pd()->rtus_.reduce_src_) {
285 rp.ws = rtus_space
286 + src_dt_size
287 * (ithr * pd()->rtus_.space_per_thread_
288 + _icb * jcp.is * jcp.ic_block);
289 if (ocb == ocb_start) {
290 rp.src = src + src_dt_size * src_off;
291 (*rtus_driver_)(&rp);
292 }
293 p.bcast_data = rp.ws;
294 } else
295 p.bcast_data = src + src_dt_size * src_off;
296
297 p.dst_l_off = dst_off;
298 p.oc_l_off = _ocb * jcp.oc_block;
299 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
300 p.dst_orig = dst;
301
302 (*kernel_)(&p);
303 };
304
305 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
306 int ocb_end) {
307 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
308 if (jcp.loop_order == loop_rlb) {
309 init_reduce();
310 int ocb = ocb_start;
311 while (ocb < ocb_end) {
312 int load_step;
313 init_load(ocb, ocb_end, load_step);
314 int iwork = bcast_start;
315 while (iwork < bcast_end) {
316 int n, g, bcast_step, od, oh, ow, id, ih, iw;
317 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
318 id, ih, iw);
319 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
320 iwork += bcast_step;
321 }
322 ocb += load_step;
323 }
324 } else if (jcp.loop_order == loop_lbr) {
325 int ocb = ocb_start;
326 while (ocb < ocb_end) {
327 int load_step;
328 init_load(ocb, ocb_end, load_step);
329 int iwork = bcast_start;
330 while (iwork < bcast_end) {
331 int n, g, bcast_step, od, oh, ow, id, ih, iw;
332 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
333 id, ih, iw);
334 init_reduce();
335 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
336 iwork += bcast_step;
337 }
338 ocb += load_step;
339 }
340 } else if (jcp.loop_order == loop_rbl) {
341 init_reduce();
342 int iwork = bcast_start;
343 while (iwork < bcast_end) {
344 int n, g, bcast_step, od, oh, ow, id, ih, iw;
345 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
346 ih, iw);
347 int ocb = ocb_start;
348 while (ocb < ocb_end) {
349 int load_step;
350 init_load(ocb, ocb_end, load_step);
351 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
352 ocb += load_step;
353 }
354 iwork += bcast_step;
355 }
356 } else if (jcp.loop_order == loop_blr) {
357 int iwork = bcast_start;
358 while (iwork < bcast_end) {
359 int n, g, bcast_step, od, oh, ow, id, ih, iw;
360 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
361 ih, iw);
362 int ocb = ocb_start;
363 while (ocb < ocb_end) {
364 int load_step;
365 init_load(ocb, ocb_end, load_step);
366 init_reduce();
367 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
368 ocb += load_step;
369 }
370 iwork += bcast_step;
371 }
372 } else {
373 assert(!"unsupported loop order");
374 }
375 };
376
377 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
378 int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad;
379 int oh_1x1_begin = nstl::max(oh_1x1, 0);
380
381 for (int i = 0; i < jcp_dw->kh; ++i)
382 addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset;
383
384 const auto ocb_end = ocb_start + load_step;
385 const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
386 auto par_conv_dw = jit_conv_call_s();
387
388 par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1));
389 par_conv_dw.b_overflow = nstl::min(
390 jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh));
391 par_conv_dw.kh_padding = nstl::max<int>(0,
392 jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow);
393
394 const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow
395 + dw_oh * jcp_dw->ow * jcp_dw->ngroups;
396
397 const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1);
398 const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow
399 * wht_h_stride;
400 for (int ocb = ocb_start; ocb < ocb_end;
401 ocb += jcp_dw->nb_ch_blocking) {
402
403 par_conv_dw.src = addrs.data();
404 par_conv_dw.dst = dst
405 + (dst_offset + jcp_dw->ch_block * ocb)
406 * jcp_dw->typesize_out;
407
408 par_conv_dw.filt
409 = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride;
410 par_conv_dw.bias
411 = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size];
412 par_conv_dw.ur_w = (size_t)(jcp_dw->ow);
413 par_conv_dw.owb = jcp_dw->ow;
414 par_conv_dw.oc_blocks = ocb;
415 par_conv_dw.compensation = compensation_dw
416 ? &compensation_dw[ocb * jcp_dw->ch_block]
417 : nullptr;
418 par_conv_dw.scales = dw_oscales
419 ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block]
420 : nullptr;
421 par_conv_dw.dst_scale = dw_dst_scales;
422
423 par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block;
424 par_conv_dw.post_ops_binary_rhs_arg_vec
425 = post_ops_binary_rhs_arg_vec_dw;
426 par_conv_dw.dst_orig = dst;
427
428 (*kernel_dw_)(&par_conv_dw);
429
430 for (int i = 0; i < jcp_dw->kh; ++i)
431 addrs[i] += src_ch_stride;
432 }
433 };
434
435 auto conv_dw = [&]() {
436 auto jcp_dw = pd()->jcp_dw_;
437 auto dw_conv_buffer = dw_scratchpad.get<char>(key_fusion_inout_buffer);
438
439 const auto dw_conv_buffer_size_
440 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
441 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
442 row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
443 addrs.resize(jcp_dw->kh);
444
445 int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end;
446 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
447 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count);
448
449 while (ocb_start < ocb_end) {
450 int load_step;
451 init_load(ocb_start, ocb_end, load_step);
452
453 int oh_1x1 = 0;
454 auto bcast_iter = bcast_start;
455 while (bcast_iter < bcast_end) {
456 int n, g, oh_dw;
457 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
458 jcp_dw->oh);
459 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
460 const int oh_1x1_range
461 = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
462 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
463 const int oh_1x1_end
464 = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
465 oh_1x1 = nstl::max(
466 oh_1x1_begin, oh_1x1); // Skip rows computed previously
467
468 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh
469 const int bcast_start_1x1
470 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
471 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
472
473 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
474 ocb_start + load_step);
475 oh_1x1 = oh_1x1_end;
476 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
477
478 bcast_iter += nb_bcast_blocking;
479 }
480 ocb_start += load_step;
481 }
482 };
483
484 if (jcp.with_dw_conv) {
485 conv_dw();
486 } else {
487 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
488 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
489 jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
490 jcp.load_grp_count);
491 if (jcp.nb_load_chunk > 1) {
492 ocb_start *= jcp.nb_load_chunk;
493 ocb_end *= jcp.nb_load_chunk;
494 }
495 conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end);
496 }
497}
498
499} // namespace x64
500} // namespace cpu
501} // namespace impl
502} // namespace dnnl
503