1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <math.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | |
27 | #include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.hpp" |
28 | |
29 | #define GET_OFF(field) offsetof(jit_wino_transform_call_s, field) |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | namespace { |
37 | |
38 | using namespace dnnl::impl::utils; |
39 | |
40 | unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
41 | unsigned int L2_cache_size = platform::get_per_core_cache_size(2); |
42 | unsigned int LLC_data_size = platform::get_per_core_cache_size(3); |
43 | |
44 | // the test funtion takes jcp, the candidate and the current best. |
45 | // it returns true if the new candidate is better |
46 | int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number, |
47 | int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int)) { |
48 | int best_divisor = default_best; |
49 | auto test_num |
50 | = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) { |
51 | if (test(jcp, num, best_divisor)) { best_divisor = num; } |
52 | }; |
53 | |
54 | for (int divisor = 1; divisor <= ::sqrt(number); divisor++) { |
55 | if (number % divisor == 0) { |
56 | test_num(jcp, divisor); |
57 | test_num(jcp, number / divisor); |
58 | } |
59 | } |
60 | |
61 | return best_divisor; |
62 | } |
63 | |
64 | namespace { |
65 | bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) { |
66 | /* Determines if current winograd implementation is faster than direct. |
67 | Following conditions are empirical and based on performance data */ |
68 | unsigned int ncores_per_socket |
69 | = cpu().getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); |
70 | unsigned int nthreads = dnnl_get_max_threads(); |
71 | |
72 | if (jcp.prop_kind == prop_kind::forward_inference) { |
73 | return jcp.mb >= 4; |
74 | } else if (nthreads > ncores_per_socket) { |
75 | double src_dst_transforms_per_core = alpha * alpha * (jcp.ic + jcp.oc) |
76 | * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size) |
77 | * ((jcp.ow + tile_size - 1) / tile_size) * sizeof(float) / 1024. |
78 | / 1024. / nthreads; |
79 | double wei_transform = alpha * alpha * jcp.ic * jcp.oc * sizeof(float) |
80 | / 1024. / 1024.; |
81 | |
82 | if (jcp.prop_kind == prop_kind::backward_weights) { |
83 | if (src_dst_transforms_per_core < 0.3 |
84 | || (src_dst_transforms_per_core <= 28 && wei_transform < 4)) |
85 | return false; |
86 | else |
87 | return true; |
88 | } else { |
89 | if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02) |
90 | return false; |
91 | } |
92 | } |
93 | |
94 | return jcp.mb > 8; |
95 | } |
96 | } // namespace |
97 | |
98 | /* assumes 512 bits registers */ |
99 | /* TODO: add support for strides */ |
100 | /* TODO: handle the prefetch distance automatically */ |
101 | using cache_t = enum cache_t_ { L1, L2, L3 }; |
102 | |
103 | template <typename data_t> |
104 | struct prefetcher_t { |
105 | prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr, |
106 | cache_t cache_type, size_t block_size, /* in number of elements*/ |
107 | int nb_instructions_in_block, int fma_ipc) |
108 | : cg_(generator) |
109 | , reg_base_addr_(reg_base_addr) |
110 | , cache_type_(cache_type) |
111 | , cache_block_size_(block_size) { |
112 | nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t)); |
113 | prefetch_spread_ |
114 | = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_); |
115 | prefetch_blk_ |
116 | = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block); |
117 | |
118 | /* assumption: when fetch in Li, data is already in L(i+1) */ |
119 | int cache_latency; |
120 | switch (cache_type_) { |
121 | case L1: cache_latency = 14; break; |
122 | case L2: cache_latency = 250; break; |
123 | case L3: cache_latency = 250; break; |
124 | } |
125 | |
126 | prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_); |
127 | } |
128 | |
129 | void prefetch(int instruction_number) { |
130 | if (instruction_number % prefetch_spread_ == 0) { |
131 | for (int i = 0; (i < prefetch_blk_) |
132 | && (prefetches_issued_ < nb_cache_lines_to_prefetch_); |
133 | i++, prefetches_issued_++) { |
134 | prefetch_inst_(cg_->EVEX_compress_addr(reg_base_addr_, |
135 | (cache_block_size_ * prefetch_distance_) |
136 | * sizeof(data_t) |
137 | + (prefetches_issued_ * 64))); |
138 | } |
139 | } |
140 | } |
141 | |
142 | private: |
143 | void prefetch_inst_(const Xbyak::Address &addr) { |
144 | switch (cache_type_) { |
145 | case L1: cg_->prefetcht0(addr); break; |
146 | case L2: cg_->prefetcht1(addr); break; |
147 | case L3: cg_->prefetcht2(addr); break; |
148 | default: break; // TODO: raise an exception or put an assert |
149 | } |
150 | } |
151 | |
152 | jit_generator *cg_; |
153 | Xbyak::Reg64 reg_base_addr_; |
154 | cache_t cache_type_; |
155 | int cache_block_size_ = 0; |
156 | int nb_cache_lines_to_prefetch_ = 0; |
157 | int prefetches_issued_ = 0; |
158 | int prefetch_spread_ = 0; |
159 | int prefetch_blk_ = 0; |
160 | int prefetch_distance_ = 0; |
161 | }; |
162 | |
163 | // utilities to support kernel parameter selection |
164 | bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, int dimN_block, |
165 | float C2_min, float C2_max) { |
166 | float block_size = alpha * alpha |
167 | * (2 * (jcp.oc + jcp.ic) * dimN_block * jcp.dimN_reg_block |
168 | + div_up(jcp.ic * jcp.oc, jcp.nthr)) |
169 | * (float)sizeof(float); |
170 | float L2_lb = C2_min * L2_cache_size; |
171 | float L2_ub = C2_max * L2_cache_size; |
172 | return (block_size > L2_lb && block_size < L2_ub); |
173 | } |
174 | |
175 | bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block, |
176 | int dimM_block, float C1_min, float C1_max) { |
177 | float gemm_block_size |
178 | = (dimM_block * jcp.dimM_simd_block * dimK_block |
179 | * jcp.dimK_reg_block * jcp.dimM_reg_block |
180 | + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block |
181 | + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block) |
182 | * (float)sizeof(float); |
183 | float L1_lb = C1_min * L1_cache_size; |
184 | float L1_ub = C1_max * L1_cache_size; |
185 | return (gemm_block_size > L1_lb && gemm_block_size < L1_ub); |
186 | } |
187 | bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block, |
188 | int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) { |
189 | float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block |
190 | + dimM_block * dimK_block * dimK_reg_block |
191 | * dimM_simd_block * dimM_reg_block |
192 | + dimK_block * dimN_reg_block * dimK_reg_block) |
193 | * (float)sizeof(float); |
194 | float rhs = C * L1_cache_size; |
195 | return (lhs < rhs); |
196 | } |
197 | bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block, |
198 | int dimM_block, int dimM_reg_block, int dimM_simd_block, float C) { |
199 | float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block |
200 | * dimM_simd_block |
201 | + dimK_block * dimN_reg_block * dimK_reg_block) |
202 | * (float)sizeof(float); |
203 | float rhs = C * L1_cache_size; |
204 | return (lhs < rhs); |
205 | } |
206 | bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block, |
207 | int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block, |
208 | int dimM_simd_block, float C) { |
209 | float lhs |
210 | = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block |
211 | * dimM_reg_block |
212 | + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block |
213 | * dimM_simd_block * dimM_reg_block |
214 | + nb_dimN_reg_block * dimK_nb_block * dimK_block |
215 | * dimN_reg_block * dimK_reg_block) |
216 | * (float)sizeof(float); |
217 | float rhs = C * L2_cache_size; |
218 | return (lhs < rhs); |
219 | } |
220 | |
221 | bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block, |
222 | int dimN_block, int dimN_reg_block, int dimK, float C1, float C2) { |
223 | float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK |
224 | * (float)sizeof(float); |
225 | float B_size = dimN_block * dimN_reg_block * dimK * (float)sizeof(float); |
226 | return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size); |
227 | } |
228 | } // namespace |
229 | |
230 | using namespace dnnl::impl::format_tag; |
231 | using namespace dnnl::impl::utils; |
232 | using namespace Xbyak; |
233 | |
234 | void _jit_avx512_core_f32_wino_conv_4x3_data_kernel::gemm_loop_generate() { |
235 | // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++) |
236 | // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block; |
237 | // dimM_reg_block++) // unrolled |
238 | // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++) |
239 | // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block; |
240 | // dimK_reg_block++) // unrolled |
241 | // for (int tile =0; tile < jcp.dimN_reg_block; tile++) |
242 | // C[dimM_block][dimM_reg_block][tile] += |
243 | // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block] |
244 | // * broadcast(B[dimK_block][tile][dimK_reg_block]); |
245 | // Notes: |
246 | // jcp.kernel_kind defines embedded or explicit broadcast |
247 | // dimM_reg_block=1 for embedded bcast kernel |
248 | |
249 | auto zmm_srcA = [=]() { return Xbyak::Zmm(0); }; |
250 | auto zmm_srcB = [=](int tile) { |
251 | int idx = 1 + tile; |
252 | assert(idx < 1 + jcp.dimN_reg_block); |
253 | return Xbyak::Zmm(idx); |
254 | }; |
255 | auto zmm_dstC = [=](int dimM_reg_block, int tile) { |
256 | int idx {0}; |
257 | if (jcp.kernel_kind == embd_bcast) |
258 | idx = 1 + tile; |
259 | else |
260 | idx = 1 + jcp.dimN_reg_block + dimM_reg_block * jcp.dimN_reg_block |
261 | + tile; |
262 | assert(idx < 32); |
263 | return Xbyak::Zmm(idx); |
264 | }; |
265 | |
266 | auto prepare_output = [=]() { |
267 | for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; |
268 | dimM_reg_block++) { |
269 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
270 | Zmm zmm = zmm_dstC(dimM_reg_block, tile); |
271 | vpxord(zmm, zmm, zmm); |
272 | } |
273 | } |
274 | }; |
275 | auto store_output = [=](bool output_is_aligned) { |
276 | Label save; |
277 | cmp(reg_is_beta_zero, 0); |
278 | je(save, T_NEAR); |
279 | |
280 | for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; |
281 | dimM_reg_block++) { |
282 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
283 | Zmm zmm = zmm_dstC(dimM_reg_block, tile); |
284 | int output_offset |
285 | = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; |
286 | vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset)); |
287 | } |
288 | } |
289 | |
290 | L(save); |
291 | for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; |
292 | dimM_reg_block++) { |
293 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
294 | Zmm zmm = zmm_dstC(dimM_reg_block, tile); |
295 | int output_offset |
296 | = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64; |
297 | |
298 | // In W_SGD, output will be reused. |
299 | if (output_is_aligned && jcp.dimK_nb_block == 1 |
300 | && jcp.sched_policy == WSCHED_DATA_W_S_G_D |
301 | && (jcp.dimN * jcp.dimM * alpha * alpha * sizeof(float) |
302 | > 2 * LLC_data_size * jcp.nthr)) |
303 | vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm); |
304 | else |
305 | vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm); |
306 | } |
307 | } |
308 | }; |
309 | |
310 | auto inner_loops = [=]() { |
311 | Label dimM_block_loop, dimK_block_loop; |
312 | |
313 | if (jcp.dimM_block > 1) { |
314 | mov(reg_dimM_block_loop_cnt, jcp.dimM_block); |
315 | L(dimM_block_loop); |
316 | } |
317 | |
318 | prepare_output(); |
319 | |
320 | if (jcp.dimK_block > 1) { |
321 | mov(reg_dimK_block_loop_cnt, jcp.dimK_block); |
322 | L(dimK_block_loop); |
323 | } |
324 | |
325 | for (int dimK_reg_block = 0; dimK_reg_block < jcp.dimK_reg_block; |
326 | dimK_reg_block++) { |
327 | |
328 | if (jcp.kernel_kind == expl_bcast) { |
329 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
330 | vbroadcastss(zmm_srcB(tile), |
331 | ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]); |
332 | } |
333 | } |
334 | |
335 | /* Performing the fmas */ |
336 | |
337 | for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block; |
338 | dimM_reg_block++) { |
339 | |
340 | vmovups(zmm_srcA(), |
341 | zword[reg_srcA |
342 | + jcp.dimK_reg_block * jcp.dimK_block * 64 |
343 | * dimM_reg_block |
344 | + dimK_reg_block * 64]); |
345 | |
346 | for (int tile = 0; tile < jcp.dimN_reg_block; tile++) { |
347 | if (jcp.kernel_kind == expl_bcast) |
348 | vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), |
349 | zmm_srcB(tile)); |
350 | else |
351 | vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(), |
352 | EVEX_compress_addr(reg_srcB, |
353 | 64 * tile + dimK_reg_block * 4, true)); |
354 | } |
355 | } |
356 | } |
357 | add(reg_srcA, jcp.dimK_reg_block * 64); |
358 | add(reg_srcB, jcp.dimN_reg_block * 64); |
359 | if (jcp.dimK_block > 1) { |
360 | sub(reg_dimK_block_loop_cnt, 1); |
361 | jnz(dimK_block_loop); |
362 | } |
363 | |
364 | Label unaligned_store, end_store; |
365 | test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1); |
366 | jnz(unaligned_store, T_NEAR); |
367 | store_output(true); |
368 | jmp(end_store, T_NEAR); |
369 | L(unaligned_store); |
370 | { store_output(false); } |
371 | L(end_store); |
372 | |
373 | if (jcp.dimM_block > 1) { |
374 | sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64); |
375 | add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64); |
376 | if (jcp.kernel_kind == expl_bcast) { |
377 | add(reg_srcA, |
378 | (jcp.dimM_reg_block - 1) * jcp.dimK_reg_block * 64 |
379 | * jcp.dimK_block); |
380 | } |
381 | sub(reg_dimM_block_loop_cnt, 1); |
382 | jnz(dimM_block_loop); |
383 | } |
384 | }; |
385 | |
386 | /* Preamble */ |
387 | preamble(); |
388 | |
389 | /* kernel */ |
390 | inner_loops(); |
391 | |
392 | /* Postamble */ |
393 | postamble(); |
394 | ret(); |
395 | } |
396 | |
397 | void _jit_avx512_core_f32_wino_conv_4x3_data_kernel :: |
398 | weights_transform_data_ker_generate() { |
399 | bool is_fwd = one_of( |
400 | jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference); |
401 | int kh = jcp.kh; |
402 | int kw = jcp.kw; |
403 | |
404 | auto zmm_temp = Xbyak::Zmm(31); |
405 | auto zmm_zero = Xbyak::Zmm(30); |
406 | |
407 | auto zmm_M = [=](int i) { return Xbyak::Zmm(i); }; |
408 | auto zmm_MT = [=](int i) { return Xbyak::Zmm(i + simd_w); }; |
409 | |
410 | auto zmm_G = [=](int i) { return Xbyak::Zmm(i); }; |
411 | auto zmm_F = [=](int i) { return Xbyak::Zmm(alpha + i); }; |
412 | auto zmm_T = [=](int i) { return Xbyak::Zmm(alpha + 3 + i); }; |
413 | auto zmm_t = [=](int i) { return Xbyak::Zmm(2 * alpha + 3 + i); }; |
414 | |
415 | auto zmm_load = [=](int i) { return Xbyak::Zmm(i); }; |
416 | |
417 | auto init_G = [=]() { |
418 | mov(wreg_temp, ptr[param1 + GET_OFF(G)]); |
419 | for (int i = 0; i < alpha; i++) { |
420 | vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]); |
421 | } |
422 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
423 | }; |
424 | |
425 | auto trans16x16 = [=]() { |
426 | for (int i = 0; i < simd_w; i += 2) { |
427 | vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]); |
428 | vmovups(zmm_M(i + 1), ptr[wreg_M + (i + 1) * simd_w * 4]); |
429 | vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i + 1)); |
430 | vunpckhps(zmm_MT(i + 1), zmm_M(i), zmm_M(i + 1)); |
431 | } |
432 | for (int i = 0; i < simd_w; i += 4) { |
433 | vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i + 2)); |
434 | vunpckhpd(zmm_M(i + 1), zmm_MT(i), zmm_MT(i + 2)); |
435 | vunpcklpd(zmm_M(i + 2), zmm_MT(i + 1), zmm_MT(i + 3)); |
436 | vunpckhpd(zmm_M(i + 3), zmm_MT(i + 1), zmm_MT(i + 3)); |
437 | } |
438 | for (int i = 0; i < simd_w; i += 8) { |
439 | vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88); |
440 | vshuff32x4(zmm_MT(i + 1), zmm_M(i + 1), zmm_M(i + 5), 0x88); |
441 | vshuff32x4(zmm_MT(i + 2), zmm_M(i + 2), zmm_M(i + 6), 0x88); |
442 | vshuff32x4(zmm_MT(i + 3), zmm_M(i + 3), zmm_M(i + 7), 0x88); |
443 | vshuff32x4(zmm_MT(i + 4), zmm_M(i), zmm_M(i + 4), 0xdd); |
444 | vshuff32x4(zmm_MT(i + 5), zmm_M(i + 1), zmm_M(i + 5), 0xdd); |
445 | vshuff32x4(zmm_MT(i + 6), zmm_M(i + 2), zmm_M(i + 6), 0xdd); |
446 | vshuff32x4(zmm_MT(i + 7), zmm_M(i + 3), zmm_M(i + 7), 0xdd); |
447 | } |
448 | { |
449 | int i = 0; |
450 | int mask = 0x88; |
451 | vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask); |
452 | vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0)); |
453 | vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask); |
454 | vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1)); |
455 | vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask); |
456 | vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2)); |
457 | vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask); |
458 | vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3)); |
459 | vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask); |
460 | vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4)); |
461 | vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask); |
462 | vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5)); |
463 | vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask); |
464 | vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6)); |
465 | vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask); |
466 | vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7)); |
467 | mask = 0xdd; |
468 | vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask); |
469 | vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8)); |
470 | vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask); |
471 | vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9)); |
472 | vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask); |
473 | vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10)); |
474 | vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask); |
475 | vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11)); |
476 | vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask); |
477 | vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12)); |
478 | vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask); |
479 | vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13)); |
480 | vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask); |
481 | vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14)); |
482 | vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask); |
483 | vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15)); |
484 | } |
485 | }; |
486 | |
487 | auto load_src = [=]() { |
488 | mov(wreg_src, ptr[param1 + GET_OFF(src)]); |
489 | mov(wreg_F, ptr[param1 + GET_OFF(M)]); |
490 | for (int j = 0; j < kh; j++) { |
491 | for (int i = 0; i < kw; i++) { |
492 | if (is_fwd) { |
493 | for (int v1 = 0; v1 < simd_w; v1++) { |
494 | int offset_src |
495 | = (j * kw * simd_w * simd_w |
496 | + i * simd_w * simd_w + v1 * simd_w) |
497 | * typesize; |
498 | int offset_F |
499 | = (j * kw * simd_w * simd_w |
500 | + i * simd_w * simd_w + v1 * simd_w) |
501 | * typesize; |
502 | vmovups(zmm_temp, ptr[wreg_src + offset_src]); |
503 | vmovups(ptr[wreg_F + offset_F], zmm_temp); |
504 | } |
505 | } else { |
506 | int offset_src = ((2 - j) * kw * simd_w * simd_w |
507 | + (2 - i) * simd_w * simd_w) |
508 | * typesize; |
509 | int offset_F |
510 | = (j * kw * simd_w * simd_w + i * simd_w * simd_w) |
511 | * typesize; |
512 | lea(wreg_M, ptr[wreg_src + offset_src]); |
513 | lea(wreg_MT, ptr[wreg_F + offset_F]); |
514 | trans16x16(); |
515 | } |
516 | } |
517 | } |
518 | }; |
519 | |
520 | auto store_dst = [=]() { |
521 | mov(wreg_dst, ptr[param1 + GET_OFF(dst)]); |
522 | mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); |
523 | |
524 | Label Loop_j; |
525 | mov(wreg_cnt_j, 0); |
526 | mov(wreg_dst_aux, wreg_dst); |
527 | mov(wreg_Fw_aux, wreg_Fw); |
528 | |
529 | int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block) |
530 | * jcp.dimK_block * simd_w * simd_w; |
531 | |
532 | L(Loop_j); |
533 | { |
534 | for (int i = 0; i < alpha; i++) { |
535 | // touch pages |
536 | vmovups(zmm_load(0), |
537 | ptr[wreg_Fw_aux + (i * simd_w * simd_w) * typesize]); |
538 | mov(wreg_dst_idx, i * dim5 * typesize); |
539 | vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0)); |
540 | } |
541 | for (int i = 0; i < alpha; i++) { |
542 | for (int v1 = 1; v1 < simd_w; v1++) { |
543 | int offset_Fw |
544 | = (i * simd_w * simd_w + v1 * simd_w) * typesize; |
545 | vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]); |
546 | } |
547 | mov(wreg_dst_idx, i * dim5 * typesize); |
548 | for (int v1 = 1; v1 < simd_w; v1++) { |
549 | int offset_dst = v1 * simd_w * typesize; |
550 | vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst], |
551 | zmm_load(v1)); |
552 | } |
553 | } |
554 | add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize); |
555 | add(wreg_dst_aux, alpha * dim5 * typesize); |
556 | add(wreg_cnt_j, 1); |
557 | cmp(wreg_cnt_j, alpha); |
558 | jl(Loop_j, T_NEAR); |
559 | } |
560 | }; |
561 | |
562 | auto trans_W_4x4_3x3 = [=]() { |
563 | auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { |
564 | vmovups(dst, a); |
565 | vfmadd231ps(dst, b, c); |
566 | }; |
567 | auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { |
568 | vmulps(zmm_temp, b, c); |
569 | vsubps(dst, a, zmm_temp); |
570 | }; |
571 | auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { |
572 | vsubps(dst, zmm_zero, a); |
573 | vfnmadd231ps(dst, b, c); |
574 | }; |
575 | |
576 | mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]); |
577 | mov(wreg_F, ptr[param1 + GET_OFF(M)]); |
578 | mov(wreg_T, ptr[param1 + GET_OFF(T)]); |
579 | |
580 | Label Loop_j; |
581 | mov(wreg_cnt_j, 0); |
582 | L(Loop_j); |
583 | mov(wreg_F_aux, wreg_F); |
584 | mov(wreg_Fw_aux, wreg_Fw); |
585 | mov(wreg_temp, wreg_cnt_j); |
586 | shl(wreg_temp, 4 + 2); |
587 | lea(wreg_F_aux, ptr[wreg_F + wreg_temp]); |
588 | lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]); |
589 | |
590 | for (int i = 0; i < 3; i++) { |
591 | for (int idx = 0; idx < 3; idx++) { |
592 | vmovups(zmm_F(idx), |
593 | ptr[wreg_F_aux |
594 | + (idx * 3 * simd_w * simd_w |
595 | + i * simd_w * simd_w) |
596 | * typesize]); |
597 | } |
598 | vmulps(zmm_t(0), zmm_G(0), zmm_F(2)); |
599 | fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0)); |
600 | fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0)); |
601 | |
602 | vmulps(zmm_T(0), zmm_G(3), zmm_F(0)); |
603 | fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1)); |
604 | fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1)); |
605 | fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1)); |
606 | fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1)); |
607 | vmovaps(zmm_T(5), zmm_F(2)); |
608 | |
609 | for (int idx = 0; idx < 6; idx++) { |
610 | vmovups(ptr[wreg_T |
611 | + (idx * 3 * simd_w + i * simd_w) * typesize], |
612 | zmm_T(idx)); |
613 | } |
614 | } |
615 | for (int i = 0; i < 6; i++) { |
616 | |
617 | for (int idx = 0; idx < 3; idx++) { |
618 | vmovups(zmm_T(idx), |
619 | ptr[wreg_T |
620 | + (i * 3 * simd_w + idx * simd_w) * typesize]); |
621 | } |
622 | vmulps(zmm_t(0), zmm_G(0), zmm_T(2)); |
623 | fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0)); |
624 | fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0)); |
625 | |
626 | vmulps(zmm_F(0), zmm_G(3), zmm_T(0)); |
627 | fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1)); |
628 | fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1)); |
629 | fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1)); |
630 | fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1)); |
631 | vmovaps(zmm_F(5), zmm_T(2)); |
632 | |
633 | for (int l = 0; l < 6; l++) { |
634 | vmovups(ptr[wreg_Fw_aux |
635 | + (i * 6 * simd_w * simd_w |
636 | + l * simd_w * simd_w) |
637 | * typesize], |
638 | zmm_F(l)); |
639 | } |
640 | } |
641 | add(wreg_cnt_j, 1); |
642 | cmp(wreg_cnt_j, 16); |
643 | jl(Loop_j, T_NEAR); |
644 | }; |
645 | |
646 | auto inner_loops = [=]() { |
647 | load_src(); |
648 | init_G(); |
649 | trans_W_4x4_3x3(); |
650 | store_dst(); |
651 | }; |
652 | |
653 | preamble(); |
654 | inner_loops(); |
655 | postamble(); |
656 | } |
657 | |
658 | void _jit_avx512_core_f32_wino_conv_4x3_data_kernel :: |
659 | output_transform_data_ker_generate() { |
660 | bool is_fwd = one_of( |
661 | jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference); |
662 | int outw = is_fwd ? jcp.ow : jcp.iw; |
663 | int outh = is_fwd ? jcp.oh : jcp.ih; |
664 | bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; |
665 | bool with_bias = jcp.with_bias; |
666 | bool with_relu = jcp.with_eltwise; |
667 | bool with_relu_postsum = jcp.with_relu_postsum; |
668 | bool with_sum = jcp.with_sum; |
669 | |
670 | auto zmm_zero = Xbyak::Zmm(0); |
671 | auto zmm_temp = Xbyak::Zmm(31); |
672 | auto zmm_G = [=](int i) { return Xbyak::Zmm(1 + i); }; |
673 | auto zmm_O = [=](int i) { return Xbyak::Zmm(1 + alpha + i); }; |
674 | auto zmm_T = [=](int i) { return Xbyak::Zmm(1 + 2 * alpha + i); }; |
675 | auto zmm_t = [=](int i) { return Xbyak::Zmm(1 + 3 * alpha + i); }; |
676 | |
677 | auto init_G = [=]() { |
678 | mov(oreg_temp, ptr[param1 + GET_OFF(G)]); |
679 | for (int i = 0; i < 6; i++) { |
680 | vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]); |
681 | } |
682 | }; |
683 | |
684 | auto load_src = [=]() { |
685 | mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); |
686 | mov(oreg_src, ptr[param1 + GET_OFF(src)]); |
687 | |
688 | mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); |
689 | imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur, |
690 | (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block |
691 | * jcp.dimM_simd_block * typesize); |
692 | add(oreg_src, oreg_nb_tile_block_ur); |
693 | |
694 | mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); |
695 | imul(oreg_tile_block_ur, oreg_tile_block_ur, |
696 | jcp.dimM_simd_block * typesize); |
697 | add(oreg_src, oreg_tile_block_ur); |
698 | |
699 | if (not_tiled) { |
700 | mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]); |
701 | imul(oreg_tile_block, oreg_tile_block, |
702 | jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block |
703 | * (jcp.dimM_block * jcp.dimM_reg_block) |
704 | * jcp.dimN_reg_block * jcp.dimM_simd_block |
705 | * typesize); |
706 | add(oreg_src, oreg_tile_block); |
707 | } |
708 | |
709 | int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block) |
710 | * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize; |
711 | for (int j = 0; j < alpha; j++) { |
712 | for (int i = 0; i < alpha; i++) { |
713 | int j_base_offset = j * alpha * last4dim; |
714 | int i_base_offset = i * last4dim; |
715 | vmovups(zmm_temp, |
716 | ptr[oreg_src + j_base_offset + i_base_offset]); |
717 | vmovups(ptr[oreg_Ow |
718 | + (j * alpha * simd_w + i * simd_w) * typesize], |
719 | zmm_temp); |
720 | } |
721 | } |
722 | }; |
723 | |
724 | auto store_dst = [=]() { |
725 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
726 | mov(oreg_dst, ptr[param1 + GET_OFF(dst)]); |
727 | mov(oreg_O, ptr[param1 + GET_OFF(M)]); |
728 | mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]); |
729 | shl(oreg_ydim, 2); // tj * tile_size (==4) |
730 | mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]); |
731 | shl(oreg_xdim, 2); // ti * tilesize (==4) |
732 | |
733 | if (with_bias) mov(oreg_bias, ptr[param1 + GET_OFF(bias)]); |
734 | |
735 | auto store_one = [=](int j, int i, bool is_aligned) { |
736 | auto zmm_O = Xbyak::Zmm(31); |
737 | auto zmm_relu_ns = Xbyak::Zmm(30); |
738 | auto xmm_relu_ns = Xbyak::Xmm(30); |
739 | int offset = (j * tile_size * simd_w + i * simd_w) * typesize; |
740 | |
741 | vmovups(zmm_O, ptr[oreg_O + offset]); |
742 | if (is_fwd) { |
743 | if (with_bias) { vaddps(zmm_O, zmm_O, ptr[oreg_bias]); } |
744 | if (with_relu) { |
745 | if (jcp.eltwise.alpha == 0) { |
746 | vmaxps(zmm_O, zmm_O, zmm_zero); |
747 | } else { |
748 | Opmask kmask = Opmask(7); |
749 | mov(imm_addr64, float2int(jcp.eltwise.alpha)); |
750 | vmovq(xmm_relu_ns, imm_addr64); |
751 | vbroadcastss(zmm_relu_ns, xmm_relu_ns); |
752 | vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os); |
753 | vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns); |
754 | } |
755 | } |
756 | } |
757 | if (with_sum) { |
758 | vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]); |
759 | if (with_relu_postsum) // orig: with_relu_postsum |
760 | vmaxps(zmm_O, zmm_O, zmm_zero); |
761 | } |
762 | if (is_aligned) |
763 | vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O); |
764 | else |
765 | vmovups(ptr[oreg_out_j + oreg_temp], zmm_O); |
766 | }; |
767 | |
768 | auto i_loop = [=](int j, bool is_aligned) { |
769 | for (int i = 0; i < tile_size; i++) { |
770 | Label next; |
771 | mov(oreg_temp, oreg_xdim); |
772 | add(oreg_temp, i); |
773 | cmp(oreg_temp, outw); |
774 | jge(next, T_NEAR); |
775 | shl(oreg_temp, 4 + 2); // * 16 * 4 |
776 | |
777 | store_one(j, i, is_aligned); |
778 | |
779 | L(next); |
780 | } |
781 | }; |
782 | |
783 | for (int j = 0; j < tile_size; j++) { |
784 | Label next, unaligned; |
785 | mov(oreg_temp, oreg_ydim); |
786 | add(oreg_temp, j); |
787 | cmp(oreg_temp, outh); |
788 | jge(next, T_NEAR); |
789 | |
790 | mov(oreg_out_j, oreg_dst); |
791 | imul(oreg_temp, oreg_temp, outw * simd_w * typesize); |
792 | add(oreg_out_j, oreg_temp); |
793 | |
794 | test(oreg_dst, 63); |
795 | jnz(unaligned, T_NEAR); |
796 | |
797 | i_loop(j, true); |
798 | jmp(next, T_NEAR); |
799 | |
800 | L(unaligned); |
801 | i_loop(j, false); |
802 | |
803 | L(next); |
804 | } |
805 | }; |
806 | |
807 | auto trans_O_4x4_3x3 = [=]() { |
808 | auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2) { |
809 | vmulps(dst, v1, u1); |
810 | vfmadd231ps(dst, v2, u2); |
811 | }; |
812 | mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]); |
813 | mov(oreg_T, ptr[param1 + GET_OFF(T)]); |
814 | mov(oreg_O, ptr[param1 + GET_OFF(M)]); |
815 | |
816 | for (int i = 0; i < alpha; i++) { |
817 | for (int j = 0; j < alpha; j++) { |
818 | vmovups(zmm_O(j), |
819 | ptr[oreg_Ow |
820 | + (j * alpha * simd_w + i * simd_w) |
821 | * typesize]); |
822 | } |
823 | |
824 | vaddps(zmm_t(0), zmm_O(1), zmm_O(2)); |
825 | vaddps(zmm_t(1), zmm_O(3), zmm_O(4)); |
826 | vsubps(zmm_t(2), zmm_O(1), zmm_O(2)); |
827 | vsubps(zmm_t(3), zmm_O(3), zmm_O(4)); |
828 | |
829 | vaddps(zmm_T(0), zmm_t(0), zmm_t(1)); |
830 | vaddps(zmm_T(0), zmm_T(0), zmm_O(0)); |
831 | fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); |
832 | fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); |
833 | fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); |
834 | vaddps(zmm_T(3), zmm_T(3), zmm_O(5)); |
835 | |
836 | for (int j = 0; j < tile_size; j++) { |
837 | vmovups(ptr[oreg_T |
838 | + (j * alpha * simd_w + i * simd_w) * typesize], |
839 | zmm_T(j)); |
840 | } |
841 | } |
842 | for (int j = 0; j < tile_size; j++) { |
843 | for (int i = 0; i < alpha; i++) { |
844 | vmovups(zmm_T(i), |
845 | ptr[oreg_T |
846 | + (j * alpha * simd_w + i * simd_w) |
847 | * typesize]); |
848 | } |
849 | vaddps(zmm_t(0), zmm_T(1), zmm_T(2)); |
850 | vaddps(zmm_t(1), zmm_T(3), zmm_T(4)); |
851 | vsubps(zmm_t(2), zmm_T(1), zmm_T(2)); |
852 | vsubps(zmm_t(3), zmm_T(3), zmm_T(4)); |
853 | |
854 | vaddps(zmm_O(0), zmm_t(0), zmm_t(1)); |
855 | vaddps(zmm_O(0), zmm_O(0), zmm_T(0)); |
856 | fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1)); |
857 | fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3)); |
858 | fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5)); |
859 | vaddps(zmm_O(3), zmm_O(3), zmm_T(5)); |
860 | |
861 | for (int i = 0; i < tile_size; i++) { |
862 | vmovups(ptr[oreg_O |
863 | + (j * tile_size * simd_w + i * simd_w) |
864 | * typesize], |
865 | zmm_O(i)); |
866 | } |
867 | } |
868 | }; |
869 | |
870 | auto inner_loops = [=]() { |
871 | init_G(); |
872 | load_src(); |
873 | trans_O_4x4_3x3(); |
874 | store_dst(); |
875 | }; |
876 | |
877 | preamble(); |
878 | inner_loops(); |
879 | postamble(); |
880 | } |
881 | |
882 | void _jit_avx512_core_f32_wino_conv_4x3_data_kernel :: |
883 | input_transform_data_ker_generate() { |
884 | bool is_fwd = one_of( |
885 | jcp.prop_kind, dnnl_forward_training, dnnl_forward_inference); |
886 | int inpw = is_fwd ? jcp.iw : jcp.ow; |
887 | int inph = is_fwd ? jcp.ih : jcp.oh; |
888 | int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow; |
889 | int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh; |
890 | int wp_max = inpw + l_pad; |
891 | int hp_max = inph + t_pad; |
892 | bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D; |
893 | int G_size = 9; |
894 | |
895 | auto zmm_zero = Xbyak::Zmm(0); |
896 | auto zmm_temp = Xbyak::Zmm(31); |
897 | auto zmm_G = [=](int i) { return Xbyak::Zmm(1 + i); }; |
898 | auto zmm_I = [=](int i) { return Xbyak::Zmm(1 + G_size + i); }; |
899 | auto zmm_T = [=](int i) { return Xbyak::Zmm(1 + G_size + alpha + i); }; |
900 | auto zmm_t = [=](int i) { return Xbyak::Zmm(1 + G_size + 2 * alpha + i); }; |
901 | |
902 | auto init_G = [=]() { |
903 | mov(ireg_temp, ptr[param1 + GET_OFF(G)]); |
904 | for (int i = 0; i < G_size; i++) { |
905 | vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]); |
906 | } |
907 | }; |
908 | |
909 | auto load_src = [=]() { |
910 | mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp |
911 | mov(ireg_I, ptr[param1 + GET_OFF(M)]); |
912 | |
913 | xor_(ireg_zero, ireg_zero); |
914 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
915 | |
916 | mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]); |
917 | shl(ireg_ydim, 2); // tj * tile_size (==4) |
918 | mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]); |
919 | shl(ireg_xdim, 2); // ti * tilesize (==4) |
920 | |
921 | for (int j = 0; j < alpha; j++) { |
922 | mov(ireg_temp, ireg_ydim); |
923 | add(ireg_temp, j); |
924 | |
925 | mov(ireg_mask_j, 0xffff); |
926 | cmp(ireg_temp, t_pad); |
927 | cmovl(ireg_mask_j, ireg_zero); |
928 | cmp(ireg_temp, hp_max); |
929 | cmovge(ireg_mask_j, ireg_zero); |
930 | |
931 | sub(ireg_temp, t_pad); |
932 | imul(ireg_temp, ireg_temp, inpw * simd_w * typesize); |
933 | mov(ireg_inp_j, ireg_src); |
934 | add(ireg_inp_j, ireg_temp); |
935 | |
936 | for (int i = 0; i < alpha; i++) { |
937 | |
938 | mov(ireg_temp, ireg_xdim); |
939 | add(ireg_temp, i); |
940 | |
941 | mov(ireg_mask, 0xffff); |
942 | cmp(ireg_temp, l_pad); |
943 | cmovl(ireg_mask, ireg_zero); |
944 | cmp(ireg_temp, wp_max); |
945 | cmovge(ireg_mask, ireg_zero); |
946 | and_(ireg_mask, ireg_mask_j); |
947 | |
948 | sub(ireg_temp, l_pad); |
949 | shl(ireg_temp, 4 + 2); |
950 | |
951 | vpxord(zmm_temp, zmm_temp, zmm_temp); |
952 | Opmask kmask = Opmask(7); |
953 | kmovw(kmask, ireg_mask_32); |
954 | vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]); |
955 | vmovups(ptr[ireg_I |
956 | + (j * alpha * simd_w + i * simd_w) * typesize], |
957 | zmm_temp); |
958 | } |
959 | } |
960 | }; |
961 | |
962 | auto store_Iw = [=]() { |
963 | mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); |
964 | mov(ireg_output, ptr[param1 + GET_OFF(dst)]); |
965 | |
966 | bool streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float) |
967 | > 2 * LLC_data_size * jcp.nthr |
968 | ? true |
969 | : false; |
970 | |
971 | if (not_tiled) { |
972 | mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]); |
973 | imul(ireg_tile_block, ireg_tile_block, |
974 | alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block |
975 | * jcp.dimK_block * jcp.dimN_reg_block |
976 | * jcp.dimK_reg_block * typesize); |
977 | } |
978 | |
979 | mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]); |
980 | imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur, |
981 | jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block |
982 | * jcp.dimK_reg_block * typesize); |
983 | |
984 | mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]); |
985 | imul(ireg_tile_block_ur, ireg_tile_block_ur, |
986 | jcp.dimK_reg_block * typesize); |
987 | |
988 | add(ireg_output, ireg_nb_tile_block_ur); |
989 | add(ireg_output, ireg_tile_block_ur); |
990 | if (not_tiled) add(ireg_output, ireg_tile_block); |
991 | |
992 | for (int j = 0; j < alpha; j++) { |
993 | for (int i = 0; i < alpha; i++) { |
994 | vmovups(zmm_temp, |
995 | ptr[ireg_Iw |
996 | + (j * alpha * simd_w + i * simd_w) |
997 | * typesize]); |
998 | |
999 | int j_base_offset = j * alpha * jcp.dimN_block |
1000 | * jcp.dimK_nb_block * jcp.dimK_block |
1001 | * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; |
1002 | int i_base_offset = i * jcp.dimN_block * jcp.dimK_nb_block |
1003 | * jcp.dimK_block * jcp.dimN_reg_block |
1004 | * jcp.dimK_reg_block * typesize; |
1005 | |
1006 | if (not_tiled && streamout) |
1007 | vmovntps(ptr[ireg_output + j_base_offset + i_base_offset], |
1008 | zmm_temp); |
1009 | else |
1010 | vmovups(ptr[ireg_output + j_base_offset + i_base_offset], |
1011 | zmm_temp); |
1012 | } |
1013 | } |
1014 | }; |
1015 | |
1016 | auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) { |
1017 | vmulps(zmm_temp, a, b); |
1018 | vaddps(dst, zmm_temp, c); |
1019 | }; |
1020 | |
1021 | auto trans_I_4x4_3x3 = [=]() { |
1022 | mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]); |
1023 | mov(ireg_T, ptr[param1 + GET_OFF(T)]); |
1024 | mov(ireg_I, ptr[param1 + GET_OFF(M)]); |
1025 | |
1026 | mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch |
1027 | for (int i = 0; i < alpha; i++) { |
1028 | for (int idx = 0; idx < alpha; idx++) { |
1029 | vmovups(zmm_I(idx), |
1030 | ptr[ireg_I |
1031 | + (idx * alpha * simd_w + i * simd_w) |
1032 | * typesize]); |
1033 | int j_base_offset = i * alpha * jcp.dimN_block |
1034 | * jcp.dimK_nb_block * jcp.dimK_block |
1035 | * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize; |
1036 | int idx_base_offset = idx * jcp.dimN_block * jcp.dimK_nb_block |
1037 | * jcp.dimK_block * jcp.dimN_reg_block |
1038 | * jcp.dimK_reg_block * typesize; |
1039 | prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]); |
1040 | } |
1041 | |
1042 | fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); |
1043 | fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); |
1044 | fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); |
1045 | fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); |
1046 | fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); |
1047 | fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); |
1048 | |
1049 | fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); |
1050 | fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); |
1051 | fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); |
1052 | fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); |
1053 | fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); |
1054 | fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); |
1055 | |
1056 | for (int idx = 0; idx < alpha; idx++) { |
1057 | vmovups(ptr[ireg_T |
1058 | + (idx * alpha * simd_w + i * simd_w) |
1059 | * typesize], |
1060 | zmm_T(idx)); |
1061 | } |
1062 | } |
1063 | for (int i = 0; i < alpha; i++) { |
1064 | for (int idx = 0; idx < alpha; idx++) { |
1065 | vmovups(zmm_T(idx), |
1066 | ptr[ireg_T |
1067 | + (i * alpha * simd_w + idx * simd_w) |
1068 | * typesize]); |
1069 | } |
1070 | |
1071 | fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); |
1072 | fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); |
1073 | fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); |
1074 | fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); |
1075 | fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); |
1076 | fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); |
1077 | |
1078 | fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); |
1079 | fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); |
1080 | fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); |
1081 | fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); |
1082 | fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); |
1083 | fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); |
1084 | |
1085 | for (int idx = 0; idx < alpha; idx++) { |
1086 | vmovups(ptr[ireg_Iw |
1087 | + (i * alpha * simd_w + idx * simd_w) |
1088 | * typesize], |
1089 | zmm_I(idx)); |
1090 | } |
1091 | } |
1092 | }; |
1093 | |
1094 | auto inner_loops = [=]() { |
1095 | init_G(); |
1096 | load_src(); |
1097 | trans_I_4x4_3x3(); |
1098 | store_Iw(); |
1099 | }; |
1100 | |
1101 | preamble(); |
1102 | inner_loops(); |
1103 | postamble(); |
1104 | } |
1105 | |
1106 | status_t _jit_avx512_core_f32_wino_conv_4x3_data_kernel::init_conf_common( |
1107 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
1108 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, |
1109 | const memory_desc_wrapper &dst_d) { |
1110 | if (!mayiuse(avx512_core)) { return status::unimplemented; } |
1111 | |
1112 | // This kernel only supports 2D convolutions. |
1113 | if (src_d.ndims() != 4) return status::unimplemented; |
1114 | |
1115 | jcp.nthr = dnnl_get_max_threads(); |
1116 | |
1117 | jcp.prop_kind = cd.prop_kind; |
1118 | |
1119 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1120 | |
1121 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1122 | jcp.mb = src_d.dims()[0]; |
1123 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1124 | jcp.oc_without_padding = jcp.oc; |
1125 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1126 | jcp.ih = src_d.dims()[2]; |
1127 | jcp.iw = src_d.dims()[3]; |
1128 | jcp.oh = dst_d.dims()[2]; |
1129 | jcp.ow = dst_d.dims()[3]; |
1130 | jcp.kh = weights_d.dims()[with_groups + 2]; |
1131 | jcp.kw = weights_d.dims()[with_groups + 3]; |
1132 | jcp.t_pad = cd.padding[0][0]; |
1133 | jcp.l_pad = cd.padding[0][1]; |
1134 | jcp.stride_h = cd.strides[0]; |
1135 | jcp.stride_w = cd.strides[1]; |
1136 | jcp.dilate_h = cd.dilates[0]; |
1137 | jcp.dilate_w = cd.dilates[1]; |
1138 | jcp.r_pad = nstl::max( |
1139 | 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); |
1140 | jcp.b_pad = nstl::max( |
1141 | 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); |
1142 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
1143 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
1144 | jcp.ohp = jcp.oh; |
1145 | jcp.owp = jcp.ow; |
1146 | |
1147 | bool ok_to_pad_channels = jcp.ngroups == 1; |
1148 | if (ok_to_pad_channels) { |
1149 | jcp.oc = rnd_up(jcp.oc, simd_w); |
1150 | jcp.ic = rnd_up(jcp.ic, simd_w); |
1151 | } |
1152 | |
1153 | // Checking conditions not supported by these kernels |
1154 | if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, |
1155 | is_winograd_faster_than_direct(jcp))) |
1156 | return status::unimplemented; |
1157 | |
1158 | const bool prb_shape_ok = jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 |
1159 | && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 && jcp.stride_h == 1 |
1160 | && jcp.stride_w == 1 && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
1161 | && jcp.l_pad <= 1 && jcp.r_pad <= 1 && jcp.t_pad <= 1 |
1162 | && jcp.b_pad <= 1; |
1163 | if (!prb_shape_ok) return status::unimplemented; |
1164 | |
1165 | format_tag_t dat_tag = nChw16c; |
1166 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
1167 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
1168 | |
1169 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
1170 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
1171 | |
1172 | if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) { |
1173 | format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; |
1174 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
1175 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
1176 | } |
1177 | |
1178 | bool layout_consistency = true && jcp.ic <= src_d.padded_dims()[1] |
1179 | && jcp.oc <= dst_d.padded_dims()[1] |
1180 | && (one_of(weights_d.format_kind(), format_kind::any, |
1181 | format_kind::wino) |
1182 | || (jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
1183 | && jcp.oc <= weights_d.padded_dims()[with_groups |
1184 | + 0])); |
1185 | if (!layout_consistency) return status::unimplemented; |
1186 | |
1187 | return status::success; |
1188 | } |
1189 | |
1190 | void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) { |
1191 | |
1192 | /* ----------- dimM reg block ---------------------*/ |
1193 | auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp, |
1194 | int dimM_reg_block, |
1195 | int current_best) { |
1196 | int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4; |
1197 | return (dimM_reg_block >= 1) && (dimM_reg_block <= max_dimM_reg_block) |
1198 | && (dimM_reg_block > current_best); |
1199 | }; |
1200 | jcp.dimM_reg_block = get_divisor_satisfying_cond( |
1201 | jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond_dimM_reg_block); |
1202 | |
1203 | /* ----------- dimN reg block ---------------------*/ |
1204 | |
1205 | auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp, |
1206 | int dimN_reg_block, |
1207 | int current_best) { |
1208 | return jcp.kernel_kind == embd_bcast |
1209 | ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best |
1210 | : dimN_reg_block >= 1 |
1211 | && (dimN_reg_block * jcp.dimM_reg_block |
1212 | + dimN_reg_block) |
1213 | < jcp.nb_reg |
1214 | && dimN_reg_block > current_best; |
1215 | }; |
1216 | jcp.dimN_reg_block = get_divisor_satisfying_cond( |
1217 | jcp, jcp.dimN, 1, test_cond_dimN_reg_block); |
1218 | } |
1219 | |
1220 | status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { |
1221 | |
1222 | jcp.kernel_kind = embd_bcast; |
1223 | |
1224 | set_kernel_dims_reg_block(jcp); |
1225 | |
1226 | /*-------------- L2 blocking for dimN block ---------*/ |
1227 | |
1228 | auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp, |
1229 | int dimN_block, int current_best) { |
1230 | return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) |
1231 | && (dimN_block > current_best) |
1232 | && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) |
1233 | >= 1.5 * jcp.nthr); |
1234 | }; |
1235 | |
1236 | jcp.dimN_block = get_divisor_satisfying_cond( |
1237 | jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block); |
1238 | jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; |
1239 | |
1240 | if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) |
1241 | && (jcp.dimN_nb_block >= 1.5 * jcp.nthr)) { |
1242 | |
1243 | /* ------------------- L1 blocking for GEMM --------------*/ |
1244 | /* -------------------- Choose dimK block ----------------*/ |
1245 | |
1246 | auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp, |
1247 | int dimK_block, int current_best) { |
1248 | return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5) |
1249 | && (dimK_block > current_best); |
1250 | }; |
1251 | |
1252 | jcp.dimK_block = get_divisor_satisfying_cond( |
1253 | jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block); |
1254 | |
1255 | if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) { |
1256 | jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; |
1257 | |
1258 | /* -------------- Choose dimM block -------------------*/ |
1259 | auto test_cond_dimM_block |
1260 | = [](jit_conv_winograd_conf_t &jcp, int dimM_block, |
1261 | int current_best) { |
1262 | return check_L1_block_gemm(jcp, jcp.dimK_block, |
1263 | dimM_block, 0.2, 0.5) |
1264 | && (dimM_block > current_best); |
1265 | }; |
1266 | |
1267 | jcp.dimM_block = get_divisor_satisfying_cond(jcp, |
1268 | jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, |
1269 | test_cond_dimM_block); |
1270 | jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block |
1271 | / jcp.dimM_simd_block; |
1272 | |
1273 | jcp.sched_policy = WSCHED_DATA_W_SGD; |
1274 | return status::success; |
1275 | } |
1276 | } |
1277 | return status::unimplemented; |
1278 | } |
1279 | |
1280 | void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) { |
1281 | |
1282 | set_kernel_dims_reg_block(jcp); |
1283 | |
1284 | //********************* Choosing dimK_block **********************// |
1285 | auto test_cond1_dimK_block = [](jit_conv_winograd_conf_t &jcp, |
1286 | int dimK_block, int current_best) { |
1287 | return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block, |
1288 | 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f) |
1289 | && (dimK_block > current_best); |
1290 | }; |
1291 | |
1292 | auto test_cond1_bis_dimK_block = [](jit_conv_winograd_conf_t &jcp, |
1293 | int dimK_block, int current_best) { |
1294 | return check_cond1_bis(jcp.dimN_reg_block, dimK_block, |
1295 | jcp.dimK_reg_block, 1, jcp.dimM_reg_block, |
1296 | jcp.dimM_simd_block, .9f) |
1297 | && (dimK_block > current_best); |
1298 | }; |
1299 | |
1300 | jcp.dimK_block = get_divisor_satisfying_cond( |
1301 | jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block); |
1302 | // If we are not able to use streams, we fall back to condition [1] |
1303 | if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) |
1304 | jcp.dimK_block = get_divisor_satisfying_cond( |
1305 | jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block); |
1306 | jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block; |
1307 | |
1308 | //********************* Choosing dimM_block **********************// |
1309 | auto test_cond1_dimM_block = [](jit_conv_winograd_conf_t &jcp, |
1310 | int dimM_block, int current_best) { |
1311 | return check_cond1(jcp.dimN_reg_block, jcp.dimK_block, |
1312 | jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, |
1313 | jcp.dimM_simd_block, .5f) |
1314 | && (dimM_block > current_best); |
1315 | }; |
1316 | |
1317 | auto test_cond1_bis_dimM_block = [](jit_conv_winograd_conf_t &jcp, |
1318 | int dimM_block, int current_best) { |
1319 | return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block, |
1320 | jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block, |
1321 | jcp.dimM_simd_block, .3f) |
1322 | && (dimM_block > current_best); |
1323 | }; |
1324 | |
1325 | if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block) |
1326 | jcp.dimM_block = get_divisor_satisfying_cond(jcp, |
1327 | jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, |
1328 | test_cond1_dimM_block); |
1329 | else |
1330 | jcp.dimM_block = get_divisor_satisfying_cond(jcp, |
1331 | jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1, |
1332 | test_cond1_bis_dimM_block); |
1333 | jcp.dimM_nb_block = jcp.dimM |
1334 | / (jcp.dimM_simd_block * jcp.dimM_block * jcp.dimM_reg_block); |
1335 | |
1336 | //******************* Choosing dimN_block *******************// |
1337 | auto test_cond2_dimN_block = [](jit_conv_winograd_conf_t &jcp, |
1338 | int dimN_block, int current_best) { |
1339 | return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block, |
1340 | jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block, |
1341 | jcp.dimM_reg_block, jcp.dimM_simd_block, .9f) |
1342 | && (dimN_block > current_best); |
1343 | }; |
1344 | |
1345 | jcp.dimN_block = get_divisor_satisfying_cond( |
1346 | jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block); |
1347 | jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block); |
1348 | } |
1349 | |
1350 | status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) { |
1351 | |
1352 | jcp.kernel_kind = expl_bcast; |
1353 | set_kernel_blocking_DATA_W_S_G_D(jcp); |
1354 | if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block, |
1355 | jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, |
1356 | jcp.dimK, .1f, .35f))) { |
1357 | jcp.kernel_kind = embd_bcast; |
1358 | set_kernel_blocking_DATA_W_S_G_D(jcp); |
1359 | } |
1360 | jcp.sched_policy = WSCHED_DATA_W_S_G_D; |
1361 | return status::success; |
1362 | } |
1363 | |
1364 | status_t _jit_avx512_core_f32_wino_conv_4x3_data_kernel::init_conf_kernel( |
1365 | jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK) { |
1366 | jcp.nb_reg = 32; |
1367 | jcp.dimN = dimN; |
1368 | jcp.dimK = dimK; |
1369 | jcp.dimM = dimM; |
1370 | jcp.sched_policy = WSCHED_INVALID; |
1371 | |
1372 | jcp.dimK_reg_block = 16; |
1373 | jcp.dimM_simd_block = 16; |
1374 | |
1375 | if (jcp.kernel_kind == embd_bcast) { jcp.dimM_reg_block = 1; } |
1376 | |
1377 | if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success)) |
1378 | set_wsched_DATA_W_S_G_D_avx512_core(jcp); |
1379 | |
1380 | assert(jcp.sched_policy != WSCHED_INVALID); |
1381 | return status::success; |
1382 | } |
1383 | |
1384 | bool jit_avx512_core_f32_wino_conv_4x3_fwd_kernel::post_ops_ok( |
1385 | jit_conv_conf_t &jcp, const primitive_attr_t &attr) { |
1386 | const auto &p = attr.post_ops_; |
1387 | |
1388 | auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; |
1389 | auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; |
1390 | |
1391 | switch (p.len()) { |
1392 | case 0: return true; // no post_ops |
1393 | case 1: return is_relu(0) || is_sum(0); // relu or sum |
1394 | case 2: |
1395 | return (is_sum(0) && is_relu(1)) |
1396 | || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum |
1397 | case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu |
1398 | default: return false; |
1399 | } |
1400 | } |
1401 | |
1402 | status_t jit_avx512_core_f32_wino_conv_4x3_fwd_kernel::init_conf( |
1403 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
1404 | const memory_desc_t &src_md, memory_desc_t &weights_md, |
1405 | const memory_desc_t &dst_md, const primitive_attr_t &attr) { |
1406 | |
1407 | status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md); |
1408 | |
1409 | if (st != status::success) return st; |
1410 | |
1411 | // Winograd specific initialization |
1412 | jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; |
1413 | jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; |
1414 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
1415 | |
1416 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
1417 | |
1418 | if (!post_ops_ok(jcp, attr)) return status::unimplemented; |
1419 | |
1420 | const auto &p = attr.post_ops_; |
1421 | const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1); |
1422 | jcp.with_eltwise = eltwise_ind != -1; |
1423 | if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise; |
1424 | |
1425 | jcp.with_sum = p.find(primitive_kind::sum, 0) != -1; |
1426 | jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1; |
1427 | |
1428 | status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic); |
1429 | |
1430 | jcp.ic_simd_block = jcp.dimK_reg_block; |
1431 | jcp.ic_block = jcp.dimK_block; |
1432 | jcp.nb_ic = jcp.dimK_nb_block; |
1433 | jcp.oc_simd_block = jcp.dimM_simd_block; |
1434 | jcp.oc_block = jcp.dimM_block; |
1435 | jcp.oc_reg_block = jcp.dimM_reg_block; |
1436 | jcp.ic_reg_block = 1; |
1437 | jcp.nb_oc = jcp.dimM_nb_block; |
1438 | jcp.tile_block_ur = jcp.dimN_reg_block; |
1439 | jcp.nb_tile_block_ur = jcp.dimN_block; |
1440 | jcp.tile_block = jcp.dimN_nb_block; |
1441 | |
1442 | /* re-create weights primitive descriptor |
1443 | and set weights wino_blocking */ |
1444 | if (cd.prop_kind == dnnl_forward_inference) { |
1445 | memory_desc_t expect_wei_md = weights_md; |
1446 | |
1447 | expect_wei_md.format_kind = format_kind::wino; |
1448 | expect_wei_md.data_type = data_type::f32; |
1449 | wino_desc_t &wd = expect_wei_md.format_desc.wino_desc; |
1450 | wd.wino_format = wino_memory_format_t::wino_wei_OBaaIBOIio; |
1451 | wd.r = 3; |
1452 | wd.alpha = 6; |
1453 | |
1454 | wd.ic = jcp.ic; |
1455 | wd.oc = jcp.oc; |
1456 | wd.ic_block = jcp.dimK_reg_block; |
1457 | wd.oc_block = jcp.dimM_simd_block; |
1458 | wd.ic2_block = jcp.dimK_block; |
1459 | wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block; |
1460 | size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc; |
1461 | wd.size = max_size; |
1462 | wd.adj_scale = 1.f; |
1463 | |
1464 | if (weights_md.format_kind == format_kind::any) |
1465 | weights_md = expect_wei_md; |
1466 | if (weights_md != expect_wei_md) return status::unimplemented; |
1467 | } |
1468 | |
1469 | return res; |
1470 | } |
1471 | |
1472 | status_t jit_avx512_core_f32_wino_conv_4x3_bwd_data_kernel::init_conf( |
1473 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
1474 | const memory_desc_wrapper &diff_src_d, |
1475 | const memory_desc_wrapper &weights_d, |
1476 | const memory_desc_wrapper &diff_dst_d) { |
1477 | status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d); |
1478 | |
1479 | if (st != status::success) return st; |
1480 | |
1481 | jcp.itiles = (jcp.iw + tile_size - 1) / tile_size; |
1482 | jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size; |
1483 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
1484 | |
1485 | status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc); |
1486 | |
1487 | jcp.oc_simd_block = jcp.dimK_reg_block; |
1488 | jcp.oc_block = jcp.dimK_block; |
1489 | jcp.nb_oc = jcp.dimK_nb_block; |
1490 | jcp.ic_simd_block = jcp.dimM_simd_block; |
1491 | jcp.ic_block = jcp.dimM_block; |
1492 | jcp.ic_reg_block = jcp.dimM_reg_block; |
1493 | jcp.oc_reg_block = 1; |
1494 | jcp.nb_ic = jcp.dimM_nb_block; |
1495 | jcp.tile_block_ur = jcp.dimN_reg_block; |
1496 | jcp.nb_tile_block_ur = jcp.dimN_block; |
1497 | jcp.tile_block = jcp.dimN_nb_block; |
1498 | |
1499 | return res; |
1500 | } |
1501 | |
1502 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel:: |
1503 | src_transform_generate() { |
1504 | constexpr int G_size = 9; |
1505 | const size_t ifwp = jcp.iw + jcp.l_pad; |
1506 | const size_t ifhp = jcp.ih + jcp.t_pad; |
1507 | |
1508 | auto zmm_G = [=](int i) { return Xbyak::Zmm(i); }; |
1509 | auto zmm_I = [=](int i) { return Xbyak::Zmm(G_size + i); }; |
1510 | auto zmm_T = [=](int i) { return Xbyak::Zmm(G_size + alpha + i); }; |
1511 | auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 2 * alpha + i); }; |
1512 | |
1513 | auto init_G = [=]() { |
1514 | mov(reg_G, ptr[reg_transp + GET_OFF(G)]); |
1515 | for (int i = 0; i < G_size; i++) { |
1516 | vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); |
1517 | } |
1518 | }; |
1519 | |
1520 | auto load_src = [=]() { |
1521 | mov(reg_I, ptr[reg_transp + GET_OFF(M)]); |
1522 | xor_(reg_zero, reg_zero); |
1523 | |
1524 | mov(reg_ydim, reg_tj); |
1525 | shl(reg_ydim, 2); //tj * tile_size(=4) |
1526 | |
1527 | for (int j = 0; j < alpha; j++) { |
1528 | /* check if tile index is within physical spatial boundaries*/ |
1529 | mov(reg_maskj, 0xffff); |
1530 | cmp(reg_ydim, jcp.t_pad); |
1531 | cmovl(reg_maskj, reg_zero); |
1532 | cmp(reg_ydim, ifhp); |
1533 | cmovge(reg_maskj, reg_zero); |
1534 | |
1535 | /*address offset for tile in src*/ |
1536 | mov(reg_src_offset, reg_ydim); |
1537 | sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad |
1538 | imul(reg_src_offset, reg_src_offset, jcp.iw); |
1539 | |
1540 | mov(reg_xdim, reg_ti); |
1541 | shl(reg_xdim, 2); // xdim = ti * tile_size |
1542 | |
1543 | add(reg_src_offset, reg_xdim); |
1544 | sub(reg_src_offset, jcp.l_pad); |
1545 | imul(reg_src_offset, reg_src_offset, simd_w * typesize); |
1546 | for (int i = 0; i < alpha; i++) { |
1547 | /* check if tile index is within physical spatial boundaries*/ |
1548 | mov(reg_maski, 0xffff); |
1549 | cmp(reg_xdim, jcp.l_pad); |
1550 | cmovl(reg_maski, reg_zero); |
1551 | cmp(reg_xdim, ifwp); |
1552 | cmovge(reg_maski, reg_zero); |
1553 | and_(reg_maski, reg_maskj); |
1554 | |
1555 | Opmask kmask_src = Xbyak::Opmask(7); |
1556 | auto zmm_src = Xbyak::Zmm(31); |
1557 | kmovw(kmask_src, reg_maski_32); |
1558 | vpxord(zmm_src, zmm_src, zmm_src); |
1559 | vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]); |
1560 | vmovups(ptr[reg_I], zmm_src); |
1561 | |
1562 | add(reg_xdim, 1); //xdim = ti * tile_size + i |
1563 | add(reg_src_offset, simd_w * typesize); |
1564 | add(reg_I, simd_w * typesize); |
1565 | } |
1566 | add(reg_ydim, 1); |
1567 | } |
1568 | }; |
1569 | |
1570 | auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) { |
1571 | vmovups(dst, c); |
1572 | vfmadd231ps(dst, a, b); |
1573 | }; |
1574 | |
1575 | auto trans_I_3x3_4x4 = [=]() { |
1576 | //Use 24 registers |
1577 | mov(reg_I, ptr[reg_transp + GET_OFF(M)]); |
1578 | mov(reg_T, ptr[reg_transp + GET_OFF(T)]); |
1579 | for (int i = 0; i < alpha; i++) { |
1580 | for (int j = 0; j < alpha; j++) { |
1581 | size_t I_off = (j * alpha + i) * simd_w * typesize; |
1582 | vmovups(zmm_I(j), ptr[reg_I + I_off]); |
1583 | } |
1584 | |
1585 | fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4)); |
1586 | fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3)); |
1587 | fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4)); |
1588 | fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3)); |
1589 | fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4)); |
1590 | fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5)); |
1591 | |
1592 | fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4)); |
1593 | fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0)); |
1594 | fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0)); |
1595 | fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2)); |
1596 | fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2)); |
1597 | fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5)); |
1598 | |
1599 | for (int j = 0; j < alpha; j++) { |
1600 | vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize], |
1601 | zmm_T(j)); |
1602 | } |
1603 | } |
1604 | |
1605 | for (int j = 0; j < alpha; j++) { |
1606 | for (int i = 0; i < alpha; i++) { |
1607 | vmovups(zmm_T(i), |
1608 | ptr[reg_T + (j * alpha + i) * simd_w * typesize]); |
1609 | } |
1610 | |
1611 | fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4)); |
1612 | fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3)); |
1613 | fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4)); |
1614 | fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3)); |
1615 | fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4)); |
1616 | fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5)); |
1617 | |
1618 | fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4)); |
1619 | fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0)); |
1620 | fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0)); |
1621 | fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2)); |
1622 | fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2)); |
1623 | fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5)); |
1624 | |
1625 | for (int i = 0; i < alpha; i++) { |
1626 | size_t dst_off |
1627 | = (j * alpha * jcp.ic_block * jcp.nb_tile_block_ur |
1628 | * jcp.tile_block_ur |
1629 | + i * jcp.ic_block * jcp.nb_tile_block_ur |
1630 | * jcp.tile_block_ur) |
1631 | * simd_w * typesize; |
1632 | vmovups(ptr[reg_dst + dst_off], zmm_I(i)); |
1633 | } |
1634 | } |
1635 | }; |
1636 | |
1637 | auto compute_transform_SDGtWo = [=]() { |
1638 | mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); |
1639 | mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); |
1640 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1641 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1642 | xor_(reg_tile_count, reg_tile_count); |
1643 | Label loop_mb, loop_jtiles, loop_itiles, done; |
1644 | L(loop_mb); |
1645 | { |
1646 | L(loop_jtiles); |
1647 | { |
1648 | L(loop_itiles); |
1649 | { |
1650 | load_src(); |
1651 | |
1652 | trans_I_3x3_4x4(); |
1653 | |
1654 | add(reg_tile_count, 1); |
1655 | cmp(reg_tile_count, |
1656 | jcp.nb_tile_block_ur * jcp.tile_block_ur); |
1657 | jge(done); |
1658 | |
1659 | add(reg_dst, simd_w * typesize); |
1660 | add(reg_ti, 1); |
1661 | cmp(reg_ti, jcp.itiles); |
1662 | jl(loop_itiles); |
1663 | } |
1664 | xor_(reg_ti, reg_ti); |
1665 | add(reg_tj, 1); |
1666 | cmp(reg_tj, jcp.jtiles); |
1667 | jl(loop_jtiles); |
1668 | } |
1669 | xor_(reg_tj, reg_tj); |
1670 | add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize); |
1671 | jmp(loop_mb); |
1672 | } |
1673 | L(done); |
1674 | }; |
1675 | |
1676 | auto compute_transform = [=]() { |
1677 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1678 | xor_(reg_ti, reg_ti); |
1679 | xor_(reg_tj, reg_tj); |
1680 | |
1681 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1682 | mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); |
1683 | imul(reg_temp, reg_tile_count, simd_w * typesize); |
1684 | add(reg_dst, reg_temp); |
1685 | |
1686 | Label loop_jtiles, loop_itiles, next_tile_block, next_tile; |
1687 | L(loop_jtiles); |
1688 | |
1689 | { |
1690 | L(loop_itiles); |
1691 | { |
1692 | load_src(); |
1693 | |
1694 | trans_I_3x3_4x4(); |
1695 | |
1696 | add(reg_tile_count, 1); |
1697 | cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur); |
1698 | jge(next_tile_block); |
1699 | add(reg_dst, simd_w * typesize); |
1700 | jmp(next_tile); |
1701 | |
1702 | L(next_tile_block); |
1703 | sub(reg_dst, |
1704 | (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) * simd_w |
1705 | * typesize); |
1706 | size_t tblk_off = alpha * alpha * jcp.ic_block |
1707 | * jcp.nb_tile_block_ur * jcp.tile_block_ur * simd_w |
1708 | * typesize; |
1709 | add(reg_dst, tblk_off); |
1710 | xor_(reg_tile_count, reg_tile_count); |
1711 | |
1712 | L(next_tile); |
1713 | add(reg_ti, 1); |
1714 | cmp(reg_ti, jcp.itiles); |
1715 | jl(loop_itiles); |
1716 | } |
1717 | xor_(reg_ti, reg_ti); |
1718 | add(reg_tj, 1); |
1719 | cmp(reg_tj, jcp.jtiles); |
1720 | jl(loop_jtiles); |
1721 | } |
1722 | }; |
1723 | |
1724 | preamble(); |
1725 | init_G(); |
1726 | if (jcp.sched_policy == WSCHED_WEI_SDGtWo) |
1727 | compute_transform_SDGtWo(); |
1728 | else |
1729 | compute_transform(); |
1730 | postamble(); |
1731 | } |
1732 | |
1733 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel:: |
1734 | diff_dst_transform_generate(bool with_bias) { |
1735 | |
1736 | constexpr int G_size = 8; |
1737 | auto zmm_G = [](int i) { return Xbyak::Zmm(31); }; |
1738 | |
1739 | auto zmm_src = [=](int j, int i) { return Xbyak::Zmm(G_size + j * 4 + i); }; |
1740 | |
1741 | auto zmm_bias = Xbyak::Zmm(31); |
1742 | |
1743 | auto load_src = [=]() { |
1744 | if (with_bias) vmovups(zmm_bias, ptr[reg_bias]); |
1745 | mov(reg_ydim, reg_tj); |
1746 | shl(reg_ydim, 2); //tj * tile_size(=4) |
1747 | for (int j = 0; j < tile_size; j++) { |
1748 | /* check if tile index is within physical spatial boundaries*/ |
1749 | mov(reg_maskj, 0xffff); |
1750 | cmp(reg_ydim, jcp.oh); |
1751 | cmovge(reg_maskj, reg_zero); |
1752 | |
1753 | /*address offset for tile in src*/ |
1754 | mov(reg_src_offset, reg_ydim); |
1755 | imul(reg_src_offset, reg_src_offset, jcp.ow); |
1756 | |
1757 | mov(reg_xdim, reg_ti); |
1758 | shl(reg_xdim, 2); // xdim = ti * tile_size |
1759 | |
1760 | add(reg_src_offset, reg_xdim); |
1761 | imul(reg_src_offset, reg_src_offset, simd_w * typesize); |
1762 | for (int i = 0; i < tile_size; i++) { |
1763 | /* check if tile index is within physical spatial boundaries*/ |
1764 | mov(reg_maski, 0xffff); |
1765 | cmp(reg_xdim, jcp.ow); |
1766 | cmovge(reg_maski, reg_zero); |
1767 | and_(reg_maski, reg_maskj); |
1768 | |
1769 | Opmask kmask_src = Xbyak::Opmask(7); |
1770 | kmovw(kmask_src, reg_maski_32); |
1771 | vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i)); |
1772 | vmovups(zmm_src(j, i) | kmask_src, |
1773 | ptr[reg_src + reg_src_offset]); |
1774 | if (with_bias) |
1775 | vaddps(zmm_bias | kmask_src, zmm_bias, |
1776 | ptr[reg_src + reg_src_offset]); |
1777 | |
1778 | add(reg_xdim, 1); //xdim = ti * tile_size + i |
1779 | add(reg_src_offset, simd_w * typesize); |
1780 | } |
1781 | add(reg_ydim, 1); |
1782 | } |
1783 | if (with_bias) vmovups(ptr[reg_bias], zmm_bias); |
1784 | }; |
1785 | |
1786 | auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 16 + i); }; |
1787 | |
1788 | auto zmm_T = [=](int j, int i) { return Xbyak::Zmm(j * 4 + i); }; |
1789 | |
1790 | auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) { |
1791 | if (jcp.sched_policy == WSCHED_WEI_SDGtWo) |
1792 | vmovups(ptr[reg_dst + dst_off], a); |
1793 | else |
1794 | vmovntps(ptr[reg_dst + dst_off], a); |
1795 | }; |
1796 | |
1797 | auto trans_W_3x3_4x4 = [=]() { |
1798 | mov(reg_G, ptr[reg_transp + GET_OFF(G)]); |
1799 | for (int i = 0; i < tile_size; i++) { |
1800 | vbroadcastss(zmm_G(0), ptr[reg_G]); |
1801 | vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0)); |
1802 | |
1803 | vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); |
1804 | vmovups(zmm_t(1), zmm_t(0)); |
1805 | vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1)); |
1806 | |
1807 | vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); |
1808 | vmovups(zmm_t(2), zmm_t(0)); |
1809 | vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2)); |
1810 | |
1811 | vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); |
1812 | vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3)); |
1813 | |
1814 | vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); |
1815 | vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4)); |
1816 | |
1817 | vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); |
1818 | vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5)); |
1819 | |
1820 | vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); |
1821 | vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6)); |
1822 | |
1823 | vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); |
1824 | vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7)); |
1825 | vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3)); |
1826 | vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3)); |
1827 | vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4)); |
1828 | vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4)); |
1829 | vmovups(zmm_T(5, i), zmm_src(3, i)); |
1830 | } |
1831 | |
1832 | for (int j = 0; j < alpha; j++) { |
1833 | vbroadcastss(zmm_G(0), ptr[reg_G]); |
1834 | vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0)); |
1835 | |
1836 | vbroadcastss(zmm_G(1), ptr[reg_G + typesize]); |
1837 | vmovups(zmm_t(1), zmm_t(0)); |
1838 | vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1)); |
1839 | |
1840 | vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]); |
1841 | vmovups(zmm_t(2), zmm_t(0)); |
1842 | vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2)); |
1843 | |
1844 | vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]); |
1845 | vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3)); |
1846 | |
1847 | vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]); |
1848 | vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4)); |
1849 | |
1850 | vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]); |
1851 | vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5)); |
1852 | |
1853 | vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]); |
1854 | vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6)); |
1855 | |
1856 | vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]); |
1857 | vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7)); |
1858 | vsubps(zmm_t(5), zmm_t(1), zmm_t(3)); |
1859 | vaddps(zmm_t(1), zmm_t(1), zmm_t(3)); |
1860 | vaddps(zmm_t(6), zmm_t(2), zmm_t(4)); |
1861 | vsubps(zmm_t(2), zmm_t(2), zmm_t(4)); |
1862 | vmovups(zmm_t(3), zmm_T(j, 3)); |
1863 | |
1864 | int alpha_offset = (jcp.oc / jcp.nb_oc) |
1865 | * (jcp.ntiles / jcp.tile_block) * typesize; |
1866 | int dst_off = j * alpha * alpha_offset; |
1867 | movps(reg_dst, dst_off, zmm_t(0)); |
1868 | dst_off += alpha_offset; |
1869 | movps(reg_dst, dst_off, zmm_t(5)); |
1870 | dst_off += alpha_offset; |
1871 | movps(reg_dst, dst_off, zmm_t(1)); |
1872 | dst_off += alpha_offset; |
1873 | movps(reg_dst, dst_off, zmm_t(6)); |
1874 | dst_off += alpha_offset; |
1875 | movps(reg_dst, dst_off, zmm_t(2)); |
1876 | dst_off += alpha_offset; |
1877 | movps(reg_dst, dst_off, zmm_t(3)); |
1878 | } |
1879 | }; |
1880 | auto compute_transform_SDGtWo = [=]() { |
1881 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1882 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1883 | if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); |
1884 | |
1885 | xor_(reg_zero, reg_zero); |
1886 | xor_(reg_oc_ur, reg_oc_ur); |
1887 | Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done; |
1888 | |
1889 | L(loop_oc_ur); |
1890 | { |
1891 | mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]); |
1892 | mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]); |
1893 | xor_(reg_tile_count, reg_tile_count); |
1894 | L(loop_mb); |
1895 | { |
1896 | L(loop_jtiles); |
1897 | { |
1898 | L(loop_itiles); |
1899 | { |
1900 | load_src(); |
1901 | |
1902 | trans_W_3x3_4x4(); |
1903 | |
1904 | add(reg_tile_count, 1); |
1905 | cmp(reg_tile_count, |
1906 | jcp.nb_tile_block_ur * jcp.tile_block_ur); |
1907 | jge(tiles_done); |
1908 | |
1909 | add(reg_dst, jcp.oc_reg_block * simd_w * typesize); |
1910 | add(reg_ti, 1); |
1911 | cmp(reg_ti, jcp.itiles); |
1912 | jl(loop_itiles); |
1913 | } |
1914 | xor_(reg_ti, reg_ti); |
1915 | add(reg_tj, 1); |
1916 | cmp(reg_tj, jcp.jtiles); |
1917 | jl(loop_jtiles); |
1918 | } |
1919 | xor_(reg_tj, reg_tj); |
1920 | add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize); |
1921 | jmp(loop_mb); |
1922 | } |
1923 | |
1924 | L(tiles_done); |
1925 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1926 | add(reg_dst, simd_w * typesize); |
1927 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1928 | add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); |
1929 | |
1930 | if (with_bias) add(reg_bias, simd_w * typesize); |
1931 | add(reg_oc_ur, 1); |
1932 | cmp(reg_oc_ur, jcp.oc_reg_block); |
1933 | jl(loop_oc_ur); |
1934 | } |
1935 | }; |
1936 | |
1937 | auto compute_transform = [=]() { |
1938 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1939 | mov(reg_G, ptr[reg_transp + GET_OFF(G)]); |
1940 | if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]); |
1941 | |
1942 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1943 | mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); |
1944 | imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize); |
1945 | add(reg_dst, reg_temp); |
1946 | |
1947 | xor_(reg_zero, reg_zero); |
1948 | xor_(reg_oc_ur, reg_oc_ur); |
1949 | Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, |
1950 | next_tile; |
1951 | |
1952 | L(loop_oc_ur); |
1953 | { |
1954 | xor_(reg_ti, reg_ti); |
1955 | xor_(reg_tj, reg_tj); |
1956 | |
1957 | L(loop_jtiles); |
1958 | { |
1959 | L(loop_itiles); |
1960 | { |
1961 | load_src(); |
1962 | |
1963 | trans_W_3x3_4x4(); |
1964 | |
1965 | add(reg_tile_count, 1); |
1966 | cmp(reg_tile_count, |
1967 | jcp.nb_tile_block_ur * jcp.tile_block_ur); |
1968 | jge(next_tile_block); |
1969 | add(reg_dst, jcp.oc_reg_block * simd_w * typesize); |
1970 | jmp(next_tile); |
1971 | |
1972 | L(next_tile_block); |
1973 | sub(reg_dst, |
1974 | (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1) |
1975 | * jcp.oc_reg_block * simd_w * typesize); |
1976 | int tblk_off = alpha * alpha * (jcp.oc / jcp.nb_oc) |
1977 | * (jcp.ntiles / jcp.tile_block) * typesize; |
1978 | add(reg_dst, tblk_off); |
1979 | xor_(reg_tile_count, reg_tile_count); |
1980 | |
1981 | L(next_tile); |
1982 | add(reg_ti, 1); |
1983 | cmp(reg_ti, jcp.itiles); |
1984 | jl(loop_itiles); |
1985 | } |
1986 | xor_(reg_ti, reg_ti); |
1987 | add(reg_tj, 1); |
1988 | cmp(reg_tj, jcp.jtiles); |
1989 | jl(loop_jtiles); |
1990 | } |
1991 | |
1992 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
1993 | mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]); |
1994 | imul(reg_temp, reg_tile_count, |
1995 | jcp.oc_reg_block * simd_w * typesize); |
1996 | add(reg_dst, reg_temp); |
1997 | add(reg_dst, simd_w * typesize); |
1998 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
1999 | add(reg_src, jcp.oh * jcp.ow * simd_w * typesize); |
2000 | |
2001 | if (with_bias) add(reg_bias, simd_w * typesize); |
2002 | add(reg_oc_ur, 1); |
2003 | cmp(reg_oc_ur, jcp.oc_reg_block); |
2004 | jl(loop_oc_ur); |
2005 | } |
2006 | }; |
2007 | |
2008 | preamble(); |
2009 | if (jcp.sched_policy == WSCHED_WEI_SDGtWo) { |
2010 | compute_transform_SDGtWo(); |
2011 | } else { |
2012 | compute_transform(); |
2013 | } |
2014 | postamble(); |
2015 | } |
2016 | |
2017 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel:: |
2018 | diff_weights_transform_generate(bool first_tile) { |
2019 | int G_size = 4; |
2020 | |
2021 | auto zmm_G = [](int i) { return Xbyak::Zmm(i); }; |
2022 | |
2023 | auto init_G = [=]() { |
2024 | mov(reg_G, ptr[reg_transp + GET_OFF(G)]); |
2025 | for (int i = 0; i < G_size; i++) |
2026 | vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]); |
2027 | }; |
2028 | |
2029 | auto zmm_src = [=](int i) { return Xbyak::Zmm(G_size + i); }; |
2030 | |
2031 | auto load_src = [=](int i) { |
2032 | for (int j = 0; j < alpha; j++) { |
2033 | size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block * jcp.ic_block |
2034 | * simd_w * simd_w * typesize; |
2035 | size_t src_off = (j * alpha + i) * alpha_offset; |
2036 | vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off)); |
2037 | } |
2038 | }; |
2039 | |
2040 | auto zmm_t = [=](int i) { return Xbyak::Zmm(G_size + 6 + i); }; |
2041 | |
2042 | auto zmm_T = [=](int j, int i) { |
2043 | return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i); |
2044 | }; |
2045 | |
2046 | auto zmm_dst = [=](int i) { return Xbyak::Zmm(G_size + i); }; |
2047 | |
2048 | auto zmm_temp = Xbyak::Zmm(31); |
2049 | |
2050 | auto store_dst = [=](int j) { |
2051 | for (int i = 0; i < jcp.kw; i++) { |
2052 | size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize; |
2053 | |
2054 | if (!first_tile) { |
2055 | vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off)); |
2056 | vaddps(zmm_dst(i), zmm_dst(i), zmm_temp); |
2057 | } |
2058 | vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i)); |
2059 | } |
2060 | }; |
2061 | |
2062 | auto compute_transform = [=]() { |
2063 | mov(reg_src, ptr[reg_transp + GET_OFF(src)]); |
2064 | mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]); |
2065 | |
2066 | xor_(reg_ic_simd, reg_ic_simd); |
2067 | Label loop_ic_simd; |
2068 | L(loop_ic_simd); |
2069 | { |
2070 | for (int i = 0; i < alpha; i++) { |
2071 | load_src(i); |
2072 | |
2073 | vaddps(zmm_t(0), zmm_src(1), zmm_src(2)); |
2074 | vaddps(zmm_t(1), zmm_src(3), zmm_src(4)); |
2075 | vmovups(zmm_t(2), zmm_src(5)); |
2076 | vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); |
2077 | |
2078 | vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0)); |
2079 | vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1)); |
2080 | vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2)); |
2081 | vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1)); |
2082 | vsubps(zmm_temp, zmm_src(3), zmm_src(4)); |
2083 | vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2)); |
2084 | vmovups(zmm_T(2, i), zmm_t(2)); |
2085 | vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3)); |
2086 | } |
2087 | |
2088 | for (int j = 0; j < jcp.kh; j++) { |
2089 | vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2)); |
2090 | vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4)); |
2091 | vmovups(zmm_t(2), zmm_T(j, 5)); |
2092 | vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0)); |
2093 | |
2094 | vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0)); |
2095 | vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1)); |
2096 | vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2)); |
2097 | vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1)); |
2098 | vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4)); |
2099 | vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2)); |
2100 | vmovups(zmm_dst(2), zmm_t(2)); |
2101 | vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3)); |
2102 | |
2103 | store_dst(j); |
2104 | } |
2105 | |
2106 | add(reg_src, jcp.oc_reg_block * simd_w * typesize); |
2107 | add(reg_dst, simd_w * typesize); |
2108 | add(reg_ic_simd, 1); |
2109 | cmp(reg_ic_simd, simd_w); |
2110 | jl(loop_ic_simd); |
2111 | } |
2112 | }; |
2113 | preamble(); |
2114 | push(reg_EVEX_max_8b_offt); |
2115 | mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); |
2116 | init_G(); |
2117 | compute_transform(); |
2118 | pop(reg_EVEX_max_8b_offt); |
2119 | postamble(); |
2120 | } |
2121 | |
2122 | void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate( |
2123 | bool is_first_tile) { |
2124 | auto zmm_srcA = [=]() { return Xbyak::Zmm(0); }; |
2125 | |
2126 | auto zmm_srcB = [=](size_t N_ur) { return Xbyak::Zmm(N_ur + 1); }; |
2127 | |
2128 | auto broadcastB = [=](size_t K_ur) { |
2129 | for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { |
2130 | size_t srcB_off |
2131 | = (K_ur * jcp.dimN_reg_block + N_bcast) * sizeof(float); |
2132 | vbroadcastss( |
2133 | zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off)); |
2134 | } |
2135 | }; |
2136 | |
2137 | auto load_srcA = [=](size_t K_ur, int M_ur) { |
2138 | size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block |
2139 | + M_ur * jcp.dimM_simd_block) |
2140 | * sizeof(float); |
2141 | vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off)); |
2142 | }; |
2143 | |
2144 | auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast) { |
2145 | size_t idx = 1 // zmm_srcA |
2146 | + jcp.dimN_bcast_ur // zmm_srcB |
2147 | + M_reg_ur * jcp.dimN_bcast_ur + N_bcast; |
2148 | assert(idx < 32); |
2149 | return Xbyak::Zmm(idx); |
2150 | }; |
2151 | auto prepare_accumm = [=]() { |
2152 | for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { |
2153 | for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) { |
2154 | Zmm zmm = zmm_dstC(M_reg_ur, N_bcast); |
2155 | vpxord(zmm, zmm, zmm); |
2156 | } |
2157 | } |
2158 | }; |
2159 | |
2160 | auto store_dstC = [=]() { |
2161 | /******** Write C back to memory *******/ |
2162 | for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) { |
2163 | for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) { |
2164 | Zmm zmm = zmm_dstC(M_reg, N_ur); |
2165 | size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block |
2166 | + M_reg * jcp.dimM_simd_block) |
2167 | * sizeof(float); |
2168 | if (!is_first_tile) { |
2169 | vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off)); |
2170 | vaddps(zmm, zmm, Xbyak::Zmm(0)); |
2171 | } |
2172 | vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm); |
2173 | } |
2174 | } |
2175 | }; |
2176 | |
2177 | auto inner_loops = [=]() { |
2178 | Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur; |
2179 | |
2180 | mov(reg_dimM_block_loop_cnt, jcp.dimM_block); |
2181 | L(dimM_block_loop); |
2182 | { /************* OC_block (M) loop ***********/ |
2183 | mov(reg_dimN_block_loop_cnt, jcp.dimN_block); |
2184 | L(dimN_block_loop); |
2185 | { /*************** IC_block (N) loop *********/ |
2186 | |
2187 | mov(reg_nb_dimN_bcast_ur, |
2188 | jcp.dimN_reg_block / jcp.dimN_bcast_ur); |
2189 | L(dimN_bcast_ur); |
2190 | { |
2191 | prepare_accumm(); |
2192 | |
2193 | mov(reg_dimK_block_loop_cnt, jcp.dimK_block); |
2194 | L(dimK_block_loop); |
2195 | { |
2196 | /************* nb_tile_ur(K) loop ********/ |
2197 | for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) { |
2198 | |
2199 | broadcastB(K_ur); |
2200 | |
2201 | for (int M_reg_ur = 0; |
2202 | M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) { |
2203 | load_srcA(K_ur, M_reg_ur); |
2204 | for (int N_bcast = 0; |
2205 | N_bcast < jcp.dimN_bcast_ur; |
2206 | ++N_bcast) { |
2207 | vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), |
2208 | zmm_srcA(), zmm_srcB(N_bcast)); |
2209 | } |
2210 | } |
2211 | } |
2212 | add(reg_srcA, |
2213 | jcp.dimK_reg_block * jcp.dimM_reg_block |
2214 | * jcp.dimM_simd_block * sizeof(float)); |
2215 | add(reg_srcB, |
2216 | jcp.dimK_reg_block * jcp.dimN_reg_block |
2217 | * sizeof(float)); |
2218 | sub(reg_dimK_block_loop_cnt, 1); |
2219 | jnz(dimK_block_loop); |
2220 | } |
2221 | |
2222 | store_dstC(); |
2223 | |
2224 | sub(reg_srcA, |
2225 | jcp.dimK_block * jcp.dimK_reg_block |
2226 | * jcp.dimM_reg_block * jcp.dimM_simd_block |
2227 | * sizeof(float)); |
2228 | sub(reg_srcB, |
2229 | jcp.dimK_block * jcp.dimK_reg_block |
2230 | * jcp.dimN_reg_block * sizeof(float)); |
2231 | add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float)); |
2232 | add(reg_dstC, |
2233 | jcp.dimN_bcast_ur * jcp.dimM_reg_block |
2234 | * jcp.dimM_simd_block * sizeof(float)); |
2235 | sub(reg_nb_dimN_bcast_ur, 1); |
2236 | jnz(dimN_bcast_ur); |
2237 | } |
2238 | |
2239 | sub(reg_srcB, jcp.dimN_reg_block * sizeof(float)); |
2240 | add(reg_srcB, |
2241 | jcp.dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block |
2242 | * sizeof(float)); |
2243 | sub(reg_dimN_block_loop_cnt, 1); |
2244 | jnz(dimN_block_loop); |
2245 | } |
2246 | |
2247 | sub(reg_srcB, |
2248 | jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block |
2249 | * jcp.dimN_reg_block * sizeof(float)); |
2250 | add(reg_srcA, |
2251 | jcp.dimK_block * jcp.dimK_reg_block * jcp.dimM_reg_block |
2252 | * jcp.dimM_simd_block * sizeof(float)); |
2253 | sub(reg_dimM_block_loop_cnt, 1); |
2254 | jnz(dimM_block_loop); |
2255 | } |
2256 | }; |
2257 | |
2258 | /* Preamble */ |
2259 | preamble(); |
2260 | |
2261 | inner_loops(); |
2262 | |
2263 | /* Postamble */ |
2264 | postamble(); |
2265 | ret(); |
2266 | } |
2267 | |
2268 | namespace { |
2269 | |
2270 | void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) { |
2271 | /*M params*/ |
2272 | jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block |
2273 | / jcp.dimM_simd_block; |
2274 | jcp.oc_reg_block = jcp.dimM_reg_block; |
2275 | jcp.oc_block = jcp.dimM_block; |
2276 | jcp.nb_oc = jcp.dimM_nb_block; |
2277 | /*N params*/ |
2278 | jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; |
2279 | jcp.ic_block = jcp.dimN_block; |
2280 | jcp.nb_ic = jcp.dimN_nb_block; |
2281 | |
2282 | /*K params*/ |
2283 | jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block; |
2284 | jcp.tile_block_ur = jcp.dimK_reg_block; |
2285 | jcp.nb_tile_block_ur = jcp.dimK_block; |
2286 | jcp.tile_block = jcp.dimK_nb_block; |
2287 | } |
2288 | |
2289 | status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { |
2290 | |
2291 | size_t K_blk_ur, N_blk, M_blk; |
2292 | /* IS this strategy feasible? */ |
2293 | auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { |
2294 | size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); |
2295 | size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); |
2296 | return (((V_sz + M_sz) / jcp.nthr) >= 2 * L2_cache_size) |
2297 | && (jcp.dimK / jcp.nthr >= 1.0); |
2298 | }; |
2299 | |
2300 | auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, |
2301 | int max_block = 1) { |
2302 | size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block |
2303 | * dimK_block_ur * sizeof(float); |
2304 | size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); |
2305 | size_t M_L2_block |
2306 | = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); |
2307 | bool load_balance = true; |
2308 | if (!(jcp.dimK % jcp.nthr)) { |
2309 | load_balance = ((jcp.dimK / dimK_block_ur) % jcp.nthr == 0); |
2310 | } |
2311 | return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) |
2312 | && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) |
2313 | && load_balance && (M_L2_block < L2_cache_size); |
2314 | }; |
2315 | |
2316 | auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, |
2317 | int useless_arg = 0) { |
2318 | return (dimK_ur >= 2) && (dimK_ur <= 8); |
2319 | }; |
2320 | |
2321 | auto blocking_ok = [&]() { |
2322 | size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block |
2323 | * jcp.dimM_simd_block * K_blk_ur * sizeof(float); |
2324 | size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block |
2325 | * K_blk_ur * sizeof(float); |
2326 | size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block |
2327 | * jcp.dimM_simd_block * N_blk * jcp.dimN_reg_block |
2328 | * sizeof(float); |
2329 | size_t L2_block = M_L2_block + V_L2_block + U_L2_block; |
2330 | /*Replace 2.375 with L2+L3 cache size*/ |
2331 | return (L2_block > 0.1 * L2_cache_size) |
2332 | && (L2_block <= 1.2 * L2_cache_size); |
2333 | }; |
2334 | |
2335 | if (test_MV_large_enough(jcp)) { |
2336 | if ((jcp.dimM / jcp.dimM_simd_block) % 2 == 0) { |
2337 | jcp.dimM_reg_block = 2; |
2338 | } else { |
2339 | jcp.dimM_reg_block = 1; |
2340 | } |
2341 | jcp.dimM_simd_block = jcp.oc_simd_block; |
2342 | jcp.dimN_reg_block = jcp.ic_simd_block; |
2343 | jcp.dimN_bcast_ur = 8; |
2344 | /*dimK_block and dimK_ur*/ |
2345 | size_t min_dimK_block_ur = get_divisor_satisfying_cond( |
2346 | jcp, jcp.dimK, 1, test_min_dimK_L1); |
2347 | |
2348 | jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; |
2349 | jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; |
2350 | for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) { |
2351 | if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) { |
2352 | for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { |
2353 | if (!(jcp.dimN_block % N_blk)) { |
2354 | for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { |
2355 | if (!(jcp.dimM_block % M_blk) && blocking_ok()) { |
2356 | jcp.dimK_reg_block |
2357 | = get_divisor_satisfying_cond( |
2358 | jcp, K_blk_ur, 1, test_dimK_ur); |
2359 | if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) |
2360 | return status::unimplemented; |
2361 | jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; |
2362 | jcp.dimN_block = N_blk; |
2363 | jcp.dimM_block = M_blk; |
2364 | jcp.sched_policy = WSCHED_WEI_SDGtWo; |
2365 | set_jcp_WEI_params(jcp); |
2366 | jcp.nthr = nstl::min(jcp.nthr, jcp.tile_block); |
2367 | return status::success; |
2368 | } |
2369 | } |
2370 | } |
2371 | } |
2372 | } |
2373 | } |
2374 | } |
2375 | return status::unimplemented; |
2376 | } |
2377 | |
2378 | status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { |
2379 | if ((jcp.dimM / jcp.dimM_simd_block) % 2 == 0) { |
2380 | jcp.dimM_reg_block = 2; |
2381 | } else { |
2382 | jcp.dimM_reg_block = 1; |
2383 | } |
2384 | jcp.dimN_bcast_ur = 8; |
2385 | jcp.dimN_reg_block = jcp.ic_simd_block; |
2386 | jcp.dimM_simd_block = jcp.oc_simd_block; |
2387 | jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block; |
2388 | jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block; |
2389 | float C1 = 0.0, C2 = 0.0; |
2390 | float C1_max = 0.5, C2_max = 1.4; |
2391 | int N_blk, M_blk, K_blk_ur; |
2392 | |
2393 | auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur, |
2394 | int useless_arg = 0) { |
2395 | return (dimK_ur >= 2) && (dimK_ur <= 8); |
2396 | }; |
2397 | |
2398 | auto blocking_ok = [&]() -> bool { |
2399 | size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur |
2400 | * sizeof(float); |
2401 | size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float); |
2402 | bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size) |
2403 | && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size); |
2404 | |
2405 | size_t nb_N_blk = jcp.dimN / N_blk / jcp.dimN_reg_block; |
2406 | size_t nb_M_blk |
2407 | = jcp.dimM / M_blk / jcp.dimM_reg_block / jcp.dimM_simd_block; |
2408 | size_t nb_K_blk = jcp.dimK / K_blk_ur; |
2409 | bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) |
2410 | >= static_cast<size_t>(jcp.nthr); |
2411 | if (!(nb_K_blk % jcp.nthr)) { |
2412 | load_balance = load_balance && (nb_K_blk % jcp.nthr == 0); |
2413 | } |
2414 | |
2415 | size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block |
2416 | * K_blk_ur * sizeof(float); |
2417 | |
2418 | size_t L2_block = V_L2_block; |
2419 | /*Replace 2.375 with L2+L3 cache size*/ |
2420 | bool L2_cond = (L2_block >= C2 * L2_cache_size) |
2421 | && (L2_block <= C2_max * L2_cache_size); |
2422 | return L1_cond && load_balance && L2_cond; |
2423 | }; |
2424 | |
2425 | for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) { |
2426 | if (jcp.dimK % K_blk_ur == 0) { |
2427 | for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) { |
2428 | if (jcp.dimN_block % N_blk == 0) { |
2429 | for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) { |
2430 | if (jcp.dimM_block % M_blk == 0) { |
2431 | if (blocking_ok()) { |
2432 | jcp.dimN_block = N_blk; |
2433 | jcp.dimM_block = M_blk; |
2434 | jcp.dimK_reg_block |
2435 | = get_divisor_satisfying_cond( |
2436 | jcp, K_blk_ur, 1, test_dimK_ur); |
2437 | jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block; |
2438 | jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; |
2439 | set_jcp_WEI_params(jcp); |
2440 | return status::success; |
2441 | } |
2442 | } |
2443 | } |
2444 | } |
2445 | } |
2446 | } |
2447 | } |
2448 | jcp.dimK_reg_block = 1; |
2449 | jcp.dimK_block = 1; |
2450 | jcp.sched_policy = WSCHED_WEI_S_D_Giot_W; |
2451 | set_jcp_WEI_params(jcp); |
2452 | return status::success; |
2453 | } |
2454 | } // namespace |
2455 | status_t jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::init_conf( |
2456 | jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd, |
2457 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d, |
2458 | const memory_desc_wrapper &diff_weights_d) { |
2459 | if (!mayiuse(avx512_core)) return status::unimplemented; |
2460 | |
2461 | // This kernel only supports 2D convolutions. |
2462 | if (src_d.ndims() != 4) return status::unimplemented; |
2463 | |
2464 | jcp.nthr = dnnl_get_max_threads(); |
2465 | |
2466 | jcp.prop_kind = cd.prop_kind; |
2467 | const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1; |
2468 | jcp.mb = src_d.dims()[0]; |
2469 | jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1; |
2470 | jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups; |
2471 | jcp.oc_without_padding = jcp.oc; |
2472 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
2473 | jcp.ih = src_d.dims()[2]; |
2474 | jcp.iw = src_d.dims()[3]; |
2475 | jcp.oh = diff_dst_d.dims()[2]; |
2476 | jcp.ow = diff_dst_d.dims()[3]; |
2477 | jcp.kh = diff_weights_d.dims()[with_groups + 2]; |
2478 | jcp.kw = diff_weights_d.dims()[with_groups + 3]; |
2479 | jcp.t_pad = cd.padding[0][0]; |
2480 | jcp.l_pad = cd.padding[0][1]; |
2481 | jcp.stride_h = cd.strides[0]; |
2482 | jcp.stride_w = cd.strides[1]; |
2483 | jcp.r_pad = nstl::max( |
2484 | 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad); |
2485 | jcp.b_pad = nstl::max( |
2486 | 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad); |
2487 | jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad; |
2488 | jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad; |
2489 | jcp.ohp = jcp.oh; |
2490 | jcp.owp = jcp.ow; |
2491 | jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef); |
2492 | jcp.dilate_h = cd.dilates[0]; |
2493 | jcp.dilate_w = cd.dilates[1]; |
2494 | |
2495 | bool ok_to_pad_channels = jcp.ngroups == 1; |
2496 | if (ok_to_pad_channels) { |
2497 | jcp.oc = rnd_up(jcp.oc, simd_w); |
2498 | jcp.ic = rnd_up(jcp.ic, simd_w); |
2499 | } |
2500 | |
2501 | // Winograd specific initialization |
2502 | jcp.itiles = (jcp.ow + tile_size - 1) / tile_size; |
2503 | jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size; |
2504 | jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles; |
2505 | |
2506 | // Winograd kernel works only for 3x3 convolution with stride 1 |
2507 | if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto, |
2508 | is_winograd_faster_than_direct(jcp))) |
2509 | return status::unimplemented; |
2510 | |
2511 | const bool prb_shape_ok = jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1 |
2512 | && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 && jcp.stride_h == 1 |
2513 | && jcp.stride_w == 1 && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
2514 | && jcp.l_pad <= 1 && jcp.r_pad <= 1 && jcp.t_pad <= 1 |
2515 | && jcp.b_pad <= 1; |
2516 | if (!prb_shape_ok) return status::unimplemented; |
2517 | |
2518 | format_tag_t dat_tag = nChw16c; |
2519 | format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o; |
2520 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
2521 | jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag); |
2522 | jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag); |
2523 | |
2524 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
2525 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
2526 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
2527 | |
2528 | bool layout_consistency = true && jcp.ic <= src_d.padded_dims()[1] |
2529 | && jcp.oc <= diff_dst_d.padded_dims()[1] |
2530 | && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1] |
2531 | && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0]; |
2532 | if (!layout_consistency) return status::unimplemented; |
2533 | |
2534 | /******************Kernel blocking Parameters ***********/ |
2535 | jcp.ic_simd_block = simd_w; |
2536 | jcp.oc_simd_block = simd_w; |
2537 | |
2538 | jcp.dimK = jcp.ntiles; |
2539 | jcp.dimN = jcp.ic; |
2540 | jcp.dimM = jcp.oc; |
2541 | jcp.dimM_simd_block = jcp.oc_simd_block; |
2542 | jcp.dimN_reg_block = jcp.ic_simd_block; |
2543 | jcp.sched_policy = WSCHED_INVALID; |
2544 | status_t res = set_wsched_WEI_SDGtWo(jcp); |
2545 | if (res == status::unimplemented) { |
2546 | res = set_wsched_WEI_S_D_Giot_W(jcp); |
2547 | assert(res == status::success); |
2548 | } |
2549 | return res; |
2550 | } |
2551 | } // namespace x64 |
2552 | } // namespace cpu |
2553 | } // namespace impl |
2554 | } // namespace dnnl |
2555 | |
2556 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
2557 | |