1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "conv/ref_conv.hpp" |
20 | |
21 | namespace conv { |
22 | |
23 | template <typename Telem, size_t Tdims> |
24 | struct array_offset_calculator_t { |
25 | template <typename... Targs> |
26 | array_offset_calculator_t(Telem *base, Targs... Fargs) : _dims {Fargs...} { |
27 | _base_ptr = base; |
28 | } |
29 | template <typename... Targs> |
30 | inline Telem &operator()(Targs... Fargs) { |
31 | return *(_base_ptr + _offset(1, Fargs...)); |
32 | } |
33 | |
34 | private: |
35 | template <typename... Targs> |
36 | inline size_t _offset(size_t const dimension, size_t element) { |
37 | return element; |
38 | } |
39 | |
40 | template <typename... Targs> |
41 | inline size_t _offset( |
42 | size_t const dimension, size_t theta, size_t element) { |
43 | return element + (_dims[dimension] * theta); |
44 | } |
45 | |
46 | template <typename... Targs> |
47 | inline size_t _offset(size_t const dimension, size_t theta, size_t element, |
48 | Targs... Fargs) { |
49 | size_t t_prime = element + (_dims[dimension] * theta); |
50 | return _offset(dimension + 1, t_prime, Fargs...); |
51 | } |
52 | |
53 | Telem *_base_ptr; |
54 | const int64_t _dims[Tdims]; |
55 | }; |
56 | |
57 | void trans_I_4x4_3x3(float Iw[6][6], float I[6][6]) { |
58 | float T[6][6]; |
59 | float t0; |
60 | float t1; |
61 | float t2; |
62 | float t3; |
63 | float t4; |
64 | float t5; |
65 | |
66 | for (int i = 0; i < 6; i++) { |
67 | t0 = I[2][i] * -2.25f + I[4][i]; |
68 | t1 = I[1][i] * -2.25f + I[3][i]; |
69 | t2 = I[2][i] * -0.390625f + I[4][i]; |
70 | t3 = I[1][i] * -0.390625f + I[3][i]; |
71 | t4 = I[0][i] * 0.87890625f + I[4][i]; |
72 | t5 = I[1][i] * 0.87890625f + I[5][i]; |
73 | |
74 | T[0][i] = I[2][i] * -2.640625f + t4; |
75 | T[1][i] = t1 * 0.625f + t0; |
76 | T[2][i] = t1 * -0.625f + t0; |
77 | T[3][i] = t3 * 1.5f + t2; |
78 | T[4][i] = t3 * -1.5f + t2; |
79 | T[5][i] = I[3][i] * -2.640625f + t5; |
80 | } |
81 | |
82 | for (int i = 0; i < 6; i++) { |
83 | t0 = T[i][2] * -2.25f + T[i][4]; |
84 | t1 = T[i][1] * -2.25f + T[i][3]; |
85 | t2 = T[i][2] * -0.390625f + T[i][4]; |
86 | t3 = T[i][1] * -0.390625f + T[i][3]; |
87 | t4 = T[i][0] * 0.87890625f + T[i][4]; |
88 | t5 = T[i][1] * 0.87890625f + T[i][5]; |
89 | |
90 | Iw[i][0] = T[i][2] * -2.640625f + t4; |
91 | Iw[i][1] = t1 * 0.625f + t0; |
92 | Iw[i][2] = t1 * -0.625f + t0; |
93 | Iw[i][3] = t3 * 1.5f + t2; |
94 | Iw[i][4] = t3 * -1.5f + t2; |
95 | Iw[i][5] = T[i][3] * -2.640625f + t5; |
96 | } |
97 | } |
98 | |
99 | void trans_W_4x4_3x3(float Fw_[6][6], float F[3][3]) { |
100 | float Fw[6]; |
101 | float T[6][3]; |
102 | float t0; |
103 | float t1; |
104 | float t2; |
105 | |
106 | for (int i = 0; i < 3; i++) { |
107 | t0 = 0.26890756302521f * F[2][i]; |
108 | t1 = -t0 - 0.688403361344538f * F[0][i]; |
109 | t2 = t0 + 0.119514472455649f * F[0][i]; |
110 | |
111 | T[0][i] = 1.13777777777778f * F[0][i]; |
112 | T[1][i] = t1 - 0.430252100840336f * F[1][i]; |
113 | T[2][i] = t1 + 0.430252100840336f * F[1][i]; |
114 | T[3][i] = t2 + 0.179271708683473f * F[1][i]; |
115 | T[4][i] = t2 - 0.179271708683473f * F[1][i]; |
116 | T[5][i] = F[2][i]; |
117 | } |
118 | |
119 | for (int i = 0; i < 6; i++) { |
120 | t0 = 0.26890756302521f * T[i][2]; |
121 | t1 = -t0 - 0.688403361344538f * T[i][0]; |
122 | t2 = t0 + 0.119514472455649f * T[i][0]; |
123 | |
124 | Fw[0] = 1.13777777777778f * T[i][0]; |
125 | Fw[1] = t1 - 0.430252100840336f * T[i][1]; |
126 | Fw[2] = t1 + 0.430252100840336f * T[i][1]; |
127 | Fw[3] = t2 + 0.179271708683473f * T[i][1]; |
128 | Fw[4] = t2 - 0.179271708683473f * T[i][1]; |
129 | Fw[5] = T[i][2]; |
130 | for (int l = 0; l < 6; l++) { |
131 | Fw_[i][l] = Fw[l]; |
132 | } |
133 | } |
134 | } |
135 | |
136 | void trans_O_4x4_3x3(float Mw[6][6], float O[4][4]) { |
137 | float T[4][6]; |
138 | float t0; |
139 | float t1; |
140 | float t2; |
141 | float t3; |
142 | |
143 | for (int i = 0; i < 6; i++) { |
144 | t0 = Mw[1][i] + Mw[2][i]; |
145 | t1 = Mw[3][i] + Mw[4][i]; |
146 | t2 = Mw[1][i] - Mw[2][i]; |
147 | t3 = Mw[3][i] - Mw[4][i]; |
148 | |
149 | T[0][i] = t0 + t1 + Mw[0][i]; |
150 | T[1][i] = t2 * 0.625f + t3 * 1.5f; |
151 | T[2][i] = t0 * 0.390625f + t1 * 2.25f; |
152 | T[3][i] = t2 * 0.244140625f + t3 * 3.375f + Mw[5][i]; |
153 | } |
154 | |
155 | for (int i = 0; i < 4; i++) { |
156 | t0 = T[i][1] + T[i][2]; |
157 | t1 = T[i][3] + T[i][4]; |
158 | t2 = T[i][1] - T[i][2]; |
159 | t3 = T[i][3] - T[i][4]; |
160 | |
161 | O[i][0] = t0 + t1 + T[i][0]; |
162 | O[i][1] = t2 * 0.625f + t3 * 1.5f; |
163 | O[i][2] = t0 * 0.390625f + t1 * 2.25f; |
164 | O[i][3] = t2 * 0.244140625f + t3 * 3.375f + T[i][5]; |
165 | } |
166 | } |
167 | |
168 | void trans_W_3x3_4x4_wu(float Fw[6][6], float F[4][6]) { |
169 | float T[6][4]; |
170 | float t0; |
171 | float t1; |
172 | float t2; |
173 | float t3; |
174 | float t4; |
175 | |
176 | for (int i = 0; i < 4; i++) { |
177 | t0 = F[2][i] * 0.26890756302521f; |
178 | t1 = F[0][i] * -0.688403361344538f - t0; |
179 | t2 = F[0][i] * 0.119514472455649f + t0; |
180 | t3 = F[1][i] * 0.430252100840336f + F[3][i] * 0.168067226890756f; |
181 | t4 = F[1][i] * 0.179271708683473f + F[3][i] * 0.403361344537815f; |
182 | |
183 | T[0][i] = F[0][i] * 1.13777777777778f; |
184 | T[1][i] = t1 - t3; |
185 | T[2][i] = t1 + t3; |
186 | T[3][i] = t2 + t4; |
187 | T[4][i] = t2 - t4; |
188 | T[5][i] = F[3][i]; |
189 | } |
190 | |
191 | for (int i = 0; i < 6; i++) { |
192 | t0 = T[i][2] * 0.26890756302521f; |
193 | t1 = T[i][0] * -0.688403361344538f - t0; |
194 | t2 = T[i][0] * 0.119514472455649f + t0; |
195 | t3 = T[i][1] * 0.430252100840336f + T[i][3] * 0.168067226890756f; |
196 | t4 = T[i][1] * 0.179271708683473f + T[i][3] * 0.403361344537815f; |
197 | |
198 | Fw[i][0] = T[i][0] * 1.13777777777778f; |
199 | Fw[i][1] = t1 - t3; |
200 | Fw[i][2] = t1 + t3; |
201 | Fw[i][3] = t2 + t4; |
202 | Fw[i][4] = t2 - t4; |
203 | Fw[i][5] = T[i][3]; |
204 | } |
205 | } |
206 | |
207 | void trans_O_3x3_4x4_wu(float Mw[6][6], float M[3][3]) { |
208 | float T[3][6]; |
209 | float t0; |
210 | float t1; |
211 | float t2; |
212 | float M_[3]; |
213 | |
214 | for (int i = 0; i < 6; i++) { |
215 | t0 = Mw[1][i] + Mw[2][i]; |
216 | t1 = Mw[3][i] + Mw[4][i]; |
217 | t2 = t1 * 2.25f + Mw[5][i]; |
218 | |
219 | T[0][i] = Mw[0][i] + t0 + t1; |
220 | T[1][i] = 0.625f * (Mw[1][i] - Mw[2][i]) + 1.5f * (Mw[3][i] - Mw[4][i]); |
221 | T[2][i] = t0 * 0.390625f + t2; |
222 | } |
223 | for (int i = 0; i < 3; i++) { |
224 | t0 = T[i][1] + T[i][2]; |
225 | t1 = T[i][3] + T[i][4]; |
226 | t2 = t1 * 2.25f + T[i][5]; |
227 | |
228 | M_[0] = T[i][0] + t0 + t1; |
229 | M_[1] = 0.625f * (T[i][1] - T[i][2]) + 1.5f * (T[i][3] - T[i][4]); |
230 | M_[2] = t0 * 0.390625f + t2; |
231 | |
232 | for (int k = 0; k < 3; k++) { |
233 | M[i][k] = M_[k]; |
234 | } |
235 | } |
236 | } |
237 | |
238 | struct scratchpad_t { |
239 | float *_u_ptr; |
240 | float *_m_ptr; |
241 | float *_v_ptr; |
242 | |
243 | int64_t h_tiles; |
244 | int64_t w_tiles; |
245 | |
246 | const int64_t alpha = 6; |
247 | const int64_t out_dim = 4; |
248 | }; |
249 | |
250 | int init_scratchpad(const prb_t *prb, scratchpad_t &sp) { |
251 | if (sp.out_dim != 4 || sp.alpha != 6) return FAIL; |
252 | |
253 | sp.h_tiles = prb->dir == FLAG_FWD ? div_up(prb->oh, sp.out_dim) |
254 | : div_up(prb->ih, sp.out_dim); |
255 | sp.w_tiles = prb->dir == FLAG_FWD ? div_up(prb->ow, sp.out_dim) |
256 | : div_up(prb->iw, sp.out_dim); |
257 | |
258 | sp._u_ptr = (float *)zmalloc( |
259 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic, 64); |
260 | sp._v_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->ic |
261 | * prb->mb * sp.h_tiles * sp.w_tiles, |
262 | 64); |
263 | sp._m_ptr = (float *)zmalloc(sizeof(float) * sp.alpha * sp.alpha * prb->oc |
264 | * prb->mb * sp.h_tiles * sp.w_tiles, |
265 | 64); |
266 | |
267 | if (sp._u_ptr == nullptr || sp._v_ptr == nullptr || sp._m_ptr == nullptr) |
268 | return dnnl_out_of_memory; |
269 | |
270 | array_set((char *)sp._u_ptr, |
271 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->ic); |
272 | array_set((char *)sp._v_ptr, |
273 | sizeof(float) * sp.alpha * sp.alpha * prb->ic * prb->mb * sp.h_tiles |
274 | * sp.w_tiles); |
275 | array_set((char *)sp._m_ptr, |
276 | sizeof(float) * sp.alpha * sp.alpha * prb->oc * prb->mb * sp.h_tiles |
277 | * sp.w_tiles); |
278 | |
279 | return OK; |
280 | } |
281 | |
282 | void free_scratchpad(scratchpad_t *sp) { |
283 | if (sp->_u_ptr != nullptr) zfree(sp->_u_ptr); |
284 | if (sp->_v_ptr != nullptr) zfree(sp->_v_ptr); |
285 | if (sp->_m_ptr != nullptr) zfree(sp->_m_ptr); |
286 | } |
287 | |
288 | void compute_wino_ref_fwd(const prb_t *prb, const args_t &args) { |
289 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
290 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
291 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
292 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
293 | scratchpad_t sp {}; |
294 | SAFE_V(init_scratchpad(prb, sp)); |
295 | |
296 | array_offset_calculator_t<float, 4> U( |
297 | sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic); |
298 | array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, |
299 | prb->ic, prb->mb, sp.h_tiles, sp.w_tiles); |
300 | array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, |
301 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
302 | |
303 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
304 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
305 | |
306 | bool with_bias = prb->dir & FLAG_BIA; |
307 | const int64_t t_pad = prb->ph; |
308 | const int64_t l_pad = prb->pw; |
309 | const int64_t wp_max = prb->iw + l_pad; |
310 | const int64_t hp_max = prb->ih + t_pad; |
311 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
312 | |
313 | benchdnn_parallel_nd(prb->mb, prb->ic, sp.h_tiles, sp.w_tiles, |
314 | [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) { |
315 | float I[6][6] = {}; |
316 | float _v[6][6] = {}; |
317 | /* src_transform v <- B_t * d * B */ |
318 | for (int64_t j = 0; j < sp.alpha; j++) { |
319 | int64_t ydim = hfm * sp.out_dim + j; |
320 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
321 | for (int64_t k = 0; k < sp.alpha; k++) { |
322 | int64_t xdim = wfm * sp.out_dim + k; |
323 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
324 | size_t src_off = src_off_f(prb, img, 0, c, 0, |
325 | ydim - t_pad, xdim - l_pad); |
326 | I[j][k] = ((float *)src_m)[src_off]; |
327 | } |
328 | } |
329 | } |
330 | } |
331 | trans_I_4x4_3x3(_v, I); |
332 | |
333 | /* scatter v:V */ |
334 | for (int64_t j = 0; j < sp.alpha; j++) { |
335 | for (int64_t k = 0; k < sp.alpha; k++) { |
336 | V(j, k, c, img, hfm, wfm) = _v[j][k]; |
337 | } |
338 | } |
339 | }); |
340 | |
341 | benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) { |
342 | float F[3][3] = {}; |
343 | float _u[6][6] = {}; |
344 | /* wei_transform u <- G * g * G_t */ |
345 | for_(int64_t j = 0; j < prb->kh; j++) |
346 | for (int64_t i = 0; i < prb->kw; i++) { |
347 | size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, j, i); |
348 | F[j][i] = ((float *)wei_m)[wei_off]; |
349 | } |
350 | trans_W_4x4_3x3(_u, F); |
351 | |
352 | /* scatter u:U */ |
353 | for_(int64_t j = 0; j < sp.alpha; j++) |
354 | for (int64_t k = 0; k < sp.alpha; k++) { |
355 | U(j, k, oc, ic) = _u[j][k]; |
356 | } |
357 | }); |
358 | |
359 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
360 | /* M = U * V */ |
361 | gemm("C" , "N" , "N" , prb->oc, p_dim, prb->ic, 1.0, |
362 | (float *)&(U(j, k, 0, 0)), prb->ic, |
363 | (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0, |
364 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim); |
365 | }); |
366 | |
367 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
368 | benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles, |
369 | [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) { |
370 | float O[4][4] = {}; |
371 | float _m[6][6] = {}; |
372 | /* Y = A_t *m * A */ |
373 | for_(int64_t j = 0; j < sp.alpha; j++) |
374 | for (int64_t k = 0; k < sp.alpha; k++) { |
375 | _m[j][k] = M(j, k, oc, img, hfm, wfm); |
376 | } |
377 | trans_O_4x4_3x3(_m, O); |
378 | |
379 | for (int64_t j = 0; j < sp.out_dim; j++) { |
380 | int64_t ydim = hfm * sp.out_dim + j; |
381 | if (ydim >= prb->oh) continue; |
382 | |
383 | for (int64_t k = 0; k < sp.out_dim; k++) { |
384 | float conv_res = O[j][k]; |
385 | int64_t xdim = wfm * sp.out_dim + k; |
386 | if (xdim >= prb->ow) continue; |
387 | |
388 | const size_t dst_off |
389 | = dst_off_f(prb, img, 0, oc, 0, ydim, xdim); |
390 | float &dst = ((float *)dst_m)[dst_off]; |
391 | |
392 | const size_t bia_off = bia_off_f(prb, 0, oc); |
393 | conv_res += with_bias ? ((float *)bia_m)[bia_off] : 0.f; |
394 | |
395 | const auto v_po_vals = prepare_po_vals( |
396 | dst_m, args, v_po_masks, dst_off); |
397 | |
398 | maybe_post_ops(prb->attr, conv_res, dst, v_po_vals); |
399 | |
400 | dst = conv_res; |
401 | } |
402 | } |
403 | }); |
404 | free_scratchpad(&sp); |
405 | } |
406 | |
407 | void compute_wino_ref_bwd_d(const prb_t *prb, const args_t &args) { |
408 | const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC); |
409 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
410 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
411 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
412 | scratchpad_t sp {}; |
413 | SAFE_V(init_scratchpad(prb, sp)); |
414 | |
415 | array_offset_calculator_t<float, 4> U( |
416 | sp._u_ptr, sp.alpha, sp.alpha, prb->ic, prb->oc); |
417 | array_offset_calculator_t<float, 6> V(sp._m_ptr, sp.alpha, sp.alpha, |
418 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
419 | array_offset_calculator_t<float, 6> M(sp._v_ptr, sp.alpha, sp.alpha, |
420 | prb->ic, prb->mb, sp.h_tiles, sp.w_tiles); |
421 | |
422 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
423 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
424 | |
425 | const int64_t r_pad = MAX2(0, prb->ow - 1 + prb->kw - prb->iw - prb->pw); |
426 | const int64_t l_pad = prb->iw + r_pad - prb->ow; |
427 | const int64_t t_pad = prb->ih + prb->ph - prb->oh; |
428 | const int64_t wp_max = prb->ow + l_pad; |
429 | const int64_t hp_max = prb->oh + t_pad; |
430 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
431 | |
432 | bool with_bias = prb->dir & FLAG_BIA; |
433 | |
434 | benchdnn_parallel_nd(prb->mb, prb->oc, sp.h_tiles, sp.w_tiles, |
435 | [&](int64_t img, int64_t c, int64_t hfm, int64_t wfm) { |
436 | float I[6][6] = {}; |
437 | float _v[6][6] = {}; |
438 | /* diff_src transform v <- B_t * d * B */ |
439 | for (int64_t j = 0; j < sp.alpha; j++) { |
440 | int64_t ydim = hfm * sp.out_dim + j; |
441 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
442 | for (int64_t k = 0; k < sp.alpha; k++) { |
443 | int64_t xdim = wfm * sp.out_dim + k; |
444 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
445 | size_t dst_off = dst_off_f(prb, img, 0, c, 0, |
446 | ydim - t_pad, xdim - l_pad); |
447 | I[j][k] = ((float *)diff_dst_m)[dst_off]; |
448 | } |
449 | } |
450 | } |
451 | trans_I_4x4_3x3(_v, I); |
452 | |
453 | /* scatter v:V */ |
454 | for_(int64_t j = 0; j < sp.alpha; j++) |
455 | for (int64_t k = 0; k < sp.alpha; k++) { |
456 | V(j, k, c, img, hfm, wfm) = _v[j][k]; |
457 | } |
458 | } |
459 | }); |
460 | |
461 | benchdnn_parallel_nd(prb->ic, prb->oc, [&](int64_t ic, int64_t oc) { |
462 | float F[3][3] = {}; |
463 | float _u[6][6] = {}; |
464 | /* wei_transform u <- G * g * G_t */ |
465 | for_(int64_t j = 0; j < prb->kh; j++) |
466 | for (int64_t i = 0; i < prb->kw; i++) { |
467 | size_t wei_off = wei_off_f( |
468 | prb, 0, oc, ic, 0, prb->kh - j - 1, prb->kw - i - 1); |
469 | F[j][i] = ((float *)wei_m)[wei_off]; |
470 | } |
471 | trans_W_4x4_3x3(_u, F); |
472 | |
473 | /* scatter u:U */ |
474 | for_(int64_t j = 0; j < sp.alpha; j++) |
475 | for (int64_t k = 0; k < sp.alpha; k++) { |
476 | U(j, k, ic, oc) = _u[j][k]; |
477 | } |
478 | }); |
479 | |
480 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
481 | /* M = U * V */ |
482 | gemm("C" , "N" , "N" , prb->ic, p_dim, prb->oc, 1.0, |
483 | (float *)&(U(j, k, 0, 0)), prb->oc, |
484 | (float *)&(V(j, k, 0, 0, 0, 0)), p_dim, 1.0, |
485 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim); |
486 | }); |
487 | |
488 | benchdnn_parallel_nd(prb->ic, prb->mb, sp.h_tiles, sp.w_tiles, |
489 | [&](int64_t c, int64_t img, int64_t hfm, int64_t wfm) { |
490 | float O[4][4] = {}; |
491 | float _m[6][6] = {}; |
492 | /* diff_dst: Y = A_t *m * A */ |
493 | for_(int64_t j = 0; j < sp.alpha; j++) |
494 | for (int64_t k = 0; k < sp.alpha; k++) { |
495 | _m[j][k] = M(j, k, c, img, hfm, wfm); |
496 | } |
497 | trans_O_4x4_3x3(_m, O); |
498 | |
499 | float bia = with_bias ? ((float *)bia_m)[c] : 0.f; |
500 | |
501 | for (int64_t j = 0; j < sp.out_dim; j++) { |
502 | int64_t ydim = hfm * sp.out_dim + j; |
503 | if (ydim < prb->ih) { |
504 | for (int64_t k = 0; k < sp.out_dim; k++) { |
505 | int64_t xdim = wfm * sp.out_dim + k; |
506 | if (xdim < prb->iw) { |
507 | size_t src_off = src_off_f( |
508 | prb, img, 0, c, 0, ydim, xdim); |
509 | ((float *)diff_src_m)[src_off] = O[j][k] + bia; |
510 | } |
511 | } |
512 | } |
513 | } |
514 | }); |
515 | |
516 | free_scratchpad(&sp); |
517 | } |
518 | |
519 | void compute_wino_ref_bwd_w(const prb_t *prb, const args_t &args) { |
520 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
521 | const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS); |
522 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
523 | scratchpad_t sp {}; |
524 | SAFE_V(init_scratchpad(prb, sp)); |
525 | |
526 | array_offset_calculator_t<float, 4> U( |
527 | sp._u_ptr, sp.alpha, sp.alpha, prb->oc, prb->ic); |
528 | array_offset_calculator_t<float, 6> V(sp._v_ptr, sp.alpha, sp.alpha, |
529 | prb->mb, sp.h_tiles, sp.w_tiles, prb->ic); |
530 | array_offset_calculator_t<float, 6> M(sp._m_ptr, sp.alpha, sp.alpha, |
531 | prb->oc, prb->mb, sp.h_tiles, sp.w_tiles); |
532 | |
533 | SAFE_V(prb->kh == 3 ? OK : FAIL); |
534 | SAFE_V(prb->kw == 3 ? OK : FAIL); |
535 | |
536 | const int64_t t_pad = prb->ph; |
537 | const int64_t l_pad = prb->pw; |
538 | const int64_t wp_max = prb->iw + l_pad; |
539 | const int64_t hp_max = prb->ih + t_pad; |
540 | const int64_t p_dim = prb->mb * sp.h_tiles * sp.w_tiles; |
541 | |
542 | benchdnn_parallel_nd(prb->mb, sp.h_tiles, sp.w_tiles, prb->ic, |
543 | [&](int64_t img, int64_t hfm, int64_t wfm, int64_t ic) { |
544 | float I[6][6] = {}; |
545 | float _v[6][6] = {}; |
546 | /* src transform v <- B_t * d * B */ |
547 | for (int64_t j = 0; j < sp.alpha; j++) { |
548 | int64_t ydim = hfm * sp.out_dim + j; |
549 | if ((t_pad <= ydim) && (ydim < hp_max)) { |
550 | for (int64_t k = 0; k < sp.alpha; k++) { |
551 | int64_t xdim = wfm * sp.out_dim + k; |
552 | if ((l_pad <= xdim) && (xdim < wp_max)) { |
553 | size_t src_off = src_off_f(prb, img, 0, ic, 0, |
554 | ydim - t_pad, xdim - l_pad); |
555 | I[j][k] = ((float *)src_m)[src_off]; |
556 | } |
557 | } |
558 | } |
559 | } |
560 | trans_I_4x4_3x3(_v, I); |
561 | |
562 | /* scatter v:V */ |
563 | for_(int64_t j = 0; j < sp.alpha; j++) |
564 | for (int64_t k = 0; k < sp.alpha; k++) { |
565 | V(j, k, img, hfm, wfm, ic) = _v[j][k]; |
566 | } |
567 | }); |
568 | |
569 | benchdnn_parallel_nd(prb->oc, prb->mb, sp.h_tiles, sp.w_tiles, |
570 | [&](int64_t oc, int64_t img, int64_t hfm, int64_t wfm) { |
571 | float O[6][6] = {}; |
572 | float _m[6][6] = {}; |
573 | /* diff_dst transform */ |
574 | for (int64_t j = 0; j < sp.alpha; j++) { |
575 | int64_t ydim = hfm * sp.out_dim + j; |
576 | if (ydim < prb->oh) { |
577 | for (int64_t k = 0; k < sp.alpha; k++) { |
578 | int64_t xdim = wfm * sp.out_dim + k; |
579 | if (xdim < prb->ow) { |
580 | size_t dst_off = dst_off_f( |
581 | prb, img, 0, oc, 0, ydim, xdim); |
582 | O[j][k] = ((float *)diff_dst_m)[dst_off]; |
583 | } |
584 | } |
585 | } |
586 | } |
587 | trans_W_3x3_4x4_wu(_m, O); |
588 | /* scatter v:V */ |
589 | for_(int64_t j = 0; j < sp.alpha; j++) |
590 | for (int64_t k = 0; k < sp.alpha; k++) { |
591 | M(j, k, oc, img, hfm, wfm) = _m[j][k]; |
592 | } |
593 | }); |
594 | |
595 | benchdnn_parallel_nd(sp.alpha, sp.alpha, [&](int64_t j, int64_t k) { |
596 | /* GeMM U = M * V */ |
597 | gemm("C" , "N" , "N" , prb->oc, prb->ic, p_dim, 1.0, |
598 | (float *)&(M(j, k, 0, 0, 0, 0)), p_dim, |
599 | (float *)&(V(j, k, 0, 0, 0, 0)), prb->ic, 1.0, |
600 | (float *)&(U(j, k, 0, 0)), prb->ic); |
601 | }); |
602 | |
603 | benchdnn_parallel_nd(prb->oc, prb->ic, [&](int64_t oc, int64_t ic) { |
604 | float F[6][6] = {}; |
605 | float _u[3][3] = {}; |
606 | for_(int64_t j = 0; j < sp.alpha; j++) |
607 | for (int64_t k = 0; k < sp.alpha; k++) { |
608 | F[j][k] = U(j, k, oc, ic); |
609 | } |
610 | trans_O_3x3_4x4_wu(F, _u); |
611 | |
612 | /* scatter u:U */ |
613 | for_(int64_t kh = 0; kh < prb->kh; kh++) |
614 | for (int64_t kw = 0; kw < prb->kw; kw++) { |
615 | size_t wei_off = wei_off_f(prb, 0, oc, ic, 0, kh, kw); |
616 | ((float *)diff_wei_m)[wei_off] = _u[kh][kw]; |
617 | } |
618 | }); |
619 | |
620 | free_scratchpad(&sp); |
621 | |
622 | if (prb->dir & FLAG_BIA) compute_ref_bwd_bias(prb, args); |
623 | } |
624 | |
625 | } // namespace conv |
626 | |