1/*******************************************************************************
2* Copyright 2016-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include "oneapi/dnnl/dnnl_types.h"
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25#include "cpu/gemm_convolution.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34
35namespace {
36struct im_pos_t {
37 im_pos_t() : n {0}, g {0}, od {0}, sp {0}, ic {0}, oc {0} {}
38 dim_t n, g, od, sp, ic, oc;
39 bool do_im2col(const im_pos_t &prev) const {
40 return true
41 && (n != prev.n || g != prev.g || od != prev.od || sp != prev.sp
42 || ic != prev.ic);
43 }
44};
45} // namespace
46
47status_t gemm_convolution_fwd_t::execute_forward_nspc(
48 const exec_ctx_t &ctx) const {
49 auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
50 auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
51 auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
52 auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
53
54 auto scratchpad = ctx.get_scratchpad_grantor();
55 const conv_gemm_conf_t &jcp = pd()->jcp_;
56 std::atomic<status_t> st(status::success);
57
58 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
59 status_t st_thr = execute_forward_thr_nspc(ctx, ithr, nthr, src_base,
60 wei_base, bia_base, dst_base, scratchpad);
61 if (st_thr != status::success) st = st_thr;
62 });
63
64 return st;
65}
66
67status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx,
68 const int ithr, const int nthr, const data_t *src_base,
69 const data_t *wei_base, const data_t *bia_base, data_t *dst_base,
70 const memory_tracking::grantor_t &scratchpad) const {
71 const conv_gemm_conf_t &jcp = pd()->jcp_;
72
73 // Src Format: mb-spatial-groups-input_channels
74 const dim_t src_mb_stride = jcp.id * jcp.ih * jcp.iw * jcp.ngroups * jcp.ic;
75 const dim_t src_g_stride = jcp.ic;
76 // Wei Format: spatial-input_channels-groups-output_channels
77 const dim_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
78
79 // Dst Format: mb-spatial-groups-output_channels
80 const dim_t dst_mb_stride = jcp.od * jcp.oh * jcp.ow * jcp.ngroups * jcp.oc;
81 const dim_t dst_g_stride = jcp.oc;
82 const dim_t dst_os_stride = jcp.ngroups * jcp.oc;
83
84 data_t *__restrict col = scratchpad.get<data_t>(key_conv_gemm_col)
85 + (ptrdiff_t)ithr * jcp.im2col_sz;
86 data_t *__restrict imtr = scratchpad.get<data_t>(key_conv_gemm_imtr)
87 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
88
89 dim_t g {0}, n {0}, ohb {0}, owb {0};
90 dim_t start = 0, end = 0;
91 const bool is_problem_3d = pd()->ndims() == 5;
92
93 assert(IMPLICATION(is_problem_3d,
94 jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
95 && jcp.ic_block == jcp.ic));
96 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
97
98 const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block);
99 const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block);
100 // threads share work across mini-batch, groups, and blocked width/height
101 const dim_t work_amount = jcp.mb * jcp.ngroups * nb_oh * nb_ow;
102 balance211(work_amount, nthr, ithr, start, end);
103 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
104
105 if (jcp.im2col_sz && is_problem_3d) {
106 // jit_gemm_convolution_utils::im2col_dt_3d() requires external
107 // data initialization by zeroes
108 PRAGMA_OMP_SIMD()
109 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
110 col[i] = 0.0f;
111 }
112 for (dim_t iwork = start; iwork < end; ++iwork) {
113 int oh = ohb * jcp.oh_block;
114 int ow = owb * jcp.ow_block;
115 const data_t *__restrict src
116 = src_base + n * src_mb_stride + g * src_g_stride;
117 const data_t *__restrict wei = wei_base + g * wei_g_stride;
118
119 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
120 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
121 if (jcp.im2col_sz && is_problem_3d) {
122 jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr);
123 }
124
125 for (int od = 0; od < jcp.od; od++) {
126 data_t *__restrict dst = dst_base + n * dst_mb_stride
127 + g * dst_g_stride
128 + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride;
129 if (jcp.im2col_sz) {
130 if (is_problem_3d)
131 jit_gemm_convolution_utils::im2col_dt_3d<data_t, data_t>(
132 jcp, imtr, col, od);
133 else
134 jit_gemm_convolution_utils::im2col_dt<data_t, data_t>(
135 jcp, src, imtr, col, oh, h_step, ow, w_step);
136 }
137
138 const dim_t M = jcp.oc;
139 const dim_t K = jcp.ks * jcp.ic;
140 const dim_t N = h_step * w_step;
141 const dim_t LDA = M * jcp.ngroups;
142 const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
143 const dim_t LDC = M * jcp.ngroups;
144 const char *BT = jcp.im2col_sz ? "T" : "N";
145 const data_t onef = 1.f;
146 const float beta = this->beta_;
147 const data_t *__restrict src_od
148 = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
149 status_t st = extended_sgemm("N", BT, &M, &N, &K, &onef, wei, &LDA,
150 jcp.im2col_sz ? col : (data_t *)src_od, &LDB, &beta, dst,
151 &LDC);
152 if (st != status::success) return st;
153
154 if (jcp.with_bias || jcp.with_eltwise || jcp.with_binary) {
155 parallel(0, [&](int ithr, int nthr) {
156 dim_t start, end;
157 balance211(N * jcp.oc, nthr, ithr, start, end);
158
159 const size_t first_oc = start % jcp.oc;
160 const size_t last_oc = (end - 1) % jcp.oc;
161 const size_t first_os = start / jcp.oc;
162 const size_t last_os = (end - 1) / jcp.oc;
163
164 for (size_t os = first_os; os <= last_os; ++os) {
165 const size_t start_oc = (os == first_os) ? first_oc : 0;
166 const size_t end_oc
167 = (os == last_os) ? last_oc : jcp.oc - 1;
168
169 const data_t *__restrict bia_arr
170 = bia_base ? bia_base + g * jcp.oc : nullptr;
171 data_t *__restrict dst_arr = dst + os * dst_os_stride;
172
173 if (jcp.with_bias) {
174 PRAGMA_OMP_SIMD()
175 for (size_t oc = start_oc; oc <= end_oc; oc++) {
176 dst_arr[oc] += bia_arr[oc];
177 }
178 }
179
180 if (jcp.with_eltwise || jcp.with_binary) {
181 bool fast_relu_done = false;
182 if (jcp.with_eltwise && jcp.post_ops.len() == 1) {
183 // fast branch for ReLU case
184 const auto &eltwise
185 = jcp.post_ops.entry_.back().eltwise;
186
187 if (eltwise.alg == alg_kind::eltwise_relu) {
188 const auto alpha = eltwise.alpha;
189 const auto scale = eltwise.scale;
190 PRAGMA_OMP_SIMD()
191 for (size_t oc = start_oc; oc <= end_oc;
192 oc++) {
193 if (dst_arr[oc] < 0)
194 dst_arr[oc] *= alpha;
195 dst_arr[oc] *= scale;
196 }
197 fast_relu_done = true;
198 }
199 }
200 if (!fast_relu_done) {
201 ref_post_ops_t::args_t args;
202 args.ctx = &ctx;
203 args.dst_md = pd()->dst_md();
204
205 for (size_t oc = start_oc; oc <= end_oc; oc++) {
206 args.l_offset = (g * jcp.oc + oc) * jcp.os;
207 post_ops_->execute(dst_arr[oc], args);
208 }
209 }
210 }
211 }
212 });
213 }
214 }
215 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
216 }
217 return status::success;
218}
219
220status_t gemm_convolution_fwd_t::execute_forward_ncsp(
221 const exec_ctx_t &ctx) const {
222 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
223 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
224 auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
225 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
226
227 auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col);
228
229 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
230
231 const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
232 const size_t weights_oc_size = jcp.ic * jcp.ks;
233 const size_t weights_g_size = weights_oc_size * jcp.oc;
234 const bool is_problem_3d = pd()->ndims() == 5;
235
236 assert(IMPLICATION(is_problem_3d,
237 jcp.os_block == jcp.os && jcp.ic_block == jcp.ic
238 && jcp.os_nb_block == 1));
239
240 status_t st = status::success;
241 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
242 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
243
244 // non-blocked jit_gemm_convolution_utils::im2col_3d() requires
245 // external data initialization by zeroes
246 const bool outer_padding = jcp.os_nb_block == 1;
247 if (outer_padding && is_problem_3d) {
248 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
249 _col[i] = (data_t)0;
250 }
251 auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
252 im_pos_t &step, const im_pos_t &end) {
253 const data_t *_src
254 = src + (curr.n * jcp.ngroups + curr.g) * src_step;
255 step.oc = nstl::min(
256 jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc);
257 step.sp = nstl::min(jcp.os_block,
258 nstl::min(jcp.os - curr.sp, end.sp - spatial));
259 step.ic = nstl::min(
260 jcp.ic_block, nstl::min(jcp.ic, end.ic) - curr.ic);
261 bool do_im2col = curr.do_im2col(prev);
262 prev = curr;
263
264 if (jcp.im2col_sz && do_im2col) {
265 if (!is_problem_3d)
266 jit_gemm_convolution_utils::im2col<float>(jcp, _src, _col,
267 curr.sp, step.sp, curr.ic, step.ic);
268 else
269 jit_gemm_convolution_utils::im2col_3d<float>(
270 jcp, _src, _col, curr.od, 0, jcp.os);
271 }
272 const data_t one = 1.0;
273
274 const dim_t M = jcp.os * jcp.od;
275 const size_t dst_step = jcp.oc * M;
276 const dim_t m = step.sp;
277 const dim_t LDA = jcp.im2col_sz ? m : M;
278 data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step
279 + curr.oc * M + curr.od * jcp.os + curr.sp;
280 const dim_t K = step.ic * jcp.ks;
281 const dim_t LDB = jcp.ic * jcp.ks;
282 const dim_t N = step.oc;
283
284 // TODO: what if this->beta_ != 0 && != 1 ?
285 const float beta = (curr.ic == 0) ? this->beta_ : one;
286 const float *_source = jcp.im2col_sz
287 ? _col
288 : _src + curr.ic * M + curr.od * jcp.os + curr.sp;
289 const data_t *_weights = weights + curr.g * weights_g_size
290 + curr.oc * weights_oc_size + curr.ic * jcp.ks;
291
292 status_t st = extended_sgemm("N", "N", &m, &N, &K, &one, _source,
293 &LDA, _weights, &LDB, &beta, _dst, &M);
294 if (st != status::success) return st;
295
296 if (curr.ic == jcp.ic - step.ic) {
297 // TODO: for "outer threading" we have parallel section within
298 // outermost "parallel". It is not good. Consider to use
299 // "parallel" here with number of threads passed as parameter
300 const int oc_start = curr.g * jcp.oc + curr.oc;
301 if (jcp.with_eltwise || jcp.with_binary) {
302 bool fast_relu_done = false;
303 if (jcp.with_eltwise && jcp.post_ops.len() == 1) {
304 // fast branch for ReLU case
305 const auto &eltwise
306 = jcp.post_ops.entry_.back().eltwise;
307 if (eltwise.alg == alg_kind::eltwise_relu) {
308 parallel_nd(step.oc, [&](dim_t oc) {
309 data_t b = jcp.with_bias ? bias[oc_start + oc]
310 : 0;
311 data_t *d_ = _dst + oc * M;
312 PRAGMA_OMP_SIMD()
313 for (int oS = 0; oS < m; ++oS) {
314 d_[oS] += b;
315 if (d_[oS] < 0) d_[oS] *= eltwise.alpha;
316 d_[oS] *= eltwise.scale;
317 }
318 });
319 fast_relu_done = true;
320 }
321 }
322 if (!fast_relu_done) {
323 parallel_nd(step.oc, [&](dim_t oc) {
324 data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
325 data_t *d_ = _dst + oc * M;
326
327 ref_post_ops_t::args_t args;
328 args.ctx = &ctx;
329 args.dst_md = pd()->dst_md();
330 args.l_offset = d_ - dst;
331
332 PRAGMA_OMP_SIMD()
333 for (int oS = 0; oS < m; ++oS) {
334 d_[oS] += b;
335 post_ops_->execute(d_[oS], args);
336 args.l_offset++;
337 }
338 });
339 }
340
341 } else if (jcp.with_bias) {
342 parallel_nd(step.oc, [&](dim_t oc) {
343 data_t b = bias[oc_start + oc];
344 data_t *d_ = _dst + oc * M;
345 PRAGMA_OMP_SIMD()
346 for (int oS = 0; oS < m; ++oS) {
347 d_[oS] += b;
348 }
349 });
350 }
351 }
352
353 return status::success;
354 };
355 im_pos_t start, end;
356 end.ic = jcp.ic;
357
358 if (!is_problem_3d) {
359 dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
360 balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
361 end.oc, dim_t(jcp.nthr_oc));
362 } else {
363 dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od;
364 balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
365 end.oc, dim_t(jcp.nthr_oc));
366 start.sp *= jcp.os;
367 end.sp *= jcp.os;
368 }
369
370 im_pos_t curr, prev, step;
371 prev.n = prev.g = prev.od = prev.sp = prev.ic = -1;
372 step.oc = jcp.oc_block;
373 step.sp = jcp.os_block;
374 step.ic = jcp.ic_block;
375
376 if (jcp.loop_order == gemm_loop_rlb)
377 for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
378 for (int spatial = start.sp; spatial < end.sp;
379 spatial += step.sp) {
380 nd_iterator_init(spatial, curr.n, jcp.mb, curr.g,
381 jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os);
382 for (curr.oc = start.oc; curr.oc < end.oc;
383 curr.oc += step.oc) {
384 status_t st_thr
385 = inner_ker(spatial, curr, prev, step, end);
386 if (st_thr != status::success) {
387 st = st_thr;
388 return;
389 }
390 }
391 }
392 else if (jcp.loop_order == gemm_loop_lrb)
393 for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) {
394 nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups,
395 curr.od, jcp.od, curr.sp, jcp.os);
396 for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
397 for (curr.oc = start.oc; curr.oc < end.oc;
398 curr.oc += step.oc) {
399 status_t st_thr
400 = inner_ker(spatial, curr, prev, step, end);
401 if (st_thr != status::success) {
402 st = st_thr;
403 return;
404 }
405 }
406 }
407 else
408 st = status::unimplemented;
409 });
410
411 return st;
412}
413
414status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc(
415 const exec_ctx_t &ctx) const {
416
417 auto diff_dst_base = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
418 auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
419 auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS);
420 auto diff_src_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
421
422 auto scratchpad = ctx.get_scratchpad_grantor();
423 const conv_gemm_conf_t &jcp = pd()->jcp_;
424 std::atomic<status_t> st(status::success);
425
426 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
427 status_t st_thr = execute_backward_data_thr_nspc(ithr, nthr,
428 diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad);
429 if (st_thr != status::success) st = st_thr;
430 });
431
432 return st;
433}
434
435status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc(
436 const int ithr, const int nthr, const data_t *diff_dst_base,
437 const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base,
438 const memory_tracking::grantor_t &scratchpad) const {
439 const conv_gemm_conf_t &jcp = pd()->jcp_;
440
441 // Diff_dst Format: mb-spatial-groups-output_channels
442 const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh
443 * jcp.ow * jcp.ngroups * jcp.oc;
444 const size_t diff_dst_g_stride = jcp.oc;
445
446 // Wei Format: spatial-input_channels-groups-output_channels
447 const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
448
449 // Diff_src Format: mb-spatial-groups-input_channels
450 const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih
451 * jcp.iw * jcp.ngroups * jcp.ic;
452 const size_t diff_src_g_stride = jcp.ic;
453 const size_t diff_src_os_stride = jcp.ngroups * jcp.ic;
454
455 // threads share work across mini-batch and groups
456 const dim_t work_amount = jcp.ngroups * jcp.mb;
457
458 data_t *__restrict col = scratchpad.get<data_t>(key_conv_gemm_col)
459 + (ptrdiff_t)ithr * jcp.im2col_sz;
460 const bool acc_needed = jcp.ngroups > 1;
461 data_t *__restrict acc = acc_needed
462 ? scratchpad.get<data_t>(key_conv_gemm_acc)
463 + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic
464 : nullptr;
465
466 dim_t n {0}, g {0};
467 dim_t start = 0, end = 0;
468
469 balance211(work_amount, nthr, ithr, start, end);
470 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
471
472 for (dim_t iwork = start; iwork < end; ++iwork) {
473 const data_t *__restrict diff_dst = diff_dst_base
474 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
475 const data_t *__restrict wei = wei_base + g * wei_g_stride;
476 data_t *__restrict diff_src = diff_src_base + n * diff_src_mb_stride
477 + g * diff_src_g_stride;
478
479 const dim_t M = jcp.ks * jcp.ic;
480 const dim_t N = jcp.os * jcp.od;
481 const dim_t K = jcp.oc;
482
483 const data_t onef = 1.0f, zerof = 0.0f;
484 const dim_t LD = K * jcp.ngroups;
485
486 status_t st = extended_sgemm("T", "N", &M, &N, &K, &onef, wei, &LD,
487 diff_dst, &LD, &zerof,
488 jcp.im2col_sz ? col : (acc_needed ? acc : diff_src), &M);
489 if (st != status::success) return st;
490
491 if (jcp.im2col_sz)
492 jit_gemm_convolution_utils::col2im_dt<data_t>(
493 jcp, col, (acc_needed ? acc : diff_src));
494
495 if (acc_needed) {
496 parallel_nd(static_cast<size_t>(jcp.is) * jcp.id, [&](size_t is) {
497 data_t *__restrict diff_src_arr
498 = diff_src + is * diff_src_os_stride;
499 const data_t *__restrict acc_arr = acc + is * jcp.ic;
500 PRAGMA_OMP_SIMD()
501 for (int ic = 0; ic < jcp.ic; ic++) {
502 diff_src_arr[ic] = acc_arr[ic];
503 }
504 });
505 }
506 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
507 }
508 return status::success;
509}
510
511status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp(
512 const exec_ctx_t &ctx) const {
513 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
514 auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS);
515 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
516
517 auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col);
518
519 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
520
521 const dim_t M = jcp.os * jcp.od;
522 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
523 const size_t dst_step = (size_t)jcp.oc * M;
524 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
525
526 const dim_t m = jcp.os_block;
527 const dim_t K = jcp.oc;
528 const dim_t N = jcp.ic * jcp.ks;
529
530 const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb;
531 const bool is_problem_3d = pd()->ndims() == 5;
532
533 std::atomic<status_t> st(status::success);
534 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
535 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
536
537 dim_t g {0}, n {0};
538 dim_t start = 0, end = 0;
539 balance211(work_amount, nthr, ithr, start, end);
540 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
541 for (dim_t iwork = start; iwork < end; ++iwork) {
542
543 data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step;
544 if (is_problem_3d && jcp.im2col_sz > 0) {
545 // jit_gemm_convolution_utils::col2im_3d() assumes that the
546 // accumulator is initialized by zeroes
547 for (size_t i = 0; i < src_step; i++)
548 _diff_src[i] = (data_t)0;
549 }
550
551 const data_t *_weights = weights + g * weights_g_size;
552 for_(int od = 0; od < jcp.od; ++od)
553 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
554 auto out_off = os_nb * m + od * jcp.os;
555 const data_t *_diff_dst
556 = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off;
557 const dim_t os_block
558 = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m);
559 const dim_t LDC = jcp.im2col_sz ? os_block : M;
560
561 const data_t zero = 0.0, one = 1.0;
562 status_t st_thr = extended_sgemm("N", "T", &os_block, &N, &K,
563 &one, _diff_dst, &M, _weights, &N, &zero,
564 jcp.im2col_sz ? _col : _diff_src + out_off, &LDC);
565 if (st_thr != status::success) {
566 st = st_thr;
567 return;
568 }
569
570 if (jcp.im2col_sz) {
571 if (!is_problem_3d)
572 jit_gemm_convolution_utils::col2im(jcp, _col, _diff_src,
573 os_nb * jcp.os_block, os_block);
574 else {
575 jit_gemm_convolution_utils::col2im_3d(jcp, _col,
576 _diff_src, od, os_nb * jcp.os_block, os_block);
577 }
578 }
579 }
580 nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
581 }
582 });
583
584 return st;
585}
586
587status_t gemm_convolution_bwd_weights_t::execute_backward_weights_nspc(
588 const exec_ctx_t &ctx) const {
589 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
590 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
591 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
592 auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
593
594 auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col);
595 const conv_gemm_conf_t &jcp = pd()->jcp_;
596
597 auto wei_reduction
598 = ctx.get_scratchpad_grantor().get<data_t>(key_conv_wei_reduction);
599
600 const dim_t K = jcp.os * static_cast<size_t>(jcp.od);
601 const size_t src_step
602 = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id;
603 const size_t dst_step = jcp.oc * K;
604 const size_t weights_g_size = jcp.oc;
605
606 const dim_t k = jcp.os;
607 const dim_t M = jcp.oc;
608 const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks;
609 const dim_t LDB = jcp.ngroups * jcp.oc;
610 const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic;
611 const bool is_problem_3d = pd()->ndims() == 5;
612
613 std::atomic<status_t> st(status::success);
614 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
615 int ithr_g, nthr_g, ithr_mb, nthr_mb;
616 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
617
618 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
619 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
620 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
621
622 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
623
624 const int need_reduction = nthr_mb != 1;
625 const dim_t LDC = need_reduction ? jcp.oc : jcp.ngroups * jcp.oc;
626 data_t *__restrict imtr
627 = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_imtr)
628 + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is;
629
630 if (ithr_g != -1 && ithr_mb != -1) {
631 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
632 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
633
634 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
635
636 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
637 if (is_problem_3d) {
638 // jit_gemm_convolution_utils::im2col_3d() requires external
639 // data initialization by zeroes
640 PRAGMA_OMP_SIMD()
641 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
642 _col[i] = 0.0f;
643 }
644
645 data_t *weights_reduce_base = wei_reduction
646 + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic;
647 data_t *weights_reduce = weights_reduce_base
648 + ithr_mb * weights_g_size * jcp.ks * jcp.ic;
649
650 for (size_t g = g_start; g < g_end; ++g) {
651 data_t *_diff_weights = need_reduction
652 ? weights_reduce
653 : diff_weights + g * weights_g_size;
654 for (size_t mb = mb_start; mb < mb_end; ++mb) {
655 const data_t *_src
656 = src + mb * jcp.ngroups * src_step + g * jcp.ic;
657 if (jcp.im2col_sz && is_problem_3d)
658 jit_gemm_convolution_utils::transpose_dt(
659 jcp, _src, imtr);
660 for (int od = 0; od < jcp.od; ++od) {
661 const data_t *_diff_dst = diff_dst
662 + mb * jcp.ngroups * dst_step
663 + od * k * jcp.ngroups * jcp.oc + g * jcp.oc;
664
665 if (jcp.im2col_sz) {
666 if (is_problem_3d)
667 jit_gemm_convolution_utils::im2col_dt_3d<data_t,
668 data_t>(jcp, imtr, _col, od);
669 else
670 jit_gemm_convolution_utils::im2col_dt<data_t,
671 data_t>(jcp, _src, imtr, _col, 0,
672 jcp.oh, 0, jcp.ow);
673 }
674 const data_t zero = 0.0f, one = 1.0f;
675 status_t st_thr = extended_sgemm("N",
676 jcp.im2col_sz ? "N" : "T", &M, &N, &k, &one,
677 _diff_dst, &LDB,
678 jcp.im2col_sz
679 ? _col
680 : _src + od * k * jcp.ngroups * jcp.ic,
681 &LDA, mb == mb_start && od == 0 ? &zero : &one,
682 _diff_weights, &LDC);
683 if (st_thr != status::success) {
684 st = st_thr;
685 // Finish the loops early if failure occured.
686 g = g_end;
687 mb = mb_end;
688 od = jcp.od;
689 }
690 }
691 }
692 }
693 if (need_reduction && dnnl_thr_syncable()) {
694 dnnl_thr_barrier();
695 if (st != status::success) return;
696 jit_gemm_convolution_utils::bwd_weights_reduction_par_nspc(
697 ithr_mb, nthr_mb, g_start, g_end, jcp,
698 weights_reduce_base, diff_weights);
699 }
700 } else {
701 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
702 }
703 });
704
705 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
706 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
707 int ithr_g, nthr_g, ithr_mb, nthr_mb;
708 size_t g_start {0}, g_end {0};
709 size_t mb_start {0}, mb_end {0};
710 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
711 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
712 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
713 nthr_mb);
714
715 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
716 const int need_reduction = nthr_mb != 1;
717
718 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
719 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
720 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
721
722 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
723
724 data_t *weights_reduce_base = wei_reduction
725 + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks;
726
727 jit_gemm_convolution_utils::bwd_weights_reduction_par_nspc(
728 ithr_mb, nthr_mb, g_start, g_end, jcp,
729 weights_reduce_base, diff_weights);
730 }
731 });
732 }
733
734 if (jcp.with_bias) {
735 parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) {
736 data_t db = 0;
737 const size_t offset_base = g * jcp.oc + oc;
738 for_(dim_t mb = 0; mb < jcp.mb; ++mb)
739 for_(dim_t od = 0; od < jcp.od; ++od)
740 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
741 const data_t *__restrict diff_dst_arr = diff_dst + offset_base
742 + ((static_cast<size_t>(mb) * jcp.od + od) * jcp.oh
743 + oh)
744 * jcp.ow * jcp.ngroups * jcp.oc;
745 const int width_stride = jcp.ngroups * jcp.oc;
746
747 PRAGMA_OMP_SIMD(reduction(+ : db))
748 for (int ow = 0; ow < jcp.ow; ++ow) {
749 db += diff_dst_arr[ow * width_stride];
750 }
751 }
752 diff_bias[g * jcp.oc + oc] = db;
753 });
754 }
755 return st;
756}
757
758status_t gemm_convolution_bwd_weights_t::execute_backward_weights_ncsp(
759 const exec_ctx_t &ctx) const {
760 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
761 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
762 auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS);
763 auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS);
764
765 auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col);
766 auto wei_reduction
767 = ctx.get_scratchpad_grantor().get<data_t>(key_conv_wei_reduction);
768
769 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
770
771 const dim_t K = jcp.os * jcp.od;
772 const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
773 const size_t dst_step = jcp.oc * K;
774 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
775
776 const dim_t k = jcp.os_block;
777 const dim_t N = jcp.oc;
778 const dim_t M = jcp.ic * jcp.ks;
779 const bool is_problem_3d = pd()->ndims() == 5;
780
781 std::atomic<status_t> st(status::success);
782 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
783 int ithr_g, nthr_g, ithr_mb, nthr_mb;
784 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
785
786 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
787 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
788 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
789
790 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
791 const int need_reduction = nthr_mb != 1;
792
793 if (ithr_g != -1 && ithr_mb != -1) {
794 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
795 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
796
797 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
798
799 data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
800
801 // non-blocked jit_gemm_convolution_utils::im2col_3d() requires
802 // external data initialization by zeroes
803 const bool outer_padding = jcp.os_nb_block == 1;
804 if (outer_padding && is_problem_3d) {
805 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
806 _col[i] = (data_t)0;
807 }
808 data_t *weights_reduce_base
809 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
810 data_t *weights_reduce
811 = weights_reduce_base + ithr_mb * weights_g_size;
812
813 for (size_t g = g_start; g < g_end; ++g) {
814 data_t *_diff_weights = need_reduction
815 ? weights_reduce
816 : (diff_weights + g * weights_g_size);
817 for (size_t mb = mb_start; mb < mb_end; ++mb) {
818 const data_t *_src
819 = src + (mb * jcp.ngroups + g) * src_step;
820 for_(int od = 0; od < jcp.od; ++od)
821 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
822 auto out_off = os_nb * k + od * jcp.os;
823 const dim_t os_block = nstl::min(
824 (dim_t)jcp.os_block, jcp.os - os_nb * k);
825 const data_t *_diff_dst = diff_dst
826 + (mb * jcp.ngroups + g) * dst_step + out_off;
827
828 if (jcp.im2col_sz) {
829 if (!is_problem_3d)
830 jit_gemm_convolution_utils::im2col<float>(jcp,
831 _src, _col, os_nb * jcp.os_block,
832 os_block, 0, jcp.ic);
833 else
834 jit_gemm_convolution_utils::im2col_3d<float>(
835 jcp, _src, _col, od,
836 os_nb * jcp.os_block, os_block);
837 }
838 const dim_t LDA = jcp.im2col_sz ? os_block : K;
839 const data_t zero = 0.0, one = 1.0;
840 status_t st_thr = extended_sgemm("T", "N", &M, &N,
841 &os_block, &one,
842 jcp.im2col_sz ? _col : _src + out_off, &LDA,
843 _diff_dst, &K,
844 mb == mb_start && os_nb == 0 && od == 0 ? &zero
845 : &one,
846 _diff_weights, &M);
847 if (st_thr != status::success) {
848 st = st_thr;
849 // Finish the loops early if failure occured.
850 g = g_end;
851 mb = mb_end;
852 od = jcp.od;
853 os_nb = jcp.os_nb_block;
854 }
855 }
856 }
857 }
858 if (need_reduction && dnnl_thr_syncable()) {
859 dnnl_thr_barrier();
860 if (st != status::success) return;
861 data_t *weights_base = diff_weights + g_start * weights_g_size;
862 jit_gemm_convolution_utils::bwd_weights_reduction_par_ncsp(
863 ithr_mb, nthr_mb, jcp, weights_reduce_base,
864 weights_base);
865 }
866 } else {
867 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
868 }
869 });
870
871 if (st != status::success) return st;
872
873 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
874 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
875 int ithr_g, nthr_g, ithr_mb, nthr_mb;
876 size_t g_start {0}, g_end {0};
877 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
878 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
879 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
880 nthr_mb);
881
882 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
883 const int need_reduction = nthr_mb != 1;
884
885 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
886 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
887
888 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
889
890 data_t *weights_reduce_base
891 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
892 data_t *weights_base = diff_weights + g_start * weights_g_size;
893
894 jit_gemm_convolution_utils::bwd_weights_reduction_par_ncsp(
895 ithr_mb, nthr_mb, jcp, weights_reduce_base,
896 weights_base);
897 }
898 });
899 }
900
901 if (jcp.with_bias) {
902 parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) {
903 data_t db = 0;
904 dim_t offset_ = g * dst_step + oc * K;
905 for (dim_t mb = 0; mb < jcp.mb; ++mb) {
906 dim_t offset = offset_ + mb * jcp.ngroups * dst_step;
907 for_(dim_t od = 0; od < jcp.od; ++od)
908 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
909 PRAGMA_OMP_SIMD(reduction(+ : db))
910 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
911 db += diff_dst[offset + ow];
912 }
913 offset += jcp.ow;
914 }
915 }
916 diff_bias[g * jcp.oc + oc] = db;
917 });
918 }
919
920 return st;
921}
922
923} // namespace cpu
924} // namespace impl
925} // namespace dnnl
926