1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35
36#define data_blk_off(f, n, c, d, h, w) \
37 ((ndims == 3) ? (f).blk_off(n, c, w) \
38 : ((ndims == 4) ? (f).blk_off(n, c, h, w) \
39 : (f).blk_off(n, c, d, h, w)))
40
41/* convolution forward */
42template <cpu_isa_t isa>
43status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::execute_forward(
44 const exec_ctx_t &ctx) const {
45 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
46 const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
47 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
48 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
49 auto weights_dw = CTX_IN_MEM(
50 const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
51 auto bias_dw = CTX_IN_MEM(
52 const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
53 const auto post_ops_binary_rhs_arg_vec
54 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
55 const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_
56 ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
57 pd()->jcp_.post_ops.entry_.size() + 1)
58 : std::vector<const void *> {};
59
60 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
61 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
62
63 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
64 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
65 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
66
67 DEFINE_ARG_SCALES_BUFFER(
68 dw_wei_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
69 DEFINE_ARG_SCALES_BUFFER(
70 dw_dst_scales, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST);
71
72 auto scratchpad = ctx.get_scratchpad_grantor();
73
74 auto local_scales
75 = scratchpad.template get<float>(key_conv_adjusted_scales);
76 const float factor = (pd()->jcp_.signed_input && (!pd()->jcp_.has_vnni))
77 ? 1.f / pd()->jcp_.wei_adj_scale
78 : 1.0f;
79 int wei_mask = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
80 if (wei_mask == 0) {
81 utils::array_set(
82 local_scales, src_scales[0] * wei_scales[0] * factor, 8);
83 } else {
84 for (dim_t c = 0; c < pd()->OC(); c++)
85 local_scales[c] = src_scales[0] * wei_scales[c] * factor;
86 }
87
88 const float *dw_oscales = nullptr;
89 if (pd()->jcp_.with_dw_conv) {
90 auto jcp_dw = pd()->jcp_dw_;
91 memory_tracking::grantor_t dw_scratchpad(
92 scratchpad, memory_tracking::names::prefix_fusion);
93 auto attr_dw = pd()->dw_conv_pd_->attr();
94
95 auto dw_local_scales
96 = dw_scratchpad.template get<float>(key_conv_adjusted_scales);
97 int wei_mask = attr_dw->scales_.get(DNNL_ARG_WEIGHTS).mask_;
98 float factor = 1.f / jcp_dw->wei_adj_scale;
99 if (wei_mask == 0) {
100 utils::array_set(dw_local_scales,
101 dw_wei_scales[0] / dst_scales[0] * factor,
102 pd()->jcp_.ic_block);
103 } else {
104 for (dim_t c = 0; c < pd()->dw_conv_pd_->OC(); c++)
105 dw_local_scales[c] = dw_wei_scales[c] / dst_scales[0] * factor;
106 }
107 dw_oscales = dw_local_scales;
108 }
109 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) {
110 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
111 dst, local_scales, dst_scales, dw_oscales, dw_dst_scales,
112 src_zero_point, dst_zero_point, scratchpad,
113 post_ops_binary_rhs_arg_vec.data(),
114 post_ops_binary_rhs_arg_vec_dw.data());
115 });
116 return status::success;
117}
118
119template <cpu_isa_t isa>
120void jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::execute_forward_thr(
121 const int ithr, const int nthr, const char *src, const char *weights,
122 const char *bias, const char *weights_dw, const char *bias_dw,
123 char *dst, const float *oscales, const float *dst_scales,
124 const float *dw_oscales, const float *dw_dst_scales,
125 const int32_t *src_zero_point, const int32_t *dst_zero_point,
126 const memory_tracking::grantor_t &scratchpad,
127 const void *post_ops_binary_rhs_arg_vec,
128 const void *post_ops_binary_rhs_arg_vec_dw) const {
129 const memory_desc_wrapper src_d(pd()->src_md());
130 const memory_desc_wrapper dst_d(pd()->dst_md());
131 const memory_desc_wrapper weights_d(pd()->weights_md(0));
132 const memory_desc_wrapper dw_weights_d(
133 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
134
135 const auto &jcp = pd()->jcp_;
136
137 const size_t src_dt_size = types::data_type_size(src_d.data_type());
138 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
139 const size_t bia_dt_size = pd()->with_bias()
140 ? types::data_type_size(pd()->desc()->bias_desc.data_type)
141 : 0;
142
143 auto rtus_space = pd()->rtus_.reduce_src_
144 ? scratchpad.get<char>(key_conv_rtus_space)
145 : nullptr;
146
147 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
148
149 const int ndims = dst_d.ndims();
150 const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1;
151 const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4];
152 const int stride_w = pd()->desc()->strides[ndims - 3];
153
154 auto offset = weights_d.size() - weights_d.additional_buffer_size();
155 char *w = const_cast<char *>(weights);
156 const int32_t *compensation = (jcp.signed_input)
157 ? reinterpret_cast<int32_t *>(w + offset)
158 : nullptr;
159 const int32_t *zp_compensation = jcp.src_zero_point
160 ? reinterpret_cast<int32_t *>(&w[offset])
161 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
162 : nullptr;
163
164 auto p = jit_1x1_conv_call_s();
165
166 auto rp = typename rtus_driver_t<isa>::call_params_t();
167 const int nb_oc = jcp.nb_load;
168 // override some constants for fused dw_conv
169 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
170 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
171 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
172 const int nb_bcast_blocking_max
173 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
174 const int nb_load_blocking = jcp.nb_load_blocking;
175 const int nb_load_blocking_max = jcp.with_dw_conv
176 ? jcp.nb_load_blocking
177 : jcp.nb_load_blocking_max;
178
179 // Begin: declare Variables needed for dw conv.
180 const auto jcp_dw = pd()->jcp_dw_;
181 const auto &dw_pd = pd()->dw_conv_pd_;
182 memory_tracking::grantor_t dw_scratchpad(
183 scratchpad, memory_tracking::names::prefix_fusion);
184
185 const size_t dw_bia_dt_size = jcp_dw && jcp_dw->with_bias
186 ? types::data_type_size(dw_pd->desc()->bias_desc.data_type)
187 : 0;
188
189 int32_t *compensation_dw {nullptr};
190
191 if (jcp.with_dw_conv) {
192 offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size();
193 w = const_cast<char *>(weights_dw);
194 compensation_dw = (jcp_dw->signed_input)
195 ? reinterpret_cast<int32_t *>(w + offset)
196 : nullptr;
197 }
198
199 char *pbuf {nullptr};
200 size_t row_offset {};
201 const int nb_buffer = jcp.nb_load_blocking;
202 std::vector<char *> addrs;
203 // End
204
205 auto step = [](int default_step, int remaining, int tail_step) {
206 assert(default_step <= tail_step);
207 return remaining < tail_step ? remaining : default_step;
208 };
209
210 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
211 int &bcast_step, int &od, int &oh, int &ow,
212 int &id, int &ih, int &iw) {
213 int osb {0};
214 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
215 bcast_step = step(
216 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
217 bcast_step = nstl::min(bcast_step, bcast_end - iwork);
218
219 const int os = osb * os_block;
220 od = os / (jcp.oh * jcp.ow);
221 const int os_2d = os % (jcp.oh * jcp.ow);
222 oh = os_2d / jcp.ow;
223 ow = os_2d % jcp.ow;
224
225 id = od * stride_d;
226 ih = oh * stride_h;
227 iw = ow * stride_w;
228 rp.iw_start = iw;
229
230 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
231 rp.os = p.bcast_dim;
232 };
233
234 auto init_load = [&](int ocb, int ocb_end, int &load_step) {
235 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
236 p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block,
237 load_step * jcp.oc_block);
238
239 if (ocb + load_step >= nb_oc)
240 p.first_last_flag |= FLAG_OC_LAST;
241 else
242 p.first_last_flag &= ~FLAG_OC_LAST;
243 };
244
245 auto init_reduce = [&]() {
246 p.reduce_dim = this_block_size(
247 0, jcp.ic_without_padding, jcp.ic_without_padding);
248 rp.icb = p.reduce_dim;
249 };
250
251 auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh,
252 int ow, int id, int ih, int iw) {
253 const int icb = 0; // Start from the first IC block
254 const int _ocb = g * nb_oc + ocb;
255 const int _icb = g;
256
257 const auto src_offset
258 = data_blk_off(src_d, n, _icb * jcp.ic_block, id, ih, iw);
259 const auto dst_offset
260 = data_blk_off(dst_d, n, _ocb * jcp.oc_block, od, oh, ow);
261 p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset
262 : dst + dst_dt_size * dst_offset;
263
264 const auto wei_offset = pd()->with_groups()
265 ? weights_d.blk_off(g, ocb, icb)
266 : weights_d.blk_off(ocb, icb);
267 p.load_data = weights + wei_offset;
268 p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
269 p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block]
270 : nullptr;
271 p.zp_compensation = jcp.src_zero_point
272 ? zp_compensation + _ocb * jcp.oc_block
273 : nullptr;
274 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
275 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
276 p.scales = &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block];
277 p.dst_scale = dst_scales;
278 if (pd()->rtus_.reduce_src_) {
279 rp.ws = rtus_space
280 + src_dt_size
281 * (ithr * pd()->rtus_.space_per_thread_
282 + _icb * jcp.is * jcp.ic_block);
283 if (ocb == ocb_start) {
284 rp.src = src + src_dt_size * src_offset;
285 (*rtus_driver_)(&rp);
286 }
287 p.bcast_data = rp.ws;
288 } else
289 p.bcast_data = src + src_dt_size * src_offset;
290
291 p.oc_l_off = g * nb_oc + ocb * jcp.oc_block;
292 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
293 p.dst_orig = jcp.with_dw_conv ? pbuf : dst;
294
295 (*kernel_)(&p);
296 };
297
298 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
299 int ocb_end) {
300 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
301 if (jcp.loop_order == loop_rlb) {
302 init_reduce();
303 int ocb = ocb_start;
304 while (ocb < ocb_end) {
305 int load_step;
306 init_load(ocb, ocb_end, load_step);
307 int iwork = bcast_start;
308 while (iwork < bcast_end) {
309 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
310 id {0}, ih {0}, iw {0};
311 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
312 id, ih, iw);
313 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
314 iwork += bcast_step;
315 }
316 ocb += load_step;
317 }
318 } else if (jcp.loop_order == loop_lbr) {
319 int ocb = ocb_start;
320 while (ocb < ocb_end) {
321 int load_step;
322 init_load(ocb, ocb_end, load_step);
323 int iwork = bcast_start;
324 while (iwork < bcast_end) {
325 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
326 id {0}, ih {0}, iw {0};
327 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
328 id, ih, iw);
329 init_reduce();
330 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
331 iwork += bcast_step;
332 }
333 ocb += load_step;
334 }
335 } else if (jcp.loop_order == loop_rbl) {
336 init_reduce();
337 int iwork = bcast_start;
338 while (iwork < bcast_end) {
339 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
340 id {0}, ih {0}, iw {0};
341 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
342 ih, iw);
343 int ocb = ocb_start;
344 while (ocb < ocb_end) {
345 int load_step;
346 init_load(ocb, ocb_end, load_step);
347 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
348 ocb += load_step;
349 }
350 iwork += bcast_step;
351 }
352 } else if (jcp.loop_order == loop_blr) {
353 int iwork = bcast_start;
354 while (iwork < bcast_end) {
355 int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0},
356 id {0}, ih {0}, iw {0};
357 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
358 ih, iw);
359 int ocb = ocb_start;
360 while (ocb < ocb_end) {
361 int load_step;
362 init_load(ocb, ocb_end, load_step);
363 init_reduce();
364 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
365 ocb += load_step;
366 }
367 iwork += bcast_step;
368 }
369 } else {
370 assert(!"unsupported loop order");
371 }
372 };
373
374 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
375 int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad;
376 int oh_1x1_begin = nstl::max(oh_1x1, 0);
377
378 for (int i = 0; i < jcp_dw->kh; ++i)
379 addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset;
380
381 const auto ocb_end = ocb_start + load_step;
382 const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
383 auto par_conv_dw = jit_conv_call_s();
384
385 par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1));
386 par_conv_dw.b_overflow = nstl::min(
387 jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh));
388 par_conv_dw.kh_padding = nstl::max<int>(0,
389 jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow);
390
391 const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow
392 + dw_oh * jcp_dw->ow * jcp_dw->ngroups;
393
394 const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1);
395 const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow
396 * wht_h_stride;
397 for (int ocb = ocb_start; ocb < ocb_end;
398 ocb += jcp_dw->nb_ch_blocking) {
399
400 par_conv_dw.src = addrs.data();
401 par_conv_dw.dst = dst
402 + (dst_offset + jcp_dw->ch_block * ocb)
403 * jcp_dw->typesize_out;
404
405 par_conv_dw.filt
406 = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride;
407 par_conv_dw.bias
408 = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size];
409 par_conv_dw.ur_w = (size_t)(jcp_dw->ow);
410 par_conv_dw.owb = jcp_dw->ow;
411 par_conv_dw.oc_blocks = ocb;
412 par_conv_dw.compensation = compensation_dw
413 ? &compensation_dw[ocb * jcp_dw->ch_block]
414 : nullptr;
415 par_conv_dw.scales = dw_oscales
416 ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block]
417 : nullptr;
418 par_conv_dw.dst_scale = dw_dst_scales;
419
420 par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block;
421 par_conv_dw.post_ops_binary_rhs_arg_vec
422 = post_ops_binary_rhs_arg_vec_dw;
423 par_conv_dw.dst_orig = dst;
424
425 (*kernel_dw_)(&par_conv_dw);
426
427 for (int i = 0; i < jcp_dw->kh; ++i)
428 addrs[i] += src_ch_stride;
429 }
430 };
431
432 auto conv_dw = [&]() {
433 auto &jcp_dw = pd()->jcp_dw_;
434 auto dw_conv_buffer = dw_scratchpad.get<char>(key_fusion_inout_buffer);
435
436 const auto dw_conv_buffer_size_
437 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
438 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
439 row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
440 addrs.resize(jcp_dw->kh);
441
442 int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end;
443 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
444 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count);
445
446 while (ocb_start < ocb_end) {
447 int load_step;
448 init_load(ocb_start, ocb_end, load_step);
449
450 int oh_1x1 = 0;
451 auto bcast_iter = bcast_start;
452 while (bcast_iter < bcast_end) {
453 int n {0}, g {0}, oh_dw {0};
454 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
455 jcp_dw->oh);
456 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
457 const int oh_1x1_range
458 = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
459 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
460 const int oh_1x1_end
461 = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
462 oh_1x1 = nstl::max(
463 oh_1x1_begin, oh_1x1); // Skip rows computed previously
464
465 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh
466 const int bcast_start_1x1
467 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
468 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
469
470 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
471 ocb_start + load_step);
472 oh_1x1 = oh_1x1_end;
473 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
474
475 bcast_iter += nb_bcast_blocking;
476 }
477 ocb_start += load_step;
478 }
479 };
480
481 if (jcp.with_dw_conv) {
482 conv_dw();
483 } else {
484 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
485 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
486 jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
487 jcp.load_grp_count);
488 if (jcp.nb_load_chunk > 1) {
489 ocb_start *= jcp.nb_load_chunk;
490 ocb_end *= jcp.nb_load_chunk;
491 }
492 conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end);
493 }
494}
495
496template struct jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>;
497template struct jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>;
498
499} // namespace x64
500} // namespace cpu
501} // namespace impl
502} // namespace dnnl
503