1 | /******************************************************************************* |
2 | * Copyright 2016-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <atomic> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | #include "cpu/gemm_convolution.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | namespace { |
36 | struct im_pos_t { |
37 | im_pos_t() : n {0}, g {0}, od {0}, sp {0}, ic {0}, oc {0} {} |
38 | dim_t n, g, od, sp, ic, oc; |
39 | bool do_im2col(const im_pos_t &prev) const { |
40 | return true |
41 | && (n != prev.n || g != prev.g || od != prev.od || sp != prev.sp |
42 | || ic != prev.ic); |
43 | } |
44 | }; |
45 | } // namespace |
46 | |
47 | status_t gemm_convolution_fwd_t::execute_forward_nspc( |
48 | const exec_ctx_t &ctx) const { |
49 | auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
50 | auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
51 | auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
52 | auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
53 | |
54 | auto scratchpad = ctx.get_scratchpad_grantor(); |
55 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
56 | std::atomic<status_t> st(status::success); |
57 | |
58 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
59 | status_t st_thr = execute_forward_thr_nspc(ctx, ithr, nthr, src_base, |
60 | wei_base, bia_base, dst_base, scratchpad); |
61 | if (st_thr != status::success) st = st_thr; |
62 | }); |
63 | |
64 | return st; |
65 | } |
66 | |
67 | status_t gemm_convolution_fwd_t::execute_forward_thr_nspc(const exec_ctx_t &ctx, |
68 | const int ithr, const int nthr, const data_t *src_base, |
69 | const data_t *wei_base, const data_t *bia_base, data_t *dst_base, |
70 | const memory_tracking::grantor_t &scratchpad) const { |
71 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
72 | |
73 | // Src Format: mb-spatial-groups-input_channels |
74 | const dim_t src_mb_stride = jcp.id * jcp.ih * jcp.iw * jcp.ngroups * jcp.ic; |
75 | const dim_t src_g_stride = jcp.ic; |
76 | // Wei Format: spatial-input_channels-groups-output_channels |
77 | const dim_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0; |
78 | |
79 | // Dst Format: mb-spatial-groups-output_channels |
80 | const dim_t dst_mb_stride = jcp.od * jcp.oh * jcp.ow * jcp.ngroups * jcp.oc; |
81 | const dim_t dst_g_stride = jcp.oc; |
82 | const dim_t dst_os_stride = jcp.ngroups * jcp.oc; |
83 | |
84 | data_t *__restrict col = scratchpad.get<data_t>(key_conv_gemm_col) |
85 | + (ptrdiff_t)ithr * jcp.im2col_sz; |
86 | data_t *__restrict imtr = scratchpad.get<data_t>(key_conv_gemm_imtr) |
87 | + (ptrdiff_t)ithr * jcp.is * jcp.ic; |
88 | |
89 | dim_t g {0}, n {0}, ohb {0}, owb {0}; |
90 | dim_t start = 0, end = 0; |
91 | const bool is_problem_3d = pd()->ndims() == 5; |
92 | |
93 | assert(IMPLICATION(is_problem_3d, |
94 | jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow |
95 | && jcp.ic_block == jcp.ic)); |
96 | assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); |
97 | |
98 | const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block); |
99 | const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block); |
100 | // threads share work across mini-batch, groups, and blocked width/height |
101 | const dim_t work_amount = jcp.mb * jcp.ngroups * nb_oh * nb_ow; |
102 | balance211(work_amount, nthr, ithr, start, end); |
103 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); |
104 | |
105 | if (jcp.im2col_sz && is_problem_3d) { |
106 | // jit_gemm_convolution_utils::im2col_dt_3d() requires external |
107 | // data initialization by zeroes |
108 | PRAGMA_OMP_SIMD() |
109 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
110 | col[i] = 0.0f; |
111 | } |
112 | for (dim_t iwork = start; iwork < end; ++iwork) { |
113 | int oh = ohb * jcp.oh_block; |
114 | int ow = owb * jcp.ow_block; |
115 | const data_t *__restrict src |
116 | = src_base + n * src_mb_stride + g * src_g_stride; |
117 | const data_t *__restrict wei = wei_base + g * wei_g_stride; |
118 | |
119 | const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); |
120 | const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); |
121 | if (jcp.im2col_sz && is_problem_3d) { |
122 | jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr); |
123 | } |
124 | |
125 | for (int od = 0; od < jcp.od; od++) { |
126 | data_t *__restrict dst = dst_base + n * dst_mb_stride |
127 | + g * dst_g_stride |
128 | + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride; |
129 | if (jcp.im2col_sz) { |
130 | if (is_problem_3d) |
131 | jit_gemm_convolution_utils::im2col_dt_3d<data_t, data_t>( |
132 | jcp, imtr, col, od); |
133 | else |
134 | jit_gemm_convolution_utils::im2col_dt<data_t, data_t>( |
135 | jcp, src, imtr, col, oh, h_step, ow, w_step); |
136 | } |
137 | |
138 | const dim_t M = jcp.oc; |
139 | const dim_t K = jcp.ks * jcp.ic; |
140 | const dim_t N = h_step * w_step; |
141 | const dim_t LDA = M * jcp.ngroups; |
142 | const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups; |
143 | const dim_t LDC = M * jcp.ngroups; |
144 | const char *BT = jcp.im2col_sz ? "T" : "N" ; |
145 | const data_t onef = 1.f; |
146 | const float beta = this->beta_; |
147 | const data_t *__restrict src_od |
148 | = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic; |
149 | status_t st = extended_sgemm("N" , BT, &M, &N, &K, &onef, wei, &LDA, |
150 | jcp.im2col_sz ? col : (data_t *)src_od, &LDB, &beta, dst, |
151 | &LDC); |
152 | if (st != status::success) return st; |
153 | |
154 | if (jcp.with_bias || jcp.with_eltwise || jcp.with_binary) { |
155 | parallel(0, [&](int ithr, int nthr) { |
156 | dim_t start, end; |
157 | balance211(N * jcp.oc, nthr, ithr, start, end); |
158 | |
159 | const size_t first_oc = start % jcp.oc; |
160 | const size_t last_oc = (end - 1) % jcp.oc; |
161 | const size_t first_os = start / jcp.oc; |
162 | const size_t last_os = (end - 1) / jcp.oc; |
163 | |
164 | for (size_t os = first_os; os <= last_os; ++os) { |
165 | const size_t start_oc = (os == first_os) ? first_oc : 0; |
166 | const size_t end_oc |
167 | = (os == last_os) ? last_oc : jcp.oc - 1; |
168 | |
169 | const data_t *__restrict bia_arr |
170 | = bia_base ? bia_base + g * jcp.oc : nullptr; |
171 | data_t *__restrict dst_arr = dst + os * dst_os_stride; |
172 | |
173 | if (jcp.with_bias) { |
174 | PRAGMA_OMP_SIMD() |
175 | for (size_t oc = start_oc; oc <= end_oc; oc++) { |
176 | dst_arr[oc] += bia_arr[oc]; |
177 | } |
178 | } |
179 | |
180 | if (jcp.with_eltwise || jcp.with_binary) { |
181 | bool fast_relu_done = false; |
182 | if (jcp.with_eltwise && jcp.post_ops.len() == 1) { |
183 | // fast branch for ReLU case |
184 | const auto &eltwise |
185 | = jcp.post_ops.entry_.back().eltwise; |
186 | |
187 | if (eltwise.alg == alg_kind::eltwise_relu) { |
188 | const auto alpha = eltwise.alpha; |
189 | const auto scale = eltwise.scale; |
190 | PRAGMA_OMP_SIMD() |
191 | for (size_t oc = start_oc; oc <= end_oc; |
192 | oc++) { |
193 | if (dst_arr[oc] < 0) |
194 | dst_arr[oc] *= alpha; |
195 | dst_arr[oc] *= scale; |
196 | } |
197 | fast_relu_done = true; |
198 | } |
199 | } |
200 | if (!fast_relu_done) { |
201 | ref_post_ops_t::args_t args; |
202 | args.ctx = &ctx; |
203 | args.dst_md = pd()->dst_md(); |
204 | |
205 | for (size_t oc = start_oc; oc <= end_oc; oc++) { |
206 | args.l_offset = (g * jcp.oc + oc) * jcp.os; |
207 | post_ops_->execute(dst_arr[oc], args); |
208 | } |
209 | } |
210 | } |
211 | } |
212 | }); |
213 | } |
214 | } |
215 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); |
216 | } |
217 | return status::success; |
218 | } |
219 | |
220 | status_t gemm_convolution_fwd_t::execute_forward_ncsp( |
221 | const exec_ctx_t &ctx) const { |
222 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
223 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
224 | auto bias = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
225 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
226 | |
227 | auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col); |
228 | |
229 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
230 | |
231 | const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; |
232 | const size_t weights_oc_size = jcp.ic * jcp.ks; |
233 | const size_t weights_g_size = weights_oc_size * jcp.oc; |
234 | const bool is_problem_3d = pd()->ndims() == 5; |
235 | |
236 | assert(IMPLICATION(is_problem_3d, |
237 | jcp.os_block == jcp.os && jcp.ic_block == jcp.ic |
238 | && jcp.os_nb_block == 1)); |
239 | |
240 | status_t st = status::success; |
241 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
242 | data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
243 | |
244 | // non-blocked jit_gemm_convolution_utils::im2col_3d() requires |
245 | // external data initialization by zeroes |
246 | const bool outer_padding = jcp.os_nb_block == 1; |
247 | if (outer_padding && is_problem_3d) { |
248 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
249 | _col[i] = (data_t)0; |
250 | } |
251 | auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev, |
252 | im_pos_t &step, const im_pos_t &end) { |
253 | const data_t *_src |
254 | = src + (curr.n * jcp.ngroups + curr.g) * src_step; |
255 | step.oc = nstl::min( |
256 | jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc); |
257 | step.sp = nstl::min(jcp.os_block, |
258 | nstl::min(jcp.os - curr.sp, end.sp - spatial)); |
259 | step.ic = nstl::min( |
260 | jcp.ic_block, nstl::min(jcp.ic, end.ic) - curr.ic); |
261 | bool do_im2col = curr.do_im2col(prev); |
262 | prev = curr; |
263 | |
264 | if (jcp.im2col_sz && do_im2col) { |
265 | if (!is_problem_3d) |
266 | jit_gemm_convolution_utils::im2col<float>(jcp, _src, _col, |
267 | curr.sp, step.sp, curr.ic, step.ic); |
268 | else |
269 | jit_gemm_convolution_utils::im2col_3d<float>( |
270 | jcp, _src, _col, curr.od, 0, jcp.os); |
271 | } |
272 | const data_t one = 1.0; |
273 | |
274 | const dim_t M = jcp.os * jcp.od; |
275 | const size_t dst_step = jcp.oc * M; |
276 | const dim_t m = step.sp; |
277 | const dim_t LDA = jcp.im2col_sz ? m : M; |
278 | data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step |
279 | + curr.oc * M + curr.od * jcp.os + curr.sp; |
280 | const dim_t K = step.ic * jcp.ks; |
281 | const dim_t LDB = jcp.ic * jcp.ks; |
282 | const dim_t N = step.oc; |
283 | |
284 | // TODO: what if this->beta_ != 0 && != 1 ? |
285 | const float beta = (curr.ic == 0) ? this->beta_ : one; |
286 | const float *_source = jcp.im2col_sz |
287 | ? _col |
288 | : _src + curr.ic * M + curr.od * jcp.os + curr.sp; |
289 | const data_t *_weights = weights + curr.g * weights_g_size |
290 | + curr.oc * weights_oc_size + curr.ic * jcp.ks; |
291 | |
292 | status_t st = extended_sgemm("N" , "N" , &m, &N, &K, &one, _source, |
293 | &LDA, _weights, &LDB, &beta, _dst, &M); |
294 | if (st != status::success) return st; |
295 | |
296 | if (curr.ic == jcp.ic - step.ic) { |
297 | // TODO: for "outer threading" we have parallel section within |
298 | // outermost "parallel". It is not good. Consider to use |
299 | // "parallel" here with number of threads passed as parameter |
300 | const int oc_start = curr.g * jcp.oc + curr.oc; |
301 | if (jcp.with_eltwise || jcp.with_binary) { |
302 | bool fast_relu_done = false; |
303 | if (jcp.with_eltwise && jcp.post_ops.len() == 1) { |
304 | // fast branch for ReLU case |
305 | const auto &eltwise |
306 | = jcp.post_ops.entry_.back().eltwise; |
307 | if (eltwise.alg == alg_kind::eltwise_relu) { |
308 | parallel_nd(step.oc, [&](dim_t oc) { |
309 | data_t b = jcp.with_bias ? bias[oc_start + oc] |
310 | : 0; |
311 | data_t *d_ = _dst + oc * M; |
312 | PRAGMA_OMP_SIMD() |
313 | for (int oS = 0; oS < m; ++oS) { |
314 | d_[oS] += b; |
315 | if (d_[oS] < 0) d_[oS] *= eltwise.alpha; |
316 | d_[oS] *= eltwise.scale; |
317 | } |
318 | }); |
319 | fast_relu_done = true; |
320 | } |
321 | } |
322 | if (!fast_relu_done) { |
323 | parallel_nd(step.oc, [&](dim_t oc) { |
324 | data_t b = jcp.with_bias ? bias[oc_start + oc] : 0; |
325 | data_t *d_ = _dst + oc * M; |
326 | |
327 | ref_post_ops_t::args_t args; |
328 | args.ctx = &ctx; |
329 | args.dst_md = pd()->dst_md(); |
330 | args.l_offset = d_ - dst; |
331 | |
332 | PRAGMA_OMP_SIMD() |
333 | for (int oS = 0; oS < m; ++oS) { |
334 | d_[oS] += b; |
335 | post_ops_->execute(d_[oS], args); |
336 | args.l_offset++; |
337 | } |
338 | }); |
339 | } |
340 | |
341 | } else if (jcp.with_bias) { |
342 | parallel_nd(step.oc, [&](dim_t oc) { |
343 | data_t b = bias[oc_start + oc]; |
344 | data_t *d_ = _dst + oc * M; |
345 | PRAGMA_OMP_SIMD() |
346 | for (int oS = 0; oS < m; ++oS) { |
347 | d_[oS] += b; |
348 | } |
349 | }); |
350 | } |
351 | } |
352 | |
353 | return status::success; |
354 | }; |
355 | im_pos_t start, end; |
356 | end.ic = jcp.ic; |
357 | |
358 | if (!is_problem_3d) { |
359 | dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os; |
360 | balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, |
361 | end.oc, dim_t(jcp.nthr_oc)); |
362 | } else { |
363 | dim_t sp_work = jcp.mb * jcp.ngroups * jcp.od; |
364 | balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, |
365 | end.oc, dim_t(jcp.nthr_oc)); |
366 | start.sp *= jcp.os; |
367 | end.sp *= jcp.os; |
368 | } |
369 | |
370 | im_pos_t curr, prev, step; |
371 | prev.n = prev.g = prev.od = prev.sp = prev.ic = -1; |
372 | step.oc = jcp.oc_block; |
373 | step.sp = jcp.os_block; |
374 | step.ic = jcp.ic_block; |
375 | |
376 | if (jcp.loop_order == gemm_loop_rlb) |
377 | for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) |
378 | for (int spatial = start.sp; spatial < end.sp; |
379 | spatial += step.sp) { |
380 | nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, |
381 | jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os); |
382 | for (curr.oc = start.oc; curr.oc < end.oc; |
383 | curr.oc += step.oc) { |
384 | status_t st_thr |
385 | = inner_ker(spatial, curr, prev, step, end); |
386 | if (st_thr != status::success) { |
387 | st = st_thr; |
388 | return; |
389 | } |
390 | } |
391 | } |
392 | else if (jcp.loop_order == gemm_loop_lrb) |
393 | for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) { |
394 | nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups, |
395 | curr.od, jcp.od, curr.sp, jcp.os); |
396 | for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) |
397 | for (curr.oc = start.oc; curr.oc < end.oc; |
398 | curr.oc += step.oc) { |
399 | status_t st_thr |
400 | = inner_ker(spatial, curr, prev, step, end); |
401 | if (st_thr != status::success) { |
402 | st = st_thr; |
403 | return; |
404 | } |
405 | } |
406 | } |
407 | else |
408 | st = status::unimplemented; |
409 | }); |
410 | |
411 | return st; |
412 | } |
413 | |
414 | status_t gemm_convolution_bwd_data_t::execute_backward_data_nspc( |
415 | const exec_ctx_t &ctx) const { |
416 | |
417 | auto diff_dst_base = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
418 | auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
419 | auto bia_base = CTX_IN_MEM(const data_t *, DNNL_ARG_BIAS); |
420 | auto diff_src_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
421 | |
422 | auto scratchpad = ctx.get_scratchpad_grantor(); |
423 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
424 | std::atomic<status_t> st(status::success); |
425 | |
426 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
427 | status_t st_thr = execute_backward_data_thr_nspc(ithr, nthr, |
428 | diff_dst_base, wei_base, bia_base, diff_src_base, scratchpad); |
429 | if (st_thr != status::success) st = st_thr; |
430 | }); |
431 | |
432 | return st; |
433 | } |
434 | |
435 | status_t gemm_convolution_bwd_data_t::execute_backward_data_thr_nspc( |
436 | const int ithr, const int nthr, const data_t *diff_dst_base, |
437 | const data_t *wei_base, const data_t *bia_base, data_t *diff_src_base, |
438 | const memory_tracking::grantor_t &scratchpad) const { |
439 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
440 | |
441 | // Diff_dst Format: mb-spatial-groups-output_channels |
442 | const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh |
443 | * jcp.ow * jcp.ngroups * jcp.oc; |
444 | const size_t diff_dst_g_stride = jcp.oc; |
445 | |
446 | // Wei Format: spatial-input_channels-groups-output_channels |
447 | const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0; |
448 | |
449 | // Diff_src Format: mb-spatial-groups-input_channels |
450 | const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih |
451 | * jcp.iw * jcp.ngroups * jcp.ic; |
452 | const size_t diff_src_g_stride = jcp.ic; |
453 | const size_t diff_src_os_stride = jcp.ngroups * jcp.ic; |
454 | |
455 | // threads share work across mini-batch and groups |
456 | const dim_t work_amount = jcp.ngroups * jcp.mb; |
457 | |
458 | data_t *__restrict col = scratchpad.get<data_t>(key_conv_gemm_col) |
459 | + (ptrdiff_t)ithr * jcp.im2col_sz; |
460 | const bool acc_needed = jcp.ngroups > 1; |
461 | data_t *__restrict acc = acc_needed |
462 | ? scratchpad.get<data_t>(key_conv_gemm_acc) |
463 | + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic |
464 | : nullptr; |
465 | |
466 | dim_t n {0}, g {0}; |
467 | dim_t start = 0, end = 0; |
468 | |
469 | balance211(work_amount, nthr, ithr, start, end); |
470 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); |
471 | |
472 | for (dim_t iwork = start; iwork < end; ++iwork) { |
473 | const data_t *__restrict diff_dst = diff_dst_base |
474 | + n * diff_dst_mb_stride + g * diff_dst_g_stride; |
475 | const data_t *__restrict wei = wei_base + g * wei_g_stride; |
476 | data_t *__restrict diff_src = diff_src_base + n * diff_src_mb_stride |
477 | + g * diff_src_g_stride; |
478 | |
479 | const dim_t M = jcp.ks * jcp.ic; |
480 | const dim_t N = jcp.os * jcp.od; |
481 | const dim_t K = jcp.oc; |
482 | |
483 | const data_t onef = 1.0f, zerof = 0.0f; |
484 | const dim_t LD = K * jcp.ngroups; |
485 | |
486 | status_t st = extended_sgemm("T" , "N" , &M, &N, &K, &onef, wei, &LD, |
487 | diff_dst, &LD, &zerof, |
488 | jcp.im2col_sz ? col : (acc_needed ? acc : diff_src), &M); |
489 | if (st != status::success) return st; |
490 | |
491 | if (jcp.im2col_sz) |
492 | jit_gemm_convolution_utils::col2im_dt<data_t>( |
493 | jcp, col, (acc_needed ? acc : diff_src)); |
494 | |
495 | if (acc_needed) { |
496 | parallel_nd(static_cast<size_t>(jcp.is) * jcp.id, [&](size_t is) { |
497 | data_t *__restrict diff_src_arr |
498 | = diff_src + is * diff_src_os_stride; |
499 | const data_t *__restrict acc_arr = acc + is * jcp.ic; |
500 | PRAGMA_OMP_SIMD() |
501 | for (int ic = 0; ic < jcp.ic; ic++) { |
502 | diff_src_arr[ic] = acc_arr[ic]; |
503 | } |
504 | }); |
505 | } |
506 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups); |
507 | } |
508 | return status::success; |
509 | } |
510 | |
511 | status_t gemm_convolution_bwd_data_t::execute_backward_data_ncsp( |
512 | const exec_ctx_t &ctx) const { |
513 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
514 | auto weights = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); |
515 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
516 | |
517 | auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col); |
518 | |
519 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
520 | |
521 | const dim_t M = jcp.os * jcp.od; |
522 | const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; |
523 | const size_t dst_step = (size_t)jcp.oc * M; |
524 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
525 | |
526 | const dim_t m = jcp.os_block; |
527 | const dim_t K = jcp.oc; |
528 | const dim_t N = jcp.ic * jcp.ks; |
529 | |
530 | const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb; |
531 | const bool is_problem_3d = pd()->ndims() == 5; |
532 | |
533 | std::atomic<status_t> st(status::success); |
534 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
535 | data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
536 | |
537 | dim_t g {0}, n {0}; |
538 | dim_t start = 0, end = 0; |
539 | balance211(work_amount, nthr, ithr, start, end); |
540 | nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); |
541 | for (dim_t iwork = start; iwork < end; ++iwork) { |
542 | |
543 | data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step; |
544 | if (is_problem_3d && jcp.im2col_sz > 0) { |
545 | // jit_gemm_convolution_utils::col2im_3d() assumes that the |
546 | // accumulator is initialized by zeroes |
547 | for (size_t i = 0; i < src_step; i++) |
548 | _diff_src[i] = (data_t)0; |
549 | } |
550 | |
551 | const data_t *_weights = weights + g * weights_g_size; |
552 | for_(int od = 0; od < jcp.od; ++od) |
553 | for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) { |
554 | auto out_off = os_nb * m + od * jcp.os; |
555 | const data_t *_diff_dst |
556 | = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off; |
557 | const dim_t os_block |
558 | = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m); |
559 | const dim_t LDC = jcp.im2col_sz ? os_block : M; |
560 | |
561 | const data_t zero = 0.0, one = 1.0; |
562 | status_t st_thr = extended_sgemm("N" , "T" , &os_block, &N, &K, |
563 | &one, _diff_dst, &M, _weights, &N, &zero, |
564 | jcp.im2col_sz ? _col : _diff_src + out_off, &LDC); |
565 | if (st_thr != status::success) { |
566 | st = st_thr; |
567 | return; |
568 | } |
569 | |
570 | if (jcp.im2col_sz) { |
571 | if (!is_problem_3d) |
572 | jit_gemm_convolution_utils::col2im(jcp, _col, _diff_src, |
573 | os_nb * jcp.os_block, os_block); |
574 | else { |
575 | jit_gemm_convolution_utils::col2im_3d(jcp, _col, |
576 | _diff_src, od, os_nb * jcp.os_block, os_block); |
577 | } |
578 | } |
579 | } |
580 | nd_iterator_step(g, jcp.ngroups, n, jcp.mb); |
581 | } |
582 | }); |
583 | |
584 | return st; |
585 | } |
586 | |
587 | status_t gemm_convolution_bwd_weights_t::execute_backward_weights_nspc( |
588 | const exec_ctx_t &ctx) const { |
589 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
590 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
591 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
592 | auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
593 | |
594 | auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col); |
595 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
596 | |
597 | auto wei_reduction |
598 | = ctx.get_scratchpad_grantor().get<data_t>(key_conv_wei_reduction); |
599 | |
600 | const dim_t K = jcp.os * static_cast<size_t>(jcp.od); |
601 | const size_t src_step |
602 | = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id; |
603 | const size_t dst_step = jcp.oc * K; |
604 | const size_t weights_g_size = jcp.oc; |
605 | |
606 | const dim_t k = jcp.os; |
607 | const dim_t M = jcp.oc; |
608 | const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks; |
609 | const dim_t LDB = jcp.ngroups * jcp.oc; |
610 | const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic; |
611 | const bool is_problem_3d = pd()->ndims() == 5; |
612 | |
613 | std::atomic<status_t> st(status::success); |
614 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
615 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
616 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
617 | |
618 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
619 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, |
620 | mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); |
621 | |
622 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
623 | |
624 | const int need_reduction = nthr_mb != 1; |
625 | const dim_t LDC = need_reduction ? jcp.oc : jcp.ngroups * jcp.oc; |
626 | data_t *__restrict imtr |
627 | = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_imtr) |
628 | + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is; |
629 | |
630 | if (ithr_g != -1 && ithr_mb != -1) { |
631 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
632 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
633 | |
634 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
635 | |
636 | data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
637 | if (is_problem_3d) { |
638 | // jit_gemm_convolution_utils::im2col_3d() requires external |
639 | // data initialization by zeroes |
640 | PRAGMA_OMP_SIMD() |
641 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
642 | _col[i] = 0.0f; |
643 | } |
644 | |
645 | data_t *weights_reduce_base = wei_reduction |
646 | + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic; |
647 | data_t *weights_reduce = weights_reduce_base |
648 | + ithr_mb * weights_g_size * jcp.ks * jcp.ic; |
649 | |
650 | for (size_t g = g_start; g < g_end; ++g) { |
651 | data_t *_diff_weights = need_reduction |
652 | ? weights_reduce |
653 | : diff_weights + g * weights_g_size; |
654 | for (size_t mb = mb_start; mb < mb_end; ++mb) { |
655 | const data_t *_src |
656 | = src + mb * jcp.ngroups * src_step + g * jcp.ic; |
657 | if (jcp.im2col_sz && is_problem_3d) |
658 | jit_gemm_convolution_utils::transpose_dt( |
659 | jcp, _src, imtr); |
660 | for (int od = 0; od < jcp.od; ++od) { |
661 | const data_t *_diff_dst = diff_dst |
662 | + mb * jcp.ngroups * dst_step |
663 | + od * k * jcp.ngroups * jcp.oc + g * jcp.oc; |
664 | |
665 | if (jcp.im2col_sz) { |
666 | if (is_problem_3d) |
667 | jit_gemm_convolution_utils::im2col_dt_3d<data_t, |
668 | data_t>(jcp, imtr, _col, od); |
669 | else |
670 | jit_gemm_convolution_utils::im2col_dt<data_t, |
671 | data_t>(jcp, _src, imtr, _col, 0, |
672 | jcp.oh, 0, jcp.ow); |
673 | } |
674 | const data_t zero = 0.0f, one = 1.0f; |
675 | status_t st_thr = extended_sgemm("N" , |
676 | jcp.im2col_sz ? "N" : "T" , &M, &N, &k, &one, |
677 | _diff_dst, &LDB, |
678 | jcp.im2col_sz |
679 | ? _col |
680 | : _src + od * k * jcp.ngroups * jcp.ic, |
681 | &LDA, mb == mb_start && od == 0 ? &zero : &one, |
682 | _diff_weights, &LDC); |
683 | if (st_thr != status::success) { |
684 | st = st_thr; |
685 | // Finish the loops early if failure occured. |
686 | g = g_end; |
687 | mb = mb_end; |
688 | od = jcp.od; |
689 | } |
690 | } |
691 | } |
692 | } |
693 | if (need_reduction && dnnl_thr_syncable()) { |
694 | dnnl_thr_barrier(); |
695 | if (st != status::success) return; |
696 | jit_gemm_convolution_utils::bwd_weights_reduction_par_nspc( |
697 | ithr_mb, nthr_mb, g_start, g_end, jcp, |
698 | weights_reduce_base, diff_weights); |
699 | } |
700 | } else { |
701 | if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier(); |
702 | } |
703 | }); |
704 | |
705 | if (jcp.need_wei_reduction && !dnnl_thr_syncable()) { |
706 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
707 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
708 | size_t g_start {0}, g_end {0}; |
709 | size_t mb_start {0}, mb_end {0}; |
710 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
711 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, |
712 | jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb, |
713 | nthr_mb); |
714 | |
715 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
716 | const int need_reduction = nthr_mb != 1; |
717 | |
718 | if (need_reduction && ithr_g != -1 && ithr_mb != -1) { |
719 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
720 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
721 | |
722 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
723 | |
724 | data_t *weights_reduce_base = wei_reduction |
725 | + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks; |
726 | |
727 | jit_gemm_convolution_utils::bwd_weights_reduction_par_nspc( |
728 | ithr_mb, nthr_mb, g_start, g_end, jcp, |
729 | weights_reduce_base, diff_weights); |
730 | } |
731 | }); |
732 | } |
733 | |
734 | if (jcp.with_bias) { |
735 | parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) { |
736 | data_t db = 0; |
737 | const size_t offset_base = g * jcp.oc + oc; |
738 | for_(dim_t mb = 0; mb < jcp.mb; ++mb) |
739 | for_(dim_t od = 0; od < jcp.od; ++od) |
740 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
741 | const data_t *__restrict diff_dst_arr = diff_dst + offset_base |
742 | + ((static_cast<size_t>(mb) * jcp.od + od) * jcp.oh |
743 | + oh) |
744 | * jcp.ow * jcp.ngroups * jcp.oc; |
745 | const int width_stride = jcp.ngroups * jcp.oc; |
746 | |
747 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
748 | for (int ow = 0; ow < jcp.ow; ++ow) { |
749 | db += diff_dst_arr[ow * width_stride]; |
750 | } |
751 | } |
752 | diff_bias[g * jcp.oc + oc] = db; |
753 | }); |
754 | } |
755 | return st; |
756 | } |
757 | |
758 | status_t gemm_convolution_bwd_weights_t::execute_backward_weights_ncsp( |
759 | const exec_ctx_t &ctx) const { |
760 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
761 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
762 | auto diff_weights = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_WEIGHTS); |
763 | auto diff_bias = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_BIAS); |
764 | |
765 | auto col = ctx.get_scratchpad_grantor().get<data_t>(key_conv_gemm_col); |
766 | auto wei_reduction |
767 | = ctx.get_scratchpad_grantor().get<data_t>(key_conv_wei_reduction); |
768 | |
769 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
770 | |
771 | const dim_t K = jcp.os * jcp.od; |
772 | const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; |
773 | const size_t dst_step = jcp.oc * K; |
774 | const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; |
775 | |
776 | const dim_t k = jcp.os_block; |
777 | const dim_t N = jcp.oc; |
778 | const dim_t M = jcp.ic * jcp.ks; |
779 | const bool is_problem_3d = pd()->ndims() == 5; |
780 | |
781 | std::atomic<status_t> st(status::success); |
782 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
783 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
784 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
785 | |
786 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
787 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, |
788 | mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); |
789 | |
790 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
791 | const int need_reduction = nthr_mb != 1; |
792 | |
793 | if (ithr_g != -1 && ithr_mb != -1) { |
794 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
795 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
796 | |
797 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
798 | |
799 | data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
800 | |
801 | // non-blocked jit_gemm_convolution_utils::im2col_3d() requires |
802 | // external data initialization by zeroes |
803 | const bool outer_padding = jcp.os_nb_block == 1; |
804 | if (outer_padding && is_problem_3d) { |
805 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
806 | _col[i] = (data_t)0; |
807 | } |
808 | data_t *weights_reduce_base |
809 | = wei_reduction + ithr_g * nthr_mb * weights_g_size; |
810 | data_t *weights_reduce |
811 | = weights_reduce_base + ithr_mb * weights_g_size; |
812 | |
813 | for (size_t g = g_start; g < g_end; ++g) { |
814 | data_t *_diff_weights = need_reduction |
815 | ? weights_reduce |
816 | : (diff_weights + g * weights_g_size); |
817 | for (size_t mb = mb_start; mb < mb_end; ++mb) { |
818 | const data_t *_src |
819 | = src + (mb * jcp.ngroups + g) * src_step; |
820 | for_(int od = 0; od < jcp.od; ++od) |
821 | for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) { |
822 | auto out_off = os_nb * k + od * jcp.os; |
823 | const dim_t os_block = nstl::min( |
824 | (dim_t)jcp.os_block, jcp.os - os_nb * k); |
825 | const data_t *_diff_dst = diff_dst |
826 | + (mb * jcp.ngroups + g) * dst_step + out_off; |
827 | |
828 | if (jcp.im2col_sz) { |
829 | if (!is_problem_3d) |
830 | jit_gemm_convolution_utils::im2col<float>(jcp, |
831 | _src, _col, os_nb * jcp.os_block, |
832 | os_block, 0, jcp.ic); |
833 | else |
834 | jit_gemm_convolution_utils::im2col_3d<float>( |
835 | jcp, _src, _col, od, |
836 | os_nb * jcp.os_block, os_block); |
837 | } |
838 | const dim_t LDA = jcp.im2col_sz ? os_block : K; |
839 | const data_t zero = 0.0, one = 1.0; |
840 | status_t st_thr = extended_sgemm("T" , "N" , &M, &N, |
841 | &os_block, &one, |
842 | jcp.im2col_sz ? _col : _src + out_off, &LDA, |
843 | _diff_dst, &K, |
844 | mb == mb_start && os_nb == 0 && od == 0 ? &zero |
845 | : &one, |
846 | _diff_weights, &M); |
847 | if (st_thr != status::success) { |
848 | st = st_thr; |
849 | // Finish the loops early if failure occured. |
850 | g = g_end; |
851 | mb = mb_end; |
852 | od = jcp.od; |
853 | os_nb = jcp.os_nb_block; |
854 | } |
855 | } |
856 | } |
857 | } |
858 | if (need_reduction && dnnl_thr_syncable()) { |
859 | dnnl_thr_barrier(); |
860 | if (st != status::success) return; |
861 | data_t *weights_base = diff_weights + g_start * weights_g_size; |
862 | jit_gemm_convolution_utils::bwd_weights_reduction_par_ncsp( |
863 | ithr_mb, nthr_mb, jcp, weights_reduce_base, |
864 | weights_base); |
865 | } |
866 | } else { |
867 | if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier(); |
868 | } |
869 | }); |
870 | |
871 | if (st != status::success) return st; |
872 | |
873 | if (jcp.need_wei_reduction && !dnnl_thr_syncable()) { |
874 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
875 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
876 | size_t g_start {0}, g_end {0}; |
877 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
878 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, |
879 | jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb, |
880 | nthr_mb); |
881 | |
882 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
883 | const int need_reduction = nthr_mb != 1; |
884 | |
885 | if (need_reduction && ithr_g != -1 && ithr_mb != -1) { |
886 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
887 | |
888 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
889 | |
890 | data_t *weights_reduce_base |
891 | = wei_reduction + ithr_g * nthr_mb * weights_g_size; |
892 | data_t *weights_base = diff_weights + g_start * weights_g_size; |
893 | |
894 | jit_gemm_convolution_utils::bwd_weights_reduction_par_ncsp( |
895 | ithr_mb, nthr_mb, jcp, weights_reduce_base, |
896 | weights_base); |
897 | } |
898 | }); |
899 | } |
900 | |
901 | if (jcp.with_bias) { |
902 | parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) { |
903 | data_t db = 0; |
904 | dim_t offset_ = g * dst_step + oc * K; |
905 | for (dim_t mb = 0; mb < jcp.mb; ++mb) { |
906 | dim_t offset = offset_ + mb * jcp.ngroups * dst_step; |
907 | for_(dim_t od = 0; od < jcp.od; ++od) |
908 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
909 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
910 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
911 | db += diff_dst[offset + ow]; |
912 | } |
913 | offset += jcp.ow; |
914 | } |
915 | } |
916 | diff_bias[g * jcp.oc + oc] = db; |
917 | }); |
918 | } |
919 | |
920 | return st; |
921 | } |
922 | |
923 | } // namespace cpu |
924 | } // namespace impl |
925 | } // namespace dnnl |
926 | |