1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/dnnl_traits.hpp"
20#include "common/math_utils.hpp"
21#include "common/type_helpers.hpp"
22
23#include "cpu/cpu_primitive.hpp"
24#include "cpu/ref_io_helper.hpp"
25
26#include "cpu/ref_convolution.hpp"
27#include "cpu/ref_convolution_utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
34 status_t status = status::success;
35 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
36 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
37 auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
38 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
39 CHECK(status);
40
41 const memory_desc_wrapper src_d(pd()->src_md());
42 const memory_desc_wrapper dst_d(pd()->dst_md());
43 const memory_desc_wrapper weights_d(pd()->weights_md(0));
44 const memory_desc_wrapper bias_d(pd()->weights_md(1));
45
46 const bool with_groups = pd()->with_groups();
47
48 const auto G = pd()->G();
49 const auto MB = pd()->MB();
50 const auto OD = pd()->OD();
51 const auto OH = pd()->OH();
52 const auto OW = pd()->OW();
53 const auto ID = pd()->ID();
54 const auto IH = pd()->IH();
55 const auto IW = pd()->IW();
56
57 const auto OC = pd()->OC() / G;
58 const auto IC = pd()->IC() / G;
59 const auto KD = pd()->KD();
60 const auto KH = pd()->KH();
61 const auto KW = pd()->KW();
62
63 const auto KSD = pd()->KSD();
64 const auto KSH = pd()->KSH();
65 const auto KSW = pd()->KSW();
66
67 const auto KDD = pd()->KDD() + 1;
68 const auto KDH = pd()->KDH() + 1;
69 const auto KDW = pd()->KDW() + 1;
70
71 const auto padFront = pd()->padFront();
72 const auto padT = pd()->padT();
73 const auto padL = pd()->padL();
74
75 const auto ndims = pd()->desc()->src_desc.ndims;
76
77 auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
78 float d = 0;
79 for_(dim_t ic = 0; ic < IC; ++ic)
80 for_(dim_t kd = 0; kd < KD; ++kd)
81 for_(dim_t kh = 0; kh < KH; ++kh)
82 for (dim_t kw = 0; kw < KW; ++kw) {
83 const dim_t id = od * KSD - padFront + kd * KDD;
84 const dim_t ih = oh * KSH - padT + kh * KDH;
85 const dim_t iw = ow * KSW - padL + kw * KDW;
86
87 if (id < 0 || id >= ID) continue;
88 if (ih < 0 || ih >= IH) continue;
89 if (iw < 0 || iw >= IW) continue;
90
91 const auto src_off = ref_conv_utils::get_data_off(
92 src_d, ndims, mb, g * IC + ic, id, ih, iw);
93 const auto wei_off = ref_conv_utils::get_weights_off(
94 weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
95
96 const float s
97 = io::load_float_value(src_d.data_type(), src, src_off);
98 const float w = io::load_float_value(
99 weights_d.data_type(), weights, wei_off);
100 d += s * w;
101 }
102 return d;
103 };
104
105 // help compiler optimize the code constants for plain layouts kernel
106 const dims_t &src_str = src_d.blocking_desc().strides;
107 const dim_t src_ic_stride = src_str[1];
108 const dim_t src_id_stride = (ndims == 5) ? src_str[2] : 0;
109 const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0;
110 const dim_t src_iw_stride = (ndims >= 3) ? src_str[ndims - 1] : 0;
111 const dims_t &weights_str = weights_d.blocking_desc().strides;
112 const int gr_shift = with_groups ? 1 : 0;
113 const dim_t weights_ic_stride = weights_str[1 + gr_shift];
114 const dim_t weights_kd_stride
115 = (ndims == 5) ? weights_str[2 + gr_shift] : 0;
116 const dim_t weights_kh_stride
117 = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0;
118 const dim_t weights_kw_stride
119 = (ndims >= 3) ? weights_str[ndims - 1 + gr_shift] : 0;
120
121 auto ker_plain = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh,
122 dim_t ow) {
123 assert(3 <= ndims && ndims <= 5);
124 float d = 0;
125
126 const dim_t src_loc_off = ref_conv_utils::get_data_off(
127 src_d, ndims, mb, g * IC, 0, 0, 0);
128 const dim_t weights_loc_off = ref_conv_utils::get_weights_off(
129 weights_d, with_groups, ndims, g, oc, 0, 0, 0, 0);
130
131 const void *__restrict src_loc = src;
132 const void *__restrict weights_loc = weights;
133
134 if (IC > KW) {
135 for_(dim_t kd = 0; kd < KD; ++kd)
136 for_(dim_t kh = 0; kh < KH; ++kh)
137 for (dim_t kw = 0; kw < KW; ++kw) {
138 const dim_t id = od * KSD - padFront + kd * KDD;
139 const dim_t ih = oh * KSH - padT + kh * KDH;
140 const dim_t iw = ow * KSW - padL + kw * KDW;
141 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0
142 || iw >= IW)
143 continue;
144
145 for (dim_t ic = 0; ic < IC; ++ic) {
146 const dim_t src_off = ic + id * src_id_stride
147 + ih * src_ih_stride + iw * src_iw_stride;
148 const dim_t weights_off = ic * weights_ic_stride
149 + kd * weights_kd_stride + kh * weights_kh_stride
150 + kw;
151 const float s = io::load_float_value(
152 src_d.data_type(), src_loc, src_off + src_loc_off);
153 const float w = io::load_float_value(weights_d.data_type(),
154 weights_loc, weights_off + weights_loc_off);
155 d += s * w;
156 }
157 }
158 } else {
159 for_(dim_t ic = 0; ic < IC; ++ic)
160 for_(dim_t kd = 0; kd < KD; ++kd)
161 for_(dim_t kh = 0; kh < KH; ++kh)
162 for (dim_t kw = 0; kw < KW; ++kw) {
163 const dim_t id = od * KSD - padFront + kd * KDD;
164 const dim_t ih = oh * KSH - padT + kh * KDH;
165 const dim_t iw = ow * KSW - padL + kw * KDW;
166 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0
167 || iw >= IW)
168 continue;
169
170 const dim_t src_off = ic + id * src_id_stride
171 + ih * src_ih_stride + iw * src_iw_stride;
172 const dim_t weights_off = ic * weights_ic_stride
173 + kd * weights_kd_stride + kh * weights_kh_stride + kw;
174 const float s = io::load_float_value(
175 src_d.data_type(), src_loc, src_off + src_loc_off);
176 const float w = io::load_float_value(weights_d.data_type(),
177 weights_loc, weights_off + weights_loc_off);
178 d += s * w;
179 }
180 }
181 return d;
182 };
183
184 auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type());
185
186 parallel_nd(G, MB, OC, OD, OH, OW,
187 [&](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
188 float acc = 0;
189 if (src_d.is_plain() && weights_d.is_plain()
190 && src_ic_stride == 1 && weights_kw_stride == 1)
191 acc += ker_plain(g, mb, oc, od, oh, ow);
192 else
193 acc += ker(g, mb, oc, od, oh, ow);
194
195 float d = acc;
196 if (bias) {
197 const auto bias_off = bias_d.off(g * OC + oc);
198 const float b = io::load_float_value(
199 bias_d.data_type(), bias, bias_off);
200 d += b;
201 }
202
203 dim_t dst_off = ref_conv_utils::get_data_off(
204 dst_d, ndims, mb, g * OC + oc, od, oh, ow);
205
206 dim_t dst_l_off = (mb * OC * G + g * OC + oc) * OD * OH * OW
207 + od * OH * OW + oh * OW + ow;
208
209 ref_post_ops_t::args_t args;
210 args.dst_val = io::load_float_value(sum_dt, dst, dst_off);
211 args.ctx = &ctx;
212 args.l_offset = dst_l_off;
213 args.dst_md = pd()->dst_md();
214 ref_post_ops->execute(d, args);
215
216 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
217 });
218
219 return status::success;
220}
221
222status_t ref_convolution_bwd_data_t::execute_backward_data(
223 const exec_ctx_t &ctx) const {
224 status_t status = status::success;
225 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
226 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
227 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
228 CHECK(status);
229
230 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
231 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
232 const memory_desc_wrapper weights_d(pd()->weights_md(0));
233
234 const bool with_groups = pd()->with_groups();
235
236 const auto G = pd()->G();
237 const auto MB = pd()->MB();
238 const auto OD = pd()->OD();
239 const auto OH = pd()->OH();
240 const auto OW = pd()->OW();
241 const auto ID = pd()->ID();
242 const auto IH = pd()->IH();
243 const auto IW = pd()->IW();
244
245 const auto OC = pd()->OC() / G;
246 const auto IC = pd()->IC() / G;
247 const auto KD = pd()->KD();
248 const auto KH = pd()->KH();
249 const auto KW = pd()->KW();
250
251 const auto KSD = pd()->KSD();
252 const auto KSH = pd()->KSH();
253 const auto KSW = pd()->KSW();
254
255 const auto KDD = pd()->KDD() + 1;
256 const auto KDH = pd()->KDH() + 1;
257 const auto KDW = pd()->KDW() + 1;
258
259 const auto padFront = pd()->padFront();
260 const auto padT = pd()->padT();
261 const auto padL = pd()->padL();
262
263 const auto ndims = pd()->desc()->diff_src_desc.ndims;
264
265 auto ker = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) {
266 float ds = 0;
267 for_(dim_t oc = 0; oc < OC; ++oc)
268 for_(dim_t kd = 0; kd < KD; ++kd)
269 for_(dim_t kh = 0; kh < KH; ++kh)
270 for (dim_t kw = 0; kw < KW; ++kw) {
271 if (iw + padL < kw * KDW || ih + padT < kh * KDH
272 || id + padFront < kd * KDD)
273 continue;
274 dim_t ow = iw - kw * KDW + padL;
275 dim_t oh = ih - kh * KDH + padT;
276 dim_t od = id - kd * KDD + padFront;
277 if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) continue;
278
279 ow /= KSW;
280 oh /= KSH;
281 od /= KSD;
282
283 if (od < OD && oh < OH && ow < OW) {
284 const auto diff_dst_off = ref_conv_utils::get_data_off(
285 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
286 const auto weights_off = ref_conv_utils::get_weights_off(
287 weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
288 const float dd = io::load_float_value(
289 diff_dst_d.data_type(), diff_dst, diff_dst_off);
290 const float w = io::load_float_value(
291 weights_d.data_type(), weights, weights_off);
292 ds += dd * w;
293 }
294 }
295 return ds;
296 };
297
298 // help compiler optimize the code constants for plain layouts kernel
299 const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides;
300 const dim_t diff_dst_oc_stride = diff_dst_str[1];
301 const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1];
302 const dim_t diff_dst_oh_stride = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0;
303 const dim_t diff_dst_od_stride = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0;
304
305 const dims_t &weights_str = weights_d.blocking_desc().strides;
306 const int gr_shift = with_groups ? 1 : 0;
307 const dim_t weights_oc_stride = weights_str[0 + gr_shift];
308 const dim_t weights_kw_stride = weights_str[ndims - 1 + gr_shift];
309 const dim_t weights_kh_stride
310 = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0;
311 const dim_t weights_kd_stride
312 = (ndims >= 5) ? weights_str[ndims - 3 + gr_shift] : 0;
313
314 auto ker_plain = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih,
315 dim_t iw) {
316 assert(3 <= ndims && ndims <= 5);
317 float ds = 0;
318 const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off(
319 diff_dst_d, ndims, mb, g * OC, 0, 0, 0);
320 const dim_t weights_loc_off = ref_conv_utils::get_weights_off(
321 weights_d, with_groups, ndims, g, 0, ic, 0, 0, 0);
322
323 const void *__restrict diff_dst_loc = diff_dst;
324 const void *__restrict weights_loc = weights;
325
326 if (OC > KW) {
327 for_(dim_t kd = 0; kd < KD; ++kd)
328 for_(dim_t kh = 0; kh < KH; ++kh)
329 for (dim_t kw = 0; kw < KW; ++kw) {
330 dim_t ow = iw - kw * KDW + padL;
331 dim_t oh = ih - kh * KDH + padT;
332 dim_t od = id - kd * KDD + padFront;
333 if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 || oh % KSH != 0
334 || od % KSD != 0)
335 continue;
336 ow /= KSW;
337 oh /= KSH;
338 od /= KSD;
339 if (od >= OD || oh >= OH || ow >= OW) continue;
340 for (dim_t oc = 0; oc < OC; ++oc) {
341 const dim_t diff_dst_off = oc + od * diff_dst_od_stride
342 + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride;
343 const dim_t weights_off = oc * weights_oc_stride
344 + kd * weights_kd_stride + kh * weights_kh_stride
345 + kw;
346 const float dd = io::load_float_value(
347 diff_dst_d.data_type(), diff_dst_loc,
348 diff_dst_off + diff_dst_loc_off);
349 const float w = io::load_float_value(weights_d.data_type(),
350 weights_loc, weights_off + weights_loc_off);
351 ds += dd * w;
352 }
353 }
354 } else {
355 for_(dim_t oc = 0; oc < OC; ++oc)
356 for_(dim_t kd = 0; kd < KD; ++kd)
357 for (dim_t kh = 0; kh < KH; ++kh) {
358 // Note: placing these 2 params outside the `kw-loop` because
359 // of a compiler-generated bug. Declaring 'od' as volatile
360 // fixes a recurring seg-fault.
361 const volatile dim_t od_ = id - kd * KDD + padFront;
362 const dim_t weights_off_ = oc * weights_oc_stride
363 + kd * weights_kd_stride + kh * weights_kh_stride;
364 for (dim_t kw = 0; kw < KW; ++kw) {
365 dim_t ow = iw - kw * KDW + padL;
366 dim_t oh = ih - kh * KDH + padT;
367 dim_t od = od_;
368 if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0
369 || oh % KSH != 0 || od % KSD != 0)
370 continue;
371 ow /= KSW;
372 oh /= KSH;
373 od /= KSD;
374 if (od >= OD || oh >= OH || ow >= OW) continue;
375 const dim_t diff_dst_off = oc + od * diff_dst_od_stride
376 + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride;
377 const dim_t weights_off = weights_off_ + kw;
378 const float dd = io::load_float_value(
379 diff_dst_d.data_type(), diff_dst_loc,
380 diff_dst_off + diff_dst_loc_off);
381 const float w = io::load_float_value(weights_d.data_type(),
382 weights_loc, weights_off + weights_loc_off);
383 ds += dd * w;
384 }
385 }
386 }
387 return ds;
388 };
389
390 parallel_nd(G, MB, IC, ID, IH, IW,
391 [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) {
392 float ds = 0;
393 if (diff_dst_d.is_plain() && weights_d.is_plain()
394 && diff_dst_oc_stride == 1 && weights_kw_stride == 1)
395 ds += ker_plain(g, mb, ic, id, ih, iw);
396 else
397 ds += ker(g, mb, ic, id, ih, iw);
398
399 const auto diff_src_off = ref_conv_utils::get_data_off(
400 diff_src_d, ndims, mb, g * IC + ic, id, ih, iw);
401 io::store_float_value(
402 diff_src_d.data_type(), ds, diff_src, diff_src_off);
403 });
404
405 return status::success;
406}
407
408status_t ref_convolution_bwd_weights_t::execute_backward_weights(
409 const exec_ctx_t &ctx) const {
410 status_t status = status::success;
411 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
412 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
413 auto diff_weights
414 = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_WEIGHTS, status);
415 CHECK(status);
416 auto diff_bias = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_BIAS, status);
417 CHECK(status);
418
419 const memory_desc_wrapper src_d(pd()->src_md());
420 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
421 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
422 const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
423
424 const bool with_groups = pd()->with_groups();
425
426 const auto G = pd()->G();
427 const auto MB = pd()->MB();
428 const auto OD = pd()->OD();
429 const auto OH = pd()->OH();
430 const auto OW = pd()->OW();
431 const auto ID = pd()->ID();
432 const auto IH = pd()->IH();
433 const auto IW = pd()->IW();
434
435 const auto OC = pd()->OC() / G;
436 const auto IC = pd()->IC() / G;
437 const auto KD = pd()->KD();
438 const auto KH = pd()->KH();
439 const auto KW = pd()->KW();
440
441 const auto KSD = pd()->KSD();
442 const auto KSH = pd()->KSH();
443 const auto KSW = pd()->KSW();
444
445 const auto KDD = pd()->KDD() + 1;
446 const auto KDH = pd()->KDH() + 1;
447 const auto KDW = pd()->KDW() + 1;
448
449 const auto padFront = pd()->padFront();
450 const auto padT = pd()->padT();
451 const auto padL = pd()->padL();
452
453 const auto ndims = pd()->desc()->src_desc.ndims;
454
455 auto ker = [=](float &dw, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh,
456 dim_t kw) {
457 for_(dim_t mb = 0; mb < MB; ++mb)
458 for_(dim_t od = 0; od < OD; ++od)
459 for_(dim_t oh = 0; oh < OH; ++oh)
460 for (dim_t ow = 0; ow < OW; ++ow) {
461 if (ow * KSW + kw * KDW < padL || oh * KSH + kh * KDH < padT
462 || od * KSD + kd * KDD < padFront
463 || ow * KSW + kw * KDW >= IW + padL
464 || oh * KSH + kh * KDH >= IH + padT
465 || od * KSD + kd * KDD >= ID + padFront)
466 continue;
467
468 dim_t id = od * KSD - padFront + kd * KDD;
469 dim_t ih = oh * KSH - padT + kh * KDH;
470 dim_t iw = ow * KSW - padL + kw * KDW;
471
472 const auto diff_dst_off = ref_conv_utils::get_data_off(
473 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
474 const auto src_off = ref_conv_utils::get_data_off(
475 src_d, ndims, mb, g * IC + ic, id, ih, iw);
476 float dd = io::load_float_value(
477 diff_dst_d.data_type(), diff_dst, diff_dst_off);
478 float s = io::load_float_value(src_d.data_type(), src, src_off);
479 dw += dd * s;
480 }
481 };
482
483 auto ker_plain = [=](float &dw, dim_t g, dim_t oc, dim_t ic, dim_t kd,
484 dim_t kh, dim_t kw) {
485 assert(3 <= ndims && ndims <= 5);
486 // help compiler optimize the code constants for plain layouts kernel
487 const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides;
488 const dim_t diff_dst_mb_stride = diff_dst_str[0];
489 const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1];
490 const dim_t diff_dst_oh_stride
491 = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0;
492 const dim_t diff_dst_od_stride
493 = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0;
494 const dims_t &src_str = src_d.blocking_desc().strides;
495 const dim_t src_mb_stride = src_str[0];
496 const dim_t src_iw_stride = src_str[ndims - 1];
497 const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0;
498 const dim_t src_id_stride = (ndims >= 5) ? src_str[ndims - 3] : 0;
499
500 const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off(
501 diff_dst_d, ndims, 0, g * OC + oc, 0, 0, 0);
502 const dim_t src_loc_off = ref_conv_utils::get_data_off(
503 src_d, ndims, 0, g * IC + ic, 0, 0, 0);
504
505 const void *__restrict diff_dst_loc = diff_dst;
506 const void *__restrict src_loc = src;
507
508 for_(dim_t mb = 0; mb < MB; ++mb)
509 for_(dim_t od = 0; od < OD; ++od)
510 for_(dim_t oh = 0; oh < OH; ++oh)
511 for (dim_t ow = 0; ow < OW; ++ow) {
512 const dim_t id = od * KSD - padFront + kd * KDD;
513 const dim_t ih = oh * KSH - padT + kh * KDH;
514 const dim_t iw = ow * KSW - padL + kw * KDW;
515 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 || iw >= IW)
516 continue;
517 const dim_t diff_dst_off = mb * diff_dst_mb_stride
518 + od * diff_dst_od_stride + oh * diff_dst_oh_stride
519 + ow * diff_dst_ow_stride;
520 const dim_t src_off = mb * src_mb_stride + id * src_id_stride
521 + ih * src_ih_stride + iw * src_iw_stride;
522 float dd = io::load_float_value(diff_dst_d.data_type(),
523 diff_dst_loc, diff_dst_off + diff_dst_loc_off);
524 float s = io::load_float_value(
525 src_d.data_type(), src_loc, src_off + src_loc_off);
526 dw += dd * s;
527 }
528 };
529
530 auto ker_bias = [=](float &db, dim_t g, dim_t oc) {
531 for_(dim_t mb = 0; mb < MB; ++mb)
532 for_(dim_t od = 0; od < OD; ++od)
533 for_(dim_t oh = 0; oh < OH; ++oh)
534 for (dim_t ow = 0; ow < OW; ++ow) {
535 const auto diff_dst_off = ref_conv_utils::get_data_off(
536 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
537 const float dd = io::load_float_value(
538 diff_dst_d.data_type(), diff_dst, diff_dst_off);
539 db += dd;
540 }
541 };
542
543 parallel_nd(G, OC, [&](dim_t g, dim_t oc) {
544 if (diff_bias) {
545 float db = 0;
546 ker_bias(db, g, oc);
547 const auto diff_bias_off = diff_bias_d.off(g * OC + oc);
548 io::store_float_value(
549 diff_bias_d.data_type(), db, diff_bias, diff_bias_off);
550 }
551
552 for_(dim_t ic = 0; ic < IC; ++ic)
553 for_(dim_t kd = 0; kd < KD; ++kd)
554 for_(dim_t kh = 0; kh < KH; ++kh)
555 for (dim_t kw = 0; kw < KW; ++kw) {
556 float dw = 0;
557 if (diff_dst_d.is_plain() && src_d.is_plain())
558 ker_plain(dw, g, oc, ic, kd, kh, kw);
559 else
560 ker(dw, g, oc, ic, kd, kh, kw);
561
562 const dim_t diff_weights_off = ref_conv_utils::get_weights_off(
563 diff_weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
564 io::store_float_value(diff_weights_d.data_type(), dw, diff_weights,
565 diff_weights_off);
566 }
567 });
568
569 return status::success;
570}
571
572} // namespace cpu
573} // namespace impl
574} // namespace dnnl
575
576// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
577