1/*******************************************************************************
2* Copyright 2017-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifdef __INTEL_COMPILER
18#include <immintrin.h>
19#endif
20
21#include "oneapi/dnnl/dnnl_types.h"
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp"
29
30#ifndef _MSC_VER
31#define pragma_unroll _Pragma("unroll")
32#else
33#define pragma_unroll
34#endif
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace dnnl::impl::status;
42using namespace dnnl::impl::memory_tracking::names;
43using namespace dnnl::impl::utils;
44
45template <bool is_fwd>
46void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::weight_transform_data(
47 const jit_conv_winograd_conf_t &jcp, float *wp, float *twp) const {
48 float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
49 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
50 const int kh = 3;
51 const int kw = 3;
52 float Fw[alpha][alpha][simd_w][simd_w];
53 float F[kh][kw][simd_w][simd_w];
54 float T[alpha][3][simd_w];
55 auto p = jit_wino_transform_call_s();
56
57 p.src = wp;
58 p.dst = twp;
59 p.G = G;
60 p.M = F;
61 p.Mw = Fw;
62 p.T = T;
63
64 kernel_->weights_transform_data_ker(&p);
65}
66
67template <bool is_fwd>
68void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::output_transform_data(
69 int image, const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
70 float *toutp, float *pout_b, float *bias) const {
71
72 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
73 float Ow[alpha][alpha][simd_w];
74 float O[tile_size][tile_size][simd_w];
75 float T[tile_size][alpha][simd_w];
76
77 auto p = jit_wino_transform_call_s();
78 p.src = toutp;
79 p.dst = pout_b;
80 p.G = G;
81 p.M = O;
82 p.Mw = Ow;
83 p.T = T;
84 p.bias = bias;
85
86 int tile_base_index = image * jcp.itiles * jcp.jtiles;
87 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
88 int nb_tile_block_ur
89 = (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
90 int tile_block
91 = (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
92
93 for (int tj = 0; tj < jcp.jtiles; tj++) {
94 for (int ti = 0; ti < jcp.itiles; ti++) {
95
96 p.tile_block_ur = tile_block_ur;
97 p.nb_tile_block_ur = nb_tile_block_ur;
98 p.tile_block = tile_block;
99 p.tj = tj;
100 p.ti = ti;
101
102 kernel_->output_transform_data_ker(&p);
103
104 tile_block_ur++;
105 if (tile_block_ur >= jcp.tile_block_ur) {
106 tile_block_ur = 0;
107 nb_tile_block_ur++;
108 }
109 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
110 nb_tile_block_ur = 0;
111 tile_block++;
112 }
113 }
114 }
115}
116
117template <bool is_fwd>
118void _jit_avx512_core_f32_wino_conv_4x3_t<
119 is_fwd>::output_transform_tileblock_data(int tile_block,
120 const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
121 float *toutp, float *outp, float *bias) const {
122
123 float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
124 float Ow[alpha][alpha][simd_w];
125 float O[tile_size][tile_size][simd_w];
126 float T[tile_size][alpha][simd_w];
127
128 auto p = jit_wino_transform_call_s();
129 p.src = toutp;
130 p.dst = outp;
131 p.G = G;
132 p.M = O;
133 p.Mw = Ow;
134 p.T = T;
135 p.bias = bias;
136
137 int outw = is_fwd ? jcp.ow : jcp.iw;
138 int outh = is_fwd ? jcp.oh : jcp.ih;
139
140 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
141
142 for (int nb_tile_block_ur = 0; nb_tile_block_ur < jcp.nb_tile_block_ur;
143 nb_tile_block_ur++) {
144
145 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
146 tile_block_ur++) {
147 int img = tile_index / (jcp.jtiles * jcp.itiles);
148 int ti = tile_index % jcp.itiles;
149 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
150
151 p.tile_block_ur = tile_block_ur;
152 p.nb_tile_block_ur = nb_tile_block_ur;
153 p.tile_block = tile_block;
154 p.tj = tj;
155 p.ti = ti;
156 p.dst = outp
157 + img * (jcp.dimM / jcp.dimM_simd_block) * outh * outw
158 * jcp.dimM_simd_block;
159
160 kernel_->output_transform_data_ker(&p);
161
162 tile_index++;
163 }
164 }
165}
166
167template <bool is_fwd>
168void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::input_transform_data(
169 int image, const jit_conv_winograd_conf_t &jcp, float *inp,
170 float *tinp) const {
171 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, -0.625f,
172 1.5f, -1.5f, -2.640625f};
173
174 float Iw[alpha][alpha][simd_w];
175 float I[alpha][alpha][simd_w];
176 float T[alpha][alpha][simd_w];
177
178 auto p = jit_wino_transform_call_s();
179
180 p.src = inp;
181 p.dst = tinp;
182 p.G = G;
183 p.M = I;
184 p.Mw = Iw;
185 p.T = T;
186
187 int tile_base_index = image * jcp.itiles * jcp.jtiles;
188 int tile_block_ur = tile_base_index % jcp.tile_block_ur;
189 int nb_tile_block_ur
190 = (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
191 int tile_block
192 = (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
193
194 for (int tj = 0; tj < jcp.jtiles; tj++) {
195 for (int ti = 0; ti < jcp.itiles; ti++) {
196
197 p.tile_block_ur = tile_block_ur;
198 p.nb_tile_block_ur = nb_tile_block_ur;
199 p.tile_block = tile_block;
200 p.tj = tj;
201 p.ti = ti;
202
203 kernel_->input_transform_data_ker(&p);
204
205 tile_block_ur++;
206 if (tile_block_ur >= jcp.tile_block_ur) {
207 tile_block_ur = 0;
208 nb_tile_block_ur++;
209 }
210 if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
211 nb_tile_block_ur = 0;
212 tile_block++;
213 }
214 }
215 }
216}
217
218template <bool is_fwd>
219void _jit_avx512_core_f32_wino_conv_4x3_t<
220 is_fwd>::input_transform_tileblock_data(int tile_block,
221 const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const {
222 float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f, -0.625f,
223 1.5f, -1.5f, -2.640625f};
224 float Iw[alpha][alpha][simd_w];
225 float I[alpha][alpha][simd_w];
226 float T[alpha][alpha][simd_w];
227
228 const int inph = is_fwd ? jcp.ih : jcp.oh;
229 const int inpw = is_fwd ? jcp.iw : jcp.ow;
230
231 array_offset_calculator<float, 5> input(
232 inp, jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
233 array_offset_calculator<float, 7> output(tinp, alpha, alpha, jcp.dimN_block,
234 jcp.dimK_nb_block, jcp.dimK_block, jcp.dimN_reg_block,
235 jcp.dimK_reg_block);
236
237 auto p = jit_wino_transform_call_s();
238
239 p.dst = tinp;
240 p.G = G;
241 p.M = I;
242 p.Mw = Iw;
243 p.T = T;
244
245 int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
246
247 for (int nb_tile_block_ur = 0; nb_tile_block_ur < jcp.nb_tile_block_ur;
248 nb_tile_block_ur++) {
249
250 for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
251 tile_block_ur++) {
252
253 int img = tile_index / (jcp.jtiles * jcp.itiles);
254 int ti = tile_index % jcp.itiles;
255 int tj = (tile_index / jcp.itiles) % jcp.jtiles;
256 float *pinp_b = &(input(img, 0, 0, 0, 0));
257
258 p.src = pinp_b;
259 p.tile_block_ur = tile_block_ur;
260 p.nb_tile_block_ur = nb_tile_block_ur;
261 p.tj = tj;
262 p.ti = ti;
263
264 kernel_->input_transform_data_ker(&p);
265
266 tile_index++;
267 }
268 }
269}
270
271template <bool is_fwd>
272void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
273 float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
274 const memory_tracking::grantor_t &scratchpad) const {
275 const auto &jcp = kernel_->jcp;
276 const auto &p_ops = attr_->post_ops_;
277
278 const int inph = is_fwd ? jcp.ih : jcp.oh;
279 const int inpw = is_fwd ? jcp.iw : jcp.ow;
280 const int outh = is_fwd ? jcp.oh : jcp.ih;
281 const int outw = is_fwd ? jcp.ow : jcp.iw;
282
283 /* Notation:
284 FWD: dimM:oc, dimN:ntiles, dimK:ic,
285 BWD: dimM:ic, dimN:ntiles, dimK:oc,
286 FWD/BWD: V: src/diff_dst transform, U:weight transform,
287 M:dst/diff_src transform */
288 array_offset_calculator<float, 5> input(inp_ptr, jcp.mb,
289 jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
290 array_offset_calculator<float, 5> output(out_ptr, jcp.mb,
291 jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
292 array_offset_calculator<float, 6> weights(wei_ptr,
293 jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh,
294 jcp.kw, jcp.ic_simd_block, jcp.oc_simd_block);
295 array_offset_calculator<float, 2> bias(
296 bias_ptr, jcp.dimM / jcp.dimM_simd_block, jcp.dimM_simd_block);
297
298 array_offset_calculator<float, 8> M(is_fwd
299 ? scratchpad.template get<float>(key_wino_M)
300 : scratchpad.template get<float>(key_wino_V),
301 jcp.dimN_nb_block, jcp.dimM_nb_block, alpha, alpha, jcp.dimN_block,
302 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimN_reg_block,
303 jcp.dimM_simd_block);
304
305 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
306 ? wei_ptr
307 : scratchpad.template get<float>(key_wino_U);
308
309 array_offset_calculator<float, 8> U(wino_wei, jcp.dimM_nb_block, alpha,
310 alpha, jcp.dimK_nb_block, jcp.dimM_block * jcp.dimM_reg_block,
311 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_simd_block);
312 array_offset_calculator<float, 8> V(is_fwd
313 ? scratchpad.template get<float>(key_wino_V)
314 : scratchpad.template get<float>(key_wino_M),
315 jcp.dimN_nb_block, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block,
316 jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
317
318 const bool wants_padded_bias
319 = jcp.with_bias && jcp.oc_without_padding != jcp.oc;
320 float last_slice_bias[simd_w] = {0};
321 if (wants_padded_bias) {
322 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
323 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
324 }
325
326 parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
327 [&](dim_t img, dim_t K_blk1, dim_t K_blk2) {
328 input_transform_data(img, jcp,
329 &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0,
330 0)),
331 &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
332 });
333
334 if (jcp.prop_kind != prop_kind::forward_inference) {
335 parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
336 (jcp.ic_block * jcp.ic_reg_block),
337 [&](dim_t ofm1, dim_t ifm1, dim_t ofm2, dim_t ifm2) {
338 float *U_base_ptr = is_fwd
339 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
340 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
341 weight_transform_data(jcp,
342 &(weights(ofm1 * jcp.oc_block * jcp.oc_reg_block
343 + ofm2,
344 ifm1 * jcp.ic_block * jcp.ic_reg_block
345 + ifm2,
346 0, 0, 0, 0)),
347 U_base_ptr);
348 });
349 }
350
351 parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
352 [&](dim_t N_blk1, dim_t oj, dim_t oi, dim_t M_blk1) {
353 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
354 for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
355 kernel_->gemm_loop_ker((float *)&(M(N_blk1, M_blk1, oj,
356 oi, N_blk2, 0, 0, 0)),
357 (const float *)&(
358 U(M_blk1, oj, oi, K_blk1, 0, 0, 0, 0)),
359 (const float *)&(V(N_blk1, oj, oi, N_blk2,
360 K_blk1, 0, 0, 0)),
361 K_blk1);
362 });
363
364 parallel_nd(jcp.mb, jcp.dimM_nb_block,
365 (jcp.dimM_block * jcp.dimM_reg_block),
366 [&](dim_t img, dim_t M_blk1, dim_t M_blk2) {
367 const int M_blk
368 = M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
369
370 float *bias_ptr = wants_padded_bias
371 && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
372 ? last_slice_bias
373 : jcp.with_bias ? &bias(M_blk, 0) : nullptr;
374 output_transform_data(img, jcp, p_ops,
375 &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
376 &(output(img, M_blk, 0, 0, 0)), bias_ptr);
377 });
378}
379
380template <bool is_fwd>
381void _jit_avx512_core_f32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
382 float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
383 const memory_tracking::grantor_t &scratchpad) const {
384 const auto &jcp = kernel_->jcp;
385 const auto &p_ops = attr_->post_ops_;
386
387 const int inph = is_fwd ? jcp.ih : jcp.oh;
388 const int inpw = is_fwd ? jcp.iw : jcp.ow;
389 const int outh = is_fwd ? jcp.oh : jcp.ih;
390 const int outw = is_fwd ? jcp.ow : jcp.iw;
391
392 array_offset_calculator<float, 5> input(inp_ptr, jcp.mb,
393 jcp.dimK / jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
394 array_offset_calculator<float, 5> output(out_ptr, jcp.mb,
395 jcp.dimM / jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
396 array_offset_calculator<float, 6> weights(wei_ptr,
397 jcp.oc / jcp.oc_simd_block, jcp.ic / jcp.ic_simd_block, jcp.kh,
398 jcp.kw, jcp.ic_simd_block, jcp.oc_simd_block);
399 array_offset_calculator<float, 2> bias(
400 bias_ptr, jcp.oc / jcp.oc_simd_block, jcp.oc_simd_block);
401
402 auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
403 ? wei_ptr
404 : scratchpad.template get<float>(key_wino_U);
405
406 array_offset_calculator<float, 8> U(wino_wei, jcp.dimM_nb_block, alpha,
407 alpha, jcp.dimK_nb_block, jcp.dimM_block * jcp.dimM_reg_block,
408 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_simd_block);
409
410 array_offset_calculator<float, 8> M(is_fwd
411 ? scratchpad.template get<float>(key_wino_M)
412 : scratchpad.template get<float>(key_wino_V),
413 0, jcp.dimM_nb_block, alpha, alpha, jcp.dimN_block,
414 jcp.dimM_block * jcp.dimM_reg_block, jcp.dimN_reg_block,
415 jcp.dimM_simd_block);
416 array_offset_calculator<float, 8> V(is_fwd
417 ? scratchpad.template get<float>(key_wino_V)
418 : scratchpad.template get<float>(key_wino_M),
419 0, alpha, alpha, jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
420 jcp.dimN_reg_block, jcp.dimK_reg_block);
421
422 const bool wants_padded_bias
423 = jcp.with_bias && jcp.oc_without_padding != jcp.oc;
424 float last_slice_bias[simd_w] = {0};
425 if (wants_padded_bias) {
426 for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
427 last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
428 }
429
430 if (jcp.prop_kind != prop_kind::forward_inference) {
431
432 parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
433 (jcp.ic_block * jcp.ic_reg_block),
434 [&](dim_t ofm1, dim_t ifm1, dim_t ofm2, dim_t ifm2) {
435 float *U_base_ptr = is_fwd
436 ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
437 : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
438 weight_transform_data(jcp,
439 &(weights(ofm1 * jcp.oc_block * jcp.oc_reg_block
440 + ofm2,
441 ifm1 * jcp.ic_block * jcp.ic_reg_block
442 + ifm2,
443 0, 0, 0, 0)),
444 U_base_ptr);
445 });
446 }
447
448 parallel_nd_ext(jcp.nthr, jcp.tile_block,
449 [&](int ithr, int nthr, dim_t tile_block) {
450 assert(nthr <= jcp.nthr);
451 MAYBE_UNUSED(nthr);
452
453 for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
454 for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
455
456 input_transform_tileblock_data(tile_block, jcp,
457 &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0,
458 0, 0)),
459 &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
460 }
461 }
462
463 for (int oj = 0; oj < alpha; oj++) {
464 for (int oi = 0; oi < alpha; oi++) {
465 for_(int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block;
466 M_blk1++)
467 for_(int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
468 K_blk1++)
469 for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
470 kernel_->gemm_loop_ker(
471 (float *)&(M(ithr, M_blk1, oj, oi, N_blk, 0,
472 0, 0)),
473 (const float *)&(U(M_blk1, oj, oi, K_blk1,
474 0, 0, 0, 0)),
475 (const float *)&(V(ithr, oj, oi, N_blk,
476 K_blk1, 0, 0, 0)),
477 K_blk1);
478 }
479 }
480
481 for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
482 for (int M_blk2 = 0;
483 M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
484 M_blk2++) {
485 const int M_blk
486 = M_blk1 * jcp.dimM_block * jcp.dimM_reg_block
487 + M_blk2;
488
489 float *bias_ptr = wants_padded_bias
490 && M_blk
491 == jcp.dimM / jcp.dimM_simd_block
492 - 1
493 ? last_slice_bias
494 : jcp.with_bias ? &bias(M_blk, 0) : nullptr;
495
496 output_transform_tileblock_data(tile_block, jcp, p_ops,
497 &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
498 &(output(0, M_blk, 0, 0, 0)), bias_ptr);
499 }
500 }
501 });
502}
503
504template struct _jit_avx512_core_f32_wino_conv_4x3_t<true>;
505template struct _jit_avx512_core_f32_wino_conv_4x3_t<false>;
506
507namespace {
508
509void subarray_sum(size_t num_arrs, float *output, size_t nelems,
510 float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
511 using namespace nstl;
512 const size_t block_size = 16 * 1024 / sizeof(float);
513 const size_t blocks_number = nelems / block_size;
514 const size_t tail = nelems % block_size;
515
516 PRAGMA_OMP(parallel)
517 {
518 const int ithr = OMP_GET_THREAD_NUM();
519 const int nthr = OMP_GET_NUM_THREADS();
520 size_t start {0}, end {0};
521 balance211(blocks_number, nthr, ithr, start, end);
522
523 for (size_t nb = start; nb < end; ++nb) {
524 size_t start_e = nb * block_size;
525 size_t end_e = start_e + block_size;
526 size_t input_start = max(start_e, min(input_starts[0], end_e));
527 size_t input_end = max(start_e, min(input_ends[0], end_e));
528
529 PRAGMA_OMP_SIMD()
530 for (size_t e = start_e; e < input_start; e++) {
531 output[e] = 0.f;
532 }
533
534 PRAGMA_OMP_SIMD()
535 for (size_t e = input_start; e < input_end; e++) {
536 output[e] = input_ptrs[0][e];
537 }
538
539 PRAGMA_OMP_SIMD()
540 for (size_t e = input_end; e < end_e; e++) {
541 output[e] = 0.f;
542 }
543
544 for (size_t a = 1; a < num_arrs; a++) {
545 input_start = max(start_e, input_starts[a]);
546 input_end = min(input_ends[a], end_e);
547
548 PRAGMA_OMP_SIMD()
549 for (size_t e = input_start; e < input_end; e++) {
550 output[e] += input_ptrs[a][e];
551 }
552 }
553 }
554
555 if (tail != 0 && ithr == nthr - 1) {
556 size_t start_e = nelems - tail;
557 size_t end_e = nelems;
558 size_t input_start = max(start_e, min(input_starts[0], end_e));
559 size_t input_end = max(start_e, min(input_ends[0], end_e));
560
561 PRAGMA_OMP_SIMD()
562 for (size_t e = start_e; e < input_start; e++) {
563 output[e] = 0.f;
564 }
565
566 PRAGMA_OMP_SIMD()
567 for (size_t e = input_start; e < input_end; e++) {
568 output[e] = input_ptrs[0][e];
569 }
570
571 PRAGMA_OMP_SIMD()
572 for (size_t e = input_end; e < end_e; e++) {
573 output[e] = 0.f;
574 }
575
576 for (size_t a = 1; a < num_arrs; a++) {
577 input_start = max(start_e, input_starts[a]);
578 input_end = min(input_ends[a], end_e);
579
580 PRAGMA_OMP_SIMD()
581 for (size_t e = input_start; e < input_end; e++) {
582 output[e] += input_ptrs[a][e];
583 }
584 }
585 }
586 }
587}
588
589const int max_threads_number = 1024;
590
591// Sum to the first buffer array
592void array_sum(size_t num_arrs, float *output, size_t nelems,
593 float *input_ptrs[], bool reduce_to_first = true) {
594 const size_t block_size = 16 * 1024 / sizeof(float);
595 const size_t blocks_number = nelems / block_size;
596 const size_t tail = nelems % block_size;
597
598 PRAGMA_OMP(parallel)
599 {
600 const size_t ithr = OMP_GET_THREAD_NUM();
601 const size_t nthr = OMP_GET_NUM_THREADS();
602 size_t start {0}, end {0};
603 balance211(blocks_number, nthr, ithr, start, end);
604
605 for (size_t nb = start; nb < end; ++nb) {
606 size_t start_e = nb * block_size;
607 size_t end_e = start_e + block_size;
608 if (!reduce_to_first) {
609 PRAGMA_OMP_SIMD()
610 for (size_t e = start_e; e < end_e; e++) {
611 output[e] = input_ptrs[0][e];
612 }
613 }
614 for (size_t a = 1; a < num_arrs; a++) {
615 PRAGMA_OMP_SIMD()
616 for (size_t e = start_e; e < end_e; e++) {
617 output[e] += input_ptrs[a][e];
618 }
619 }
620 }
621
622 if (tail != 0 && ithr == nthr - 1) {
623 size_t start_e = nelems - tail;
624 size_t end_e = nelems;
625 if (!reduce_to_first) {
626 PRAGMA_OMP_SIMD()
627 for (size_t e = start_e; e < end_e; e++) {
628 output[e] = input_ptrs[0][e];
629 }
630 }
631 for (size_t a = 1; a < num_arrs; a++) {
632 PRAGMA_OMP_SIMD()
633 for (size_t e = start_e; e < end_e; e++) {
634 output[e] += input_ptrs[a][e];
635 }
636 }
637 }
638 }
639}
640} // namespace
641
642void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
643 _execute_backward_weights_SDGtWo(const float *ptr_src,
644 const float *ptr_diff_dst, float *ptr_diff_weights,
645 float *ptr_diff_bias,
646 const memory_tracking::grantor_t &scratchpad) const {
647 const auto &jcp = kernel_->jcp;
648 const int nthreads = jcp.nthr;
649
650 array_offset_calculator<float, 5> src(
651 (float *)ptr_src, jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
652 array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, jcp.mb,
653 jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
654 array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
655 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
656
657 array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U), 0,
658 alpha, alpha, jcp.oc_block, jcp.ic_block, jcp.ic_simd_block,
659 jcp.oc_reg_block, jcp.oc_simd_block);
660
661 const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc * jcp.ic
662 / jcp.nb_ic;
663 array_offset_calculator<float, 7> diff_weights_prv(
664 scratchpad.get<float>(key_wino_U) + U_sz, 0, jcp.oc / simd_w,
665 jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
666
667 array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M), 0,
668 alpha, alpha, jcp.oc_block, jcp.nb_tile_block_ur, jcp.tile_block_ur,
669 jcp.oc_reg_block, jcp.oc_simd_block);
670
671 array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V), 0,
672 alpha, alpha, jcp.ic_block, jcp.nb_tile_block_ur, jcp.tile_block_ur,
673 jcp.ic_simd_block);
674
675 array_offset_calculator<float, 2> diff_bias_prv(
676 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
677
678 auto trans_ker_p = jit_wino_transform_call_s();
679 float I[alpha][alpha][simd_w];
680 float T[alpha][alpha][simd_w];
681 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f,
682 -0.625f, 1.5f, -1.5f, -2.640625f};
683 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
684 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
685 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
686 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
687
688 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
689 {
690 if (jcp.with_bias) {
691 parallel_nd_in_omp(
692 nthreads, jcp.oc / simd_w, [&](dim_t ithr, dim_t ofm) {
693 float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
694 PRAGMA_OMP_SIMD()
695 for (int v = 0; v < simd_w; v++) {
696 pdbias[v] = 0.0f;
697 }
698 });
699 }
700
701 int ithr = OMP_GET_THREAD_NUM();
702 for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
703 int first_tblk = 0;
704 PRAGMA_OMP(for)
705 for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
706 int tile_index
707 = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
708 int img = tile_index / (jcp.itiles * jcp.jtiles);
709 trans_ker_p.ti = tile_index % jcp.itiles;
710 trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
711 trans_ker_p.M = I;
712 trans_ker_p.T = T;
713 trans_ker_p.G = G_I_3x3_4x4;
714 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
715 int ifm = ifm1 * jcp.ic_block + ifm2;
716 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
717 trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
718 kernel_->src_transform(&trans_ker_p);
719 }
720
721 for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
722 trans_ker_p.G = G_W_3x3_4x4;
723 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
724 int ofm = (ofm1 * jcp.oc_block + ofm2)
725 * jcp.oc_reg_block;
726 trans_ker_p.src
727 = (float *)&(diff_dst(img, ofm, 0, 0, 0));
728 trans_ker_p.dst
729 = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
730 if (jcp.with_bias && ifm1 == 0) {
731 trans_ker_p.bias = (float *)&(
732 diff_bias_prv(ithr, ofm * simd_w));
733 kernel_->diff_dst_transform_wbias(&trans_ker_p);
734 } else {
735 kernel_->diff_dst_transform(&trans_ker_p);
736 }
737 }
738
739 for (int oj = 0; oj < alpha; ++oj) {
740 for (int oi = 0; oi < alpha; ++oi) {
741 kernel_->gemm_loop_ker_first_iter(
742 &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
743 &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
744 &(V(ithr, oj, oi, 0, 0, 0, 0)));
745 }
746 }
747 trans_ker_p.G = G_O_3x3_4x4;
748 for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
749 for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
750 int ofm = (ofm1 * jcp.oc_block + ofm2)
751 * jcp.oc_reg_block
752 + ofm3;
753 for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
754 int ifm = ifm1 * jcp.ic_block + ifm2;
755 trans_ker_p.src = (float *)&(
756 Us(ithr, 0, 0, ofm2, ifm2, 0, ofm3, 0));
757 trans_ker_p.dst = (float *)&(diff_weights_prv(
758 ithr, ofm, ifm, 0, 0, 0, 0));
759 if (first_tblk == 0) {
760 kernel_->diff_weights_transform(
761 &trans_ker_p);
762 } else {
763 kernel_->diff_weights_transform_accum(
764 &trans_ker_p);
765 }
766 }
767 }
768 }
769 }
770 ++first_tblk;
771 }
772 }
773 }
774
775 // Reduce diff-weights
776 {
777 float *output = ptr_diff_weights;
778 float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
779 int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
780 float *input_ptrs[max_threads_number];
781 for (int i = 0; i < nthreads; ++i) {
782 input_ptrs[i] = input_base + nelems * i;
783 }
784 array_sum(nthreads, output, nelems, input_ptrs, false);
785
786 if (jcp.with_bias) {
787 output = ptr_diff_bias;
788 input_base = scratchpad.get<float>(key_conv_bia_reduction);
789 for (int i = 0; i < nthreads; ++i) {
790 input_ptrs[i] = input_base + jcp.oc * i;
791 }
792 array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
793 false);
794 }
795 }
796}
797
798void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
799 _execute_backward_weights_S_D_Giot_W(const float *ptr_src,
800 const float *ptr_diff_dst, float *ptr_diff_weights,
801 float *ptr_diff_bias,
802 const memory_tracking::grantor_t &scratchpad) const {
803 const auto &jcp = kernel_->jcp;
804 const int nthreads = jcp.nthr;
805
806 array_offset_calculator<float, 5> src(
807 (float *)ptr_src, jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
808 array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst, jcp.mb,
809 jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
810 array_offset_calculator<float, 6> diff_weights((float *)ptr_diff_weights,
811 jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
812 array_offset_calculator<float, 1> diff_bias((float *)ptr_diff_bias, jcp.oc);
813
814 array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
815 jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.oc_block, jcp.ic_block,
816 jcp.ic_simd_block, jcp.oc_reg_block, jcp.oc_simd_block);
817
818 const int U_size = jcp.oc * jcp.ic * alpha * alpha;
819 array_offset_calculator<float, 10> Us(
820 scratchpad.get<float>(key_wino_U) + U_size, 0, jcp.nb_ic, jcp.nb_oc,
821 alpha, alpha, jcp.oc_block, jcp.ic_block, jcp.ic_simd_block,
822 jcp.oc_reg_block, jcp.oc_simd_block);
823
824 array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
825 jcp.nb_oc, jcp.tile_block, alpha, alpha, jcp.oc_block,
826 jcp.nb_tile_block_ur, jcp.tile_block_ur, jcp.oc_reg_block,
827 jcp.oc_simd_block);
828
829 array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
830 jcp.nb_ic, jcp.tile_block, alpha, alpha, jcp.ic_block,
831 jcp.nb_tile_block_ur, jcp.tile_block_ur, jcp.ic_simd_block);
832
833 array_offset_calculator<float, 2> diff_bias_prv(
834 scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
835
836 size_t input_starts[max_threads_number] = {0};
837 size_t input_ends[max_threads_number] = {0};
838 size_t first_tblk = 0;
839
840 auto trans_ker_p = jit_wino_transform_call_s();
841 float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f, 0.625f,
842 -0.625f, 1.5f, -1.5f, -2.640625f};
843 float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
844 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
845 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
846 float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
847 float I[alpha][alpha][simd_w];
848 float T[alpha][alpha][simd_w];
849
850 PRAGMA_OMP(parallel num_threads(nthreads)
851 firstprivate(first_tblk, trans_ker_p, I, T))
852 {
853 if (jcp.with_bias) {
854 parallel_nd_in_omp(nthreads, jcp.oc, [&](dim_t ithr, dim_t ofm) {
855 diff_bias_prv(ithr, ofm) = 0.0f;
856 });
857 }
858
859 trans_ker_p.G = G_I_3x3_4x4;
860 trans_ker_p.M = I;
861 trans_ker_p.T = T;
862
863 parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
864 [&](dim_t ifm1, dim_t ifm2, dim_t img) {
865 size_t ifm = ifm1 * jcp.ic_block + ifm2;
866 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
867 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
868 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
869 % jcp.nb_tile_block_ur;
870 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
871 / jcp.nb_tile_block_ur;
872 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
873 trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
874 trans_ker_p.dst
875 = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
876 kernel_->src_transform(&trans_ker_p);
877 });
878
879 int ithr = OMP_GET_THREAD_NUM();
880 trans_ker_p.G = G_W_3x3_4x4;
881 parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
882 [&](dim_t ofm1, dim_t ofm2, dim_t img) {
883 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
884 size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
885 size_t tblk3 = tile_base_index % jcp.tile_block_ur;
886 size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
887 % jcp.nb_tile_block_ur;
888 size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
889 / jcp.nb_tile_block_ur;
890 trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
891 trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
892 trans_ker_p.dst = (float *)&(
893 M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
894 if (jcp.with_bias) {
895 trans_ker_p.bias
896 = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
897 kernel_->diff_dst_transform_wbias(&trans_ker_p);
898 } else {
899 kernel_->diff_dst_transform(&trans_ker_p);
900 }
901 });
902
903 PRAGMA_OMP(barrier)
904
905 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
906 [&](dim_t ifm1, dim_t ofm1, dim_t oj, dim_t oi, dim_t tblk1) {
907 if (first_tblk == 0) {
908 input_starts[ithr] = (float *)&(Us(ithr, ifm1, ofm1, oj,
909 oi, 0, 0, 0, 0, 0))
910 - (float *)&(
911 Us(ithr, 0, 0, 0, 0, 0, 0, 0, 0, 0));
912 input_ends[ithr] = input_starts[ithr]
913 + jcp.oc_block * jcp.ic_block
914 * jcp.ic_simd_block * jcp.oc_reg_block
915 * jcp.oc_simd_block;
916 } else if (tblk1 == 0) {
917 input_ends[ithr] += jcp.oc_block * jcp.ic_block
918 * jcp.ic_simd_block * jcp.oc_reg_block
919 * jcp.oc_simd_block;
920 }
921
922 if (first_tblk == 0 || tblk1 == 0) {
923 kernel_->gemm_loop_ker_first_iter(
924 &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0, 0)),
925 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
926 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
927 } else {
928 kernel_->gemm_loop_ker(
929 &(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0, 0, 0)),
930 &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
931 &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
932 }
933 ++first_tblk;
934 });
935 }
936
937 // Reduce diff-weights
938 {
939 float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
940 size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
941 float *input_ptrs[max_threads_number];
942 for (int i = 0; i < nthreads; ++i)
943 input_ptrs[i] = output + nelems * (i + 1);
944 subarray_sum(
945 nthreads, output, nelems, input_ptrs, input_starts, input_ends);
946 }
947
948 trans_ker_p.G = G_O_3x3_4x4;
949 PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p))
950 {
951 parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
952 jcp.oc_reg_block,
953 [&](dim_t ifm1, dim_t ofm1, dim_t ofm2, dim_t ifm2,
954 dim_t ofm3) {
955 int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
956 + ofm3;
957 int ifm = ifm1 * jcp.ic_block + ifm2;
958 trans_ker_p.src = (float *)&(
959 U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, ofm3, 0));
960 trans_ker_p.dst
961 = (float *)&(diff_weights(ofm, ifm, 0, 0, 0, 0));
962 kernel_->diff_weights_transform(&trans_ker_p);
963 });
964 }
965
966 if (jcp.with_bias) {
967 parallel_nd(jcp.oc / simd_w, [&](dim_t ofm1) {
968 float *pbias = &(diff_bias(ofm1 * simd_w));
969 float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
970
971 const int blk_sz = ofm1 == jcp.oc / simd_w - 1
972 ? jcp.oc_without_padding - ofm1 * simd_w
973 : simd_w;
974
975 PRAGMA_OMP_SIMD()
976 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
977 pbias[ofm2] = pbias_prv[ofm2];
978 }
979
980 for (int ithr = 1; ithr < nthreads; ++ithr) {
981 pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
982 PRAGMA_OMP_SIMD()
983 for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
984 pbias[ofm2] += pbias_prv[ofm2];
985 }
986 }
987 });
988 }
989}
990
991} // namespace x64
992} // namespace cpu
993} // namespace impl
994} // namespace dnnl
995// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
996