1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "deconv/ref_deconv.hpp" |
20 | |
21 | // Copy paste from conv/ref_wino.cpp. |
22 | // TODO: remove the duplication. |
23 | |
24 | namespace deconv { |
25 | |
26 | template <typename Telem, size_t Tdims> |
27 | struct array_offset_calculator_t { |
28 | template <typename... Targs> |
29 | array_offset_calculator_t(Telem *base, Targs... Fargs) : _dims {Fargs...} { |
30 | _base_ptr = base; |
31 | } |
32 | template <typename... Targs> |
33 | inline Telem &operator()(Targs... Fargs) { |
34 | return *(_base_ptr + _offset(1, Fargs...)); |
35 | } |
36 | |
37 | private: |
38 | template <typename... Targs> |
39 | inline size_t _offset(size_t const dimension, size_t element) { |
40 | return element; |
41 | } |
42 | |
43 | template <typename... Targs> |
44 | inline size_t _offset( |
45 | size_t const dimension, size_t theta, size_t element) { |
46 | return element + (_dims[dimension] * theta); |
47 | } |
48 | |
49 | template <typename... Targs> |
50 | inline size_t _offset(size_t const dimension, size_t theta, size_t element, |
51 | Targs... Fargs) { |
52 | size_t t_prime = element + (_dims[dimension] * theta); |
53 | return _offset(dimension + 1, t_prime, Fargs...); |
54 | } |
55 | |
56 | Telem *_base_ptr; |
57 | const int64_t _dims[Tdims]; |
58 | }; |
59 | |
60 | void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) { |
61 | float T[6][6]; |
62 | float t0; |
63 | float t1; |
64 | float t2; |
65 | float t3; |
66 | float t4; |
67 | float t5; |
68 | |
69 | for (int i = 0; i < 6; i++) { |
70 | t0 = I[2][i] * -2.25f + I[4][i]; |
71 | t1 = I[1][i] * -2.25f + I[3][i]; |
72 | t2 = I[2][i] * -0.390625f + I[4][i]; |
73 | t3 = I[1][i] * -0.390625f + I[3][i]; |
74 | t4 = I[0][i] * 0.87890625f + I[4][i]; |
75 | t5 = I[1][i] * 0.87890625f + I[5][i]; |
76 | |
77 | T[0][i] = I[2][i] * -2.640625f + t4; |
78 | T[1][i] = t1 * 0.625f + t0; |
79 | T[2][i] = t1 * -0.625f + t0; |
80 | T[3][i] = t3 * 1.5f + t2; |
81 | T[4][i] = t3 * -1.5f + t2; |
82 | T[5][i] = I[3][i] * -2.640625f + t5; |
83 | } |
84 | |
85 | for (int i = 0; i < 6; i++) { |
86 | t0 = T[i][2] * -2.25f + T[i][4]; |
87 | t1 = T[i][1] * -2.25f + T[i][3]; |
88 | t2 = T[i][2] * -0.390625f + T[i][4]; |
89 | t3 = T[i][1] * -0.390625f + T[i][3]; |
90 | t4 = T[i][0] * 0.87890625f + T[i][4]; |
91 | t5 = T[i][1] * 0.87890625f + T[i][5]; |
92 | |
93 | Iw[i][0] = T[i][2] * -2.640625f + t4; |
94 | Iw[i][1] = t1 * 0.625f + t0; |
95 | Iw[i][2] = t1 * -0.625f + t0; |
96 | Iw[i][3] = t3 * 1.5f + t2; |
97 | Iw[i][4] = t3 * -1.5f + t2; |
98 | Iw[i][5] = T[i][3] * -2.640625f + t5; |
99 | } |
100 | } |
101 | |
102 | void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) { |
103 | float Fw[6]; |
104 | float T[6][3]; |
105 | float t0; |
106 | float t1; |
107 | float t2; |
108 | |
109 | for (int i = 0; i < 3; i++) { |
110 | t0 = 0.26890756302521f * F[2][i]; |
111 | t1 = -t0 - 0.688403361344538f * F[0][i]; |
112 | t2 = t0 + 0.119514472455649f * F[0][i]; |
113 | |
114 | T[0][i] = 1.13777777777778f * F[0][i]; |
115 | T[1][i] = t1 - 0.430252100840336f * F[1][i]; |
116 | T[2][i] = t1 + 0.430252100840336f * F[1][i]; |
117 | T[3][i] = t2 + 0.179271708683473f * F[1][i]; |
118 | T[4][i] = t2 - 0.179271708683473f * F[1][i]; |
119 | T[5][i] = F[2][i]; |
120 | } |
121 | |
122 | for (int i = 0; i < 6; i++) { |
123 | t0 = 0.26890756302521f * T[i][2]; |
124 | t1 = -t0 - 0.688403361344538f * T[i][0]; |
125 | t2 = t0 + 0.119514472455649f * T[i][0]; |
126 | |
127 | Fw[0] = 1.13777777777778f * T[i][0]; |
128 | Fw[1] = t1 - 0.430252100840336f * T[i][1]; |
129 | Fw[2] = t1 + 0.430252100840336f * T[i][1]; |
130 | Fw[3] = t2 + 0.179271708683473f * T[i][1]; |
131 | Fw[4] = t2 - 0.179271708683473f * T[i][1]; |
132 | Fw[5] = T[i][2]; |
133 | for (int l = 0; l < 6; l++) { |
134 | Fw_[i][l] = Fw[l]; |
135 | } |
136 | } |
137 | } |
138 | |
139 | void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) { |
140 | float T[4][6]; |
141 | float t0; |
142 | float t1; |
143 | float t2; |
144 | float t3; |
145 | |
146 | for (int i = 0; i < 6; i++) { |
147 | t0 = Mw[1][i] + Mw[2][i]; |
148 | t1 = Mw[3][i] + Mw[4][i]; |
149 | t2 = Mw[1][i] - Mw[2][i]; |
150 | t3 = Mw[3][i] - Mw[4][i]; |
151 | |
152 | T[0][i] = t0 + t1 + Mw[0][i]; |
153 | T[1][i] = t2 * 0.625f + t3 * 1.5f; |
154 | T[2][i] = t0 * 0.390625f + t1 * 2.25f; |
155 | T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i]; |
156 | } |
157 | |
158 | for (int i = 0; i < 4; i++) { |
159 | t0 = T[i][1] + T[i][2]; |
160 | t1 = T[i][3] + T[i][4]; |
161 | t2 = T[i][1] - T[i][2]; |
162 | t3 = T[i][3] - T[i][4]; |
163 | |
164 | O[i][0] = t0 + t1 + T[i][0]; |
165 | O[i][1] = t2 * 0.625f + t3 * 1.5f; |
166 | O[i][2] = t0 * 0.390625f + t1 * 2.25f; |
167 | O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5]; |
168 | } |
169 | } |
170 | |
171 | void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) { |
172 | float T[6][4]; |
173 | float t0; |
174 | float t1; |
175 | float t2; |
176 | float t3; |
177 | float t4; |
178 | |
179 | for (int i = 0; i < 4; i++) { |
180 | t0 = F[2][i] * 0.26890756302521f; |
181 | t1 = F[0][i] * -0.688403361344538f - t0; |
182 | t2 = F[0][i] * 0.119514472455649f + t0; |
183 | t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f; |
184 | t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f; |
185 | |
186 | T[0][i] = F[0][i] * 1.13777777777778f; |
187 | T[1][i] = t1 - t3; |
188 | T[2][i] = t1 + t3; |
189 | T[3][i] = t2 + t4; |
190 | T[4][i] = t2 - t4; |
191 | T[5][i] = F[3][i]; |
192 | } |
193 | |
194 | for (int i = 0; i < 6; i++) { |
195 | t0 = T[i][2] * 0.26890756302521f; |
196 | t1 = T[i][0] * -0.688403361344538f - t0; |
197 | t2 = T[i][0] * 0.119514472455649f + t0; |
198 | t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f; |
199 | t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f; |
200 | |
201 | Fw[i][0] = T[i][0] * 1.13777777777778f; |
202 | Fw[i][1] = t1 - t3; |
203 | Fw[i][2] = t1 + t3; |
204 | Fw[i][3] = t2 + t4; |
205 | Fw[i][4] = t2 - t4; |
206 | Fw[i][5] = T[i][3]; |
207 | } |
208 | } |
209 | |
210 | void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) { |
211 | float T[3][6]; |
212 | float t0; |
213 | float t1; |
214 | float t2; |
215 | float M_[3]; |
216 | |
217 | for (int i = 0; i < 6; i++) { |
218 | t0 = Mw[1][i] + Mw[2][i]; |
219 | t1 = Mw[3][i] + Mw[4][i]; |
220 | t2 = t1 * 2.25f + Mw[5][i]; |
221 | |
222 | T[0][i] = Mw[0][i] + t0 + t1; |
223 | T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) + 1.5f * (Mw[3][i] - Mw[4][i]); |
224 | T[2][i] = t0 * 0.390625f + t2; |
225 | } |
226 | for (int i = 0; i < 3; i++) { |
227 | t0 = T[i][1] + T[i][2]; |
228 | t1 = T[i][3] + T[i][4]; |
229 | t2 = t1 * 2.25f + T[i][5]; |
230 | |
231 | M_[0] = T[i][0] + t0 + t1; |
232 | M_[1] = 0.625f * (T[i][1] - T[i][2]) + 1.5f * (T[i][3] - T[i][4]); |
233 | M_[2] = t0 * 0.390625f + t2; |
234 | |
235 | for (int k = 0; k < 3; k++) { |
236 | M[i][k] = M_[k]; |
237 | } |
238 | } |
239 | } |
240 | |
241 | struct scratchpad_t { |
242 | float *_u_ptr; |
243 | float *_m_ptr; |
244 | float *_v_ptr; |
245 | |
246 | int64_t h_tiles; |
247 | int64_t w_tiles; |
248 | |
249 | const int64_t alpha = 6; |
250 | const int64_t out_dim = 4; |
251 | }; |
252 | |
253 | int init_scratchpad(const prb_t *prb, scratchpad_t &sp) { |
254 | if (sp.out_dim != 4 || sp.alpha != 6) return FAIL; |
255 | |
256 | sp.h_tiles = prb->dir == FLAG_FWD ? div_up(prb->oh, sp.out_dim) |
257 | : div_up(prb->ih, sp.out_dim); |
258 | sp.w_tiles = prb->dir == FLAG_FWD ? div_up(prb->ow, sp.out_dim) |
259 | : div_up(prb->iw, sp.out_dim); |
260 | |
261 | sp._u_ptr = (float *)zmalloc( |
262 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic, 64); |
263 | sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->ic |
264 | * prb->mb * sp.h_tiles * sp.w_tiles, |
265 | 64); |
266 | sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->oc |
267 | * prb->mb * sp.h_tiles * sp.w_tiles, |
268 | 64); |
269 | |
270 | if (sp._u_ptr == nullptr || sp._v_ptr == nullptr || sp._m_ptr == nullptr) |
271 | return dnnl_out_of_memory; |
272 | |
273 | array_set((char *)sp._u_ptr, |
274 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic); |
275 | array_set((char *)sp._v_ptr, |
276 | sizeof(float) * sp.alpha * sp.alpha * prb->ic * prb->mb * sp.h_tiles |
277 | * sp.w_tiles); |
278 | array_set((char *)sp._m_ptr, |
279 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->mb * sp.h_tiles |
280 | * sp.w_tiles); |
281 | |
282 | return OK; |
283 | } |
284 | |
285 | void free_scratchpad(scratchpad_t *sp) { |
286 | if (sp->_u_ptr != nullptr) zfree(sp->_u_ptr); |
287 | if (sp->_v_ptr != nullptr) zfree(sp->_v_ptr); |
288 | if (sp->_m_ptr != nullptr) zfree(sp->_m_ptr); |
289 | } |
290 | |
291 | void compute_wino_ref_fwd(const prb_t *prb, const args_t &args) { |
292 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
293 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
294 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
295 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
296 | scratchpad_t sp {}; |
297 | SAFE_V(init_scratchpad(prb, sp)); |
298 | |
299 | array_offset_calculator_t<float, 4> U( |
300 | sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic); |
301 | array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, |
302 | prb->ic, prb->mb, sp.h_tiles, sp.w_tiles); |
303 | array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, |
304 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
305 | |
306 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
307 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
308 | |
309 | bool with_bias = prb->dir & FLAG_BIA; |
310 | const int64_t t_pad = prb->ph; |
311 | const int64_t l_pad = prb->pw; |
312 | const int64_t wp_max = prb->iw + l_pad; |
313 | const int64_t hp_max = prb->ih + t_pad; |
314 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
315 | |
316 | benchdnn_parallel_nd(prb->mb, prb->ic, sp.h_tiles, sp.w_tiles, |
317 | [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) { |
318 | float I[6][6] = {}; |
319 | float _v[6][6] = {}; |
320 | /* src_transform v <- B_t * d * B */ |
321 | for (int64_t j = 0; j < sp.alpha; j++) { |
322 | int64_t ydim = hfm * sp.out_dim + j; |
323 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
324 | for (int64_t k = 0; k < sp.alpha; k++) { |
325 | int64_t xdim = wfm * sp.out_dim + k; |
326 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
327 | size_t src_off = src_off_f(prb, img, 0, c, 0, |
328 | ydim - t_pad, xdim - l_pad); |
329 | I[j][k] = ((float *)src_m)[src_off]; |
330 | } |
331 | } |
332 | } |
333 | } |
334 | trans_I_4x4_3x3(_v, I); |
335 | |
336 | /* scatter v:V */ |
337 | for (int64_t j = 0; j < sp.alpha; j++) { |
338 | for (int64_t k = 0; k < sp.alpha; k++) { |
339 | V(j, k, c, img, hfm, wfm) = _v[j][k]; |
340 | } |
341 | } |
342 | }); |
343 | |
344 | benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) { |
345 | float F[3][3] = {}; |
346 | float _u[6][6] = {}; |
347 | /* wei_transform u <- G * g * G_t */ |
348 | for_(int64_t j = 0; j < prb->kh; j++) |
349 | for (int64_t i = 0; i < prb->kw; i++) { |
350 | size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, j, i); |
351 | F[j][i] = ((float *)wei_m)[wei_off]; |
352 | } |
353 | trans_W_4x4_3x3(_u, F); |
354 | |
355 | /* scatter u:U */ |
356 | for_(int64_t j = 0; j < sp.alpha; j++) |
357 | for (int64_t k = 0; k < sp.alpha; k++) { |
358 | U(j, k, oc, ic) = _u[j][k]; |
359 | } |
360 | }); |
361 | |
362 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
363 | /* M = U * V */ |
364 | gemm("C" , "N" , "N" , prb->oc, p_dim, prb->ic, 1.0, |
365 | (float *)&(U(j, k, 0, 0)), prb->ic, |
366 | (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0, |
367 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim); |
368 | }); |
369 | |
370 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
371 | benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles, |
372 | [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) { |
373 | float O[4][4] = {}; |
374 | float _m[6][6] = {}; |
375 | /* Y = A_t *m * A */ |
376 | for_(int64_t j = 0; j < sp.alpha; j++) |
377 | for (int64_t k = 0; k < sp.alpha; k++) { |
378 | _m[j][k] = M(j, k, oc, img, hfm, wfm); |
379 | } |
380 | trans_O_4x4_3x3(_m, O); |
381 | |
382 | for (int64_t j = 0; j < sp.out_dim; j++) { |
383 | int64_t ydim = hfm * sp.out_dim + j; |
384 | if (ydim >= prb->oh) continue; |
385 | |
386 | for (int64_t k = 0; k < sp.out_dim; k++) { |
387 | float conv_res = O[j][k]; |
388 | int64_t xdim = wfm * sp.out_dim + k; |
389 | if (xdim >= prb->ow) continue; |
390 | |
391 | const size_t dst_off |
392 | = dst_off_f(prb, img, 0, oc, 0, ydim, xdim); |
393 | float &dst = ((float *)dst_m)[dst_off]; |
394 | |
395 | const size_t bia_off = bia_off_f(prb, 0, oc); |
396 | conv_res += with_bias ? ((float *)bia_m)[bia_off] : 0.f; |
397 | |
398 | const auto v_po_vals = prepare_po_vals( |
399 | dst_m, args, v_po_masks, dst_off); |
400 | |
401 | maybe_post_ops(prb->attr, conv_res, dst, v_po_vals); |
402 | |
403 | dst = conv_res; |
404 | } |
405 | } |
406 | }); |
407 | free_scratchpad(&sp); |
408 | } |
409 | |
410 | void compute_wino_ref_bwd_d(const prb_t *prb, const args_t &args) { |
411 | const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC); |
412 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
413 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
414 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
415 | scratchpad_t sp {}; |
416 | SAFE_V(init_scratchpad(prb, sp)); |
417 | |
418 | array_offset_calculator_t<float, 4> U( |
419 | sp._u_ptr, sp.alpha, sp.alpha, prb->ic, prb->oc); |
420 | array_offset_calculator_t<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha, |
421 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
422 | array_offset_calculator_t<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha, |
423 | prb->ic, prb->mb, sp.h_tiles, sp.w_tiles); |
424 | |
425 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
426 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
427 | |
428 | const int64_t r_pad = MAX2(0, prb->ow - 1 + prb->kw - prb->iw - prb->pw); |
429 | const int64_t l_pad = prb->iw + r_pad - prb->ow; |
430 | const int64_t t_pad = prb->ih + prb->ph - prb->oh; |
431 | const int64_t wp_max = prb->ow + l_pad; |
432 | const int64_t hp_max = prb->oh + t_pad; |
433 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
434 | |
435 | bool with_bias = prb->dir & FLAG_BIA; |
436 | |
437 | benchdnn_parallel_nd(prb->mb, prb->oc, sp.h_tiles, sp.w_tiles, |
438 | [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) { |
439 | float I[6][6] = {}; |
440 | float _v[6][6] = {}; |
441 | /* diff_src transform v <- B_t * d * B */ |
442 | for (int64_t j = 0; j < sp.alpha; j++) { |
443 | int64_t ydim = hfm * sp.out_dim + j; |
444 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
445 | for (int64_t k = 0; k < sp.alpha; k++) { |
446 | int64_t xdim = wfm * sp.out_dim + k; |
447 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
448 | size_t dst_off = dst_off_f(prb, img, 0, c, 0, |
449 | ydim - t_pad, xdim - l_pad); |
450 | I[j][k] = ((float *)diff_dst_m)[dst_off]; |
451 | } |
452 | } |
453 | } |
454 | trans_I_4x4_3x3(_v, I); |
455 | |
456 | /* scatter v:V */ |
457 | for_(int64_t j = 0; j < sp.alpha; j++) |
458 | for (int64_t k = 0; k < sp.alpha; k++) { |
459 | V(j, k, c, img, hfm, wfm) = _v[j][k]; |
460 | } |
461 | } |
462 | }); |
463 | |
464 | benchdnn_parallel_nd(prb->ic, prb->oc, [&](int64_t ic, int64_t oc) { |
465 | float F[3][3] = {}; |
466 | float _u[6][6] = {}; |
467 | /* wei_transform u <- G * g * G_t */ |
468 | for_(int64_t j = 0; j < prb->kh; j++) |
469 | for (int64_t i = 0; i < prb->kw; i++) { |
470 | size_t wei_off = wei_off_f( |
471 | prb, 0, oc, ic, 0, prb->kh - j - 1, prb->kw - i - 1); |
472 | F[j][i] = ((float *)wei_m)[wei_off]; |
473 | } |
474 | trans_W_4x4_3x3(_u, F); |
475 | |
476 | /* scatter u:U */ |
477 | for_(int64_t j = 0; j < sp.alpha; j++) |
478 | for (int64_t k = 0; k < sp.alpha; k++) { |
479 | U(j, k, ic, oc) = _u[j][k]; |
480 | } |
481 | }); |
482 | |
483 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
484 | /* M = U * V */ |
485 | gemm("C" , "N" , "N" , prb->ic, p_dim, prb->oc, 1.0, |
486 | (float *)&(U(j, k, 0, 0)), prb->oc, |
487 | (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0, |
488 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim); |
489 | }); |
490 | |
491 | benchdnn_parallel_nd(prb->ic, prb->mb, sp.h_tiles, sp.w_tiles, |
492 | [&](int64_t c, int64_t img, int64_t hfm, int64_t wfm) { |
493 | float O[4][4] = {}; |
494 | float _m[6][6] = {}; |
495 | /* diff_dst: Y = A_t *m * A */ |
496 | for_(int64_t j = 0; j < sp.alpha; j++) |
497 | for (int64_t k = 0; k < sp.alpha; k++) { |
498 | _m[j][k] = M(j, k, c, img, hfm, wfm); |
499 | } |
500 | trans_O_4x4_3x3(_m, O); |
501 | |
502 | float bia = with_bias ? ((float *)bia_m)[c] : 0.f; |
503 | |
504 | for (int64_t j = 0; j < sp.out_dim; j++) { |
505 | int64_t ydim = hfm * sp.out_dim + j; |
506 | if (ydim < prb->ih) { |
507 | for (int64_t k = 0; k < sp.out_dim; k++) { |
508 | int64_t xdim = wfm * sp.out_dim + k; |
509 | if (xdim < prb->iw) { |
510 | size_t src_off = src_off_f( |
511 | prb, img, 0, c, 0, ydim, xdim); |
512 | ((float *)diff_src_m)[src_off] = O[j][k] + bia; |
513 | } |
514 | } |
515 | } |
516 | } |
517 | }); |
518 | |
519 | free_scratchpad(&sp); |
520 | } |
521 | |
522 | void compute_wino_ref_bwd_w(const prb_t *prb, const args_t &args) { |
523 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
524 | const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS); |
525 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
526 | scratchpad_t sp {}; |
527 | SAFE_V(init_scratchpad(prb, sp)); |
528 | |
529 | array_offset_calculator_t<float, 4> U( |
530 | sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic); |
531 | array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, |
532 | prb->mb, sp.h_tiles, sp.w_tiles, prb->ic); |
533 | array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, |
534 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
535 | |
536 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
537 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
538 | |
539 | const int64_t t_pad = prb->ph; |
540 | const int64_t l_pad = prb->pw; |
541 | const int64_t wp_max = prb->iw + l_pad; |
542 | const int64_t hp_max = prb->ih + t_pad; |
543 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
544 | |
545 | benchdnn_parallel_nd(prb->mb, sp.h_tiles, sp.w_tiles, prb->ic, |
546 | [&](int64_t img, int64_t hfm, int64_t wfm, int64_t ic) { |
547 | float I[6][6] = {}; |
548 | float _v[6][6] = {}; |
549 | /* src transform v <- B_t * d * B */ |
550 | for (int64_t j = 0; j < sp.alpha; j++) { |
551 | int64_t ydim = hfm * sp.out_dim + j; |
552 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
553 | for (int64_t k = 0; k < sp.alpha; k++) { |
554 | int64_t xdim = wfm * sp.out_dim + k; |
555 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
556 | size_t src_off = src_off_f(prb, img, 0, ic, 0, |
557 | ydim - t_pad, xdim - l_pad); |
558 | I[j][k] = ((float *)src_m)[src_off]; |
559 | } |
560 | } |
561 | } |
562 | } |
563 | trans_I_4x4_3x3(_v, I); |
564 | |
565 | /* scatter v:V */ |
566 | for_(int64_t j = 0; j < sp.alpha; j++) |
567 | for (int64_t k = 0; k < sp.alpha; k++) { |
568 | V(j, k, img, hfm, wfm, ic) = _v[j][k]; |
569 | } |
570 | }); |
571 | |
572 | benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles, |
573 | [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) { |
574 | float O[6][6] = {}; |
575 | float _m[6][6] = {}; |
576 | /* diff_dst transform */ |
577 | for (int64_t j = 0; j < sp.alpha; j++) { |
578 | int64_t ydim = hfm * sp.out_dim + j; |
579 | if (ydim < prb->oh) { |
580 | for (int64_t k = 0; k < sp.alpha; k++) { |
581 | int64_t xdim = wfm * sp.out_dim + k; |
582 | if (xdim < prb->ow) { |
583 | size_t dst_off = dst_off_f( |
584 | prb, img, 0, oc, 0, ydim, xdim); |
585 | O[j][k] = ((float *)diff_dst_m)[dst_off]; |
586 | } |
587 | } |
588 | } |
589 | } |
590 | trans_W_3x3_4x4_wu(_m, O); |
591 | /* scatter v:V */ |
592 | for_(int64_t j = 0; j < sp.alpha; j++) |
593 | for (int64_t k = 0; k < sp.alpha; k++) { |
594 | M(j, k, oc, img, hfm, wfm) = _m[j][k]; |
595 | } |
596 | }); |
597 | |
598 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
599 | /* GeMM U = M * V */ |
600 | gemm("C" , "N" , "N" , prb->oc, prb->ic, p_dim, 1.0, |
601 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim, |
602 | (float *)&(V(j, k, 0, 0, 0, 0)), prb->ic, 1.0, |
603 | (float *)&(U(j, k, 0, 0)), prb->ic); |
604 | }); |
605 | |
606 | benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) { |
607 | float F[6][6] = {}; |
608 | float _u[3][3] = {}; |
609 | for_(int64_t j = 0; j < sp.alpha; j++) |
610 | for (int64_t k = 0; k < sp.alpha; k++) { |
611 | F[j][k] = U(j, k, oc, ic); |
612 | } |
613 | trans_O_3x3_4x4_wu(F, _u); |
614 | |
615 | /* scatter u:U */ |
616 | for_(int64_t kh = 0; kh < prb->kh; kh++) |
617 | for (int64_t kw = 0; kw < prb->kw; kw++) { |
618 | size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, kh, kw); |
619 | ((float *)diff_wei_m)[wei_off] = _u[kh][kw]; |
620 | } |
621 | }); |
622 | |
623 | free_scratchpad(&sp); |
624 | |
625 | if (prb->dir & FLAG_BIA) compute_ref_bwd_bias(prb, args); |
626 | } |
627 | |
628 | } // namespace deconv |
629 | |