1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | |
26 | #include "cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp" |
27 | #include "cpu/x64/jit_generator.hpp" |
28 | #include "cpu/x64/jit_primitive_conf.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::format_kind; |
36 | using namespace dnnl::impl::memory_tracking::names; |
37 | using namespace dnnl::impl::utils; |
38 | using namespace Xbyak; |
39 | |
40 | /// SRC TRANSFORMS ///////////////////////////////////////////////////////////// |
41 | struct jit_avx512_core_f32_wino_conv_2x3_src_trans_t : public jit_generator { |
42 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_src_trans_t) |
43 | |
44 | jit_conv_conf_2x3_wino_t jcp; |
45 | |
46 | struct call_params_t { |
47 | const void *src; |
48 | const void *wino_src; |
49 | const void *v_y_masks; |
50 | const void *v_x_masks; |
51 | }; |
52 | |
53 | jit_avx512_core_f32_wino_conv_2x3_src_trans_t( |
54 | const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr) |
55 | : jit_generator(jit_name()), jcp(ajcp) {} |
56 | |
57 | void generate() override; |
58 | |
59 | Zmm vreg_inp(int i) const { |
60 | assert(i < jcp.alpha * jcp.alpha); |
61 | return Zmm(31 - i); |
62 | } |
63 | |
64 | Zmm vreg_tmp(int i) const { |
65 | assert(i < jcp.alpha * jcp.alpha); |
66 | return Zmm(15 - i); |
67 | } |
68 | |
69 | Zmm vreg_out(int i) const { |
70 | assert(i < jcp.alpha * jcp.alpha); |
71 | return Zmm(31 - i); |
72 | } |
73 | |
74 | Opmask y_mask = Opmask(1); |
75 | Opmask r_mask = Opmask(2); |
76 | Opmask x_mask(int id) { |
77 | assert(id < 4); |
78 | return Opmask(3 + id); |
79 | } |
80 | |
81 | Reg64 reg_ptr_v_y_masks = r12; |
82 | Reg64 reg_ptr_v_x_masks = r11; |
83 | |
84 | Reg64 reg_aux_ptr_src = r10; |
85 | Reg64 reg_aux_ptr_dst = r9; |
86 | |
87 | Reg64 reg_ic_block = r8; |
88 | }; |
89 | |
90 | void jit_avx512_core_f32_wino_conv_2x3_src_trans_t::generate() { |
91 | Label ic_block_label; |
92 | |
93 | const int load_block = 16; |
94 | int out_offset = 0, inp_offset = 0; |
95 | preamble(); |
96 | |
97 | #define READ_PARAM(reg, field) \ |
98 | mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) |
99 | READ_PARAM(reg_aux_ptr_src, src); |
100 | READ_PARAM(reg_aux_ptr_dst, wino_src); |
101 | READ_PARAM(reg_ptr_v_y_masks, v_y_masks); |
102 | READ_PARAM(reg_ptr_v_x_masks, v_x_masks); |
103 | #undef READ_PARAM |
104 | |
105 | for (int i = 0; i < jcp.alpha; i++) { |
106 | kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); |
107 | } |
108 | mov(reg_ic_block, jcp.ic / load_block); |
109 | L(ic_block_label); |
110 | { |
111 | for (int y = 0; y < jcp.alpha; y++) { |
112 | kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); |
113 | for (int x = 0; x < jcp.alpha; x++) { |
114 | Zmm zmm = vreg_inp(y * jcp.alpha + x); |
115 | |
116 | vxorps(zmm, zmm, zmm); |
117 | kandw(r_mask, y_mask, x_mask(x)); |
118 | inp_offset = sizeof(float) |
119 | * ((-jcp.t_pad + y) * jcp.iw * load_block |
120 | + (-jcp.l_pad + x) * load_block); |
121 | vmovups(zmm | r_mask, |
122 | EVEX_compress_addr(reg_aux_ptr_src, inp_offset)); |
123 | } |
124 | } |
125 | for (int y = 0; y < jcp.alpha; y++) { |
126 | vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0), |
127 | vreg_inp(y * jcp.alpha + 2)); |
128 | vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1), |
129 | vreg_inp(y * jcp.alpha + 2)); |
130 | vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2), |
131 | vreg_inp(y * jcp.alpha + 1)); |
132 | vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1), |
133 | vreg_inp(y * jcp.alpha + 3)); |
134 | } |
135 | for (int x = 0; x < jcp.alpha; x++) { |
136 | vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0), |
137 | vreg_tmp(x + jcp.alpha * 2)); |
138 | vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), |
139 | vreg_tmp(x + jcp.alpha * 2)); |
140 | vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2), |
141 | vreg_tmp(x + jcp.alpha * 1)); |
142 | vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1), |
143 | vreg_tmp(x + jcp.alpha * 3)); |
144 | } |
145 | |
146 | for (int i = 0; i < 16; i++) { |
147 | out_offset = sizeof(float) * (jcp.inp_stride * i); |
148 | vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), |
149 | vreg_out(i)); |
150 | } |
151 | |
152 | add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block); |
153 | add(reg_aux_ptr_dst, sizeof(float) * load_block); |
154 | } |
155 | dec(reg_ic_block); |
156 | cmp(reg_ic_block, 0); |
157 | jg(ic_block_label, T_NEAR); |
158 | postamble(); |
159 | } |
160 | |
161 | /// DST TRANSFORMS ///////////////////////////////////////////////////////////// |
162 | struct jit_avx512_core_f32_wino_conv_2x3_dst_trans_t : public jit_generator { |
163 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_dst_trans_t) |
164 | |
165 | jit_conv_conf_2x3_wino_t jcp; |
166 | const primitive_attr_t &attr_; |
167 | |
168 | struct call_params_t { |
169 | const void *wino_dst; |
170 | const void *dst; |
171 | const void *v_y_masks; |
172 | const void *v_x_masks; |
173 | |
174 | const void *bias; |
175 | }; |
176 | |
177 | jit_avx512_core_f32_wino_conv_2x3_dst_trans_t( |
178 | const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr) |
179 | : jit_generator(jit_name()), jcp(ajcp), attr_(attr) {} |
180 | |
181 | void generate() override; |
182 | bool maybe_relu(int position); |
183 | |
184 | Zmm vreg_inp(int i) const { // 16 |
185 | assert(i < jcp.alpha * jcp.alpha); |
186 | return Zmm(31 - i); |
187 | } |
188 | |
189 | Zmm vreg_stg(int id) const { // 8 |
190 | const int id_reg_stg = jcp.alpha * jcp.alpha + id; |
191 | assert(id_reg_stg < jcp.alpha * jcp.alpha + 8); |
192 | return Zmm(31 - id_reg_stg); |
193 | } |
194 | |
195 | Zmm vreg_out(int id) const { // 4 |
196 | const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id; |
197 | assert(id_reg_out < jcp.alpha * jcp.alpha + 12); |
198 | return Zmm(31 - id_reg_out); |
199 | } |
200 | |
201 | Zmm vreg_tmp(int id) const { // 2 |
202 | const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id; |
203 | assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14); |
204 | return Zmm(31 - id_reg_tmp); |
205 | } |
206 | |
207 | Zmm vreg_zero = Zmm(0); |
208 | Zmm vreg_prev_dst = Zmm(0); |
209 | Zmm vreg_bias = Zmm(2); |
210 | |
211 | Opmask y_mask = Opmask(1); |
212 | Opmask r_mask = Opmask(2); |
213 | Opmask x_mask(int id) { |
214 | assert(id < 4); |
215 | return Opmask(3 + id); |
216 | } |
217 | |
218 | Reg64 reg_ptr_v_y_masks = r12; |
219 | Reg64 reg_ptr_v_x_masks = r11; |
220 | |
221 | Reg64 reg_aux_ptr_src = r10; |
222 | Reg64 reg_aux_ptr_dst = r9; |
223 | |
224 | Reg64 reg_oc_block = r8; |
225 | |
226 | Reg64 reg_ptr_bias = rbx; |
227 | Reg64 reg_ptr_sum_scale = rdx; |
228 | }; |
229 | |
230 | bool jit_avx512_core_f32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) { |
231 | using namespace primitive_kind; |
232 | const auto &p = attr_.post_ops_; |
233 | |
234 | if (position == 0) { |
235 | /* relu before sum */ |
236 | return false || p.contain(eltwise, 0); |
237 | } else if (position == 1) { |
238 | /* relu after sum */ |
239 | const int sum_idx |
240 | = p.contain(sum, 0) ? 0 : (p.contain(sum, 1) ? 1 : -1); |
241 | if (sum_idx == -1) return false; |
242 | |
243 | return false || p.contain(eltwise, sum_idx + 1); |
244 | } |
245 | |
246 | return false; |
247 | } |
248 | |
249 | void jit_avx512_core_f32_wino_conv_2x3_dst_trans_t::generate() { |
250 | Label oc_block_label; |
251 | |
252 | const int load_block = 16; |
253 | |
254 | auto loop_body = [=]() { |
255 | const auto &p = attr_.post_ops_; |
256 | const int sum_idx = p.find(primitive_kind::sum); |
257 | const float *p_sum_scale |
258 | = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr; |
259 | if (p_sum_scale && *p_sum_scale != 1.f) |
260 | mov(reg_ptr_sum_scale, (size_t)p_sum_scale); |
261 | |
262 | for (int i = 0; i < 16; i++) { |
263 | int internal_offset = sizeof(float) * jcp.out_stride * i; |
264 | vmovups(vreg_inp(i), |
265 | EVEX_compress_addr(reg_aux_ptr_src, internal_offset)); |
266 | } |
267 | for (int y = 0; y < jcp.alpha; y++) { |
268 | vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1)); |
269 | vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2)); |
270 | |
271 | vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2)); |
272 | vsubps(vreg_stg(y * 2 + 1), vreg_tmp(1), vreg_inp(y * 4 + 3)); |
273 | } |
274 | for (int x = 0; x < jcp.m; x++) { |
275 | vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x + 2 * 1)); |
276 | vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x + 2 * 2)); |
277 | |
278 | vsubps(vreg_tmp(1), vreg_stg(x + 2 * 1), vreg_stg(x + 2 * 2)); |
279 | vsubps(vreg_out(x + 2), vreg_tmp(1), vreg_stg(x + 2 * 3)); |
280 | } |
281 | |
282 | if (jcp.with_bias) { |
283 | auto bias_addr = ptr[reg_ptr_bias]; |
284 | vmovups(vreg_bias, bias_addr); |
285 | } |
286 | for (int y = 0; y < jcp.m; y++) { |
287 | kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]); |
288 | for (int x = 0; x < jcp.m; x++) { |
289 | kandw(r_mask, y_mask, x_mask(x)); |
290 | |
291 | int i = y * jcp.m + x; |
292 | int offset = sizeof(float) |
293 | * (y * jcp.ow * jcp.oc_block + x * jcp.oc_block); |
294 | Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset); |
295 | |
296 | Zmm zmm = vreg_out(i); |
297 | if (jcp.with_bias) vaddps(zmm, zmm, vreg_bias); |
298 | |
299 | if (maybe_relu(0)) { |
300 | vxorps(vreg_zero, vreg_zero, vreg_zero); |
301 | vmaxps(zmm, vreg_zero, zmm); |
302 | } |
303 | if (p_sum_scale) { // post_op: sum |
304 | vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst); |
305 | vmovups(vreg_prev_dst | r_mask, addr); |
306 | if (*p_sum_scale == 1.f) |
307 | vaddps(zmm, vreg_prev_dst); |
308 | else |
309 | vfmadd231ps( |
310 | zmm, vreg_prev_dst, zword_b[reg_ptr_sum_scale]); |
311 | } |
312 | if (maybe_relu(1)) { |
313 | vxorps(vreg_zero, vreg_zero, vreg_zero); |
314 | vmaxps(zmm, vreg_zero, zmm); |
315 | } |
316 | |
317 | vmovups(addr, zmm | r_mask); |
318 | } |
319 | } |
320 | }; |
321 | |
322 | preamble(); |
323 | |
324 | #define READ_PARAM(reg, field) \ |
325 | mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) |
326 | READ_PARAM(reg_aux_ptr_src, wino_dst); |
327 | READ_PARAM(reg_aux_ptr_dst, dst); |
328 | READ_PARAM(reg_ptr_v_y_masks, v_y_masks); |
329 | READ_PARAM(reg_ptr_v_x_masks, v_x_masks); |
330 | READ_PARAM(reg_ptr_bias, bias); |
331 | #undef READ_PARAM |
332 | |
333 | for (int i = 0; i < jcp.alpha * jcp.alpha; i++) |
334 | vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i)); |
335 | |
336 | for (int i = 0; i < jcp.alpha; i++) |
337 | kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]); |
338 | |
339 | int oc_blocks = 1; |
340 | oc_blocks = jcp.oc / load_block; |
341 | mov(reg_oc_block, oc_blocks); |
342 | L(oc_block_label); |
343 | { |
344 | loop_body(); |
345 | add(reg_aux_ptr_src, sizeof(float) * load_block); |
346 | add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block); |
347 | |
348 | add(reg_ptr_bias, jcp.typesize_bia * load_block); |
349 | } |
350 | dec(reg_oc_block); |
351 | cmp(reg_oc_block, 0); |
352 | jg(oc_block_label, T_NEAR); |
353 | |
354 | sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block); |
355 | |
356 | postamble(); |
357 | } |
358 | |
359 | /// GEMM kernel //////////////////////////////////////////////////////////////// |
360 | struct jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t : public jit_generator { |
361 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t) |
362 | jit_conv_conf_2x3_wino_t jcp; |
363 | |
364 | struct call_params_t { |
365 | const void *src; |
366 | const void *dst; |
367 | const void *wei; |
368 | const void *dst_b; |
369 | }; |
370 | |
371 | void generate() override; |
372 | static bool post_ops_ok( |
373 | jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr); |
374 | |
375 | jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t( |
376 | const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr) |
377 | : jit_generator(jit_name()), jcp(ajcp) {} |
378 | |
379 | static status_t init_conf(jit_conv_conf_2x3_wino_t &jcp, |
380 | const convolution_desc_t &cd, memory_desc_t &src_md, |
381 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
382 | memory_desc_t &bias_md, const primitive_attr_t &attr, |
383 | memory_desc_t &expect_wei_md); |
384 | |
385 | Zmm vreg_out(int n, int m) const { |
386 | const int id_reg_out = n * jcp.m_block + m; |
387 | assert(id_reg_out < jcp.n2_block * jcp.m_block); |
388 | return Zmm(31 - id_reg_out); |
389 | } |
390 | Zmm vreg_wei(int i) const { |
391 | assert(31 - jcp.n2_block * jcp.m_block - i > 1); |
392 | return Zmm(31 - jcp.n2_block * jcp.m_block - i); |
393 | } |
394 | |
395 | Zmm vreg_src = Zmm(0); |
396 | Zmm vreg_one = Zmm(1); |
397 | Zmm vreg_tmp = Zmm(2); |
398 | |
399 | Reg64 reg_ptr_src = r15; |
400 | |
401 | Reg64 reg_aux_dst = r12; |
402 | Reg64 reg_aux_dst2 = r11; |
403 | Reg64 reg_aux_wei = r10; |
404 | Reg64 reg_aux_wei2 = r9; |
405 | Reg64 reg_aux_src = r8; |
406 | Reg64 reg_aux_src2 = rax; |
407 | |
408 | Reg64 reg_mb = rbx; |
409 | Reg64 reg_nnb = rdx; |
410 | Reg64 reg_K = rsi; |
411 | }; |
412 | |
413 | bool jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::post_ops_ok( |
414 | jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) { |
415 | using namespace primitive_kind; |
416 | const auto &p = attr.post_ops_; |
417 | |
418 | auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; |
419 | |
420 | switch (p.len()) { |
421 | case 0: return true; |
422 | case 1: return is_relu(0) || p.contain(sum, 0); |
423 | case 2: |
424 | return (p.contain(sum, 0) && is_relu(1)) |
425 | || (p.contain(sum, 1) && is_relu(0)); |
426 | case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2); |
427 | default: return false; |
428 | } |
429 | } |
430 | |
431 | void jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::generate() { |
432 | Label nnb_loop_label, K_loop_label, mb_loop_label; |
433 | |
434 | preamble(); |
435 | #define READ_PARAM(reg, field) \ |
436 | mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)]) |
437 | READ_PARAM(reg_ptr_src, src); |
438 | READ_PARAM(reg_aux_dst, dst); |
439 | READ_PARAM(reg_aux_wei, wei); |
440 | #undef READ_PARAM |
441 | |
442 | if (!jcp.small_mb) { |
443 | mov(reg_nnb, jcp.n_chunks); |
444 | L(nnb_loop_label); |
445 | } |
446 | mov(reg_aux_dst2, reg_aux_dst); |
447 | mov(reg_aux_src, reg_ptr_src); |
448 | mov(reg_mb, jcp.M / jcp.m_block); |
449 | L(mb_loop_label); |
450 | { |
451 | int nb2 = 0; |
452 | for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { |
453 | for (int m = 0; m < jcp.m_block; m++) { |
454 | vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m)); |
455 | } |
456 | } |
457 | mov(reg_aux_src2, reg_aux_src); |
458 | mov(reg_aux_wei2, reg_aux_wei); |
459 | |
460 | mov(reg_K, jcp.k_chunks); |
461 | L(K_loop_label); |
462 | { |
463 | int wei_offset = 0; |
464 | for (int _i = 0; _i < jcp.k2_block; _i++) { |
465 | for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) { |
466 | if (jcp.small_mb) { |
467 | int wei_offset = sizeof(float) |
468 | * ((nb2 * jcp.nb_ic * jcp.ic_block |
469 | * jcp.oc_block) |
470 | + _i * jcp.oc_block); |
471 | vmovups(vreg_wei(nb2), |
472 | EVEX_compress_addr(reg_aux_wei2, wei_offset)); |
473 | } else { |
474 | vmovups(vreg_wei(nb2), |
475 | EVEX_compress_addr(reg_aux_wei2, |
476 | sizeof(float) * wei_offset)); |
477 | wei_offset += jcp.oc_block; |
478 | } |
479 | } |
480 | for (int m = 0; m < jcp.m_block; m++) { |
481 | int inp_offset = sizeof(float) * (m * jcp.K + _i); |
482 | if (jcp.n2_block > 1) { |
483 | vbroadcastss(vreg_src, |
484 | EVEX_compress_addr(reg_aux_src2, inp_offset)); |
485 | for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) |
486 | vfmadd231ps( |
487 | vreg_out(nb2, m), vreg_wei(nb2), vreg_src); |
488 | } else { |
489 | vfmadd231ps(vreg_out(0, m), vreg_wei(0), |
490 | EVEX_compress_addr( |
491 | reg_aux_src2, inp_offset, true)); |
492 | } |
493 | } |
494 | } |
495 | add(reg_aux_src2, sizeof(float) * jcp.ic_block); |
496 | if (jcp.small_mb) |
497 | add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block); |
498 | else |
499 | add(reg_aux_wei2, |
500 | sizeof(float) * jcp.k2_block * jcp.n2_block |
501 | * jcp.oc_block); |
502 | } |
503 | dec(reg_K); |
504 | cmp(reg_K, 0); |
505 | jg(K_loop_label, T_NEAR); |
506 | |
507 | for (int m = 0; m < jcp.m_block; m++) { |
508 | int nb2 = 0; |
509 | for (nb2 = 0; nb2 < jcp.n2_block; nb2++) { |
510 | int offset = sizeof(float) * (m * jcp.N + nb2 * jcp.oc_block); |
511 | vmovups(EVEX_compress_addr(reg_aux_dst2, offset), |
512 | vreg_out(nb2, m)); |
513 | } |
514 | } |
515 | add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K); |
516 | add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N); |
517 | } |
518 | dec(reg_mb); |
519 | cmp(reg_mb, 0); |
520 | jg(mb_loop_label, T_NEAR); |
521 | |
522 | if (!jcp.small_mb) { |
523 | add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block); |
524 | add(reg_aux_wei, |
525 | sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block |
526 | * jcp.oc_block); |
527 | |
528 | dec(reg_nnb); |
529 | cmp(reg_nnb, 0); |
530 | jg(nnb_loop_label, T_NEAR); |
531 | } |
532 | postamble(); |
533 | } |
534 | |
535 | namespace { |
536 | bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) { |
537 | return jcp.mb >= 4; |
538 | } |
539 | } // namespace |
540 | |
541 | status_t jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t ::init_conf( |
542 | jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd, |
543 | memory_desc_t &src_md, memory_desc_t &wei_md, memory_desc_t &dst_md, |
544 | memory_desc_t &bias_md, const primitive_attr_t &attr, |
545 | memory_desc_t &expect_wei_md) { |
546 | const memory_desc_wrapper src_d(&src_md); |
547 | const memory_desc_wrapper wei_d(&wei_md); |
548 | const memory_desc_wrapper dst_d(&dst_md); |
549 | const memory_desc_wrapper bias_d(&bias_md); |
550 | |
551 | // This kernel only supports 2D convolutions. |
552 | if (src_d.ndims() != 4) return status::unimplemented; |
553 | |
554 | const bool with_groups = wei_d.ndims() == src_d.ndims() + 1; |
555 | |
556 | jcp.nthr = dnnl_get_max_threads(); |
557 | |
558 | jcp.ngroups = with_groups ? wei_d.dims()[0] : 1; |
559 | jcp.mb = src_d.dims()[0]; |
560 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
561 | jcp.oc_without_padding = jcp.oc; |
562 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
563 | jcp.ih = src_d.dims()[2]; |
564 | jcp.iw = src_d.dims()[3]; |
565 | jcp.oh = dst_d.dims()[2]; |
566 | jcp.ow = dst_d.dims()[3]; |
567 | jcp.kh = wei_d.dims()[with_groups + 2]; |
568 | jcp.kw = wei_d.dims()[with_groups + 3]; |
569 | jcp.t_pad = cd.padding[0][0]; |
570 | jcp.l_pad = cd.padding[0][1]; |
571 | jcp.stride_h = cd.strides[0]; |
572 | jcp.stride_w = cd.strides[1]; |
573 | jcp.dilate_h = cd.dilates[0]; |
574 | jcp.dilate_w = cd.dilates[1]; |
575 | |
576 | const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
577 | const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
578 | jcp.r_pad = calculate_end_padding( |
579 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
580 | jcp.b_pad = calculate_end_padding( |
581 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
582 | |
583 | jcp.m = 2; |
584 | jcp.r = 3; |
585 | jcp.alpha = jcp.m + jcp.r - 1; |
586 | int simdw = 16; |
587 | |
588 | format_tag_t dat_tag = format_tag::nChw16c; |
589 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
590 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
591 | |
592 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
593 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
594 | |
595 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
596 | |
597 | if (!post_ops_ok(jcp, attr)) return status::unimplemented; |
598 | |
599 | bool ok_to_pad_channels = jcp.ngroups == 1; |
600 | if (ok_to_pad_channels) { |
601 | jcp.oc = rnd_up(jcp.oc, simdw); |
602 | jcp.ic = rnd_up(jcp.ic, simdw); |
603 | } |
604 | |
605 | if (!(mayiuse(avx512_core))) return status::unimplemented; |
606 | |
607 | if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, |
608 | is_winograd_faster_than_direct(jcp))) |
609 | return status::unimplemented; |
610 | |
611 | if (src_d.data_type() != data_type::f32) return status::unimplemented; |
612 | if (wei_d.data_type() != data_type::f32) return status::unimplemented; |
613 | if (dst_d.data_type() != data_type::f32) return status::unimplemented; |
614 | |
615 | jcp.ic_block = simdw; |
616 | jcp.oc_block = simdw; |
617 | |
618 | bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 |
619 | && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0 |
620 | && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0 |
621 | && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad |
622 | && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0 |
623 | && jcp.l_pad < 2 && jcp.l_pad >= 0; |
624 | if (!ok) return status::unimplemented; |
625 | |
626 | const int L2_capacity |
627 | = platform::get_per_core_cache_size(2) / sizeof(float); |
628 | const int L3_capacity |
629 | = platform::get_per_core_cache_size(3) * jcp.nthr / sizeof(float); |
630 | int a = jcp.alpha; |
631 | int aa = a * a; |
632 | int mb = jcp.mb; |
633 | int ic = jcp.ic; |
634 | int oc = jcp.oc; |
635 | int ih = jcp.ih; |
636 | int iw = jcp.iw; |
637 | auto wei_sz = (float)aa * ic * oc; |
638 | auto inp_sz = (float)mb * ih * iw * ic; |
639 | auto sp_sz = (float)mb * ih * iw; |
640 | |
641 | /* Heuristics here. Numbers '28','196' is an observation from data. */ |
642 | if (wei_sz / inp_sz > 5) |
643 | jcp.small_mb = true; |
644 | else |
645 | jcp.small_mb = false; |
646 | |
647 | if (mb > nstl::min(jcp.nthr, 28) |
648 | || (!jcp.small_mb |
649 | && (wei_sz >= 0.9f * L2_capacity |
650 | || inp_sz > L2_capacity * jcp.nthr + L3_capacity)) |
651 | || (jcp.small_mb && sp_sz > 196)) |
652 | return status::unimplemented; |
653 | |
654 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
655 | jcp.dst_dt = cd.dst_desc.data_type; |
656 | |
657 | jcp.typesize_bia |
658 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
659 | |
660 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
661 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
662 | |
663 | const int skx_free_regs = 30; |
664 | |
665 | auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block, |
666 | int &n2_block, float ®_eff) { |
667 | M = (xb * yb) / jcp.alpha; |
668 | int max_m_block = m_block = nstl::min(M, skx_free_regs); |
669 | int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs); |
670 | reg_eff = 0; |
671 | for (int im = max_m_block; im > 0; im--) { |
672 | for (int in2 = max_n2_block; in2 > 0; in2--) { |
673 | int used_regs = in2 * im + in2; |
674 | float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f; |
675 | if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs |
676 | || cur_reg_eff <= reg_eff) |
677 | continue; |
678 | reg_eff = cur_reg_eff; |
679 | m_block = im; |
680 | n2_block = in2; |
681 | } |
682 | } |
683 | }; |
684 | |
685 | int oh = jcp.oh; |
686 | int ow = jcp.ow; |
687 | int nb_oc = jcp.nb_oc; |
688 | int Z = ic + oc; |
689 | int Y = ic * oc; |
690 | const int L3_cap_per_core |
691 | = platform::get_per_core_cache_size(3) / sizeof(float); |
692 | |
693 | /* Selecting xb and yb blocking */ |
694 | int min_yb = jcp.alpha; |
695 | int min_xb = jcp.alpha; |
696 | int max_yb = nstl::max(min_yb, rnd_up(ih, 2)); |
697 | int max_xb = nstl::max(min_xb, rnd_up(iw, 2)); |
698 | float best_eff = 0.f; |
699 | for (int ix = max_xb; ix >= min_xb; ix -= 2) { |
700 | if (rnd_up(ow, ix) < iw - 2) continue; |
701 | for (int iy = max_yb; iy >= min_yb; iy -= 2) { |
702 | if (rnd_up(oh, iy) < ih - 2) continue; |
703 | int ex_y = rnd_up(oh, iy); |
704 | int ex_x = rnd_up(ow, ix); |
705 | float work_eff = (float)(ih * iw) / (ex_y * ex_x); |
706 | |
707 | int M, m_block, n2_b; |
708 | float reg_eff, thr_eff, par_eff, mem_eff, req_mem; |
709 | |
710 | find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff); |
711 | |
712 | /* outer parallelization */ |
713 | int nblocks = mb * div_up(ih, iy) * div_up(iw, ix); |
714 | thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr); |
715 | |
716 | mem_eff = 1.f; |
717 | req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y; |
718 | if (req_mem > L2_capacity / 2) { |
719 | if (req_mem > ((L2_capacity + L3_cap_per_core) * 4) / 7) |
720 | mem_eff /= (n2_b + 1) / 2.f; |
721 | else |
722 | mem_eff /= (n2_b + 1) / 3.f; |
723 | } |
724 | |
725 | float outer_eff = thr_eff + work_eff + reg_eff + mem_eff; |
726 | |
727 | /* inner parallelization */ |
728 | int bsz = iy * ix / a; |
729 | int gemmw = aa * (nb_oc / n2_b); |
730 | int bsz_r = rnd_up(bsz, jcp.nthr); |
731 | int gemmw_r = rnd_up(gemmw, jcp.nthr); |
732 | thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y); |
733 | |
734 | req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic; |
735 | mem_eff = nstl::min(1.f, L2_capacity / req_mem); |
736 | int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr)); |
737 | int oc_per_thr |
738 | = nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr)); |
739 | req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z; |
740 | if (req_mem > L2_capacity) mem_eff = 0.1f; |
741 | par_eff = 1 / (2.f * nblocks); |
742 | |
743 | float inner_eff = thr_eff + work_eff + mem_eff + par_eff; |
744 | |
745 | float eff = jcp.small_mb ? inner_eff : outer_eff; |
746 | if (eff > best_eff) { |
747 | best_eff = eff; |
748 | jcp.yb = iy; |
749 | jcp.xb = ix; |
750 | jcp.M = M; |
751 | jcp.m_block = m_block; |
752 | jcp.n2_block = n2_b; |
753 | } |
754 | } |
755 | } |
756 | |
757 | assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0); |
758 | |
759 | jcp.inp_stride = jcp.M * jcp.ic; |
760 | jcp.out_stride = jcp.M * jcp.oc; |
761 | jcp.wei_stride = jcp.ic * jcp.oc; |
762 | jcp.bia_stride = jcp.oc; |
763 | |
764 | jcp.N = jcp.oc; |
765 | jcp.K = jcp.ic; |
766 | |
767 | jcp.n_block = jcp.oc_block; |
768 | jcp.k_block = jcp.ic_block; |
769 | |
770 | assert(jcp.M % jcp.m_block == 0); |
771 | assert(jcp.nb_oc % jcp.n2_block == 0); |
772 | |
773 | jcp.n_chunks = jcp.nb_oc / jcp.n2_block; |
774 | jcp.k2_block = jcp.ic_block; |
775 | jcp.k_chunks = jcp.K / jcp.k2_block; |
776 | |
777 | /* re-create weights primitive descriptor |
778 | and set weights wino_blocking */ |
779 | expect_wei_md.format_kind = format_kind::wino; |
780 | expect_wei_md.data_type = data_type::f32; |
781 | wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; |
782 | wd.wino_format = jcp.small_mb ? wino_memory_format_t::wino_wei_aaOio |
783 | : wino_memory_format_t::wino_wei_aaOBiOo; |
784 | wd.r = jcp.r; |
785 | wd.alpha = jcp.alpha; |
786 | wd.ic = jcp.ic; |
787 | wd.oc = jcp.oc; |
788 | wd.ic_block = jcp.ic_block; |
789 | wd.oc_block = jcp.oc_block; |
790 | wd.oc2_block = jcp.n2_block; |
791 | wd.ic2_block = 1; |
792 | wd.adj_scale = 1.f; |
793 | size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc; |
794 | wd.size = max_size; |
795 | |
796 | return status::success; |
797 | } |
798 | //////////////////////////////////////////////////////////////////////////////// |
799 | |
800 | status_t jit_avx512_core_f32_wino_conv_2x3_fwd_t ::pd_t::jit_conf( |
801 | memory_desc_t &expect_wei_md) { |
802 | return jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::init_conf(jcp_, |
803 | *this->desc(), this->src_md_, this->weights_md_, this->dst_md_, |
804 | this->bias_md_, *this->attr(), expect_wei_md); |
805 | } |
806 | |
807 | jit_avx512_core_f32_wino_conv_2x3_fwd_t:: |
808 | jit_avx512_core_f32_wino_conv_2x3_fwd_t(const pd_t *apd) |
809 | : primitive_t(apd) {} |
810 | |
811 | status_t jit_avx512_core_f32_wino_conv_2x3_fwd_t::init(engine_t *engine) { |
812 | CHECK(safe_ptr_assign(kernel_, |
813 | new jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t( |
814 | pd()->jcp_, *pd()->attr()))); |
815 | CHECK(safe_ptr_assign(src_trans_, |
816 | new jit_avx512_core_f32_wino_conv_2x3_src_trans_t( |
817 | pd()->jcp_, *pd()->attr()))); |
818 | CHECK(safe_ptr_assign(dst_trans_, |
819 | new jit_avx512_core_f32_wino_conv_2x3_dst_trans_t( |
820 | pd()->jcp_, *pd()->attr()))); |
821 | CHECK(kernel_->create_kernel()); |
822 | CHECK(src_trans_->create_kernel()); |
823 | CHECK(dst_trans_->create_kernel()); |
824 | return status::success; |
825 | } |
826 | |
827 | jit_avx512_core_f32_wino_conv_2x3_fwd_t:: |
828 | ~jit_avx512_core_f32_wino_conv_2x3_fwd_t() |
829 | = default; |
830 | |
831 | void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_mbN( |
832 | const float *src, const float *wei, const float *bia, float *dst, |
833 | const memory_tracking::grantor_t &scratchpad) const { |
834 | const auto &jcp = kernel_->jcp; |
835 | |
836 | const size_t wino_size_offset |
837 | = (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) |
838 | + (pd()->jcp_.xb); |
839 | const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16; |
840 | const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16; |
841 | |
842 | if (pd()->wants_padded_bias()) { |
843 | auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); |
844 | utils::array_copy(padded_bias, bia, jcp.oc_without_padding); |
845 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
846 | jcp.oc - jcp.oc_without_padding); |
847 | bia = padded_bias; |
848 | } |
849 | |
850 | auto ptr_V = scratchpad.get<float>(key_wino_V); |
851 | auto ptr_M = scratchpad.get<float>(key_wino_M); |
852 | |
853 | parallel_nd_ext(jcp.nthr, jcp.mb, div_up(jcp.oh, jcp.yb), |
854 | div_up(jcp.ow, jcp.xb), |
855 | [&](dim_t ithr, dim_t nthr, dim_t mb, dim_t tile_y_b, |
856 | dim_t tile_x_b) { |
857 | assert(nthr <= jcp.nthr); |
858 | MAYBE_UNUSED(nthr); |
859 | |
860 | int tile_y = tile_y_b * jcp.yb; |
861 | int tile_x = tile_x_b * jcp.xb; |
862 | |
863 | auto wino_src = ptr_V + size_wino_src * ithr; |
864 | auto wino_dst = ptr_M + size_wino_dst * ithr; |
865 | |
866 | auto src_trans_p |
867 | = jit_avx512_core_f32_wino_conv_2x3_src_trans_t :: |
868 | call_params_t(); |
869 | auto dst_trans_p |
870 | = jit_avx512_core_f32_wino_conv_2x3_dst_trans_t :: |
871 | call_params_t(); |
872 | auto gemm_p = jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t :: |
873 | call_params_t(); |
874 | |
875 | /* transformation of input tensor to winograd domain */ |
876 | for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { |
877 | for (int x_in_block = 0; x_in_block < jcp.xb; |
878 | x_in_block += 2) { |
879 | |
880 | unsigned short v_y_masks[4], v_x_masks[4]; |
881 | |
882 | int y = y_in_block + tile_y; |
883 | int x = x_in_block + tile_x; |
884 | int m = (y_in_block / 2) * (jcp.xb / 2) |
885 | + (x_in_block / 2); |
886 | |
887 | int v_ys = nstl::max(0, jcp.t_pad - y); |
888 | int v_ye = nstl::min(jcp.alpha, |
889 | nstl::max(0, jcp.ih + jcp.t_pad - y)); |
890 | |
891 | int v_xs = nstl::max(0, jcp.l_pad - x); |
892 | int v_xe = nstl::min(jcp.alpha, |
893 | nstl::max(0, jcp.iw + jcp.l_pad - x)); |
894 | |
895 | #pragma unroll(4) |
896 | for (int i = 0; i < jcp.alpha; i++) { |
897 | v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; |
898 | v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; |
899 | } |
900 | auto local_s = src |
901 | + (dim_t)mb * jcp.nb_ic * jcp.ih * jcp.iw |
902 | * jcp.ic_block |
903 | + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; |
904 | auto local_w = wino_src + m * jcp.ic; |
905 | |
906 | src_trans_p.src = local_s; |
907 | src_trans_p.wino_src = local_w; |
908 | src_trans_p.v_y_masks = v_y_masks; |
909 | src_trans_p.v_x_masks = v_x_masks; |
910 | |
911 | (*src_trans_)(&src_trans_p); |
912 | } |
913 | } |
914 | /* gemms */ |
915 | for (int tile_ij = 0; tile_ij < 16; tile_ij++) { |
916 | int offset = (tile_ij + ithr) % 16; |
917 | gemm_p.src = wino_src + jcp.inp_stride * offset; |
918 | gemm_p.dst = wino_dst + jcp.out_stride * offset; |
919 | gemm_p.wei = wei + jcp.wei_stride * offset; |
920 | |
921 | (*kernel_)(&gemm_p); |
922 | } |
923 | |
924 | /* transformation from winograd domain to output tensor */ |
925 | for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) { |
926 | for (int x_in_block = 0; x_in_block < jcp.xb; |
927 | x_in_block += 2) { |
928 | unsigned short v_y_masks[2], v_x_masks[2]; |
929 | |
930 | int y = y_in_block + tile_y; |
931 | int x = x_in_block + tile_x; |
932 | int m = (y_in_block / 2) * (jcp.xb / 2) |
933 | + (x_in_block / 2); |
934 | |
935 | #pragma unroll(2) |
936 | for (int i = 0; i < jcp.m; i++) { |
937 | v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; |
938 | v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; |
939 | } |
940 | auto local_d = dst |
941 | + (dim_t)mb * jcp.nb_oc * jcp.oh * jcp.ow |
942 | * jcp.oc_block |
943 | + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; |
944 | auto local_w = wino_dst + m * jcp.oc; |
945 | |
946 | dst_trans_p.dst = local_d; |
947 | dst_trans_p.wino_dst = local_w; |
948 | dst_trans_p.v_y_masks = v_y_masks; |
949 | dst_trans_p.v_x_masks = v_x_masks; |
950 | dst_trans_p.bias = bia; |
951 | |
952 | (*dst_trans_)(&dst_trans_p); |
953 | } |
954 | } |
955 | }); |
956 | } |
957 | |
958 | void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_small_mb( |
959 | const float *src, const float *wei, const float *bia, float *dst, |
960 | const memory_tracking::grantor_t &scratchpad) const { |
961 | const auto &jcp = kernel_->jcp; |
962 | |
963 | if (pd()->wants_padded_bias()) { |
964 | auto padded_bias = scratchpad.get<float>(key_conv_padded_bias); |
965 | utils::array_copy(padded_bias, bia, jcp.oc_without_padding); |
966 | utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, |
967 | jcp.oc - jcp.oc_without_padding); |
968 | bia = padded_bias; |
969 | } |
970 | |
971 | auto ptr_V = scratchpad.get<float>(key_wino_V); |
972 | auto ptr_M = scratchpad.get<float>(key_wino_M); |
973 | |
974 | for_(int mb = 0; mb < jcp.mb; mb++) |
975 | for_(int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) |
976 | for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) { |
977 | /* transformation of input tensor to winograd domain */ |
978 | parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), |
979 | [&](dim_t y_in_block_b, dim_t x_in_block_b) { |
980 | int y_in_block = y_in_block_b * 2; |
981 | int x_in_block = x_in_block_b * 2; |
982 | |
983 | auto src_trans_p |
984 | = jit_avx512_core_f32_wino_conv_2x3_src_trans_t :: |
985 | call_params_t(); |
986 | |
987 | unsigned short v_y_masks[4], v_x_masks[4]; |
988 | |
989 | int y = y_in_block + tile_y; |
990 | int x = x_in_block + tile_x; |
991 | int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); |
992 | |
993 | int v_ys = nstl::max(0, jcp.t_pad - y); |
994 | int v_ye = nstl::min( |
995 | jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y)); |
996 | |
997 | int v_xs = nstl::max(0, jcp.l_pad - x); |
998 | int v_xe = nstl::min( |
999 | jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x)); |
1000 | |
1001 | #pragma unroll(4) |
1002 | for (int i = 0; i < jcp.alpha; i++) { |
1003 | v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff; |
1004 | v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff; |
1005 | } |
1006 | auto local_s = src |
1007 | + (dim_t)mb * jcp.nb_ic * jcp.ih * jcp.iw |
1008 | * jcp.ic_block |
1009 | + y * jcp.iw * jcp.ic_block + x * jcp.ic_block; |
1010 | auto local_w = ptr_V + m * jcp.ic; |
1011 | |
1012 | src_trans_p.src = local_s; |
1013 | src_trans_p.wino_src = local_w; |
1014 | src_trans_p.v_y_masks = v_y_masks; |
1015 | src_trans_p.v_x_masks = v_x_masks; |
1016 | |
1017 | (*src_trans_)(&src_trans_p); |
1018 | }); |
1019 | |
1020 | /* gemms */ |
1021 | parallel_nd(16, jcp.n_chunks, [&](dim_t tile_ij, dim_t nnb) { |
1022 | auto gemm_p = jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t :: |
1023 | call_params_t(); |
1024 | |
1025 | gemm_p.src = ptr_V + jcp.inp_stride * tile_ij; |
1026 | gemm_p.dst = ptr_M + jcp.out_stride * tile_ij |
1027 | + nnb * jcp.n2_block * jcp.n_block; |
1028 | gemm_p.wei = wei + jcp.wei_stride * tile_ij |
1029 | + nnb * jcp.n2_block * jcp.n_block * jcp.K; |
1030 | |
1031 | (*kernel_)(&gemm_p); |
1032 | }); |
1033 | |
1034 | /* transformation from winograd domain to output tensor */ |
1035 | |
1036 | parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), |
1037 | [&](dim_t y_in_block_b, dim_t x_in_block_b) { |
1038 | int y_in_block = y_in_block_b * 2; |
1039 | int x_in_block = x_in_block_b * 2; |
1040 | |
1041 | auto dst_trans_p |
1042 | = jit_avx512_core_f32_wino_conv_2x3_dst_trans_t :: |
1043 | call_params_t(); |
1044 | |
1045 | unsigned short v_y_masks[2], v_x_masks[2]; |
1046 | |
1047 | int y = y_in_block + tile_y; |
1048 | int x = x_in_block + tile_x; |
1049 | int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2); |
1050 | |
1051 | #pragma unroll(2) |
1052 | for (int i = 0; i < jcp.m; i++) { |
1053 | v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0; |
1054 | v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0; |
1055 | } |
1056 | auto local_d = dst |
1057 | + (dim_t)mb * jcp.nb_oc * jcp.oh * jcp.ow |
1058 | * jcp.oc_block |
1059 | + y * jcp.ow * jcp.oc_block + x * jcp.oc_block; |
1060 | auto local_w = ptr_M + m * jcp.oc; |
1061 | |
1062 | dst_trans_p.dst = local_d; |
1063 | dst_trans_p.wino_dst = local_w; |
1064 | dst_trans_p.v_y_masks = v_y_masks; |
1065 | dst_trans_p.v_x_masks = v_x_masks; |
1066 | |
1067 | dst_trans_p.bias = bia; |
1068 | |
1069 | (*dst_trans_)(&dst_trans_p); |
1070 | }); |
1071 | } |
1072 | } |
1073 | |
1074 | } // namespace x64 |
1075 | } // namespace cpu |
1076 | } // namespace impl |
1077 | } // namespace dnnl |
1078 | |