1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <utility> |
18 | |
19 | #include "utils/parallel.hpp" |
20 | |
21 | #include "deconv/ref_deconv.hpp" |
22 | |
23 | namespace deconv { |
24 | |
25 | void compute_ref_direct_fwd(const prb_t *prb, const args_t &args) { |
26 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
27 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
28 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
29 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
30 | /* help compiler optimize the code */ |
31 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
32 | const int64_t OCG = OC / G, ICG = IC / G; |
33 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
34 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
35 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
36 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
37 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
38 | const int64_t DD = prb->dd + 1; |
39 | const int64_t DH = prb->dh + 1; |
40 | const int64_t DW = prb->dw + 1; |
41 | |
42 | auto ker = [&](float &d, int64_t g, int64_t mb, int64_t oc, int64_t od, |
43 | int64_t oh, int64_t ow) { |
44 | const float *__restrict src_loc |
45 | = (const float *)src_m + (mb * IC + g * ICG) * ID * IH * IW; |
46 | const float *__restrict wei_loc |
47 | = (const float *)wei_m + (g * OCG + oc) * ICG * KD * KH * KW; |
48 | |
49 | for (int64_t kd = 0; kd < KD; ++kd) { |
50 | const int64_t id = od * SD - PD + kd * DD; |
51 | if (id < 0 || id >= ID) continue; |
52 | for (int64_t kh = 0; kh < KH; ++kh) { |
53 | const int64_t ih = oh * SH - PH + kh * DH; |
54 | if (ih < 0 || ih >= IH) continue; |
55 | for (int64_t kw = 0; kw < KW; ++kw) { |
56 | const int64_t iw = ow * SW - PW + kw * DW; |
57 | if (iw < 0 || iw >= IW) continue; |
58 | |
59 | for (int64_t ic = 0; ic < ICG; ++ic) { |
60 | int64_t src_off = ((ic * ID + id) * IH + ih) * IW + iw; |
61 | int64_t wei_off = ((ic * KD + kd) * KH + kh) * KW + kw; |
62 | float s = src_loc[src_off]; |
63 | maybe_zero_point(prb->attr, s, prb->src_zp, |
64 | g * ICG + ic, DNNL_ARG_SRC); |
65 | maybe_scale(prb->attr, s, prb->src_scales, g * ICG + ic, |
66 | DNNL_ARG_SRC); |
67 | float w = wei_loc[wei_off]; |
68 | maybe_scale(prb->attr, w, prb->wei_scales, g * OCG + oc, |
69 | DNNL_ARG_WEIGHTS); |
70 | d += s * w; |
71 | } |
72 | } |
73 | } |
74 | } |
75 | }; |
76 | |
77 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
78 | benchdnn_parallel_nd(G, MB, OCG, OD, OH, OW, |
79 | [&](int64_t g, int64_t mb, int64_t oc, int64_t od, int64_t oh, |
80 | int64_t ow) { |
81 | const size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow); |
82 | float &dst = ((float *)dst_m)[dst_off]; |
83 | |
84 | float conv_res = 0; |
85 | ker(conv_res, g, mb, oc, od, oh, ow); |
86 | |
87 | if (prb->dir & FLAG_BIA) { |
88 | const size_t bia_off = bia_off_f(prb, g, oc); |
89 | conv_res += ((float *)bia_m)[bia_off]; |
90 | } |
91 | |
92 | const auto v_po_vals |
93 | = prepare_po_vals(dst_m, args, v_po_masks, dst_off); |
94 | |
95 | maybe_post_ops(prb->attr, conv_res, dst, v_po_vals); |
96 | |
97 | maybe_scale(prb->attr, conv_res, prb->dst_scales, g * OCG + oc, |
98 | DNNL_ARG_DST, true); |
99 | maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * OCG + oc, |
100 | DNNL_ARG_DST, true); |
101 | |
102 | dst = conv_res; |
103 | }); |
104 | } |
105 | |
106 | void compute_ref_direct_bwd_d(const prb_t *prb, const args_t &args) { |
107 | const dnn_mem_t &diff_src_m = args.find(DNNL_ARG_DIFF_SRC); |
108 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
109 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
110 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
111 | /* help compiler optimize the code */ |
112 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
113 | const int64_t OCG = OC / G, ICG = IC / G; |
114 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
115 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
116 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
117 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
118 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
119 | const int64_t DD = prb->dd + 1; |
120 | const int64_t DH = prb->dh + 1; |
121 | const int64_t DW = prb->dw + 1; |
122 | |
123 | enum { precompute_size = 16 }; |
124 | const bool fast = MAX3(KD, KH, KW) <= precompute_size; |
125 | |
126 | // from bwd pov zp src from fwd is zp diff dst and |
127 | // zp dst is zp dst is zp diff_src |
128 | const auto map_arg_to_zp_arg = [](int num) { |
129 | switch (num) { |
130 | case DNNL_ARG_DIFF_DST: return DNNL_ARG_SRC; |
131 | case DNNL_ARG_DIFF_SRC: return DNNL_ARG_DST; |
132 | default: assert(false && "map_arg_to_zp_arg unsupported arg" ); |
133 | } |
134 | |
135 | return -1; |
136 | }; |
137 | |
138 | /* pre-computes arrays of oh(ow) and kh(kw) for traversing in kernel */ |
139 | auto precompute_ok |
140 | = [](int64_t i, int64_t O, int64_t K, int64_t S, int64_t P, |
141 | int64_t D, int64_t &num, int64_t *_o, int64_t *_k) { |
142 | assert(K <= precompute_size); |
143 | num = 0; |
144 | for (int64_t k = 0; k < K; ++k) { |
145 | int64_t o = i - k * D + P; |
146 | if (o < 0 || o % S) continue; |
147 | o /= S; |
148 | if (o >= O) continue; |
149 | _k[num] = k; |
150 | _o[num] = o; |
151 | ++num; |
152 | } |
153 | }; |
154 | |
155 | auto ker_fast = [&](float &ds, int64_t g, int64_t mb, int64_t ic, |
156 | int64_t id, int64_t ih, int64_t iw) { |
157 | int64_t kd[precompute_size], od[precompute_size], num_d; |
158 | int64_t kh[precompute_size], oh[precompute_size], num_h; |
159 | int64_t kw[precompute_size], ow[precompute_size], num_w; |
160 | precompute_ok(id, OD, KD, SD, PD, DD, num_d, od, kd); |
161 | precompute_ok(ih, OH, KH, SH, PH, DH, num_h, oh, kh); |
162 | precompute_ok(iw, OW, KW, SW, PW, DW, num_w, ow, kw); |
163 | |
164 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
165 | + (mb * OC + g * OCG) * OD * OH * OW; |
166 | const float *__restrict wei_loc |
167 | = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW; |
168 | |
169 | for_(int64_t d = 0; d < num_d; ++d) |
170 | for_(int64_t h = 0; h < num_h; ++h) |
171 | for (int64_t w = 0; w < num_w; ++w) { |
172 | for (int64_t oc = 0; oc < OCG; ++oc) { |
173 | const int64_t diff_dst_off |
174 | = ((oc * OD + od[d]) * OH + oh[h]) * OW + ow[w]; |
175 | const int64_t wei_off |
176 | = ((oc * ICG * KD + kd[d]) * KH + kh[h]) * KW + kw[w]; |
177 | float diff_dst_val = diff_dst_loc[diff_dst_off]; |
178 | maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp, |
179 | g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
180 | maybe_scale(prb->attr, diff_dst_val, prb->src_scales, |
181 | g * OCG + oc, map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
182 | |
183 | float wei_val = wei_loc[wei_off]; |
184 | maybe_scale(prb->attr, wei_val, prb->wei_scales, g * ICG + ic, |
185 | DNNL_ARG_WEIGHTS); |
186 | ds += diff_dst_val * wei_val; |
187 | } |
188 | } |
189 | }; |
190 | |
191 | auto ker = [&](float &ds, int64_t g, int64_t mb, int64_t ic, int64_t id, |
192 | int64_t ih, int64_t iw) { |
193 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
194 | + (mb * OC + g * OCG) * OD * OH * OW; |
195 | const float *__restrict wei_loc |
196 | = (const float *)wei_m + ((g * OCG) * ICG + ic) * KD * KH * KW; |
197 | |
198 | for (int64_t kd = 0; kd < KD; ++kd) { |
199 | int64_t od = id - kd * DD + PD; |
200 | if (od < 0 || od % SD || od >= OD * SD) continue; |
201 | od /= SD; |
202 | for (int64_t kh = 0; kh < KH; ++kh) { |
203 | int64_t oh = ih - kh * DH + PH; |
204 | if (oh < 0 || oh % SH || oh >= OH * SH) continue; |
205 | oh /= SH; |
206 | for (int64_t kw = 0; kw < KW; ++kw) { |
207 | int64_t ow = iw - kw * DW + PW; |
208 | if (ow < 0 || ow % SW || ow >= OW * SW) continue; |
209 | ow /= SW; |
210 | for (int64_t oc = 0; oc < OCG; ++oc) { |
211 | const int64_t diff_dst_off |
212 | = ((oc * OD + od) * OH + oh) * OW + ow; |
213 | const int64_t wei_off |
214 | = ((oc * ICG * KD + kd) * KH + kh) * KW + kw; |
215 | float diff_dst_val = diff_dst_loc[diff_dst_off]; |
216 | maybe_zero_point(prb->attr, diff_dst_val, prb->src_zp, |
217 | g * OCG + oc, |
218 | map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
219 | maybe_scale(prb->attr, diff_dst_val, prb->src_scales, |
220 | g * OCG + oc, |
221 | map_arg_to_zp_arg(DNNL_ARG_DIFF_DST)); |
222 | |
223 | float wei_val = wei_loc[wei_off]; |
224 | maybe_scale(prb->attr, wei_val, prb->wei_scales, |
225 | g * ICG + ic, DNNL_ARG_WEIGHTS); |
226 | |
227 | ds += diff_dst_val * wei_val; |
228 | } |
229 | } |
230 | } |
231 | } |
232 | }; |
233 | |
234 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
235 | benchdnn_parallel_nd(G, MB, ICG, ID, IH, IW, |
236 | [&](int64_t g, int64_t mb, int64_t ic, int64_t id, int64_t ih, |
237 | int64_t iw) { |
238 | size_t src_off = src_off_f(prb, mb, g, ic, id, ih, iw); |
239 | float &ds = ((float *)diff_src_m)[src_off]; |
240 | float conv_res = 0; |
241 | if (fast) |
242 | ker_fast(conv_res, g, mb, ic, id, ih, iw); |
243 | else |
244 | ker(conv_res, g, mb, ic, id, ih, iw); |
245 | |
246 | if (prb->dir & FLAG_BIA) { |
247 | const size_t bia_off = (size_t)g * ICG + ic; |
248 | conv_res += ((float *)bia_m)[bia_off]; |
249 | } |
250 | |
251 | const auto v_po_vals = prepare_po_vals( |
252 | diff_src_m, args, v_po_masks, src_off); |
253 | |
254 | maybe_post_ops(prb->attr, conv_res, ds, v_po_vals); |
255 | maybe_scale(prb->attr, conv_res, prb->dst_scales, g * ICG + ic, |
256 | map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true); |
257 | maybe_zero_point(prb->attr, conv_res, prb->dst_zp, g * ICG + ic, |
258 | map_arg_to_zp_arg(DNNL_ARG_DIFF_SRC), true); |
259 | |
260 | ds = conv_res; |
261 | }); |
262 | } |
263 | |
264 | void compute_ref_bwd_weights(const prb_t *prb, const args_t &args) { |
265 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
266 | const dnn_mem_t &diff_wei_m = args.find(DNNL_ARG_DIFF_WEIGHTS); |
267 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
268 | /* help compiler optimize the code */ |
269 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc, IC = prb->ic; |
270 | const int64_t OCG = OC / G, ICG = IC / G; |
271 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
272 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
273 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
274 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
275 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
276 | const int64_t DD = prb->dd + 1; |
277 | const int64_t DH = prb->dh + 1; |
278 | const int64_t DW = prb->dw + 1; |
279 | |
280 | auto compute_bounds |
281 | = [](int64_t I, int64_t O, int64_t k, int64_t S, int64_t P, |
282 | int64_t D, int64_t &o_s, int64_t &o_e) { |
283 | const float tmp = P - k * D; |
284 | o_s = MAX2(0, ceilf(tmp / S)); |
285 | o_e = MIN2(O, ceilf((I + tmp) / S)); |
286 | }; |
287 | |
288 | auto ker = [&](float &dw, int64_t g, int64_t oc, int64_t ic, int64_t kd, |
289 | int64_t kh, int64_t kw) { |
290 | int64_t od_s, od_e, oh_s, oh_e, ow_s, ow_e; |
291 | compute_bounds(ID, OD, kd, SD, PD, DD, od_s, od_e); |
292 | compute_bounds(IH, OH, kh, SH, PH, DH, oh_s, oh_e); |
293 | compute_bounds(IW, OW, kw, SW, PW, DW, ow_s, ow_e); |
294 | const int64_t id_s = kd * DD - PD; |
295 | const int64_t ih_s = kh * DH - PH; |
296 | const int64_t iw_s = kw * DW - PW; |
297 | |
298 | for (int64_t mb = 0; mb < MB; ++mb) { |
299 | const float *__restrict diff_dst_loc = (const float *)diff_dst_m |
300 | + (mb * OC + g * OCG + oc) * OD * OH * OW; |
301 | const float *__restrict src_loc = (const float *)src_m |
302 | + (mb * IC + g * ICG + ic) * ID * IH * IW; |
303 | |
304 | for_(int64_t od = od_s; od < od_e; ++od) |
305 | for_(int64_t oh = oh_s; oh < oh_e; ++oh) |
306 | for (int64_t ow = ow_s; ow < ow_e; ++ow) { |
307 | const int64_t id = od * SD + id_s; |
308 | const int64_t ih = oh * SH + ih_s; |
309 | const int64_t iw = ow * SW + iw_s; |
310 | |
311 | size_t diff_dst_off = (od * OH + oh) * OW + ow; |
312 | size_t src_off = (id * IH + ih) * IW + iw; |
313 | dw += diff_dst_loc[diff_dst_off] * src_loc[src_off]; |
314 | } |
315 | } |
316 | }; |
317 | |
318 | benchdnn_parallel_nd(G, OCG, ICG, KD, KH, KW, |
319 | [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh, |
320 | int64_t kw) { |
321 | size_t wei_off = wei_off_f(prb, g, oc, ic, kd, kh, kw); |
322 | float &dw = ((float *)diff_wei_m)[wei_off]; |
323 | dw = 0; |
324 | ker(dw, g, oc, ic, kd, kh, kw); |
325 | }); |
326 | } |
327 | |
328 | void compute_ref_bwd_bias(const prb_t *prb, const args_t &args) { |
329 | const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS); |
330 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
331 | /* help compiler optimize the code */ |
332 | const int64_t MB = prb->mb, G = prb->g, OC = prb->oc; |
333 | const int64_t OCG = OC / G; |
334 | const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow; |
335 | |
336 | benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) { |
337 | size_t bia_off = bia_off_f(prb, g, oc); |
338 | double sum = 0; |
339 | |
340 | for_(int64_t mb = 0; mb < MB; ++mb) |
341 | for_(int64_t od = 0; od < OD; ++od) |
342 | for_(int64_t oh = 0; oh < OH; ++oh) |
343 | for (int64_t ow = 0; ow < OW; ++ow) { |
344 | size_t dst_off = dst_off_f(prb, mb, g, oc, od, oh, ow); |
345 | sum += ((float *)diff_dst_m)[dst_off]; |
346 | } |
347 | ((float *)diff_bia_m)[bia_off] = (float)sum; |
348 | }); |
349 | } |
350 | |
351 | void compute_ref_fwd( |
352 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
353 | if (prim_ref) { |
354 | SAFE_V(execute_and_wait(prim_ref, args)); |
355 | return; |
356 | } |
357 | |
358 | // Swap arguments to re-use existing conv ref implementation. |
359 | args_t ref_conv_args; |
360 | for (int i = 0; i < args.size(); i++) { |
361 | if (args.arg(i) == DNNL_ARG_SRC) |
362 | ref_conv_args.set(DNNL_ARG_DIFF_DST, args.dnn_mem(i)); |
363 | else if (args.arg(i) == DNNL_ARG_WEIGHTS) |
364 | ref_conv_args.set(DNNL_ARG_WEIGHTS, args.find(DNNL_ARG_WEIGHTS_1)); |
365 | else if (args.arg(i) == DNNL_ARG_DST) |
366 | ref_conv_args.set(DNNL_ARG_DIFF_SRC, args.dnn_mem(i)); |
367 | else |
368 | ref_conv_args.set(args.arg(i), args.dnn_mem(i)); |
369 | } |
370 | |
371 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
372 | compute_wino_ref_bwd_d(prb, ref_conv_args); |
373 | } else { |
374 | compute_ref_direct_bwd_d(prb, ref_conv_args); |
375 | } |
376 | } |
377 | |
378 | void compute_ref_bwd_d( |
379 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
380 | if (prim_ref) { |
381 | SAFE_V(execute_and_wait(prim_ref, args)); |
382 | return; |
383 | } |
384 | |
385 | // Swap arguments to re-use existing conv ref implementation. |
386 | args_t ref_conv_args; |
387 | for (int i = 0; i < args.size(); i++) { |
388 | if (args.arg(i) == DNNL_ARG_DIFF_SRC) |
389 | ref_conv_args.set(DNNL_ARG_DST, args.dnn_mem(i)); |
390 | else if (args.arg(i) == DNNL_ARG_WEIGHTS) |
391 | ref_conv_args.set(DNNL_ARG_WEIGHTS, args.find(DNNL_ARG_WEIGHTS_1)); |
392 | else if (args.arg(i) == DNNL_ARG_DIFF_DST) |
393 | ref_conv_args.set(DNNL_ARG_SRC, args.dnn_mem(i)); |
394 | else |
395 | ref_conv_args.set(args.arg(i), args.dnn_mem(i)); |
396 | } |
397 | |
398 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
399 | compute_wino_ref_fwd(prb, ref_conv_args); |
400 | } else { |
401 | compute_ref_direct_fwd(prb, ref_conv_args); |
402 | } |
403 | } |
404 | |
405 | void compute_ref_bwd_w( |
406 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
407 | if (prim_ref) { |
408 | SAFE_V(execute_and_wait(prim_ref, args)); |
409 | return; |
410 | } |
411 | |
412 | // Swap arguments to re-use existing conv ref implementation. |
413 | args_t ref_conv_args; |
414 | for (int i = 0; i < args.size(); i++) { |
415 | if (args.arg(i) == DNNL_ARG_SRC) |
416 | ref_conv_args.set(DNNL_ARG_DIFF_DST, args.dnn_mem(i)); |
417 | else if (args.arg(i) == DNNL_ARG_DIFF_WEIGHTS) |
418 | ref_conv_args.set( |
419 | DNNL_ARG_DIFF_WEIGHTS, args.find(DNNL_ARG_DIFF_WEIGHTS_1)); |
420 | else if (args.arg(i) == DNNL_ARG_DIFF_DST) |
421 | ref_conv_args.set(DNNL_ARG_SRC, args.dnn_mem(i)); |
422 | else |
423 | ref_conv_args.set(args.arg(i), args.dnn_mem(i)); |
424 | } |
425 | |
426 | if (prb->alg == WINO && prb->cfg[SRC].dt == dnnl_f32) { |
427 | compute_wino_ref_bwd_w(prb, ref_conv_args); |
428 | } else { |
429 | compute_ref_bwd_weights(prb, ref_conv_args); |
430 | } |
431 | |
432 | // Need to transpose data in weights back for proper comparison. This step |
433 | // is done here as it's not needed for fast-ref-gpu. |
434 | transpose_data_wei(prb, args.find(DNNL_ARG_DIFF_WEIGHTS_1), |
435 | args.find(DNNL_ARG_DIFF_WEIGHTS)); |
436 | |
437 | // We don't reuse `compute_ref_bwd_bias` as it doesn't match arguments and |
438 | // entry problem which is transposed - `p_tr`. Simpler to use the kernel |
439 | // directly. |
440 | // Take original memories, not `ref_conv_args`. |
441 | if (prb->dir & FLAG_BIA) { |
442 | const dnn_mem_t &diff_bia_m = args.find(DNNL_ARG_DIFF_BIAS); |
443 | const dnn_mem_t &diff_dst_m = args.find(DNNL_ARG_DIFF_DST); |
444 | /* help compiler optimize the code */ |
445 | const int64_t MB = prb->mb, G = prb->g; |
446 | const int64_t OC = prb->ic; // prb.oc = p_tr.ic |
447 | const int64_t OCG = OC / G; |
448 | const int64_t OD = prb->id; // prb.od = p_tr.id |
449 | const int64_t OH = prb->ih; // prb.oh = p_tr.ih |
450 | const int64_t OW = prb->iw; // prb.ow = p_tr.iw |
451 | |
452 | benchdnn_parallel_nd(G, OCG, [&](int64_t g, int64_t oc) { |
453 | size_t bia_off = g * OCG + oc; |
454 | double sum = 0; |
455 | |
456 | for_(int64_t mb = 0; mb < MB; ++mb) |
457 | for_(int64_t od = 0; od < OD; ++od) |
458 | for_(int64_t oh = 0; oh < OH; ++oh) |
459 | for (int64_t ow = 0; ow < OW; ++ow) { |
460 | // src_off_f instead of dst_off_f due to inverse descriptor. |
461 | size_t dst_off = src_off_f(prb, mb, g, oc, od, oh, ow); |
462 | sum += ((float *)diff_dst_m)[dst_off]; |
463 | } |
464 | ((float *)diff_bia_m)[bia_off] = (float)sum; |
465 | }); |
466 | } |
467 | } |
468 | |
469 | void compute_ref( |
470 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
471 | // Update prb descriptor to re-use convolution reference. |
472 | prb_t prb_tr((desc_t)*prb, prb->dir, prb->cfg, prb->stag, prb->wtag, |
473 | prb->dtag, prb->alg, prb->attr, prb->ctx_init, prb->ctx_exe, |
474 | prb->mb); |
475 | std::swap(prb_tr.ic, prb_tr.oc); |
476 | std::swap(prb_tr.ih, prb_tr.oh); |
477 | std::swap(prb_tr.id, prb_tr.od); |
478 | std::swap(prb_tr.iw, prb_tr.ow); |
479 | |
480 | if (prb->dir & FLAG_FWD) |
481 | compute_ref_fwd(&prb_tr, args, prim_ref); |
482 | else if (prb->dir == BWD_D) |
483 | compute_ref_bwd_d(&prb_tr, args, prim_ref); |
484 | else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) |
485 | compute_ref_bwd_w(&prb_tr, args, prim_ref); |
486 | } |
487 | |
488 | } // namespace deconv |
489 | |