1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "conv/ref_conv.hpp" |
20 | |
21 | namespace conv { |
22 | |
23 | void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
26 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
27 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
28 | /* help compiler optimize the code */ |
29 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
30 | const int64_t OCG = OC / G, ICG = IC / G; |
31 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
32 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
33 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
34 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
35 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
36 | const int64_t DD = prb->dd + 1; |
37 | const int64_t DH = prb->dh + 1; |
38 | const int64_t DW = prb->dw + 1; |
39 | |
40 | auto ker = [&](float &d, int64_t g, int64_t mb, int64_t oc, int64_t od, |
41 | int64_t oh, int64_t ow) { |
42 | const float *__restrict src_loc |
43 | = (const float *)src_m + (mb * IC + g * ICG) * ID * IH * IW; |
44 | const float *__restrict wei_loc |
45 | = (const float *)wei_m + (g * OCG + oc) * ICG * KD * KH * KW; |
46 | |
47 | for (int64_t kd = 0; kd < KD; ++kd) { |
48 | const int64_t id = od * SD - PD + kd * DD; |
49 | if (id < 0 || id >= ID) continue; |
50 | for (int64_t kh = 0; kh < KH; ++kh) { |
51 | const int64_t ih = oh * SH - PH + kh * DH; |
52 | if (ih < 0 || ih >= IH) continue; |
53 | for (int64_t kw = 0; kw < KW; ++kw) { |
54 | const int64_t iw = ow * SW - PW + kw * DW; |
55 | if (iw < 0 || iw >= IW) continue; |
56 | |
57 | for (int64_t ic = 0; ic < ICG; ++ic) { |
58 | int64_t src_off = ((ic * ID + id) * IH + ih) * IW + iw; |
59 | int64_t wei_off = ((ic * KD + kd) * KH + kh) * KW + kw; |
60 | float s = src_loc[src_off]; |
61 | maybe_zero_point(prb->attr, s, prb->src_zp, |
62 | g * ICG + ic, DNNL_ARG_SRC); |
63 | maybe_scale(prb->attr, s, prb->src_scales, g * ICG + ic, |
64 | DNNL_ARG_SRC); |
65 | float w = wei_loc[wei_off]; |
66 | maybe_scale(prb->attr, w, prb->wei_scales, g * OCG + oc, |
67 | DNNL_ARG_WEIGHTS); |
68 | d += s * w; |
69 | } |
70 | } |
71 | } |
72 | } |
73 | }; |
74 | |
75 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
76 | benchdnn_parallel_nd(G, MB, OCG, OD, OH, OW, |
77 | [&](int64_t g, int64_t mb, int64_t oc, int64_t od, int64_t oh, |
78 | int64_t ow) { |
79 | const size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow); |
80 | float &dst = ((float *)dst_m)[dst_off]; |
81 | |
82 | float conv_res = 0; |
83 | ker(conv_res, g, mb, oc, od, oh, ow); |
84 | |
85 | if (prb->dir & FLAG_BIA) { |
86 | const size_t bia_off = bia_off_f(prb, g, oc); |
87 | conv_res += ((float *)bia_m)[bia_off]; |
88 | } |
89 | |
90 | const auto v_po_vals |
91 | = prepare_po_vals(dst_m, args, v_po_masks, dst_off); |
92 | |
93 | maybe_post_ops(prb->attr, conv_res, dst, v_po_vals); |
94 | |
95 | maybe_scale(prb->attr, conv_res, prb->dst_scales, g * OCG + oc, |
96 | DNNL_ARG_DST, true); |
97 | maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * OCG + oc, |
98 | DNNL_ARG_DST, true); |
99 | |
100 | dst = conv_res; |
101 | }); |
102 | } |
103 | |
104 | void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) { |
105 | const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC); |
106 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
107 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
108 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
109 | /* help compiler optimize the code */ |
110 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
111 | const int64_t OCG = OC / G, ICG = IC / G; |
112 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
113 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
114 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
115 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
116 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
117 | const int64_t DD = prb->dd + 1; |
118 | const int64_t DH = prb->dh + 1; |
119 | const int64_t DW = prb->dw + 1; |
120 | |
121 | enum { precompute_size = 16 }; |
122 | const bool fast = MAX3(KD, KH, KW) <= precompute_size; |
123 | |
124 | // from bwd pov zp src from fwd is zp diff dst and |
125 | // zp dst is zp dst is zp diff_src |
126 | const auto map_arg_to_zp_arg = [](int num) { |
127 | switch (num) { |
128 | case DNNL_ARG_DIFF_DST: return DNNL_ARG_SRC; |
129 | case DNNL_ARG_DIFF_SRC: return DNNL_ARG_DST; |
130 | default: assert(false && "map_arg_to_zp_arg unsupported arg" ); |
131 | } |
132 | |
133 | return -1; |
134 | }; |
135 | const auto &map_arg_to_sc_arg = map_arg_to_zp_arg; |
136 | |
137 | /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */ |
138 | auto precompute_ok |
139 | = [](int64_t i, int64_t O, int64_t K, int64_t S, int64_t P, |
140 | int64_t D, int64_t &num, int64_t *_o, int64_t *_k) { |
141 | assert(K <= precompute_size); |
142 | num = 0; |
143 | for (int64_t k = 0; k < K; ++k) { |
144 | int64_t o = i - k * D + P; |
145 | if (o < 0 || o % S) continue; |
146 | o /= S; |
147 | if (o >= O) continue; |
148 | _k[num] = k; |
149 | _o[num] = o; |
150 | ++num; |
151 | } |
152 | }; |
153 | |
154 | auto ker_fast = [&](float &ds, int64_t g, int64_t mb, int64_t ic, |
155 | int64_t id, int64_t ih, int64_t iw) { |
156 | int64_t kd[precompute_size], od[precompute_size], num_d; |
157 | int64_t kh[precompute_size], oh[precompute_size], num_h; |
158 | int64_t kw[precompute_size], ow[precompute_size], num_w; |
159 | precompute_ok(id, OD, KD, SD, PD, DD, num_d, od, kd); |
160 | precompute_ok(ih, OH, KH, SH, PH, DH, num_h, oh, kh); |
161 | precompute_ok(iw, OW, KW, SW, PW, DW, num_w, ow, kw); |
162 | |
163 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
164 | + (mb * OC + g * OCG) * OD * OH * OW; |
165 | const float *__restrict wei_loc |
166 | = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW; |
167 | |
168 | for_(int64_t d = 0; d < num_d; ++d) |
169 | for_(int64_t h = 0; h < num_h; ++h) |
170 | for (int64_t w = 0; w < num_w; ++w) { |
171 | for (int64_t oc = 0; oc < OCG; ++oc) { |
172 | const int64_t diff_dst_off |
173 | = ((oc * OD + od[d]) * OH + oh[h]) * OW + ow[w]; |
174 | const int64_t wei_off |
175 | = ((oc * ICG * KD + kd[d]) * KH + kh[h]) * KW + kw[w]; |
176 | float diff_dst_val = diff_dst_loc[diff_dst_off]; |
177 | maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp, |
178 | g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
179 | maybe_scale(prb->attr, diff_dst_val, prb->src_scales, |
180 | g * OCG + oc, map_arg_to_sc_arg(DNNL_ARG_DIFF_DST)); |
181 | |
182 | float wei_val = wei_loc[wei_off]; |
183 | maybe_scale(prb->attr, wei_val, prb->wei_scales, g * ICG + ic, |
184 | DNNL_ARG_WEIGHTS); |
185 | ds += diff_dst_val * wei_val; |
186 | } |
187 | } |
188 | }; |
189 | |
190 | auto ker = [&](float &ds, int64_t g, int64_t mb, int64_t ic, int64_t id, |
191 | int64_t ih, int64_t iw) { |
192 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
193 | + (mb * OC + g * OCG) * OD * OH * OW; |
194 | const float *__restrict wei_loc |
195 | = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW; |
196 | |
197 | for (int64_t kd = 0; kd < KD; ++kd) { |
198 | int64_t od = id - kd * DD + PD; |
199 | if (od < 0 || od % SD || od >= OD * SD) continue; |
200 | od /= SD; |
201 | for (int64_t kh = 0; kh < KH; ++kh) { |
202 | int64_t oh = ih - kh * DH + PH; |
203 | if (oh < 0 || oh % SH || oh >= OH * SH) continue; |
204 | oh /= SH; |
205 | for (int64_t kw = 0; kw < KW; ++kw) { |
206 | int64_t ow = iw - kw * DW + PW; |
207 | if (ow < 0 || ow % SW || ow >= OW * SW) continue; |
208 | ow /= SW; |
209 | for (int64_t oc = 0; oc < OCG; ++oc) { |
210 | const int64_t diff_dst_off |
211 | = ((oc * OD + od) * OH + oh) * OW + ow; |
212 | const int64_t wei_off |
213 | = ((oc * ICG * KD + kd) * KH + kh) * KW + kw; |
214 | float diff_dst_val = diff_dst_loc[diff_dst_off]; |
215 | maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp, |
216 | g * OCG + oc, |
217 | map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
218 | maybe_scale(prb->attr, diff_dst_val, prb->src_scales, |
219 | g * OCG + oc, |
220 | map_arg_to_sc_arg(DNNL_ARG_DIFF_DST)); |
221 | float wei_val = wei_loc[wei_off]; |
222 | maybe_scale(prb->attr, wei_val, prb->wei_scales, |
223 | g * ICG + ic, DNNL_ARG_WEIGHTS); |
224 | |
225 | ds += diff_dst_val * wei_val; |
226 | } |
227 | } |
228 | } |
229 | } |
230 | }; |
231 | |
232 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
233 | benchdnn_parallel_nd(G, MB, ICG, ID, IH, IW, |
234 | [&](int64_t g, int64_t mb, int64_t ic, int64_t id, int64_t ih, |
235 | int64_t iw) { |
236 | size_t src_off = src_off_f(prb, mb, g, ic, id, ih, iw); |
237 | float &ds = ((float *)diff_src_m)[src_off]; |
238 | float conv_res = 0; |
239 | if (fast) |
240 | ker_fast(conv_res, g, mb, ic, id, ih, iw); |
241 | else |
242 | ker(conv_res, g, mb, ic, id, ih, iw); |
243 | |
244 | if (prb->dir & FLAG_BIA) { |
245 | const size_t bia_off = (size_t)g * ICG + ic; |
246 | conv_res += ((float *)bia_m)[bia_off]; |
247 | } |
248 | |
249 | const auto v_po_vals = prepare_po_vals( |
250 | diff_src_m, args, v_po_masks, src_off); |
251 | |
252 | maybe_post_ops(prb->attr, conv_res, ds, v_po_vals); |
253 | maybe_scale(prb->attr, conv_res, prb->dst_scales, g * ICG + ic, |
254 | map_arg_to_sc_arg(DNNL_ARG_DIFF_SRC), true); |
255 | maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * ICG + ic, |
256 | map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true); |
257 | |
258 | ds = conv_res; |
259 | }); |
260 | } |
261 | |
262 | void compute_ref_bwd_weights(const prb_t *prb, const args_t &args) { |
263 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
264 | const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS); |
265 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
266 | /* help compiler optimize the code */ |
267 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
268 | const int64_t OCG = OC / G, ICG = IC / G; |
269 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
270 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
271 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
272 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
273 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
274 | const int64_t DD = prb->dd + 1; |
275 | const int64_t DH = prb->dh + 1; |
276 | const int64_t DW = prb->dw + 1; |
277 | |
278 | auto compute_bounds |
279 | = [](int64_t I, int64_t O, int64_t k, int64_t S, int64_t P, |
280 | int64_t D, int64_t &o_s, int64_t &o_e) { |
281 | const float tmp = P - k * D; |
282 | o_s = MAX2(0, ceilf(tmp / S)); |
283 | o_e = MIN2(O, ceilf((I + tmp) / S)); |
284 | }; |
285 | |
286 | auto ker = [&](float &dw, int64_t g, int64_t oc, int64_t ic, int64_t kd, |
287 | int64_t kh, int64_t kw) { |
288 | int64_t od_s, od_e, oh_s, oh_e, ow_s, ow_e; |
289 | compute_bounds(ID, OD, kd, SD, PD, DD, od_s, od_e); |
290 | compute_bounds(IH, OH, kh, SH, PH, DH, oh_s, oh_e); |
291 | compute_bounds(IW, OW, kw, SW, PW, DW, ow_s, ow_e); |
292 | const int64_t id_s = kd * DD - PD; |
293 | const int64_t ih_s = kh * DH - PH; |
294 | const int64_t iw_s = kw * DW - PW; |
295 | |
296 | for (int64_t mb = 0; mb < MB; ++mb) { |
297 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
298 | + (mb * OC + g * OCG + oc) * OD * OH * OW; |
299 | const float *__restrict src_loc = (const float *)src_m |
300 | + (mb * IC + g * ICG + ic) * ID * IH * IW; |
301 | |
302 | for_(int64_t od = od_s; od < od_e; ++od) |
303 | for_(int64_t oh = oh_s; oh < oh_e; ++oh) |
304 | for (int64_t ow = ow_s; ow < ow_e; ++ow) { |
305 | const int64_t id = od * SD + id_s; |
306 | const int64_t ih = oh * SH + ih_s; |
307 | const int64_t iw = ow * SW + iw_s; |
308 | |
309 | size_t diff_dst_off = (od * OH + oh) * OW + ow; |
310 | size_t src_off = (id * IH + ih) * IW + iw; |
311 | dw += diff_dst_loc[diff_dst_off] * src_loc[src_off]; |
312 | } |
313 | } |
314 | }; |
315 | |
316 | benchdnn_parallel_nd(G, OCG, ICG, KD, KH, KW, |
317 | [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh, |
318 | int64_t kw) { |
319 | size_t wei_off = wei_off_f(prb, g, oc, ic, kd, kh, kw); |
320 | float &dw = ((float *)diff_wei_m)[wei_off]; |
321 | dw = 0; |
322 | ker(dw, g, oc, ic, kd, kh, kw); |
323 | }); |
324 | } |
325 | |
326 | void compute_ref_bwd_bias(const prb_t *prb, const args_t &args) { |
327 | const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS); |
328 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
329 | /* help compiler optimize the code */ |
330 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc; |
331 | const int64_t OCG = OC / G; |
332 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
333 | |
334 | benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) { |
335 | size_t bia_off = bia_off_f(prb, g, oc); |
336 | double sum = 0; |
337 | |
338 | for_(int64_t mb = 0; mb < MB; ++mb) |
339 | for_(int64_t od = 0; od < OD; ++od) |
340 | for_(int64_t oh = 0; oh < OH; ++oh) |
341 | for (int64_t ow = 0; ow < OW; ++ow) { |
342 | size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow); |
343 | sum += ((float *)diff_dst_m)[dst_off]; |
344 | } |
345 | ((float *)diff_bia_m)[bia_off] = (float)sum; |
346 | }); |
347 | } |
348 | |
349 | void compute_ref_direct_bwd_w(const prb_t *prb, const args_t &args) { |
350 | compute_ref_bwd_weights(prb, args); |
351 | if (!(prb->dir & FLAG_BIA)) return; |
352 | compute_ref_bwd_bias(prb, args); |
353 | } |
354 | |
355 | void compute_ref_fwd( |
356 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
357 | if (prim_ref) { |
358 | SAFE_V(execute_and_wait(prim_ref, args)); |
359 | return; |
360 | } |
361 | |
362 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
363 | compute_wino_ref_fwd(prb, args); |
364 | } else { |
365 | compute_ref_direct_fwd(prb, args); |
366 | } |
367 | } |
368 | |
369 | void compute_ref_bwd_d( |
370 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
371 | if (prim_ref) { |
372 | SAFE_V(execute_and_wait(prim_ref, args)); |
373 | return; |
374 | } |
375 | |
376 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
377 | compute_wino_ref_bwd_d(prb, args); |
378 | } else { |
379 | compute_ref_direct_bwd_d(prb, args); |
380 | } |
381 | } |
382 | |
383 | void compute_ref_bwd_w( |
384 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
385 | if (prim_ref) { |
386 | SAFE_V(execute_and_wait(prim_ref, args)); |
387 | return; |
388 | } |
389 | |
390 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
391 | compute_wino_ref_bwd_w(prb, args); |
392 | } else { |
393 | compute_ref_direct_bwd_w(prb, args); |
394 | } |
395 | } |
396 | |
397 | void compute_ref( |
398 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
399 | if (prb->dir & FLAG_FWD) |
400 | compute_ref_fwd(prb, args, prim_ref); |
401 | else if (prb->dir == BWD_D) |
402 | compute_ref_bwd_d(prb, args, prim_ref); |
403 | else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) |
404 | compute_ref_bwd_w(prb, args, prim_ref); |
405 | } |
406 | |
407 | } // namespace conv |
408 | |