1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "conv/ref_conv.hpp"
20
21namespace conv {
22
23template <typename Telem, size_t Tdims>
24struct array_offset_calculator_t {
25 template <typename... Targs>
26 array_offset_calculator_t(Telem *base, Targs... Fargs) : _dims {Fargs...} {
27 _base_ptr = base;
28 }
29 template <typename... Targs>
30 inline Telem &operator()(Targs... Fargs) {
31 return *(_base_ptr + _offset(1, Fargs...));
32 }
33
34private:
35 template <typename... Targs>
36 inline size_t _offset(size_t const dimension, size_t element) {
37 return element;
38 }
39
40 template <typename... Targs>
41 inline size_t _offset(
42 size_t const dimension, size_t theta, size_t element) {
43 return element + (_dims[dimension] * theta);
44 }
45
46 template <typename... Targs>
47 inline size_t _offset(size_t const dimension, size_t theta, size_t element,
48 Targs... Fargs) {
49 size_t t_prime = element + (_dims[dimension] * theta);
50 return _offset(dimension + 1, t_prime, Fargs...);
51 }
52
53 Telem *_base_ptr;
54 const int64_t _dims[Tdims];
55};
56
57void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) {
58 float T[6][6];
59 float t0;
60 float t1;
61 float t2;
62 float t3;
63 float t4;
64 float t5;
65
66 for (int i = 0; i < 6; i++) {
67 t0 = I[2][i] * -2.25f + I[4][i];
68 t1 = I[1][i] * -2.25f + I[3][i];
69 t2 = I[2][i] * -0.390625f + I[4][i];
70 t3 = I[1][i] * -0.390625f + I[3][i];
71 t4 = I[0][i] * 0.87890625f + I[4][i];
72 t5 = I[1][i] * 0.87890625f + I[5][i];
73
74 T[0][i] = I[2][i] * -2.640625f + t4;
75 T[1][i] = t1 * 0.625f + t0;
76 T[2][i] = t1 * -0.625f + t0;
77 T[3][i] = t3 * 1.5f + t2;
78 T[4][i] = t3 * -1.5f + t2;
79 T[5][i] = I[3][i] * -2.640625f + t5;
80 }
81
82 for (int i = 0; i < 6; i++) {
83 t0 = T[i][2] * -2.25f + T[i][4];
84 t1 = T[i][1] * -2.25f + T[i][3];
85 t2 = T[i][2] * -0.390625f + T[i][4];
86 t3 = T[i][1] * -0.390625f + T[i][3];
87 t4 = T[i][0] * 0.87890625f + T[i][4];
88 t5 = T[i][1] * 0.87890625f + T[i][5];
89
90 Iw[i][0] = T[i][2] * -2.640625f + t4;
91 Iw[i][1] = t1 * 0.625f + t0;
92 Iw[i][2] = t1 * -0.625f + t0;
93 Iw[i][3] = t3 * 1.5f + t2;
94 Iw[i][4] = t3 * -1.5f + t2;
95 Iw[i][5] = T[i][3] * -2.640625f + t5;
96 }
97}
98
99void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) {
100 float Fw[6];
101 float T[6][3];
102 float t0;
103 float t1;
104 float t2;
105
106 for (int i = 0; i < 3; i++) {
107 t0 = 0.26890756302521f * F[2][i];
108 t1 = -t0 - 0.688403361344538f * F[0][i];
109 t2 = t0 + 0.119514472455649f * F[0][i];
110
111 T[0][i] = 1.13777777777778f * F[0][i];
112 T[1][i] = t1 - 0.430252100840336f * F[1][i];
113 T[2][i] = t1 + 0.430252100840336f * F[1][i];
114 T[3][i] = t2 + 0.179271708683473f * F[1][i];
115 T[4][i] = t2 - 0.179271708683473f * F[1][i];
116 T[5][i] = F[2][i];
117 }
118
119 for (int i = 0; i < 6; i++) {
120 t0 = 0.26890756302521f * T[i][2];
121 t1 = -t0 - 0.688403361344538f * T[i][0];
122 t2 = t0 + 0.119514472455649f * T[i][0];
123
124 Fw[0] = 1.13777777777778f * T[i][0];
125 Fw[1] = t1 - 0.430252100840336f * T[i][1];
126 Fw[2] = t1 + 0.430252100840336f * T[i][1];
127 Fw[3] = t2 + 0.179271708683473f * T[i][1];
128 Fw[4] = t2 - 0.179271708683473f * T[i][1];
129 Fw[5] = T[i][2];
130 for (int l = 0; l < 6; l++) {
131 Fw_[i][l] = Fw[l];
132 }
133 }
134}
135
136void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) {
137 float T[4][6];
138 float t0;
139 float t1;
140 float t2;
141 float t3;
142
143 for (int i = 0; i < 6; i++) {
144 t0 = Mw[1][i] + Mw[2][i];
145 t1 = Mw[3][i] + Mw[4][i];
146 t2 = Mw[1][i] - Mw[2][i];
147 t3 = Mw[3][i] - Mw[4][i];
148
149 T[0][i] = t0 + t1 + Mw[0][i];
150 T[1][i] = t2 * 0.625f + t3 * 1.5f;
151 T[2][i] = t0 * 0.390625f + t1 * 2.25f;
152 T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i];
153 }
154
155 for (int i = 0; i < 4; i++) {
156 t0 = T[i][1] + T[i][2];
157 t1 = T[i][3] + T[i][4];
158 t2 = T[i][1] - T[i][2];
159 t3 = T[i][3] - T[i][4];
160
161 O[i][0] = t0 + t1 + T[i][0];
162 O[i][1] = t2 * 0.625f + t3 * 1.5f;
163 O[i][2] = t0 * 0.390625f + t1 * 2.25f;
164 O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5];
165 }
166}
167
168void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) {
169 float T[6][4];
170 float t0;
171 float t1;
172 float t2;
173 float t3;
174 float t4;
175
176 for (int i = 0; i < 4; i++) {
177 t0 = F[2][i] * 0.26890756302521f;
178 t1 = F[0][i] * -0.688403361344538f - t0;
179 t2 = F[0][i] * 0.119514472455649f + t0;
180 t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f;
181 t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f;
182
183 T[0][i] = F[0][i] * 1.13777777777778f;
184 T[1][i] = t1 - t3;
185 T[2][i] = t1 + t3;
186 T[3][i] = t2 + t4;
187 T[4][i] = t2 - t4;
188 T[5][i] = F[3][i];
189 }
190
191 for (int i = 0; i < 6; i++) {
192 t0 = T[i][2] * 0.26890756302521f;
193 t1 = T[i][0] * -0.688403361344538f - t0;
194 t2 = T[i][0] * 0.119514472455649f + t0;
195 t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f;
196 t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f;
197
198 Fw[i][0] = T[i][0] * 1.13777777777778f;
199 Fw[i][1] = t1 - t3;
200 Fw[i][2] = t1 + t3;
201 Fw[i][3] = t2 + t4;
202 Fw[i][4] = t2 - t4;
203 Fw[i][5] = T[i][3];
204 }
205}
206
207void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) {
208 float T[3][6];
209 float t0;
210 float t1;
211 float t2;
212 float M_[3];
213
214 for (int i = 0; i < 6; i++) {
215 t0 = Mw[1][i] + Mw[2][i];
216 t1 = Mw[3][i] + Mw[4][i];
217 t2 = t1 * 2.25f + Mw[5][i];
218
219 T[0][i] = Mw[0][i] + t0 + t1;
220 T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) + 1.5f * (Mw[3][i] - Mw[4][i]);
221 T[2][i] = t0 * 0.390625f + t2;
222 }
223 for (int i = 0; i < 3; i++) {
224 t0 = T[i][1] + T[i][2];
225 t1 = T[i][3] + T[i][4];
226 t2 = t1 * 2.25f + T[i][5];
227
228 M_[0] = T[i][0] + t0 + t1;
229 M_[1] = 0.625f * (T[i][1] - T[i][2]) + 1.5f * (T[i][3] - T[i][4]);
230 M_[2] = t0 * 0.390625f + t2;
231
232 for (int k = 0; k < 3; k++) {
233 M[i][k] = M_[k];
234 }
235 }
236}
237
238struct scratchpad_t {
239 float *_u_ptr;
240 float *_m_ptr;
241 float *_v_ptr;
242
243 int64_t h_tiles;
244 int64_t w_tiles;
245
246 const int64_t alpha = 6;
247 const int64_t out_dim = 4;
248};
249
250int init_scratchpad(const prb_t *prb, scratchpad_t &sp) {
251 if (sp.out_dim != 4 || sp.alpha != 6) return FAIL;
252
253 sp.h_tiles = prb->dir == FLAG_FWD ? div_up(prb->oh, sp.out_dim)
254 : div_up(prb->ih, sp.out_dim);
255 sp.w_tiles = prb->dir == FLAG_FWD ? div_up(prb->ow, sp.out_dim)
256 : div_up(prb->iw, sp.out_dim);
257
258 sp._u_ptr = (float *)zmalloc(
259 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic, 64);
260 sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->ic
261 * prb->mb * sp.h_tiles * sp.w_tiles,
262 64);
263 sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->oc
264 * prb->mb * sp.h_tiles * sp.w_tiles,
265 64);
266
267 if (sp._u_ptr == nullptr || sp._v_ptr == nullptr || sp._m_ptr == nullptr)
268 return dnnl_out_of_memory;
269
270 array_set((char *)sp._u_ptr,
271 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic);
272 array_set((char *)sp._v_ptr,
273 sizeof(float) * sp.alpha * sp.alpha * prb->ic * prb->mb * sp.h_tiles
274 * sp.w_tiles);
275 array_set((char *)sp._m_ptr,
276 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->mb * sp.h_tiles
277 * sp.w_tiles);
278
279 return OK;
280}
281
282void free_scratchpad(scratchpad_t *sp) {
283 if (sp->_u_ptr != nullptr) zfree(sp->_u_ptr);
284 if (sp->_v_ptr != nullptr) zfree(sp->_v_ptr);
285 if (sp->_m_ptr != nullptr) zfree(sp->_m_ptr);
286}
287
288void compute_wino_ref_fwd(const prb_t *prb, const args_t &args) {
289 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
290 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
291 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
292 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
293 scratchpad_t sp {};
294 SAFE_V(init_scratchpad(prb, sp));
295
296 array_offset_calculator_t<float, 4> U(
297 sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic);
298 array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha,
299 prb->ic, prb->mb, sp.h_tiles, sp.w_tiles);
300 array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha,
301 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
302
303 SAFE_V(prb->kh == 3 ? OK : FAIL);
304 SAFE_V(prb->kw == 3 ? OK : FAIL);
305
306 bool with_bias = prb->dir & FLAG_BIA;
307 const int64_t t_pad = prb->ph;
308 const int64_t l_pad = prb->pw;
309 const int64_t wp_max = prb->iw + l_pad;
310 const int64_t hp_max = prb->ih + t_pad;
311 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
312
313 benchdnn_parallel_nd(prb->mb, prb->ic, sp.h_tiles, sp.w_tiles,
314 [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) {
315 float I[6][6] = {};
316 float _v[6][6] = {};
317 /* src_transform v <- B_t * d * B */
318 for (int64_t j = 0; j < sp.alpha; j++) {
319 int64_t ydim = hfm * sp.out_dim + j;
320 if ((t_pad <= ydim) && (ydim < hp_max)) {
321 for (int64_t k = 0; k < sp.alpha; k++) {
322 int64_t xdim = wfm * sp.out_dim + k;
323 if ((l_pad <= xdim) && (xdim < wp_max)) {
324 size_t src_off = src_off_f(prb, img, 0, c, 0,
325 ydim - t_pad, xdim - l_pad);
326 I[j][k] = ((float *)src_m)[src_off];
327 }
328 }
329 }
330 }
331 trans_I_4x4_3x3(_v, I);
332
333 /* scatter v:V */
334 for (int64_t j = 0; j < sp.alpha; j++) {
335 for (int64_t k = 0; k < sp.alpha; k++) {
336 V(j, k, c, img, hfm, wfm) = _v[j][k];
337 }
338 }
339 });
340
341 benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) {
342 float F[3][3] = {};
343 float _u[6][6] = {};
344 /* wei_transform u <- G * g * G_t */
345 for_(int64_t j = 0; j < prb->kh; j++)
346 for (int64_t i = 0; i < prb->kw; i++) {
347 size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, j, i);
348 F[j][i] = ((float *)wei_m)[wei_off];
349 }
350 trans_W_4x4_3x3(_u, F);
351
352 /* scatter u:U */
353 for_(int64_t j = 0; j < sp.alpha; j++)
354 for (int64_t k = 0; k < sp.alpha; k++) {
355 U(j, k, oc, ic) = _u[j][k];
356 }
357 });
358
359 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
360 /* M = U * V */
361 gemm("C", "N", "N", prb->oc, p_dim, prb->ic, 1.0,
362 (float *)&(U(j, k, 0, 0)), prb->ic,
363 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
364 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
365 });
366
367 auto v_po_masks = prb->attr.post_ops.get_po_masks();
368 benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles,
369 [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) {
370 float O[4][4] = {};
371 float _m[6][6] = {};
372 /* Y = A_t *m * A */
373 for_(int64_t j = 0; j < sp.alpha; j++)
374 for (int64_t k = 0; k < sp.alpha; k++) {
375 _m[j][k] = M(j, k, oc, img, hfm, wfm);
376 }
377 trans_O_4x4_3x3(_m, O);
378
379 for (int64_t j = 0; j < sp.out_dim; j++) {
380 int64_t ydim = hfm * sp.out_dim + j;
381 if (ydim >= prb->oh) continue;
382
383 for (int64_t k = 0; k < sp.out_dim; k++) {
384 float conv_res = O[j][k];
385 int64_t xdim = wfm * sp.out_dim + k;
386 if (xdim >= prb->ow) continue;
387
388 const size_t dst_off
389 = dst_off_f(prb, img, 0, oc, 0, ydim, xdim);
390 float &dst = ((float *)dst_m)[dst_off];
391
392 const size_t bia_off = bia_off_f(prb, 0, oc);
393 conv_res += with_bias ? ((float *)bia_m)[bia_off] : 0.f;
394
395 const auto v_po_vals = prepare_po_vals(
396 dst_m, args, v_po_masks, dst_off);
397
398 maybe_post_ops(prb->attr, conv_res, dst, v_po_vals);
399
400 dst = conv_res;
401 }
402 }
403 });
404 free_scratchpad(&sp);
405}
406
407void compute_wino_ref_bwd_d(const prb_t *prb, const args_t &args) {
408 const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
409 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
410 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
411 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
412 scratchpad_t sp {};
413 SAFE_V(init_scratchpad(prb, sp));
414
415 array_offset_calculator_t<float, 4> U(
416 sp._u_ptr, sp.alpha, sp.alpha, prb->ic, prb->oc);
417 array_offset_calculator_t<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha,
418 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
419 array_offset_calculator_t<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha,
420 prb->ic, prb->mb, sp.h_tiles, sp.w_tiles);
421
422 SAFE_V(prb->kh == 3 ? OK : FAIL);
423 SAFE_V(prb->kw == 3 ? OK : FAIL);
424
425 const int64_t r_pad = MAX2(0, prb->ow - 1 + prb->kw - prb->iw - prb->pw);
426 const int64_t l_pad = prb->iw + r_pad - prb->ow;
427 const int64_t t_pad = prb->ih + prb->ph - prb->oh;
428 const int64_t wp_max = prb->ow + l_pad;
429 const int64_t hp_max = prb->oh + t_pad;
430 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
431
432 bool with_bias = prb->dir & FLAG_BIA;
433
434 benchdnn_parallel_nd(prb->mb, prb->oc, sp.h_tiles, sp.w_tiles,
435 [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) {
436 float I[6][6] = {};
437 float _v[6][6] = {};
438 /* diff_src transform v <- B_t * d * B */
439 for (int64_t j = 0; j < sp.alpha; j++) {
440 int64_t ydim = hfm * sp.out_dim + j;
441 if ((t_pad <= ydim) && (ydim < hp_max)) {
442 for (int64_t k = 0; k < sp.alpha; k++) {
443 int64_t xdim = wfm * sp.out_dim + k;
444 if ((l_pad <= xdim) && (xdim < wp_max)) {
445 size_t dst_off = dst_off_f(prb, img, 0, c, 0,
446 ydim - t_pad, xdim - l_pad);
447 I[j][k] = ((float *)diff_dst_m)[dst_off];
448 }
449 }
450 }
451 trans_I_4x4_3x3(_v, I);
452
453 /* scatter v:V */
454 for_(int64_t j = 0; j < sp.alpha; j++)
455 for (int64_t k = 0; k < sp.alpha; k++) {
456 V(j, k, c, img, hfm, wfm) = _v[j][k];
457 }
458 }
459 });
460
461 benchdnn_parallel_nd(prb->ic, prb->oc, [&](int64_t ic, int64_t oc) {
462 float F[3][3] = {};
463 float _u[6][6] = {};
464 /* wei_transform u <- G * g * G_t */
465 for_(int64_t j = 0; j < prb->kh; j++)
466 for (int64_t i = 0; i < prb->kw; i++) {
467 size_t wei_off = wei_off_f(
468 prb, 0, oc, ic, 0, prb->kh - j - 1, prb->kw - i - 1);
469 F[j][i] = ((float *)wei_m)[wei_off];
470 }
471 trans_W_4x4_3x3(_u, F);
472
473 /* scatter u:U */
474 for_(int64_t j = 0; j < sp.alpha; j++)
475 for (int64_t k = 0; k < sp.alpha; k++) {
476 U(j, k, ic, oc) = _u[j][k];
477 }
478 });
479
480 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
481 /* M = U * V */
482 gemm("C", "N", "N", prb->ic, p_dim, prb->oc, 1.0,
483 (float *)&(U(j, k, 0, 0)), prb->oc,
484 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
485 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
486 });
487
488 benchdnn_parallel_nd(prb->ic, prb->mb, sp.h_tiles, sp.w_tiles,
489 [&](int64_t c, int64_t img, int64_t hfm, int64_t wfm) {
490 float O[4][4] = {};
491 float _m[6][6] = {};
492 /* diff_dst: Y = A_t *m * A */
493 for_(int64_t j = 0; j < sp.alpha; j++)
494 for (int64_t k = 0; k < sp.alpha; k++) {
495 _m[j][k] = M(j, k, c, img, hfm, wfm);
496 }
497 trans_O_4x4_3x3(_m, O);
498
499 float bia = with_bias ? ((float *)bia_m)[c] : 0.f;
500
501 for (int64_t j = 0; j < sp.out_dim; j++) {
502 int64_t ydim = hfm * sp.out_dim + j;
503 if (ydim < prb->ih) {
504 for (int64_t k = 0; k < sp.out_dim; k++) {
505 int64_t xdim = wfm * sp.out_dim + k;
506 if (xdim < prb->iw) {
507 size_t src_off = src_off_f(
508 prb, img, 0, c, 0, ydim, xdim);
509 ((float *)diff_src_m)[src_off] = O[j][k] + bia;
510 }
511 }
512 }
513 }
514 });
515
516 free_scratchpad(&sp);
517}
518
519void compute_wino_ref_bwd_w(const prb_t *prb, const args_t &args) {
520 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
521 const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
522 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
523 scratchpad_t sp {};
524 SAFE_V(init_scratchpad(prb, sp));
525
526 array_offset_calculator_t<float, 4> U(
527 sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic);
528 array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha,
529 prb->mb, sp.h_tiles, sp.w_tiles, prb->ic);
530 array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha,
531 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
532
533 SAFE_V(prb->kh == 3 ? OK : FAIL);
534 SAFE_V(prb->kw == 3 ? OK : FAIL);
535
536 const int64_t t_pad = prb->ph;
537 const int64_t l_pad = prb->pw;
538 const int64_t wp_max = prb->iw + l_pad;
539 const int64_t hp_max = prb->ih + t_pad;
540 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
541
542 benchdnn_parallel_nd(prb->mb, sp.h_tiles, sp.w_tiles, prb->ic,
543 [&](int64_t img, int64_t hfm, int64_t wfm, int64_t ic) {
544 float I[6][6] = {};
545 float _v[6][6] = {};
546 /* src transform v <- B_t * d * B */
547 for (int64_t j = 0; j < sp.alpha; j++) {
548 int64_t ydim = hfm * sp.out_dim + j;
549 if ((t_pad <= ydim) && (ydim < hp_max)) {
550 for (int64_t k = 0; k < sp.alpha; k++) {
551 int64_t xdim = wfm * sp.out_dim + k;
552 if ((l_pad <= xdim) && (xdim < wp_max)) {
553 size_t src_off = src_off_f(prb, img, 0, ic, 0,
554 ydim - t_pad, xdim - l_pad);
555 I[j][k] = ((float *)src_m)[src_off];
556 }
557 }
558 }
559 }
560 trans_I_4x4_3x3(_v, I);
561
562 /* scatter v:V */
563 for_(int64_t j = 0; j < sp.alpha; j++)
564 for (int64_t k = 0; k < sp.alpha; k++) {
565 V(j, k, img, hfm, wfm, ic) = _v[j][k];
566 }
567 });
568
569 benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles,
570 [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) {
571 float O[6][6] = {};
572 float _m[6][6] = {};
573 /* diff_dst transform */
574 for (int64_t j = 0; j < sp.alpha; j++) {
575 int64_t ydim = hfm * sp.out_dim + j;
576 if (ydim < prb->oh) {
577 for (int64_t k = 0; k < sp.alpha; k++) {
578 int64_t xdim = wfm * sp.out_dim + k;
579 if (xdim < prb->ow) {
580 size_t dst_off = dst_off_f(
581 prb, img, 0, oc, 0, ydim, xdim);
582 O[j][k] = ((float *)diff_dst_m)[dst_off];
583 }
584 }
585 }
586 }
587 trans_W_3x3_4x4_wu(_m, O);
588 /* scatter v:V */
589 for_(int64_t j = 0; j < sp.alpha; j++)
590 for (int64_t k = 0; k < sp.alpha; k++) {
591 M(j, k, oc, img, hfm, wfm) = _m[j][k];
592 }
593 });
594
595 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
596 /* GeMM U = M * V */
597 gemm("C", "N", "N", prb->oc, prb->ic, p_dim, 1.0,
598 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim,
599 (float *)&(V(j, k, 0, 0, 0, 0)), prb->ic, 1.0,
600 (float *)&(U(j, k, 0, 0)), prb->ic);
601 });
602
603 benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) {
604 float F[6][6] = {};
605 float _u[3][3] = {};
606 for_(int64_t j = 0; j < sp.alpha; j++)
607 for (int64_t k = 0; k < sp.alpha; k++) {
608 F[j][k] = U(j, k, oc, ic);
609 }
610 trans_O_3x3_4x4_wu(F, _u);
611
612 /* scatter u:U */
613 for_(int64_t kh = 0; kh < prb->kh; kh++)
614 for (int64_t kw = 0; kw < prb->kw; kw++) {
615 size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, kh, kw);
616 ((float *)diff_wei_m)[wei_off] = _u[kh][kw];
617 }
618 });
619
620 free_scratchpad(&sp);
621
622 if (prb->dir & FLAG_BIA) compute_ref_bwd_bias(prb, args);
623}
624
625} // namespace conv
626