1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
25 | #include "cpu/x64/jit_generator.hpp" |
26 | #include "cpu/x64/jit_primitive_conf.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | template <typename Vmm> |
34 | struct _jit_avx512_core_bf16_fwd_kernel : public jit_generator { |
35 | _jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t &ajcp, |
36 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
37 | |
38 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_fwd_kernel) |
39 | |
40 | const jit_conv_conf_t &jcp; |
41 | const primitive_attr_t &attr_; |
42 | |
43 | private: |
44 | constexpr static int isa_simd_width_ |
45 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
46 | using Vmm_down_t = |
47 | typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value, |
48 | Xbyak::Ymm, Xbyak::Xmm>::type; |
49 | using reg64_t = const Xbyak::Reg64; |
50 | enum { |
51 | ker_reg_base_idx = 28, |
52 | ker_code_size = 1024 * 1024, |
53 | }; |
54 | |
55 | reg64_t param = abi_param1; //L: RDI, W: RCX |
56 | |
57 | reg64_t reg_src = r8; |
58 | reg64_t reg_ker = r9; |
59 | reg64_t reg_dst = r10; |
60 | reg64_t reg_owb = r11; |
61 | |
62 | reg64_t aux_reg_src = r12; |
63 | reg64_t aux_reg_ker = r13; |
64 | |
65 | reg64_t reg_ic = rax; |
66 | reg64_t reg_oc = r15; |
67 | reg64_t reg_bias = rbx; |
68 | |
69 | reg64_t reg_kj = abi_not_param1; |
70 | reg64_t reg_ki = reg_bias; |
71 | reg64_t reg_oi = rdx; |
72 | reg64_t reg_kh = rsi; |
73 | |
74 | reg64_t reg_long_offt = r14; |
75 | |
76 | /* binary post-ops operand */ |
77 | reg64_t temp_offset_reg = r12; |
78 | |
79 | int vmm_dst_idx(const int i_ur, const int i_oc) const; |
80 | Vmm vmm_dst(const int i_ur, const int i_oc) const; |
81 | |
82 | Vmm vmm_src(int i_ic, int nb_x_blocking) { |
83 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
84 | assert(idx < 31); |
85 | return Vmm(idx); |
86 | } |
87 | |
88 | Vmm_down_t vmm_src_down(int i_ic, int nb_x_blocking) { |
89 | int idx = i_ic + nb_x_blocking * jcp.ur_w; |
90 | assert(idx < 31); |
91 | return Vmm_down_t(idx); |
92 | } |
93 | |
94 | inline Vmm may_be_mask_vmm(Vmm vmm, bool mask_flag, bool zero_mask, |
95 | bool use_extended_mask = false) { |
96 | if (mask_flag) { |
97 | vmm = vmm |
98 | | (use_extended_mask ? k_oc_tail_mask_extended |
99 | : k_oc_tail_mask); |
100 | if (zero_mask) vmm = vmm | T_z; |
101 | } |
102 | return vmm; |
103 | } |
104 | |
105 | inline Vmm_down_t may_be_mask_vmm(Vmm_down_t vmm, bool mask_flag) { |
106 | return (mask_flag) ? vmm | k_oc_tail_mask : vmm; |
107 | } |
108 | |
109 | Vmm vmm_wei = Vmm(31); |
110 | Vmm vmm_prev_dst = Vmm(31); |
111 | Vmm vmm_bias = Vmm(31); |
112 | |
113 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
114 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
115 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
116 | reg64_t bf16_emu_scratch = reg_ic; |
117 | Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(29); |
118 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30); |
119 | |
120 | Xbyak::Opmask odd_load_mask = Xbyak::Opmask(2); |
121 | Xbyak::Opmask even_load_mask = Xbyak::Opmask(3); |
122 | Xbyak::Opmask k_oc_tail_mask = Xbyak::Opmask(4); |
123 | Xbyak::Opmask k_oc_tail_mask_extended = Xbyak::Opmask(5); |
124 | const Xbyak::Opmask postops_mask = Xbyak::Opmask(6); |
125 | |
126 | constexpr static int off_reg_src_ = 0; |
127 | constexpr static int off_reg_ker_ = 8; |
128 | constexpr static int stack_space_needed_ = 16; |
129 | |
130 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core, Vmm>> |
131 | postops_injector_; |
132 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
133 | |
134 | inline void prepare_dst(int ur_w); |
135 | void apply_postops(int ur_w); |
136 | inline void store_dst(int ur_w); |
137 | inline void compute_loop(int ur_w, int pad_l, int pad_r); |
138 | |
139 | void generate() override; |
140 | |
141 | inline dim_t get_dst_offset(dim_t sp_idx, int ocb) { |
142 | const bool is_layout_nxc = is_dst_layout_nxc(); |
143 | dim_t sp_str = is_layout_nxc ? jcp.ngroups * jcp.oc : jcp.oc_block; |
144 | dim_t ocb_str = jcp.oc_block |
145 | * (is_layout_nxc ? 1 : (dim_t)jcp.od * jcp.oh * jcp.ow); |
146 | return jcp.typesize_out * (ocb_str * ocb + sp_str * sp_idx); |
147 | } |
148 | |
149 | inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) { |
150 | return kw * (jcp.dilate_w + 1) + ow * jcp.stride_w - pad_l; |
151 | } |
152 | inline dim_t filter_h_to_src(int kh) { |
153 | return kh * (jcp.dilate_h + 1) * jcp.iw; |
154 | } |
155 | inline dim_t filter_d_to_src(int kd) { |
156 | return kd * (jcp.dilate_d + 1) * jcp.iw * jcp.ih; |
157 | } |
158 | |
159 | inline dim_t get_src_offset(dim_t ic_idx, dim_t isp) { |
160 | int icb = ic_idx / jcp.ic_block; |
161 | int ic = ic_idx % jcp.ic_block; |
162 | dim_t isp_str = is_src_layout_nxc() |
163 | ? jcp.ngroups * jcp.ic |
164 | : (jcp.is_1stconv ? 1 : jcp.ic_block); |
165 | dim_t full_spatial_size = (dim_t)jcp.iw * jcp.ih * jcp.id; |
166 | dim_t ic_str = jcp.is_1stconv && !is_src_layout_nxc() |
167 | ? full_spatial_size |
168 | : 1; |
169 | dim_t icb_str |
170 | = (is_src_layout_nxc() ? 1 : full_spatial_size) * jcp.ic_block; |
171 | return jcp.typesize_in * (isp_str * isp + icb_str * icb + ic_str * ic); |
172 | } |
173 | |
174 | inline dim_t get_kernel_offset( |
175 | int ocb, int ic_idx, int kw, int kh = 0, int kd = 0) { |
176 | int scale = 2; //bf16 vnni is used |
177 | int rnd_ic_block = utils::rnd_up(jcp.ic_block, scale); |
178 | int icb = ic_idx / jcp.ic_block; |
179 | int ic = ic_idx % jcp.ic_block; |
180 | dim_t ksp_str = rnd_ic_block * jcp.oc_block; |
181 | dim_t ksp_idx = kd * jcp.kh * jcp.kw + kh * jcp.kw + kw; |
182 | |
183 | dim_t icb_str = jcp.kd * jcp.kh * jcp.kw * ksp_str; |
184 | dim_t ocb_str = jcp.nb_ic * icb_str; |
185 | dim_t ic_off = (ic / scale) * jcp.oc_block * scale + (ic % scale); |
186 | return jcp.typesize_in |
187 | * (ocb * ocb_str + icb * icb_str + ksp_idx * ksp_str + ic_off); |
188 | } |
189 | |
190 | int get_ow_start(int ki, int pad_l) { |
191 | return nstl::max(0, |
192 | utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); |
193 | } |
194 | |
195 | int get_ow_end(int ur_w, int ki, int pad_r) { |
196 | return ur_w |
197 | - nstl::max(0, |
198 | utils::div_up( |
199 | pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), |
200 | jcp.stride_w)); |
201 | } |
202 | inline bool is_src_layout_nxc() { |
203 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
204 | format_tag::nwc); |
205 | } |
206 | inline bool is_dst_layout_nxc() { |
207 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
208 | format_tag::nwc); |
209 | } |
210 | }; |
211 | |
212 | struct jit_avx512_core_bf16_fwd_kernel { |
213 | jit_avx512_core_bf16_fwd_kernel(const jit_conv_conf_t &ajcp, |
214 | const primitive_attr_t &attr, const memory_desc_t &dst_md) |
215 | : kernel_(nullptr) { |
216 | switch (ajcp.oc_block) { |
217 | case 16: |
218 | kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Zmm>( |
219 | ajcp, attr, dst_md); |
220 | return; |
221 | case 8: |
222 | kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Ymm>( |
223 | ajcp, attr, dst_md); |
224 | return; |
225 | case 4: |
226 | kernel_ = new _jit_avx512_core_bf16_fwd_kernel<Xbyak::Xmm>( |
227 | ajcp, attr, dst_md); |
228 | return; |
229 | default: assert(!"invalid channel blocking" ); |
230 | } |
231 | } |
232 | |
233 | status_t create_kernel() { return kernel_->create_kernel(); } |
234 | |
235 | ~jit_avx512_core_bf16_fwd_kernel() { delete kernel_; } |
236 | |
237 | static status_t init_conf(jit_conv_conf_t &jcp, |
238 | const convolution_desc_t &cd, memory_desc_t &src_pd, |
239 | memory_desc_t &weights_pd, memory_desc_t &dst_pd, |
240 | memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads); |
241 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
242 | const jit_conv_conf_t &jcp); |
243 | |
244 | void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } |
245 | const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } |
246 | |
247 | private: |
248 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_bf16_fwd_kernel); |
249 | jit_generator *kernel_; |
250 | }; |
251 | |
252 | template <typename Vmm> |
253 | struct _jit_avx512_core_bf16_bwd_data_kernel : public jit_generator { |
254 | |
255 | _jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) |
256 | : jit_generator( |
257 | jit_name(), nullptr, ker_code_size, true, avx512_core_bf16) |
258 | , jcp(ajcp) |
259 | , bf16_emu_(nullptr) { |
260 | if (!isa_has_bf16(jcp.isa)) |
261 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
262 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
263 | bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_5); |
264 | } |
265 | |
266 | DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_bf16_bwd_data_kernel_f32) |
267 | |
268 | const jit_conv_conf_t &jcp; |
269 | |
270 | private: |
271 | using Vmm_down_t = |
272 | typename utils::conditional<std::is_same<Vmm, Xbyak::Zmm>::value, |
273 | Xbyak::Ymm, Xbyak::Xmm>::type; |
274 | using reg64_t = const Xbyak::Reg64; |
275 | enum { |
276 | ker_reg_base_idx = 31, |
277 | ker_code_size = 1024 * 1024, |
278 | }; |
279 | |
280 | reg64_t param = abi_param1; |
281 | reg64_t reg_dst = r8; |
282 | reg64_t reg_ker = r9; |
283 | reg64_t reg_src = r10; |
284 | |
285 | reg64_t reg_iwb = rdx; |
286 | |
287 | reg64_t aux_reg_dst = r14; |
288 | reg64_t aux_reg_ker = r15; |
289 | |
290 | reg64_t aux_reg_dst_d = r12; |
291 | reg64_t aux_reg_ker_d = r13; |
292 | reg64_t reg_ki = rsi; |
293 | |
294 | reg64_t reg_kj = rax; |
295 | reg64_t reg_oi = rbx; |
296 | reg64_t reg_kh = abi_not_param1; |
297 | |
298 | reg64_t reg_oc = r11; |
299 | reg64_t reg_ic = aux_reg_ker_d; |
300 | |
301 | Xbyak::Opmask k_ic_tail_mask = Xbyak::Opmask(2); |
302 | Xbyak::Opmask k_ic_tail_mask_extended = Xbyak::Opmask(3); |
303 | |
304 | Vmm vmm_ddst(int i_ic) { |
305 | int idx = i_ic + jcp.nb_ic_blocking * jcp.ur_w; |
306 | assert(idx < ker_reg_base_idx); |
307 | return Vmm(idx); |
308 | } |
309 | |
310 | Vmm_down_t vmm_ddst_down(int i_ic) { |
311 | int idx = i_ic + jcp.nb_ic_blocking * jcp.ur_w; |
312 | assert(idx < ker_reg_base_idx); |
313 | return Vmm_down_t(idx); |
314 | } |
315 | |
316 | Vmm vmm_dsrc(int i_ur, int i_oc) { |
317 | int idx = i_ur + i_oc * jcp.ur_w; |
318 | assert(idx < ker_reg_base_idx); |
319 | return Vmm(idx); |
320 | } |
321 | |
322 | inline Vmm may_be_mask_vmm(Vmm vmm, bool mask_flag, bool zero_mask, |
323 | bool use_extended_mask = false) { |
324 | if (mask_flag) { |
325 | vmm = vmm |
326 | | (use_extended_mask ? k_ic_tail_mask_extended |
327 | : k_ic_tail_mask); |
328 | if (zero_mask) vmm = vmm | T_z; |
329 | } |
330 | return vmm; |
331 | } |
332 | |
333 | inline Vmm_down_t may_be_mask_vmm(Vmm_down_t vmm, bool mask_flag) { |
334 | return mask_flag ? vmm | k_ic_tail_mask : vmm; |
335 | } |
336 | |
337 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(26); |
338 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(27); |
339 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(28); |
340 | reg64_t bf16_emu_scratch = reg_kj; |
341 | Xbyak::Zmm bf16_emu_reserv_4 = Xbyak::Zmm(29); |
342 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30); |
343 | |
344 | Vmm vmm_wei = Vmm(31); |
345 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
346 | |
347 | inline void prepare_output(int ur_w); |
348 | inline void store_output(int ur_w); |
349 | inline void compute_loop(int ur_w, int l_overflow, int r_overflow); |
350 | void generate() override; |
351 | |
352 | int get_iw_start(int ki, int l_overflow) { |
353 | int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w |
354 | + l_overflow * jcp.stride_w |
355 | - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); |
356 | while (res < 0) |
357 | res += jcp.stride_w; |
358 | |
359 | return res; |
360 | } |
361 | |
362 | int get_iw_end(int ur_w, int ki, int r_overflow) { |
363 | if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) |
364 | ur_w += nstl::min(0, jcp.r_pad); // remove negative padding |
365 | int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w |
366 | + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); |
367 | while (res < 0) |
368 | res += jcp.stride_w; |
369 | |
370 | return ur_w - res; |
371 | } |
372 | |
373 | inline int filter_h_to_dst(int kh) { |
374 | return kh * (jcp.dilate_h + 1) * jcp.ow; |
375 | } |
376 | inline int filter_d_to_dst(int kd) { |
377 | return kd * (jcp.dilate_d + 1) * jcp.ow * jcp.oh; |
378 | } |
379 | |
380 | inline size_t get_diff_src_offset(int iw_idx, int n_ic_block) { |
381 | const bool is_nxc_layout = is_dsrc_layout_nxc(); |
382 | size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block; |
383 | size_t icb_str = jcp.ic_block |
384 | * (is_nxc_layout ? 1 : (size_t)jcp.id * jcp.ih * jcp.iw); |
385 | return jcp.typesize_out * (iw_str * iw_idx + icb_str * n_ic_block); |
386 | } |
387 | |
388 | inline size_t get_diff_dst_offset( |
389 | int osp_idx, int oc_within_block_idx, int oc_block_idx = 0) { |
390 | const bool is_nxc_layout = is_ddst_layout_nxc(); |
391 | size_t osp_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block; |
392 | size_t ocb_str = jcp.oc_block |
393 | * (is_nxc_layout ? 1 : (size_t)jcp.od * jcp.oh * jcp.ow); |
394 | return jcp.typesize_in |
395 | * (osp_str * osp_idx + ocb_str * oc_block_idx |
396 | + oc_within_block_idx); |
397 | } |
398 | |
399 | inline size_t get_kernel_offset( |
400 | int icb, int oc_idx, int kw, int kh = 0, int kd = 0) { |
401 | int scale = 2; //bf16 vnni is used |
402 | int ocb = oc_idx / jcp.oc_block; |
403 | int oc = oc_idx % jcp.oc_block; |
404 | size_t ksp_str = jcp.ic_block * jcp.oc_block; |
405 | size_t ksp_idx = kd * jcp.kh * jcp.kw + kh * jcp.kw + kw; |
406 | |
407 | size_t icb_str = jcp.kd * jcp.kh * jcp.kw * ksp_str; |
408 | size_t ocb_str = jcp.nb_ic * icb_str; |
409 | size_t oc_off = (oc / scale) * jcp.ic_block * scale + (oc % scale); |
410 | return jcp.typesize_in |
411 | * (ocb * ocb_str + icb * icb_str + ksp_idx * ksp_str + oc_off); |
412 | } |
413 | |
414 | inline bool is_dsrc_layout_nxc() { |
415 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, |
416 | format_tag::nwc); |
417 | } |
418 | inline bool is_ddst_layout_nxc() { |
419 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, |
420 | format_tag::nwc); |
421 | } |
422 | }; |
423 | |
424 | struct jit_avx512_core_bf16_bwd_data_kernel { |
425 | |
426 | jit_avx512_core_bf16_bwd_data_kernel(const jit_conv_conf_t &ajcp) |
427 | : kernel_(nullptr) { |
428 | switch (ajcp.ic_block) { |
429 | case 16: |
430 | kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Zmm>( |
431 | ajcp); |
432 | return; |
433 | case 8: |
434 | kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Ymm>( |
435 | ajcp); |
436 | return; |
437 | case 4: |
438 | kernel_ = new _jit_avx512_core_bf16_bwd_data_kernel<Xbyak::Xmm>( |
439 | ajcp); |
440 | return; |
441 | default: assert(!"invalid channel blocking" ); |
442 | } |
443 | } |
444 | |
445 | status_t create_kernel() { return kernel_->create_kernel(); } |
446 | |
447 | ~jit_avx512_core_bf16_bwd_data_kernel() { delete kernel_; } |
448 | |
449 | static status_t init_conf(jit_conv_conf_t &jcp, |
450 | const convolution_desc_t &cd, memory_desc_t &diff_src_md, |
451 | memory_desc_t &weights_md, memory_desc_t &diff_dst_md, |
452 | int nthreads); |
453 | void operator()(const jit_conv_call_s *p) const { (*kernel_)(p); } |
454 | const Xbyak::uint8 *jit_ker() const { return kernel_->jit_ker(); } |
455 | |
456 | private: |
457 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_avx512_core_bf16_bwd_data_kernel); |
458 | jit_generator *kernel_; |
459 | }; |
460 | |
461 | struct jit_avx512_core_bf16_conv_bwd_weights_kernel_f32 : public jit_generator { |
462 | |
463 | jit_avx512_core_bf16_conv_bwd_weights_kernel_f32( |
464 | const jit_conv_conf_t &ajcp) |
465 | : jit_generator( |
466 | jit_name(), nullptr, ker_code_size, true, avx512_core_bf16) |
467 | , jcp(ajcp) |
468 | , bf16_emu_(nullptr) { |
469 | if (!isa_has_bf16(jcp.isa)) { |
470 | bf16_emu_ = utils::make_unique<bf16_emulation_t>( |
471 | this, one, even, selector, scratch, tmp0, tmp1); |
472 | } |
473 | } |
474 | |
475 | ~jit_avx512_core_bf16_conv_bwd_weights_kernel_f32() = default; |
476 | |
477 | DECLARE_CPU_JIT_AUX_FUNCTIONS( |
478 | jit_avx512_core_bf16_conv_bwd_weights_kernel_f32) |
479 | |
480 | static status_t init_conf(jit_conv_conf_t &jcp, |
481 | const convolution_desc_t &cd, memory_desc_t &src_md, |
482 | memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, |
483 | memory_desc_t &diff_dst_md, int nthreads); |
484 | static void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
485 | const jit_conv_conf_t &jcp); |
486 | |
487 | const jit_conv_conf_t &jcp; |
488 | |
489 | private: |
490 | Xbyak::Label dst_prm_table; |
491 | // Used by compute_ic_block_step_{vpermw, interleave} |
492 | Xbyak::Opmask m_ffffffff = Xbyak::Opmask(1); |
493 | // Used by compute_ic_block_step_vpermw |
494 | Xbyak::Opmask m_0000ffff = Xbyak::Opmask(2); |
495 | Xbyak::Opmask m_ffff0000 = Xbyak::Opmask(3); |
496 | Xbyak::Opmask m_0000_oc_tail = Xbyak::Opmask(4); |
497 | Xbyak::Opmask m_oc_tail_0000 = Xbyak::Opmask(5); |
498 | Xbyak::Opmask m_0000_ic_tail = Xbyak::Opmask(6); |
499 | Xbyak::Opmask m_ic_tail_0000 = Xbyak::Opmask(7); |
500 | // Used by compute_ic_block_step_extern (1st_conv only) |
501 | Xbyak::Opmask everyother_mask = Xbyak::Opmask(6); |
502 | Xbyak::Opmask everyother_shift_mask = Xbyak::Opmask(7); |
503 | // Used by compute_ic_block_step_interleave (1st_conv only) |
504 | Xbyak::Opmask underflow_mask = Xbyak::Opmask(4); |
505 | Xbyak::Opmask overflow_mask = Xbyak::Opmask(5); |
506 | Xbyak::Opmask underflow_stride_mask = Xbyak::Opmask(6); |
507 | Xbyak::Opmask overflow_stride_mask = Xbyak::Opmask(7); |
508 | |
509 | using reg64_t = const Xbyak::Reg64; |
510 | enum { |
511 | sizeof_cacheline = 64, |
512 | full_spat_opt_working_set_size = 48 * 1024, |
513 | full_spat_max_working_set_size = 128 * 1024, |
514 | ker_code_size = 1024 * 1024, |
515 | }; |
516 | static const int max_ur_w; |
517 | |
518 | reg64_t param = abi_param1; |
519 | reg64_t reg_src = rax; |
520 | reg64_t reg_kernel = rdx; |
521 | reg64_t reg_ddst = rsi; |
522 | reg64_t b_ic = abi_not_param1; |
523 | reg64_t kj = r8; |
524 | reg64_t reg_kh = r9; |
525 | reg64_t reg_ur_w_trips = r10; |
526 | reg64_t reg_oj = r15; |
527 | reg64_t reg_tmp = r14; |
528 | reg64_t reg_ih_shift = reg_tmp; |
529 | reg64_t reg_long_offt = r14; |
530 | reg64_t reg_icb = rbx; |
531 | |
532 | reg64_t ki = r11; |
533 | reg64_t reg_oj_setup = r11; |
534 | reg64_t reg_kd_count = r12; |
535 | reg64_t reg_oi = r12; |
536 | reg64_t reg_d_index = r13; |
537 | reg64_t reg_src_d = r15; |
538 | reg64_t reg_ddst_d = rbx; |
539 | reg64_t aux_reg_src = r12; |
540 | reg64_t aux_reg_kernel = r13; |
541 | |
542 | Xbyak::Zmm vreg_bias_acc = Xbyak::Zmm(0); |
543 | Xbyak::Zmm vreg_bias_unit = Xbyak::Zmm(1); |
544 | Xbyak::Zmm vreg_bias_ddst = Xbyak::Zmm(2); |
545 | |
546 | Xbyak::Zmm one = Xbyak::Zmm(27); |
547 | Xbyak::Zmm even = Xbyak::Zmm(28); |
548 | Xbyak::Zmm selector = Xbyak::Zmm(29); |
549 | Xbyak::Zmm tmp0 = Xbyak::Zmm(30); |
550 | Xbyak::Zmm tmp1 = Xbyak::Zmm(31); |
551 | reg64_t scratch = r11; |
552 | |
553 | inline void maybe_zero_kernel(); |
554 | inline void get_ur_w(int &ur_w, int &ur_w_tail, int &ur_w_trips); |
555 | inline void compute_oh_step_unroll_ow_icblock(int ic_block_step); |
556 | inline void od_step_comeback_pointers(); |
557 | inline void oh_step_comeback_pointers(); |
558 | inline void compute_oh_step_unroll_ow(int ic_block_step); |
559 | inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, |
560 | int ic_block_step, int src_offset, int kernel_offset, |
561 | int ddst_offset, bool is_tail = false); |
562 | inline void compute_ic_block_step_extern(int ur_w, int pad_l, int pad_r, |
563 | int ic_block_step, int src_offset, int kernel_offset, |
564 | int ddst_offset, bool is_tail = false); |
565 | inline void compute_ic_block_step_interleave(int ur_w, int pad_l, int pad_r, |
566 | int ic_block_step, int src_offset, int kernel_offset, |
567 | int ddst_offset, bool is_tail = false); |
568 | inline void compute_ic_block_step_vpermw(int ur_w, int pad_l, int pad_r, |
569 | int ic_block_step, int src_offset, int kernel_offset, |
570 | int ddst_offset, bool is_tail = false); |
571 | inline void compute_oh_step_common(int ic_block_step); |
572 | inline void compute_oh_step_disp(); |
573 | inline void compute_loop(); |
574 | inline void compute_oh_loop_common(bool partial = false); |
575 | inline void compute_od_loop_common(bool partial = false); |
576 | void compute_full_spat_loop(); |
577 | void compute_diff_bias_init(); |
578 | void compute_diff_bias_row(bool is_partial = true); |
579 | void maybe_compute_diff_bias(); |
580 | void convert_src_to_vnni_format( |
581 | int ur_w, int pad_l, int pad_r, int src_offset); |
582 | void may_be_set_oc_tail_mask(); |
583 | void may_be_reset_oc_tail_mask(); |
584 | inline void compute_ic_block_step_vpermw_expl(int ur_w, int pad_l, |
585 | int pad_r, int ic_block_step, int src_offset, int kernel_offset, |
586 | int ddst_offset, bool is_tail = false); |
587 | inline bool is_src_layout_nxc() { |
588 | return jcp.uses_permw_transposition |
589 | && utils::one_of(jcp.src_tag, format_tag::ndhwc, |
590 | format_tag::nhwc, format_tag::nwc); |
591 | } |
592 | inline bool is_ddst_layout_nxc() { |
593 | return jcp.uses_permw_transposition |
594 | && utils::one_of(jcp.dst_tag, format_tag::ndhwc, |
595 | format_tag::nhwc, format_tag::nwc); |
596 | } |
597 | |
598 | void generate() override; |
599 | |
600 | static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, |
601 | int &nthr_g, int &nthr_oc_b, int &nthr_ic_b); |
602 | |
603 | void get_w_positions(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw, |
604 | int &iw_0, int &iw_1) { |
605 | auto get_w_position = [=](int idx) { |
606 | int iw = i_ur + idx; |
607 | if (iw >= ur_w) return -1; |
608 | iw += i_kw; |
609 | if (iw - pad_l < 0 || iw > (ur_w - 1) + (jcp.kw - 1) - pad_r) |
610 | return -1; |
611 | return iw - pad_l; |
612 | }; |
613 | iw_0 = get_w_position(0); |
614 | iw_1 = get_w_position(1); |
615 | } |
616 | bool check_borders(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw) { |
617 | int iw_1, iw_2; |
618 | get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2); |
619 | |
620 | return (iw_1 == -1 && iw_2 == -1) ? false : true; |
621 | } |
622 | bool get_load_mask(int ur_w, int pad_l, int pad_r, int i_ur, int i_kw, |
623 | Xbyak::Opmask &load_mask) { |
624 | int iw_1, iw_2; |
625 | get_w_positions(ur_w, pad_l, pad_r, i_ur, i_kw, iw_1, iw_2); |
626 | |
627 | bool rt = true; |
628 | if (iw_1 != -1 && iw_2 != -1) |
629 | load_mask = m_ffffffff; |
630 | else if (iw_1 != -1 && iw_2 == -1) |
631 | load_mask = m_0000ffff; |
632 | else if (iw_1 == -1 && iw_2 != -1) |
633 | load_mask = m_ffff0000; |
634 | else |
635 | rt = false; |
636 | |
637 | return rt; |
638 | } |
639 | |
640 | inline dim_t filter_w_to_src(int kw, int ow = 0, int pad_l = 0) { |
641 | int stride_w = jcp.transpose_src ? 1 : jcp.stride_w; |
642 | return kw * (jcp.dilate_w + 1) + ow * stride_w - pad_l; |
643 | } |
644 | inline dim_t filter_h_to_src(int kh) { return kh * (jcp.dilate_h + 1); } |
645 | inline dim_t filter_d_to_src(int kd) { |
646 | return kd * (jcp.dilate_d + 1) * jcp.ih; |
647 | } |
648 | |
649 | inline dim_t get_src_offset(dim_t ic_idx, dim_t w_idx, dim_t hd_idx = 0) { |
650 | // For is_src_layout_nxc() the ic_idx index inside the block |
651 | // is supported only ic_idx == jcp.ic_block is considered as a shift |
652 | // within one block and not as moving to the next ic block. |
653 | assert(IMPLICATION(!is_src_layout_nxc(), ic_idx <= jcp.ic_block)); |
654 | dim_t icb = is_src_layout_nxc() ? ic_idx / jcp.ic_block : 0; |
655 | dim_t ic = is_src_layout_nxc() ? ic_idx % jcp.ic_block : ic_idx; |
656 | dim_t iw_str = jcp.is_1stconv || jcp.transpose_src |
657 | ? 1 |
658 | : (is_src_layout_nxc() ? jcp.ngroups * jcp.ic : jcp.ic_block); |
659 | dim_t ihid_str |
660 | = jcp.tr_iw * (jcp.transpose_src ? jcp.ic_block : iw_str); |
661 | // jcp.transpose_src w_idx might be greater than jcp.tr_iw as right zero |
662 | // padding memory is shared with left zero padding of the next block |
663 | dim_t isp_off = hd_idx * ihid_str + w_idx * iw_str; |
664 | dim_t full_spatial_size = (dim_t)jcp.tr_iw * jcp.ih * jcp.id; |
665 | dim_t ic_str = jcp.transpose_src |
666 | ? jcp.tr_iw |
667 | : (jcp.is_1stconv ? full_spatial_size : 1); |
668 | dim_t icb_str |
669 | = jcp.ic_block * (is_src_layout_nxc() ? 1 : full_spatial_size); |
670 | return jcp.typesize_in * (isp_off + icb_str * icb + ic_str * ic); |
671 | } |
672 | |
673 | inline dim_t get_ddst_offset(dim_t w_idx, dim_t hd_idx = 0) { |
674 | int ow_per_oc = jcp.transpose_dst ? 2 : 1; |
675 | int ch_mult |
676 | = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block; |
677 | dim_t hd_off = jcp.tr_ow * ch_mult * hd_idx; |
678 | dim_t w_off |
679 | = w_idx / ow_per_oc * ow_per_oc * ch_mult + w_idx % ow_per_oc; |
680 | return jcp.typesize_in * (w_off + hd_off); |
681 | } |
682 | |
683 | inline dim_t get_kernel_offset(int ic_idx, dim_t ksp_idx) { |
684 | // Only the ic_idx index inside the block is supported, |
685 | // ic_idx == jcp.ic_block is considered as a shift inside one block |
686 | // and not as moving to the next ic block. |
687 | // Negative values are supported for negative shift. |
688 | assert(nstl::abs(ic_idx) <= jcp.ic_block); |
689 | return jcp.typesize_out * jcp.oc_block |
690 | * (ksp_idx * jcp.ic_block + ic_idx); |
691 | } |
692 | |
693 | Xbyak::Zmm get_perm_reg() { |
694 | int idx = !(jcp.uses_permw_transposition |
695 | && jcp.kernel_kind == expl_bcast) |
696 | ? 24 |
697 | : ((!isa_has_bf16(jcp.isa)) ? 26 : 31); |
698 | return Xbyak::Zmm(idx); |
699 | } |
700 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
701 | |
702 | inline int interleave_w_reorder_size(int ur_w) const; |
703 | inline int interleave_w_reorder_bytes(int ur_w); |
704 | inline int interleave_stack_size(int ur_w, int ic_block_step); |
705 | inline int permw_stack_size(int ur_w) { |
706 | return (ur_w + jcp.kw - 1) * sizeof_cacheline; |
707 | } |
708 | |
709 | inline void setup_stack_space(); |
710 | static const int extern_ic_block_step_stack_size = 0; |
711 | int ic_block_step_stack_size = 0; |
712 | int stack_space_needed = 0; |
713 | int permw_buffer_start = 0; |
714 | int kd_count_offset = 0; |
715 | int src_d_offset = 0; |
716 | int ddst_d_offset = 0; |
717 | int d_index_offset = 0; |
718 | int trans_tmp_offset = 0; |
719 | int ih_dilate_shift = 0; |
720 | int icb_loop_ker_ptr = 0; |
721 | int icb_loop_src_ptr = 0; |
722 | }; |
723 | } // namespace x64 |
724 | } // namespace cpu |
725 | } // namespace impl |
726 | } // namespace dnnl |
727 | #endif |
728 | |