1 | /******************************************************************************* |
2 | * Copyright 2016-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/dnnl_traits.hpp" |
20 | #include "common/math_utils.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | |
23 | #include "cpu/cpu_primitive.hpp" |
24 | #include "cpu/ref_io_helper.hpp" |
25 | |
26 | #include "cpu/ref_convolution.hpp" |
27 | #include "cpu/ref_convolution_utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | status_t ref_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
34 | status_t status = status::success; |
35 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
36 | auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
37 | auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
38 | auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); |
39 | CHECK(status); |
40 | |
41 | const memory_desc_wrapper src_d(pd()->src_md()); |
42 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
43 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
44 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
45 | |
46 | const bool with_groups = pd()->with_groups(); |
47 | |
48 | const auto G = pd()->G(); |
49 | const auto MB = pd()->MB(); |
50 | const auto OD = pd()->OD(); |
51 | const auto OH = pd()->OH(); |
52 | const auto OW = pd()->OW(); |
53 | const auto ID = pd()->ID(); |
54 | const auto IH = pd()->IH(); |
55 | const auto IW = pd()->IW(); |
56 | |
57 | const auto OC = pd()->OC() / G; |
58 | const auto IC = pd()->IC() / G; |
59 | const auto KD = pd()->KD(); |
60 | const auto KH = pd()->KH(); |
61 | const auto KW = pd()->KW(); |
62 | |
63 | const auto KSD = pd()->KSD(); |
64 | const auto KSH = pd()->KSH(); |
65 | const auto KSW = pd()->KSW(); |
66 | |
67 | const auto KDD = pd()->KDD() + 1; |
68 | const auto KDH = pd()->KDH() + 1; |
69 | const auto KDW = pd()->KDW() + 1; |
70 | |
71 | const auto padFront = pd()->padFront(); |
72 | const auto padT = pd()->padT(); |
73 | const auto padL = pd()->padL(); |
74 | |
75 | const auto ndims = pd()->desc()->src_desc.ndims; |
76 | |
77 | auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
78 | float d = 0; |
79 | for_(dim_t ic = 0; ic < IC; ++ic) |
80 | for_(dim_t kd = 0; kd < KD; ++kd) |
81 | for_(dim_t kh = 0; kh < KH; ++kh) |
82 | for (dim_t kw = 0; kw < KW; ++kw) { |
83 | const dim_t id = od * KSD - padFront + kd * KDD; |
84 | const dim_t ih = oh * KSH - padT + kh * KDH; |
85 | const dim_t iw = ow * KSW - padL + kw * KDW; |
86 | |
87 | if (id < 0 || id >= ID) continue; |
88 | if (ih < 0 || ih >= IH) continue; |
89 | if (iw < 0 || iw >= IW) continue; |
90 | |
91 | const auto src_off = ref_conv_utils::get_data_off( |
92 | src_d, ndims, mb, g * IC + ic, id, ih, iw); |
93 | const auto wei_off = ref_conv_utils::get_weights_off( |
94 | weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
95 | |
96 | const float s |
97 | = io::load_float_value(src_d.data_type(), src, src_off); |
98 | const float w = io::load_float_value( |
99 | weights_d.data_type(), weights, wei_off); |
100 | d += s * w; |
101 | } |
102 | return d; |
103 | }; |
104 | |
105 | // help compiler optimize the code constants for plain layouts kernel |
106 | const dims_t &src_str = src_d.blocking_desc().strides; |
107 | const dim_t src_ic_stride = src_str[1]; |
108 | const dim_t src_id_stride = (ndims == 5) ? src_str[2] : 0; |
109 | const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0; |
110 | const dim_t src_iw_stride = (ndims >= 3) ? src_str[ndims - 1] : 0; |
111 | const dims_t &weights_str = weights_d.blocking_desc().strides; |
112 | const int gr_shift = with_groups ? 1 : 0; |
113 | const dim_t weights_ic_stride = weights_str[1 + gr_shift]; |
114 | const dim_t weights_kd_stride |
115 | = (ndims == 5) ? weights_str[2 + gr_shift] : 0; |
116 | const dim_t weights_kh_stride |
117 | = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0; |
118 | const dim_t weights_kw_stride |
119 | = (ndims >= 3) ? weights_str[ndims - 1 + gr_shift] : 0; |
120 | |
121 | auto ker_plain = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, |
122 | dim_t ow) { |
123 | assert(3 <= ndims && ndims <= 5); |
124 | float d = 0; |
125 | |
126 | const dim_t src_loc_off = ref_conv_utils::get_data_off( |
127 | src_d, ndims, mb, g * IC, 0, 0, 0); |
128 | const dim_t weights_loc_off = ref_conv_utils::get_weights_off( |
129 | weights_d, with_groups, ndims, g, oc, 0, 0, 0, 0); |
130 | |
131 | const void *__restrict src_loc = src; |
132 | const void *__restrict weights_loc = weights; |
133 | |
134 | if (IC > KW) { |
135 | for_(dim_t kd = 0; kd < KD; ++kd) |
136 | for_(dim_t kh = 0; kh < KH; ++kh) |
137 | for (dim_t kw = 0; kw < KW; ++kw) { |
138 | const dim_t id = od * KSD - padFront + kd * KDD; |
139 | const dim_t ih = oh * KSH - padT + kh * KDH; |
140 | const dim_t iw = ow * KSW - padL + kw * KDW; |
141 | if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 |
142 | || iw >= IW) |
143 | continue; |
144 | |
145 | for (dim_t ic = 0; ic < IC; ++ic) { |
146 | const dim_t src_off = ic + id * src_id_stride |
147 | + ih * src_ih_stride + iw * src_iw_stride; |
148 | const dim_t weights_off = ic * weights_ic_stride |
149 | + kd * weights_kd_stride + kh * weights_kh_stride |
150 | + kw; |
151 | const float s = io::load_float_value( |
152 | src_d.data_type(), src_loc, src_off + src_loc_off); |
153 | const float w = io::load_float_value(weights_d.data_type(), |
154 | weights_loc, weights_off + weights_loc_off); |
155 | d += s * w; |
156 | } |
157 | } |
158 | } else { |
159 | for_(dim_t ic = 0; ic < IC; ++ic) |
160 | for_(dim_t kd = 0; kd < KD; ++kd) |
161 | for_(dim_t kh = 0; kh < KH; ++kh) |
162 | for (dim_t kw = 0; kw < KW; ++kw) { |
163 | const dim_t id = od * KSD - padFront + kd * KDD; |
164 | const dim_t ih = oh * KSH - padT + kh * KDH; |
165 | const dim_t iw = ow * KSW - padL + kw * KDW; |
166 | if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 |
167 | || iw >= IW) |
168 | continue; |
169 | |
170 | const dim_t src_off = ic + id * src_id_stride |
171 | + ih * src_ih_stride + iw * src_iw_stride; |
172 | const dim_t weights_off = ic * weights_ic_stride |
173 | + kd * weights_kd_stride + kh * weights_kh_stride + kw; |
174 | const float s = io::load_float_value( |
175 | src_d.data_type(), src_loc, src_off + src_loc_off); |
176 | const float w = io::load_float_value(weights_d.data_type(), |
177 | weights_loc, weights_off + weights_loc_off); |
178 | d += s * w; |
179 | } |
180 | } |
181 | return d; |
182 | }; |
183 | |
184 | auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); |
185 | |
186 | parallel_nd(G, MB, OC, OD, OH, OW, |
187 | [&](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
188 | float acc = 0; |
189 | if (src_d.is_plain() && weights_d.is_plain() |
190 | && src_ic_stride == 1 && weights_kw_stride == 1) |
191 | acc += ker_plain(g, mb, oc, od, oh, ow); |
192 | else |
193 | acc += ker(g, mb, oc, od, oh, ow); |
194 | |
195 | float d = acc; |
196 | if (bias) { |
197 | const auto bias_off = bias_d.off(g * OC + oc); |
198 | const float b = io::load_float_value( |
199 | bias_d.data_type(), bias, bias_off); |
200 | d += b; |
201 | } |
202 | |
203 | dim_t dst_off = ref_conv_utils::get_data_off( |
204 | dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
205 | |
206 | dim_t dst_l_off = (mb * OC * G + g * OC + oc) * OD * OH * OW |
207 | + od * OH * OW + oh * OW + ow; |
208 | |
209 | ref_post_ops_t::args_t args; |
210 | args.dst_val = io::load_float_value(sum_dt, dst, dst_off); |
211 | args.ctx = &ctx; |
212 | args.l_offset = dst_l_off; |
213 | args.dst_md = pd()->dst_md(); |
214 | ref_post_ops->execute(d, args); |
215 | |
216 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
217 | }); |
218 | |
219 | return status::success; |
220 | } |
221 | |
222 | status_t ref_convolution_bwd_data_t::execute_backward_data( |
223 | const exec_ctx_t &ctx) const { |
224 | status_t status = status::success; |
225 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
226 | auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
227 | auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); |
228 | CHECK(status); |
229 | |
230 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
231 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
232 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
233 | |
234 | const bool with_groups = pd()->with_groups(); |
235 | |
236 | const auto G = pd()->G(); |
237 | const auto MB = pd()->MB(); |
238 | const auto OD = pd()->OD(); |
239 | const auto OH = pd()->OH(); |
240 | const auto OW = pd()->OW(); |
241 | const auto ID = pd()->ID(); |
242 | const auto IH = pd()->IH(); |
243 | const auto IW = pd()->IW(); |
244 | |
245 | const auto OC = pd()->OC() / G; |
246 | const auto IC = pd()->IC() / G; |
247 | const auto KD = pd()->KD(); |
248 | const auto KH = pd()->KH(); |
249 | const auto KW = pd()->KW(); |
250 | |
251 | const auto KSD = pd()->KSD(); |
252 | const auto KSH = pd()->KSH(); |
253 | const auto KSW = pd()->KSW(); |
254 | |
255 | const auto KDD = pd()->KDD() + 1; |
256 | const auto KDH = pd()->KDH() + 1; |
257 | const auto KDW = pd()->KDW() + 1; |
258 | |
259 | const auto padFront = pd()->padFront(); |
260 | const auto padT = pd()->padT(); |
261 | const auto padL = pd()->padL(); |
262 | |
263 | const auto ndims = pd()->desc()->diff_src_desc.ndims; |
264 | |
265 | auto ker = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { |
266 | float ds = 0; |
267 | for_(dim_t oc = 0; oc < OC; ++oc) |
268 | for_(dim_t kd = 0; kd < KD; ++kd) |
269 | for_(dim_t kh = 0; kh < KH; ++kh) |
270 | for (dim_t kw = 0; kw < KW; ++kw) { |
271 | if (iw + padL < kw * KDW || ih + padT < kh * KDH |
272 | || id + padFront < kd * KDD) |
273 | continue; |
274 | dim_t ow = iw - kw * KDW + padL; |
275 | dim_t oh = ih - kh * KDH + padT; |
276 | dim_t od = id - kd * KDD + padFront; |
277 | if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) continue; |
278 | |
279 | ow /= KSW; |
280 | oh /= KSH; |
281 | od /= KSD; |
282 | |
283 | if (od < OD && oh < OH && ow < OW) { |
284 | const auto diff_dst_off = ref_conv_utils::get_data_off( |
285 | diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
286 | const auto weights_off = ref_conv_utils::get_weights_off( |
287 | weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
288 | const float dd = io::load_float_value( |
289 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
290 | const float w = io::load_float_value( |
291 | weights_d.data_type(), weights, weights_off); |
292 | ds += dd * w; |
293 | } |
294 | } |
295 | return ds; |
296 | }; |
297 | |
298 | // help compiler optimize the code constants for plain layouts kernel |
299 | const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides; |
300 | const dim_t diff_dst_oc_stride = diff_dst_str[1]; |
301 | const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1]; |
302 | const dim_t diff_dst_oh_stride = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0; |
303 | const dim_t diff_dst_od_stride = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0; |
304 | |
305 | const dims_t &weights_str = weights_d.blocking_desc().strides; |
306 | const int gr_shift = with_groups ? 1 : 0; |
307 | const dim_t weights_oc_stride = weights_str[0 + gr_shift]; |
308 | const dim_t weights_kw_stride = weights_str[ndims - 1 + gr_shift]; |
309 | const dim_t weights_kh_stride |
310 | = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0; |
311 | const dim_t weights_kd_stride |
312 | = (ndims >= 5) ? weights_str[ndims - 3 + gr_shift] : 0; |
313 | |
314 | auto ker_plain = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, |
315 | dim_t iw) { |
316 | assert(3 <= ndims && ndims <= 5); |
317 | float ds = 0; |
318 | const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off( |
319 | diff_dst_d, ndims, mb, g * OC, 0, 0, 0); |
320 | const dim_t weights_loc_off = ref_conv_utils::get_weights_off( |
321 | weights_d, with_groups, ndims, g, 0, ic, 0, 0, 0); |
322 | |
323 | const void *__restrict diff_dst_loc = diff_dst; |
324 | const void *__restrict weights_loc = weights; |
325 | |
326 | if (OC > KW) { |
327 | for_(dim_t kd = 0; kd < KD; ++kd) |
328 | for_(dim_t kh = 0; kh < KH; ++kh) |
329 | for (dim_t kw = 0; kw < KW; ++kw) { |
330 | dim_t ow = iw - kw * KDW + padL; |
331 | dim_t oh = ih - kh * KDH + padT; |
332 | dim_t od = id - kd * KDD + padFront; |
333 | if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 || oh % KSH != 0 |
334 | || od % KSD != 0) |
335 | continue; |
336 | ow /= KSW; |
337 | oh /= KSH; |
338 | od /= KSD; |
339 | if (od >= OD || oh >= OH || ow >= OW) continue; |
340 | for (dim_t oc = 0; oc < OC; ++oc) { |
341 | const dim_t diff_dst_off = oc + od * diff_dst_od_stride |
342 | + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride; |
343 | const dim_t weights_off = oc * weights_oc_stride |
344 | + kd * weights_kd_stride + kh * weights_kh_stride |
345 | + kw; |
346 | const float dd = io::load_float_value( |
347 | diff_dst_d.data_type(), diff_dst_loc, |
348 | diff_dst_off + diff_dst_loc_off); |
349 | const float w = io::load_float_value(weights_d.data_type(), |
350 | weights_loc, weights_off + weights_loc_off); |
351 | ds += dd * w; |
352 | } |
353 | } |
354 | } else { |
355 | for_(dim_t oc = 0; oc < OC; ++oc) |
356 | for_(dim_t kd = 0; kd < KD; ++kd) |
357 | for (dim_t kh = 0; kh < KH; ++kh) { |
358 | // Note: placing these 2 params outside the `kw-loop` because |
359 | // of a compiler-generated bug. Declaring 'od' as volatile |
360 | // fixes a recurring seg-fault. |
361 | const volatile dim_t od_ = id - kd * KDD + padFront; |
362 | const dim_t weights_off_ = oc * weights_oc_stride |
363 | + kd * weights_kd_stride + kh * weights_kh_stride; |
364 | for (dim_t kw = 0; kw < KW; ++kw) { |
365 | dim_t ow = iw - kw * KDW + padL; |
366 | dim_t oh = ih - kh * KDH + padT; |
367 | dim_t od = od_; |
368 | if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 |
369 | || oh % KSH != 0 || od % KSD != 0) |
370 | continue; |
371 | ow /= KSW; |
372 | oh /= KSH; |
373 | od /= KSD; |
374 | if (od >= OD || oh >= OH || ow >= OW) continue; |
375 | const dim_t diff_dst_off = oc + od * diff_dst_od_stride |
376 | + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride; |
377 | const dim_t weights_off = weights_off_ + kw; |
378 | const float dd = io::load_float_value( |
379 | diff_dst_d.data_type(), diff_dst_loc, |
380 | diff_dst_off + diff_dst_loc_off); |
381 | const float w = io::load_float_value(weights_d.data_type(), |
382 | weights_loc, weights_off + weights_loc_off); |
383 | ds += dd * w; |
384 | } |
385 | } |
386 | } |
387 | return ds; |
388 | }; |
389 | |
390 | parallel_nd(G, MB, IC, ID, IH, IW, |
391 | [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { |
392 | float ds = 0; |
393 | if (diff_dst_d.is_plain() && weights_d.is_plain() |
394 | && diff_dst_oc_stride == 1 && weights_kw_stride == 1) |
395 | ds += ker_plain(g, mb, ic, id, ih, iw); |
396 | else |
397 | ds += ker(g, mb, ic, id, ih, iw); |
398 | |
399 | const auto diff_src_off = ref_conv_utils::get_data_off( |
400 | diff_src_d, ndims, mb, g * IC + ic, id, ih, iw); |
401 | io::store_float_value( |
402 | diff_src_d.data_type(), ds, diff_src, diff_src_off); |
403 | }); |
404 | |
405 | return status::success; |
406 | } |
407 | |
408 | status_t ref_convolution_bwd_weights_t::execute_backward_weights( |
409 | const exec_ctx_t &ctx) const { |
410 | status_t status = status::success; |
411 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
412 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
413 | auto diff_weights |
414 | = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_WEIGHTS, status); |
415 | CHECK(status); |
416 | auto diff_bias = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_BIAS, status); |
417 | CHECK(status); |
418 | |
419 | const memory_desc_wrapper src_d(pd()->src_md()); |
420 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
421 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
422 | const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1)); |
423 | |
424 | const bool with_groups = pd()->with_groups(); |
425 | |
426 | const auto G = pd()->G(); |
427 | const auto MB = pd()->MB(); |
428 | const auto OD = pd()->OD(); |
429 | const auto OH = pd()->OH(); |
430 | const auto OW = pd()->OW(); |
431 | const auto ID = pd()->ID(); |
432 | const auto IH = pd()->IH(); |
433 | const auto IW = pd()->IW(); |
434 | |
435 | const auto OC = pd()->OC() / G; |
436 | const auto IC = pd()->IC() / G; |
437 | const auto KD = pd()->KD(); |
438 | const auto KH = pd()->KH(); |
439 | const auto KW = pd()->KW(); |
440 | |
441 | const auto KSD = pd()->KSD(); |
442 | const auto KSH = pd()->KSH(); |
443 | const auto KSW = pd()->KSW(); |
444 | |
445 | const auto KDD = pd()->KDD() + 1; |
446 | const auto KDH = pd()->KDH() + 1; |
447 | const auto KDW = pd()->KDW() + 1; |
448 | |
449 | const auto padFront = pd()->padFront(); |
450 | const auto padT = pd()->padT(); |
451 | const auto padL = pd()->padL(); |
452 | |
453 | const auto ndims = pd()->desc()->src_desc.ndims; |
454 | |
455 | auto ker = [=](float &dw, dim_t g, dim_t oc, dim_t ic, dim_t kd, dim_t kh, |
456 | dim_t kw) { |
457 | for_(dim_t mb = 0; mb < MB; ++mb) |
458 | for_(dim_t od = 0; od < OD; ++od) |
459 | for_(dim_t oh = 0; oh < OH; ++oh) |
460 | for (dim_t ow = 0; ow < OW; ++ow) { |
461 | if (ow * KSW + kw * KDW < padL || oh * KSH + kh * KDH < padT |
462 | || od * KSD + kd * KDD < padFront |
463 | || ow * KSW + kw * KDW >= IW + padL |
464 | || oh * KSH + kh * KDH >= IH + padT |
465 | || od * KSD + kd * KDD >= ID + padFront) |
466 | continue; |
467 | |
468 | dim_t id = od * KSD - padFront + kd * KDD; |
469 | dim_t ih = oh * KSH - padT + kh * KDH; |
470 | dim_t iw = ow * KSW - padL + kw * KDW; |
471 | |
472 | const auto diff_dst_off = ref_conv_utils::get_data_off( |
473 | diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
474 | const auto src_off = ref_conv_utils::get_data_off( |
475 | src_d, ndims, mb, g * IC + ic, id, ih, iw); |
476 | float dd = io::load_float_value( |
477 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
478 | float s = io::load_float_value(src_d.data_type(), src, src_off); |
479 | dw += dd * s; |
480 | } |
481 | }; |
482 | |
483 | auto ker_plain = [=](float &dw, dim_t g, dim_t oc, dim_t ic, dim_t kd, |
484 | dim_t kh, dim_t kw) { |
485 | assert(3 <= ndims && ndims <= 5); |
486 | // help compiler optimize the code constants for plain layouts kernel |
487 | const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides; |
488 | const dim_t diff_dst_mb_stride = diff_dst_str[0]; |
489 | const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1]; |
490 | const dim_t diff_dst_oh_stride |
491 | = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0; |
492 | const dim_t diff_dst_od_stride |
493 | = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0; |
494 | const dims_t &src_str = src_d.blocking_desc().strides; |
495 | const dim_t src_mb_stride = src_str[0]; |
496 | const dim_t src_iw_stride = src_str[ndims - 1]; |
497 | const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0; |
498 | const dim_t src_id_stride = (ndims >= 5) ? src_str[ndims - 3] : 0; |
499 | |
500 | const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off( |
501 | diff_dst_d, ndims, 0, g * OC + oc, 0, 0, 0); |
502 | const dim_t src_loc_off = ref_conv_utils::get_data_off( |
503 | src_d, ndims, 0, g * IC + ic, 0, 0, 0); |
504 | |
505 | const void *__restrict diff_dst_loc = diff_dst; |
506 | const void *__restrict src_loc = src; |
507 | |
508 | for_(dim_t mb = 0; mb < MB; ++mb) |
509 | for_(dim_t od = 0; od < OD; ++od) |
510 | for_(dim_t oh = 0; oh < OH; ++oh) |
511 | for (dim_t ow = 0; ow < OW; ++ow) { |
512 | const dim_t id = od * KSD - padFront + kd * KDD; |
513 | const dim_t ih = oh * KSH - padT + kh * KDH; |
514 | const dim_t iw = ow * KSW - padL + kw * KDW; |
515 | if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 || iw >= IW) |
516 | continue; |
517 | const dim_t diff_dst_off = mb * diff_dst_mb_stride |
518 | + od * diff_dst_od_stride + oh * diff_dst_oh_stride |
519 | + ow * diff_dst_ow_stride; |
520 | const dim_t src_off = mb * src_mb_stride + id * src_id_stride |
521 | + ih * src_ih_stride + iw * src_iw_stride; |
522 | float dd = io::load_float_value(diff_dst_d.data_type(), |
523 | diff_dst_loc, diff_dst_off + diff_dst_loc_off); |
524 | float s = io::load_float_value( |
525 | src_d.data_type(), src_loc, src_off + src_loc_off); |
526 | dw += dd * s; |
527 | } |
528 | }; |
529 | |
530 | auto ker_bias = [=](float &db, dim_t g, dim_t oc) { |
531 | for_(dim_t mb = 0; mb < MB; ++mb) |
532 | for_(dim_t od = 0; od < OD; ++od) |
533 | for_(dim_t oh = 0; oh < OH; ++oh) |
534 | for (dim_t ow = 0; ow < OW; ++ow) { |
535 | const auto diff_dst_off = ref_conv_utils::get_data_off( |
536 | diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
537 | const float dd = io::load_float_value( |
538 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
539 | db += dd; |
540 | } |
541 | }; |
542 | |
543 | parallel_nd(G, OC, [&](dim_t g, dim_t oc) { |
544 | if (diff_bias) { |
545 | float db = 0; |
546 | ker_bias(db, g, oc); |
547 | const auto diff_bias_off = diff_bias_d.off(g * OC + oc); |
548 | io::store_float_value( |
549 | diff_bias_d.data_type(), db, diff_bias, diff_bias_off); |
550 | } |
551 | |
552 | for_(dim_t ic = 0; ic < IC; ++ic) |
553 | for_(dim_t kd = 0; kd < KD; ++kd) |
554 | for_(dim_t kh = 0; kh < KH; ++kh) |
555 | for (dim_t kw = 0; kw < KW; ++kw) { |
556 | float dw = 0; |
557 | if (diff_dst_d.is_plain() && src_d.is_plain()) |
558 | ker_plain(dw, g, oc, ic, kd, kh, kw); |
559 | else |
560 | ker(dw, g, oc, ic, kd, kh, kw); |
561 | |
562 | const dim_t diff_weights_off = ref_conv_utils::get_weights_off( |
563 | diff_weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
564 | io::store_float_value(diff_weights_d.data_type(), dw, diff_weights, |
565 | diff_weights_off); |
566 | } |
567 | }); |
568 | |
569 | return status::success; |
570 | } |
571 | |
572 | } // namespace cpu |
573 | } // namespace impl |
574 | } // namespace dnnl |
575 | |
576 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
577 | |