1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <utility>
18
19#include "utils/parallel.hpp"
20
21#include "deconv/ref_deconv.hpp"
22
23namespace deconv {
24
25void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) {
26 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
27 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
28 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
29 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
30 /* help compiler optimize the code */
31 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
32 const int64_t OCG = OC / G, ICG = IC / G;
33 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
34 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
35 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
36 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
37 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
38 const int64_t DD = prb->dd + 1;
39 const int64_t DH = prb->dh + 1;
40 const int64_t DW = prb->dw + 1;
41
42 auto ker = [&](float &d, int64_t g, int64_t mb, int64_t oc, int64_t od,
43 int64_t oh, int64_t ow) {
44 const float *__restrict src_loc
45 = (const float *)src_m + (mb * IC + g * ICG) * ID * IH * IW;
46 const float *__restrict wei_loc
47 = (const float *)wei_m + (g * OCG + oc) * ICG * KD * KH * KW;
48
49 for (int64_t kd = 0; kd < KD; ++kd) {
50 const int64_t id = od * SD - PD + kd * DD;
51 if (id < 0 || id >= ID) continue;
52 for (int64_t kh = 0; kh < KH; ++kh) {
53 const int64_t ih = oh * SH - PH + kh * DH;
54 if (ih < 0 || ih >= IH) continue;
55 for (int64_t kw = 0; kw < KW; ++kw) {
56 const int64_t iw = ow * SW - PW + kw * DW;
57 if (iw < 0 || iw >= IW) continue;
58
59 for (int64_t ic = 0; ic < ICG; ++ic) {
60 int64_t src_off = ((ic * ID + id) * IH + ih) * IW + iw;
61 int64_t wei_off = ((ic * KD + kd) * KH + kh) * KW + kw;
62 float s = src_loc[src_off];
63 maybe_zero_point(prb->attr, s, prb->src_zp,
64 g * ICG + ic, DNNL_ARG_SRC);
65 maybe_scale(prb->attr, s, prb->src_scales, g * ICG + ic,
66 DNNL_ARG_SRC);
67 float w = wei_loc[wei_off];
68 maybe_scale(prb->attr, w, prb->wei_scales, g * OCG + oc,
69 DNNL_ARG_WEIGHTS);
70 d += s * w;
71 }
72 }
73 }
74 }
75 };
76
77 auto v_po_masks = prb->attr.post_ops.get_po_masks();
78 benchdnn_parallel_nd(G, MB, OCG, OD, OH, OW,
79 [&](int64_t g, int64_t mb, int64_t oc, int64_t od, int64_t oh,
80 int64_t ow) {
81 const size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
82 float &dst = ((float *)dst_m)[dst_off];
83
84 float conv_res = 0;
85 ker(conv_res, g, mb, oc, od, oh, ow);
86
87 if (prb->dir & FLAG_BIA) {
88 const size_t bia_off = bia_off_f(prb, g, oc);
89 conv_res += ((float *)bia_m)[bia_off];
90 }
91
92 const auto v_po_vals
93 = prepare_po_vals(dst_m, args, v_po_masks, dst_off);
94
95 maybe_post_ops(prb->attr, conv_res, dst, v_po_vals);
96
97 maybe_scale(prb->attr, conv_res, prb->dst_scales, g * OCG + oc,
98 DNNL_ARG_DST, true);
99 maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * OCG + oc,
100 DNNL_ARG_DST, true);
101
102 dst = conv_res;
103 });
104}
105
106void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) {
107 const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
108 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
109 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
110 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
111 /* help compiler optimize the code */
112 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
113 const int64_t OCG = OC / G, ICG = IC / G;
114 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
115 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
116 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
117 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
118 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
119 const int64_t DD = prb->dd + 1;
120 const int64_t DH = prb->dh + 1;
121 const int64_t DW = prb->dw + 1;
122
123 enum { precompute_size = 16 };
124 const bool fast = MAX3(KD, KH, KW) <= precompute_size;
125
126 // from bwd pov zp src from fwd is zp diff dst and
127 // zp dst is zp dst is zp diff_src
128 const auto map_arg_to_zp_arg = [](int num) {
129 switch (num) {
130 case DNNL_ARG_DIFF_DST: return DNNL_ARG_SRC;
131 case DNNL_ARG_DIFF_SRC: return DNNL_ARG_DST;
132 default: assert(false && "map_arg_to_zp_arg unsupported arg");
133 }
134
135 return -1;
136 };
137
138 /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */
139 auto precompute_ok
140 = [](int64_t i, int64_t O, int64_t K, int64_t S, int64_t P,
141 int64_t D, int64_t &num, int64_t *_o, int64_t *_k) {
142 assert(K <= precompute_size);
143 num = 0;
144 for (int64_t k = 0; k < K; ++k) {
145 int64_t o = i - k * D + P;
146 if (o < 0 || o % S) continue;
147 o /= S;
148 if (o >= O) continue;
149 _k[num] = k;
150 _o[num] = o;
151 ++num;
152 }
153 };
154
155 auto ker_fast = [&](float &ds, int64_t g, int64_t mb, int64_t ic,
156 int64_t id, int64_t ih, int64_t iw) {
157 int64_t kd[precompute_size], od[precompute_size], num_d;
158 int64_t kh[precompute_size], oh[precompute_size], num_h;
159 int64_t kw[precompute_size], ow[precompute_size], num_w;
160 precompute_ok(id, OD, KD, SD, PD, DD, num_d, od, kd);
161 precompute_ok(ih, OH, KH, SH, PH, DH, num_h, oh, kh);
162 precompute_ok(iw, OW, KW, SW, PW, DW, num_w, ow, kw);
163
164 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
165 + (mb * OC + g * OCG) * OD * OH * OW;
166 const float *__restrict wei_loc
167 = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;
168
169 for_(int64_t d = 0; d < num_d; ++d)
170 for_(int64_t h = 0; h < num_h; ++h)
171 for (int64_t w = 0; w < num_w; ++w) {
172 for (int64_t oc = 0; oc < OCG; ++oc) {
173 const int64_t diff_dst_off
174 = ((oc * OD + od[d]) * OH + oh[h]) * OW + ow[w];
175 const int64_t wei_off
176 = ((oc * ICG * KD + kd[d]) * KH + kh[h]) * KW + kw[w];
177 float diff_dst_val = diff_dst_loc[diff_dst_off];
178 maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp,
179 g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
180 maybe_scale(prb->attr, diff_dst_val, prb->src_scales,
181 g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
182
183 float wei_val = wei_loc[wei_off];
184 maybe_scale(prb->attr, wei_val, prb->wei_scales, g * ICG + ic,
185 DNNL_ARG_WEIGHTS);
186 ds += diff_dst_val * wei_val;
187 }
188 }
189 };
190
191 auto ker = [&](float &ds, int64_t g, int64_t mb, int64_t ic, int64_t id,
192 int64_t ih, int64_t iw) {
193 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
194 + (mb * OC + g * OCG) * OD * OH * OW;
195 const float *__restrict wei_loc
196 = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;
197
198 for (int64_t kd = 0; kd < KD; ++kd) {
199 int64_t od = id - kd * DD + PD;
200 if (od < 0 || od % SD || od >= OD * SD) continue;
201 od /= SD;
202 for (int64_t kh = 0; kh < KH; ++kh) {
203 int64_t oh = ih - kh * DH + PH;
204 if (oh < 0 || oh % SH || oh >= OH * SH) continue;
205 oh /= SH;
206 for (int64_t kw = 0; kw < KW; ++kw) {
207 int64_t ow = iw - kw * DW + PW;
208 if (ow < 0 || ow % SW || ow >= OW * SW) continue;
209 ow /= SW;
210 for (int64_t oc = 0; oc < OCG; ++oc) {
211 const int64_t diff_dst_off
212 = ((oc * OD + od) * OH + oh) * OW + ow;
213 const int64_t wei_off
214 = ((oc * ICG * KD + kd) * KH + kh) * KW + kw;
215 float diff_dst_val = diff_dst_loc[diff_dst_off];
216 maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp,
217 g * OCG + oc,
218 map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
219 maybe_scale(prb->attr, diff_dst_val, prb->src_scales,
220 g * OCG + oc,
221 map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
222
223 float wei_val = wei_loc[wei_off];
224 maybe_scale(prb->attr, wei_val, prb->wei_scales,
225 g * ICG + ic, DNNL_ARG_WEIGHTS);
226
227 ds += diff_dst_val * wei_val;
228 }
229 }
230 }
231 }
232 };
233
234 auto v_po_masks = prb->attr.post_ops.get_po_masks();
235 benchdnn_parallel_nd(G, MB, ICG, ID, IH, IW,
236 [&](int64_t g, int64_t mb, int64_t ic, int64_t id, int64_t ih,
237 int64_t iw) {
238 size_t src_off = src_off_f(prb, mb, g, ic, id, ih, iw);
239 float &ds = ((float *)diff_src_m)[src_off];
240 float conv_res = 0;
241 if (fast)
242 ker_fast(conv_res, g, mb, ic, id, ih, iw);
243 else
244 ker(conv_res, g, mb, ic, id, ih, iw);
245
246 if (prb->dir & FLAG_BIA) {
247 const size_t bia_off = (size_t)g * ICG + ic;
248 conv_res += ((float *)bia_m)[bia_off];
249 }
250
251 const auto v_po_vals = prepare_po_vals(
252 diff_src_m, args, v_po_masks, src_off);
253
254 maybe_post_ops(prb->attr, conv_res, ds, v_po_vals);
255 maybe_scale(prb->attr, conv_res, prb->dst_scales, g * ICG + ic,
256 map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true);
257 maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * ICG + ic,
258 map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true);
259
260 ds = conv_res;
261 });
262}
263
264void compute_ref_bwd_weights(const prb_t *prb, const args_t &args) {
265 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
266 const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
267 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
268 /* help compiler optimize the code */
269 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
270 const int64_t OCG = OC / G, ICG = IC / G;
271 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
272 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
273 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
274 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
275 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
276 const int64_t DD = prb->dd + 1;
277 const int64_t DH = prb->dh + 1;
278 const int64_t DW = prb->dw + 1;
279
280 auto compute_bounds
281 = [](int64_t I, int64_t O, int64_t k, int64_t S, int64_t P,
282 int64_t D, int64_t &o_s, int64_t &o_e) {
283 const float tmp = P - k * D;
284 o_s = MAX2(0, ceilf(tmp / S));
285 o_e = MIN2(O, ceilf((I + tmp) / S));
286 };
287
288 auto ker = [&](float &dw, int64_t g, int64_t oc, int64_t ic, int64_t kd,
289 int64_t kh, int64_t kw) {
290 int64_t od_s, od_e, oh_s, oh_e, ow_s, ow_e;
291 compute_bounds(ID, OD, kd, SD, PD, DD, od_s, od_e);
292 compute_bounds(IH, OH, kh, SH, PH, DH, oh_s, oh_e);
293 compute_bounds(IW, OW, kw, SW, PW, DW, ow_s, ow_e);
294 const int64_t id_s = kd * DD - PD;
295 const int64_t ih_s = kh * DH - PH;
296 const int64_t iw_s = kw * DW - PW;
297
298 for (int64_t mb = 0; mb < MB; ++mb) {
299 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
300 + (mb * OC + g * OCG + oc) * OD * OH * OW;
301 const float *__restrict src_loc = (const float *)src_m
302 + (mb * IC + g * ICG + ic) * ID * IH * IW;
303
304 for_(int64_t od = od_s; od < od_e; ++od)
305 for_(int64_t oh = oh_s; oh < oh_e; ++oh)
306 for (int64_t ow = ow_s; ow < ow_e; ++ow) {
307 const int64_t id = od * SD + id_s;
308 const int64_t ih = oh * SH + ih_s;
309 const int64_t iw = ow * SW + iw_s;
310
311 size_t diff_dst_off = (od * OH + oh) * OW + ow;
312 size_t src_off = (id * IH + ih) * IW + iw;
313 dw += diff_dst_loc[diff_dst_off] * src_loc[src_off];
314 }
315 }
316 };
317
318 benchdnn_parallel_nd(G, OCG, ICG, KD, KH, KW,
319 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
320 int64_t kw) {
321 size_t wei_off = wei_off_f(prb, g, oc, ic, kd, kh, kw);
322 float &dw = ((float *)diff_wei_m)[wei_off];
323 dw = 0;
324 ker(dw, g, oc, ic, kd, kh, kw);
325 });
326}
327
328void compute_ref_bwd_bias(const prb_t *prb, const args_t &args) {
329 const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
330 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
331 /* help compiler optimize the code */
332 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc;
333 const int64_t OCG = OC / G;
334 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
335
336 benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) {
337 size_t bia_off = bia_off_f(prb, g, oc);
338 double sum = 0;
339
340 for_(int64_t mb = 0; mb < MB; ++mb)
341 for_(int64_t od = 0; od < OD; ++od)
342 for_(int64_t oh = 0; oh < OH; ++oh)
343 for (int64_t ow = 0; ow < OW; ++ow) {
344 size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
345 sum += ((float *)diff_dst_m)[dst_off];
346 }
347 ((float *)diff_bia_m)[bia_off] = (float)sum;
348 });
349}
350
351void compute_ref_fwd(
352 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
353 if (prim_ref) {
354 SAFE_V(execute_and_wait(prim_ref, args));
355 return;
356 }
357
358 // Swap arguments to re-use existing conv ref implementation.
359 args_t ref_conv_args;
360 for (int i = 0; i < args.size(); i++) {
361 if (args.arg(i) == DNNL_ARG_SRC)
362 ref_conv_args.set(DNNL_ARG_DIFF_DST, args.dnn_mem(i));
363 else if (args.arg(i) == DNNL_ARG_WEIGHTS)
364 ref_conv_args.set(DNNL_ARG_WEIGHTS, args.find(DNNL_ARG_WEIGHTS_1));
365 else if (args.arg(i) == DNNL_ARG_DST)
366 ref_conv_args.set(DNNL_ARG_DIFF_SRC, args.dnn_mem(i));
367 else
368 ref_conv_args.set(args.arg(i), args.dnn_mem(i));
369 }
370
371 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
372 compute_wino_ref_bwd_d(prb, ref_conv_args);
373 } else {
374 compute_ref_direct_bwd_d(prb, ref_conv_args);
375 }
376}
377
378void compute_ref_bwd_d(
379 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
380 if (prim_ref) {
381 SAFE_V(execute_and_wait(prim_ref, args));
382 return;
383 }
384
385 // Swap arguments to re-use existing conv ref implementation.
386 args_t ref_conv_args;
387 for (int i = 0; i < args.size(); i++) {
388 if (args.arg(i) == DNNL_ARG_DIFF_SRC)
389 ref_conv_args.set(DNNL_ARG_DST, args.dnn_mem(i));
390 else if (args.arg(i) == DNNL_ARG_WEIGHTS)
391 ref_conv_args.set(DNNL_ARG_WEIGHTS, args.find(DNNL_ARG_WEIGHTS_1));
392 else if (args.arg(i) == DNNL_ARG_DIFF_DST)
393 ref_conv_args.set(DNNL_ARG_SRC, args.dnn_mem(i));
394 else
395 ref_conv_args.set(args.arg(i), args.dnn_mem(i));
396 }
397
398 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
399 compute_wino_ref_fwd(prb, ref_conv_args);
400 } else {
401 compute_ref_direct_fwd(prb, ref_conv_args);
402 }
403}
404
405void compute_ref_bwd_w(
406 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
407 if (prim_ref) {
408 SAFE_V(execute_and_wait(prim_ref, args));
409 return;
410 }
411
412 // Swap arguments to re-use existing conv ref implementation.
413 args_t ref_conv_args;
414 for (int i = 0; i < args.size(); i++) {
415 if (args.arg(i) == DNNL_ARG_SRC)
416 ref_conv_args.set(DNNL_ARG_DIFF_DST, args.dnn_mem(i));
417 else if (args.arg(i) == DNNL_ARG_DIFF_WEIGHTS)
418 ref_conv_args.set(
419 DNNL_ARG_DIFF_WEIGHTS, args.find(DNNL_ARG_DIFF_WEIGHTS_1));
420 else if (args.arg(i) == DNNL_ARG_DIFF_DST)
421 ref_conv_args.set(DNNL_ARG_SRC, args.dnn_mem(i));
422 else
423 ref_conv_args.set(args.arg(i), args.dnn_mem(i));
424 }
425
426 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
427 compute_wino_ref_bwd_w(prb, ref_conv_args);
428 } else {
429 compute_ref_bwd_weights(prb, ref_conv_args);
430 }
431
432 // Need to transpose data in weights back for proper comparison. This step
433 // is done here as it's not needed for fast-ref-gpu.
434 transpose_data_wei(prb, args.find(DNNL_ARG_DIFF_WEIGHTS_1),
435 args.find(DNNL_ARG_DIFF_WEIGHTS));
436
437 // We don't reuse `compute_ref_bwd_bias` as it doesn't match arguments and
438 // entry problem which is transposed - `p_tr`. Simpler to use the kernel
439 // directly.
440 // Take original memories, not `ref_conv_args`.
441 if (prb->dir & FLAG_BIA) {
442 const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
443 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
444 /* help compiler optimize the code */
445 const int64_t MB = prb->mb, G = prb->g;
446 const int64_t OC = prb->ic; // prb.oc = p_tr.ic
447 const int64_t OCG = OC / G;
448 const int64_t OD = prb->id; // prb.od = p_tr.id
449 const int64_t OH = prb->ih; // prb.oh = p_tr.ih
450 const int64_t OW = prb->iw; // prb.ow = p_tr.iw
451
452 benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) {
453 size_t bia_off = g * OCG + oc;
454 double sum = 0;
455
456 for_(int64_t mb = 0; mb < MB; ++mb)
457 for_(int64_t od = 0; od < OD; ++od)
458 for_(int64_t oh = 0; oh < OH; ++oh)
459 for (int64_t ow = 0; ow < OW; ++ow) {
460 // src_off_f instead of dst_off_f due to inverse descriptor.
461 size_t dst_off = src_off_f(prb, mb, g, oc, od, oh, ow);
462 sum += ((float *)diff_dst_m)[dst_off];
463 }
464 ((float *)diff_bia_m)[bia_off] = (float)sum;
465 });
466 }
467}
468
469void compute_ref(
470 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
471 // Update prb descriptor to re-use convolution reference.
472 prb_t prb_tr((desc_t)*prb, prb->dir, prb->cfg, prb->stag, prb->wtag,
473 prb->dtag, prb->alg, prb->attr, prb->ctx_init, prb->ctx_exe,
474 prb->mb);
475 std::swap(prb_tr.ic, prb_tr.oc);
476 std::swap(prb_tr.ih, prb_tr.oh);
477 std::swap(prb_tr.id, prb_tr.od);
478 std::swap(prb_tr.iw, prb_tr.ow);
479
480 if (prb->dir & FLAG_FWD)
481 compute_ref_fwd(&prb_tr, args, prim_ref);
482 else if (prb->dir == BWD_D)
483 compute_ref_bwd_d(&prb_tr, args, prim_ref);
484 else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI)
485 compute_ref_bwd_w(&prb_tr, args, prim_ref);
486}
487
488} // namespace deconv
489