1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "conv/ref_conv.hpp"
20
21namespace conv {
22
23void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
26 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
27 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
28 /* help compiler optimize the code */
29 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
30 const int64_t OCG = OC / G, ICG = IC / G;
31 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
32 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
33 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
34 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
35 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
36 const int64_t DD = prb->dd + 1;
37 const int64_t DH = prb->dh + 1;
38 const int64_t DW = prb->dw + 1;
39
40 auto ker = [&](float &d, int64_t g, int64_t mb, int64_t oc, int64_t od,
41 int64_t oh, int64_t ow) {
42 const float *__restrict src_loc
43 = (const float *)src_m + (mb * IC + g * ICG) * ID * IH * IW;
44 const float *__restrict wei_loc
45 = (const float *)wei_m + (g * OCG + oc) * ICG * KD * KH * KW;
46
47 for (int64_t kd = 0; kd < KD; ++kd) {
48 const int64_t id = od * SD - PD + kd * DD;
49 if (id < 0 || id >= ID) continue;
50 for (int64_t kh = 0; kh < KH; ++kh) {
51 const int64_t ih = oh * SH - PH + kh * DH;
52 if (ih < 0 || ih >= IH) continue;
53 for (int64_t kw = 0; kw < KW; ++kw) {
54 const int64_t iw = ow * SW - PW + kw * DW;
55 if (iw < 0 || iw >= IW) continue;
56
57 for (int64_t ic = 0; ic < ICG; ++ic) {
58 int64_t src_off = ((ic * ID + id) * IH + ih) * IW + iw;
59 int64_t wei_off = ((ic * KD + kd) * KH + kh) * KW + kw;
60 float s = src_loc[src_off];
61 maybe_zero_point(prb->attr, s, prb->src_zp,
62 g * ICG + ic, DNNL_ARG_SRC);
63 maybe_scale(prb->attr, s, prb->src_scales, g * ICG + ic,
64 DNNL_ARG_SRC);
65 float w = wei_loc[wei_off];
66 maybe_scale(prb->attr, w, prb->wei_scales, g * OCG + oc,
67 DNNL_ARG_WEIGHTS);
68 d += s * w;
69 }
70 }
71 }
72 }
73 };
74
75 auto v_po_masks = prb->attr.post_ops.get_po_masks();
76 benchdnn_parallel_nd(G, MB, OCG, OD, OH, OW,
77 [&](int64_t g, int64_t mb, int64_t oc, int64_t od, int64_t oh,
78 int64_t ow) {
79 const size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
80 float &dst = ((float *)dst_m)[dst_off];
81
82 float conv_res = 0;
83 ker(conv_res, g, mb, oc, od, oh, ow);
84
85 if (prb->dir & FLAG_BIA) {
86 const size_t bia_off = bia_off_f(prb, g, oc);
87 conv_res += ((float *)bia_m)[bia_off];
88 }
89
90 const auto v_po_vals
91 = prepare_po_vals(dst_m, args, v_po_masks, dst_off);
92
93 maybe_post_ops(prb->attr, conv_res, dst, v_po_vals);
94
95 maybe_scale(prb->attr, conv_res, prb->dst_scales, g * OCG + oc,
96 DNNL_ARG_DST, true);
97 maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * OCG + oc,
98 DNNL_ARG_DST, true);
99
100 dst = conv_res;
101 });
102}
103
104void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) {
105 const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC);
106 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
107 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
108 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
109 /* help compiler optimize the code */
110 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
111 const int64_t OCG = OC / G, ICG = IC / G;
112 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
113 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
114 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
115 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
116 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
117 const int64_t DD = prb->dd + 1;
118 const int64_t DH = prb->dh + 1;
119 const int64_t DW = prb->dw + 1;
120
121 enum { precompute_size = 16 };
122 const bool fast = MAX3(KD, KH, KW) <= precompute_size;
123
124 // from bwd pov zp src from fwd is zp diff dst and
125 // zp dst is zp dst is zp diff_src
126 const auto map_arg_to_zp_arg = [](int num) {
127 switch (num) {
128 case DNNL_ARG_DIFF_DST: return DNNL_ARG_SRC;
129 case DNNL_ARG_DIFF_SRC: return DNNL_ARG_DST;
130 default: assert(false && "map_arg_to_zp_arg unsupported arg");
131 }
132
133 return -1;
134 };
135 const auto &map_arg_to_sc_arg = map_arg_to_zp_arg;
136
137 /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */
138 auto precompute_ok
139 = [](int64_t i, int64_t O, int64_t K, int64_t S, int64_t P,
140 int64_t D, int64_t &num, int64_t *_o, int64_t *_k) {
141 assert(K <= precompute_size);
142 num = 0;
143 for (int64_t k = 0; k < K; ++k) {
144 int64_t o = i - k * D + P;
145 if (o < 0 || o % S) continue;
146 o /= S;
147 if (o >= O) continue;
148 _k[num] = k;
149 _o[num] = o;
150 ++num;
151 }
152 };
153
154 auto ker_fast = [&](float &ds, int64_t g, int64_t mb, int64_t ic,
155 int64_t id, int64_t ih, int64_t iw) {
156 int64_t kd[precompute_size], od[precompute_size], num_d;
157 int64_t kh[precompute_size], oh[precompute_size], num_h;
158 int64_t kw[precompute_size], ow[precompute_size], num_w;
159 precompute_ok(id, OD, KD, SD, PD, DD, num_d, od, kd);
160 precompute_ok(ih, OH, KH, SH, PH, DH, num_h, oh, kh);
161 precompute_ok(iw, OW, KW, SW, PW, DW, num_w, ow, kw);
162
163 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
164 + (mb * OC + g * OCG) * OD * OH * OW;
165 const float *__restrict wei_loc
166 = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;
167
168 for_(int64_t d = 0; d < num_d; ++d)
169 for_(int64_t h = 0; h < num_h; ++h)
170 for (int64_t w = 0; w < num_w; ++w) {
171 for (int64_t oc = 0; oc < OCG; ++oc) {
172 const int64_t diff_dst_off
173 = ((oc * OD + od[d]) * OH + oh[h]) * OW + ow[w];
174 const int64_t wei_off
175 = ((oc * ICG * KD + kd[d]) * KH + kh[h]) * KW + kw[w];
176 float diff_dst_val = diff_dst_loc[diff_dst_off];
177 maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp,
178 g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
179 maybe_scale(prb->attr, diff_dst_val, prb->src_scales,
180 g * OCG + oc, map_arg_to_sc_arg(DNNL_ARG_DIFF_DST));
181
182 float wei_val = wei_loc[wei_off];
183 maybe_scale(prb->attr, wei_val, prb->wei_scales, g * ICG + ic,
184 DNNL_ARG_WEIGHTS);
185 ds += diff_dst_val * wei_val;
186 }
187 }
188 };
189
190 auto ker = [&](float &ds, int64_t g, int64_t mb, int64_t ic, int64_t id,
191 int64_t ih, int64_t iw) {
192 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
193 + (mb * OC + g * OCG) * OD * OH * OW;
194 const float *__restrict wei_loc
195 = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW;
196
197 for (int64_t kd = 0; kd < KD; ++kd) {
198 int64_t od = id - kd * DD + PD;
199 if (od < 0 || od % SD || od >= OD * SD) continue;
200 od /= SD;
201 for (int64_t kh = 0; kh < KH; ++kh) {
202 int64_t oh = ih - kh * DH + PH;
203 if (oh < 0 || oh % SH || oh >= OH * SH) continue;
204 oh /= SH;
205 for (int64_t kw = 0; kw < KW; ++kw) {
206 int64_t ow = iw - kw * DW + PW;
207 if (ow < 0 || ow % SW || ow >= OW * SW) continue;
208 ow /= SW;
209 for (int64_t oc = 0; oc < OCG; ++oc) {
210 const int64_t diff_dst_off
211 = ((oc * OD + od) * OH + oh) * OW + ow;
212 const int64_t wei_off
213 = ((oc * ICG * KD + kd) * KH + kh) * KW + kw;
214 float diff_dst_val = diff_dst_loc[diff_dst_off];
215 maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp,
216 g * OCG + oc,
217 map_arg_to_zp_arg(DNNL_ARG_DIFF_DST));
218 maybe_scale(prb->attr, diff_dst_val, prb->src_scales,
219 g * OCG + oc,
220 map_arg_to_sc_arg(DNNL_ARG_DIFF_DST));
221 float wei_val = wei_loc[wei_off];
222 maybe_scale(prb->attr, wei_val, prb->wei_scales,
223 g * ICG + ic, DNNL_ARG_WEIGHTS);
224
225 ds += diff_dst_val * wei_val;
226 }
227 }
228 }
229 }
230 };
231
232 auto v_po_masks = prb->attr.post_ops.get_po_masks();
233 benchdnn_parallel_nd(G, MB, ICG, ID, IH, IW,
234 [&](int64_t g, int64_t mb, int64_t ic, int64_t id, int64_t ih,
235 int64_t iw) {
236 size_t src_off = src_off_f(prb, mb, g, ic, id, ih, iw);
237 float &ds = ((float *)diff_src_m)[src_off];
238 float conv_res = 0;
239 if (fast)
240 ker_fast(conv_res, g, mb, ic, id, ih, iw);
241 else
242 ker(conv_res, g, mb, ic, id, ih, iw);
243
244 if (prb->dir & FLAG_BIA) {
245 const size_t bia_off = (size_t)g * ICG + ic;
246 conv_res += ((float *)bia_m)[bia_off];
247 }
248
249 const auto v_po_vals = prepare_po_vals(
250 diff_src_m, args, v_po_masks, src_off);
251
252 maybe_post_ops(prb->attr, conv_res, ds, v_po_vals);
253 maybe_scale(prb->attr, conv_res, prb->dst_scales, g * ICG + ic,
254 map_arg_to_sc_arg(DNNL_ARG_DIFF_SRC), true);
255 maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * ICG + ic,
256 map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true);
257
258 ds = conv_res;
259 });
260}
261
262void compute_ref_bwd_weights(const prb_t *prb, const args_t &args) {
263 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
264 const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS);
265 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
266 /* help compiler optimize the code */
267 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic;
268 const int64_t OCG = OC / G, ICG = IC / G;
269 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
270 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
271 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
272 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
273 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
274 const int64_t DD = prb->dd + 1;
275 const int64_t DH = prb->dh + 1;
276 const int64_t DW = prb->dw + 1;
277
278 auto compute_bounds
279 = [](int64_t I, int64_t O, int64_t k, int64_t S, int64_t P,
280 int64_t D, int64_t &o_s, int64_t &o_e) {
281 const float tmp = P - k * D;
282 o_s = MAX2(0, ceilf(tmp / S));
283 o_e = MIN2(O, ceilf((I + tmp) / S));
284 };
285
286 auto ker = [&](float &dw, int64_t g, int64_t oc, int64_t ic, int64_t kd,
287 int64_t kh, int64_t kw) {
288 int64_t od_s, od_e, oh_s, oh_e, ow_s, ow_e;
289 compute_bounds(ID, OD, kd, SD, PD, DD, od_s, od_e);
290 compute_bounds(IH, OH, kh, SH, PH, DH, oh_s, oh_e);
291 compute_bounds(IW, OW, kw, SW, PW, DW, ow_s, ow_e);
292 const int64_t id_s = kd * DD - PD;
293 const int64_t ih_s = kh * DH - PH;
294 const int64_t iw_s = kw * DW - PW;
295
296 for (int64_t mb = 0; mb < MB; ++mb) {
297 const float *__restrict diff_dst_loc = (const float *)diff_dst_m
298 + (mb * OC + g * OCG + oc) * OD * OH * OW;
299 const float *__restrict src_loc = (const float *)src_m
300 + (mb * IC + g * ICG + ic) * ID * IH * IW;
301
302 for_(int64_t od = od_s; od < od_e; ++od)
303 for_(int64_t oh = oh_s; oh < oh_e; ++oh)
304 for (int64_t ow = ow_s; ow < ow_e; ++ow) {
305 const int64_t id = od * SD + id_s;
306 const int64_t ih = oh * SH + ih_s;
307 const int64_t iw = ow * SW + iw_s;
308
309 size_t diff_dst_off = (od * OH + oh) * OW + ow;
310 size_t src_off = (id * IH + ih) * IW + iw;
311 dw += diff_dst_loc[diff_dst_off] * src_loc[src_off];
312 }
313 }
314 };
315
316 benchdnn_parallel_nd(G, OCG, ICG, KD, KH, KW,
317 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
318 int64_t kw) {
319 size_t wei_off = wei_off_f(prb, g, oc, ic, kd, kh, kw);
320 float &dw = ((float *)diff_wei_m)[wei_off];
321 dw = 0;
322 ker(dw, g, oc, ic, kd, kh, kw);
323 });
324}
325
326void compute_ref_bwd_bias(const prb_t *prb, const args_t &args) {
327 const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS);
328 const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST);
329 /* help compiler optimize the code */
330 const int64_t MB = prb->mb, G = prb->g, OC = prb->oc;
331 const int64_t OCG = OC / G;
332 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
333
334 benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) {
335 size_t bia_off = bia_off_f(prb, g, oc);
336 double sum = 0;
337
338 for_(int64_t mb = 0; mb < MB; ++mb)
339 for_(int64_t od = 0; od < OD; ++od)
340 for_(int64_t oh = 0; oh < OH; ++oh)
341 for (int64_t ow = 0; ow < OW; ++ow) {
342 size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow);
343 sum += ((float *)diff_dst_m)[dst_off];
344 }
345 ((float *)diff_bia_m)[bia_off] = (float)sum;
346 });
347}
348
349void compute_ref_direct_bwd_w(const prb_t *prb, const args_t &args) {
350 compute_ref_bwd_weights(prb, args);
351 if (!(prb->dir & FLAG_BIA)) return;
352 compute_ref_bwd_bias(prb, args);
353}
354
355void compute_ref_fwd(
356 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
357 if (prim_ref) {
358 SAFE_V(execute_and_wait(prim_ref, args));
359 return;
360 }
361
362 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
363 compute_wino_ref_fwd(prb, args);
364 } else {
365 compute_ref_direct_fwd(prb, args);
366 }
367}
368
369void compute_ref_bwd_d(
370 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
371 if (prim_ref) {
372 SAFE_V(execute_and_wait(prim_ref, args));
373 return;
374 }
375
376 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
377 compute_wino_ref_bwd_d(prb, args);
378 } else {
379 compute_ref_direct_bwd_d(prb, args);
380 }
381}
382
383void compute_ref_bwd_w(
384 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
385 if (prim_ref) {
386 SAFE_V(execute_and_wait(prim_ref, args));
387 return;
388 }
389
390 if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) {
391 compute_wino_ref_bwd_w(prb, args);
392 } else {
393 compute_ref_direct_bwd_w(prb, args);
394 }
395}
396
397void compute_ref(
398 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
399 if (prb->dir & FLAG_FWD)
400 compute_ref_fwd(prb, args, prim_ref);
401 else if (prb->dir == BWD_D)
402 compute_ref_bwd_d(prb, args, prim_ref);
403 else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI)
404 compute_ref_bwd_w(prb, args, prim_ref);
405}
406
407} // namespace conv
408