1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl_types.h"
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24#include "cpu/gemm_convolution_utils.hpp"
25#include "cpu/scale_utils.hpp"
26#if DNNL_X64
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#endif
29
30#include "cpu/platform.hpp"
31
32#if DNNL_X64
33#include "cpu/x64/cpu_isa_traits.hpp"
34#endif
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39
40using namespace dnnl::impl::status;
41using namespace dnnl::impl::utils;
42using namespace prop_kind;
43using namespace data_type;
44
45single_gemm_conv_chunk_desc_t::single_gemm_conv_chunk_desc_t(dim_t d_off,
46 dim_t d_size, dim_t h_off, dim_t h_size, dim_t w_off, dim_t w_size)
47 : d_off_(d_off)
48 , d_size_(d_size)
49 , h_off_(h_off)
50 , h_size_(h_size)
51 , w_off_(w_off)
52 , w_size_(w_size) {}
53
54namespace jit_gemm_convolution_utils {
55
56template <typename data_type_t>
57void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im,
58 data_type_t *col, dim_t od, int spatial_step, int spatial_block) {
59 using data_t =
60 typename conditional<data_traits<data_type_t>::data_type == bf16,
61 uint16_t, data_type_t>::type;
62 const data_t *__restrict _im
63 = reinterpret_cast<const data_t *__restrict>(im);
64 data_t *__restrict _col = reinterpret_cast<data_t *__restrict>(col);
65
66 const size_t OHW = spatial_block;
67 const size_t im_step = jcp.ih * jcp.iw * jcp.id;
68 const size_t col_step = jcp.ks * OHW;
69
70 auto compute_im2col_outer_padding = [&](dim_t ic) {
71 const data_t *__restrict im_loc = _im + ic * im_step;
72 data_t *__restrict col_loc = _col + ic * col_step;
73 dim_t id = od * jcp.stride_d - jcp.f_pad;
74 for (dim_t kd = 0; kd < jcp.kd; ++kd) {
75 data_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
76 if (id < 0 || id >= jcp.id) {
77 dim_t ih_ = -jcp.t_pad;
78 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
79 dim_t ih = ih_;
80 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
81 if (ih < 0 || ih >= jcp.ih) {
82 ih += jcp.stride_h;
83 continue;
84 }
85 dim_t iw_ = -jcp.l_pad;
86 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
87 dim_t iw = iw_;
88 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
89 if (iw < 0 || iw >= jcp.iw) {
90 iw += jcp.stride_w;
91 continue;
92 }
93
94 const size_t col_idx
95 = kw * OHW + oh * jcp.ow + ow;
96
97 col_[col_idx] = 0;
98 iw += jcp.stride_w;
99 }
100 iw_ += (1 + jcp.dilate_w);
101 }
102 ih += jcp.stride_h;
103 }
104 ih_ += (1 + jcp.dilate_h);
105 col_ += jcp.kw * OHW;
106 }
107 } else {
108 const data_t *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
109 dim_t ih_ = -jcp.t_pad;
110 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
111 dim_t ih = ih_;
112 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
113 if (ih < 0 || ih >= jcp.ih) {
114 ih += jcp.stride_h;
115 continue;
116 }
117 dim_t iw_ = -jcp.l_pad;
118 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
119 dim_t iw = iw_;
120 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
121 if (iw < 0 || iw >= jcp.iw) {
122 iw += jcp.stride_w;
123 continue;
124 }
125
126 const size_t col_idx
127 = kw * OHW + oh * jcp.ow + ow;
128 const size_t im_idx = ih * jcp.iw + iw;
129
130 col_[col_idx] = im_[im_idx];
131 iw += jcp.stride_w;
132 }
133 iw_ += (1 + jcp.dilate_w);
134 }
135 ih += jcp.stride_h;
136 }
137 ih_ += (1 + jcp.dilate_h);
138 col_ += jcp.kw * OHW;
139 }
140 }
141 id += (1 + jcp.dilate_d);
142 }
143 };
144 auto compute_im2col_padding = [&](dim_t ic) {
145 const dim_t first_oh = spatial_step / jcp.ow;
146 const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow;
147 const dim_t oh_begin = first_oh;
148 const dim_t oh_end = last_oh + 1;
149 const dim_t first_ow = spatial_step % jcp.ow;
150 const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow;
151
152 const data_t *__restrict im_loc = _im + ic * im_step;
153 data_t *__restrict col_loc = _col + ic * col_step;
154 dim_t id = od * jcp.stride_d - jcp.f_pad;
155 for (dim_t kd = 0; kd < jcp.kd; ++kd) {
156 data_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
157 if (id < 0 || id >= jcp.id) {
158 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
159 for (dim_t oh = oh_begin; oh < oh_end; ++oh) {
160 const dim_t ow_begin = (oh == first_oh) ? first_ow : 0;
161 const dim_t ow_end
162 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
163 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
164 for (dim_t ow = ow_begin; ow < ow_end; ++ow) {
165 const size_t col_idx = kw * OHW + oh * jcp.ow
166 + ow - spatial_step;
167 col_[col_idx] = 0;
168 }
169 }
170 }
171 col_ += jcp.kw * OHW;
172 }
173 } else {
174 const data_t *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
175 dim_t ih_ = oh_begin * jcp.stride_h - jcp.t_pad;
176 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
177 dim_t ih = ih_;
178 for (dim_t oh = oh_begin; oh < oh_end; ++oh) {
179 const dim_t ow_begin = (oh == first_oh) ? first_ow : 0;
180 const dim_t ow_end
181 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
182 if (ih < 0 || ih >= jcp.ih) {
183 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
184 for (dim_t ow = ow_begin; ow < ow_end; ++ow) {
185 const size_t col_idx = kw * OHW
186 + oh * jcp.ow + ow - spatial_step;
187 col_[col_idx] = 0;
188 }
189 }
190 ih += jcp.stride_h;
191 continue;
192 }
193 dim_t iw_ = ow_begin * jcp.stride_w - jcp.l_pad;
194 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
195 dim_t iw = iw_;
196 for (dim_t ow = ow_begin; ow < ow_end; ++ow) {
197 const size_t col_idx = kw * OHW + oh * jcp.ow
198 + ow - spatial_step;
199 if (iw < 0 || iw >= jcp.iw) {
200 col_[col_idx] = 0;
201 iw += jcp.stride_w;
202 continue;
203 }
204 const size_t im_idx = ih * jcp.iw + iw;
205 col_[col_idx] = im_[im_idx];
206 iw += jcp.stride_w;
207 }
208 iw_ += (1 + jcp.dilate_w);
209 }
210 ih += jcp.stride_h;
211 }
212 ih_ += (1 + jcp.dilate_h);
213 col_ += jcp.kw * OHW;
214 }
215 }
216 id += (1 + jcp.dilate_d);
217 }
218 };
219
220 // zero padding is handled outside im2col
221 const bool outer_padding = jcp.os_nb_block == 1;
222 if (outer_padding)
223 parallel_nd(jcp.ic, compute_im2col_outer_padding);
224 else
225 parallel_nd(jcp.ic, compute_im2col_padding);
226}
227
228template void im2col_3d(const conv_gemm_conf_t &jcp, const float *im,
229 float *col, dim_t od, int spatial_step, int spatial_block);
230
231template void im2col_3d(const conv_gemm_conf_t &jcp, const bfloat16_t *im,
232 bfloat16_t *col, dim_t od, int spatial_step, int spatial_block);
233
234/* imtr[ic][od][oh][ow] <-- im[id][ih][iw][ic]*/
235template <typename T>
236void transpose_dt(const conv_gemm_conf_t &jcp, const T *__restrict im,
237 T *__restrict imtr) {
238 uint8_t shift = jcp.signed_input ? 128 : 0;
239 const dim_t ic_stride = jcp.id * jcp.ih * jcp.iw;
240 const dim_t IC = jcp.ngroups * jcp.ic;
241 const dim_t IHW = jcp.ih * jcp.iw;
242 constexpr dim_t ic_block = platform::get_cache_line_size();
243 const dim_t nb_ic = jcp.ic / ic_block;
244 const dim_t ic_blocked = nb_ic * ic_block;
245 parallel_nd(jcp.id, jcp.ih, [&](dim_t id, dim_t ih) {
246 const T *__restrict im_h = im + id * IHW * IC + ih * jcp.iw * IC;
247 T *__restrict imtr_h = imtr + id * IHW + ih * jcp.iw;
248 for (dim_t iw = 0; iw < jcp.iw; iw++) {
249 const T *__restrict im_w = im_h + iw * IC;
250 T *__restrict imtr_w = imtr_h + iw;
251 for (dim_t icb = 0; icb < nb_ic; icb++) {
252 const T *__restrict im_icb = im_w + icb * ic_block;
253 T *__restrict imtr_icb = imtr_w + icb * ic_block * ic_stride;
254 PRAGMA_OMP_SIMD()
255 for (dim_t ic = 0; ic < ic_block; ic++) {
256 imtr_icb[ic * ic_stride] = im_icb[ic] + shift;
257 }
258 }
259 for (dim_t ic = ic_blocked; ic < jcp.ic; ic++) {
260 imtr_w[ic * ic_stride] = im_w[ic] + shift;
261 }
262 }
263 });
264}
265
266template void transpose_dt(const conv_gemm_conf_t &jcp,
267 const int8_t *__restrict im, int8_t *__restrict imtr);
268template void transpose_dt(const conv_gemm_conf_t &jcp,
269 const uint8_t *__restrict im, uint8_t *__restrict imtr);
270template void transpose_dt(const conv_gemm_conf_t &jcp,
271 const char *__restrict im, char *__restrict imtr);
272template void transpose_dt(const conv_gemm_conf_t &jcp,
273 const float *__restrict im, float *__restrict imtr);
274template void transpose_dt(const conv_gemm_conf_t &jcp,
275 const bfloat16_t *__restrict im, bfloat16_t *__restrict imtr);
276
277/* col[kd][kh][kw][g][ic][od][oh][ow] <-- im2col_dt_3d(im[id][ih][iw][g][ic]) */
278template <typename orig_im_dt, typename orig_col_dt>
279void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr,
280 orig_col_dt *__restrict _col, dim_t od) {
281 // For performance reasons, use uint16_t as a proxy for bfloat16_t
282 using im_dt = typename utils::conditional<data_traits<orig_im_dt>::data_type
283 == bf16,
284 uint16_t, orig_im_dt>::type;
285 using col_dt =
286 typename utils::conditional<data_traits<orig_col_dt>::data_type
287 == bf16,
288 uint16_t, orig_col_dt>::type;
289 const im_dt *__restrict imtr
290 = reinterpret_cast<const im_dt *__restrict>(_imtr);
291 col_dt *__restrict col = reinterpret_cast<col_dt *__restrict>(_col);
292
293 col_dt shift = static_cast<col_dt>(jcp.signed_input ? 128 : 0);
294 const dim_t dd = 1 + jcp.dilate_d;
295 const dim_t dh = 1 + jcp.dilate_h;
296 const dim_t dw = 1 + jcp.dilate_w;
297 const dim_t sd = jcp.stride_d;
298 const dim_t sh = jcp.stride_h;
299 const dim_t sw = jcp.stride_w;
300 const dim_t fp = jcp.f_pad;
301 const dim_t tp = jcp.t_pad;
302 const dim_t lp = jcp.l_pad;
303 const dim_t col_ic_s = jcp.oh * jcp.ow;
304 const dim_t col_kw_s = jcp.ic * col_ic_s;
305 const dim_t col_kh_s = jcp.kw * col_kw_s;
306 const dim_t col_kd_s = jcp.kh * col_kh_s;
307 const dim_t IHW = jcp.ih * jcp.iw;
308 const dim_t OHW = jcp.oh * jcp.ow;
309
310 if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1)
311 parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
312 [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
313 col_dt *__restrict col_loc = col + kd * col_kd_s
314 + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
315 const dim_t id = od - fp + kd;
316 if (id < 0 || id >= jcp.id) {
317 for (ptrdiff_t i = 0; i < OHW; i++)
318 col_loc[i] = shift;
319 return;
320 }
321 const im_dt *__restrict imtr_loc
322 = imtr + (ic * jcp.id + id) * IHW;
323 const dim_t oh_start = saturate(dim_t(0), jcp.oh, tp - kh);
324 const dim_t oh_end
325 = saturate(dim_t(0), jcp.oh, jcp.ih + tp - kh);
326 const dim_t ow_start = saturate(dim_t(0), jcp.ow, lp - kw);
327 const dim_t ow_end
328 = saturate(dim_t(0), jcp.ow, jcp.iw + lp - kw);
329 for (dim_t oh = oh_start, ih = oh_start - tp + kh;
330 oh < oh_end; oh++, ih++) {
331 col_dt *__restrict col_h = col_loc + oh * jcp.ow;
332 const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw;
333 for (dim_t ow = ow_start, iw = ow_start - lp + kw;
334 ow < ow_end; ow++, iw++) {
335 col_h[ow] = imtr_h[iw];
336 }
337 }
338 });
339 else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1)
340 parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
341 [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
342 col_dt *__restrict col_loc = col + kd * col_kd_s
343 + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
344 const dim_t id = od * 2 - fp + kd;
345 if (id < 0 || id >= jcp.id) {
346 for (ptrdiff_t i = 0; i < OHW; i++)
347 col_loc[i] = shift;
348 return;
349 }
350 const im_dt *__restrict imtr_loc
351 = imtr + (ic * jcp.id + id) * IHW;
352 const dim_t oh_start
353 = saturate(dim_t(0), jcp.oh, div_up(tp - kh, 2));
354 const dim_t oh_end = saturate(
355 dim_t(0), jcp.oh, div_up(jcp.ih + tp - kh, 2));
356 const dim_t ow_start
357 = saturate(dim_t(0), jcp.ow, div_up(lp - kw, 2));
358 const dim_t ow_end = saturate(
359 dim_t(0), jcp.ow, div_up(jcp.iw + lp - kw, 2));
360 for (dim_t oh = oh_start, ih = oh_start * 2 - tp + kh;
361 oh < oh_end; ++oh, ih += 2) {
362 col_dt *__restrict col_h = col_loc + oh * jcp.ow;
363 const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw;
364 for (dim_t ow = ow_start, iw = ow_start * 2 - lp + kw;
365 ow < ow_end; ++ow, iw += 2) {
366 col_h[ow] = imtr_h[iw];
367 }
368 }
369 });
370 else
371 parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic,
372 [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) {
373 col_dt *__restrict col_loc = col + kd * col_kd_s
374 + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s;
375 const dim_t id = od * sd - fp + kd * dd;
376 if (id < 0 || id >= jcp.id) {
377 for (ptrdiff_t i = 0; i < OHW; i++)
378 col_loc[i] = shift;
379 return;
380 }
381 const im_dt *__restrict imtr_loc
382 = imtr + (ic * jcp.id + id) * IHW;
383 const dim_t oh_start = saturate(
384 dim_t(0), jcp.oh, div_up(tp - kh * dh, sh));
385 const dim_t oh_end = saturate(dim_t(0), jcp.oh,
386 div_up(jcp.ih + tp - kh * dh, sh));
387 const dim_t ow_start = saturate(
388 dim_t(0), jcp.ow, div_up(lp - kw * dw, sw));
389 const dim_t ow_end = saturate(dim_t(0), jcp.ow,
390 div_up(jcp.iw + lp - kw * dw, sw));
391 for (dim_t oh = oh_start, ih = oh_start * sh - tp + kh * dh;
392 oh < oh_end; ++oh, ih += sh) {
393 col_dt *__restrict col_h = col_loc + oh * jcp.ow;
394 const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw;
395 for (dim_t ow = ow_start,
396 iw = ow_start * sw - lp + kw * dw;
397 ow < ow_end; ++ow, iw += sw) {
398 col_h[ow] = imtr_h[iw];
399 }
400 }
401 });
402}
403
404template void im2col_dt_3d<int8_t, uint8_t>(const conv_gemm_conf_t &jcp,
405 const void *__restrict im, uint8_t *__restrict col, dim_t od);
406template void im2col_dt_3d<uint8_t, uint8_t>(const conv_gemm_conf_t &jcp,
407 const void *__restrict im, uint8_t *__restrict col, dim_t od);
408template void im2col_dt_3d<float, float>(const conv_gemm_conf_t &jcp,
409 const void *__restrict im, float *__restrict col, dim_t od);
410template void im2col_dt_3d<bfloat16_t, bfloat16_t>(const conv_gemm_conf_t &jcp,
411 const void *__restrict im, bfloat16_t *__restrict col, dim_t od);
412
413/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
414template <typename data_type_t>
415void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im,
416 data_type_t *__restrict col, dim_t ss, dim_t sb, dim_t cs, dim_t cb) {
417
418 using data_t =
419 typename utils::conditional<data_traits<data_type_t>::data_type
420 == bf16,
421 uint16_t, data_type_t>::type;
422 const data_t *__restrict _im
423 = reinterpret_cast<const data_t *__restrict>(im);
424 data_t *__restrict _col = reinterpret_cast<data_t *__restrict>(col);
425
426 const size_t im_step = jcp.is;
427 const size_t col_step = jcp.ks * sb;
428 const dim_t dh = 1 + jcp.dilate_h;
429 const dim_t dw = 1 + jcp.dilate_w;
430 const dim_t sh = jcp.stride_h;
431 const dim_t sw = jcp.stride_w;
432 const dim_t tp = jcp.t_pad;
433 const dim_t lp = jcp.l_pad;
434 const dim_t first_oh = ss / jcp.ow;
435 const dim_t last_oh = (ss + sb - 1) / jcp.ow;
436 const dim_t oh_begin = first_oh;
437 const dim_t oh_end = last_oh + 1;
438 const dim_t first_ow = ss % jcp.ow;
439 const dim_t last_ow = (ss + sb - 1) % jcp.ow;
440
441 const data_t zero_val = 0;
442
443 if (jcp.outer_threading) {
444 if (sw == 1) {
445 // Generated code is more optimized for stride_w == 1
446 // because innermost loop is by width
447 for (dim_t ic = 0; ic < cb; ic++) {
448 const data_t *__restrict im_ic = _im + (ic + cs) * im_step;
449 for (dim_t kh = 0; kh < jcp.kh; kh++) {
450 for (dim_t kw = 0; kw < jcp.kw; kw++) {
451 data_t *__restrict col_k = _col + ic * col_step
452 + (kh * jcp.kw + kw) * sb;
453 for (dim_t oh = oh_begin; oh < oh_end; oh++) {
454 const dim_t ih = oh * sh - tp + kh * dh;
455 const data_t *__restrict im_
456 = im_ic + ih * jcp.iw - lp + kw * dw;
457 const dim_t ow_begin
458 = (oh == first_oh) ? first_ow : 0;
459 const dim_t ow_end
460 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
461 data_t *__restrict col_ = col_k + oh * jcp.ow - ss;
462 if (ih < 0 || ih >= jcp.ih)
463 for (dim_t ow = ow_begin; ow < ow_end; ow++)
464 col_[ow] = zero_val;
465 else {
466 for (dim_t ow = ow_begin; ow < ow_end; ++ow) {
467 const dim_t iw = ow;
468 if (iw < lp - kw * dw
469 || iw >= jcp.iw + lp - kw * dw)
470 col_[ow] = zero_val;
471 else
472 col_[ow] = im_[iw];
473 }
474 }
475 }
476 }
477 }
478 }
479 } else {
480 for (dim_t ic = 0; ic < cb; ic++) {
481 const data_t *__restrict im_ = _im + (ic + cs) * im_step;
482 for (dim_t kh = 0; kh < jcp.kh; kh++) {
483 for (dim_t kw = 0; kw < jcp.kw; kw++) {
484 data_t *__restrict col_k = _col + ic * col_step
485 + (kh * jcp.kw + kw) * sb;
486 for (dim_t oh = oh_begin; oh < oh_end; oh++) {
487 const dim_t ih = oh * sh - tp + kh * dh;
488 const dim_t ow_begin
489 = (oh == first_oh) ? first_ow : 0;
490 const dim_t ow_end
491 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
492 data_t *__restrict col_oh
493 = col_k + oh * jcp.ow - ss;
494 if (ih < 0 || ih >= jcp.ih)
495 for (dim_t ow = ow_begin; ow < ow_end; ow++)
496 col_oh[ow] = zero_val;
497 else
498 for (dim_t ow = ow_begin; ow < ow_end; ow++) {
499 const dim_t iw = ow * sw - lp + kw * dw;
500 if (iw < 0 || iw >= jcp.iw)
501 col_oh[ow] = zero_val;
502 else {
503 const ptrdiff_t im_idx
504 = ih * jcp.iw + iw;
505 col_oh[ow] = im_[im_idx];
506 }
507 }
508 }
509 }
510 }
511 }
512 }
513 } else {
514 // TODO: optimize threading if jcp.ic*jcp.kh*jcp.kw*oh_range is small
515 // comparing to number of threads
516 const dim_t oh_range = oh_end - oh_begin;
517 // Generated code is more optimized for stride_w == 1
518 // because innermost loop is by width
519 if (sw == 1)
520 parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
521 [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
522 const dim_t oh = ohr + oh_begin;
523 const dim_t ih = oh * sh - tp + kh * dh;
524 const dim_t ow_start = (oh == first_oh) ? first_ow : 0;
525 const dim_t ow_end
526 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
527 data_t *__restrict col_oh = _col + ic * col_step
528 + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss;
529 const data_t *__restrict im_
530 = _im + (ic + cs) * im_step + ih * jcp.iw;
531 const dim_t iw_shift = kw * dw - lp;
532 if (ih < 0 || ih >= jcp.ih)
533 for (dim_t ow = ow_start; ow < ow_end; ow++)
534 col_oh[ow] = zero_val;
535 else
536 for (dim_t ow = ow_start; ow < ow_end; ow++) {
537 const dim_t iw = ow + iw_shift;
538 if (iw < 0 || iw >= jcp.iw)
539 col_oh[ow] = zero_val;
540 else
541 col_oh[ow] = im_[iw];
542 }
543 });
544 else
545 parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
546 [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) {
547 const dim_t oh = ohr + oh_begin;
548 const dim_t ih = oh * sh - tp + kh * dh;
549 const dim_t ow_start = (oh == first_oh) ? first_ow : 0;
550 const dim_t ow_end
551 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
552 data_t *__restrict col_oh = _col + ic * col_step
553 + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss;
554 const data_t *__restrict im_
555 = _im + (ic + cs) * im_step;
556 if (ih < 0 || ih >= jcp.ih)
557 for (dim_t ow = ow_start; ow < ow_end; ow++)
558 col_oh[ow] = zero_val;
559 else
560 for (dim_t ow = ow_start; ow < ow_end; ow++) {
561 const dim_t iw = ow * sw - lp + kw * dw;
562 if (iw < 0 || iw >= jcp.iw)
563 col_oh[ow] = zero_val;
564 else {
565 const ptrdiff_t im_idx = ih * jcp.iw + iw;
566 col_oh[ow] = im_[im_idx];
567 }
568 }
569 });
570 }
571}
572
573template void im2col(const conv_gemm_conf_t &jcp, const float *__restrict im,
574 float *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb);
575
576template void im2col(const conv_gemm_conf_t &jcp,
577 const bfloat16_t *__restrict im, bfloat16_t *__restrict col, dim_t hs,
578 dim_t hb, dim_t ws, dim_t wb);
579
580/* col[kh][kw][ic][oh][ow] <-- im2col_dt(im[ih][iw][ic]) */
581template <typename orig_im_dt, typename orig_col_dt>
582void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im,
583 void *__restrict _imtr, orig_col_dt *__restrict _col, dim_t hs,
584 dim_t hb, dim_t ws, dim_t wb) {
585 // For performance reasons, use uint16_t as a proxy for bfloat16_t
586 using im_dt = typename utils::conditional<data_traits<orig_im_dt>::data_type
587 == bf16,
588 uint16_t, orig_im_dt>::type;
589 using col_dt =
590 typename utils::conditional<data_traits<orig_col_dt>::data_type
591 == bf16,
592 uint16_t, orig_col_dt>::type;
593 const im_dt *__restrict im = reinterpret_cast<const im_dt *__restrict>(_im);
594 im_dt *__restrict imtr = reinterpret_cast<im_dt *__restrict>(_imtr);
595 col_dt *__restrict col = reinterpret_cast<col_dt *__restrict>(_col);
596
597 col_dt shift = static_cast<col_dt>(jcp.signed_input ? 128 : 0);
598 const dim_t dh = 1 + jcp.dilate_h;
599 const dim_t dw = 1 + jcp.dilate_w;
600 const dim_t sh = jcp.stride_h;
601 const dim_t sw = jcp.stride_w;
602 const dim_t im_iw_stride = jcp.ic * jcp.ngroups;
603 const dim_t im_ih_stride = jcp.iw * im_iw_stride;
604 const dim_t tp = jcp.t_pad;
605 const dim_t lp = jcp.l_pad;
606
607 if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) {
608 /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
609 const dim_t hp = hs - tp;
610 const dim_t wp = ws - lp;
611 const dim_t ih_start = saturate(dim_t(0), jcp.ih, hp);
612 const dim_t ih_end = saturate(dim_t(0), jcp.ih, hp + hb + jcp.kh);
613 const dim_t iw_start = saturate(dim_t(0), jcp.iw, wp);
614 const dim_t iw_end = saturate(dim_t(0), jcp.iw, wp + wb + jcp.kw);
615
616 const dim_t ihb = ih_end - ih_start;
617 const dim_t iwb = iw_end - iw_start;
618
619 const dim_t imtr_ic_stride = ihb * iwb;
620 const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start;
621 for (dim_t ic = 0; ic < jcp.ic; ic++) {
622 const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift;
623 for (dim_t ih = ih_start; ih < ih_end; ih++) {
624 const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride;
625 const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb;
626 for (dim_t iw = iw_start; iw < iw_end; iw++)
627 imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride];
628 }
629 }
630
631 const dim_t col_ic_str = hb * wb;
632 const dim_t col_kw_stride = jcp.ic * col_ic_str;
633 const dim_t col_kh_stride = jcp.kw * col_kw_stride;
634
635 const dim_t oh_init = ih_start - hp;
636 const dim_t ow_init = iw_start - wp;
637 for (dim_t kh = 0; kh < jcp.kh; kh++) {
638 const ptrdiff_t col_idx_kh = kh * col_kh_stride;
639 const dim_t oh_kh = oh_init - kh;
640 const dim_t oh_start = saturate(dim_t(0), hb, oh_kh);
641 const dim_t oh_end = saturate(dim_t(0), hb, oh_kh + ihb);
642 for (dim_t kw = 0; kw < jcp.kw; kw++) {
643 const ptrdiff_t col_idx_kw
644 = col_idx_kh + kw * jcp.ic * col_ic_str;
645 const dim_t ow_kw = ow_init - kw;
646 const dim_t imtr_shift = oh_kh * iwb + ow_kw;
647 const dim_t ow_start = saturate(dim_t(0), wb, ow_kw);
648 const dim_t ow_end = saturate(dim_t(0), wb, ow_kw + iwb);
649 for (dim_t ic = 0; ic < jcp.ic; ic++) {
650 const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
651 const dim_t imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
652 for (dim_t oh = 0; oh < oh_start; oh++) {
653 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
654 for (dim_t ow = 0; ow < wb; ++ow)
655 col[col_idx_oh + ow] = shift;
656 }
657 for (dim_t oh = oh_start; oh < oh_end; oh++) {
658 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
659 const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb;
660 for (dim_t ow = 0; ow < ow_start; ++ow)
661 col[col_idx_oh + ow] = shift;
662 for (dim_t ow = ow_start; ow < ow_end; ++ow)
663 col[col_idx_oh + ow]
664 = imtr[imtr_idx_oh + ow] + shift;
665 for (dim_t ow = ow_end; ow < wb; ++ow)
666 col[col_idx_oh + ow] = shift;
667 }
668 for (dim_t oh = oh_end; oh < hb; oh++) {
669 const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
670 for (dim_t ow = 0; ow < wb; ++ow)
671 col[col_idx_oh + ow] = shift;
672 }
673 }
674 }
675 }
676 } else {
677 parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
678 [&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) {
679 const dim_t hp = tp - kh * dh;
680 const dim_t ih = (oh + hs) * sh - hp;
681 const ptrdiff_t col_idx_base
682 = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh)
683 * wb;
684 if (ih < 0 || ih >= jcp.ih)
685 for (dim_t ow = 0; ow < wb; ow++)
686 col[col_idx_base + ow] = shift;
687 else {
688 const dim_t wp = lp - kw * dw;
689 const dim_t ow_start
690 = saturate(dim_t(0), wb, div_up(wp, sw) - ws);
691 const dim_t ow_end = saturate(
692 dim_t(0), wb, div_up(jcp.iw + wp, sw) - ws);
693 for (dim_t ow = 0; ow < ow_start; ow++)
694 col[col_idx_base + ow] = shift;
695 const dim_t iw_base = ws * sw - wp;
696 const ptrdiff_t im_idx_base = ih * im_ih_stride + ic;
697 for (dim_t ow = ow_start; ow < ow_end; ow++) {
698 const dim_t iw = iw_base + ow * sw;
699 const ptrdiff_t im_idx
700 = im_idx_base + iw * im_iw_stride;
701 col[col_idx_base + ow] = im[im_idx] + shift;
702 }
703 for (dim_t ow = ow_end; ow < wb; ow++)
704 col[col_idx_base + ow] = shift;
705 }
706 });
707 }
708}
709
710template void im2col_dt<int8_t, uint8_t>(const conv_gemm_conf_t &jcp,
711 const void *__restrict im, void *__restrict imtr,
712 uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb);
713template void im2col_dt<uint8_t, uint8_t>(const conv_gemm_conf_t &jcp,
714 const void *__restrict im, void *__restrict imtr,
715 uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb);
716template void im2col_dt<float, float>(const conv_gemm_conf_t &jcp,
717 const void *__restrict im, void *__restrict imtr, float *__restrict col,
718 dim_t hs, dim_t hb, dim_t ws, dim_t wb);
719
720template void im2col_dt<bfloat16_t, bfloat16_t>(const conv_gemm_conf_t &jcp,
721 const void *__restrict im, void *__restrict imtr,
722 bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb);
723
724/* im[id][ih][iw][ic] <-- col2im_dt_3d(col[od][oh][ow][kd][kh][kw][ic]) */
725template <typename orig_T>
726void col2im_dt(const conv_gemm_conf_t &jcp, const orig_T *__restrict _col,
727 orig_T *__restrict _im) {
728 // For performance reasons, use uint16_t as a proxy for bfloat16_t
729 using T =
730 typename utils::conditional<data_traits<orig_T>::data_type == bf16,
731 uint16_t, orig_T>::type;
732 const T *__restrict col = reinterpret_cast<const T *__restrict>(_col);
733 T *__restrict im = reinterpret_cast<T *__restrict>(_im);
734
735 parallel(0, [&](const int ithr, const int nthr) {
736 dim_t d_nthr = nstl::min(jcp.id, dim_t(nthr));
737 dim_t h_nthr = nstl::min(jcp.ih, dim_t(nthr) / d_nthr);
738 dim_t w_nthr = nstl::min(jcp.iw, dim_t(nthr) / (d_nthr * h_nthr));
739 dim_t d_ithr = 1, d_s = 0, d_e = 0, h_ithr = 1, h_s = 0, h_e = 0,
740 w_ithr = 1, w_s = 0, w_e = 0;
741 if (ithr < d_nthr * h_nthr * w_nthr) {
742 d_ithr = ithr / (h_nthr * w_nthr);
743 h_ithr = (ithr % (h_nthr * w_nthr)) / w_nthr;
744 w_ithr = (ithr % (h_nthr * w_nthr)) % w_nthr;
745 balance211(jcp.id, d_nthr, d_ithr, d_s, d_e);
746 balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
747 balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
748 } else {
749 d_nthr = h_ithr = w_ithr = -ithr;
750 d_s = d_e = h_s = h_e = w_s = w_e = -1;
751 }
752
753 for_(dim_t id = d_s; id < d_e; ++id)
754 for_(dim_t ih = h_s; ih < h_e; ++ih)
755 for (dim_t iw = w_s; iw < w_e; ++iw) {
756 PRAGMA_OMP_SIMD()
757 for (dim_t ic = 0; ic < jcp.ic; ++ic) {
758 im[((id * jcp.ih + ih) * jcp.iw + iw) * jcp.ic + ic] = 0;
759 }
760 }
761
762 // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
763 for_(dim_t od = 0; od < jcp.od; ++od)
764 for_(dim_t oh = 0; oh < jcp.oh; ++oh)
765 for_(dim_t ow = 0; ow < jcp.ow; ++ow)
766 for (dim_t kd = 0; kd < jcp.kd; ++kd) {
767 const dim_t id
768 = od * jcp.stride_d - jcp.f_pad + kd * (1 + jcp.dilate_d);
769 if (id < d_s || id >= d_e) continue;
770
771 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
772 const dim_t ih = oh * jcp.stride_h - jcp.t_pad
773 + kh * (1 + jcp.dilate_h);
774 if (ih < h_s || ih >= h_e) continue;
775
776 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
777 const dim_t iw = ow * jcp.stride_w - jcp.l_pad
778 + kw * (1 + jcp.dilate_w);
779 if (iw < w_s || iw >= w_e) continue;
780
781 const size_t col_idx
782 = (((((od * jcp.oh + oh) * jcp.ow + ow) * jcp.kd
783 + kd) * jcp.kh
784 + kh) * jcp.kw
785 + kw)
786 * jcp.ic;
787 const size_t im_idx
788 = ((id * jcp.ih + ih) * jcp.iw + iw) * jcp.ic;
789 PRAGMA_OMP_SIMD()
790 for (dim_t ic = 0; ic < jcp.ic; ++ic) {
791 im[im_idx + ic] += col[col_idx + ic];
792 }
793 }
794 }
795 }
796 });
797}
798
799template void col2im_dt<int32_t>(const conv_gemm_conf_t &jcp,
800 const int32_t *__restrict col, int32_t *__restrict im);
801
802template void col2im_dt<float>(const conv_gemm_conf_t &jcp,
803 const float *__restrict col, float *__restrict im);
804
805template void col2im_dt<bfloat16_t>(const conv_gemm_conf_t &jcp,
806 const bfloat16_t *__restrict col, bfloat16_t *__restrict im);
807
808void col2im_3d(const conv_gemm_conf_t &jcp, const float *col, float *im,
809 dim_t od, int spatial_step, int spatial_block) {
810
811 auto sp_blocked_ker = [&](dim_t ic) {
812 const size_t col_step = jcp.ks * spatial_block;
813 const float *__restrict col_ = col + ic * col_step;
814 float *__restrict im_ic = im + ic * jcp.ih * jcp.iw * jcp.id;
815
816 const dim_t first_oh = spatial_step / jcp.ow;
817 const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow;
818 const dim_t oh_begin = first_oh;
819 const dim_t oh_end = last_oh + 1;
820 const dim_t first_ow = spatial_step % jcp.ow;
821 const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow;
822 const dim_t wei_stride
823 = nstl::min(jcp.ow * jcp.oh, dim_t(spatial_block));
824
825 dim_t id = od * jcp.stride_d - jcp.f_pad;
826 for (dim_t kd = 0; kd < jcp.kd; ++kd) {
827 if (id < 0 || id >= jcp.id) {
828 col_ += jcp.kh * jcp.kw * wei_stride;
829 id += (1 + jcp.dilate_d);
830 continue;
831 }
832
833 float *__restrict im_ = im_ic + (size_t)id * jcp.ih * jcp.iw;
834 for_(dim_t kh = 0; kh < jcp.kh; ++kh)
835 for_(dim_t kw = 0; kw < jcp.kw; ++kw)
836 for (dim_t oh = oh_begin, col_off = 0; oh < oh_end; ++oh) {
837
838 const dim_t ow_begin = (oh == first_oh) ? first_ow : 0;
839 const dim_t ow_end = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
840 const dim_t ow_work = ow_end - ow_begin;
841
842 const dim_t ih = oh * jcp.stride_h - jcp.t_pad
843 + kh * (1 + jcp.dilate_h);
844 if (ih < 0 || ih >= jcp.ih) {
845 col_off += ow_work;
846 continue;
847 }
848
849 for (dim_t ow = ow_begin; ow < ow_end; ++ow, ++col_off) {
850 const dim_t iw = ow * jcp.stride_w - jcp.l_pad
851 + kw * (1 + jcp.dilate_w);
852 if (iw < 0 || iw >= jcp.iw) { continue; }
853
854 const size_t col_idx
855 = (kh * jcp.kw + kw) * wei_stride + col_off;
856 const size_t im_idx = ih * jcp.iw + iw;
857 im_[im_idx] += col_[col_idx];
858 }
859 }
860 col_ += jcp.kh * jcp.kw * wei_stride;
861 id += (1 + jcp.dilate_d);
862 }
863 };
864
865 auto ker = [&](dim_t ic) {
866 const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
867 float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
868
869 dim_t id = od * jcp.stride_d - jcp.f_pad;
870 for (dim_t kd = 0; kd < jcp.kd; ++kd) {
871 if (id < 0 || id >= jcp.id) {
872 col_ += jcp.kh * jcp.kw * jcp.os;
873 id += (1 + jcp.dilate_d);
874 continue;
875 }
876
877 float *__restrict im_ = im_ic + (size_t)id * jcp.ih * jcp.iw;
878
879 for_(dim_t oh = 0; oh < jcp.oh; ++oh)
880 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
881 const dim_t ih = oh * jcp.stride_h - jcp.t_pad
882 + kh * (1 + jcp.dilate_h);
883 if (ih < 0 || ih >= jcp.ih) continue;
884
885 for_(dim_t ow = 0; ow < jcp.ow; ++ow)
886 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
887 const dim_t iw = ow * jcp.stride_w - jcp.l_pad
888 + kw * (1 + jcp.dilate_w);
889 if (iw < 0 || iw >= jcp.iw) continue;
890
891 const size_t col_idx
892 = ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow;
893 const size_t im_idx = ih * jcp.iw + iw;
894 im_[im_idx] += col_[col_idx];
895 }
896 }
897
898 col_ += jcp.kh * jcp.kw * jcp.os;
899 id += (1 + jcp.dilate_d);
900 }
901 };
902
903 const bool blocked_kernel = jcp.os_nb_block > 1;
904 if (blocked_kernel)
905 parallel_nd(jcp.ic, sp_blocked_ker);
906 else
907 parallel_nd(jcp.ic, ker);
908}
909
910void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im,
911 int spatial_step, int spatial_block) {
912 const size_t col_step = jcp.ks * spatial_block;
913 const size_t im_step = jcp.ih * jcp.iw;
914 const dim_t iS = jcp.ih * jcp.iw;
915
916 auto sp_blocked_ker = [&](dim_t ic) {
917 const dim_t wei_stride
918 = nstl::min(jcp.ow * jcp.oh, dim_t(spatial_block));
919 const dim_t first_oh = spatial_step / jcp.ow;
920 const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow;
921 const dim_t oh_begin = first_oh;
922 const dim_t oh_end = last_oh + 1;
923 const dim_t first_ow = spatial_step % jcp.ow;
924 const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow;
925
926 float *__restrict img_ithr = im + ic * im_step;
927 const float *__restrict col_icb = col + ic * col_step;
928
929 if (spatial_step == 0) {
930 PRAGMA_OMP_SIMD()
931 for (dim_t is = 0; is < iS; ++is)
932 img_ithr[is] = 0.;
933 }
934
935 float *__restrict img_kh = img_ithr;
936 for (dim_t kh = 0; kh < jcp.kh; ++kh) {
937 float *__restrict im_ = img_kh;
938 for (dim_t kw = 0; kw < jcp.kw; ++kw) {
939 const float *__restrict col_ = col_icb;
940 for (dim_t oh = oh_begin; oh < oh_end; ++oh) {
941 const dim_t ow_begin = (oh == first_oh) ? first_ow : 0;
942 const dim_t ow_end
943 = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
944 const dim_t ow_work = ow_end - ow_begin;
945
946 const dim_t ih = oh * jcp.stride_h - jcp.t_pad;
947 const dim_t ih_ = ih + kh * (1 + jcp.dilate_h);
948 if (ih_ < 0 || ih_ >= jcp.ih) {
949 col_ += ow_work;
950 continue;
951 }
952 for (dim_t ow = ow_begin; ow < ow_end; ++ow, ++col_) {
953 const dim_t iw = ow * jcp.stride_w - jcp.l_pad;
954 const dim_t iw_ = iw + kw * (1 + jcp.dilate_w);
955 if (iw_ < 0 || iw_ >= jcp.iw) continue;
956
957 const size_t im_idx = ih * jcp.iw + iw;
958 im_[im_idx] += *col_;
959 }
960 }
961 col_icb += wei_stride;
962 im_ += (1 + jcp.dilate_w);
963 }
964 img_kh += (jcp.iw * (1 + jcp.dilate_h));
965 }
966 };
967
968 auto ker = [&](dim_t ic) {
969 float *__restrict im_ = im + ic * im_step;
970 const float *__restrict col_ = col + ic * col_step;
971 PRAGMA_OMP_SIMD()
972 for (dim_t is = 0; is < iS; ++is)
973 im_[is] = 0.;
974
975 for_(dim_t kh = 0; kh < jcp.kh; ++kh)
976 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
977 const dim_t ih
978 = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
979 if (ih < 0 || ih >= jcp.ih) continue;
980
981 for_(dim_t kw = 0; kw < jcp.kw; ++kw)
982 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
983 const dim_t iw = ow * jcp.stride_w - jcp.l_pad
984 + kw * (1 + jcp.dilate_w);
985 if (iw < 0 || iw >= jcp.iw) continue;
986
987 const size_t col_idx
988 = ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow;
989 const size_t im_idx = ih * jcp.iw + iw;
990 im_[im_idx] += col_[col_idx];
991 }
992 }
993 };
994
995 const bool blocked_kernel = jcp.os_nb_block > 1;
996 if (blocked_kernel)
997 parallel_nd(jcp.ic, sp_blocked_ker);
998 else
999 parallel_nd(jcp.ic, ker);
1000}
1001
1002status_t init_conf(conv_gemm_conf_t &jcp,
1003 memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
1004 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
1005 memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads) {
1006 const memory_desc_wrapper src_d(&src_md);
1007 const memory_desc_wrapper weights_d(&weights_md);
1008 const memory_desc_wrapper dst_d(&dst_md);
1009
1010 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1011 const int ndims = src_d.ndims();
1012 const int is_1d = ndims == 3;
1013 const int is_3d = ndims == 5;
1014
1015 jcp.prop_kind = cd.prop_kind;
1016
1017 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1018 jcp.mb = src_d.dims()[0];
1019
1020 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1021 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1022 jcp.id = is_3d ? src_d.dims()[2] : 1;
1023 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1024 jcp.iw = src_d.dims()[ndims - 1];
1025 jcp.od = is_3d ? dst_d.dims()[2] : 1;
1026 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1027 jcp.ow = dst_d.dims()[ndims - 1];
1028
1029 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
1030 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
1031 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1032
1033 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
1034 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
1035 jcp.l_pad = cd.padding[0][ndims - 3];
1036
1037 jcp.stride_d = is_3d ? cd.strides[0] : 1;
1038 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
1039 jcp.stride_w = cd.strides[ndims - 3];
1040
1041 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1042 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1043 jcp.dilate_w = cd.dilates[ndims - 3];
1044
1045 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef
1046 || cd.diff_bias_desc.format_kind != format_kind::undef;
1047
1048 jcp.is = jcp.ih * jcp.iw;
1049 jcp.os = jcp.oh * jcp.ow;
1050 jcp.ks = jcp.kh * jcp.kw * jcp.kd;
1051
1052 jcp.signed_input = src_d.data_type() == data_type::s8;
1053
1054 jcp.outer_threading = false;
1055
1056 jcp.zp = zero_point_config_t(attr);
1057 jcp.b_pad = nstl::max((jcp.oh - 1) * jcp.stride_h
1058 + (jcp.kh - 1) * (jcp.dilate_h + 1)
1059 - (jcp.ih + jcp.t_pad - 1),
1060 dim_t(0));
1061 jcp.r_pad = nstl::max((jcp.ow - 1) * jcp.stride_w
1062 + (jcp.kw - 1) * (jcp.dilate_w + 1)
1063 - (jcp.iw + jcp.l_pad - 1),
1064 dim_t(0));
1065 jcp.e_pad = nstl::max((jcp.od - 1) * jcp.stride_d
1066 + (jcp.kd - 1) * (jcp.dilate_d + 1)
1067 - (jcp.id + jcp.f_pad - 1),
1068 dim_t(0));
1069
1070 const bool zp_src_with_padding = jcp.zp.src_exists && padding_exists(jcp);
1071
1072 if (zp_src_with_padding) {
1073 jcp.zp.src_pad_comp = zero_point_pad_comp_config_t(jcp.f_pad, jcp.e_pad,
1074 jcp.t_pad, jcp.b_pad, jcp.l_pad, jcp.r_pad, jcp.stride_d,
1075 jcp.stride_h, jcp.stride_w, jcp.od, jcp.oh, jcp.ow);
1076 }
1077
1078 const auto set_or_check_tags
1079 = [&](format_tag_t desired_src_tag, format_tag_t desired_dst_tag,
1080 bool is_src_s8) -> status_t {
1081 using namespace format_tag;
1082 auto src_tag = any, dst_tag = any;
1083
1084 if (src_d.format_kind() == format_kind::any) {
1085 CHECK(memory_desc_init_by_tag(src_md, desired_src_tag));
1086 src_tag = desired_src_tag;
1087 } else {
1088 src_tag = memory_desc_matches_one_of_tag(
1089 src_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
1090 }
1091
1092 if (dst_d.format_kind() == format_kind::any) {
1093 CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag));
1094 dst_tag = desired_dst_tag;
1095 } else {
1096 dst_tag = memory_desc_matches_one_of_tag(
1097 dst_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw);
1098 }
1099
1100 if (src_tag == format_tag::undef || dst_tag == format_tag::undef)
1101 return status::unimplemented;
1102 if (src_tag != dst_tag) return status::unimplemented;
1103
1104 if (jcp.with_bias && bias_md.format_kind == format_kind::any)
1105 CHECK(memory_desc_init_by_tag(bias_md, x));
1106
1107 const bool is_nspc = utils::one_of(src_tag, nwc, nhwc, ndhwc);
1108 jcp.is_nspc = is_nspc;
1109
1110 memory_desc_t want_wei_md = weights_md;
1111 auto wei_tag = is_nspc
1112 ? (with_groups ? utils::pick(ndims - 3, wigo, hwigo, dhwigo)
1113 : utils::pick(ndims - 3, wio, hwio, dhwio))
1114 : (with_groups ? utils::pick(ndims - 3, goiw, goihw, goidhw)
1115 : utils::pick(ndims - 3, oiw, oihw, oidhw));
1116 CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag));
1117
1118 if (is_src_s8) {
1119 want_wei_md.extra.flags = 0
1120 | memory_extra_flags::compensation_conv_s8s8
1121 | memory_extra_flags::scale_adjust;
1122 want_wei_md.extra.compensation_mask
1123 = (1 << 0) + (with_groups ? (1 << 1) : 0);
1124 want_wei_md.extra.scale_adjust
1125 = platform::s8s8_weights_scale_factor();
1126 }
1127
1128 if (jcp.zp.src_exists) set_zp_src_comp_flags(want_wei_md, with_groups);
1129
1130 if (weights_md.format_kind == format_kind::any) {
1131 weights_md = want_wei_md;
1132 return status::success;
1133 }
1134 return (want_wei_md == weights_md) ? status::success
1135 : status::unimplemented;
1136 };
1137
1138 const bool is_bwd_d = jcp.prop_kind == backward_data;
1139 const bool is_bwd_w = jcp.prop_kind == backward_weights;
1140 const bool is_fwd = !is_bwd_d && !is_bwd_w;
1141
1142 bool is_int8_conv = (is_fwd ? utils::one_of(src_d.data_type(), s8, u8)
1143 : utils::one_of(dst_d.data_type(), s8, u8))
1144 && weights_d.data_type() == s8;
1145
1146 auto default_dat_tag = is_int8_conv
1147 ? utils::pick(ndims - 3, format_tag::nwc, format_tag::nhwc,
1148 format_tag::ndhwc)
1149 : utils::pick(ndims - 3, format_tag::ncw, format_tag::nchw,
1150 format_tag::ncdhw);
1151 CHECK(set_or_check_tags(default_dat_tag, default_dat_tag,
1152 src_md.data_type == data_type::s8));
1153
1154 // Does int8 conv ever need to support ncsp input format
1155 if (is_int8_conv && !src_d.matches_one_of_tag(default_dat_tag))
1156 return status::unimplemented;
1157
1158 CHECK(attr.set_default_formats(&dst_md));
1159
1160 jcp.post_ops = attr.post_ops_;
1161
1162 const int eltwise_ind = jcp.post_ops.find(primitive_kind::eltwise);
1163 jcp.with_eltwise = eltwise_ind != -1;
1164 const int binary_ind = jcp.post_ops.find(primitive_kind::binary);
1165 jcp.with_binary = binary_ind != -1;
1166 const int sum_ind = jcp.post_ops.find(primitive_kind::sum);
1167 jcp.with_sum = sum_ind != -1;
1168
1169 bool is_bf16_conv = false
1170 || (is_fwd
1171 && utils::everyone_is(
1172 bf16, src_d.data_type(), weights_d.data_type()))
1173 || (is_bwd_d
1174 && utils::everyone_is(
1175 bf16, dst_d.data_type(), weights_d.data_type()))
1176 || (is_bwd_w
1177 && utils::everyone_is(
1178 bf16, src_d.data_type(), dst_d.data_type()));
1179 if (is_bf16_conv && !platform::has_data_type_support(bf16))
1180 return status::unimplemented;
1181
1182 const int vlen = std::max(platform::get_vector_register_size(), 4);
1183 const int data_size = (is_int8_conv ? 1 : (is_bf16_conv ? 2 : 4));
1184 const int simd_w = vlen / data_size;
1185
1186 jcp.os_block = jcp.os;
1187 jcp.os_nb_block = 1;
1188 jcp.oc_block = jcp.oc;
1189 jcp.ic_block = jcp.ic;
1190 jcp.loop_order = gemm_loop_rlb;
1191 jcp.nthr_oc = 1;
1192
1193 jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
1194 jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
1195
1196 using namespace memory_tracking::names;
1197 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
1198
1199 // TODO: maybe mitigate blocking restriction
1200 const auto L2 = platform::get_per_core_cache_size(2) / data_size;
1201 const int gemm_thrld = 64 * 1024;
1202
1203 // Heuristic threshold for requested scratchpad memory to avoid
1204 // possible crash on memory allocation:
1205 // 1Gb or size of the buffers already used for this convolution proportional
1206 // to the number of threads and multiplied by a heuristic coefficient (15)
1207 const size_t zp_src_pad_comp_size = zp_src_with_padding
1208 ? (jcp.oc * jcp.ngroups * jcp.zp.src_pad_comp.d
1209 * jcp.zp.src_pad_comp.h * jcp.zp.src_pad_comp.w)
1210 : 0u;
1211 const size_t zp_src_comp_size = jcp.zp.src_is_common
1212 ? utils::rnd_up(jcp.oc * jcp.ngroups,
1213 platform::get_cache_line_size() / sizeof(int))
1214 : 0u;
1215
1216 const size_t weights_size = weights_d.size()
1217 + (zp_src_comp_size + zp_src_pad_comp_size) * sizeof(int32_t);
1218
1219 static constexpr size_t scratchpad_limit_by_absolute_value = (size_t)1
1220 << 30; // 1Gb
1221 const size_t scratchpad_limit_by_tensor_sizes
1222 = 15 * max_threads * (src_d.size() + weights_size + dst_d.size());
1223 const size_t scratchpad_limit
1224 = nstl::min(scratchpad_limit_by_absolute_value,
1225 scratchpad_limit_by_tensor_sizes);
1226
1227 if (is_int8_conv) {
1228 if (is_fwd) {
1229 jcp.im2col_sz
1230 = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih,
1231 jcp.od == jcp.id, jcp.stride_w == 1,
1232 jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1,
1233 !jcp.signed_input)
1234 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os
1235 : 0;
1236
1237 dim_t wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
1238 bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz
1239 && !is_3d && jcp.dilate_h == 0 && jcp.dilate_w == 0
1240 && !is_depthwise && wei_size < L2 / 2;
1241 if (is_blocking_applicable) {
1242 // looking for oh and ow blocking
1243 dim_t h_block {jcp.oh_block}, w_block {jcp.ow_block};
1244 dim_t ic = jcp.ic;
1245 dim_t oc = jcp.oc;
1246 dim_t iw = jcp.iw;
1247 dim_t ow = jcp.ow;
1248 dim_t oh = jcp.oh;
1249 dim_t os = oh * ow;
1250
1251 // 1. cache requirement
1252 dim_t row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
1253 // Heuristic rule: gemm needed a lot of memory for internal
1254 // usage
1255 row_size *= 5;
1256 // memory for accumulators
1257 row_size += oc * ow * sizeof(uint32_t);
1258 // memory for transposition
1259 row_size += ic * iw;
1260
1261 h_block = nstl::max(
1262 dim_t(1), nstl::min(oh, div_up(dim_t(L2), row_size)));
1263 if (h_block == 1) {
1264 dim_t col_size = ic * jcp.ks + 2 * (ic + oc);
1265 if (is_int8_conv) {
1266 col_size *= 5;
1267 col_size += oc * sizeof(uint32_t);
1268 col_size += ic;
1269 }
1270 w_block = nstl::max(dim_t(1),
1271 nstl::min(ow, div_up(dim_t(L2), col_size)));
1272 }
1273
1274 // 2. threading requirement
1275 if (h_block != oh)
1276 h_block = nstl::max(dim_t(1), rnd_dn(h_block, dim_t(4)));
1277 if (w_block != ow)
1278 w_block = nstl::max(dim_t(1), rnd_dn(w_block, simd_w));
1279
1280 float thr_eff = 0.f;
1281 float thr_eff_treshold = 0.9f;
1282 if (w_block == ow) {
1283 do {
1284 dim_t nb_h = div_up(oh, h_block);
1285 dim_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
1286 float disb = (float)oh / rnd_up(oh, h_block);
1287 thr_eff = (float)work / rnd_up(work, max_threads);
1288 thr_eff = (thr_eff + disb) / 2.f;
1289 if (thr_eff >= thr_eff_treshold) break;
1290 h_block = rnd_dn(h_block - 4, 4);
1291 } while (h_block > 0);
1292 }
1293 if (thr_eff
1294 < thr_eff_treshold) // we didn't find suitable h_block
1295 {
1296 h_block = 1;
1297 int nb_h = oh;
1298 do {
1299 dim_t nb_w = div_up(ow, w_block);
1300 dim_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
1301 float disb = (float)ow / rnd_up(ow, w_block);
1302 thr_eff = (float)work_amount
1303 / rnd_up(work_amount, max_threads);
1304 thr_eff = (thr_eff + disb) / 2.f;
1305 if (thr_eff > thr_eff_treshold) break;
1306 w_block = rnd_dn(w_block - simd_w, simd_w);
1307 } while (w_block > 0);
1308 }
1309 h_block = nstl::max(dim_t(1), h_block);
1310 w_block = nstl::max(dim_t(1), w_block);
1311 dim_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
1312 const float inner_thr_eff
1313 = (float)inner_work / rnd_up(inner_work, max_threads);
1314 if (thr_eff >= inner_thr_eff / 2 && h_block > 0
1315 && w_block > 0) {
1316 jcp.oh_block = h_block;
1317 jcp.ow_block = w_block;
1318 jcp.outer_threading = true;
1319 }
1320 // updating jcp.im2col_sz
1321 if (jcp.oh_block != 1) jcp.ow_block = ow;
1322 jcp.im2col_sz
1323 = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
1324 }
1325 // For threading selection in bwd_d we do:
1326 // 1. Rough estimation of efficiency for inner and outer threading.
1327 // 2. Gemm size estimation in assumption that it does not work
1328 // so effectively for small sizes.
1329 // 64K - this is heuristic gemm size per thread threshold.
1330 const int gemm_thrld = 64 * 1024;
1331 if (!jcp.outer_threading && !is_3d) {
1332 bool is_depthwise
1333 = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
1334 const dim_t outer_work = jcp.ngroups * jcp.mb;
1335 const float outer_thr_eff
1336 = (float)outer_work / rnd_up(outer_work, max_threads);
1337 const size_t inner_work
1338 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
1339 const float inner_thr_eff
1340 = (float)inner_work / rnd_up(inner_work, max_threads);
1341 jcp.outer_threading
1342 = (is_depthwise
1343 || (jcp.is / max_threads < 64 && jcp.mb != 1))
1344 && (outer_thr_eff / inner_thr_eff >= 1.f
1345 || (jcp.os * jcp.ic * jcp.oc) / max_threads
1346 < gemm_thrld);
1347 }
1348 jcp.nthr = jcp.outer_threading ? max_threads : 1;
1349 scratchpad.book<int8_t>(
1350 key_conv_gemm_col, jcp.nthr * jcp.im2col_sz);
1351 scratchpad.book<int32_t>(key_conv_int_dat_in_acc_dt,
1352 jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc);
1353 scratchpad.book<int8_t>(
1354 key_conv_gemm_imtr, jcp.nthr * jcp.id * jcp.is * jcp.ic);
1355 } else if (is_bwd_d) {
1356 jcp.im2col_sz
1357 = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih,
1358 jcp.od == jcp.id, jcp.stride_w == 1,
1359 jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1,
1360 !jcp.signed_input)
1361 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os * jcp.od
1362 : 0;
1363
1364 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
1365 const size_t outer_work = jcp.ngroups * jcp.mb;
1366 const float outer_thr_eff
1367 = (float)outer_work / rnd_up(outer_work, max_threads);
1368 const size_t inner_work
1369 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
1370 const float inner_thr_eff
1371 = (float)inner_work / rnd_up(inner_work, max_threads);
1372 jcp.outer_threading = !is_3d
1373 && (is_depthwise
1374 || (jcp.is / max_threads < 64 && jcp.mb != 1))
1375 && (outer_thr_eff / inner_thr_eff >= 1.f
1376 || (jcp.is * jcp.ic * jcp.oc) / max_threads
1377 < gemm_thrld);
1378
1379 jcp.nthr = jcp.outer_threading ? max_threads : 1;
1380 scratchpad.book<int32_t>(
1381 key_conv_gemm_col, jcp.nthr * jcp.im2col_sz);
1382 scratchpad.book<int32_t>(key_conv_int_dat_in_acc_dt,
1383 jcp.nthr * jcp.is * jcp.id * jcp.ic);
1384 } else if (is_bwd_w) {
1385 assert(!"unimplemented prop_kind");
1386 return status::unimplemented;
1387 }
1388 } else {
1389 jcp.im2col_sz = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih,
1390 jcp.od == jcp.id, jcp.stride_w == 1,
1391 jcp.stride_h == 1, jcp.stride_d == 1,
1392 jcp.ks == 1, !jcp.signed_input)
1393 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os
1394 : 0;
1395 if (jcp.is_nspc && is_fwd) {
1396 const size_t wei_size
1397 = static_cast<size_t>(jcp.oc) * jcp.ic * jcp.kh * jcp.kw;
1398 bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz
1399 && !is_3d && jcp.dilate_h == 0 && jcp.dilate_w == 0
1400 && !is_depthwise && wei_size < static_cast<size_t>(L2) / 2;
1401 // Logic for blocking for f32_nspc gemm convolution follows that of
1402 // int8_nspc gemm convolution. Currently, not optimized for f32
1403 // data type.
1404 if (is_blocking_applicable) {
1405 // looking for oh and ow blocking
1406 size_t h_block = jcp.oh_block;
1407 size_t w_block = jcp.ow_block;
1408
1409 const size_t ic = jcp.ic;
1410 const size_t oc = jcp.oc;
1411 const size_t iw = jcp.iw;
1412 const size_t ow = jcp.ow;
1413 const size_t oh = jcp.oh;
1414 const size_t os = oh * ow;
1415
1416 // 1. cache requirement
1417 size_t row_size = ic * ow * jcp.ks * data_size
1418 + 2 * (ic * iw + oc * ow) * data_size;
1419 // Heuristic rule: gemm needed a lot of memory for internal
1420 // usage
1421 row_size *= 5;
1422 // memory for accumulators
1423 row_size += oc * ow * data_size;
1424 // memory for transposition
1425 row_size += ic * iw * data_size;
1426
1427 const size_t L2_rows = div_up(L2, row_size);
1428 h_block = saturate(size_t {1}, L2_rows, oh);
1429 if (h_block == 1) {
1430 size_t col_size = ic * jcp.ks * data_size
1431 + 2 * (ic + oc) * data_size;
1432 const size_t L2_cols = div_up(L2, col_size);
1433 w_block = saturate(size_t {1}, L2_cols, ow);
1434 }
1435
1436 // 2. threading requirement
1437 if (h_block != oh)
1438 h_block = nstl::max(size_t {1}, rnd_dn(h_block, 4));
1439 if (w_block != ow)
1440 w_block = nstl::max(size_t {1}, rnd_dn(w_block, simd_w));
1441
1442 float thr_eff = 0.f;
1443 float thr_eff_treshold = 0.9f;
1444 if (w_block == ow) {
1445 do {
1446 size_t nb_h = div_up(oh, h_block);
1447 size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
1448 float disb = (float)oh / rnd_up(oh, h_block);
1449 thr_eff = (float)work / rnd_up(work, max_threads);
1450 thr_eff = (thr_eff + disb) / 2.f;
1451 if (thr_eff >= thr_eff_treshold) break;
1452
1453 if (h_block < 4)
1454 h_block = 0;
1455 else
1456 h_block = rnd_dn(h_block - 4, 4);
1457 } while (h_block > 0);
1458 }
1459 if (thr_eff
1460 < thr_eff_treshold) // we didn't find suitable h_block
1461 {
1462 h_block = 1;
1463 size_t nb_h = oh;
1464 do {
1465 size_t nb_w = div_up(ow, w_block);
1466 size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
1467 float disb = (float)ow / rnd_up(ow, w_block);
1468 thr_eff = (float)work_amount
1469 / rnd_up(work_amount, max_threads);
1470 thr_eff = (thr_eff + disb) / 2.f;
1471 if (thr_eff > thr_eff_treshold) break;
1472
1473 if (w_block < static_cast<size_t>(simd_w))
1474 w_block = 0;
1475 else
1476 w_block = rnd_dn(w_block - simd_w, simd_w);
1477 } while (w_block > 0);
1478 }
1479 h_block = nstl::max(size_t {1}, h_block);
1480 w_block = nstl::max(size_t {1}, w_block);
1481 const size_t inner_work
1482 = div_up(os, simd_w) * div_up(oc, simd_w);
1483 const float inner_thr_eff
1484 = (float)inner_work / rnd_up(inner_work, max_threads);
1485 if (thr_eff >= inner_thr_eff / 2 && h_block > 0
1486 && w_block > 0) {
1487 jcp.oh_block = static_cast<int>(h_block);
1488 jcp.ow_block = static_cast<int>(w_block);
1489 jcp.outer_threading = true;
1490 }
1491 // updating jcp.im2col_sz
1492 if (jcp.oh_block != 1) jcp.ow_block = static_cast<int>(ow);
1493 jcp.im2col_sz
1494 = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
1495 }
1496 // For threading selection in fwd_d we do:
1497 // 1. Rough estimation of efficiency for inner and outer threading.
1498 // 2. Gemm size estimation in assumption that it does not work
1499 // so effectively for small sizes.
1500 // 64K - this is heuristic gemm size per thread threshold.
1501 constexpr size_t gemm_thrld = 64 * 1024;
1502 if (!jcp.outer_threading && !is_3d) {
1503 bool is_depthwise
1504 = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
1505 const size_t outer_work = jcp.ngroups * jcp.mb;
1506 const float outer_thr_eff
1507 = (float)outer_work / rnd_up(outer_work, max_threads);
1508 const size_t inner_work
1509 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
1510 const float inner_thr_eff
1511 = (float)inner_work / rnd_up(inner_work, max_threads);
1512 jcp.outer_threading
1513 = (is_depthwise
1514 || (jcp.is / max_threads < 64 && jcp.mb != 1))
1515 && (outer_thr_eff / inner_thr_eff >= 1.f
1516 || (static_cast<size_t>(jcp.os) * jcp.ic
1517 * jcp.oc)
1518 / max_threads
1519 < gemm_thrld);
1520 }
1521 jcp.nthr = jcp.outer_threading ? max_threads : 1;
1522 const size_t gemm_col_datatype_size
1523 = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float);
1524
1525 scratchpad.book(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz,
1526 gemm_col_datatype_size);
1527 if (is_bf16_conv) {
1528 scratchpad.book<float>(key_conv_gemm_acc,
1529 jcp.nthr * static_cast<size_t>(jcp.oh_block)
1530 * jcp.ow_block * jcp.oc);
1531 }
1532
1533 scratchpad.book(key_conv_gemm_imtr,
1534 jcp.nthr * static_cast<size_t>(jcp.id) * jcp.is * jcp.ic,
1535 gemm_col_datatype_size);
1536 if (is_bf16_conv && jcp.with_bias
1537 && one_of(data_type::bf16, cd.diff_bias_desc.data_type,
1538 cd.bias_desc.data_type)) {
1539 scratchpad.book<float>(
1540 key_conv_bias_bf16_convert_wsp, jcp.ngroups * jcp.oc);
1541 }
1542
1543 } else if (!jcp.is_nspc && is_fwd) {
1544 const dim_t sh = jcp.stride_h;
1545 const dim_t sw = jcp.stride_w;
1546 const dim_t spatial = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
1547 dim_t K = jcp.ic * jcp.ks;
1548
1549 // There is some heuristics in the definition of
1550 // inner/outer threading cross point due to the nature of the
1551 // gemm implementation which we cannot control
1552 bool is_blocking_applicable = true && !is_3d
1553 && (!jcp.im2col_sz
1554 // spatial is small
1555 || spatial >= max_threads * simd_w
1556 // inner threading work is greater then outer
1557 // threading work
1558 || jcp.os < jcp.mb * jcp.ngroups * jcp.od
1559 // im2col is big
1560 || (sw == 1 && K <= 0.05 * jcp.oc))
1561 // heuristic condition
1562 && (jcp.im2col_sz
1563 || (jcp.ic / jcp.oc < 42
1564 && jcp.ic * jcp.oc * jcp.is < 1024));
1565
1566 if (is_blocking_applicable) {
1567 const dim_t min_oc_block = 8;
1568 const dim_t min_os_block = simd_w;
1569 const float non_cache_access = 20;
1570 const float strided_im2col_k = 8;
1571 const float thr_disb_k = 8;
1572 const float thr_mem_eff_k {1}, oc_disb_k {1}, os_disb_k {1},
1573 ic_disb_k {1}, reg_osb_disb_k {1}, gemm_eff_k {0.5},
1574 gemm_calc_eff_k {1};
1575 const float k_sum = thr_disb_k + oc_disb_k + os_disb_k
1576 + ic_disb_k + reg_osb_disb_k + thr_mem_eff_k
1577 + gemm_eff_k + gemm_calc_eff_k;
1578
1579 auto calc_max_icb = [=](dim_t nthr_oc, dim_t ocb, dim_t osb,
1580 dim_t oc_per_thr,
1581 dim_t os_per_thr) {
1582 const dim_t block_out_size = ocb * osb;
1583 // TODO: need more precise calculation if stride more than
1584 // kernel size
1585 const dim_t inp_row_size = sh * sw * osb;
1586 dim_t max_icb = 1;
1587 if (jcp.im2col_sz) {
1588 const dim_t col_row_size = jcp.ks * osb;
1589 if (osb >= os_per_thr) { // one pass by os
1590 const dim_t wei_col_size = jcp.ks * ocb;
1591 max_icb = L2 / (inp_row_size + col_row_size);
1592 if (ocb < oc_per_thr) {
1593 max_icb = nstl::min(max_icb,
1594 (L2 - block_out_size)
1595 / (col_row_size
1596 + wei_col_size));
1597 }
1598 } else {
1599 const dim_t wei_col_size = jcp.ks * oc_per_thr;
1600 max_icb = (L2 - block_out_size)
1601 / (inp_row_size + col_row_size
1602 + wei_col_size);
1603 }
1604 } else {
1605 if (osb >= os_per_thr)
1606 max_icb = L2 / inp_row_size;
1607 else {
1608 const dim_t wei_col_size = jcp.ks * oc_per_thr;
1609 max_icb = L2 / (inp_row_size + wei_col_size);
1610 }
1611 }
1612 if (max_icb < jcp.ic) {
1613 if (jcp.im2col_sz) {
1614 const dim_t col_row_size = jcp.ks * osb;
1615 const dim_t wei_col_size = jcp.ks * oc_per_thr;
1616 max_icb = (L2 - block_out_size)
1617 / (inp_row_size + col_row_size
1618 + wei_col_size);
1619 }
1620 }
1621 return max_icb;
1622 };
1623
1624 dim_t best_ocb {1}, best_osb {1};
1625 dim_t best_nthr_oc {1};
1626 dim_t best_icb {jcp.ic};
1627 float best_thr_eff = 0;
1628
1629 auto try_cfg = [&](dim_t nthr_oc, dim_t ocb, dim_t osb) {
1630 // for given nthr_oc, oc block:
1631 // 1. find ic block to fit into cache
1632 // 2. estimate efficiency basing on rules and heuristic:
1633 // - Minimize im2col cost
1634 // - ratio of FMA number to data size
1635 // - gemm works better if M divided by 48 and N divided by 8
1636
1637 const dim_t max_oc = div_up(jcp.oc, nthr_oc);
1638 const dim_t min_oc = nstl::max(dim_t(1), jcp.oc / nthr_oc);
1639 const dim_t max_os
1640 = div_up(spatial, (dim_t)(max_threads / nthr_oc));
1641 ocb = utils::saturate(min_oc_block, max_oc, ocb);
1642 osb = utils::saturate(min_os_block, max_os, osb);
1643
1644 // The computation of max_thr_size and min_thr_size is
1645 // based on work balance using:
1646 // balance2D(max_threads, i, spatial, sp_start, sp_end,
1647 // jcp.oc, oc_start, oc_end, nthr_oc);
1648 size_t max_thr_size = 1;
1649 {
1650 const dim_t min_os = div_up(
1651 spatial, (dim_t)div_up(max_threads, nthr_oc));
1652 /* --- compute max_thr_size ------------
1653 may not necessarily be (max_oc * max_os)
1654 thr_size = thr_oc * (spatial /nthrs_in_slice);
1655 with spatial as const, thr_size has maxima when
1656 (A: thr_oc is max) and (B: nthrs_in_slice is min)
1657 */
1658 if (jcp.oc % nthr_oc > max_threads % nthr_oc) {
1659 // If (A) and (B) are true together, then it is the
1660 // global max
1661 max_thr_size = max_oc * max_os;
1662 } else {
1663 const size_t oc_max_os_min = max_oc * min_os;
1664 const size_t oc_min_os_max = min_oc * max_os;
1665 max_thr_size
1666 = nstl::max(oc_max_os_min, oc_min_os_max);
1667 }
1668 }
1669
1670 size_t min_thr_size {1};
1671 {
1672 const dim_t min_os = nstl::max(dim_t(1),
1673 spatial / div_up(max_threads, nthr_oc));
1674 /* --- compute min_thr_size ------------
1675 may not necessarily be (min_oc * min_y)
1676 thr_size = thr_oc * (spatial /nthrs_in_slice);
1677 with spatial as const, thr_size has minima when
1678 (A: thr_oc is min) and (B: nthrs_in_slice is max)
1679 */
1680 if (max_threads % nthr_oc > jcp.oc % nthr_oc) {
1681 // If (A) and (B) are true together, then it is the
1682 // global min
1683 min_thr_size = min_oc * min_os;
1684 } else {
1685 const size_t oc_max_os_min = max_oc * min_os;
1686 const size_t oc_min_os_max = min_oc
1687 * (size_t)(spatial
1688 / (dim_t)(max_threads / nthr_oc));
1689 min_thr_size
1690 = nstl::min(oc_max_os_min, oc_min_os_max);
1691 }
1692 }
1693 auto thr_disb = (float)min_thr_size / max_thr_size;
1694
1695 const dim_t oc_per_thr = max_oc;
1696 const dim_t os_per_thr = max_os;
1697 ocb = nstl::min(oc_per_thr, ocb);
1698 const dim_t os_max = nstl::min(jcp.os, os_per_thr);
1699 osb = nstl::min(os_max, osb);
1700
1701 // -- selecting icb ---------------------
1702 dim_t max_ic_block = calc_max_icb(
1703 nthr_oc, ocb, osb, oc_per_thr, os_per_thr);
1704 // if we don't fit into cache then access to memory is
1705 // expensive
1706 dim_t mem_access_cost
1707 = (max_ic_block < 1) ? non_cache_access : 1;
1708 max_ic_block = nstl::max(dim_t(1), max_ic_block);
1709 dim_t icb = nstl::max(
1710 dim_t(1), jcp.ic / div_up(jcp.ic, max_ic_block));
1711 dim_t nb_ic = div_up(jcp.ic, icb);
1712 dim_t kb = icb * jcp.ks;
1713 dim_t kb_caligned = rnd_up(kb, simd_w);
1714
1715 // -- mem efficiency ------------
1716 const size_t out_size
1717 = oc_per_thr * rnd_up(os_per_thr, simd_w);
1718 const size_t out_ops = mem_access_cost * out_size
1719 * ((icb == jcp.ic) ? 1 : (2 * nb_ic - 1));
1720 const dim_t osb_caligned = rnd_up(osb, simd_w);
1721 const size_t inp_size
1722 = jcp.ic * rnd_up(os_per_thr * sh * sw, simd_w);
1723 size_t inp_ops = 0;
1724 size_t col_ops = 0;
1725 // TODO: simplify calculations
1726 if (jcp.im2col_sz) {
1727 inp_ops = mem_access_cost * jcp.ks * inp_size;
1728 const float col_tail_koeff = (float)osb_caligned / osb;
1729 col_ops = mem_access_cost
1730 * (jcp.ks * inp_size * col_tail_koeff
1731 + jcp.ks * inp_size * col_tail_koeff);
1732 if (sw != 1) // im2col with strides is much slower
1733 col_ops *= strided_im2col_k;
1734 } else {
1735 inp_ops = mem_access_cost * jcp.ks * inp_size;
1736 }
1737 // TODO: what about groups?
1738 const size_t wei_size = oc_per_thr * rnd_up(K, simd_w);
1739 const size_t wei_ops = mem_access_cost * wei_size;
1740 // ratio of real FMA to number of memory ops
1741 const float thr_mem_eff
1742 = (((float)os_per_thr / simd_w) * oc_per_thr * K)
1743 / (inp_ops + col_ops + wei_ops + out_ops);
1744
1745 auto oc_disb = (float)oc_per_thr / rnd_up(oc_per_thr, ocb);
1746 auto os_disb = (float)os_max / rnd_up(os_max, osb);
1747 auto ic_disb = (float)jcp.ic / rnd_up(jcp.ic, icb);
1748
1749 auto reg_osb_disb = (float)osb / rnd_up(osb, 3 * simd_w);
1750
1751 // Heuristics
1752 const float gemm_eff = ((float)osb * ocb * kb)
1753 / ((float)oc_per_thr * os_per_thr * K);
1754
1755 // number of FMA to memory size
1756 const float gemm_calc_eff
1757 = (((float)osb / simd_w) * ocb * kb)
1758 / (osb_caligned * kb + ocb * kb_caligned
1759 + ocb * osb_caligned);
1760 // optimization: remove pow, when corresponding weight is 1
1761 const float res_eff = pow(pow(thr_disb, thr_disb_k)
1762 * oc_disb // pow(oc_disb, oc_disb_k)
1763 * os_disb // pow(os_disb, os_disb_k)
1764 * ic_disb // pow(ic_disb, ic_disb_k)
1765 // pow(reg_osb_disb, reg_osb_disb_k)
1766 * reg_osb_disb
1767 //pow(thr_mem_eff, thr_mem_eff_k)
1768 * thr_mem_eff
1769 //pow(gemm_calc_eff, gemm_calc_eff_k)
1770 * pow(gemm_eff, gemm_eff_k) * gemm_calc_eff,
1771 1.f / k_sum);
1772
1773 if (res_eff > best_thr_eff) {
1774 best_thr_eff = res_eff;
1775 best_nthr_oc = nthr_oc;
1776 best_ocb = ocb;
1777 best_osb = osb;
1778 best_icb = icb;
1779 }
1780 };
1781
1782 auto explore_cfg = [&](dim_t nthr_oc, dim_t ocb, dim_t osb) {
1783 try_cfg(nthr_oc, ocb, osb);
1784 // few combinations to try, as the eff is better when ocb is
1785 // multiple of 8 and osb is multiple of 48 or min_os_block.
1786 try_cfg(nthr_oc, rnd_dn(ocb, 8), rnd_dn(osb, 48));
1787 try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_dn(osb, 48));
1788 try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_up(osb, min_os_block));
1789 try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_up(osb, 48));
1790 };
1791
1792 for (dim_t nthr_oc = 1; nthr_oc <= max_threads; ++nthr_oc) {
1793 const dim_t max_oc_per_thr = div_up(jcp.oc, nthr_oc);
1794 dim_t max_os_per_thr
1795 = div_up(spatial, max_threads / nthr_oc);
1796 dim_t ocb {1}, osb {1}, icb {1};
1797 if (jcp.im2col_sz) {
1798 try_cfg(nthr_oc, max_oc_per_thr, max_os_per_thr);
1799 if ((best_ocb == max_oc_per_thr)
1800 && (best_osb == max_os_per_thr)
1801 && (best_icb == jcp.ic)) {
1802 // best case scenario
1803 continue;
1804 }
1805
1806 /*
1807 memory eq from calc_max_icb():
1808 max_icb = (L2 - block_out_size)
1809 / (inp_row_size + col_row_size
1810 + wei_col_size);
1811 icb*sh*sw*osb + icb*jcp.ks*osb +
1812 jcp.ks*max_oc_per_thr*icb + osb *ocb = L2
1813
1814 a_k*icb*osb + b_k*icb + osb*ocb = L2
1815 We would like to maximize icb*osb*ocb (FMA).
1816
1817 Unfortunately, above eq and constraint doesn't have
1818 a single solution. So, based on experiments we try
1819 few scenarios.
1820 1. icb = jcp.ic
1821 2. Solving the constraint eq we get
1822 osb = (L2 - 2*b_k*icb)/(2*a_k*icb) >= min_oc_block
1823 => icb <= (L2)/(2* min_oc_block * a_k + 2 * b_k)
1824 3. Maximize channel compute:
1825 ocb = max_oc_per_thr;
1826 icb = jcp.ic;
1827 */
1828 dim_t a_k = sh * sw + jcp.ks;
1829 dim_t b_k = jcp.ks * max_oc_per_thr;
1830
1831 // Note 1:
1832 icb = jcp.ic;
1833 ocb = utils::saturate(min_oc_block, max_oc_per_thr,
1834 (L2 - a_k * icb * min_os_block - b_k * icb)
1835 / min_os_block);
1836 osb = utils::saturate(min_os_block, max_os_per_thr,
1837 (L2 - b_k * icb) / (a_k * icb + ocb));
1838 explore_cfg(nthr_oc, ocb, osb);
1839
1840 // Note 2:
1841 const dim_t icb_max = nstl::max(dim_t(1),
1842 L2 / (2 * min_oc_block * a_k + 2 * b_k));
1843 if (icb_max < jcp.ic) {
1844 // adjust icb, such that it is evenly distributed.
1845 icb = jcp.ic
1846 / nstl::max(dim_t(1), jcp.ic / icb_max);
1847 osb = nstl::max(dim_t(1),
1848 (L2 - 2 * b_k * icb) / (2 * icb * a_k));
1849 ocb = L2 / 2 / osb;
1850
1851 if (ocb > max_oc_per_thr) {
1852 ocb = max_oc_per_thr;
1853 // reduce mem eq by making ocb constant. we get
1854 osb = utils::saturate(min_os_block,
1855 max_os_per_thr,
1856 (L2 - b_k * icb) / (a_k * icb + ocb));
1857 } else if (osb > max_os_per_thr) {
1858 // reduce mem eq by making osb constant. we get
1859 osb = max_os_per_thr;
1860 ocb = utils::saturate(min_oc_block,
1861 max_oc_per_thr,
1862 (L2 - a_k * icb * osb - b_k * icb)
1863 / (osb));
1864 }
1865
1866 explore_cfg(nthr_oc, ocb, osb);
1867 }
1868
1869 // Note 3:
1870 ocb = max_oc_per_thr;
1871 icb = jcp.ic;
1872 osb = nstl::max(min_os_block,
1873 rnd_dn((L2 - b_k * icb) / (a_k * icb + ocb),
1874 min_os_block));
1875 explore_cfg(nthr_oc, ocb, osb);
1876
1877 } else {
1878 // from calc_max_icb, memory eq is independent of ocb.
1879 // So, set it to maximum.
1880 ocb = max_oc_per_thr;
1881 osb = (L2 - jcp.ks * jcp.ic) / (sh * sw * jcp.ic);
1882 explore_cfg(nthr_oc, ocb, osb);
1883 }
1884 }
1885 jcp.outer_threading = true;
1886 jcp.nthr_oc = best_nthr_oc;
1887 jcp.oc_block = best_ocb;
1888 jcp.os_block = best_osb;
1889 jcp.ic_block = best_icb;
1890
1891 // TODO: define loop order
1892 // if im2col then gemm_loop_rlb and gemm_loop_lrb looks
1893 // preferable otherwise other loop orders are possible
1894 jcp.loop_order = gemm_loop_rlb;
1895 } else {
1896 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
1897 const float outer_thr_eff = (float)outer_work_amount
1898 / rnd_up(outer_work_amount, max_threads);
1899 const size_t inner_work_amount
1900 = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
1901 const float inner_thr_eff = (float)inner_work_amount
1902 / rnd_up(inner_work_amount, max_threads);
1903 jcp.outer_threading = jcp.os / max_threads < 512
1904 && IMPLICATION(
1905 jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
1906 && (outer_thr_eff / inner_thr_eff >= 1.f
1907 || (jcp.os * jcp.ic * jcp.oc) / max_threads
1908 < gemm_thrld);
1909 }
1910 jcp.os_nb_block = div_up(jcp.os, jcp.os_block);
1911
1912 // BF16: other loops should be explored for potential
1913 // performance speedup, but BF16-dst post-processing implementation
1914 // would require enabling this support.
1915 if (is_bf16_conv) jcp.loop_order = gemm_loop_lbr;
1916
1917 if (jcp.im2col_sz)
1918 jcp.im2col_sz = (ptrdiff_t)jcp.ic_block * jcp.ks * jcp.os_block;
1919 } else if (jcp.is_nspc && is_bwd_d) {
1920 jcp.im2col_sz
1921 = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih,
1922 jcp.od == jcp.id, jcp.stride_w == 1,
1923 jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1,
1924 !jcp.signed_input)
1925 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os * jcp.od
1926 : 0;
1927
1928 bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
1929 const size_t outer_work = jcp.ngroups * jcp.mb;
1930 const float outer_thr_eff
1931 = (float)outer_work / rnd_up(outer_work, max_threads);
1932 const size_t inner_work
1933 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
1934 const float inner_thr_eff
1935 = (float)inner_work / rnd_up(inner_work, max_threads);
1936 jcp.outer_threading = !is_3d
1937 && (is_depthwise
1938 || (jcp.is / max_threads < 64 && jcp.mb != 1))
1939 && (outer_thr_eff / inner_thr_eff >= 1.f
1940 || (static_cast<size_t>(jcp.is) * jcp.ic * jcp.oc)
1941 / max_threads
1942 < gemm_thrld);
1943
1944 jcp.nthr = jcp.outer_threading ? max_threads : 1;
1945 scratchpad.book<float>(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz);
1946 if (jcp.ngroups > 1 || is_bf16_conv)
1947 scratchpad.book<float>(key_conv_gemm_acc,
1948 jcp.nthr * static_cast<size_t>(jcp.is) * jcp.id
1949 * jcp.ic);
1950 } else if (!jcp.is_nspc && is_bwd_d) {
1951 const size_t outer_work_amount = jcp.ngroups * jcp.mb;
1952 const float outer_thr_eff = (float)outer_work_amount
1953 / rnd_up(outer_work_amount, max_threads);
1954 const size_t inner_work
1955 = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
1956 const float inner_thr_eff
1957 = (float)inner_work / rnd_up(inner_work, max_threads);
1958 jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
1959 && (jcp.mb != 1 || jcp.ngroups > 2)
1960 && (outer_thr_eff / inner_thr_eff >= 1.f
1961 || (jcp.is * jcp.ic * jcp.oc) / max_threads
1962 < gemm_thrld);
1963 } else if (jcp.is_nspc && is_bwd_w) {
1964 jcp.im2col_sz
1965 = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih,
1966 jcp.od == jcp.id, jcp.stride_w == 1,
1967 jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1,
1968 !jcp.signed_input)
1969 ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os
1970 : 0;
1971 const size_t gemm_col_datatype_size
1972 = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float);
1973
1974 // Potential scratchpad memory requirement when outer threading is
1975 // enabled during f32/bf16 BWD_W nspc convolution
1976 size_t thr_mem_estimate = max_threads
1977 * (gemm_col_datatype_size * jcp.im2col_sz
1978 + gemm_col_datatype_size * jcp.id * jcp.is * jcp.ic
1979 + sizeof(float) * weights_d.size());
1980 if (is_bf16_conv) {
1981 thr_mem_estimate += sizeof(float) * weights_d.size();
1982 if (jcp.with_bias
1983 && one_of(data_type::bf16, cd.diff_bias_desc.data_type,
1984 cd.bias_desc.data_type))
1985 thr_mem_estimate += sizeof(float) * jcp.ngroups * jcp.oc;
1986 }
1987 const bool outer_threading_mem_ok
1988 = thr_mem_estimate < scratchpad_limit;
1989
1990 jcp.outer_threading = outer_threading_mem_ok
1991 && jcp.os / max_threads < 256
1992 && (jcp.mb != 1 || jcp.ngroups > 2);
1993 jcp.nthr = jcp.outer_threading ? max_threads : 1;
1994
1995 scratchpad.book(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz,
1996 gemm_col_datatype_size);
1997
1998 jcp.need_wei_reduction = jcp.mb != 1 && jcp.nthr != 1;
1999 scratchpad.book<float>(
2000 key_conv_wei_reduction, jcp.nthr * weights_d.size());
2001 scratchpad.book(key_conv_gemm_imtr,
2002 static_cast<size_t>(jcp.nthr) * jcp.id * jcp.is * jcp.ic,
2003 gemm_col_datatype_size);
2004 if (is_bf16_conv) {
2005 size_t conv_acc_buffer_size = weights_d.size();
2006 scratchpad.book<float>(
2007 key_conv_int_dat_in_acc_dt, conv_acc_buffer_size);
2008 }
2009 if ((is_bf16_conv) && jcp.with_bias
2010 && one_of(data_type::bf16, cd.diff_bias_desc.data_type,
2011 cd.bias_desc.data_type))
2012 scratchpad.book<float>(
2013 key_conv_bias_bf16_convert_wsp, jcp.ngroups * jcp.oc);
2014 } else if (!jcp.is_nspc && is_bwd_w) {
2015 // Potential scratchpad memory requirement when outer threading is
2016 // enabled during f32/bf16 BWD_W blocked convolution
2017 size_t thr_mem_estimate
2018 = sizeof(float) * max_threads * weights_d.size();
2019 if (is_bf16_conv) {
2020 thr_mem_estimate += sizeof(float) * weights_d.size();
2021 if (jcp.with_bias
2022 && one_of(data_type::bf16, cd.diff_bias_desc.data_type,
2023 cd.bias_desc.data_type))
2024 thr_mem_estimate += sizeof(float) * jcp.ngroups * jcp.oc;
2025 }
2026 const size_t gemm_col_datatype_size
2027 = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float);
2028 // Minimum memory requirement as os_block >= simd_w
2029 thr_mem_estimate += gemm_col_datatype_size * max_threads * jcp.ic
2030 * jcp.ks * simd_w;
2031
2032 const bool outer_threading_mem_ok
2033 = thr_mem_estimate < scratchpad_limit;
2034 jcp.outer_threading = outer_threading_mem_ok
2035 && jcp.os / max_threads < 256
2036 && (jcp.mb != 1 || jcp.ngroups > 2);
2037 }
2038
2039 if (!jcp.is_nspc) {
2040 jcp.nthr = jcp.outer_threading ? max_threads : 1;
2041 const int sizeof_cacheline_float = 16;
2042 if (is_bwd_w) {
2043 jcp.need_wei_reduction = jcp.mb != 1 && jcp.nthr != 1;
2044 scratchpad.book<float>(
2045 key_conv_wei_reduction, jcp.nthr * weights_d.size());
2046 }
2047
2048 if (is_bf16_conv) {
2049 size_t conv_acc_buffer_size = 0;
2050 if (is_fwd)
2051 conv_acc_buffer_size = jcp.nthr
2052 * rnd_up(jcp.oc_block * jcp.os_block,
2053 sizeof_cacheline_float);
2054 else if (is_bwd_d)
2055 conv_acc_buffer_size = jcp.nthr
2056 * rnd_up(jcp.ic * jcp.ih * jcp.iw * jcp.id,
2057 sizeof_cacheline_float);
2058 else if (is_bwd_w)
2059 conv_acc_buffer_size = weights_d.size();
2060 scratchpad.book<float>(
2061 key_conv_int_dat_in_acc_dt, conv_acc_buffer_size);
2062 if ((is_fwd || is_bwd_w) && jcp.with_bias
2063 && one_of(data_type::bf16, cd.diff_bias_desc.data_type,
2064 cd.bias_desc.data_type))
2065 scratchpad.book<float>(key_conv_bias_bf16_convert_wsp,
2066 jcp.ngroups * jcp.oc);
2067 }
2068
2069 const size_t gemm_col_datatype_size = is_bf16_conv && !is_bwd_d
2070 ? sizeof(bfloat16_t)
2071 : sizeof(float);
2072 size_t gemm_col_memory_sz = jcp.nthr * jcp.im2col_sz;
2073
2074 if (is_bwd_d || is_bwd_w) {
2075 // check available memory
2076 if (scratchpad_limit < scratchpad.size())
2077 return status::unimplemented;
2078 const size_t available_mem
2079 = scratchpad_limit - scratchpad.size();
2080 if (available_mem
2081 < gemm_col_memory_sz * gemm_col_datatype_size) {
2082 // Required memory in this scenario overflows the
2083 // available memory due to the large dimensions.
2084 const int min_os_block = simd_w;
2085 const int max_os_block = (int)available_mem
2086 / ((int)gemm_col_datatype_size * jcp.nthr
2087 * (jcp.im2col_sz / jcp.os));
2088 // Choose an arbitrary small coeficient reduce spatial
2089 // dimensions.
2090 // TODO: better heuristic to determine os_block based
2091 // on cache efficiency
2092 float _coef = is_bwd_w ? 0.05 : 0.1;
2093 jcp.os_block = nstl::max(
2094 min_os_block, (int)(max_os_block * _coef));
2095 jcp.os_nb_block = div_up(jcp.os, jcp.os_block);
2096 jcp.im2col_sz = (ptrdiff_t)jcp.ic * jcp.ks * jcp.os_block;
2097 gemm_col_memory_sz = jcp.nthr * jcp.im2col_sz;
2098 }
2099 }
2100 scratchpad.book(key_conv_gemm_col, gemm_col_memory_sz,
2101 gemm_col_datatype_size);
2102 }
2103 }
2104
2105 jcp.bias_data_type = cd.bias_desc.data_type;
2106 jcp.dst_data_type = dst_md.data_type;
2107 jcp.sum_data_type = jcp.post_ops.get_sum_dt(jcp.dst_data_type);
2108 jcp.dst_os_stride = dst_d.is_blocking_desc()
2109 ? dst_d.blocking_desc().strides[ndims - 1]
2110 : 0;
2111 jcp.scale_idx_mult = (attr.scales_.get(DNNL_ARG_WEIGHTS).mask_
2112 == (1 << (int)with_groups));
2113 jcp.with_dst_scale = !attr.scales_.get(DNNL_ARG_DST).has_default_values();
2114 book_precomputed_scales(scratchpad, attr.scales_, jcp.ngroups * jcp.oc);
2115
2116 if (jcp.zp.src_exists) {
2117 const auto size = zp_src_comp_size + zp_src_pad_comp_size;
2118 if (size) scratchpad.book<int32_t>(key_conv_gemm_zp_src_comp, size);
2119 }
2120
2121 if (scratchpad.size() > scratchpad_limit) return status::unimplemented;
2122 return status::success;
2123}
2124
2125void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
2126 int &nthr_g, int &ithr_mb, int &nthr_mb) {
2127 nthr_g = nstl::min(ngroups, nthr);
2128 nthr_mb = nstl::min(mb, nthr / nthr_g);
2129 if (ithr / nthr_mb >= ngroups) {
2130 ithr_g = ithr_mb = -1;
2131 } else {
2132 ithr_g = ithr / nthr_mb;
2133 ithr_mb = ithr % nthr_mb;
2134 }
2135}
2136
2137void bwd_weights_reduction_par_ncsp(int ithr, int nthr,
2138 const conv_gemm_conf_t &jcp, const float *weights_reduce_ws,
2139 float *weights) {
2140 const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
2141
2142 size_t weights_start {0}, weights_end {0};
2143 balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
2144
2145 for (int i = 0; i < nthr; ++i) {
2146 const float *ws_i = weights_reduce_ws + i * weights_g_size;
2147 for (size_t s = weights_start; s < weights_end; ++s)
2148 weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
2149 }
2150}
2151
2152void bwd_weights_reduction_par_nspc(int ithr, int nthr, size_t g_start,
2153 size_t g_end, const conv_gemm_conf_t &jcp,
2154 const float *weights_reduce_base, float *diff_weights) {
2155 const dim_t weights_g_size = jcp.oc;
2156 dim_t weights_start {0}, weights_end {0};
2157 balance211(jcp.ks * jcp.ic, nthr, ithr, weights_start, weights_end);
2158
2159 // Threads divide work w.r.t. min-batch and groups, therefore
2160 // - weights_reduce_base format: spatial-input_channels-output_channels
2161 // - diff_weights format: spatial-input_channels-groups-output_channels
2162 for (auto tidx = 0; tidx < nthr; ++tidx) {
2163 const float *ws_base
2164 = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic;
2165 for_(auto w = weights_start; w < weights_end; ++w)
2166 for (auto g = g_start; g < g_end; ++g) {
2167 float *__restrict dwei_ptr
2168 = diff_weights + (w * jcp.ngroups + g) * jcp.oc;
2169 const float *__restrict ws_ptr = ws_base + w * jcp.oc;
2170 if (tidx == 0) {
2171 PRAGMA_OMP_SIMD()
2172 for (auto oc = 0; oc < jcp.oc; ++oc) {
2173 dwei_ptr[oc] = ws_ptr[oc];
2174 }
2175 } else {
2176 PRAGMA_OMP_SIMD()
2177 for (auto oc = 0; oc < jcp.oc; ++oc) {
2178 dwei_ptr[oc] += ws_ptr[oc];
2179 }
2180 }
2181 }
2182 }
2183}
2184
2185bool padding_exists(const conv_gemm_conf_t &jcp) noexcept {
2186 return jcp.l_pad || jcp.t_pad || jcp.f_pad || jcp.e_pad || jcp.b_pad
2187 || jcp.r_pad;
2188}
2189
2190} // namespace jit_gemm_convolution_utils
2191} // namespace cpu
2192} // namespace impl
2193} // namespace dnnl
2194