1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/nstl.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/platform.hpp"
26
27#include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.hpp"
28
29#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36namespace {
37
38using namespace dnnl::impl::utils;
39
40unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
41unsigned int L2_cache_size = platform::get_per_core_cache_size(2);
42unsigned int LLC_data_size = platform::get_per_core_cache_size(3);
43
44// the test funtion takes jcp, the candidate and the current best.
45// it returns true if the new candidate is better
46int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
47 int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) {
48 int best_divisor = default_best;
49 auto test_num
50 = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
51 if (test(jcp, num, best_divisor)) { best_divisor = num; }
52 };
53
54 for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
55 if (number % divisor == 0) {
56 test_num(jcp, divisor);
57 test_num(jcp, number / divisor);
58 }
59 }
60
61 return best_divisor;
62}
63
64namespace {
65bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
66 /* Determines if current winograd implementation is faster than direct.
67 Following conditions are empirical and based on performance data */
68 unsigned int ncores_per_socket
69 = cpu().getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
70 unsigned int nthreads = dnnl_get_max_threads();
71
72 if (jcp.prop_kind == prop_kind::forward_inference) {
73 return jcp.mb >= 4;
74 } else if (nthreads > ncores_per_socket) {
75 double src_dst_transforms_per_core = alpha * alpha * (jcp.ic + jcp.oc)
76 * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
77 * ((jcp.ow + tile_size - 1) / tile_size) * sizeof(float) / 1024.
78 / 1024. / nthreads;
79 double wei_transform = alpha * alpha * jcp.ic * jcp.oc * sizeof(float)
80 / 1024. / 1024.;
81
82 if (jcp.prop_kind == prop_kind::backward_weights) {
83 if (src_dst_transforms_per_core < 0.3
84 || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
85 return false;
86 else
87 return true;
88 } else {
89 if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
90 return false;
91 }
92 }
93
94 return jcp.mb > 8;
95}
96} // namespace
97
98/* assumes 512 bits registers */
99/* TODO: add support for strides */
100/* TODO: handle the prefetch distance automatically */
101using cache_t = enum cache_t_ { L1, L2, L3 };
102
103template <typename data_t>
104struct prefetcher_t {
105 prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
106 cache_t cache_type, size_t block_size, /* in number of elements*/
107 int nb_instructions_in_block, int fma_ipc)
108 : cg_(generator)
109 , reg_base_addr_(reg_base_addr)
110 , cache_type_(cache_type)
111 , cache_block_size_(block_size) {
112 nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
113 prefetch_spread_
114 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
115 prefetch_blk_
116 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
117
118 /* assumption: when fetch in Li, data is already in L(i+1) */
119 int cache_latency;
120 switch (cache_type_) {
121 case L1: cache_latency = 14; break;
122 case L2: cache_latency = 250; break;
123 case L3: cache_latency = 250; break;
124 }
125
126 prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
127 }
128
129 void prefetch(int instruction_number) {
130 if (instruction_number % prefetch_spread_ == 0) {
131 for (int i = 0; (i < prefetch_blk_)
132 && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
133 i++, prefetches_issued_++) {
134 prefetch_inst_(cg_->EVEX_compress_addr(reg_base_addr_,
135 (cache_block_size_ * prefetch_distance_)
136 * sizeof(data_t)
137 + (prefetches_issued_ * 64)));
138 }
139 }
140 }
141
142private:
143 void prefetch_inst_(const Xbyak::Address &addr) {
144 switch (cache_type_) {
145 case L1: cg_->prefetcht0(addr); break;
146 case L2: cg_->prefetcht1(addr); break;
147 case L3: cg_->prefetcht2(addr); break;
148 default: break; // TODO: raise an exception or put an assert
149 }
150 }
151
152 jit_generator *cg_;
153 Xbyak::Reg64 reg_base_addr_;
154 cache_t cache_type_;
155 int cache_block_size_ = 0;
156 int nb_cache_lines_to_prefetch_ = 0;
157 int prefetches_issued_ = 0;
158 int prefetch_spread_ = 0;
159 int prefetch_blk_ = 0;
160 int prefetch_distance_ = 0;
161};
162
163// utilities to support kernel parameter selection
164bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, int dimN_block,
165 float C2_min, float C2_max) {
166 float block_size = alpha * alpha
167 * (2 * (jcp.oc + jcp.ic) * dimN_block * jcp.dimN_reg_block
168 + div_up(jcp.ic * jcp.oc, jcp.nthr))
169 * (float)sizeof(float);
170 float L2_lb = C2_min * L2_cache_size;
171 float L2_ub = C2_max * L2_cache_size;
172 return (block_size > L2_lb && block_size < L2_ub);
173}
174
175bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
176 int dimM_block, float C1_min, float C1_max) {
177 float gemm_block_size
178 = (dimM_block * jcp.dimM_simd_block * dimK_block
179 * jcp.dimK_reg_block * jcp.dimM_reg_block
180 + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
181 + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
182 * (float)sizeof(float);
183 float L1_lb = C1_min * L1_cache_size;
184 float L1_ub = C1_max * L1_cache_size;
185 return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
186}
187bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
188 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) {
189 float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
190 + dimM_block * dimK_block * dimK_reg_block
191 * dimM_simd_block * dimM_reg_block
192 + dimK_block * dimN_reg_block * dimK_reg_block)
193 * (float)sizeof(float);
194 float rhs = C * L1_cache_size;
195 return (lhs < rhs);
196}
197bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
198 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) {
199 float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
200 * dimM_simd_block
201 + dimK_block * dimN_reg_block * dimK_reg_block)
202 * (float)sizeof(float);
203 float rhs = C * L1_cache_size;
204 return (lhs < rhs);
205}
206bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
207 int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
208 int dimM_simd_block, float C) {
209 float lhs
210 = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
211 * dimM_reg_block
212 + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
213 * dimM_simd_block * dimM_reg_block
214 + nb_dimN_reg_block * dimK_nb_block * dimK_block
215 * dimN_reg_block * dimK_reg_block)
216 * (float)sizeof(float);
217 float rhs = C * L2_cache_size;
218 return (lhs < rhs);
219}
220
221bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
222 int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) {
223 float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
224 * (float)sizeof(float);
225 float B_size = dimN_block * dimN_reg_block * dimK * (float)sizeof(float);
226 return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
227}
228} // namespace
229
230using namespace dnnl::impl::format_tag;
231using namespace dnnl::impl::utils;
232using namespace Xbyak;
233
234void _jit_avx512_core_f32_wino_conv_4x3_data_kernel::gemm_loop_generate() {
235 // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
236 // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
237 // dimM_reg_block++) // unrolled
238 // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
239 // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
240 // dimK_reg_block++) // unrolled
241 // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
242 // C[dimM_block][dimM_reg_block][tile] +=
243 // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
244 // * broadcast(B[dimK_block][tile][dimK_reg_block]);
245 // Notes:
246 // jcp.kernel_kind defines embedded or explicit broadcast
247 // dimM_reg_block=1 for embedded bcast kernel
248
249 auto zmm_srcA = [=]() { return Xbyak::Zmm(0); };
250 auto zmm_srcB = [=](int tile) {
251 int idx = 1 + tile;
252 assert(idx < 1 + jcp.dimN_reg_block);
253 return Xbyak::Zmm(idx);
254 };
255 auto zmm_dstC = [=](int dimM_reg_block, int tile) {
256 int idx {0};
257 if (jcp.kernel_kind == embd_bcast)
258 idx = 1 + tile;
259 else
260 idx = 1 + jcp.dimN_reg_block + dimM_reg_block * jcp.dimN_reg_block
261 + tile;
262 assert(idx < 32);
263 return Xbyak::Zmm(idx);
264 };
265
266 auto prepare_output = [=]() {
267 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
268 dimM_reg_block++) {
269 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
270 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
271 vpxord(zmm, zmm, zmm);
272 }
273 }
274 };
275 auto store_output = [=](bool output_is_aligned) {
276 Label save;
277 cmp(reg_is_beta_zero, 0);
278 je(save, T_NEAR);
279
280 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
281 dimM_reg_block++) {
282 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
283 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
284 int output_offset
285 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
286 vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
287 }
288 }
289
290 L(save);
291 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
292 dimM_reg_block++) {
293 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
294 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
295 int output_offset
296 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
297
298 // In W_SGD, output will be reused.
299 if (output_is_aligned && jcp.dimK_nb_block == 1
300 && jcp.sched_policy == WSCHED_DATA_W_S_G_D
301 && (jcp.dimN * jcp.dimM * alpha * alpha * sizeof(float)
302 > 2 * LLC_data_size * jcp.nthr))
303 vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
304 else
305 vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
306 }
307 }
308 };
309
310 auto inner_loops = [=]() {
311 Label dimM_block_loop, dimK_block_loop;
312
313 if (jcp.dimM_block > 1) {
314 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
315 L(dimM_block_loop);
316 }
317
318 prepare_output();
319
320 if (jcp.dimK_block > 1) {
321 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
322 L(dimK_block_loop);
323 }
324
325 for (int dimK_reg_block = 0; dimK_reg_block < jcp.dimK_reg_block;
326 dimK_reg_block++) {
327
328 if (jcp.kernel_kind == expl_bcast) {
329 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
330 vbroadcastss(zmm_srcB(tile),
331 ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
332 }
333 }
334
335 /* Performing the fmas */
336
337 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
338 dimM_reg_block++) {
339
340 vmovups(zmm_srcA(),
341 zword[reg_srcA
342 + jcp.dimK_reg_block * jcp.dimK_block * 64
343 * dimM_reg_block
344 + dimK_reg_block * 64]);
345
346 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
347 if (jcp.kernel_kind == expl_bcast)
348 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
349 zmm_srcB(tile));
350 else
351 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
352 EVEX_compress_addr(reg_srcB,
353 64 * tile + dimK_reg_block * 4, true));
354 }
355 }
356 }
357 add(reg_srcA, jcp.dimK_reg_block * 64);
358 add(reg_srcB, jcp.dimN_reg_block * 64);
359 if (jcp.dimK_block > 1) {
360 sub(reg_dimK_block_loop_cnt, 1);
361 jnz(dimK_block_loop);
362 }
363
364 Label unaligned_store, end_store;
365 test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
366 jnz(unaligned_store, T_NEAR);
367 store_output(true);
368 jmp(end_store, T_NEAR);
369 L(unaligned_store);
370 { store_output(false); }
371 L(end_store);
372
373 if (jcp.dimM_block > 1) {
374 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
375 add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
376 if (jcp.kernel_kind == expl_bcast) {
377 add(reg_srcA,
378 (jcp.dimM_reg_block - 1) * jcp.dimK_reg_block * 64
379 * jcp.dimK_block);
380 }
381 sub(reg_dimM_block_loop_cnt, 1);
382 jnz(dimM_block_loop);
383 }
384 };
385
386 /* Preamble */
387 preamble();
388
389 /* kernel */
390 inner_loops();
391
392 /* Postamble */
393 postamble();
394 ret();
395}
396
397void _jit_avx512_core_f32_wino_conv_4x3_data_kernel ::
398 weights_transform_data_ker_generate() {
399 bool is_fwd = one_of(
400 jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference);
401 int kh = jcp.kh;
402 int kw = jcp.kw;
403
404 auto zmm_temp = Xbyak::Zmm(31);
405 auto zmm_zero = Xbyak::Zmm(30);
406
407 auto zmm_M = [=](int i) { return Xbyak::Zmm(i); };
408 auto zmm_MT = [=](int i) { return Xbyak::Zmm(i + simd_w); };
409
410 auto zmm_G = [=](int i) { return Xbyak::Zmm(i); };
411 auto zmm_F = [=](int i) { return Xbyak::Zmm(alpha + i); };
412 auto zmm_T = [=](int i) { return Xbyak::Zmm(alpha + 3 + i); };
413 auto zmm_t = [=](int i) { return Xbyak::Zmm(2 * alpha + 3 + i); };
414
415 auto zmm_load = [=](int i) { return Xbyak::Zmm(i); };
416
417 auto init_G = [=]() {
418 mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
419 for (int i = 0; i < alpha; i++) {
420 vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
421 }
422 vpxord(zmm_zero, zmm_zero, zmm_zero);
423 };
424
425 auto trans16x16 = [=]() {
426 for (int i = 0; i < simd_w; i += 2) {
427 vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
428 vmovups(zmm_M(i + 1), ptr[wreg_M + (i + 1) * simd_w * 4]);
429 vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i + 1));
430 vunpckhps(zmm_MT(i + 1), zmm_M(i), zmm_M(i + 1));
431 }
432 for (int i = 0; i < simd_w; i += 4) {
433 vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i + 2));
434 vunpckhpd(zmm_M(i + 1), zmm_MT(i), zmm_MT(i + 2));
435 vunpcklpd(zmm_M(i + 2), zmm_MT(i + 1), zmm_MT(i + 3));
436 vunpckhpd(zmm_M(i + 3), zmm_MT(i + 1), zmm_MT(i + 3));
437 }
438 for (int i = 0; i < simd_w; i += 8) {
439 vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
440 vshuff32x4(zmm_MT(i + 1), zmm_M(i + 1), zmm_M(i + 5), 0x88);
441 vshuff32x4(zmm_MT(i + 2), zmm_M(i + 2), zmm_M(i + 6), 0x88);
442 vshuff32x4(zmm_MT(i + 3), zmm_M(i + 3), zmm_M(i + 7), 0x88);
443 vshuff32x4(zmm_MT(i + 4), zmm_M(i), zmm_M(i + 4), 0xdd);
444 vshuff32x4(zmm_MT(i + 5), zmm_M(i + 1), zmm_M(i + 5), 0xdd);
445 vshuff32x4(zmm_MT(i + 6), zmm_M(i + 2), zmm_M(i + 6), 0xdd);
446 vshuff32x4(zmm_MT(i + 7), zmm_M(i + 3), zmm_M(i + 7), 0xdd);
447 }
448 {
449 int i = 0;
450 int mask = 0x88;
451 vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
452 vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
453 vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
454 vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
455 vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
456 vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
457 vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
458 vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
459 vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
460 vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
461 vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
462 vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
463 vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
464 vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
465 vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
466 vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
467 mask = 0xdd;
468 vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
469 vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
470 vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
471 vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
472 vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
473 vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
474 vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
475 vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
476 vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
477 vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
478 vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
479 vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
480 vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
481 vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
482 vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
483 vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
484 }
485 };
486
487 auto load_src = [=]() {
488 mov(wreg_src, ptr[param1 + GET_OFF(src)]);
489 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
490 for (int j = 0; j < kh; j++) {
491 for (int i = 0; i < kw; i++) {
492 if (is_fwd) {
493 for (int v1 = 0; v1 < simd_w; v1++) {
494 int offset_src
495 = (j * kw * simd_w * simd_w
496 + i * simd_w * simd_w + v1 * simd_w)
497 * typesize;
498 int offset_F
499 = (j * kw * simd_w * simd_w
500 + i * simd_w * simd_w + v1 * simd_w)
501 * typesize;
502 vmovups(zmm_temp, ptr[wreg_src + offset_src]);
503 vmovups(ptr[wreg_F + offset_F], zmm_temp);
504 }
505 } else {
506 int offset_src = ((2 - j) * kw * simd_w * simd_w
507 + (2 - i) * simd_w * simd_w)
508 * typesize;
509 int offset_F
510 = (j * kw * simd_w * simd_w + i * simd_w * simd_w)
511 * typesize;
512 lea(wreg_M, ptr[wreg_src + offset_src]);
513 lea(wreg_MT, ptr[wreg_F + offset_F]);
514 trans16x16();
515 }
516 }
517 }
518 };
519
520 auto store_dst = [=]() {
521 mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
522 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
523
524 Label Loop_j;
525 mov(wreg_cnt_j, 0);
526 mov(wreg_dst_aux, wreg_dst);
527 mov(wreg_Fw_aux, wreg_Fw);
528
529 int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
530 * jcp.dimK_block * simd_w * simd_w;
531
532 L(Loop_j);
533 {
534 for (int i = 0; i < alpha; i++) {
535 // touch pages
536 vmovups(zmm_load(0),
537 ptr[wreg_Fw_aux + (i * simd_w * simd_w) * typesize]);
538 mov(wreg_dst_idx, i * dim5 * typesize);
539 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
540 }
541 for (int i = 0; i < alpha; i++) {
542 for (int v1 = 1; v1 < simd_w; v1++) {
543 int offset_Fw
544 = (i * simd_w * simd_w + v1 * simd_w) * typesize;
545 vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
546 }
547 mov(wreg_dst_idx, i * dim5 * typesize);
548 for (int v1 = 1; v1 < simd_w; v1++) {
549 int offset_dst = v1 * simd_w * typesize;
550 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
551 zmm_load(v1));
552 }
553 }
554 add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
555 add(wreg_dst_aux, alpha * dim5 * typesize);
556 add(wreg_cnt_j, 1);
557 cmp(wreg_cnt_j, alpha);
558 jl(Loop_j, T_NEAR);
559 }
560 };
561
562 auto trans_W_4x4_3x3 = [=]() {
563 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
564 vmovups(dst, a);
565 vfmadd231ps(dst, b, c);
566 };
567 auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
568 vmulps(zmm_temp, b, c);
569 vsubps(dst, a, zmm_temp);
570 };
571 auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
572 vsubps(dst, zmm_zero, a);
573 vfnmadd231ps(dst, b, c);
574 };
575
576 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
577 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
578 mov(wreg_T, ptr[param1 + GET_OFF(T)]);
579
580 Label Loop_j;
581 mov(wreg_cnt_j, 0);
582 L(Loop_j);
583 mov(wreg_F_aux, wreg_F);
584 mov(wreg_Fw_aux, wreg_Fw);
585 mov(wreg_temp, wreg_cnt_j);
586 shl(wreg_temp, 4 + 2);
587 lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
588 lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
589
590 for (int i = 0; i < 3; i++) {
591 for (int idx = 0; idx < 3; idx++) {
592 vmovups(zmm_F(idx),
593 ptr[wreg_F_aux
594 + (idx * 3 * simd_w * simd_w
595 + i * simd_w * simd_w)
596 * typesize]);
597 }
598 vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
599 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
600 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
601
602 vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
603 fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
604 fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
605 fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
606 fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
607 vmovaps(zmm_T(5), zmm_F(2));
608
609 for (int idx = 0; idx < 6; idx++) {
610 vmovups(ptr[wreg_T
611 + (idx * 3 * simd_w + i * simd_w) * typesize],
612 zmm_T(idx));
613 }
614 }
615 for (int i = 0; i < 6; i++) {
616
617 for (int idx = 0; idx < 3; idx++) {
618 vmovups(zmm_T(idx),
619 ptr[wreg_T
620 + (i * 3 * simd_w + idx * simd_w) * typesize]);
621 }
622 vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
623 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
624 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
625
626 vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
627 fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
628 fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
629 fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
630 fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
631 vmovaps(zmm_F(5), zmm_T(2));
632
633 for (int l = 0; l < 6; l++) {
634 vmovups(ptr[wreg_Fw_aux
635 + (i * 6 * simd_w * simd_w
636 + l * simd_w * simd_w)
637 * typesize],
638 zmm_F(l));
639 }
640 }
641 add(wreg_cnt_j, 1);
642 cmp(wreg_cnt_j, 16);
643 jl(Loop_j, T_NEAR);
644 };
645
646 auto inner_loops = [=]() {
647 load_src();
648 init_G();
649 trans_W_4x4_3x3();
650 store_dst();
651 };
652
653 preamble();
654 inner_loops();
655 postamble();
656}
657
658void _jit_avx512_core_f32_wino_conv_4x3_data_kernel ::
659 output_transform_data_ker_generate() {
660 bool is_fwd = one_of(
661 jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference);
662 int outw = is_fwd ? jcp.ow : jcp.iw;
663 int outh = is_fwd ? jcp.oh : jcp.ih;
664 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
665 bool with_bias = jcp.with_bias;
666 bool with_relu = jcp.with_eltwise;
667 bool with_relu_postsum = jcp.with_relu_postsum;
668 bool with_sum = jcp.with_sum;
669
670 auto zmm_zero = Xbyak::Zmm(0);
671 auto zmm_temp = Xbyak::Zmm(31);
672 auto zmm_G = [=](int i) { return Xbyak::Zmm(1 + i); };
673 auto zmm_O = [=](int i) { return Xbyak::Zmm(1 + alpha + i); };
674 auto zmm_T = [=](int i) { return Xbyak::Zmm(1 + 2 * alpha + i); };
675 auto zmm_t = [=](int i) { return Xbyak::Zmm(1 + 3 * alpha + i); };
676
677 auto init_G = [=]() {
678 mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
679 for (int i = 0; i < 6; i++) {
680 vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
681 }
682 };
683
684 auto load_src = [=]() {
685 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
686 mov(oreg_src, ptr[param1 + GET_OFF(src)]);
687
688 mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
689 imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
690 (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
691 * jcp.dimM_simd_block * typesize);
692 add(oreg_src, oreg_nb_tile_block_ur);
693
694 mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
695 imul(oreg_tile_block_ur, oreg_tile_block_ur,
696 jcp.dimM_simd_block * typesize);
697 add(oreg_src, oreg_tile_block_ur);
698
699 if (not_tiled) {
700 mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
701 imul(oreg_tile_block, oreg_tile_block,
702 jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
703 * (jcp.dimM_block * jcp.dimM_reg_block)
704 * jcp.dimN_reg_block * jcp.dimM_simd_block
705 * typesize);
706 add(oreg_src, oreg_tile_block);
707 }
708
709 int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
710 * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
711 for (int j = 0; j < alpha; j++) {
712 for (int i = 0; i < alpha; i++) {
713 int j_base_offset = j * alpha * last4dim;
714 int i_base_offset = i * last4dim;
715 vmovups(zmm_temp,
716 ptr[oreg_src + j_base_offset + i_base_offset]);
717 vmovups(ptr[oreg_Ow
718 + (j * alpha * simd_w + i * simd_w) * typesize],
719 zmm_temp);
720 }
721 }
722 };
723
724 auto store_dst = [=]() {
725 vpxord(zmm_zero, zmm_zero, zmm_zero);
726 mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
727 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
728 mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
729 shl(oreg_ydim, 2); // tj * tile_size (==4)
730 mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
731 shl(oreg_xdim, 2); // ti * tilesize (==4)
732
733 if (with_bias) mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
734
735 auto store_one = [=](int j, int i, bool is_aligned) {
736 auto zmm_O = Xbyak::Zmm(31);
737 auto zmm_relu_ns = Xbyak::Zmm(30);
738 auto xmm_relu_ns = Xbyak::Xmm(30);
739 int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
740
741 vmovups(zmm_O, ptr[oreg_O + offset]);
742 if (is_fwd) {
743 if (with_bias) { vaddps(zmm_O, zmm_O, ptr[oreg_bias]); }
744 if (with_relu) {
745 if (jcp.eltwise.alpha == 0) {
746 vmaxps(zmm_O, zmm_O, zmm_zero);
747 } else {
748 Opmask kmask = Opmask(7);
749 mov(imm_addr64, float2int(jcp.eltwise.alpha));
750 vmovq(xmm_relu_ns, imm_addr64);
751 vbroadcastss(zmm_relu_ns, xmm_relu_ns);
752 vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
753 vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
754 }
755 }
756 }
757 if (with_sum) {
758 vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
759 if (with_relu_postsum) // orig: with_relu_postsum
760 vmaxps(zmm_O, zmm_O, zmm_zero);
761 }
762 if (is_aligned)
763 vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
764 else
765 vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
766 };
767
768 auto i_loop = [=](int j, bool is_aligned) {
769 for (int i = 0; i < tile_size; i++) {
770 Label next;
771 mov(oreg_temp, oreg_xdim);
772 add(oreg_temp, i);
773 cmp(oreg_temp, outw);
774 jge(next, T_NEAR);
775 shl(oreg_temp, 4 + 2); // * 16 * 4
776
777 store_one(j, i, is_aligned);
778
779 L(next);
780 }
781 };
782
783 for (int j = 0; j < tile_size; j++) {
784 Label next, unaligned;
785 mov(oreg_temp, oreg_ydim);
786 add(oreg_temp, j);
787 cmp(oreg_temp, outh);
788 jge(next, T_NEAR);
789
790 mov(oreg_out_j, oreg_dst);
791 imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
792 add(oreg_out_j, oreg_temp);
793
794 test(oreg_dst, 63);
795 jnz(unaligned, T_NEAR);
796
797 i_loop(j, true);
798 jmp(next, T_NEAR);
799
800 L(unaligned);
801 i_loop(j, false);
802
803 L(next);
804 }
805 };
806
807 auto trans_O_4x4_3x3 = [=]() {
808 auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2) {
809 vmulps(dst, v1, u1);
810 vfmadd231ps(dst, v2, u2);
811 };
812 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
813 mov(oreg_T, ptr[param1 + GET_OFF(T)]);
814 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
815
816 for (int i = 0; i < alpha; i++) {
817 for (int j = 0; j < alpha; j++) {
818 vmovups(zmm_O(j),
819 ptr[oreg_Ow
820 + (j * alpha * simd_w + i * simd_w)
821 * typesize]);
822 }
823
824 vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
825 vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
826 vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
827 vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
828
829 vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
830 vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
831 fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
832 fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
833 fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
834 vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
835
836 for (int j = 0; j < tile_size; j++) {
837 vmovups(ptr[oreg_T
838 + (j * alpha * simd_w + i * simd_w) * typesize],
839 zmm_T(j));
840 }
841 }
842 for (int j = 0; j < tile_size; j++) {
843 for (int i = 0; i < alpha; i++) {
844 vmovups(zmm_T(i),
845 ptr[oreg_T
846 + (j * alpha * simd_w + i * simd_w)
847 * typesize]);
848 }
849 vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
850 vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
851 vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
852 vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
853
854 vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
855 vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
856 fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
857 fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
858 fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
859 vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
860
861 for (int i = 0; i < tile_size; i++) {
862 vmovups(ptr[oreg_O
863 + (j * tile_size * simd_w + i * simd_w)
864 * typesize],
865 zmm_O(i));
866 }
867 }
868 };
869
870 auto inner_loops = [=]() {
871 init_G();
872 load_src();
873 trans_O_4x4_3x3();
874 store_dst();
875 };
876
877 preamble();
878 inner_loops();
879 postamble();
880}
881
882void _jit_avx512_core_f32_wino_conv_4x3_data_kernel ::
883 input_transform_data_ker_generate() {
884 bool is_fwd = one_of(
885 jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference);
886 int inpw = is_fwd ? jcp.iw : jcp.ow;
887 int inph = is_fwd ? jcp.ih : jcp.oh;
888 int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
889 int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
890 int wp_max = inpw + l_pad;
891 int hp_max = inph + t_pad;
892 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
893 int G_size = 9;
894
895 auto zmm_zero = Xbyak::Zmm(0);
896 auto zmm_temp = Xbyak::Zmm(31);
897 auto zmm_G = [=](int i) { return Xbyak::Zmm(1 + i); };
898 auto zmm_I = [=](int i) { return Xbyak::Zmm(1 + G_size + i); };
899 auto zmm_T = [=](int i) { return Xbyak::Zmm(1 + G_size + alpha + i); };
900 auto zmm_t = [=](int i) { return Xbyak::Zmm(1 + G_size + 2 * alpha + i); };
901
902 auto init_G = [=]() {
903 mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
904 for (int i = 0; i < G_size; i++) {
905 vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
906 }
907 };
908
909 auto load_src = [=]() {
910 mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
911 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
912
913 xor_(ireg_zero, ireg_zero);
914 vpxord(zmm_zero, zmm_zero, zmm_zero);
915
916 mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
917 shl(ireg_ydim, 2); // tj * tile_size (==4)
918 mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
919 shl(ireg_xdim, 2); // ti * tilesize (==4)
920
921 for (int j = 0; j < alpha; j++) {
922 mov(ireg_temp, ireg_ydim);
923 add(ireg_temp, j);
924
925 mov(ireg_mask_j, 0xffff);
926 cmp(ireg_temp, t_pad);
927 cmovl(ireg_mask_j, ireg_zero);
928 cmp(ireg_temp, hp_max);
929 cmovge(ireg_mask_j, ireg_zero);
930
931 sub(ireg_temp, t_pad);
932 imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
933 mov(ireg_inp_j, ireg_src);
934 add(ireg_inp_j, ireg_temp);
935
936 for (int i = 0; i < alpha; i++) {
937
938 mov(ireg_temp, ireg_xdim);
939 add(ireg_temp, i);
940
941 mov(ireg_mask, 0xffff);
942 cmp(ireg_temp, l_pad);
943 cmovl(ireg_mask, ireg_zero);
944 cmp(ireg_temp, wp_max);
945 cmovge(ireg_mask, ireg_zero);
946 and_(ireg_mask, ireg_mask_j);
947
948 sub(ireg_temp, l_pad);
949 shl(ireg_temp, 4 + 2);
950
951 vpxord(zmm_temp, zmm_temp, zmm_temp);
952 Opmask kmask = Opmask(7);
953 kmovw(kmask, ireg_mask_32);
954 vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
955 vmovups(ptr[ireg_I
956 + (j * alpha * simd_w + i * simd_w) * typesize],
957 zmm_temp);
958 }
959 }
960 };
961
962 auto store_Iw = [=]() {
963 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
964 mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
965
966 bool streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
967 > 2 * LLC_data_size * jcp.nthr
968 ? true
969 : false;
970
971 if (not_tiled) {
972 mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
973 imul(ireg_tile_block, ireg_tile_block,
974 alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
975 * jcp.dimK_block * jcp.dimN_reg_block
976 * jcp.dimK_reg_block * typesize);
977 }
978
979 mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
980 imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
981 jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
982 * jcp.dimK_reg_block * typesize);
983
984 mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
985 imul(ireg_tile_block_ur, ireg_tile_block_ur,
986 jcp.dimK_reg_block * typesize);
987
988 add(ireg_output, ireg_nb_tile_block_ur);
989 add(ireg_output, ireg_tile_block_ur);
990 if (not_tiled) add(ireg_output, ireg_tile_block);
991
992 for (int j = 0; j < alpha; j++) {
993 for (int i = 0; i < alpha; i++) {
994 vmovups(zmm_temp,
995 ptr[ireg_Iw
996 + (j * alpha * simd_w + i * simd_w)
997 * typesize]);
998
999 int j_base_offset = j * alpha * jcp.dimN_block
1000 * jcp.dimK_nb_block * jcp.dimK_block
1001 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1002 int i_base_offset = i * jcp.dimN_block * jcp.dimK_nb_block
1003 * jcp.dimK_block * jcp.dimN_reg_block
1004 * jcp.dimK_reg_block * typesize;
1005
1006 if (not_tiled && streamout)
1007 vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
1008 zmm_temp);
1009 else
1010 vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
1011 zmm_temp);
1012 }
1013 }
1014 };
1015
1016 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
1017 vmulps(zmm_temp, a, b);
1018 vaddps(dst, zmm_temp, c);
1019 };
1020
1021 auto trans_I_4x4_3x3 = [=]() {
1022 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
1023 mov(ireg_T, ptr[param1 + GET_OFF(T)]);
1024 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
1025
1026 mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
1027 for (int i = 0; i < alpha; i++) {
1028 for (int idx = 0; idx < alpha; idx++) {
1029 vmovups(zmm_I(idx),
1030 ptr[ireg_I
1031 + (idx * alpha * simd_w + i * simd_w)
1032 * typesize]);
1033 int j_base_offset = i * alpha * jcp.dimN_block
1034 * jcp.dimK_nb_block * jcp.dimK_block
1035 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1036 int idx_base_offset = idx * jcp.dimN_block * jcp.dimK_nb_block
1037 * jcp.dimK_block * jcp.dimN_reg_block
1038 * jcp.dimK_reg_block * typesize;
1039 prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
1040 }
1041
1042 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1043 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1044 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1045 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1046 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1047 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1048
1049 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1050 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1051 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1052 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1053 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1054 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1055
1056 for (int idx = 0; idx < alpha; idx++) {
1057 vmovups(ptr[ireg_T
1058 + (idx * alpha * simd_w + i * simd_w)
1059 * typesize],
1060 zmm_T(idx));
1061 }
1062 }
1063 for (int i = 0; i < alpha; i++) {
1064 for (int idx = 0; idx < alpha; idx++) {
1065 vmovups(zmm_T(idx),
1066 ptr[ireg_T
1067 + (i * alpha * simd_w + idx * simd_w)
1068 * typesize]);
1069 }
1070
1071 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1072 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1073 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1074 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1075 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1076 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1077
1078 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1079 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1080 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1081 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1082 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1083 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1084
1085 for (int idx = 0; idx < alpha; idx++) {
1086 vmovups(ptr[ireg_Iw
1087 + (i * alpha * simd_w + idx * simd_w)
1088 * typesize],
1089 zmm_I(idx));
1090 }
1091 }
1092 };
1093
1094 auto inner_loops = [=]() {
1095 init_G();
1096 load_src();
1097 trans_I_4x4_3x3();
1098 store_Iw();
1099 };
1100
1101 preamble();
1102 inner_loops();
1103 postamble();
1104}
1105
1106status_t _jit_avx512_core_f32_wino_conv_4x3_data_kernel::init_conf_common(
1107 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1108 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1109 const memory_desc_wrapper &dst_d) {
1110 if (!mayiuse(avx512_core)) { return status::unimplemented; }
1111
1112 // This kernel only supports 2D convolutions.
1113 if (src_d.ndims() != 4) return status::unimplemented;
1114
1115 jcp.nthr = dnnl_get_max_threads();
1116
1117 jcp.prop_kind = cd.prop_kind;
1118
1119 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1120
1121 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1122 jcp.mb = src_d.dims()[0];
1123 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1124 jcp.oc_without_padding = jcp.oc;
1125 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1126 jcp.ih = src_d.dims()[2];
1127 jcp.iw = src_d.dims()[3];
1128 jcp.oh = dst_d.dims()[2];
1129 jcp.ow = dst_d.dims()[3];
1130 jcp.kh = weights_d.dims()[with_groups + 2];
1131 jcp.kw = weights_d.dims()[with_groups + 3];
1132 jcp.t_pad = cd.padding[0][0];
1133 jcp.l_pad = cd.padding[0][1];
1134 jcp.stride_h = cd.strides[0];
1135 jcp.stride_w = cd.strides[1];
1136 jcp.dilate_h = cd.dilates[0];
1137 jcp.dilate_w = cd.dilates[1];
1138 jcp.r_pad = nstl::max(
1139 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1140 jcp.b_pad = nstl::max(
1141 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1142 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1143 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1144 jcp.ohp = jcp.oh;
1145 jcp.owp = jcp.ow;
1146
1147 bool ok_to_pad_channels = jcp.ngroups == 1;
1148 if (ok_to_pad_channels) {
1149 jcp.oc = rnd_up(jcp.oc, simd_w);
1150 jcp.ic = rnd_up(jcp.ic, simd_w);
1151 }
1152
1153 // Checking conditions not supported by these kernels
1154 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1155 is_winograd_faster_than_direct(jcp)))
1156 return status::unimplemented;
1157
1158 const bool prb_shape_ok = jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
1159 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 && jcp.stride_h == 1
1160 && jcp.stride_w == 1 && jcp.dilate_h == 0 && jcp.dilate_w == 0
1161 && jcp.l_pad <= 1 && jcp.r_pad <= 1 && jcp.t_pad <= 1
1162 && jcp.b_pad <= 1;
1163 if (!prb_shape_ok) return status::unimplemented;
1164
1165 format_tag_t dat_tag = nChw16c;
1166 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1167 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1168
1169 if (jcp.src_tag != dat_tag) return status::unimplemented;
1170 if (jcp.dst_tag != dat_tag) return status::unimplemented;
1171
1172 if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) {
1173 format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
1174 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
1175 if (jcp.wei_tag != wei_tag) return status::unimplemented;
1176 }
1177
1178 bool layout_consistency = true && jcp.ic <= src_d.padded_dims()[1]
1179 && jcp.oc <= dst_d.padded_dims()[1]
1180 && (one_of(weights_d.format_kind(), format_kind::any,
1181 format_kind::wino)
1182 || (jcp.ic <= weights_d.padded_dims()[with_groups + 1]
1183 && jcp.oc <= weights_d.padded_dims()[with_groups
1184 + 0]));
1185 if (!layout_consistency) return status::unimplemented;
1186
1187 return status::success;
1188}
1189
1190void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
1191
1192 /* ----------- dimM reg block ---------------------*/
1193 auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
1194 int dimM_reg_block,
1195 int current_best) {
1196 int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
1197 return (dimM_reg_block >= 1) && (dimM_reg_block <= max_dimM_reg_block)
1198 && (dimM_reg_block > current_best);
1199 };
1200 jcp.dimM_reg_block = get_divisor_satisfying_cond(
1201 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
1202
1203 /* ----------- dimN reg block ---------------------*/
1204
1205 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
1206 int dimN_reg_block,
1207 int current_best) {
1208 return jcp.kernel_kind == embd_bcast
1209 ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
1210 : dimN_reg_block >= 1
1211 && (dimN_reg_block * jcp.dimM_reg_block
1212 + dimN_reg_block)
1213 < jcp.nb_reg
1214 && dimN_reg_block > current_best;
1215 };
1216 jcp.dimN_reg_block = get_divisor_satisfying_cond(
1217 jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
1218}
1219
1220status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
1221
1222 jcp.kernel_kind = embd_bcast;
1223
1224 set_kernel_dims_reg_block(jcp);
1225
1226 /*-------------- L2 blocking for dimN block ---------*/
1227
1228 auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1229 int dimN_block, int current_best) {
1230 return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
1231 && (dimN_block > current_best)
1232 && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
1233 >= 1.5 * jcp.nthr);
1234 };
1235
1236 jcp.dimN_block = get_divisor_satisfying_cond(
1237 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
1238 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
1239
1240 if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
1241 && (jcp.dimN_nb_block >= 1.5 * jcp.nthr)) {
1242
1243 /* ------------------- L1 blocking for GEMM --------------*/
1244 /* -------------------- Choose dimK block ----------------*/
1245
1246 auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1247 int dimK_block, int current_best) {
1248 return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
1249 && (dimK_block > current_best);
1250 };
1251
1252 jcp.dimK_block = get_divisor_satisfying_cond(
1253 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
1254
1255 if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
1256 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
1257
1258 /* -------------- Choose dimM block -------------------*/
1259 auto test_cond_dimM_block
1260 = [](jit_conv_winograd_conf_t &jcp, int dimM_block,
1261 int current_best) {
1262 return check_L1_block_gemm(jcp, jcp.dimK_block,
1263 dimM_block, 0.2, 0.5)
1264 && (dimM_block > current_best);
1265 };
1266
1267 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1268 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1269 test_cond_dimM_block);
1270 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
1271 / jcp.dimM_simd_block;
1272
1273 jcp.sched_policy = WSCHED_DATA_W_SGD;
1274 return status::success;
1275 }
1276 }
1277 return status::unimplemented;
1278}
1279
1280void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
1281
1282 set_kernel_dims_reg_block(jcp);
1283
1284 //********************* Choosing dimK_block **********************//
1285 auto test_cond1_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1286 int dimK_block, int current_best) {
1287 return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
1288 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
1289 && (dimK_block > current_best);
1290 };
1291
1292 auto test_cond1_bis_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1293 int dimK_block, int current_best) {
1294 return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
1295 jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
1296 jcp.dimM_simd_block, .9f)
1297 && (dimK_block > current_best);
1298 };
1299
1300 jcp.dimK_block = get_divisor_satisfying_cond(
1301 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
1302 // If we are not able to use streams, we fall back to condition [1]
1303 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1304 jcp.dimK_block = get_divisor_satisfying_cond(
1305 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
1306 jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
1307
1308 //********************* Choosing dimM_block **********************//
1309 auto test_cond1_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1310 int dimM_block, int current_best) {
1311 return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
1312 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1313 jcp.dimM_simd_block, .5f)
1314 && (dimM_block > current_best);
1315 };
1316
1317 auto test_cond1_bis_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1318 int dimM_block, int current_best) {
1319 return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
1320 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1321 jcp.dimM_simd_block, .3f)
1322 && (dimM_block > current_best);
1323 };
1324
1325 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1326 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1327 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1328 test_cond1_dimM_block);
1329 else
1330 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1331 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1332 test_cond1_bis_dimM_block);
1333 jcp.dimM_nb_block = jcp.dimM
1334 / (jcp.dimM_simd_block * jcp.dimM_block * jcp.dimM_reg_block);
1335
1336 //******************* Choosing dimN_block *******************//
1337 auto test_cond2_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1338 int dimN_block, int current_best) {
1339 return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
1340 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
1341 jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
1342 && (dimN_block > current_best);
1343 };
1344
1345 jcp.dimN_block = get_divisor_satisfying_cond(
1346 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
1347 jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
1348}
1349
1350status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
1351
1352 jcp.kernel_kind = expl_bcast;
1353 set_kernel_blocking_DATA_W_S_G_D(jcp);
1354 if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
1355 jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block,
1356 jcp.dimK, .1f, .35f))) {
1357 jcp.kernel_kind = embd_bcast;
1358 set_kernel_blocking_DATA_W_S_G_D(jcp);
1359 }
1360 jcp.sched_policy = WSCHED_DATA_W_S_G_D;
1361 return status::success;
1362}
1363
1364status_t _jit_avx512_core_f32_wino_conv_4x3_data_kernel::init_conf_kernel(
1365 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) {
1366 jcp.nb_reg = 32;
1367 jcp.dimN = dimN;
1368 jcp.dimK = dimK;
1369 jcp.dimM = dimM;
1370 jcp.sched_policy = WSCHED_INVALID;
1371
1372 jcp.dimK_reg_block = 16;
1373 jcp.dimM_simd_block = 16;
1374
1375 if (jcp.kernel_kind == embd_bcast) { jcp.dimM_reg_block = 1; }
1376
1377 if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
1378 set_wsched_DATA_W_S_G_D_avx512_core(jcp);
1379
1380 assert(jcp.sched_policy != WSCHED_INVALID);
1381 return status::success;
1382}
1383
1384bool jit_avx512_core_f32_wino_conv_4x3_fwd_kernel::post_ops_ok(
1385 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1386 const auto &p = attr.post_ops_;
1387
1388 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
1389 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1390
1391 switch (p.len()) {
1392 case 0: return true; // no post_ops
1393 case 1: return is_relu(0) || is_sum(0); // relu or sum
1394 case 2:
1395 return (is_sum(0) && is_relu(1))
1396 || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
1397 case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
1398 default: return false;
1399 }
1400}
1401
1402status_t jit_avx512_core_f32_wino_conv_4x3_fwd_kernel::init_conf(
1403 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1404 const memory_desc_t &src_md, memory_desc_t &weights_md,
1405 const memory_desc_t &dst_md, const primitive_attr_t &attr) {
1406
1407 status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md);
1408
1409 if (st != status::success) return st;
1410
1411 // Winograd specific initialization
1412 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1413 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1414 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1415
1416 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1417
1418 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
1419
1420 const auto &p = attr.post_ops_;
1421 const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
1422 jcp.with_eltwise = eltwise_ind != -1;
1423 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1424
1425 jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
1426 jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
1427
1428 status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
1429
1430 jcp.ic_simd_block = jcp.dimK_reg_block;
1431 jcp.ic_block = jcp.dimK_block;
1432 jcp.nb_ic = jcp.dimK_nb_block;
1433 jcp.oc_simd_block = jcp.dimM_simd_block;
1434 jcp.oc_block = jcp.dimM_block;
1435 jcp.oc_reg_block = jcp.dimM_reg_block;
1436 jcp.ic_reg_block = 1;
1437 jcp.nb_oc = jcp.dimM_nb_block;
1438 jcp.tile_block_ur = jcp.dimN_reg_block;
1439 jcp.nb_tile_block_ur = jcp.dimN_block;
1440 jcp.tile_block = jcp.dimN_nb_block;
1441
1442 /* re-create weights primitive descriptor
1443 and set weights wino_blocking */
1444 if (cd.prop_kind == dnnl_forward_inference) {
1445 memory_desc_t expect_wei_md = weights_md;
1446
1447 expect_wei_md.format_kind = format_kind::wino;
1448 expect_wei_md.data_type = data_type::f32;
1449 wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
1450 wd.wino_format = wino_memory_format_t::wino_wei_OBaaIBOIio;
1451 wd.r = 3;
1452 wd.alpha = 6;
1453
1454 wd.ic = jcp.ic;
1455 wd.oc = jcp.oc;
1456 wd.ic_block = jcp.dimK_reg_block;
1457 wd.oc_block = jcp.dimM_simd_block;
1458 wd.ic2_block = jcp.dimK_block;
1459 wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block;
1460 size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc;
1461 wd.size = max_size;
1462 wd.adj_scale = 1.f;
1463
1464 if (weights_md.format_kind == format_kind::any)
1465 weights_md = expect_wei_md;
1466 if (weights_md != expect_wei_md) return status::unimplemented;
1467 }
1468
1469 return res;
1470}
1471
1472status_t jit_avx512_core_f32_wino_conv_4x3_bwd_data_kernel::init_conf(
1473 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1474 const memory_desc_wrapper &diff_src_d,
1475 const memory_desc_wrapper &weights_d,
1476 const memory_desc_wrapper &diff_dst_d) {
1477 status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
1478
1479 if (st != status::success) return st;
1480
1481 jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
1482 jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
1483 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1484
1485 status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
1486
1487 jcp.oc_simd_block = jcp.dimK_reg_block;
1488 jcp.oc_block = jcp.dimK_block;
1489 jcp.nb_oc = jcp.dimK_nb_block;
1490 jcp.ic_simd_block = jcp.dimM_simd_block;
1491 jcp.ic_block = jcp.dimM_block;
1492 jcp.ic_reg_block = jcp.dimM_reg_block;
1493 jcp.oc_reg_block = 1;
1494 jcp.nb_ic = jcp.dimM_nb_block;
1495 jcp.tile_block_ur = jcp.dimN_reg_block;
1496 jcp.nb_tile_block_ur = jcp.dimN_block;
1497 jcp.tile_block = jcp.dimN_nb_block;
1498
1499 return res;
1500}
1501
1502void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::
1503 src_transform_generate() {
1504 constexpr int G_size = 9;
1505 const size_t ifwp = jcp.iw + jcp.l_pad;
1506 const size_t ifhp = jcp.ih + jcp.t_pad;
1507
1508 auto zmm_G = [=](int i) { return Xbyak::Zmm(i); };
1509 auto zmm_I = [=](int i) { return Xbyak::Zmm(G_size + i); };
1510 auto zmm_T = [=](int i) { return Xbyak::Zmm(G_size + alpha + i); };
1511 auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 2 * alpha + i); };
1512
1513 auto init_G = [=]() {
1514 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1515 for (int i = 0; i < G_size; i++) {
1516 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
1517 }
1518 };
1519
1520 auto load_src = [=]() {
1521 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1522 xor_(reg_zero, reg_zero);
1523
1524 mov(reg_ydim, reg_tj);
1525 shl(reg_ydim, 2); //tj * tile_size(=4)
1526
1527 for (int j = 0; j < alpha; j++) {
1528 /* check if tile index is within physical spatial boundaries*/
1529 mov(reg_maskj, 0xffff);
1530 cmp(reg_ydim, jcp.t_pad);
1531 cmovl(reg_maskj, reg_zero);
1532 cmp(reg_ydim, ifhp);
1533 cmovge(reg_maskj, reg_zero);
1534
1535 /*address offset for tile in src*/
1536 mov(reg_src_offset, reg_ydim);
1537 sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
1538 imul(reg_src_offset, reg_src_offset, jcp.iw);
1539
1540 mov(reg_xdim, reg_ti);
1541 shl(reg_xdim, 2); // xdim = ti * tile_size
1542
1543 add(reg_src_offset, reg_xdim);
1544 sub(reg_src_offset, jcp.l_pad);
1545 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1546 for (int i = 0; i < alpha; i++) {
1547 /* check if tile index is within physical spatial boundaries*/
1548 mov(reg_maski, 0xffff);
1549 cmp(reg_xdim, jcp.l_pad);
1550 cmovl(reg_maski, reg_zero);
1551 cmp(reg_xdim, ifwp);
1552 cmovge(reg_maski, reg_zero);
1553 and_(reg_maski, reg_maskj);
1554
1555 Opmask kmask_src = Xbyak::Opmask(7);
1556 auto zmm_src = Xbyak::Zmm(31);
1557 kmovw(kmask_src, reg_maski_32);
1558 vpxord(zmm_src, zmm_src, zmm_src);
1559 vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
1560 vmovups(ptr[reg_I], zmm_src);
1561
1562 add(reg_xdim, 1); //xdim = ti * tile_size + i
1563 add(reg_src_offset, simd_w * typesize);
1564 add(reg_I, simd_w * typesize);
1565 }
1566 add(reg_ydim, 1);
1567 }
1568 };
1569
1570 auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
1571 vmovups(dst, c);
1572 vfmadd231ps(dst, a, b);
1573 };
1574
1575 auto trans_I_3x3_4x4 = [=]() {
1576 //Use 24 registers
1577 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1578 mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
1579 for (int i = 0; i < alpha; i++) {
1580 for (int j = 0; j < alpha; j++) {
1581 size_t I_off = (j * alpha + i) * simd_w * typesize;
1582 vmovups(zmm_I(j), ptr[reg_I + I_off]);
1583 }
1584
1585 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1586 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1587 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1588 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1589 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1590 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1591
1592 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1593 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1594 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1595 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1596 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1597 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1598
1599 for (int j = 0; j < alpha; j++) {
1600 vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
1601 zmm_T(j));
1602 }
1603 }
1604
1605 for (int j = 0; j < alpha; j++) {
1606 for (int i = 0; i < alpha; i++) {
1607 vmovups(zmm_T(i),
1608 ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
1609 }
1610
1611 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1612 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1613 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1614 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1615 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1616 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1617
1618 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1619 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1620 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1621 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1622 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1623 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1624
1625 for (int i = 0; i < alpha; i++) {
1626 size_t dst_off
1627 = (j * alpha * jcp.ic_block * jcp.nb_tile_block_ur
1628 * jcp.tile_block_ur
1629 + i * jcp.ic_block * jcp.nb_tile_block_ur
1630 * jcp.tile_block_ur)
1631 * simd_w * typesize;
1632 vmovups(ptr[reg_dst + dst_off], zmm_I(i));
1633 }
1634 }
1635 };
1636
1637 auto compute_transform_SDGtWo = [=]() {
1638 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1639 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1640 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1641 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1642 xor_(reg_tile_count, reg_tile_count);
1643 Label loop_mb, loop_jtiles, loop_itiles, done;
1644 L(loop_mb);
1645 {
1646 L(loop_jtiles);
1647 {
1648 L(loop_itiles);
1649 {
1650 load_src();
1651
1652 trans_I_3x3_4x4();
1653
1654 add(reg_tile_count, 1);
1655 cmp(reg_tile_count,
1656 jcp.nb_tile_block_ur * jcp.tile_block_ur);
1657 jge(done);
1658
1659 add(reg_dst, simd_w * typesize);
1660 add(reg_ti, 1);
1661 cmp(reg_ti, jcp.itiles);
1662 jl(loop_itiles);
1663 }
1664 xor_(reg_ti, reg_ti);
1665 add(reg_tj, 1);
1666 cmp(reg_tj, jcp.jtiles);
1667 jl(loop_jtiles);
1668 }
1669 xor_(reg_tj, reg_tj);
1670 add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
1671 jmp(loop_mb);
1672 }
1673 L(done);
1674 };
1675
1676 auto compute_transform = [=]() {
1677 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1678 xor_(reg_ti, reg_ti);
1679 xor_(reg_tj, reg_tj);
1680
1681 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1682 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1683 imul(reg_temp, reg_tile_count, simd_w * typesize);
1684 add(reg_dst, reg_temp);
1685
1686 Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
1687 L(loop_jtiles);
1688
1689 {
1690 L(loop_itiles);
1691 {
1692 load_src();
1693
1694 trans_I_3x3_4x4();
1695
1696 add(reg_tile_count, 1);
1697 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1698 jge(next_tile_block);
1699 add(reg_dst, simd_w * typesize);
1700 jmp(next_tile);
1701
1702 L(next_tile_block);
1703 sub(reg_dst,
1704 (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) * simd_w
1705 * typesize);
1706 size_t tblk_off = alpha * alpha * jcp.ic_block
1707 * jcp.nb_tile_block_ur * jcp.tile_block_ur * simd_w
1708 * typesize;
1709 add(reg_dst, tblk_off);
1710 xor_(reg_tile_count, reg_tile_count);
1711
1712 L(next_tile);
1713 add(reg_ti, 1);
1714 cmp(reg_ti, jcp.itiles);
1715 jl(loop_itiles);
1716 }
1717 xor_(reg_ti, reg_ti);
1718 add(reg_tj, 1);
1719 cmp(reg_tj, jcp.jtiles);
1720 jl(loop_jtiles);
1721 }
1722 };
1723
1724 preamble();
1725 init_G();
1726 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1727 compute_transform_SDGtWo();
1728 else
1729 compute_transform();
1730 postamble();
1731}
1732
1733void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::
1734 diff_dst_transform_generate(bool with_bias) {
1735
1736 constexpr int G_size = 8;
1737 auto zmm_G = [](int i) { return Xbyak::Zmm(31); };
1738
1739 auto zmm_src = [=](int j, int i) { return Xbyak::Zmm(G_size + j * 4 + i); };
1740
1741 auto zmm_bias = Xbyak::Zmm(31);
1742
1743 auto load_src = [=]() {
1744 if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
1745 mov(reg_ydim, reg_tj);
1746 shl(reg_ydim, 2); //tj * tile_size(=4)
1747 for (int j = 0; j < tile_size; j++) {
1748 /* check if tile index is within physical spatial boundaries*/
1749 mov(reg_maskj, 0xffff);
1750 cmp(reg_ydim, jcp.oh);
1751 cmovge(reg_maskj, reg_zero);
1752
1753 /*address offset for tile in src*/
1754 mov(reg_src_offset, reg_ydim);
1755 imul(reg_src_offset, reg_src_offset, jcp.ow);
1756
1757 mov(reg_xdim, reg_ti);
1758 shl(reg_xdim, 2); // xdim = ti * tile_size
1759
1760 add(reg_src_offset, reg_xdim);
1761 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1762 for (int i = 0; i < tile_size; i++) {
1763 /* check if tile index is within physical spatial boundaries*/
1764 mov(reg_maski, 0xffff);
1765 cmp(reg_xdim, jcp.ow);
1766 cmovge(reg_maski, reg_zero);
1767 and_(reg_maski, reg_maskj);
1768
1769 Opmask kmask_src = Xbyak::Opmask(7);
1770 kmovw(kmask_src, reg_maski_32);
1771 vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
1772 vmovups(zmm_src(j, i) | kmask_src,
1773 ptr[reg_src + reg_src_offset]);
1774 if (with_bias)
1775 vaddps(zmm_bias | kmask_src, zmm_bias,
1776 ptr[reg_src + reg_src_offset]);
1777
1778 add(reg_xdim, 1); //xdim = ti * tile_size + i
1779 add(reg_src_offset, simd_w * typesize);
1780 }
1781 add(reg_ydim, 1);
1782 }
1783 if (with_bias) vmovups(ptr[reg_bias], zmm_bias);
1784 };
1785
1786 auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 16 + i); };
1787
1788 auto zmm_T = [=](int j, int i) { return Xbyak::Zmm(j * 4 + i); };
1789
1790 auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
1791 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1792 vmovups(ptr[reg_dst + dst_off], a);
1793 else
1794 vmovntps(ptr[reg_dst + dst_off], a);
1795 };
1796
1797 auto trans_W_3x3_4x4 = [=]() {
1798 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1799 for (int i = 0; i < tile_size; i++) {
1800 vbroadcastss(zmm_G(0), ptr[reg_G]);
1801 vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
1802
1803 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1804 vmovups(zmm_t(1), zmm_t(0));
1805 vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
1806
1807 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1808 vmovups(zmm_t(2), zmm_t(0));
1809 vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
1810
1811 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1812 vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
1813
1814 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1815 vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
1816
1817 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1818 vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
1819
1820 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1821 vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
1822
1823 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1824 vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
1825 vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
1826 vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
1827 vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
1828 vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
1829 vmovups(zmm_T(5, i), zmm_src(3, i));
1830 }
1831
1832 for (int j = 0; j < alpha; j++) {
1833 vbroadcastss(zmm_G(0), ptr[reg_G]);
1834 vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
1835
1836 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1837 vmovups(zmm_t(1), zmm_t(0));
1838 vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
1839
1840 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1841 vmovups(zmm_t(2), zmm_t(0));
1842 vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
1843
1844 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1845 vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
1846
1847 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1848 vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
1849
1850 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1851 vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
1852
1853 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1854 vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
1855
1856 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1857 vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
1858 vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
1859 vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
1860 vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
1861 vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
1862 vmovups(zmm_t(3), zmm_T(j, 3));
1863
1864 int alpha_offset = (jcp.oc / jcp.nb_oc)
1865 * (jcp.ntiles / jcp.tile_block) * typesize;
1866 int dst_off = j * alpha * alpha_offset;
1867 movps(reg_dst, dst_off, zmm_t(0));
1868 dst_off += alpha_offset;
1869 movps(reg_dst, dst_off, zmm_t(5));
1870 dst_off += alpha_offset;
1871 movps(reg_dst, dst_off, zmm_t(1));
1872 dst_off += alpha_offset;
1873 movps(reg_dst, dst_off, zmm_t(6));
1874 dst_off += alpha_offset;
1875 movps(reg_dst, dst_off, zmm_t(2));
1876 dst_off += alpha_offset;
1877 movps(reg_dst, dst_off, zmm_t(3));
1878 }
1879 };
1880 auto compute_transform_SDGtWo = [=]() {
1881 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1882 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1883 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1884
1885 xor_(reg_zero, reg_zero);
1886 xor_(reg_oc_ur, reg_oc_ur);
1887 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
1888
1889 L(loop_oc_ur);
1890 {
1891 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1892 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1893 xor_(reg_tile_count, reg_tile_count);
1894 L(loop_mb);
1895 {
1896 L(loop_jtiles);
1897 {
1898 L(loop_itiles);
1899 {
1900 load_src();
1901
1902 trans_W_3x3_4x4();
1903
1904 add(reg_tile_count, 1);
1905 cmp(reg_tile_count,
1906 jcp.nb_tile_block_ur * jcp.tile_block_ur);
1907 jge(tiles_done);
1908
1909 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1910 add(reg_ti, 1);
1911 cmp(reg_ti, jcp.itiles);
1912 jl(loop_itiles);
1913 }
1914 xor_(reg_ti, reg_ti);
1915 add(reg_tj, 1);
1916 cmp(reg_tj, jcp.jtiles);
1917 jl(loop_jtiles);
1918 }
1919 xor_(reg_tj, reg_tj);
1920 add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
1921 jmp(loop_mb);
1922 }
1923
1924 L(tiles_done);
1925 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1926 add(reg_dst, simd_w * typesize);
1927 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1928 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1929
1930 if (with_bias) add(reg_bias, simd_w * typesize);
1931 add(reg_oc_ur, 1);
1932 cmp(reg_oc_ur, jcp.oc_reg_block);
1933 jl(loop_oc_ur);
1934 }
1935 };
1936
1937 auto compute_transform = [=]() {
1938 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1939 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1940 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1941
1942 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1943 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1944 imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1945 add(reg_dst, reg_temp);
1946
1947 xor_(reg_zero, reg_zero);
1948 xor_(reg_oc_ur, reg_oc_ur);
1949 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block,
1950 next_tile;
1951
1952 L(loop_oc_ur);
1953 {
1954 xor_(reg_ti, reg_ti);
1955 xor_(reg_tj, reg_tj);
1956
1957 L(loop_jtiles);
1958 {
1959 L(loop_itiles);
1960 {
1961 load_src();
1962
1963 trans_W_3x3_4x4();
1964
1965 add(reg_tile_count, 1);
1966 cmp(reg_tile_count,
1967 jcp.nb_tile_block_ur * jcp.tile_block_ur);
1968 jge(next_tile_block);
1969 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1970 jmp(next_tile);
1971
1972 L(next_tile_block);
1973 sub(reg_dst,
1974 (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1975 * jcp.oc_reg_block * simd_w * typesize);
1976 int tblk_off = alpha * alpha * (jcp.oc / jcp.nb_oc)
1977 * (jcp.ntiles / jcp.tile_block) * typesize;
1978 add(reg_dst, tblk_off);
1979 xor_(reg_tile_count, reg_tile_count);
1980
1981 L(next_tile);
1982 add(reg_ti, 1);
1983 cmp(reg_ti, jcp.itiles);
1984 jl(loop_itiles);
1985 }
1986 xor_(reg_ti, reg_ti);
1987 add(reg_tj, 1);
1988 cmp(reg_tj, jcp.jtiles);
1989 jl(loop_jtiles);
1990 }
1991
1992 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1993 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1994 imul(reg_temp, reg_tile_count,
1995 jcp.oc_reg_block * simd_w * typesize);
1996 add(reg_dst, reg_temp);
1997 add(reg_dst, simd_w * typesize);
1998 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1999 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
2000
2001 if (with_bias) add(reg_bias, simd_w * typesize);
2002 add(reg_oc_ur, 1);
2003 cmp(reg_oc_ur, jcp.oc_reg_block);
2004 jl(loop_oc_ur);
2005 }
2006 };
2007
2008 preamble();
2009 if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
2010 compute_transform_SDGtWo();
2011 } else {
2012 compute_transform();
2013 }
2014 postamble();
2015}
2016
2017void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::
2018 diff_weights_transform_generate(bool first_tile) {
2019 int G_size = 4;
2020
2021 auto zmm_G = [](int i) { return Xbyak::Zmm(i); };
2022
2023 auto init_G = [=]() {
2024 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
2025 for (int i = 0; i < G_size; i++)
2026 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
2027 };
2028
2029 auto zmm_src = [=](int i) { return Xbyak::Zmm(G_size + i); };
2030
2031 auto load_src = [=](int i) {
2032 for (int j = 0; j < alpha; j++) {
2033 size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block * jcp.ic_block
2034 * simd_w * simd_w * typesize;
2035 size_t src_off = (j * alpha + i) * alpha_offset;
2036 vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
2037 }
2038 };
2039
2040 auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 6 + i); };
2041
2042 auto zmm_T = [=](int j, int i) {
2043 return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
2044 };
2045
2046 auto zmm_dst = [=](int i) { return Xbyak::Zmm(G_size + i); };
2047
2048 auto zmm_temp = Xbyak::Zmm(31);
2049
2050 auto store_dst = [=](int j) {
2051 for (int i = 0; i < jcp.kw; i++) {
2052 size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
2053
2054 if (!first_tile) {
2055 vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
2056 vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
2057 }
2058 vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
2059 }
2060 };
2061
2062 auto compute_transform = [=]() {
2063 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2064 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2065
2066 xor_(reg_ic_simd, reg_ic_simd);
2067 Label loop_ic_simd;
2068 L(loop_ic_simd);
2069 {
2070 for (int i = 0; i < alpha; i++) {
2071 load_src(i);
2072
2073 vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
2074 vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
2075 vmovups(zmm_t(2), zmm_src(5));
2076 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2077
2078 vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
2079 vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
2080 vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
2081 vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
2082 vsubps(zmm_temp, zmm_src(3), zmm_src(4));
2083 vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
2084 vmovups(zmm_T(2, i), zmm_t(2));
2085 vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
2086 }
2087
2088 for (int j = 0; j < jcp.kh; j++) {
2089 vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
2090 vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
2091 vmovups(zmm_t(2), zmm_T(j, 5));
2092 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2093
2094 vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
2095 vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
2096 vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
2097 vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
2098 vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
2099 vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
2100 vmovups(zmm_dst(2), zmm_t(2));
2101 vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
2102
2103 store_dst(j);
2104 }
2105
2106 add(reg_src, jcp.oc_reg_block * simd_w * typesize);
2107 add(reg_dst, simd_w * typesize);
2108 add(reg_ic_simd, 1);
2109 cmp(reg_ic_simd, simd_w);
2110 jl(loop_ic_simd);
2111 }
2112 };
2113 preamble();
2114 push(reg_EVEX_max_8b_offt);
2115 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2116 init_G();
2117 compute_transform();
2118 pop(reg_EVEX_max_8b_offt);
2119 postamble();
2120}
2121
2122void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate(
2123 bool is_first_tile) {
2124 auto zmm_srcA = [=]() { return Xbyak::Zmm(0); };
2125
2126 auto zmm_srcB = [=](size_t N_ur) { return Xbyak::Zmm(N_ur + 1); };
2127
2128 auto broadcastB = [=](size_t K_ur) {
2129 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2130 size_t srcB_off
2131 = (K_ur * jcp.dimN_reg_block + N_bcast) * sizeof(float);
2132 vbroadcastss(
2133 zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
2134 }
2135 };
2136
2137 auto load_srcA = [=](size_t K_ur, int M_ur) {
2138 size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2139 + M_ur * jcp.dimM_simd_block)
2140 * sizeof(float);
2141 vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
2142 };
2143
2144 auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast) {
2145 size_t idx = 1 // zmm_srcA
2146 + jcp.dimN_bcast_ur // zmm_srcB
2147 + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
2148 assert(idx < 32);
2149 return Xbyak::Zmm(idx);
2150 };
2151 auto prepare_accumm = [=]() {
2152 for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2153 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2154 Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
2155 vpxord(zmm, zmm, zmm);
2156 }
2157 }
2158 };
2159
2160 auto store_dstC = [=]() {
2161 /******** Write C back to memory *******/
2162 for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
2163 for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
2164 Zmm zmm = zmm_dstC(M_reg, N_ur);
2165 size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2166 + M_reg * jcp.dimM_simd_block)
2167 * sizeof(float);
2168 if (!is_first_tile) {
2169 vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
2170 vaddps(zmm, zmm, Xbyak::Zmm(0));
2171 }
2172 vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
2173 }
2174 }
2175 };
2176
2177 auto inner_loops = [=]() {
2178 Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
2179
2180 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
2181 L(dimM_block_loop);
2182 { /************* OC_block (M) loop ***********/
2183 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
2184 L(dimN_block_loop);
2185 { /*************** IC_block (N) loop *********/
2186
2187 mov(reg_nb_dimN_bcast_ur,
2188 jcp.dimN_reg_block / jcp.dimN_bcast_ur);
2189 L(dimN_bcast_ur);
2190 {
2191 prepare_accumm();
2192
2193 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
2194 L(dimK_block_loop);
2195 {
2196 /************* nb_tile_ur(K) loop ********/
2197 for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
2198
2199 broadcastB(K_ur);
2200
2201 for (int M_reg_ur = 0;
2202 M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2203 load_srcA(K_ur, M_reg_ur);
2204 for (int N_bcast = 0;
2205 N_bcast < jcp.dimN_bcast_ur;
2206 ++N_bcast) {
2207 vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast),
2208 zmm_srcA(), zmm_srcB(N_bcast));
2209 }
2210 }
2211 }
2212 add(reg_srcA,
2213 jcp.dimK_reg_block * jcp.dimM_reg_block
2214 * jcp.dimM_simd_block * sizeof(float));
2215 add(reg_srcB,
2216 jcp.dimK_reg_block * jcp.dimN_reg_block
2217 * sizeof(float));
2218 sub(reg_dimK_block_loop_cnt, 1);
2219 jnz(dimK_block_loop);
2220 }
2221
2222 store_dstC();
2223
2224 sub(reg_srcA,
2225 jcp.dimK_block * jcp.dimK_reg_block
2226 * jcp.dimM_reg_block * jcp.dimM_simd_block
2227 * sizeof(float));
2228 sub(reg_srcB,
2229 jcp.dimK_block * jcp.dimK_reg_block
2230 * jcp.dimN_reg_block * sizeof(float));
2231 add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
2232 add(reg_dstC,
2233 jcp.dimN_bcast_ur * jcp.dimM_reg_block
2234 * jcp.dimM_simd_block * sizeof(float));
2235 sub(reg_nb_dimN_bcast_ur, 1);
2236 jnz(dimN_bcast_ur);
2237 }
2238
2239 sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
2240 add(reg_srcB,
2241 jcp.dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
2242 * sizeof(float));
2243 sub(reg_dimN_block_loop_cnt, 1);
2244 jnz(dimN_block_loop);
2245 }
2246
2247 sub(reg_srcB,
2248 jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block
2249 * jcp.dimN_reg_block * sizeof(float));
2250 add(reg_srcA,
2251 jcp.dimK_block * jcp.dimK_reg_block * jcp.dimM_reg_block
2252 * jcp.dimM_simd_block * sizeof(float));
2253 sub(reg_dimM_block_loop_cnt, 1);
2254 jnz(dimM_block_loop);
2255 }
2256 };
2257
2258 /* Preamble */
2259 preamble();
2260
2261 inner_loops();
2262
2263 /* Postamble */
2264 postamble();
2265 ret();
2266}
2267
2268namespace {
2269
2270void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
2271 /*M params*/
2272 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
2273 / jcp.dimM_simd_block;
2274 jcp.oc_reg_block = jcp.dimM_reg_block;
2275 jcp.oc_block = jcp.dimM_block;
2276 jcp.nb_oc = jcp.dimM_nb_block;
2277 /*N params*/
2278 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
2279 jcp.ic_block = jcp.dimN_block;
2280 jcp.nb_ic = jcp.dimN_nb_block;
2281
2282 /*K params*/
2283 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
2284 jcp.tile_block_ur = jcp.dimK_reg_block;
2285 jcp.nb_tile_block_ur = jcp.dimK_block;
2286 jcp.tile_block = jcp.dimK_nb_block;
2287}
2288
2289status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
2290
2291 size_t K_blk_ur, N_blk, M_blk;
2292 /* IS this strategy feasible? */
2293 auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
2294 size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
2295 size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
2296 return (((V_sz + M_sz) / jcp.nthr) >= 2 * L2_cache_size)
2297 && (jcp.dimK / jcp.nthr >= 1.0);
2298 };
2299
2300 auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
2301 int max_block = 1) {
2302 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block
2303 * dimK_block_ur * sizeof(float);
2304 size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
2305 size_t M_L2_block
2306 = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
2307 bool load_balance = true;
2308 if (!(jcp.dimK % jcp.nthr)) {
2309 load_balance = ((jcp.dimK / dimK_block_ur) % jcp.nthr == 0);
2310 }
2311 return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
2312 && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
2313 && load_balance && (M_L2_block < L2_cache_size);
2314 };
2315
2316 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2317 int useless_arg = 0) {
2318 return (dimK_ur >= 2) && (dimK_ur <= 8);
2319 };
2320
2321 auto blocking_ok = [&]() {
2322 size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block
2323 * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
2324 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2325 * K_blk_ur * sizeof(float);
2326 size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block
2327 * jcp.dimM_simd_block * N_blk * jcp.dimN_reg_block
2328 * sizeof(float);
2329 size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
2330 /*Replace 2.375 with L2+L3 cache size*/
2331 return (L2_block > 0.1 * L2_cache_size)
2332 && (L2_block <= 1.2 * L2_cache_size);
2333 };
2334
2335 if (test_MV_large_enough(jcp)) {
2336 if ((jcp.dimM / jcp.dimM_simd_block) % 2 == 0) {
2337 jcp.dimM_reg_block = 2;
2338 } else {
2339 jcp.dimM_reg_block = 1;
2340 }
2341 jcp.dimM_simd_block = jcp.oc_simd_block;
2342 jcp.dimN_reg_block = jcp.ic_simd_block;
2343 jcp.dimN_bcast_ur = 8;
2344 /*dimK_block and dimK_ur*/
2345 size_t min_dimK_block_ur = get_divisor_satisfying_cond(
2346 jcp, jcp.dimK, 1, test_min_dimK_L1);
2347
2348 jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2349 jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2350 for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
2351 if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
2352 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2353 if (!(jcp.dimN_block % N_blk)) {
2354 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2355 if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
2356 jcp.dimK_reg_block
2357 = get_divisor_satisfying_cond(
2358 jcp, K_blk_ur, 1, test_dimK_ur);
2359 if (!test_dimK_ur(jcp, jcp.dimK_reg_block))
2360 return status::unimplemented;
2361 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2362 jcp.dimN_block = N_blk;
2363 jcp.dimM_block = M_blk;
2364 jcp.sched_policy = WSCHED_WEI_SDGtWo;
2365 set_jcp_WEI_params(jcp);
2366 jcp.nthr = nstl::min(jcp.nthr, jcp.tile_block);
2367 return status::success;
2368 }
2369 }
2370 }
2371 }
2372 }
2373 }
2374 }
2375 return status::unimplemented;
2376}
2377
2378status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
2379 if ((jcp.dimM / jcp.dimM_simd_block) % 2 == 0) {
2380 jcp.dimM_reg_block = 2;
2381 } else {
2382 jcp.dimM_reg_block = 1;
2383 }
2384 jcp.dimN_bcast_ur = 8;
2385 jcp.dimN_reg_block = jcp.ic_simd_block;
2386 jcp.dimM_simd_block = jcp.oc_simd_block;
2387 jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2388 jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2389 float C1 = 0.0, C2 = 0.0;
2390 float C1_max = 0.5, C2_max = 1.4;
2391 int N_blk, M_blk, K_blk_ur;
2392
2393 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2394 int useless_arg = 0) {
2395 return (dimK_ur >= 2) && (dimK_ur <= 8);
2396 };
2397
2398 auto blocking_ok = [&]() -> bool {
2399 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur
2400 * sizeof(float);
2401 size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2402 bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
2403 && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
2404
2405 size_t nb_N_blk = jcp.dimN / N_blk / jcp.dimN_reg_block;
2406 size_t nb_M_blk
2407 = jcp.dimM / M_blk / jcp.dimM_reg_block / jcp.dimM_simd_block;
2408 size_t nb_K_blk = jcp.dimK / K_blk_ur;
2409 bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk)
2410 >= static_cast<size_t>(jcp.nthr);
2411 if (!(nb_K_blk % jcp.nthr)) {
2412 load_balance = load_balance && (nb_K_blk % jcp.nthr == 0);
2413 }
2414
2415 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2416 * K_blk_ur * sizeof(float);
2417
2418 size_t L2_block = V_L2_block;
2419 /*Replace 2.375 with L2+L3 cache size*/
2420 bool L2_cond = (L2_block >= C2 * L2_cache_size)
2421 && (L2_block <= C2_max * L2_cache_size);
2422 return L1_cond && load_balance && L2_cond;
2423 };
2424
2425 for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
2426 if (jcp.dimK % K_blk_ur == 0) {
2427 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2428 if (jcp.dimN_block % N_blk == 0) {
2429 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2430 if (jcp.dimM_block % M_blk == 0) {
2431 if (blocking_ok()) {
2432 jcp.dimN_block = N_blk;
2433 jcp.dimM_block = M_blk;
2434 jcp.dimK_reg_block
2435 = get_divisor_satisfying_cond(
2436 jcp, K_blk_ur, 1, test_dimK_ur);
2437 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2438 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2439 set_jcp_WEI_params(jcp);
2440 return status::success;
2441 }
2442 }
2443 }
2444 }
2445 }
2446 }
2447 }
2448 jcp.dimK_reg_block = 1;
2449 jcp.dimK_block = 1;
2450 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2451 set_jcp_WEI_params(jcp);
2452 return status::success;
2453}
2454} // namespace
2455status_t jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::init_conf(
2456 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
2457 const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
2458 const memory_desc_wrapper &diff_weights_d) {
2459 if (!mayiuse(avx512_core)) return status::unimplemented;
2460
2461 // This kernel only supports 2D convolutions.
2462 if (src_d.ndims() != 4) return status::unimplemented;
2463
2464 jcp.nthr = dnnl_get_max_threads();
2465
2466 jcp.prop_kind = cd.prop_kind;
2467 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2468 jcp.mb = src_d.dims()[0];
2469 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
2470 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2471 jcp.oc_without_padding = jcp.oc;
2472 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2473 jcp.ih = src_d.dims()[2];
2474 jcp.iw = src_d.dims()[3];
2475 jcp.oh = diff_dst_d.dims()[2];
2476 jcp.ow = diff_dst_d.dims()[3];
2477 jcp.kh = diff_weights_d.dims()[with_groups + 2];
2478 jcp.kw = diff_weights_d.dims()[with_groups + 3];
2479 jcp.t_pad = cd.padding[0][0];
2480 jcp.l_pad = cd.padding[0][1];
2481 jcp.stride_h = cd.strides[0];
2482 jcp.stride_w = cd.strides[1];
2483 jcp.r_pad = nstl::max(
2484 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
2485 jcp.b_pad = nstl::max(
2486 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
2487 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
2488 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
2489 jcp.ohp = jcp.oh;
2490 jcp.owp = jcp.ow;
2491 jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
2492 jcp.dilate_h = cd.dilates[0];
2493 jcp.dilate_w = cd.dilates[1];
2494
2495 bool ok_to_pad_channels = jcp.ngroups == 1;
2496 if (ok_to_pad_channels) {
2497 jcp.oc = rnd_up(jcp.oc, simd_w);
2498 jcp.ic = rnd_up(jcp.ic, simd_w);
2499 }
2500
2501 // Winograd specific initialization
2502 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
2503 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
2504 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
2505
2506 // Winograd kernel works only for 3x3 convolution with stride 1
2507 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
2508 is_winograd_faster_than_direct(jcp)))
2509 return status::unimplemented;
2510
2511 const bool prb_shape_ok = jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
2512 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 && jcp.stride_h == 1
2513 && jcp.stride_w == 1 && jcp.dilate_h == 0 && jcp.dilate_w == 0
2514 && jcp.l_pad <= 1 && jcp.r_pad <= 1 && jcp.t_pad <= 1
2515 && jcp.b_pad <= 1;
2516 if (!prb_shape_ok) return status::unimplemented;
2517
2518 format_tag_t dat_tag = nChw16c;
2519 format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
2520 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
2521 jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
2522 jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
2523
2524 if (jcp.src_tag != dat_tag) return status::unimplemented;
2525 if (jcp.wei_tag != wei_tag) return status::unimplemented;
2526 if (jcp.dst_tag != dat_tag) return status::unimplemented;
2527
2528 bool layout_consistency = true && jcp.ic <= src_d.padded_dims()[1]
2529 && jcp.oc <= diff_dst_d.padded_dims()[1]
2530 && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
2531 && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
2532 if (!layout_consistency) return status::unimplemented;
2533
2534 /******************Kernel blocking Parameters ***********/
2535 jcp.ic_simd_block = simd_w;
2536 jcp.oc_simd_block = simd_w;
2537
2538 jcp.dimK = jcp.ntiles;
2539 jcp.dimN = jcp.ic;
2540 jcp.dimM = jcp.oc;
2541 jcp.dimM_simd_block = jcp.oc_simd_block;
2542 jcp.dimN_reg_block = jcp.ic_simd_block;
2543 jcp.sched_policy = WSCHED_INVALID;
2544 status_t res = set_wsched_WEI_SDGtWo(jcp);
2545 if (res == status::unimplemented) {
2546 res = set_wsched_WEI_S_D_Giot_W(jcp);
2547 assert(res == status::success);
2548 }
2549 return res;
2550}
2551} // namespace x64
2552} // namespace cpu
2553} // namespace impl
2554} // namespace dnnl
2555
2556// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2557