1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "deconv/ref_deconv.hpp"
20
21// Copy paste from conv/ref_wino.cpp.
22// TODO: remove the duplication.
23
24namespace deconv {
25
26template <typename Telem, size_t Tdims>
27struct array_offset_calculator_t {
28 template <typename... Targs>
29 array_offset_calculator_t(Telem *base, Targs... Fargs) : _dims {Fargs...} {
30 _base_ptr = base;
31 }
32 template <typename... Targs>
33 inline Telem &operator()(Targs... Fargs) {
34 return *(_base_ptr + _offset(1, Fargs...));
35 }
36
37private:
38 template <typename... Targs>
39 inline size_t _offset(size_t const dimension, size_t element) {
40 return element;
41 }
42
43 template <typename... Targs>
44 inline size_t _offset(
45 size_t const dimension, size_t theta, size_t element) {
46 return element + (_dims[dimension] * theta);
47 }
48
49 template <typename... Targs>
50 inline size_t _offset(size_t const dimension, size_t theta, size_t element,
51 Targs... Fargs) {
52 size_t t_prime = element + (_dims[dimension] * theta);
53 return _offset(dimension + 1, t_prime, Fargs...);
54 }
55
56 Telem *_base_ptr;
57 const int64_t _dims[Tdims];
58};
59
60void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) {
61 float T[6][6];
62 float t0;
63 float t1;
64 float t2;
65 float t3;
66 float t4;
67 float t5;
68
69 for (int i = 0; i < 6; i++) {
70 t0 = I[2][i] * -2.25f + I[4][i];
71 t1 = I[1][i] * -2.25f + I[3][i];
72 t2 = I[2][i] * -0.390625f + I[4][i];
73 t3 = I[1][i] * -0.390625f + I[3][i];
74 t4 = I[0][i] * 0.87890625f + I[4][i];
75 t5 = I[1][i] * 0.87890625f + I[5][i];
76
77 T[0][i] = I[2][i] * -2.640625f + t4;
78 T[1][i] = t1 * 0.625f + t0;
79 T[2][i] = t1 * -0.625f + t0;
80 T[3][i] = t3 * 1.5f + t2;
81 T[4][i] = t3 * -1.5f + t2;
82 T[5][i] = I[3][i] * -2.640625f + t5;
83 }
84
85 for (int i = 0; i < 6; i++) {
86 t0 = T[i][2] * -2.25f + T[i][4];
87 t1 = T[i][1] * -2.25f + T[i][3];
88 t2 = T[i][2] * -0.390625f + T[i][4];
89 t3 = T[i][1] * -0.390625f + T[i][3];
90 t4 = T[i][0] * 0.87890625f + T[i][4];
91 t5 = T[i][1] * 0.87890625f + T[i][5];
92
93 Iw[i][0] = T[i][2] * -2.640625f + t4;
94 Iw[i][1] = t1 * 0.625f + t0;
95 Iw[i][2] = t1 * -0.625f + t0;
96 Iw[i][3] = t3 * 1.5f + t2;
97 Iw[i][4] = t3 * -1.5f + t2;
98 Iw[i][5] = T[i][3] * -2.640625f + t5;
99 }
100}
101
102void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) {
103 float Fw[6];
104 float T[6][3];
105 float t0;
106 float t1;
107 float t2;
108
109 for (int i = 0; i < 3; i++) {
110 t0 = 0.26890756302521f * F[2][i];
111 t1 = -t0 - 0.688403361344538f * F[0][i];
112 t2 = t0 + 0.119514472455649f * F[0][i];
113
114 T[0][i] = 1.13777777777778f * F[0][i];
115 T[1][i] = t1 - 0.430252100840336f * F[1][i];
116 T[2][i] = t1 + 0.430252100840336f * F[1][i];
117 T[3][i] = t2 + 0.179271708683473f * F[1][i];
118 T[4][i] = t2 - 0.179271708683473f * F[1][i];
119 T[5][i] = F[2][i];
120 }
121
122 for (int i = 0; i < 6; i++) {
123 t0 = 0.26890756302521f * T[i][2];
124 t1 = -t0 - 0.688403361344538f * T[i][0];
125 t2 = t0 + 0.119514472455649f * T[i][0];
126
127 Fw[0] = 1.13777777777778f * T[i][0];
128 Fw[1] = t1 - 0.430252100840336f * T[i][1];
129 Fw[2] = t1 + 0.430252100840336f * T[i][1];
130 Fw[3] = t2 + 0.179271708683473f * T[i][1];
131 Fw[4] = t2 - 0.179271708683473f * T[i][1];
132 Fw[5] = T[i][2];
133 for (int l = 0; l < 6; l++) {
134 Fw_[i][l] = Fw[l];
135 }
136 }
137}
138
139void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) {
140 float T[4][6];
141 float t0;
142 float t1;
143 float t2;
144 float t3;
145
146 for (int i = 0; i < 6; i++) {
147 t0 = Mw[1][i] + Mw[2][i];
148 t1 = Mw[3][i] + Mw[4][i];
149 t2 = Mw[1][i] - Mw[2][i];
150 t3 = Mw[3][i] - Mw[4][i];
151
152 T[0][i] = t0 + t1 + Mw[0][i];
153 T[1][i] = t2 * 0.625f + t3 * 1.5f;
154 T[2][i] = t0 * 0.390625f + t1 * 2.25f;
155 T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i];
156 }
157
158 for (int i = 0; i < 4; i++) {
159 t0 = T[i][1] + T[i][2];
160 t1 = T[i][3] + T[i][4];
161 t2 = T[i][1] - T[i][2];
162 t3 = T[i][3] - T[i][4];
163
164 O[i][0] = t0 + t1 + T[i][0];
165 O[i][1] = t2 * 0.625f + t3 * 1.5f;
166 O[i][2] = t0 * 0.390625f + t1 * 2.25f;
167 O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5];
168 }
169}
170
171void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) {
172 float T[6][4];
173 float t0;
174 float t1;
175 float t2;
176 float t3;
177 float t4;
178
179 for (int i = 0; i < 4; i++) {
180 t0 = F[2][i] * 0.26890756302521f;
181 t1 = F[0][i] * -0.688403361344538f - t0;
182 t2 = F[0][i] * 0.119514472455649f + t0;
183 t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f;
184 t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f;
185
186 T[0][i] = F[0][i] * 1.13777777777778f;
187 T[1][i] = t1 - t3;
188 T[2][i] = t1 + t3;
189 T[3][i] = t2 + t4;
190 T[4][i] = t2 - t4;
191 T[5][i] = F[3][i];
192 }
193
194 for (int i = 0; i < 6; i++) {
195 t0 = T[i][2] * 0.26890756302521f;
196 t1 = T[i][0] * -0.688403361344538f - t0;
197 t2 = T[i][0] * 0.119514472455649f + t0;
198 t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f;
199 t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f;
200
201 Fw[i][0] = T[i][0] * 1.13777777777778f;
202 Fw[i][1] = t1 - t3;
203 Fw[i][2] = t1 + t3;
204 Fw[i][3] = t2 + t4;
205 Fw[i][4] = t2 - t4;
206 Fw[i][5] = T[i][3];
207 }
208}
209
210void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) {
211 float T[3][6];
212 float t0;
213 float t1;
214 float t2;
215 float M_[3];
216
217 for (int i = 0; i < 6; i++) {
218 t0 = Mw[1][i] + Mw[2][i];
219 t1 = Mw[3][i] + Mw[4][i];
220 t2 = t1 * 2.25f + Mw[5][i];
221
222 T[0][i] = Mw[0][i] + t0 + t1;
223 T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) + 1.5f * (Mw[3][i] - Mw[4][i]);
224 T[2][i] = t0 * 0.390625f + t2;
225 }
226 for (int i = 0; i < 3; i++) {
227 t0 = T[i][1] + T[i][2];
228 t1 = T[i][3] + T[i][4];
229 t2 = t1 * 2.25f + T[i][5];
230
231 M_[0] = T[i][0] + t0 + t1;
232 M_[1] = 0.625f * (T[i][1] - T[i][2]) + 1.5f * (T[i][3] - T[i][4]);
233 M_[2] = t0 * 0.390625f + t2;
234
235 for (int k = 0; k < 3; k++) {
236 M[i][k] = M_[k];
237 }
238 }
239}
240
241struct scratchpad_t {
242 float *_u_ptr;
243 float *_m_ptr;
244 float *_v_ptr;
245
246 int64_t h_tiles;
247 int64_t w_tiles;
248
249 const int64_t alpha = 6;
250 const int64_t out_dim = 4;
251};
252
253int init_scratchpad(const prb_t *prb, scratchpad_t &sp) {
254 if (sp.out_dim != 4 || sp.alpha != 6) return FAIL;
255
256 sp.h_tiles = prb->dir == FLAG_FWD ? div_up(prb->oh, sp.out_dim)
257 : div_up(prb->ih, sp.out_dim);
258 sp.w_tiles = prb->dir == FLAG_FWD ? div_up(prb->ow, sp.out_dim)
259 : div_up(prb->iw, sp.out_dim);
260
261 sp._u_ptr = (float *)zmalloc(
262 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic, 64);
263 sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->ic
264 * prb->mb * sp.h_tiles * sp.w_tiles,
265 64);
266 sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->oc
267 * prb->mb * sp.h_tiles * sp.w_tiles,
268 64);
269
270 if (sp._u_ptr == nullptr || sp._v_ptr == nullptr || sp._m_ptr == nullptr)
271 return dnnl_out_of_memory;
272
273 array_set((char *)sp._u_ptr,
274 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic);
275 array_set((char *)sp._v_ptr,
276 sizeof(float) * sp.alpha * sp.alpha * prb->ic * prb->mb * sp.h_tiles
277 * sp.w_tiles);
278 array_set((char *)sp._m_ptr,
279 sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->mb * sp.h_tiles
280 * sp.w_tiles);
281
282 return OK;
283}
284
285void free_scratchpad(scratchpad_t *sp) {
286 if (sp->_u_ptr != nullptr) zfree(sp->_u_ptr);
287 if (sp->_v_ptr != nullptr) zfree(sp->_v_ptr);
288 if (sp->_m_ptr != nullptr) zfree(sp->_m_ptr);
289}
290
291void compute_wino_ref_fwd(const prb_t *prb, const args_t &args) {
292 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
293 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
294 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
295 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
296 scratchpad_t sp {};
297 SAFE_V(init_scratchpad(prb, sp));
298
299 array_offset_calculator_t<float, 4> U(
300 sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic);
301 array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha,
302 prb->ic, prb->mb, sp.h_tiles, sp.w_tiles);
303 array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha,
304 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
305
306 SAFE_V(prb->kh == 3 ? OK : FAIL);
307 SAFE_V(prb->kw == 3 ? OK : FAIL);
308
309 bool with_bias = prb->dir & FLAG_BIA;
310 const int64_t t_pad = prb->ph;
311 const int64_t l_pad = prb->pw;
312 const int64_t wp_max = prb->iw + l_pad;
313 const int64_t hp_max = prb->ih + t_pad;
314 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
315
316 benchdnn_parallel_nd(prb->mb, prb->ic, sp.h_tiles, sp.w_tiles,
317 [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) {
318 float I[6][6] = {};
319 float _v[6][6] = {};
320 /* src_transform v <- B_t * d * B */
321 for (int64_t j = 0; j < sp.alpha; j++) {
322 int64_t ydim = hfm * sp.out_dim + j;
323 if ((t_pad <= ydim) && (ydim < hp_max)) {
324 for (int64_t k = 0; k < sp.alpha; k++) {
325 int64_t xdim = wfm * sp.out_dim + k;
326 if ((l_pad <= xdim) && (xdim < wp_max)) {
327 size_t src_off = src_off_f(prb, img, 0, c, 0,
328 ydim - t_pad, xdim - l_pad);
329 I[j][k] = ((float *)src_m)[src_off];
330 }
331 }
332 }
333 }
334 trans_I_4x4_3x3(_v, I);
335
336 /* scatter v:V */
337 for (int64_t j = 0; j < sp.alpha; j++) {
338 for (int64_t k = 0; k < sp.alpha; k++) {
339 V(j, k, c, img, hfm, wfm) = _v[j][k];
340 }
341 }
342 });
343
344 benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) {
345 float F[3][3] = {};
346 float _u[6][6] = {};
347 /* wei_transform u <- G * g * G_t */
348 for_(int64_t j = 0; j < prb->kh; j++)
349 for (int64_t i = 0; i < prb->kw; i++) {
350 size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, j, i);
351 F[j][i] = ((float *)wei_m)[wei_off];
352 }
353 trans_W_4x4_3x3(_u, F);
354
355 /* scatter u:U */
356 for_(int64_t j = 0; j < sp.alpha; j++)
357 for (int64_t k = 0; k < sp.alpha; k++) {
358 U(j, k, oc, ic) = _u[j][k];
359 }
360 });
361
362 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
363 /* M = U * V */
364 gemm("C", "N", "N", prb->oc, p_dim, prb->ic, 1.0,
365 (float *)&(U(j, k, 0, 0)), prb->ic,
366 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
367 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
368 });
369
370 auto v_po_masks = prb->attr.post_ops.get_po_masks();
371 benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles,
372 [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) {
373 float O[4][4] = {};
374 float _m[6][6] = {};
375 /* Y = A_t *m * A */
376 for_(int64_t j = 0; j < sp.alpha; j++)
377 for (int64_t k = 0; k < sp.alpha; k++) {
378 _m[j][k] = M(j, k, oc, img, hfm, wfm);
379 }
380 trans_O_4x4_3x3(_m, O);
381
382 for (int64_t j = 0; j < sp.out_dim; j++) {
383 int64_t ydim = hfm * sp.out_dim + j;
384 if (ydim >= prb->oh) continue;
385
386 for (int64_t k = 0; k < sp.out_dim; k++) {
387 float conv_res = O[j][k];
388 int64_t xdim = wfm * sp.out_dim + k;
389 if (xdim >= prb->ow) continue;
390
391 const size_t dst_off
392 = dst_off_f(prb, img, 0, oc, 0, ydim, xdim);
393 float &dst = ((float *)dst_m)[dst_off];
394
395 const size_t bia_off = bia_off_f(prb, 0, oc);
396 conv_res += with_bias ? ((float *)bia_m)[bia_off] : 0.f;
397
398 const auto v_po_vals = prepare_po_vals(
399 dst_m, args, v_po_masks, dst_off);
400
401 maybe_post_ops(prb->attr, conv_res, dst, v_po_vals);
402
403 dst = conv_res;
404 }
405 }
406 });
407 free_scratchpad(&sp);
408}
409
410void compute_wino_ref_bwd_d(const prb_t *prb, const args_t &args) {
411 const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
412 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
413 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
414 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
415 scratchpad_t sp {};
416 SAFE_V(init_scratchpad(prb, sp));
417
418 array_offset_calculator_t<float, 4> U(
419 sp._u_ptr, sp.alpha, sp.alpha, prb->ic, prb->oc);
420 array_offset_calculator_t<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha,
421 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
422 array_offset_calculator_t<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha,
423 prb->ic, prb->mb, sp.h_tiles, sp.w_tiles);
424
425 SAFE_V(prb->kh == 3 ? OK : FAIL);
426 SAFE_V(prb->kw == 3 ? OK : FAIL);
427
428 const int64_t r_pad = MAX2(0, prb->ow - 1 + prb->kw - prb->iw - prb->pw);
429 const int64_t l_pad = prb->iw + r_pad - prb->ow;
430 const int64_t t_pad = prb->ih + prb->ph - prb->oh;
431 const int64_t wp_max = prb->ow + l_pad;
432 const int64_t hp_max = prb->oh + t_pad;
433 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
434
435 bool with_bias = prb->dir & FLAG_BIA;
436
437 benchdnn_parallel_nd(prb->mb, prb->oc, sp.h_tiles, sp.w_tiles,
438 [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) {
439 float I[6][6] = {};
440 float _v[6][6] = {};
441 /* diff_src transform v <- B_t * d * B */
442 for (int64_t j = 0; j < sp.alpha; j++) {
443 int64_t ydim = hfm * sp.out_dim + j;
444 if ((t_pad <= ydim) && (ydim < hp_max)) {
445 for (int64_t k = 0; k < sp.alpha; k++) {
446 int64_t xdim = wfm * sp.out_dim + k;
447 if ((l_pad <= xdim) && (xdim < wp_max)) {
448 size_t dst_off = dst_off_f(prb, img, 0, c, 0,
449 ydim - t_pad, xdim - l_pad);
450 I[j][k] = ((float *)diff_dst_m)[dst_off];
451 }
452 }
453 }
454 trans_I_4x4_3x3(_v, I);
455
456 /* scatter v:V */
457 for_(int64_t j = 0; j < sp.alpha; j++)
458 for (int64_t k = 0; k < sp.alpha; k++) {
459 V(j, k, c, img, hfm, wfm) = _v[j][k];
460 }
461 }
462 });
463
464 benchdnn_parallel_nd(prb->ic, prb->oc, [&](int64_t ic, int64_t oc) {
465 float F[3][3] = {};
466 float _u[6][6] = {};
467 /* wei_transform u <- G * g * G_t */
468 for_(int64_t j = 0; j < prb->kh; j++)
469 for (int64_t i = 0; i < prb->kw; i++) {
470 size_t wei_off = wei_off_f(
471 prb, 0, oc, ic, 0, prb->kh - j - 1, prb->kw - i - 1);
472 F[j][i] = ((float *)wei_m)[wei_off];
473 }
474 trans_W_4x4_3x3(_u, F);
475
476 /* scatter u:U */
477 for_(int64_t j = 0; j < sp.alpha; j++)
478 for (int64_t k = 0; k < sp.alpha; k++) {
479 U(j, k, ic, oc) = _u[j][k];
480 }
481 });
482
483 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
484 /* M = U * V */
485 gemm("C", "N", "N", prb->ic, p_dim, prb->oc, 1.0,
486 (float *)&(U(j, k, 0, 0)), prb->oc,
487 (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0,
488 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim);
489 });
490
491 benchdnn_parallel_nd(prb->ic, prb->mb, sp.h_tiles, sp.w_tiles,
492 [&](int64_t c, int64_t img, int64_t hfm, int64_t wfm) {
493 float O[4][4] = {};
494 float _m[6][6] = {};
495 /* diff_dst: Y = A_t *m * A */
496 for_(int64_t j = 0; j < sp.alpha; j++)
497 for (int64_t k = 0; k < sp.alpha; k++) {
498 _m[j][k] = M(j, k, c, img, hfm, wfm);
499 }
500 trans_O_4x4_3x3(_m, O);
501
502 float bia = with_bias ? ((float *)bia_m)[c] : 0.f;
503
504 for (int64_t j = 0; j < sp.out_dim; j++) {
505 int64_t ydim = hfm * sp.out_dim + j;
506 if (ydim < prb->ih) {
507 for (int64_t k = 0; k < sp.out_dim; k++) {
508 int64_t xdim = wfm * sp.out_dim + k;
509 if (xdim < prb->iw) {
510 size_t src_off = src_off_f(
511 prb, img, 0, c, 0, ydim, xdim);
512 ((float *)diff_src_m)[src_off] = O[j][k] + bia;
513 }
514 }
515 }
516 }
517 });
518
519 free_scratchpad(&sp);
520}
521
522void compute_wino_ref_bwd_w(const prb_t *prb, const args_t &args) {
523 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
524 const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
525 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
526 scratchpad_t sp {};
527 SAFE_V(init_scratchpad(prb, sp));
528
529 array_offset_calculator_t<float, 4> U(
530 sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic);
531 array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha,
532 prb->mb, sp.h_tiles, sp.w_tiles, prb->ic);
533 array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha,
534 prb->oc, prb->mb, sp.h_tiles, sp.w_tiles);
535
536 SAFE_V(prb->kh == 3 ? OK : FAIL);
537 SAFE_V(prb->kw == 3 ? OK : FAIL);
538
539 const int64_t t_pad = prb->ph;
540 const int64_t l_pad = prb->pw;
541 const int64_t wp_max = prb->iw + l_pad;
542 const int64_t hp_max = prb->ih + t_pad;
543 const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles;
544
545 benchdnn_parallel_nd(prb->mb, sp.h_tiles, sp.w_tiles, prb->ic,
546 [&](int64_t img, int64_t hfm, int64_t wfm, int64_t ic) {
547 float I[6][6] = {};
548 float _v[6][6] = {};
549 /* src transform v <- B_t * d * B */
550 for (int64_t j = 0; j < sp.alpha; j++) {
551 int64_t ydim = hfm * sp.out_dim + j;
552 if ((t_pad <= ydim) && (ydim < hp_max)) {
553 for (int64_t k = 0; k < sp.alpha; k++) {
554 int64_t xdim = wfm * sp.out_dim + k;
555 if ((l_pad <= xdim) && (xdim < wp_max)) {
556 size_t src_off = src_off_f(prb, img, 0, ic, 0,
557 ydim - t_pad, xdim - l_pad);
558 I[j][k] = ((float *)src_m)[src_off];
559 }
560 }
561 }
562 }
563 trans_I_4x4_3x3(_v, I);
564
565 /* scatter v:V */
566 for_(int64_t j = 0; j < sp.alpha; j++)
567 for (int64_t k = 0; k < sp.alpha; k++) {
568 V(j, k, img, hfm, wfm, ic) = _v[j][k];
569 }
570 });
571
572 benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles,
573 [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) {
574 float O[6][6] = {};
575 float _m[6][6] = {};
576 /* diff_dst transform */
577 for (int64_t j = 0; j < sp.alpha; j++) {
578 int64_t ydim = hfm * sp.out_dim + j;
579 if (ydim < prb->oh) {
580 for (int64_t k = 0; k < sp.alpha; k++) {
581 int64_t xdim = wfm * sp.out_dim + k;
582 if (xdim < prb->ow) {
583 size_t dst_off = dst_off_f(
584 prb, img, 0, oc, 0, ydim, xdim);
585 O[j][k] = ((float *)diff_dst_m)[dst_off];
586 }
587 }
588 }
589 }
590 trans_W_3x3_4x4_wu(_m, O);
591 /* scatter v:V */
592 for_(int64_t j = 0; j < sp.alpha; j++)
593 for (int64_t k = 0; k < sp.alpha; k++) {
594 M(j, k, oc, img, hfm, wfm) = _m[j][k];
595 }
596 });
597
598 benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) {
599 /* GeMM U = M * V */
600 gemm("C", "N", "N", prb->oc, prb->ic, p_dim, 1.0,
601 (float *)&(M(j, k, 0, 0, 0, 0)), p_dim,
602 (float *)&(V(j, k, 0, 0, 0, 0)), prb->ic, 1.0,
603 (float *)&(U(j, k, 0, 0)), prb->ic);
604 });
605
606 benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) {
607 float F[6][6] = {};
608 float _u[3][3] = {};
609 for_(int64_t j = 0; j < sp.alpha; j++)
610 for (int64_t k = 0; k < sp.alpha; k++) {
611 F[j][k] = U(j, k, oc, ic);
612 }
613 trans_O_3x3_4x4_wu(F, _u);
614
615 /* scatter u:U */
616 for_(int64_t kh = 0; kh < prb->kh; kh++)
617 for (int64_t kw = 0; kw < prb->kw; kw++) {
618 size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, kh, kw);
619 ((float *)diff_wei_m)[wei_off] = _u[kh][kw];
620 }
621 });
622
623 free_scratchpad(&sp);
624
625 if (prb->dir & FLAG_BIA) compute_ref_bwd_bias(prb, args);
626}
627
628} // namespace deconv
629