1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
25#include "cpu/x64/jit_generator.hpp"
26#include "cpu/x64/jit_primitive_conf.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33template <typename Vmm>
34struct _jit_avx512_core_bf16_fwd_kernel : public jit_generator {
35 _jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t &ajcp,
36 const primitive_attr_t &attr, const memory_desc_t &dst_md);
37
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_fwd_kernel)
39
40 const jit_conv_conf_t &jcp;
41 const primitive_attr_t &attr_;
42
43private:
44 constexpr static int isa_simd_width_
45 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
46 using Vmm_down_t =
47 typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value,
48 Xbyak::Ymm, Xbyak::Xmm>::type;
49 using reg64_t = const Xbyak::Reg64;
50 enum {
51 ker_reg_base_idx = 28,
52 ker_code_size = 1024 * 1024,
53 };
54
55 reg64_t param = abi_param1; //L: RDI, W: RCX
56
57 reg64_t reg_src = r8;
58 reg64_t reg_ker = r9;
59 reg64_t reg_dst = r10;
60 reg64_t reg_owb = r11;
61
62 reg64_t aux_reg_src = r12;
63 reg64_t aux_reg_ker = r13;
64
65 reg64_t reg_ic = rax;
66 reg64_t reg_oc = r15;
67 reg64_t reg_bias = rbx;
68
69 reg64_t reg_kj = abi_not_param1;
70 reg64_t reg_ki = reg_bias;
71 reg64_t reg_oi = rdx;
72 reg64_t reg_kh = rsi;
73
74 reg64_t reg_long_offt = r14;
75
76 /* binary post-ops operand */
77 reg64_t temp_offset_reg = r12;
78
79 int vmm_dst_idx(const int i_ur, const int i_oc) const;
80 Vmm vmm_dst(const int i_ur, const int i_oc) const;
81
82 Vmm vmm_src(int i_ic, int nb_x_blocking) {
83 int idx = i_ic + nb_x_blocking * jcp.ur_w;
84 assert(idx < 31);
85 return Vmm(idx);
86 }
87
88 Vmm_down_t vmm_src_down(int i_ic, int nb_x_blocking) {
89 int idx = i_ic + nb_x_blocking * jcp.ur_w;
90 assert(idx < 31);
91 return Vmm_down_t(idx);
92 }
93
94 inline Vmm may_be_mask_vmm(Vmm vmm, bool mask_flag, bool zero_mask,
95 bool use_extended_mask = false) {
96 if (mask_flag) {
97 vmm = vmm
98 | (use_extended_mask ? k_oc_tail_mask_extended
99 : k_oc_tail_mask);
100 if (zero_mask) vmm = vmm | T_z;
101 }
102 return vmm;
103 }
104
105 inline Vmm_down_t may_be_mask_vmm(Vmm_down_t vmm, bool mask_flag) {
106 return (mask_flag) ? vmm | k_oc_tail_mask : vmm;
107 }
108
109 Vmm vmm_wei = Vmm(31);
110 Vmm vmm_prev_dst = Vmm(31);
111 Vmm vmm_bias = Vmm(31);
112
113 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
114 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
115 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
116 reg64_t bf16_emu_scratch = reg_ic;
117 Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(29);
118 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);
119
120 Xbyak::Opmask odd_load_mask = Xbyak::Opmask(2);
121 Xbyak::Opmask even_load_mask = Xbyak::Opmask(3);
122 Xbyak::Opmask k_oc_tail_mask = Xbyak::Opmask(4);
123 Xbyak::Opmask k_oc_tail_mask_extended = Xbyak::Opmask(5);
124 const Xbyak::Opmask postops_mask = Xbyak::Opmask(6);
125
126 constexpr static int off_reg_src_ = 0;
127 constexpr static int off_reg_ker_ = 8;
128 constexpr static int stack_space_needed_ = 16;
129
130 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>>
131 postops_injector_;
132 std::unique_ptr<bf16_emulation_t> bf16_emu_;
133
134 inline void prepare_dst(int ur_w);
135 void apply_postops(int ur_w);
136 inline void store_dst(int ur_w);
137 inline void compute_loop(int ur_w, int pad_l, int pad_r);
138
139 void generate() override;
140
141 inline dim_t get_dst_offset(dim_t sp_idx, int ocb) {
142 const bool is_layout_nxc = is_dst_layout_nxc();
143 dim_t sp_str = is_layout_nxc ? jcp.ngroups * jcp.oc : jcp.oc_block;
144 dim_t ocb_str = jcp.oc_block
145 * (is_layout_nxc ? 1 : (dim_t)jcp.od * jcp.oh * jcp.ow);
146 return jcp.typesize_out * (ocb_str * ocb + sp_str * sp_idx);
147 }
148
149 inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) {
150 return kw * (jcp.dilate_w + 1) + ow * jcp.stride_w - pad_l;
151 }
152 inline dim_t filter_h_to_src(int kh) {
153 return kh * (jcp.dilate_h + 1) * jcp.iw;
154 }
155 inline dim_t filter_d_to_src(int kd) {
156 return kd * (jcp.dilate_d + 1) * jcp.iw * jcp.ih;
157 }
158
159 inline dim_t get_src_offset(dim_t ic_idx, dim_t isp) {
160 int icb = ic_idx / jcp.ic_block;
161 int ic = ic_idx % jcp.ic_block;
162 dim_t isp_str = is_src_layout_nxc()
163 ? jcp.ngroups * jcp.ic
164 : (jcp.is_1stconv ? 1 : jcp.ic_block);
165 dim_t full_spatial_size = (dim_t)jcp.iw * jcp.ih * jcp.id;
166 dim_t ic_str = jcp.is_1stconv && !is_src_layout_nxc()
167 ? full_spatial_size
168 : 1;
169 dim_t icb_str
170 = (is_src_layout_nxc() ? 1 : full_spatial_size) * jcp.ic_block;
171 return jcp.typesize_in * (isp_str * isp + icb_str * icb + ic_str * ic);
172 }
173
174 inline dim_t get_kernel_offset(
175 int ocb, int ic_idx, int kw, int kh = 0, int kd = 0) {
176 int scale = 2; //bf16 vnni is used
177 int rnd_ic_block = utils::rnd_up(jcp.ic_block, scale);
178 int icb = ic_idx / jcp.ic_block;
179 int ic = ic_idx % jcp.ic_block;
180 dim_t ksp_str = rnd_ic_block * jcp.oc_block;
181 dim_t ksp_idx = kd * jcp.kh * jcp.kw + kh * jcp.kw + kw;
182
183 dim_t icb_str = jcp.kd * jcp.kh * jcp.kw * ksp_str;
184 dim_t ocb_str = jcp.nb_ic * icb_str;
185 dim_t ic_off = (ic / scale) * jcp.oc_block * scale + (ic % scale);
186 return jcp.typesize_in
187 * (ocb * ocb_str + icb * icb_str + ksp_idx * ksp_str + ic_off);
188 }
189
190 int get_ow_start(int ki, int pad_l) {
191 return nstl::max(0,
192 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
193 }
194
195 int get_ow_end(int ur_w, int ki, int pad_r) {
196 return ur_w
197 - nstl::max(0,
198 utils::div_up(
199 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
200 jcp.stride_w));
201 }
202 inline bool is_src_layout_nxc() {
203 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
204 format_tag::nwc);
205 }
206 inline bool is_dst_layout_nxc() {
207 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
208 format_tag::nwc);
209 }
210};
211
212struct jit_avx512_core_bf16_fwd_kernel {
213 jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t &ajcp,
214 const primitive_attr_t &attr, const memory_desc_t &dst_md)
215 : kernel_(nullptr) {
216 switch (ajcp.oc_block) {
217 case 16:
218 kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>(
219 ajcp, attr, dst_md);
220 return;
221 case 8:
222 kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>(
223 ajcp, attr, dst_md);
224 return;
225 case 4:
226 kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>(
227 ajcp, attr, dst_md);
228 return;
229 default: assert(!"invalid channel blocking");
230 }
231 }
232
233 status_t create_kernel() { return kernel_->create_kernel(); }
234
235 ~jit_avx512_core_bf16_fwd_kernel() { delete kernel_; }
236
237 static status_t init_conf(jit_conv_conf_t &jcp,
238 const convolution_desc_t &cd, memory_desc_t &src_pd,
239 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
240 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
241 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
242 const jit_conv_conf_t &jcp);
243
244 void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
245 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
246
247private:
248 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_bf16_fwd_kernel);
249 jit_generator *kernel_;
250};
251
252template <typename Vmm>
253struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator {
254
255 _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp)
256 : jit_generator(
257 jit_name(), nullptr, ker_code_size, true, avx512_core_bf16)
258 , jcp(ajcp)
259 , bf16_emu_(nullptr) {
260 if (!isa_has_bf16(jcp.isa))
261 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
262 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
263 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5);
264 }
265
266 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_bwd_data_kernel_f32)
267
268 const jit_conv_conf_t &jcp;
269
270private:
271 using Vmm_down_t =
272 typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value,
273 Xbyak::Ymm, Xbyak::Xmm>::type;
274 using reg64_t = const Xbyak::Reg64;
275 enum {
276 ker_reg_base_idx = 31,
277 ker_code_size = 1024 * 1024,
278 };
279
280 reg64_t param = abi_param1;
281 reg64_t reg_dst = r8;
282 reg64_t reg_ker = r9;
283 reg64_t reg_src = r10;
284
285 reg64_t reg_iwb = rdx;
286
287 reg64_t aux_reg_dst = r14;
288 reg64_t aux_reg_ker = r15;
289
290 reg64_t aux_reg_dst_d = r12;
291 reg64_t aux_reg_ker_d = r13;
292 reg64_t reg_ki = rsi;
293
294 reg64_t reg_kj = rax;
295 reg64_t reg_oi = rbx;
296 reg64_t reg_kh = abi_not_param1;
297
298 reg64_t reg_oc = r11;
299 reg64_t reg_ic = aux_reg_ker_d;
300
301 Xbyak::Opmask k_ic_tail_mask = Xbyak::Opmask(2);
302 Xbyak::Opmask k_ic_tail_mask_extended = Xbyak::Opmask(3);
303
304 Vmm vmm_ddst(int i_ic) {
305 int idx = i_ic + jcp.nb_ic_blocking * jcp.ur_w;
306 assert(idx < ker_reg_base_idx);
307 return Vmm(idx);
308 }
309
310 Vmm_down_t vmm_ddst_down(int i_ic) {
311 int idx = i_ic + jcp.nb_ic_blocking * jcp.ur_w;
312 assert(idx < ker_reg_base_idx);
313 return Vmm_down_t(idx);
314 }
315
316 Vmm vmm_dsrc(int i_ur, int i_oc) {
317 int idx = i_ur + i_oc * jcp.ur_w;
318 assert(idx < ker_reg_base_idx);
319 return Vmm(idx);
320 }
321
322 inline Vmm may_be_mask_vmm(Vmm vmm, bool mask_flag, bool zero_mask,
323 bool use_extended_mask = false) {
324 if (mask_flag) {
325 vmm = vmm
326 | (use_extended_mask ? k_ic_tail_mask_extended
327 : k_ic_tail_mask);
328 if (zero_mask) vmm = vmm | T_z;
329 }
330 return vmm;
331 }
332
333 inline Vmm_down_t may_be_mask_vmm(Vmm_down_t vmm, bool mask_flag) {
334 return mask_flag ? vmm | k_ic_tail_mask : vmm;
335 }
336
337 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26);
338 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27);
339 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28);
340 reg64_t bf16_emu_scratch = reg_kj;
341 Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(29);
342 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);
343
344 Vmm vmm_wei = Vmm(31);
345 std::unique_ptr<bf16_emulation_t> bf16_emu_;
346
347 inline void prepare_output(int ur_w);
348 inline void store_output(int ur_w);
349 inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
350 void generate() override;
351
352 int get_iw_start(int ki, int l_overflow) {
353 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
354 + l_overflow * jcp.stride_w
355 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
356 while (res < 0)
357 res += jcp.stride_w;
358
359 return res;
360 }
361
362 int get_iw_end(int ur_w, int ki, int r_overflow) {
363 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
364 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
365 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
366 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
367 while (res < 0)
368 res += jcp.stride_w;
369
370 return ur_w - res;
371 }
372
373 inline int filter_h_to_dst(int kh) {
374 return kh * (jcp.dilate_h + 1) * jcp.ow;
375 }
376 inline int filter_d_to_dst(int kd) {
377 return kd * (jcp.dilate_d + 1) * jcp.ow * jcp.oh;
378 }
379
380 inline size_t get_diff_src_offset(int iw_idx, int n_ic_block) {
381 const bool is_nxc_layout = is_dsrc_layout_nxc();
382 size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block;
383 size_t icb_str = jcp.ic_block
384 * (is_nxc_layout ? 1 : (size_t)jcp.id * jcp.ih * jcp.iw);
385 return jcp.typesize_out * (iw_str * iw_idx + icb_str * n_ic_block);
386 }
387
388 inline size_t get_diff_dst_offset(
389 int osp_idx, int oc_within_block_idx, int oc_block_idx = 0) {
390 const bool is_nxc_layout = is_ddst_layout_nxc();
391 size_t osp_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block;
392 size_t ocb_str = jcp.oc_block
393 * (is_nxc_layout ? 1 : (size_t)jcp.od * jcp.oh * jcp.ow);
394 return jcp.typesize_in
395 * (osp_str * osp_idx + ocb_str * oc_block_idx
396 + oc_within_block_idx);
397 }
398
399 inline size_t get_kernel_offset(
400 int icb, int oc_idx, int kw, int kh = 0, int kd = 0) {
401 int scale = 2; //bf16 vnni is used
402 int ocb = oc_idx / jcp.oc_block;
403 int oc = oc_idx % jcp.oc_block;
404 size_t ksp_str = jcp.ic_block * jcp.oc_block;
405 size_t ksp_idx = kd * jcp.kh * jcp.kw + kh * jcp.kw + kw;
406
407 size_t icb_str = jcp.kd * jcp.kh * jcp.kw * ksp_str;
408 size_t ocb_str = jcp.nb_ic * icb_str;
409 size_t oc_off = (oc / scale) * jcp.ic_block * scale + (oc % scale);
410 return jcp.typesize_in
411 * (ocb * ocb_str + icb * icb_str + ksp_idx * ksp_str + oc_off);
412 }
413
414 inline bool is_dsrc_layout_nxc() {
415 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
416 format_tag::nwc);
417 }
418 inline bool is_ddst_layout_nxc() {
419 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
420 format_tag::nwc);
421 }
422};
423
424struct jit_avx512_core_bf16_bwd_data_kernel {
425
426 jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp)
427 : kernel_(nullptr) {
428 switch (ajcp.ic_block) {
429 case 16:
430 kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>(
431 ajcp);
432 return;
433 case 8:
434 kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>(
435 ajcp);
436 return;
437 case 4:
438 kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>(
439 ajcp);
440 return;
441 default: assert(!"invalid channel blocking");
442 }
443 }
444
445 status_t create_kernel() { return kernel_->create_kernel(); }
446
447 ~jit_avx512_core_bf16_bwd_data_kernel() { delete kernel_; }
448
449 static status_t init_conf(jit_conv_conf_t &jcp,
450 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
451 memory_desc_t &weights_md, memory_desc_t &diff_dst_md,
452 int nthreads);
453 void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); }
454 const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); }
455
456private:
457 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_bf16_bwd_data_kernel);
458 jit_generator *kernel_;
459};
460
461struct jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 : public jit_generator {
462
463 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32(
464 const jit_conv_conf_t &ajcp)
465 : jit_generator(
466 jit_name(), nullptr, ker_code_size, true, avx512_core_bf16)
467 , jcp(ajcp)
468 , bf16_emu_(nullptr) {
469 if (!isa_has_bf16(jcp.isa)) {
470 bf16_emu_ = utils::make_unique<bf16_emulation_t>(
471 this, one, even, selector, scratch, tmp0, tmp1);
472 }
473 }
474
475 ~jit_avx512_core_bf16_conv_bwd_weights_kernel_f32() = default;
476
477 DECLARE_CPU_JIT_AUX_FUNCTIONS(
478 jit_avx512_core_bf16_conv_bwd_weights_kernel_f32)
479
480 static status_t init_conf(jit_conv_conf_t &jcp,
481 const convolution_desc_t &cd, memory_desc_t &src_md,
482 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
483 memory_desc_t &diff_dst_md, int nthreads);
484 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
485 const jit_conv_conf_t &jcp);
486
487 const jit_conv_conf_t &jcp;
488
489private:
490 Xbyak::Label dst_prm_table;
491 // Used by compute_ic_block_step_{vpermw, interleave}
492 Xbyak::Opmask m_ffffffff = Xbyak::Opmask(1);
493 // Used by compute_ic_block_step_vpermw
494 Xbyak::Opmask m_0000ffff = Xbyak::Opmask(2);
495 Xbyak::Opmask m_ffff0000 = Xbyak::Opmask(3);
496 Xbyak::Opmask m_0000_oc_tail = Xbyak::Opmask(4);
497 Xbyak::Opmask m_oc_tail_0000 = Xbyak::Opmask(5);
498 Xbyak::Opmask m_0000_ic_tail = Xbyak::Opmask(6);
499 Xbyak::Opmask m_ic_tail_0000 = Xbyak::Opmask(7);
500 // Used by compute_ic_block_step_extern (1st_conv only)
501 Xbyak::Opmask everyother_mask = Xbyak::Opmask(6);
502 Xbyak::Opmask everyother_shift_mask = Xbyak::Opmask(7);
503 // Used by compute_ic_block_step_interleave (1st_conv only)
504 Xbyak::Opmask underflow_mask = Xbyak::Opmask(4);
505 Xbyak::Opmask overflow_mask = Xbyak::Opmask(5);
506 Xbyak::Opmask underflow_stride_mask = Xbyak::Opmask(6);
507 Xbyak::Opmask overflow_stride_mask = Xbyak::Opmask(7);
508
509 using reg64_t = const Xbyak::Reg64;
510 enum {
511 sizeof_cacheline = 64,
512 full_spat_opt_working_set_size = 48 * 1024,
513 full_spat_max_working_set_size = 128 * 1024,
514 ker_code_size = 1024 * 1024,
515 };
516 static const int max_ur_w;
517
518 reg64_t param = abi_param1;
519 reg64_t reg_src = rax;
520 reg64_t reg_kernel = rdx;
521 reg64_t reg_ddst = rsi;
522 reg64_t b_ic = abi_not_param1;
523 reg64_t kj = r8;
524 reg64_t reg_kh = r9;
525 reg64_t reg_ur_w_trips = r10;
526 reg64_t reg_oj = r15;
527 reg64_t reg_tmp = r14;
528 reg64_t reg_ih_shift = reg_tmp;
529 reg64_t reg_long_offt = r14;
530 reg64_t reg_icb = rbx;
531
532 reg64_t ki = r11;
533 reg64_t reg_oj_setup = r11;
534 reg64_t reg_kd_count = r12;
535 reg64_t reg_oi = r12;
536 reg64_t reg_d_index = r13;
537 reg64_t reg_src_d = r15;
538 reg64_t reg_ddst_d = rbx;
539 reg64_t aux_reg_src = r12;
540 reg64_t aux_reg_kernel = r13;
541
542 Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0);
543 Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1);
544 Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2);
545
546 Xbyak::Zmm one = Xbyak::Zmm(27);
547 Xbyak::Zmm even = Xbyak::Zmm(28);
548 Xbyak::Zmm selector = Xbyak::Zmm(29);
549 Xbyak::Zmm tmp0 = Xbyak::Zmm(30);
550 Xbyak::Zmm tmp1 = Xbyak::Zmm(31);
551 reg64_t scratch = r11;
552
553 inline void maybe_zero_kernel();
554 inline void get_ur_w(int &ur_w, int &ur_w_tail, int &ur_w_trips);
555 inline void compute_oh_step_unroll_ow_icblock(int ic_block_step);
556 inline void od_step_comeback_pointers();
557 inline void oh_step_comeback_pointers();
558 inline void compute_oh_step_unroll_ow(int ic_block_step);
559 inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
560 int ic_block_step, int src_offset, int kernel_offset,
561 int ddst_offset, bool is_tail = false);
562 inline void compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r,
563 int ic_block_step, int src_offset, int kernel_offset,
564 int ddst_offset, bool is_tail = false);
565 inline void compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r,
566 int ic_block_step, int src_offset, int kernel_offset,
567 int ddst_offset, bool is_tail = false);
568 inline void compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r,
569 int ic_block_step, int src_offset, int kernel_offset,
570 int ddst_offset, bool is_tail = false);
571 inline void compute_oh_step_common(int ic_block_step);
572 inline void compute_oh_step_disp();
573 inline void compute_loop();
574 inline void compute_oh_loop_common(bool partial = false);
575 inline void compute_od_loop_common(bool partial = false);
576 void compute_full_spat_loop();
577 void compute_diff_bias_init();
578 void compute_diff_bias_row(bool is_partial = true);
579 void maybe_compute_diff_bias();
580 void convert_src_to_vnni_format(
581 int ur_w, int pad_l, int pad_r, int src_offset);
582 void may_be_set_oc_tail_mask();
583 void may_be_reset_oc_tail_mask();
584 inline void compute_ic_block_step_vpermw_expl(int ur_w, int pad_l,
585 int pad_r, int ic_block_step, int src_offset, int kernel_offset,
586 int ddst_offset, bool is_tail = false);
587 inline bool is_src_layout_nxc() {
588 return jcp.uses_permw_transposition
589 && utils::one_of(jcp.src_tag, format_tag::ndhwc,
590 format_tag::nhwc, format_tag::nwc);
591 }
592 inline bool is_ddst_layout_nxc() {
593 return jcp.uses_permw_transposition
594 && utils::one_of(jcp.dst_tag, format_tag::ndhwc,
595 format_tag::nhwc, format_tag::nwc);
596 }
597
598 void generate() override;
599
600 static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
601 int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
602
603 void get_w_positions(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw,
604 int &iw_0, int &iw_1) {
605 auto get_w_position = [=](int idx) {
606 int iw = i_ur + idx;
607 if (iw >= ur_w) return -1;
608 iw += i_kw;
609 if (iw - pad_l < 0 || iw > (ur_w - 1) + (jcp.kw - 1) - pad_r)
610 return -1;
611 return iw - pad_l;
612 };
613 iw_0 = get_w_position(0);
614 iw_1 = get_w_position(1);
615 }
616 bool check_borders(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw) {
617 int iw_1, iw_2;
618 get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2);
619
620 return (iw_1 == -1 && iw_2 == -1) ? false : true;
621 }
622 bool get_load_mask(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw,
623 Xbyak::Opmask &load_mask) {
624 int iw_1, iw_2;
625 get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2);
626
627 bool rt = true;
628 if (iw_1 != -1 && iw_2 != -1)
629 load_mask = m_ffffffff;
630 else if (iw_1 != -1 && iw_2 == -1)
631 load_mask = m_0000ffff;
632 else if (iw_1 == -1 && iw_2 != -1)
633 load_mask = m_ffff0000;
634 else
635 rt = false;
636
637 return rt;
638 }
639
640 inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) {
641 int stride_w = jcp.transpose_src ? 1 : jcp.stride_w;
642 return kw * (jcp.dilate_w + 1) + ow * stride_w - pad_l;
643 }
644 inline dim_t filter_h_to_src(int kh) { return kh * (jcp.dilate_h + 1); }
645 inline dim_t filter_d_to_src(int kd) {
646 return kd * (jcp.dilate_d + 1) * jcp.ih;
647 }
648
649 inline dim_t get_src_offset(dim_t ic_idx, dim_t w_idx, dim_t hd_idx = 0) {
650 // For is_src_layout_nxc() the ic_idx index inside the block
651 // is supported only ic_idx == jcp.ic_block is considered as a shift
652 // within one block and not as moving to the next ic block.
653 assert(IMPLICATION(!is_src_layout_nxc(), ic_idx <= jcp.ic_block));
654 dim_t icb = is_src_layout_nxc() ? ic_idx / jcp.ic_block : 0;
655 dim_t ic = is_src_layout_nxc() ? ic_idx % jcp.ic_block : ic_idx;
656 dim_t iw_str = jcp.is_1stconv || jcp.transpose_src
657 ? 1
658 : (is_src_layout_nxc() ? jcp.ngroups * jcp.ic : jcp.ic_block);
659 dim_t ihid_str
660 = jcp.tr_iw * (jcp.transpose_src ? jcp.ic_block : iw_str);
661 // jcp.transpose_src w_idx might be greater than jcp.tr_iw as right zero
662 // padding memory is shared with left zero padding of the next block
663 dim_t isp_off = hd_idx * ihid_str + w_idx * iw_str;
664 dim_t full_spatial_size = (dim_t)jcp.tr_iw * jcp.ih * jcp.id;
665 dim_t ic_str = jcp.transpose_src
666 ? jcp.tr_iw
667 : (jcp.is_1stconv ? full_spatial_size : 1);
668 dim_t icb_str
669 = jcp.ic_block * (is_src_layout_nxc() ? 1 : full_spatial_size);
670 return jcp.typesize_in * (isp_off + icb_str * icb + ic_str * ic);
671 }
672
673 inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) {
674 int ow_per_oc = jcp.transpose_dst ? 2 : 1;
675 int ch_mult
676 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
677 dim_t hd_off = jcp.tr_ow * ch_mult * hd_idx;
678 dim_t w_off
679 = w_idx / ow_per_oc * ow_per_oc * ch_mult + w_idx % ow_per_oc;
680 return jcp.typesize_in * (w_off + hd_off);
681 }
682
683 inline dim_t get_kernel_offset(int ic_idx, dim_t ksp_idx) {
684 // Only the ic_idx index inside the block is supported,
685 // ic_idx == jcp.ic_block is considered as a shift inside one block
686 // and not as moving to the next ic block.
687 // Negative values are supported for negative shift.
688 assert(nstl::abs(ic_idx) <= jcp.ic_block);
689 return jcp.typesize_out * jcp.oc_block
690 * (ksp_idx * jcp.ic_block + ic_idx);
691 }
692
693 Xbyak::Zmm get_perm_reg() {
694 int idx = !(jcp.uses_permw_transposition
695 && jcp.kernel_kind == expl_bcast)
696 ? 24
697 : ((!isa_has_bf16(jcp.isa)) ? 26 : 31);
698 return Xbyak::Zmm(idx);
699 }
700 std::unique_ptr<bf16_emulation_t> bf16_emu_;
701
702 inline int interleave_w_reorder_size(int ur_w) const;
703 inline int interleave_w_reorder_bytes(int ur_w);
704 inline int interleave_stack_size(int ur_w, int ic_block_step);
705 inline int permw_stack_size(int ur_w) {
706 return (ur_w + jcp.kw - 1) * sizeof_cacheline;
707 }
708
709 inline void setup_stack_space();
710 static const int extern_ic_block_step_stack_size = 0;
711 int ic_block_step_stack_size = 0;
712 int stack_space_needed = 0;
713 int permw_buffer_start = 0;
714 int kd_count_offset = 0;
715 int src_d_offset = 0;
716 int ddst_d_offset = 0;
717 int d_index_offset = 0;
718 int trans_tmp_offset = 0;
719 int ih_dilate_shift = 0;
720 int icb_loop_ker_ptr = 0;
721 int icb_loop_src_ptr = 0;
722};
723} // namespace x64
724} // namespace cpu
725} // namespace impl
726} // namespace dnnl
727#endif
728