1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl_types.h" |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | #include "cpu/gemm_convolution_utils.hpp" |
25 | #include "cpu/scale_utils.hpp" |
26 | #if DNNL_X64 |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #endif |
29 | |
30 | #include "cpu/platform.hpp" |
31 | |
32 | #if DNNL_X64 |
33 | #include "cpu/x64/cpu_isa_traits.hpp" |
34 | #endif |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | |
40 | using namespace dnnl::impl::status; |
41 | using namespace dnnl::impl::utils; |
42 | using namespace prop_kind; |
43 | using namespace data_type; |
44 | |
45 | single_gemm_conv_chunk_desc_t::single_gemm_conv_chunk_desc_t(dim_t d_off, |
46 | dim_t d_size, dim_t h_off, dim_t h_size, dim_t w_off, dim_t w_size) |
47 | : d_off_(d_off) |
48 | , d_size_(d_size) |
49 | , h_off_(h_off) |
50 | , h_size_(h_size) |
51 | , w_off_(w_off) |
52 | , w_size_(w_size) {} |
53 | |
54 | namespace jit_gemm_convolution_utils { |
55 | |
56 | template <typename data_type_t> |
57 | void im2col_3d(const conv_gemm_conf_t &jcp, const data_type_t *im, |
58 | data_type_t *col, dim_t od, int spatial_step, int spatial_block) { |
59 | using data_t = |
60 | typename conditional<data_traits<data_type_t>::data_type == bf16, |
61 | uint16_t, data_type_t>::type; |
62 | const data_t *__restrict _im |
63 | = reinterpret_cast<const data_t *__restrict>(im); |
64 | data_t *__restrict _col = reinterpret_cast<data_t *__restrict>(col); |
65 | |
66 | const size_t OHW = spatial_block; |
67 | const size_t im_step = jcp.ih * jcp.iw * jcp.id; |
68 | const size_t col_step = jcp.ks * OHW; |
69 | |
70 | auto compute_im2col_outer_padding = [&](dim_t ic) { |
71 | const data_t *__restrict im_loc = _im + ic * im_step; |
72 | data_t *__restrict col_loc = _col + ic * col_step; |
73 | dim_t id = od * jcp.stride_d - jcp.f_pad; |
74 | for (dim_t kd = 0; kd < jcp.kd; ++kd) { |
75 | data_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; |
76 | if (id < 0 || id >= jcp.id) { |
77 | dim_t ih_ = -jcp.t_pad; |
78 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
79 | dim_t ih = ih_; |
80 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
81 | if (ih < 0 || ih >= jcp.ih) { |
82 | ih += jcp.stride_h; |
83 | continue; |
84 | } |
85 | dim_t iw_ = -jcp.l_pad; |
86 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
87 | dim_t iw = iw_; |
88 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
89 | if (iw < 0 || iw >= jcp.iw) { |
90 | iw += jcp.stride_w; |
91 | continue; |
92 | } |
93 | |
94 | const size_t col_idx |
95 | = kw * OHW + oh * jcp.ow + ow; |
96 | |
97 | col_[col_idx] = 0; |
98 | iw += jcp.stride_w; |
99 | } |
100 | iw_ += (1 + jcp.dilate_w); |
101 | } |
102 | ih += jcp.stride_h; |
103 | } |
104 | ih_ += (1 + jcp.dilate_h); |
105 | col_ += jcp.kw * OHW; |
106 | } |
107 | } else { |
108 | const data_t *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; |
109 | dim_t ih_ = -jcp.t_pad; |
110 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
111 | dim_t ih = ih_; |
112 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
113 | if (ih < 0 || ih >= jcp.ih) { |
114 | ih += jcp.stride_h; |
115 | continue; |
116 | } |
117 | dim_t iw_ = -jcp.l_pad; |
118 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
119 | dim_t iw = iw_; |
120 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
121 | if (iw < 0 || iw >= jcp.iw) { |
122 | iw += jcp.stride_w; |
123 | continue; |
124 | } |
125 | |
126 | const size_t col_idx |
127 | = kw * OHW + oh * jcp.ow + ow; |
128 | const size_t im_idx = ih * jcp.iw + iw; |
129 | |
130 | col_[col_idx] = im_[im_idx]; |
131 | iw += jcp.stride_w; |
132 | } |
133 | iw_ += (1 + jcp.dilate_w); |
134 | } |
135 | ih += jcp.stride_h; |
136 | } |
137 | ih_ += (1 + jcp.dilate_h); |
138 | col_ += jcp.kw * OHW; |
139 | } |
140 | } |
141 | id += (1 + jcp.dilate_d); |
142 | } |
143 | }; |
144 | auto compute_im2col_padding = [&](dim_t ic) { |
145 | const dim_t first_oh = spatial_step / jcp.ow; |
146 | const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow; |
147 | const dim_t oh_begin = first_oh; |
148 | const dim_t oh_end = last_oh + 1; |
149 | const dim_t first_ow = spatial_step % jcp.ow; |
150 | const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow; |
151 | |
152 | const data_t *__restrict im_loc = _im + ic * im_step; |
153 | data_t *__restrict col_loc = _col + ic * col_step; |
154 | dim_t id = od * jcp.stride_d - jcp.f_pad; |
155 | for (dim_t kd = 0; kd < jcp.kd; ++kd) { |
156 | data_t *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW; |
157 | if (id < 0 || id >= jcp.id) { |
158 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
159 | for (dim_t oh = oh_begin; oh < oh_end; ++oh) { |
160 | const dim_t ow_begin = (oh == first_oh) ? first_ow : 0; |
161 | const dim_t ow_end |
162 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
163 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
164 | for (dim_t ow = ow_begin; ow < ow_end; ++ow) { |
165 | const size_t col_idx = kw * OHW + oh * jcp.ow |
166 | + ow - spatial_step; |
167 | col_[col_idx] = 0; |
168 | } |
169 | } |
170 | } |
171 | col_ += jcp.kw * OHW; |
172 | } |
173 | } else { |
174 | const data_t *__restrict im_ = im_loc + id * jcp.ih * jcp.iw; |
175 | dim_t ih_ = oh_begin * jcp.stride_h - jcp.t_pad; |
176 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
177 | dim_t ih = ih_; |
178 | for (dim_t oh = oh_begin; oh < oh_end; ++oh) { |
179 | const dim_t ow_begin = (oh == first_oh) ? first_ow : 0; |
180 | const dim_t ow_end |
181 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
182 | if (ih < 0 || ih >= jcp.ih) { |
183 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
184 | for (dim_t ow = ow_begin; ow < ow_end; ++ow) { |
185 | const size_t col_idx = kw * OHW |
186 | + oh * jcp.ow + ow - spatial_step; |
187 | col_[col_idx] = 0; |
188 | } |
189 | } |
190 | ih += jcp.stride_h; |
191 | continue; |
192 | } |
193 | dim_t iw_ = ow_begin * jcp.stride_w - jcp.l_pad; |
194 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
195 | dim_t iw = iw_; |
196 | for (dim_t ow = ow_begin; ow < ow_end; ++ow) { |
197 | const size_t col_idx = kw * OHW + oh * jcp.ow |
198 | + ow - spatial_step; |
199 | if (iw < 0 || iw >= jcp.iw) { |
200 | col_[col_idx] = 0; |
201 | iw += jcp.stride_w; |
202 | continue; |
203 | } |
204 | const size_t im_idx = ih * jcp.iw + iw; |
205 | col_[col_idx] = im_[im_idx]; |
206 | iw += jcp.stride_w; |
207 | } |
208 | iw_ += (1 + jcp.dilate_w); |
209 | } |
210 | ih += jcp.stride_h; |
211 | } |
212 | ih_ += (1 + jcp.dilate_h); |
213 | col_ += jcp.kw * OHW; |
214 | } |
215 | } |
216 | id += (1 + jcp.dilate_d); |
217 | } |
218 | }; |
219 | |
220 | // zero padding is handled outside im2col |
221 | const bool outer_padding = jcp.os_nb_block == 1; |
222 | if (outer_padding) |
223 | parallel_nd(jcp.ic, compute_im2col_outer_padding); |
224 | else |
225 | parallel_nd(jcp.ic, compute_im2col_padding); |
226 | } |
227 | |
228 | template void im2col_3d(const conv_gemm_conf_t &jcp, const float *im, |
229 | float *col, dim_t od, int spatial_step, int spatial_block); |
230 | |
231 | template void im2col_3d(const conv_gemm_conf_t &jcp, const bfloat16_t *im, |
232 | bfloat16_t *col, dim_t od, int spatial_step, int spatial_block); |
233 | |
234 | /* imtr[ic][od][oh][ow] <-- im[id][ih][iw][ic]*/ |
235 | template <typename T> |
236 | void transpose_dt(const conv_gemm_conf_t &jcp, const T *__restrict im, |
237 | T *__restrict imtr) { |
238 | uint8_t shift = jcp.signed_input ? 128 : 0; |
239 | const dim_t ic_stride = jcp.id * jcp.ih * jcp.iw; |
240 | const dim_t IC = jcp.ngroups * jcp.ic; |
241 | const dim_t IHW = jcp.ih * jcp.iw; |
242 | constexpr dim_t ic_block = platform::get_cache_line_size(); |
243 | const dim_t nb_ic = jcp.ic / ic_block; |
244 | const dim_t ic_blocked = nb_ic * ic_block; |
245 | parallel_nd(jcp.id, jcp.ih, [&](dim_t id, dim_t ih) { |
246 | const T *__restrict im_h = im + id * IHW * IC + ih * jcp.iw * IC; |
247 | T *__restrict imtr_h = imtr + id * IHW + ih * jcp.iw; |
248 | for (dim_t iw = 0; iw < jcp.iw; iw++) { |
249 | const T *__restrict im_w = im_h + iw * IC; |
250 | T *__restrict imtr_w = imtr_h + iw; |
251 | for (dim_t icb = 0; icb < nb_ic; icb++) { |
252 | const T *__restrict im_icb = im_w + icb * ic_block; |
253 | T *__restrict imtr_icb = imtr_w + icb * ic_block * ic_stride; |
254 | PRAGMA_OMP_SIMD() |
255 | for (dim_t ic = 0; ic < ic_block; ic++) { |
256 | imtr_icb[ic * ic_stride] = im_icb[ic] + shift; |
257 | } |
258 | } |
259 | for (dim_t ic = ic_blocked; ic < jcp.ic; ic++) { |
260 | imtr_w[ic * ic_stride] = im_w[ic] + shift; |
261 | } |
262 | } |
263 | }); |
264 | } |
265 | |
266 | template void transpose_dt(const conv_gemm_conf_t &jcp, |
267 | const int8_t *__restrict im, int8_t *__restrict imtr); |
268 | template void transpose_dt(const conv_gemm_conf_t &jcp, |
269 | const uint8_t *__restrict im, uint8_t *__restrict imtr); |
270 | template void transpose_dt(const conv_gemm_conf_t &jcp, |
271 | const char *__restrict im, char *__restrict imtr); |
272 | template void transpose_dt(const conv_gemm_conf_t &jcp, |
273 | const float *__restrict im, float *__restrict imtr); |
274 | template void transpose_dt(const conv_gemm_conf_t &jcp, |
275 | const bfloat16_t *__restrict im, bfloat16_t *__restrict imtr); |
276 | |
277 | /* col[kd][kh][kw][g][ic][od][oh][ow] <-- im2col_dt_3d(im[id][ih][iw][g][ic]) */ |
278 | template <typename orig_im_dt, typename orig_col_dt> |
279 | void im2col_dt_3d(const conv_gemm_conf_t &jcp, const void *__restrict _imtr, |
280 | orig_col_dt *__restrict _col, dim_t od) { |
281 | // For performance reasons, use uint16_t as a proxy for bfloat16_t |
282 | using im_dt = typename utils::conditional<data_traits<orig_im_dt>::data_type |
283 | == bf16, |
284 | uint16_t, orig_im_dt>::type; |
285 | using col_dt = |
286 | typename utils::conditional<data_traits<orig_col_dt>::data_type |
287 | == bf16, |
288 | uint16_t, orig_col_dt>::type; |
289 | const im_dt *__restrict imtr |
290 | = reinterpret_cast<const im_dt *__restrict>(_imtr); |
291 | col_dt *__restrict col = reinterpret_cast<col_dt *__restrict>(_col); |
292 | |
293 | col_dt shift = static_cast<col_dt>(jcp.signed_input ? 128 : 0); |
294 | const dim_t dd = 1 + jcp.dilate_d; |
295 | const dim_t dh = 1 + jcp.dilate_h; |
296 | const dim_t dw = 1 + jcp.dilate_w; |
297 | const dim_t sd = jcp.stride_d; |
298 | const dim_t sh = jcp.stride_h; |
299 | const dim_t sw = jcp.stride_w; |
300 | const dim_t fp = jcp.f_pad; |
301 | const dim_t tp = jcp.t_pad; |
302 | const dim_t lp = jcp.l_pad; |
303 | const dim_t col_ic_s = jcp.oh * jcp.ow; |
304 | const dim_t col_kw_s = jcp.ic * col_ic_s; |
305 | const dim_t col_kh_s = jcp.kw * col_kw_s; |
306 | const dim_t col_kd_s = jcp.kh * col_kh_s; |
307 | const dim_t IHW = jcp.ih * jcp.iw; |
308 | const dim_t OHW = jcp.oh * jcp.ow; |
309 | |
310 | if (sd == 1 && sh == 1 && sw == 1 && dd == 1 && dh == 1 && dw == 1) |
311 | parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, |
312 | [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { |
313 | col_dt *__restrict col_loc = col + kd * col_kd_s |
314 | + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; |
315 | const dim_t id = od - fp + kd; |
316 | if (id < 0 || id >= jcp.id) { |
317 | for (ptrdiff_t i = 0; i < OHW; i++) |
318 | col_loc[i] = shift; |
319 | return; |
320 | } |
321 | const im_dt *__restrict imtr_loc |
322 | = imtr + (ic * jcp.id + id) * IHW; |
323 | const dim_t oh_start = saturate(dim_t(0), jcp.oh, tp - kh); |
324 | const dim_t oh_end |
325 | = saturate(dim_t(0), jcp.oh, jcp.ih + tp - kh); |
326 | const dim_t ow_start = saturate(dim_t(0), jcp.ow, lp - kw); |
327 | const dim_t ow_end |
328 | = saturate(dim_t(0), jcp.ow, jcp.iw + lp - kw); |
329 | for (dim_t oh = oh_start, ih = oh_start - tp + kh; |
330 | oh < oh_end; oh++, ih++) { |
331 | col_dt *__restrict col_h = col_loc + oh * jcp.ow; |
332 | const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw; |
333 | for (dim_t ow = ow_start, iw = ow_start - lp + kw; |
334 | ow < ow_end; ow++, iw++) { |
335 | col_h[ow] = imtr_h[iw]; |
336 | } |
337 | } |
338 | }); |
339 | else if (sd == 2 && sh == 2 && sw == 2 && dd == 1 && dh == 1 && dw == 1) |
340 | parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, |
341 | [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { |
342 | col_dt *__restrict col_loc = col + kd * col_kd_s |
343 | + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; |
344 | const dim_t id = od * 2 - fp + kd; |
345 | if (id < 0 || id >= jcp.id) { |
346 | for (ptrdiff_t i = 0; i < OHW; i++) |
347 | col_loc[i] = shift; |
348 | return; |
349 | } |
350 | const im_dt *__restrict imtr_loc |
351 | = imtr + (ic * jcp.id + id) * IHW; |
352 | const dim_t oh_start |
353 | = saturate(dim_t(0), jcp.oh, div_up(tp - kh, 2)); |
354 | const dim_t oh_end = saturate( |
355 | dim_t(0), jcp.oh, div_up(jcp.ih + tp - kh, 2)); |
356 | const dim_t ow_start |
357 | = saturate(dim_t(0), jcp.ow, div_up(lp - kw, 2)); |
358 | const dim_t ow_end = saturate( |
359 | dim_t(0), jcp.ow, div_up(jcp.iw + lp - kw, 2)); |
360 | for (dim_t oh = oh_start, ih = oh_start * 2 - tp + kh; |
361 | oh < oh_end; ++oh, ih += 2) { |
362 | col_dt *__restrict col_h = col_loc + oh * jcp.ow; |
363 | const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw; |
364 | for (dim_t ow = ow_start, iw = ow_start * 2 - lp + kw; |
365 | ow < ow_end; ++ow, iw += 2) { |
366 | col_h[ow] = imtr_h[iw]; |
367 | } |
368 | } |
369 | }); |
370 | else |
371 | parallel_nd(jcp.kd, jcp.kh, jcp.kw, jcp.ic, |
372 | [&](dim_t kd, dim_t kh, dim_t kw, dim_t ic) { |
373 | col_dt *__restrict col_loc = col + kd * col_kd_s |
374 | + kh * col_kh_s + kw * col_kw_s + ic * col_ic_s; |
375 | const dim_t id = od * sd - fp + kd * dd; |
376 | if (id < 0 || id >= jcp.id) { |
377 | for (ptrdiff_t i = 0; i < OHW; i++) |
378 | col_loc[i] = shift; |
379 | return; |
380 | } |
381 | const im_dt *__restrict imtr_loc |
382 | = imtr + (ic * jcp.id + id) * IHW; |
383 | const dim_t oh_start = saturate( |
384 | dim_t(0), jcp.oh, div_up(tp - kh * dh, sh)); |
385 | const dim_t oh_end = saturate(dim_t(0), jcp.oh, |
386 | div_up(jcp.ih + tp - kh * dh, sh)); |
387 | const dim_t ow_start = saturate( |
388 | dim_t(0), jcp.ow, div_up(lp - kw * dw, sw)); |
389 | const dim_t ow_end = saturate(dim_t(0), jcp.ow, |
390 | div_up(jcp.iw + lp - kw * dw, sw)); |
391 | for (dim_t oh = oh_start, ih = oh_start * sh - tp + kh * dh; |
392 | oh < oh_end; ++oh, ih += sh) { |
393 | col_dt *__restrict col_h = col_loc + oh * jcp.ow; |
394 | const im_dt *__restrict imtr_h = imtr_loc + ih * jcp.iw; |
395 | for (dim_t ow = ow_start, |
396 | iw = ow_start * sw - lp + kw * dw; |
397 | ow < ow_end; ++ow, iw += sw) { |
398 | col_h[ow] = imtr_h[iw]; |
399 | } |
400 | } |
401 | }); |
402 | } |
403 | |
404 | template void im2col_dt_3d<int8_t, uint8_t>(const conv_gemm_conf_t &jcp, |
405 | const void *__restrict im, uint8_t *__restrict col, dim_t od); |
406 | template void im2col_dt_3d<uint8_t, uint8_t>(const conv_gemm_conf_t &jcp, |
407 | const void *__restrict im, uint8_t *__restrict col, dim_t od); |
408 | template void im2col_dt_3d<float, float>(const conv_gemm_conf_t &jcp, |
409 | const void *__restrict im, float *__restrict col, dim_t od); |
410 | template void im2col_dt_3d<bfloat16_t, bfloat16_t>(const conv_gemm_conf_t &jcp, |
411 | const void *__restrict im, bfloat16_t *__restrict col, dim_t od); |
412 | |
413 | /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ |
414 | template <typename data_type_t> |
415 | void im2col(const conv_gemm_conf_t &jcp, const data_type_t *__restrict im, |
416 | data_type_t *__restrict col, dim_t ss, dim_t sb, dim_t cs, dim_t cb) { |
417 | |
418 | using data_t = |
419 | typename utils::conditional<data_traits<data_type_t>::data_type |
420 | == bf16, |
421 | uint16_t, data_type_t>::type; |
422 | const data_t *__restrict _im |
423 | = reinterpret_cast<const data_t *__restrict>(im); |
424 | data_t *__restrict _col = reinterpret_cast<data_t *__restrict>(col); |
425 | |
426 | const size_t im_step = jcp.is; |
427 | const size_t col_step = jcp.ks * sb; |
428 | const dim_t dh = 1 + jcp.dilate_h; |
429 | const dim_t dw = 1 + jcp.dilate_w; |
430 | const dim_t sh = jcp.stride_h; |
431 | const dim_t sw = jcp.stride_w; |
432 | const dim_t tp = jcp.t_pad; |
433 | const dim_t lp = jcp.l_pad; |
434 | const dim_t first_oh = ss / jcp.ow; |
435 | const dim_t last_oh = (ss + sb - 1) / jcp.ow; |
436 | const dim_t oh_begin = first_oh; |
437 | const dim_t oh_end = last_oh + 1; |
438 | const dim_t first_ow = ss % jcp.ow; |
439 | const dim_t last_ow = (ss + sb - 1) % jcp.ow; |
440 | |
441 | const data_t zero_val = 0; |
442 | |
443 | if (jcp.outer_threading) { |
444 | if (sw == 1) { |
445 | // Generated code is more optimized for stride_w == 1 |
446 | // because innermost loop is by width |
447 | for (dim_t ic = 0; ic < cb; ic++) { |
448 | const data_t *__restrict im_ic = _im + (ic + cs) * im_step; |
449 | for (dim_t kh = 0; kh < jcp.kh; kh++) { |
450 | for (dim_t kw = 0; kw < jcp.kw; kw++) { |
451 | data_t *__restrict col_k = _col + ic * col_step |
452 | + (kh * jcp.kw + kw) * sb; |
453 | for (dim_t oh = oh_begin; oh < oh_end; oh++) { |
454 | const dim_t ih = oh * sh - tp + kh * dh; |
455 | const data_t *__restrict im_ |
456 | = im_ic + ih * jcp.iw - lp + kw * dw; |
457 | const dim_t ow_begin |
458 | = (oh == first_oh) ? first_ow : 0; |
459 | const dim_t ow_end |
460 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
461 | data_t *__restrict col_ = col_k + oh * jcp.ow - ss; |
462 | if (ih < 0 || ih >= jcp.ih) |
463 | for (dim_t ow = ow_begin; ow < ow_end; ow++) |
464 | col_[ow] = zero_val; |
465 | else { |
466 | for (dim_t ow = ow_begin; ow < ow_end; ++ow) { |
467 | const dim_t iw = ow; |
468 | if (iw < lp - kw * dw |
469 | || iw >= jcp.iw + lp - kw * dw) |
470 | col_[ow] = zero_val; |
471 | else |
472 | col_[ow] = im_[iw]; |
473 | } |
474 | } |
475 | } |
476 | } |
477 | } |
478 | } |
479 | } else { |
480 | for (dim_t ic = 0; ic < cb; ic++) { |
481 | const data_t *__restrict im_ = _im + (ic + cs) * im_step; |
482 | for (dim_t kh = 0; kh < jcp.kh; kh++) { |
483 | for (dim_t kw = 0; kw < jcp.kw; kw++) { |
484 | data_t *__restrict col_k = _col + ic * col_step |
485 | + (kh * jcp.kw + kw) * sb; |
486 | for (dim_t oh = oh_begin; oh < oh_end; oh++) { |
487 | const dim_t ih = oh * sh - tp + kh * dh; |
488 | const dim_t ow_begin |
489 | = (oh == first_oh) ? first_ow : 0; |
490 | const dim_t ow_end |
491 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
492 | data_t *__restrict col_oh |
493 | = col_k + oh * jcp.ow - ss; |
494 | if (ih < 0 || ih >= jcp.ih) |
495 | for (dim_t ow = ow_begin; ow < ow_end; ow++) |
496 | col_oh[ow] = zero_val; |
497 | else |
498 | for (dim_t ow = ow_begin; ow < ow_end; ow++) { |
499 | const dim_t iw = ow * sw - lp + kw * dw; |
500 | if (iw < 0 || iw >= jcp.iw) |
501 | col_oh[ow] = zero_val; |
502 | else { |
503 | const ptrdiff_t im_idx |
504 | = ih * jcp.iw + iw; |
505 | col_oh[ow] = im_[im_idx]; |
506 | } |
507 | } |
508 | } |
509 | } |
510 | } |
511 | } |
512 | } |
513 | } else { |
514 | // TODO: optimize threading if jcp.ic*jcp.kh*jcp.kw*oh_range is small |
515 | // comparing to number of threads |
516 | const dim_t oh_range = oh_end - oh_begin; |
517 | // Generated code is more optimized for stride_w == 1 |
518 | // because innermost loop is by width |
519 | if (sw == 1) |
520 | parallel_nd(cb, jcp.kh, jcp.kw, oh_range, |
521 | [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { |
522 | const dim_t oh = ohr + oh_begin; |
523 | const dim_t ih = oh * sh - tp + kh * dh; |
524 | const dim_t ow_start = (oh == first_oh) ? first_ow : 0; |
525 | const dim_t ow_end |
526 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
527 | data_t *__restrict col_oh = _col + ic * col_step |
528 | + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss; |
529 | const data_t *__restrict im_ |
530 | = _im + (ic + cs) * im_step + ih * jcp.iw; |
531 | const dim_t iw_shift = kw * dw - lp; |
532 | if (ih < 0 || ih >= jcp.ih) |
533 | for (dim_t ow = ow_start; ow < ow_end; ow++) |
534 | col_oh[ow] = zero_val; |
535 | else |
536 | for (dim_t ow = ow_start; ow < ow_end; ow++) { |
537 | const dim_t iw = ow + iw_shift; |
538 | if (iw < 0 || iw >= jcp.iw) |
539 | col_oh[ow] = zero_val; |
540 | else |
541 | col_oh[ow] = im_[iw]; |
542 | } |
543 | }); |
544 | else |
545 | parallel_nd(cb, jcp.kh, jcp.kw, oh_range, |
546 | [&](dim_t ic, dim_t kh, dim_t kw, dim_t ohr) { |
547 | const dim_t oh = ohr + oh_begin; |
548 | const dim_t ih = oh * sh - tp + kh * dh; |
549 | const dim_t ow_start = (oh == first_oh) ? first_ow : 0; |
550 | const dim_t ow_end |
551 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
552 | data_t *__restrict col_oh = _col + ic * col_step |
553 | + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss; |
554 | const data_t *__restrict im_ |
555 | = _im + (ic + cs) * im_step; |
556 | if (ih < 0 || ih >= jcp.ih) |
557 | for (dim_t ow = ow_start; ow < ow_end; ow++) |
558 | col_oh[ow] = zero_val; |
559 | else |
560 | for (dim_t ow = ow_start; ow < ow_end; ow++) { |
561 | const dim_t iw = ow * sw - lp + kw * dw; |
562 | if (iw < 0 || iw >= jcp.iw) |
563 | col_oh[ow] = zero_val; |
564 | else { |
565 | const ptrdiff_t im_idx = ih * jcp.iw + iw; |
566 | col_oh[ow] = im_[im_idx]; |
567 | } |
568 | } |
569 | }); |
570 | } |
571 | } |
572 | |
573 | template void im2col(const conv_gemm_conf_t &jcp, const float *__restrict im, |
574 | float *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); |
575 | |
576 | template void im2col(const conv_gemm_conf_t &jcp, |
577 | const bfloat16_t *__restrict im, bfloat16_t *__restrict col, dim_t hs, |
578 | dim_t hb, dim_t ws, dim_t wb); |
579 | |
580 | /* col[kh][kw][ic][oh][ow] <-- im2col_dt(im[ih][iw][ic]) */ |
581 | template <typename orig_im_dt, typename orig_col_dt> |
582 | void im2col_dt(const conv_gemm_conf_t &jcp, const void *__restrict _im, |
583 | void *__restrict _imtr, orig_col_dt *__restrict _col, dim_t hs, |
584 | dim_t hb, dim_t ws, dim_t wb) { |
585 | // For performance reasons, use uint16_t as a proxy for bfloat16_t |
586 | using im_dt = typename utils::conditional<data_traits<orig_im_dt>::data_type |
587 | == bf16, |
588 | uint16_t, orig_im_dt>::type; |
589 | using col_dt = |
590 | typename utils::conditional<data_traits<orig_col_dt>::data_type |
591 | == bf16, |
592 | uint16_t, orig_col_dt>::type; |
593 | const im_dt *__restrict im = reinterpret_cast<const im_dt *__restrict>(_im); |
594 | im_dt *__restrict imtr = reinterpret_cast<im_dt *__restrict>(_imtr); |
595 | col_dt *__restrict col = reinterpret_cast<col_dt *__restrict>(_col); |
596 | |
597 | col_dt shift = static_cast<col_dt>(jcp.signed_input ? 128 : 0); |
598 | const dim_t dh = 1 + jcp.dilate_h; |
599 | const dim_t dw = 1 + jcp.dilate_w; |
600 | const dim_t sh = jcp.stride_h; |
601 | const dim_t sw = jcp.stride_w; |
602 | const dim_t im_iw_stride = jcp.ic * jcp.ngroups; |
603 | const dim_t im_ih_stride = jcp.iw * im_iw_stride; |
604 | const dim_t tp = jcp.t_pad; |
605 | const dim_t lp = jcp.l_pad; |
606 | |
607 | if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) { |
608 | /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ |
609 | const dim_t hp = hs - tp; |
610 | const dim_t wp = ws - lp; |
611 | const dim_t ih_start = saturate(dim_t(0), jcp.ih, hp); |
612 | const dim_t ih_end = saturate(dim_t(0), jcp.ih, hp + hb + jcp.kh); |
613 | const dim_t iw_start = saturate(dim_t(0), jcp.iw, wp); |
614 | const dim_t iw_end = saturate(dim_t(0), jcp.iw, wp + wb + jcp.kw); |
615 | |
616 | const dim_t ihb = ih_end - ih_start; |
617 | const dim_t iwb = iw_end - iw_start; |
618 | |
619 | const dim_t imtr_ic_stride = ihb * iwb; |
620 | const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start; |
621 | for (dim_t ic = 0; ic < jcp.ic; ic++) { |
622 | const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift; |
623 | for (dim_t ih = ih_start; ih < ih_end; ih++) { |
624 | const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride; |
625 | const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb; |
626 | for (dim_t iw = iw_start; iw < iw_end; iw++) |
627 | imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride]; |
628 | } |
629 | } |
630 | |
631 | const dim_t col_ic_str = hb * wb; |
632 | const dim_t col_kw_stride = jcp.ic * col_ic_str; |
633 | const dim_t col_kh_stride = jcp.kw * col_kw_stride; |
634 | |
635 | const dim_t oh_init = ih_start - hp; |
636 | const dim_t ow_init = iw_start - wp; |
637 | for (dim_t kh = 0; kh < jcp.kh; kh++) { |
638 | const ptrdiff_t col_idx_kh = kh * col_kh_stride; |
639 | const dim_t oh_kh = oh_init - kh; |
640 | const dim_t oh_start = saturate(dim_t(0), hb, oh_kh); |
641 | const dim_t oh_end = saturate(dim_t(0), hb, oh_kh + ihb); |
642 | for (dim_t kw = 0; kw < jcp.kw; kw++) { |
643 | const ptrdiff_t col_idx_kw |
644 | = col_idx_kh + kw * jcp.ic * col_ic_str; |
645 | const dim_t ow_kw = ow_init - kw; |
646 | const dim_t imtr_shift = oh_kh * iwb + ow_kw; |
647 | const dim_t ow_start = saturate(dim_t(0), wb, ow_kw); |
648 | const dim_t ow_end = saturate(dim_t(0), wb, ow_kw + iwb); |
649 | for (dim_t ic = 0; ic < jcp.ic; ic++) { |
650 | const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; |
651 | const dim_t imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; |
652 | for (dim_t oh = 0; oh < oh_start; oh++) { |
653 | const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; |
654 | for (dim_t ow = 0; ow < wb; ++ow) |
655 | col[col_idx_oh + ow] = shift; |
656 | } |
657 | for (dim_t oh = oh_start; oh < oh_end; oh++) { |
658 | const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; |
659 | const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb; |
660 | for (dim_t ow = 0; ow < ow_start; ++ow) |
661 | col[col_idx_oh + ow] = shift; |
662 | for (dim_t ow = ow_start; ow < ow_end; ++ow) |
663 | col[col_idx_oh + ow] |
664 | = imtr[imtr_idx_oh + ow] + shift; |
665 | for (dim_t ow = ow_end; ow < wb; ++ow) |
666 | col[col_idx_oh + ow] = shift; |
667 | } |
668 | for (dim_t oh = oh_end; oh < hb; oh++) { |
669 | const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb; |
670 | for (dim_t ow = 0; ow < wb; ++ow) |
671 | col[col_idx_oh + ow] = shift; |
672 | } |
673 | } |
674 | } |
675 | } |
676 | } else { |
677 | parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb, |
678 | [&](dim_t kh, dim_t kw, dim_t ic, dim_t oh) { |
679 | const dim_t hp = tp - kh * dh; |
680 | const dim_t ih = (oh + hs) * sh - hp; |
681 | const ptrdiff_t col_idx_base |
682 | = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) |
683 | * wb; |
684 | if (ih < 0 || ih >= jcp.ih) |
685 | for (dim_t ow = 0; ow < wb; ow++) |
686 | col[col_idx_base + ow] = shift; |
687 | else { |
688 | const dim_t wp = lp - kw * dw; |
689 | const dim_t ow_start |
690 | = saturate(dim_t(0), wb, div_up(wp, sw) - ws); |
691 | const dim_t ow_end = saturate( |
692 | dim_t(0), wb, div_up(jcp.iw + wp, sw) - ws); |
693 | for (dim_t ow = 0; ow < ow_start; ow++) |
694 | col[col_idx_base + ow] = shift; |
695 | const dim_t iw_base = ws * sw - wp; |
696 | const ptrdiff_t im_idx_base = ih * im_ih_stride + ic; |
697 | for (dim_t ow = ow_start; ow < ow_end; ow++) { |
698 | const dim_t iw = iw_base + ow * sw; |
699 | const ptrdiff_t im_idx |
700 | = im_idx_base + iw * im_iw_stride; |
701 | col[col_idx_base + ow] = im[im_idx] + shift; |
702 | } |
703 | for (dim_t ow = ow_end; ow < wb; ow++) |
704 | col[col_idx_base + ow] = shift; |
705 | } |
706 | }); |
707 | } |
708 | } |
709 | |
710 | template void im2col_dt<int8_t, uint8_t>(const conv_gemm_conf_t &jcp, |
711 | const void *__restrict im, void *__restrict imtr, |
712 | uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); |
713 | template void im2col_dt<uint8_t, uint8_t>(const conv_gemm_conf_t &jcp, |
714 | const void *__restrict im, void *__restrict imtr, |
715 | uint8_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); |
716 | template void im2col_dt<float, float>(const conv_gemm_conf_t &jcp, |
717 | const void *__restrict im, void *__restrict imtr, float *__restrict col, |
718 | dim_t hs, dim_t hb, dim_t ws, dim_t wb); |
719 | |
720 | template void im2col_dt<bfloat16_t, bfloat16_t>(const conv_gemm_conf_t &jcp, |
721 | const void *__restrict im, void *__restrict imtr, |
722 | bfloat16_t *__restrict col, dim_t hs, dim_t hb, dim_t ws, dim_t wb); |
723 | |
724 | /* im[id][ih][iw][ic] <-- col2im_dt_3d(col[od][oh][ow][kd][kh][kw][ic]) */ |
725 | template <typename orig_T> |
726 | void col2im_dt(const conv_gemm_conf_t &jcp, const orig_T *__restrict _col, |
727 | orig_T *__restrict _im) { |
728 | // For performance reasons, use uint16_t as a proxy for bfloat16_t |
729 | using T = |
730 | typename utils::conditional<data_traits<orig_T>::data_type == bf16, |
731 | uint16_t, orig_T>::type; |
732 | const T *__restrict col = reinterpret_cast<const T *__restrict>(_col); |
733 | T *__restrict im = reinterpret_cast<T *__restrict>(_im); |
734 | |
735 | parallel(0, [&](const int ithr, const int nthr) { |
736 | dim_t d_nthr = nstl::min(jcp.id, dim_t(nthr)); |
737 | dim_t h_nthr = nstl::min(jcp.ih, dim_t(nthr) / d_nthr); |
738 | dim_t w_nthr = nstl::min(jcp.iw, dim_t(nthr) / (d_nthr * h_nthr)); |
739 | dim_t d_ithr = 1, d_s = 0, d_e = 0, h_ithr = 1, h_s = 0, h_e = 0, |
740 | w_ithr = 1, w_s = 0, w_e = 0; |
741 | if (ithr < d_nthr * h_nthr * w_nthr) { |
742 | d_ithr = ithr / (h_nthr * w_nthr); |
743 | h_ithr = (ithr % (h_nthr * w_nthr)) / w_nthr; |
744 | w_ithr = (ithr % (h_nthr * w_nthr)) % w_nthr; |
745 | balance211(jcp.id, d_nthr, d_ithr, d_s, d_e); |
746 | balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e); |
747 | balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e); |
748 | } else { |
749 | d_nthr = h_ithr = w_ithr = -ithr; |
750 | d_s = d_e = h_s = h_e = w_s = w_e = -1; |
751 | } |
752 | |
753 | for_(dim_t id = d_s; id < d_e; ++id) |
754 | for_(dim_t ih = h_s; ih < h_e; ++ih) |
755 | for (dim_t iw = w_s; iw < w_e; ++iw) { |
756 | PRAGMA_OMP_SIMD() |
757 | for (dim_t ic = 0; ic < jcp.ic; ++ic) { |
758 | im[((id * jcp.ih + ih) * jcp.iw + iw) * jcp.ic + ic] = 0; |
759 | } |
760 | } |
761 | |
762 | // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh] |
763 | for_(dim_t od = 0; od < jcp.od; ++od) |
764 | for_(dim_t oh = 0; oh < jcp.oh; ++oh) |
765 | for_(dim_t ow = 0; ow < jcp.ow; ++ow) |
766 | for (dim_t kd = 0; kd < jcp.kd; ++kd) { |
767 | const dim_t id |
768 | = od * jcp.stride_d - jcp.f_pad + kd * (1 + jcp.dilate_d); |
769 | if (id < d_s || id >= d_e) continue; |
770 | |
771 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
772 | const dim_t ih = oh * jcp.stride_h - jcp.t_pad |
773 | + kh * (1 + jcp.dilate_h); |
774 | if (ih < h_s || ih >= h_e) continue; |
775 | |
776 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
777 | const dim_t iw = ow * jcp.stride_w - jcp.l_pad |
778 | + kw * (1 + jcp.dilate_w); |
779 | if (iw < w_s || iw >= w_e) continue; |
780 | |
781 | const size_t col_idx |
782 | = (((((od * jcp.oh + oh) * jcp.ow + ow) * jcp.kd |
783 | + kd) * jcp.kh |
784 | + kh) * jcp.kw |
785 | + kw) |
786 | * jcp.ic; |
787 | const size_t im_idx |
788 | = ((id * jcp.ih + ih) * jcp.iw + iw) * jcp.ic; |
789 | PRAGMA_OMP_SIMD() |
790 | for (dim_t ic = 0; ic < jcp.ic; ++ic) { |
791 | im[im_idx + ic] += col[col_idx + ic]; |
792 | } |
793 | } |
794 | } |
795 | } |
796 | }); |
797 | } |
798 | |
799 | template void col2im_dt<int32_t>(const conv_gemm_conf_t &jcp, |
800 | const int32_t *__restrict col, int32_t *__restrict im); |
801 | |
802 | template void col2im_dt<float>(const conv_gemm_conf_t &jcp, |
803 | const float *__restrict col, float *__restrict im); |
804 | |
805 | template void col2im_dt<bfloat16_t>(const conv_gemm_conf_t &jcp, |
806 | const bfloat16_t *__restrict col, bfloat16_t *__restrict im); |
807 | |
808 | void col2im_3d(const conv_gemm_conf_t &jcp, const float *col, float *im, |
809 | dim_t od, int spatial_step, int spatial_block) { |
810 | |
811 | auto sp_blocked_ker = [&](dim_t ic) { |
812 | const size_t col_step = jcp.ks * spatial_block; |
813 | const float *__restrict col_ = col + ic * col_step; |
814 | float *__restrict im_ic = im + ic * jcp.ih * jcp.iw * jcp.id; |
815 | |
816 | const dim_t first_oh = spatial_step / jcp.ow; |
817 | const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow; |
818 | const dim_t oh_begin = first_oh; |
819 | const dim_t oh_end = last_oh + 1; |
820 | const dim_t first_ow = spatial_step % jcp.ow; |
821 | const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow; |
822 | const dim_t wei_stride |
823 | = nstl::min(jcp.ow * jcp.oh, dim_t(spatial_block)); |
824 | |
825 | dim_t id = od * jcp.stride_d - jcp.f_pad; |
826 | for (dim_t kd = 0; kd < jcp.kd; ++kd) { |
827 | if (id < 0 || id >= jcp.id) { |
828 | col_ += jcp.kh * jcp.kw * wei_stride; |
829 | id += (1 + jcp.dilate_d); |
830 | continue; |
831 | } |
832 | |
833 | float *__restrict im_ = im_ic + (size_t)id * jcp.ih * jcp.iw; |
834 | for_(dim_t kh = 0; kh < jcp.kh; ++kh) |
835 | for_(dim_t kw = 0; kw < jcp.kw; ++kw) |
836 | for (dim_t oh = oh_begin, col_off = 0; oh < oh_end; ++oh) { |
837 | |
838 | const dim_t ow_begin = (oh == first_oh) ? first_ow : 0; |
839 | const dim_t ow_end = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
840 | const dim_t ow_work = ow_end - ow_begin; |
841 | |
842 | const dim_t ih = oh * jcp.stride_h - jcp.t_pad |
843 | + kh * (1 + jcp.dilate_h); |
844 | if (ih < 0 || ih >= jcp.ih) { |
845 | col_off += ow_work; |
846 | continue; |
847 | } |
848 | |
849 | for (dim_t ow = ow_begin; ow < ow_end; ++ow, ++col_off) { |
850 | const dim_t iw = ow * jcp.stride_w - jcp.l_pad |
851 | + kw * (1 + jcp.dilate_w); |
852 | if (iw < 0 || iw >= jcp.iw) { continue; } |
853 | |
854 | const size_t col_idx |
855 | = (kh * jcp.kw + kw) * wei_stride + col_off; |
856 | const size_t im_idx = ih * jcp.iw + iw; |
857 | im_[im_idx] += col_[col_idx]; |
858 | } |
859 | } |
860 | col_ += jcp.kh * jcp.kw * wei_stride; |
861 | id += (1 + jcp.dilate_d); |
862 | } |
863 | }; |
864 | |
865 | auto ker = [&](dim_t ic) { |
866 | const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os; |
867 | float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id; |
868 | |
869 | dim_t id = od * jcp.stride_d - jcp.f_pad; |
870 | for (dim_t kd = 0; kd < jcp.kd; ++kd) { |
871 | if (id < 0 || id >= jcp.id) { |
872 | col_ += jcp.kh * jcp.kw * jcp.os; |
873 | id += (1 + jcp.dilate_d); |
874 | continue; |
875 | } |
876 | |
877 | float *__restrict im_ = im_ic + (size_t)id * jcp.ih * jcp.iw; |
878 | |
879 | for_(dim_t oh = 0; oh < jcp.oh; ++oh) |
880 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
881 | const dim_t ih = oh * jcp.stride_h - jcp.t_pad |
882 | + kh * (1 + jcp.dilate_h); |
883 | if (ih < 0 || ih >= jcp.ih) continue; |
884 | |
885 | for_(dim_t ow = 0; ow < jcp.ow; ++ow) |
886 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
887 | const dim_t iw = ow * jcp.stride_w - jcp.l_pad |
888 | + kw * (1 + jcp.dilate_w); |
889 | if (iw < 0 || iw >= jcp.iw) continue; |
890 | |
891 | const size_t col_idx |
892 | = ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow; |
893 | const size_t im_idx = ih * jcp.iw + iw; |
894 | im_[im_idx] += col_[col_idx]; |
895 | } |
896 | } |
897 | |
898 | col_ += jcp.kh * jcp.kw * jcp.os; |
899 | id += (1 + jcp.dilate_d); |
900 | } |
901 | }; |
902 | |
903 | const bool blocked_kernel = jcp.os_nb_block > 1; |
904 | if (blocked_kernel) |
905 | parallel_nd(jcp.ic, sp_blocked_ker); |
906 | else |
907 | parallel_nd(jcp.ic, ker); |
908 | } |
909 | |
910 | void col2im(const conv_gemm_conf_t &jcp, const float *col, float *im, |
911 | int spatial_step, int spatial_block) { |
912 | const size_t col_step = jcp.ks * spatial_block; |
913 | const size_t im_step = jcp.ih * jcp.iw; |
914 | const dim_t iS = jcp.ih * jcp.iw; |
915 | |
916 | auto sp_blocked_ker = [&](dim_t ic) { |
917 | const dim_t wei_stride |
918 | = nstl::min(jcp.ow * jcp.oh, dim_t(spatial_block)); |
919 | const dim_t first_oh = spatial_step / jcp.ow; |
920 | const dim_t last_oh = (spatial_step + spatial_block - 1) / jcp.ow; |
921 | const dim_t oh_begin = first_oh; |
922 | const dim_t oh_end = last_oh + 1; |
923 | const dim_t first_ow = spatial_step % jcp.ow; |
924 | const dim_t last_ow = (spatial_step + spatial_block - 1) % jcp.ow; |
925 | |
926 | float *__restrict img_ithr = im + ic * im_step; |
927 | const float *__restrict col_icb = col + ic * col_step; |
928 | |
929 | if (spatial_step == 0) { |
930 | PRAGMA_OMP_SIMD() |
931 | for (dim_t is = 0; is < iS; ++is) |
932 | img_ithr[is] = 0.; |
933 | } |
934 | |
935 | float *__restrict img_kh = img_ithr; |
936 | for (dim_t kh = 0; kh < jcp.kh; ++kh) { |
937 | float *__restrict im_ = img_kh; |
938 | for (dim_t kw = 0; kw < jcp.kw; ++kw) { |
939 | const float *__restrict col_ = col_icb; |
940 | for (dim_t oh = oh_begin; oh < oh_end; ++oh) { |
941 | const dim_t ow_begin = (oh == first_oh) ? first_ow : 0; |
942 | const dim_t ow_end |
943 | = (oh == last_oh) ? (last_ow + 1) : jcp.ow; |
944 | const dim_t ow_work = ow_end - ow_begin; |
945 | |
946 | const dim_t ih = oh * jcp.stride_h - jcp.t_pad; |
947 | const dim_t ih_ = ih + kh * (1 + jcp.dilate_h); |
948 | if (ih_ < 0 || ih_ >= jcp.ih) { |
949 | col_ += ow_work; |
950 | continue; |
951 | } |
952 | for (dim_t ow = ow_begin; ow < ow_end; ++ow, ++col_) { |
953 | const dim_t iw = ow * jcp.stride_w - jcp.l_pad; |
954 | const dim_t iw_ = iw + kw * (1 + jcp.dilate_w); |
955 | if (iw_ < 0 || iw_ >= jcp.iw) continue; |
956 | |
957 | const size_t im_idx = ih * jcp.iw + iw; |
958 | im_[im_idx] += *col_; |
959 | } |
960 | } |
961 | col_icb += wei_stride; |
962 | im_ += (1 + jcp.dilate_w); |
963 | } |
964 | img_kh += (jcp.iw * (1 + jcp.dilate_h)); |
965 | } |
966 | }; |
967 | |
968 | auto ker = [&](dim_t ic) { |
969 | float *__restrict im_ = im + ic * im_step; |
970 | const float *__restrict col_ = col + ic * col_step; |
971 | PRAGMA_OMP_SIMD() |
972 | for (dim_t is = 0; is < iS; ++is) |
973 | im_[is] = 0.; |
974 | |
975 | for_(dim_t kh = 0; kh < jcp.kh; ++kh) |
976 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
977 | const dim_t ih |
978 | = oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h); |
979 | if (ih < 0 || ih >= jcp.ih) continue; |
980 | |
981 | for_(dim_t kw = 0; kw < jcp.kw; ++kw) |
982 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
983 | const dim_t iw = ow * jcp.stride_w - jcp.l_pad |
984 | + kw * (1 + jcp.dilate_w); |
985 | if (iw < 0 || iw >= jcp.iw) continue; |
986 | |
987 | const size_t col_idx |
988 | = ((kh * jcp.kw + kw) * jcp.oh + oh) * jcp.ow + ow; |
989 | const size_t im_idx = ih * jcp.iw + iw; |
990 | im_[im_idx] += col_[col_idx]; |
991 | } |
992 | } |
993 | }; |
994 | |
995 | const bool blocked_kernel = jcp.os_nb_block > 1; |
996 | if (blocked_kernel) |
997 | parallel_nd(jcp.ic, sp_blocked_ker); |
998 | else |
999 | parallel_nd(jcp.ic, ker); |
1000 | } |
1001 | |
1002 | status_t init_conf(conv_gemm_conf_t &jcp, |
1003 | memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd, |
1004 | memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md, |
1005 | memory_desc_t &bias_md, primitive_attr_t &attr, int max_threads) { |
1006 | const memory_desc_wrapper src_d(&src_md); |
1007 | const memory_desc_wrapper weights_d(&weights_md); |
1008 | const memory_desc_wrapper dst_d(&dst_md); |
1009 | |
1010 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1011 | const int ndims = src_d.ndims(); |
1012 | const int is_1d = ndims == 3; |
1013 | const int is_3d = ndims == 5; |
1014 | |
1015 | jcp.prop_kind = cd.prop_kind; |
1016 | |
1017 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1018 | jcp.mb = src_d.dims()[0]; |
1019 | |
1020 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1021 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1022 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
1023 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
1024 | jcp.iw = src_d.dims()[ndims - 1]; |
1025 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
1026 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
1027 | jcp.ow = dst_d.dims()[ndims - 1]; |
1028 | |
1029 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
1030 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1031 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1032 | |
1033 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
1034 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
1035 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1036 | |
1037 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
1038 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
1039 | jcp.stride_w = cd.strides[ndims - 3]; |
1040 | |
1041 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
1042 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
1043 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1044 | |
1045 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef |
1046 | || cd.diff_bias_desc.format_kind != format_kind::undef; |
1047 | |
1048 | jcp.is = jcp.ih * jcp.iw; |
1049 | jcp.os = jcp.oh * jcp.ow; |
1050 | jcp.ks = jcp.kh * jcp.kw * jcp.kd; |
1051 | |
1052 | jcp.signed_input = src_d.data_type() == data_type::s8; |
1053 | |
1054 | jcp.outer_threading = false; |
1055 | |
1056 | jcp.zp = zero_point_config_t(attr); |
1057 | jcp.b_pad = nstl::max((jcp.oh - 1) * jcp.stride_h |
1058 | + (jcp.kh - 1) * (jcp.dilate_h + 1) |
1059 | - (jcp.ih + jcp.t_pad - 1), |
1060 | dim_t(0)); |
1061 | jcp.r_pad = nstl::max((jcp.ow - 1) * jcp.stride_w |
1062 | + (jcp.kw - 1) * (jcp.dilate_w + 1) |
1063 | - (jcp.iw + jcp.l_pad - 1), |
1064 | dim_t(0)); |
1065 | jcp.e_pad = nstl::max((jcp.od - 1) * jcp.stride_d |
1066 | + (jcp.kd - 1) * (jcp.dilate_d + 1) |
1067 | - (jcp.id + jcp.f_pad - 1), |
1068 | dim_t(0)); |
1069 | |
1070 | const bool zp_src_with_padding = jcp.zp.src_exists && padding_exists(jcp); |
1071 | |
1072 | if (zp_src_with_padding) { |
1073 | jcp.zp.src_pad_comp = zero_point_pad_comp_config_t(jcp.f_pad, jcp.e_pad, |
1074 | jcp.t_pad, jcp.b_pad, jcp.l_pad, jcp.r_pad, jcp.stride_d, |
1075 | jcp.stride_h, jcp.stride_w, jcp.od, jcp.oh, jcp.ow); |
1076 | } |
1077 | |
1078 | const auto set_or_check_tags |
1079 | = [&](format_tag_t desired_src_tag, format_tag_t desired_dst_tag, |
1080 | bool is_src_s8) -> status_t { |
1081 | using namespace format_tag; |
1082 | auto src_tag = any, dst_tag = any; |
1083 | |
1084 | if (src_d.format_kind() == format_kind::any) { |
1085 | CHECK(memory_desc_init_by_tag(src_md, desired_src_tag)); |
1086 | src_tag = desired_src_tag; |
1087 | } else { |
1088 | src_tag = memory_desc_matches_one_of_tag( |
1089 | src_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); |
1090 | } |
1091 | |
1092 | if (dst_d.format_kind() == format_kind::any) { |
1093 | CHECK(memory_desc_init_by_tag(dst_md, desired_dst_tag)); |
1094 | dst_tag = desired_dst_tag; |
1095 | } else { |
1096 | dst_tag = memory_desc_matches_one_of_tag( |
1097 | dst_md, nwc, nhwc, ndhwc, ncw, nchw, ncdhw); |
1098 | } |
1099 | |
1100 | if (src_tag == format_tag::undef || dst_tag == format_tag::undef) |
1101 | return status::unimplemented; |
1102 | if (src_tag != dst_tag) return status::unimplemented; |
1103 | |
1104 | if (jcp.with_bias && bias_md.format_kind == format_kind::any) |
1105 | CHECK(memory_desc_init_by_tag(bias_md, x)); |
1106 | |
1107 | const bool is_nspc = utils::one_of(src_tag, nwc, nhwc, ndhwc); |
1108 | jcp.is_nspc = is_nspc; |
1109 | |
1110 | memory_desc_t want_wei_md = weights_md; |
1111 | auto wei_tag = is_nspc |
1112 | ? (with_groups ? utils::pick(ndims - 3, wigo, hwigo, dhwigo) |
1113 | : utils::pick(ndims - 3, wio, hwio, dhwio)) |
1114 | : (with_groups ? utils::pick(ndims - 3, goiw, goihw, goidhw) |
1115 | : utils::pick(ndims - 3, oiw, oihw, oidhw)); |
1116 | CHECK(memory_desc_init_by_tag(want_wei_md, wei_tag)); |
1117 | |
1118 | if (is_src_s8) { |
1119 | want_wei_md.extra.flags = 0 |
1120 | | memory_extra_flags::compensation_conv_s8s8 |
1121 | | memory_extra_flags::scale_adjust; |
1122 | want_wei_md.extra.compensation_mask |
1123 | = (1 << 0) + (with_groups ? (1 << 1) : 0); |
1124 | want_wei_md.extra.scale_adjust |
1125 | = platform::s8s8_weights_scale_factor(); |
1126 | } |
1127 | |
1128 | if (jcp.zp.src_exists) set_zp_src_comp_flags(want_wei_md, with_groups); |
1129 | |
1130 | if (weights_md.format_kind == format_kind::any) { |
1131 | weights_md = want_wei_md; |
1132 | return status::success; |
1133 | } |
1134 | return (want_wei_md == weights_md) ? status::success |
1135 | : status::unimplemented; |
1136 | }; |
1137 | |
1138 | const bool is_bwd_d = jcp.prop_kind == backward_data; |
1139 | const bool is_bwd_w = jcp.prop_kind == backward_weights; |
1140 | const bool is_fwd = !is_bwd_d && !is_bwd_w; |
1141 | |
1142 | bool is_int8_conv = (is_fwd ? utils::one_of(src_d.data_type(), s8, u8) |
1143 | : utils::one_of(dst_d.data_type(), s8, u8)) |
1144 | && weights_d.data_type() == s8; |
1145 | |
1146 | auto default_dat_tag = is_int8_conv |
1147 | ? utils::pick(ndims - 3, format_tag::nwc, format_tag::nhwc, |
1148 | format_tag::ndhwc) |
1149 | : utils::pick(ndims - 3, format_tag::ncw, format_tag::nchw, |
1150 | format_tag::ncdhw); |
1151 | CHECK(set_or_check_tags(default_dat_tag, default_dat_tag, |
1152 | src_md.data_type == data_type::s8)); |
1153 | |
1154 | // Does int8 conv ever need to support ncsp input format |
1155 | if (is_int8_conv && !src_d.matches_one_of_tag(default_dat_tag)) |
1156 | return status::unimplemented; |
1157 | |
1158 | CHECK(attr.set_default_formats(&dst_md)); |
1159 | |
1160 | jcp.post_ops = attr.post_ops_; |
1161 | |
1162 | const int eltwise_ind = jcp.post_ops.find(primitive_kind::eltwise); |
1163 | jcp.with_eltwise = eltwise_ind != -1; |
1164 | const int binary_ind = jcp.post_ops.find(primitive_kind::binary); |
1165 | jcp.with_binary = binary_ind != -1; |
1166 | const int sum_ind = jcp.post_ops.find(primitive_kind::sum); |
1167 | jcp.with_sum = sum_ind != -1; |
1168 | |
1169 | bool is_bf16_conv = false |
1170 | || (is_fwd |
1171 | && utils::everyone_is( |
1172 | bf16, src_d.data_type(), weights_d.data_type())) |
1173 | || (is_bwd_d |
1174 | && utils::everyone_is( |
1175 | bf16, dst_d.data_type(), weights_d.data_type())) |
1176 | || (is_bwd_w |
1177 | && utils::everyone_is( |
1178 | bf16, src_d.data_type(), dst_d.data_type())); |
1179 | if (is_bf16_conv && !platform::has_data_type_support(bf16)) |
1180 | return status::unimplemented; |
1181 | |
1182 | const int vlen = std::max(platform::get_vector_register_size(), 4); |
1183 | const int data_size = (is_int8_conv ? 1 : (is_bf16_conv ? 2 : 4)); |
1184 | const int simd_w = vlen / data_size; |
1185 | |
1186 | jcp.os_block = jcp.os; |
1187 | jcp.os_nb_block = 1; |
1188 | jcp.oc_block = jcp.oc; |
1189 | jcp.ic_block = jcp.ic; |
1190 | jcp.loop_order = gemm_loop_rlb; |
1191 | jcp.nthr_oc = 1; |
1192 | |
1193 | jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; |
1194 | jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; |
1195 | |
1196 | using namespace memory_tracking::names; |
1197 | bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; |
1198 | |
1199 | // TODO: maybe mitigate blocking restriction |
1200 | const auto L2 = platform::get_per_core_cache_size(2) / data_size; |
1201 | const int gemm_thrld = 64 * 1024; |
1202 | |
1203 | // Heuristic threshold for requested scratchpad memory to avoid |
1204 | // possible crash on memory allocation: |
1205 | // 1Gb or size of the buffers already used for this convolution proportional |
1206 | // to the number of threads and multiplied by a heuristic coefficient (15) |
1207 | const size_t zp_src_pad_comp_size = zp_src_with_padding |
1208 | ? (jcp.oc * jcp.ngroups * jcp.zp.src_pad_comp.d |
1209 | * jcp.zp.src_pad_comp.h * jcp.zp.src_pad_comp.w) |
1210 | : 0u; |
1211 | const size_t zp_src_comp_size = jcp.zp.src_is_common |
1212 | ? utils::rnd_up(jcp.oc * jcp.ngroups, |
1213 | platform::get_cache_line_size() / sizeof(int)) |
1214 | : 0u; |
1215 | |
1216 | const size_t weights_size = weights_d.size() |
1217 | + (zp_src_comp_size + zp_src_pad_comp_size) * sizeof(int32_t); |
1218 | |
1219 | static constexpr size_t scratchpad_limit_by_absolute_value = (size_t)1 |
1220 | << 30; // 1Gb |
1221 | const size_t scratchpad_limit_by_tensor_sizes |
1222 | = 15 * max_threads * (src_d.size() + weights_size + dst_d.size()); |
1223 | const size_t scratchpad_limit |
1224 | = nstl::min(scratchpad_limit_by_absolute_value, |
1225 | scratchpad_limit_by_tensor_sizes); |
1226 | |
1227 | if (is_int8_conv) { |
1228 | if (is_fwd) { |
1229 | jcp.im2col_sz |
1230 | = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih, |
1231 | jcp.od == jcp.id, jcp.stride_w == 1, |
1232 | jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1, |
1233 | !jcp.signed_input) |
1234 | ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os |
1235 | : 0; |
1236 | |
1237 | dim_t wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; |
1238 | bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz |
1239 | && !is_3d && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
1240 | && !is_depthwise && wei_size < L2 / 2; |
1241 | if (is_blocking_applicable) { |
1242 | // looking for oh and ow blocking |
1243 | dim_t h_block {jcp.oh_block}, w_block {jcp.ow_block}; |
1244 | dim_t ic = jcp.ic; |
1245 | dim_t oc = jcp.oc; |
1246 | dim_t iw = jcp.iw; |
1247 | dim_t ow = jcp.ow; |
1248 | dim_t oh = jcp.oh; |
1249 | dim_t os = oh * ow; |
1250 | |
1251 | // 1. cache requirement |
1252 | dim_t row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); |
1253 | // Heuristic rule: gemm needed a lot of memory for internal |
1254 | // usage |
1255 | row_size *= 5; |
1256 | // memory for accumulators |
1257 | row_size += oc * ow * sizeof(uint32_t); |
1258 | // memory for transposition |
1259 | row_size += ic * iw; |
1260 | |
1261 | h_block = nstl::max( |
1262 | dim_t(1), nstl::min(oh, div_up(dim_t(L2), row_size))); |
1263 | if (h_block == 1) { |
1264 | dim_t col_size = ic * jcp.ks + 2 * (ic + oc); |
1265 | if (is_int8_conv) { |
1266 | col_size *= 5; |
1267 | col_size += oc * sizeof(uint32_t); |
1268 | col_size += ic; |
1269 | } |
1270 | w_block = nstl::max(dim_t(1), |
1271 | nstl::min(ow, div_up(dim_t(L2), col_size))); |
1272 | } |
1273 | |
1274 | // 2. threading requirement |
1275 | if (h_block != oh) |
1276 | h_block = nstl::max(dim_t(1), rnd_dn(h_block, dim_t(4))); |
1277 | if (w_block != ow) |
1278 | w_block = nstl::max(dim_t(1), rnd_dn(w_block, simd_w)); |
1279 | |
1280 | float thr_eff = 0.f; |
1281 | float thr_eff_treshold = 0.9f; |
1282 | if (w_block == ow) { |
1283 | do { |
1284 | dim_t nb_h = div_up(oh, h_block); |
1285 | dim_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; |
1286 | float disb = (float)oh / rnd_up(oh, h_block); |
1287 | thr_eff = (float)work / rnd_up(work, max_threads); |
1288 | thr_eff = (thr_eff + disb) / 2.f; |
1289 | if (thr_eff >= thr_eff_treshold) break; |
1290 | h_block = rnd_dn(h_block - 4, 4); |
1291 | } while (h_block > 0); |
1292 | } |
1293 | if (thr_eff |
1294 | < thr_eff_treshold) // we didn't find suitable h_block |
1295 | { |
1296 | h_block = 1; |
1297 | int nb_h = oh; |
1298 | do { |
1299 | dim_t nb_w = div_up(ow, w_block); |
1300 | dim_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; |
1301 | float disb = (float)ow / rnd_up(ow, w_block); |
1302 | thr_eff = (float)work_amount |
1303 | / rnd_up(work_amount, max_threads); |
1304 | thr_eff = (thr_eff + disb) / 2.f; |
1305 | if (thr_eff > thr_eff_treshold) break; |
1306 | w_block = rnd_dn(w_block - simd_w, simd_w); |
1307 | } while (w_block > 0); |
1308 | } |
1309 | h_block = nstl::max(dim_t(1), h_block); |
1310 | w_block = nstl::max(dim_t(1), w_block); |
1311 | dim_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); |
1312 | const float inner_thr_eff |
1313 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1314 | if (thr_eff >= inner_thr_eff / 2 && h_block > 0 |
1315 | && w_block > 0) { |
1316 | jcp.oh_block = h_block; |
1317 | jcp.ow_block = w_block; |
1318 | jcp.outer_threading = true; |
1319 | } |
1320 | // updating jcp.im2col_sz |
1321 | if (jcp.oh_block != 1) jcp.ow_block = ow; |
1322 | jcp.im2col_sz |
1323 | = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; |
1324 | } |
1325 | // For threading selection in bwd_d we do: |
1326 | // 1. Rough estimation of efficiency for inner and outer threading. |
1327 | // 2. Gemm size estimation in assumption that it does not work |
1328 | // so effectively for small sizes. |
1329 | // 64K - this is heuristic gemm size per thread threshold. |
1330 | const int gemm_thrld = 64 * 1024; |
1331 | if (!jcp.outer_threading && !is_3d) { |
1332 | bool is_depthwise |
1333 | = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; |
1334 | const dim_t outer_work = jcp.ngroups * jcp.mb; |
1335 | const float outer_thr_eff |
1336 | = (float)outer_work / rnd_up(outer_work, max_threads); |
1337 | const size_t inner_work |
1338 | = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); |
1339 | const float inner_thr_eff |
1340 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1341 | jcp.outer_threading |
1342 | = (is_depthwise |
1343 | || (jcp.is / max_threads < 64 && jcp.mb != 1)) |
1344 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1345 | || (jcp.os * jcp.ic * jcp.oc) / max_threads |
1346 | < gemm_thrld); |
1347 | } |
1348 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
1349 | scratchpad.book<int8_t>( |
1350 | key_conv_gemm_col, jcp.nthr * jcp.im2col_sz); |
1351 | scratchpad.book<int32_t>(key_conv_int_dat_in_acc_dt, |
1352 | jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); |
1353 | scratchpad.book<int8_t>( |
1354 | key_conv_gemm_imtr, jcp.nthr * jcp.id * jcp.is * jcp.ic); |
1355 | } else if (is_bwd_d) { |
1356 | jcp.im2col_sz |
1357 | = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih, |
1358 | jcp.od == jcp.id, jcp.stride_w == 1, |
1359 | jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1, |
1360 | !jcp.signed_input) |
1361 | ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os * jcp.od |
1362 | : 0; |
1363 | |
1364 | bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; |
1365 | const size_t outer_work = jcp.ngroups * jcp.mb; |
1366 | const float outer_thr_eff |
1367 | = (float)outer_work / rnd_up(outer_work, max_threads); |
1368 | const size_t inner_work |
1369 | = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); |
1370 | const float inner_thr_eff |
1371 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1372 | jcp.outer_threading = !is_3d |
1373 | && (is_depthwise |
1374 | || (jcp.is / max_threads < 64 && jcp.mb != 1)) |
1375 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1376 | || (jcp.is * jcp.ic * jcp.oc) / max_threads |
1377 | < gemm_thrld); |
1378 | |
1379 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
1380 | scratchpad.book<int32_t>( |
1381 | key_conv_gemm_col, jcp.nthr * jcp.im2col_sz); |
1382 | scratchpad.book<int32_t>(key_conv_int_dat_in_acc_dt, |
1383 | jcp.nthr * jcp.is * jcp.id * jcp.ic); |
1384 | } else if (is_bwd_w) { |
1385 | assert(!"unimplemented prop_kind" ); |
1386 | return status::unimplemented; |
1387 | } |
1388 | } else { |
1389 | jcp.im2col_sz = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih, |
1390 | jcp.od == jcp.id, jcp.stride_w == 1, |
1391 | jcp.stride_h == 1, jcp.stride_d == 1, |
1392 | jcp.ks == 1, !jcp.signed_input) |
1393 | ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os |
1394 | : 0; |
1395 | if (jcp.is_nspc && is_fwd) { |
1396 | const size_t wei_size |
1397 | = static_cast<size_t>(jcp.oc) * jcp.ic * jcp.kh * jcp.kw; |
1398 | bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz |
1399 | && !is_3d && jcp.dilate_h == 0 && jcp.dilate_w == 0 |
1400 | && !is_depthwise && wei_size < static_cast<size_t>(L2) / 2; |
1401 | // Logic for blocking for f32_nspc gemm convolution follows that of |
1402 | // int8_nspc gemm convolution. Currently, not optimized for f32 |
1403 | // data type. |
1404 | if (is_blocking_applicable) { |
1405 | // looking for oh and ow blocking |
1406 | size_t h_block = jcp.oh_block; |
1407 | size_t w_block = jcp.ow_block; |
1408 | |
1409 | const size_t ic = jcp.ic; |
1410 | const size_t oc = jcp.oc; |
1411 | const size_t iw = jcp.iw; |
1412 | const size_t ow = jcp.ow; |
1413 | const size_t oh = jcp.oh; |
1414 | const size_t os = oh * ow; |
1415 | |
1416 | // 1. cache requirement |
1417 | size_t row_size = ic * ow * jcp.ks * data_size |
1418 | + 2 * (ic * iw + oc * ow) * data_size; |
1419 | // Heuristic rule: gemm needed a lot of memory for internal |
1420 | // usage |
1421 | row_size *= 5; |
1422 | // memory for accumulators |
1423 | row_size += oc * ow * data_size; |
1424 | // memory for transposition |
1425 | row_size += ic * iw * data_size; |
1426 | |
1427 | const size_t L2_rows = div_up(L2, row_size); |
1428 | h_block = saturate(size_t {1}, L2_rows, oh); |
1429 | if (h_block == 1) { |
1430 | size_t col_size = ic * jcp.ks * data_size |
1431 | + 2 * (ic + oc) * data_size; |
1432 | const size_t L2_cols = div_up(L2, col_size); |
1433 | w_block = saturate(size_t {1}, L2_cols, ow); |
1434 | } |
1435 | |
1436 | // 2. threading requirement |
1437 | if (h_block != oh) |
1438 | h_block = nstl::max(size_t {1}, rnd_dn(h_block, 4)); |
1439 | if (w_block != ow) |
1440 | w_block = nstl::max(size_t {1}, rnd_dn(w_block, simd_w)); |
1441 | |
1442 | float thr_eff = 0.f; |
1443 | float thr_eff_treshold = 0.9f; |
1444 | if (w_block == ow) { |
1445 | do { |
1446 | size_t nb_h = div_up(oh, h_block); |
1447 | size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; |
1448 | float disb = (float)oh / rnd_up(oh, h_block); |
1449 | thr_eff = (float)work / rnd_up(work, max_threads); |
1450 | thr_eff = (thr_eff + disb) / 2.f; |
1451 | if (thr_eff >= thr_eff_treshold) break; |
1452 | |
1453 | if (h_block < 4) |
1454 | h_block = 0; |
1455 | else |
1456 | h_block = rnd_dn(h_block - 4, 4); |
1457 | } while (h_block > 0); |
1458 | } |
1459 | if (thr_eff |
1460 | < thr_eff_treshold) // we didn't find suitable h_block |
1461 | { |
1462 | h_block = 1; |
1463 | size_t nb_h = oh; |
1464 | do { |
1465 | size_t nb_w = div_up(ow, w_block); |
1466 | size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; |
1467 | float disb = (float)ow / rnd_up(ow, w_block); |
1468 | thr_eff = (float)work_amount |
1469 | / rnd_up(work_amount, max_threads); |
1470 | thr_eff = (thr_eff + disb) / 2.f; |
1471 | if (thr_eff > thr_eff_treshold) break; |
1472 | |
1473 | if (w_block < static_cast<size_t>(simd_w)) |
1474 | w_block = 0; |
1475 | else |
1476 | w_block = rnd_dn(w_block - simd_w, simd_w); |
1477 | } while (w_block > 0); |
1478 | } |
1479 | h_block = nstl::max(size_t {1}, h_block); |
1480 | w_block = nstl::max(size_t {1}, w_block); |
1481 | const size_t inner_work |
1482 | = div_up(os, simd_w) * div_up(oc, simd_w); |
1483 | const float inner_thr_eff |
1484 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1485 | if (thr_eff >= inner_thr_eff / 2 && h_block > 0 |
1486 | && w_block > 0) { |
1487 | jcp.oh_block = static_cast<int>(h_block); |
1488 | jcp.ow_block = static_cast<int>(w_block); |
1489 | jcp.outer_threading = true; |
1490 | } |
1491 | // updating jcp.im2col_sz |
1492 | if (jcp.oh_block != 1) jcp.ow_block = static_cast<int>(ow); |
1493 | jcp.im2col_sz |
1494 | = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; |
1495 | } |
1496 | // For threading selection in fwd_d we do: |
1497 | // 1. Rough estimation of efficiency for inner and outer threading. |
1498 | // 2. Gemm size estimation in assumption that it does not work |
1499 | // so effectively for small sizes. |
1500 | // 64K - this is heuristic gemm size per thread threshold. |
1501 | constexpr size_t gemm_thrld = 64 * 1024; |
1502 | if (!jcp.outer_threading && !is_3d) { |
1503 | bool is_depthwise |
1504 | = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; |
1505 | const size_t outer_work = jcp.ngroups * jcp.mb; |
1506 | const float outer_thr_eff |
1507 | = (float)outer_work / rnd_up(outer_work, max_threads); |
1508 | const size_t inner_work |
1509 | = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); |
1510 | const float inner_thr_eff |
1511 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1512 | jcp.outer_threading |
1513 | = (is_depthwise |
1514 | || (jcp.is / max_threads < 64 && jcp.mb != 1)) |
1515 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1516 | || (static_cast<size_t>(jcp.os) * jcp.ic |
1517 | * jcp.oc) |
1518 | / max_threads |
1519 | < gemm_thrld); |
1520 | } |
1521 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
1522 | const size_t gemm_col_datatype_size |
1523 | = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float); |
1524 | |
1525 | scratchpad.book(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz, |
1526 | gemm_col_datatype_size); |
1527 | if (is_bf16_conv) { |
1528 | scratchpad.book<float>(key_conv_gemm_acc, |
1529 | jcp.nthr * static_cast<size_t>(jcp.oh_block) |
1530 | * jcp.ow_block * jcp.oc); |
1531 | } |
1532 | |
1533 | scratchpad.book(key_conv_gemm_imtr, |
1534 | jcp.nthr * static_cast<size_t>(jcp.id) * jcp.is * jcp.ic, |
1535 | gemm_col_datatype_size); |
1536 | if (is_bf16_conv && jcp.with_bias |
1537 | && one_of(data_type::bf16, cd.diff_bias_desc.data_type, |
1538 | cd.bias_desc.data_type)) { |
1539 | scratchpad.book<float>( |
1540 | key_conv_bias_bf16_convert_wsp, jcp.ngroups * jcp.oc); |
1541 | } |
1542 | |
1543 | } else if (!jcp.is_nspc && is_fwd) { |
1544 | const dim_t sh = jcp.stride_h; |
1545 | const dim_t sw = jcp.stride_w; |
1546 | const dim_t spatial = jcp.mb * jcp.ngroups * jcp.od * jcp.os; |
1547 | dim_t K = jcp.ic * jcp.ks; |
1548 | |
1549 | // There is some heuristics in the definition of |
1550 | // inner/outer threading cross point due to the nature of the |
1551 | // gemm implementation which we cannot control |
1552 | bool is_blocking_applicable = true && !is_3d |
1553 | && (!jcp.im2col_sz |
1554 | // spatial is small |
1555 | || spatial >= max_threads * simd_w |
1556 | // inner threading work is greater then outer |
1557 | // threading work |
1558 | || jcp.os < jcp.mb * jcp.ngroups * jcp.od |
1559 | // im2col is big |
1560 | || (sw == 1 && K <= 0.05 * jcp.oc)) |
1561 | // heuristic condition |
1562 | && (jcp.im2col_sz |
1563 | || (jcp.ic / jcp.oc < 42 |
1564 | && jcp.ic * jcp.oc * jcp.is < 1024)); |
1565 | |
1566 | if (is_blocking_applicable) { |
1567 | const dim_t min_oc_block = 8; |
1568 | const dim_t min_os_block = simd_w; |
1569 | const float non_cache_access = 20; |
1570 | const float strided_im2col_k = 8; |
1571 | const float thr_disb_k = 8; |
1572 | const float thr_mem_eff_k {1}, oc_disb_k {1}, os_disb_k {1}, |
1573 | ic_disb_k {1}, reg_osb_disb_k {1}, gemm_eff_k {0.5}, |
1574 | gemm_calc_eff_k {1}; |
1575 | const float k_sum = thr_disb_k + oc_disb_k + os_disb_k |
1576 | + ic_disb_k + reg_osb_disb_k + thr_mem_eff_k |
1577 | + gemm_eff_k + gemm_calc_eff_k; |
1578 | |
1579 | auto calc_max_icb = [=](dim_t nthr_oc, dim_t ocb, dim_t osb, |
1580 | dim_t oc_per_thr, |
1581 | dim_t os_per_thr) { |
1582 | const dim_t block_out_size = ocb * osb; |
1583 | // TODO: need more precise calculation if stride more than |
1584 | // kernel size |
1585 | const dim_t inp_row_size = sh * sw * osb; |
1586 | dim_t max_icb = 1; |
1587 | if (jcp.im2col_sz) { |
1588 | const dim_t col_row_size = jcp.ks * osb; |
1589 | if (osb >= os_per_thr) { // one pass by os |
1590 | const dim_t wei_col_size = jcp.ks * ocb; |
1591 | max_icb = L2 / (inp_row_size + col_row_size); |
1592 | if (ocb < oc_per_thr) { |
1593 | max_icb = nstl::min(max_icb, |
1594 | (L2 - block_out_size) |
1595 | / (col_row_size |
1596 | + wei_col_size)); |
1597 | } |
1598 | } else { |
1599 | const dim_t wei_col_size = jcp.ks * oc_per_thr; |
1600 | max_icb = (L2 - block_out_size) |
1601 | / (inp_row_size + col_row_size |
1602 | + wei_col_size); |
1603 | } |
1604 | } else { |
1605 | if (osb >= os_per_thr) |
1606 | max_icb = L2 / inp_row_size; |
1607 | else { |
1608 | const dim_t wei_col_size = jcp.ks * oc_per_thr; |
1609 | max_icb = L2 / (inp_row_size + wei_col_size); |
1610 | } |
1611 | } |
1612 | if (max_icb < jcp.ic) { |
1613 | if (jcp.im2col_sz) { |
1614 | const dim_t col_row_size = jcp.ks * osb; |
1615 | const dim_t wei_col_size = jcp.ks * oc_per_thr; |
1616 | max_icb = (L2 - block_out_size) |
1617 | / (inp_row_size + col_row_size |
1618 | + wei_col_size); |
1619 | } |
1620 | } |
1621 | return max_icb; |
1622 | }; |
1623 | |
1624 | dim_t best_ocb {1}, best_osb {1}; |
1625 | dim_t best_nthr_oc {1}; |
1626 | dim_t best_icb {jcp.ic}; |
1627 | float best_thr_eff = 0; |
1628 | |
1629 | auto try_cfg = [&](dim_t nthr_oc, dim_t ocb, dim_t osb) { |
1630 | // for given nthr_oc, oc block: |
1631 | // 1. find ic block to fit into cache |
1632 | // 2. estimate efficiency basing on rules and heuristic: |
1633 | // - Minimize im2col cost |
1634 | // - ratio of FMA number to data size |
1635 | // - gemm works better if M divided by 48 and N divided by 8 |
1636 | |
1637 | const dim_t max_oc = div_up(jcp.oc, nthr_oc); |
1638 | const dim_t min_oc = nstl::max(dim_t(1), jcp.oc / nthr_oc); |
1639 | const dim_t max_os |
1640 | = div_up(spatial, (dim_t)(max_threads / nthr_oc)); |
1641 | ocb = utils::saturate(min_oc_block, max_oc, ocb); |
1642 | osb = utils::saturate(min_os_block, max_os, osb); |
1643 | |
1644 | // The computation of max_thr_size and min_thr_size is |
1645 | // based on work balance using: |
1646 | // balance2D(max_threads, i, spatial, sp_start, sp_end, |
1647 | // jcp.oc, oc_start, oc_end, nthr_oc); |
1648 | size_t max_thr_size = 1; |
1649 | { |
1650 | const dim_t min_os = div_up( |
1651 | spatial, (dim_t)div_up(max_threads, nthr_oc)); |
1652 | /* --- compute max_thr_size ------------ |
1653 | may not necessarily be (max_oc * max_os) |
1654 | thr_size = thr_oc * (spatial /nthrs_in_slice); |
1655 | with spatial as const, thr_size has maxima when |
1656 | (A: thr_oc is max) and (B: nthrs_in_slice is min) |
1657 | */ |
1658 | if (jcp.oc % nthr_oc > max_threads % nthr_oc) { |
1659 | // If (A) and (B) are true together, then it is the |
1660 | // global max |
1661 | max_thr_size = max_oc * max_os; |
1662 | } else { |
1663 | const size_t oc_max_os_min = max_oc * min_os; |
1664 | const size_t oc_min_os_max = min_oc * max_os; |
1665 | max_thr_size |
1666 | = nstl::max(oc_max_os_min, oc_min_os_max); |
1667 | } |
1668 | } |
1669 | |
1670 | size_t min_thr_size {1}; |
1671 | { |
1672 | const dim_t min_os = nstl::max(dim_t(1), |
1673 | spatial / div_up(max_threads, nthr_oc)); |
1674 | /* --- compute min_thr_size ------------ |
1675 | may not necessarily be (min_oc * min_y) |
1676 | thr_size = thr_oc * (spatial /nthrs_in_slice); |
1677 | with spatial as const, thr_size has minima when |
1678 | (A: thr_oc is min) and (B: nthrs_in_slice is max) |
1679 | */ |
1680 | if (max_threads % nthr_oc > jcp.oc % nthr_oc) { |
1681 | // If (A) and (B) are true together, then it is the |
1682 | // global min |
1683 | min_thr_size = min_oc * min_os; |
1684 | } else { |
1685 | const size_t oc_max_os_min = max_oc * min_os; |
1686 | const size_t oc_min_os_max = min_oc |
1687 | * (size_t)(spatial |
1688 | / (dim_t)(max_threads / nthr_oc)); |
1689 | min_thr_size |
1690 | = nstl::min(oc_max_os_min, oc_min_os_max); |
1691 | } |
1692 | } |
1693 | auto thr_disb = (float)min_thr_size / max_thr_size; |
1694 | |
1695 | const dim_t oc_per_thr = max_oc; |
1696 | const dim_t os_per_thr = max_os; |
1697 | ocb = nstl::min(oc_per_thr, ocb); |
1698 | const dim_t os_max = nstl::min(jcp.os, os_per_thr); |
1699 | osb = nstl::min(os_max, osb); |
1700 | |
1701 | // -- selecting icb --------------------- |
1702 | dim_t max_ic_block = calc_max_icb( |
1703 | nthr_oc, ocb, osb, oc_per_thr, os_per_thr); |
1704 | // if we don't fit into cache then access to memory is |
1705 | // expensive |
1706 | dim_t mem_access_cost |
1707 | = (max_ic_block < 1) ? non_cache_access : 1; |
1708 | max_ic_block = nstl::max(dim_t(1), max_ic_block); |
1709 | dim_t icb = nstl::max( |
1710 | dim_t(1), jcp.ic / div_up(jcp.ic, max_ic_block)); |
1711 | dim_t nb_ic = div_up(jcp.ic, icb); |
1712 | dim_t kb = icb * jcp.ks; |
1713 | dim_t kb_caligned = rnd_up(kb, simd_w); |
1714 | |
1715 | // -- mem efficiency ------------ |
1716 | const size_t out_size |
1717 | = oc_per_thr * rnd_up(os_per_thr, simd_w); |
1718 | const size_t out_ops = mem_access_cost * out_size |
1719 | * ((icb == jcp.ic) ? 1 : (2 * nb_ic - 1)); |
1720 | const dim_t osb_caligned = rnd_up(osb, simd_w); |
1721 | const size_t inp_size |
1722 | = jcp.ic * rnd_up(os_per_thr * sh * sw, simd_w); |
1723 | size_t inp_ops = 0; |
1724 | size_t col_ops = 0; |
1725 | // TODO: simplify calculations |
1726 | if (jcp.im2col_sz) { |
1727 | inp_ops = mem_access_cost * jcp.ks * inp_size; |
1728 | const float col_tail_koeff = (float)osb_caligned / osb; |
1729 | col_ops = mem_access_cost |
1730 | * (jcp.ks * inp_size * col_tail_koeff |
1731 | + jcp.ks * inp_size * col_tail_koeff); |
1732 | if (sw != 1) // im2col with strides is much slower |
1733 | col_ops *= strided_im2col_k; |
1734 | } else { |
1735 | inp_ops = mem_access_cost * jcp.ks * inp_size; |
1736 | } |
1737 | // TODO: what about groups? |
1738 | const size_t wei_size = oc_per_thr * rnd_up(K, simd_w); |
1739 | const size_t wei_ops = mem_access_cost * wei_size; |
1740 | // ratio of real FMA to number of memory ops |
1741 | const float thr_mem_eff |
1742 | = (((float)os_per_thr / simd_w) * oc_per_thr * K) |
1743 | / (inp_ops + col_ops + wei_ops + out_ops); |
1744 | |
1745 | auto oc_disb = (float)oc_per_thr / rnd_up(oc_per_thr, ocb); |
1746 | auto os_disb = (float)os_max / rnd_up(os_max, osb); |
1747 | auto ic_disb = (float)jcp.ic / rnd_up(jcp.ic, icb); |
1748 | |
1749 | auto reg_osb_disb = (float)osb / rnd_up(osb, 3 * simd_w); |
1750 | |
1751 | // Heuristics |
1752 | const float gemm_eff = ((float)osb * ocb * kb) |
1753 | / ((float)oc_per_thr * os_per_thr * K); |
1754 | |
1755 | // number of FMA to memory size |
1756 | const float gemm_calc_eff |
1757 | = (((float)osb / simd_w) * ocb * kb) |
1758 | / (osb_caligned * kb + ocb * kb_caligned |
1759 | + ocb * osb_caligned); |
1760 | // optimization: remove pow, when corresponding weight is 1 |
1761 | const float res_eff = pow(pow(thr_disb, thr_disb_k) |
1762 | * oc_disb // pow(oc_disb, oc_disb_k) |
1763 | * os_disb // pow(os_disb, os_disb_k) |
1764 | * ic_disb // pow(ic_disb, ic_disb_k) |
1765 | // pow(reg_osb_disb, reg_osb_disb_k) |
1766 | * reg_osb_disb |
1767 | //pow(thr_mem_eff, thr_mem_eff_k) |
1768 | * thr_mem_eff |
1769 | //pow(gemm_calc_eff, gemm_calc_eff_k) |
1770 | * pow(gemm_eff, gemm_eff_k) * gemm_calc_eff, |
1771 | 1.f / k_sum); |
1772 | |
1773 | if (res_eff > best_thr_eff) { |
1774 | best_thr_eff = res_eff; |
1775 | best_nthr_oc = nthr_oc; |
1776 | best_ocb = ocb; |
1777 | best_osb = osb; |
1778 | best_icb = icb; |
1779 | } |
1780 | }; |
1781 | |
1782 | auto explore_cfg = [&](dim_t nthr_oc, dim_t ocb, dim_t osb) { |
1783 | try_cfg(nthr_oc, ocb, osb); |
1784 | // few combinations to try, as the eff is better when ocb is |
1785 | // multiple of 8 and osb is multiple of 48 or min_os_block. |
1786 | try_cfg(nthr_oc, rnd_dn(ocb, 8), rnd_dn(osb, 48)); |
1787 | try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_dn(osb, 48)); |
1788 | try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_up(osb, min_os_block)); |
1789 | try_cfg(nthr_oc, rnd_up(ocb, 8), rnd_up(osb, 48)); |
1790 | }; |
1791 | |
1792 | for (dim_t nthr_oc = 1; nthr_oc <= max_threads; ++nthr_oc) { |
1793 | const dim_t max_oc_per_thr = div_up(jcp.oc, nthr_oc); |
1794 | dim_t max_os_per_thr |
1795 | = div_up(spatial, max_threads / nthr_oc); |
1796 | dim_t ocb {1}, osb {1}, icb {1}; |
1797 | if (jcp.im2col_sz) { |
1798 | try_cfg(nthr_oc, max_oc_per_thr, max_os_per_thr); |
1799 | if ((best_ocb == max_oc_per_thr) |
1800 | && (best_osb == max_os_per_thr) |
1801 | && (best_icb == jcp.ic)) { |
1802 | // best case scenario |
1803 | continue; |
1804 | } |
1805 | |
1806 | /* |
1807 | memory eq from calc_max_icb(): |
1808 | max_icb = (L2 - block_out_size) |
1809 | / (inp_row_size + col_row_size |
1810 | + wei_col_size); |
1811 | icb*sh*sw*osb + icb*jcp.ks*osb + |
1812 | jcp.ks*max_oc_per_thr*icb + osb *ocb = L2 |
1813 | |
1814 | a_k*icb*osb + b_k*icb + osb*ocb = L2 |
1815 | We would like to maximize icb*osb*ocb (FMA). |
1816 | |
1817 | Unfortunately, above eq and constraint doesn't have |
1818 | a single solution. So, based on experiments we try |
1819 | few scenarios. |
1820 | 1. icb = jcp.ic |
1821 | 2. Solving the constraint eq we get |
1822 | osb = (L2 - 2*b_k*icb)/(2*a_k*icb) >= min_oc_block |
1823 | => icb <= (L2)/(2* min_oc_block * a_k + 2 * b_k) |
1824 | 3. Maximize channel compute: |
1825 | ocb = max_oc_per_thr; |
1826 | icb = jcp.ic; |
1827 | */ |
1828 | dim_t a_k = sh * sw + jcp.ks; |
1829 | dim_t b_k = jcp.ks * max_oc_per_thr; |
1830 | |
1831 | // Note 1: |
1832 | icb = jcp.ic; |
1833 | ocb = utils::saturate(min_oc_block, max_oc_per_thr, |
1834 | (L2 - a_k * icb * min_os_block - b_k * icb) |
1835 | / min_os_block); |
1836 | osb = utils::saturate(min_os_block, max_os_per_thr, |
1837 | (L2 - b_k * icb) / (a_k * icb + ocb)); |
1838 | explore_cfg(nthr_oc, ocb, osb); |
1839 | |
1840 | // Note 2: |
1841 | const dim_t icb_max = nstl::max(dim_t(1), |
1842 | L2 / (2 * min_oc_block * a_k + 2 * b_k)); |
1843 | if (icb_max < jcp.ic) { |
1844 | // adjust icb, such that it is evenly distributed. |
1845 | icb = jcp.ic |
1846 | / nstl::max(dim_t(1), jcp.ic / icb_max); |
1847 | osb = nstl::max(dim_t(1), |
1848 | (L2 - 2 * b_k * icb) / (2 * icb * a_k)); |
1849 | ocb = L2 / 2 / osb; |
1850 | |
1851 | if (ocb > max_oc_per_thr) { |
1852 | ocb = max_oc_per_thr; |
1853 | // reduce mem eq by making ocb constant. we get |
1854 | osb = utils::saturate(min_os_block, |
1855 | max_os_per_thr, |
1856 | (L2 - b_k * icb) / (a_k * icb + ocb)); |
1857 | } else if (osb > max_os_per_thr) { |
1858 | // reduce mem eq by making osb constant. we get |
1859 | osb = max_os_per_thr; |
1860 | ocb = utils::saturate(min_oc_block, |
1861 | max_oc_per_thr, |
1862 | (L2 - a_k * icb * osb - b_k * icb) |
1863 | / (osb)); |
1864 | } |
1865 | |
1866 | explore_cfg(nthr_oc, ocb, osb); |
1867 | } |
1868 | |
1869 | // Note 3: |
1870 | ocb = max_oc_per_thr; |
1871 | icb = jcp.ic; |
1872 | osb = nstl::max(min_os_block, |
1873 | rnd_dn((L2 - b_k * icb) / (a_k * icb + ocb), |
1874 | min_os_block)); |
1875 | explore_cfg(nthr_oc, ocb, osb); |
1876 | |
1877 | } else { |
1878 | // from calc_max_icb, memory eq is independent of ocb. |
1879 | // So, set it to maximum. |
1880 | ocb = max_oc_per_thr; |
1881 | osb = (L2 - jcp.ks * jcp.ic) / (sh * sw * jcp.ic); |
1882 | explore_cfg(nthr_oc, ocb, osb); |
1883 | } |
1884 | } |
1885 | jcp.outer_threading = true; |
1886 | jcp.nthr_oc = best_nthr_oc; |
1887 | jcp.oc_block = best_ocb; |
1888 | jcp.os_block = best_osb; |
1889 | jcp.ic_block = best_icb; |
1890 | |
1891 | // TODO: define loop order |
1892 | // if im2col then gemm_loop_rlb and gemm_loop_lrb looks |
1893 | // preferable otherwise other loop orders are possible |
1894 | jcp.loop_order = gemm_loop_rlb; |
1895 | } else { |
1896 | const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; |
1897 | const float outer_thr_eff = (float)outer_work_amount |
1898 | / rnd_up(outer_work_amount, max_threads); |
1899 | const size_t inner_work_amount |
1900 | = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w); |
1901 | const float inner_thr_eff = (float)inner_work_amount |
1902 | / rnd_up(inner_work_amount, max_threads); |
1903 | jcp.outer_threading = jcp.os / max_threads < 512 |
1904 | && IMPLICATION( |
1905 | jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) |
1906 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1907 | || (jcp.os * jcp.ic * jcp.oc) / max_threads |
1908 | < gemm_thrld); |
1909 | } |
1910 | jcp.os_nb_block = div_up(jcp.os, jcp.os_block); |
1911 | |
1912 | // BF16: other loops should be explored for potential |
1913 | // performance speedup, but BF16-dst post-processing implementation |
1914 | // would require enabling this support. |
1915 | if (is_bf16_conv) jcp.loop_order = gemm_loop_lbr; |
1916 | |
1917 | if (jcp.im2col_sz) |
1918 | jcp.im2col_sz = (ptrdiff_t)jcp.ic_block * jcp.ks * jcp.os_block; |
1919 | } else if (jcp.is_nspc && is_bwd_d) { |
1920 | jcp.im2col_sz |
1921 | = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih, |
1922 | jcp.od == jcp.id, jcp.stride_w == 1, |
1923 | jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1, |
1924 | !jcp.signed_input) |
1925 | ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os * jcp.od |
1926 | : 0; |
1927 | |
1928 | bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; |
1929 | const size_t outer_work = jcp.ngroups * jcp.mb; |
1930 | const float outer_thr_eff |
1931 | = (float)outer_work / rnd_up(outer_work, max_threads); |
1932 | const size_t inner_work |
1933 | = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); |
1934 | const float inner_thr_eff |
1935 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1936 | jcp.outer_threading = !is_3d |
1937 | && (is_depthwise |
1938 | || (jcp.is / max_threads < 64 && jcp.mb != 1)) |
1939 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1940 | || (static_cast<size_t>(jcp.is) * jcp.ic * jcp.oc) |
1941 | / max_threads |
1942 | < gemm_thrld); |
1943 | |
1944 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
1945 | scratchpad.book<float>(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz); |
1946 | if (jcp.ngroups > 1 || is_bf16_conv) |
1947 | scratchpad.book<float>(key_conv_gemm_acc, |
1948 | jcp.nthr * static_cast<size_t>(jcp.is) * jcp.id |
1949 | * jcp.ic); |
1950 | } else if (!jcp.is_nspc && is_bwd_d) { |
1951 | const size_t outer_work_amount = jcp.ngroups * jcp.mb; |
1952 | const float outer_thr_eff = (float)outer_work_amount |
1953 | / rnd_up(outer_work_amount, max_threads); |
1954 | const size_t inner_work |
1955 | = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); |
1956 | const float inner_thr_eff |
1957 | = (float)inner_work / rnd_up(inner_work, max_threads); |
1958 | jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) |
1959 | && (jcp.mb != 1 || jcp.ngroups > 2) |
1960 | && (outer_thr_eff / inner_thr_eff >= 1.f |
1961 | || (jcp.is * jcp.ic * jcp.oc) / max_threads |
1962 | < gemm_thrld); |
1963 | } else if (jcp.is_nspc && is_bwd_w) { |
1964 | jcp.im2col_sz |
1965 | = !everyone_is(true, jcp.ow == jcp.iw, jcp.oh == jcp.ih, |
1966 | jcp.od == jcp.id, jcp.stride_w == 1, |
1967 | jcp.stride_h == 1, jcp.stride_d == 1, jcp.ks == 1, |
1968 | !jcp.signed_input) |
1969 | ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os |
1970 | : 0; |
1971 | const size_t gemm_col_datatype_size |
1972 | = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float); |
1973 | |
1974 | // Potential scratchpad memory requirement when outer threading is |
1975 | // enabled during f32/bf16 BWD_W nspc convolution |
1976 | size_t thr_mem_estimate = max_threads |
1977 | * (gemm_col_datatype_size * jcp.im2col_sz |
1978 | + gemm_col_datatype_size * jcp.id * jcp.is * jcp.ic |
1979 | + sizeof(float) * weights_d.size()); |
1980 | if (is_bf16_conv) { |
1981 | thr_mem_estimate += sizeof(float) * weights_d.size(); |
1982 | if (jcp.with_bias |
1983 | && one_of(data_type::bf16, cd.diff_bias_desc.data_type, |
1984 | cd.bias_desc.data_type)) |
1985 | thr_mem_estimate += sizeof(float) * jcp.ngroups * jcp.oc; |
1986 | } |
1987 | const bool outer_threading_mem_ok |
1988 | = thr_mem_estimate < scratchpad_limit; |
1989 | |
1990 | jcp.outer_threading = outer_threading_mem_ok |
1991 | && jcp.os / max_threads < 256 |
1992 | && (jcp.mb != 1 || jcp.ngroups > 2); |
1993 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
1994 | |
1995 | scratchpad.book(key_conv_gemm_col, jcp.nthr * jcp.im2col_sz, |
1996 | gemm_col_datatype_size); |
1997 | |
1998 | jcp.need_wei_reduction = jcp.mb != 1 && jcp.nthr != 1; |
1999 | scratchpad.book<float>( |
2000 | key_conv_wei_reduction, jcp.nthr * weights_d.size()); |
2001 | scratchpad.book(key_conv_gemm_imtr, |
2002 | static_cast<size_t>(jcp.nthr) * jcp.id * jcp.is * jcp.ic, |
2003 | gemm_col_datatype_size); |
2004 | if (is_bf16_conv) { |
2005 | size_t conv_acc_buffer_size = weights_d.size(); |
2006 | scratchpad.book<float>( |
2007 | key_conv_int_dat_in_acc_dt, conv_acc_buffer_size); |
2008 | } |
2009 | if ((is_bf16_conv) && jcp.with_bias |
2010 | && one_of(data_type::bf16, cd.diff_bias_desc.data_type, |
2011 | cd.bias_desc.data_type)) |
2012 | scratchpad.book<float>( |
2013 | key_conv_bias_bf16_convert_wsp, jcp.ngroups * jcp.oc); |
2014 | } else if (!jcp.is_nspc && is_bwd_w) { |
2015 | // Potential scratchpad memory requirement when outer threading is |
2016 | // enabled during f32/bf16 BWD_W blocked convolution |
2017 | size_t thr_mem_estimate |
2018 | = sizeof(float) * max_threads * weights_d.size(); |
2019 | if (is_bf16_conv) { |
2020 | thr_mem_estimate += sizeof(float) * weights_d.size(); |
2021 | if (jcp.with_bias |
2022 | && one_of(data_type::bf16, cd.diff_bias_desc.data_type, |
2023 | cd.bias_desc.data_type)) |
2024 | thr_mem_estimate += sizeof(float) * jcp.ngroups * jcp.oc; |
2025 | } |
2026 | const size_t gemm_col_datatype_size |
2027 | = is_bf16_conv ? sizeof(bfloat16_t) : sizeof(float); |
2028 | // Minimum memory requirement as os_block >= simd_w |
2029 | thr_mem_estimate += gemm_col_datatype_size * max_threads * jcp.ic |
2030 | * jcp.ks * simd_w; |
2031 | |
2032 | const bool outer_threading_mem_ok |
2033 | = thr_mem_estimate < scratchpad_limit; |
2034 | jcp.outer_threading = outer_threading_mem_ok |
2035 | && jcp.os / max_threads < 256 |
2036 | && (jcp.mb != 1 || jcp.ngroups > 2); |
2037 | } |
2038 | |
2039 | if (!jcp.is_nspc) { |
2040 | jcp.nthr = jcp.outer_threading ? max_threads : 1; |
2041 | const int sizeof_cacheline_float = 16; |
2042 | if (is_bwd_w) { |
2043 | jcp.need_wei_reduction = jcp.mb != 1 && jcp.nthr != 1; |
2044 | scratchpad.book<float>( |
2045 | key_conv_wei_reduction, jcp.nthr * weights_d.size()); |
2046 | } |
2047 | |
2048 | if (is_bf16_conv) { |
2049 | size_t conv_acc_buffer_size = 0; |
2050 | if (is_fwd) |
2051 | conv_acc_buffer_size = jcp.nthr |
2052 | * rnd_up(jcp.oc_block * jcp.os_block, |
2053 | sizeof_cacheline_float); |
2054 | else if (is_bwd_d) |
2055 | conv_acc_buffer_size = jcp.nthr |
2056 | * rnd_up(jcp.ic * jcp.ih * jcp.iw * jcp.id, |
2057 | sizeof_cacheline_float); |
2058 | else if (is_bwd_w) |
2059 | conv_acc_buffer_size = weights_d.size(); |
2060 | scratchpad.book<float>( |
2061 | key_conv_int_dat_in_acc_dt, conv_acc_buffer_size); |
2062 | if ((is_fwd || is_bwd_w) && jcp.with_bias |
2063 | && one_of(data_type::bf16, cd.diff_bias_desc.data_type, |
2064 | cd.bias_desc.data_type)) |
2065 | scratchpad.book<float>(key_conv_bias_bf16_convert_wsp, |
2066 | jcp.ngroups * jcp.oc); |
2067 | } |
2068 | |
2069 | const size_t gemm_col_datatype_size = is_bf16_conv && !is_bwd_d |
2070 | ? sizeof(bfloat16_t) |
2071 | : sizeof(float); |
2072 | size_t gemm_col_memory_sz = jcp.nthr * jcp.im2col_sz; |
2073 | |
2074 | if (is_bwd_d || is_bwd_w) { |
2075 | // check available memory |
2076 | if (scratchpad_limit < scratchpad.size()) |
2077 | return status::unimplemented; |
2078 | const size_t available_mem |
2079 | = scratchpad_limit - scratchpad.size(); |
2080 | if (available_mem |
2081 | < gemm_col_memory_sz * gemm_col_datatype_size) { |
2082 | // Required memory in this scenario overflows the |
2083 | // available memory due to the large dimensions. |
2084 | const int min_os_block = simd_w; |
2085 | const int max_os_block = (int)available_mem |
2086 | / ((int)gemm_col_datatype_size * jcp.nthr |
2087 | * (jcp.im2col_sz / jcp.os)); |
2088 | // Choose an arbitrary small coeficient reduce spatial |
2089 | // dimensions. |
2090 | // TODO: better heuristic to determine os_block based |
2091 | // on cache efficiency |
2092 | float _coef = is_bwd_w ? 0.05 : 0.1; |
2093 | jcp.os_block = nstl::max( |
2094 | min_os_block, (int)(max_os_block * _coef)); |
2095 | jcp.os_nb_block = div_up(jcp.os, jcp.os_block); |
2096 | jcp.im2col_sz = (ptrdiff_t)jcp.ic * jcp.ks * jcp.os_block; |
2097 | gemm_col_memory_sz = jcp.nthr * jcp.im2col_sz; |
2098 | } |
2099 | } |
2100 | scratchpad.book(key_conv_gemm_col, gemm_col_memory_sz, |
2101 | gemm_col_datatype_size); |
2102 | } |
2103 | } |
2104 | |
2105 | jcp.bias_data_type = cd.bias_desc.data_type; |
2106 | jcp.dst_data_type = dst_md.data_type; |
2107 | jcp.sum_data_type = jcp.post_ops.get_sum_dt(jcp.dst_data_type); |
2108 | jcp.dst_os_stride = dst_d.is_blocking_desc() |
2109 | ? dst_d.blocking_desc().strides[ndims - 1] |
2110 | : 0; |
2111 | jcp.scale_idx_mult = (attr.scales_.get(DNNL_ARG_WEIGHTS).mask_ |
2112 | == (1 << (int)with_groups)); |
2113 | jcp.with_dst_scale = !attr.scales_.get(DNNL_ARG_DST).has_default_values(); |
2114 | book_precomputed_scales(scratchpad, attr.scales_, jcp.ngroups * jcp.oc); |
2115 | |
2116 | if (jcp.zp.src_exists) { |
2117 | const auto size = zp_src_comp_size + zp_src_pad_comp_size; |
2118 | if (size) scratchpad.book<int32_t>(key_conv_gemm_zp_src_comp, size); |
2119 | } |
2120 | |
2121 | if (scratchpad.size() > scratchpad_limit) return status::unimplemented; |
2122 | return status::success; |
2123 | } |
2124 | |
2125 | void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g, |
2126 | int &nthr_g, int &ithr_mb, int &nthr_mb) { |
2127 | nthr_g = nstl::min(ngroups, nthr); |
2128 | nthr_mb = nstl::min(mb, nthr / nthr_g); |
2129 | if (ithr / nthr_mb >= ngroups) { |
2130 | ithr_g = ithr_mb = -1; |
2131 | } else { |
2132 | ithr_g = ithr / nthr_mb; |
2133 | ithr_mb = ithr % nthr_mb; |
2134 | } |
2135 | } |
2136 | |
2137 | void bwd_weights_reduction_par_ncsp(int ithr, int nthr, |
2138 | const conv_gemm_conf_t &jcp, const float *weights_reduce_ws, |
2139 | float *weights) { |
2140 | const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; |
2141 | |
2142 | size_t weights_start {0}, weights_end {0}; |
2143 | balance211(weights_g_size, nthr, ithr, weights_start, weights_end); |
2144 | |
2145 | for (int i = 0; i < nthr; ++i) { |
2146 | const float *ws_i = weights_reduce_ws + i * weights_g_size; |
2147 | for (size_t s = weights_start; s < weights_end; ++s) |
2148 | weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s]; |
2149 | } |
2150 | } |
2151 | |
2152 | void bwd_weights_reduction_par_nspc(int ithr, int nthr, size_t g_start, |
2153 | size_t g_end, const conv_gemm_conf_t &jcp, |
2154 | const float *weights_reduce_base, float *diff_weights) { |
2155 | const dim_t weights_g_size = jcp.oc; |
2156 | dim_t weights_start {0}, weights_end {0}; |
2157 | balance211(jcp.ks * jcp.ic, nthr, ithr, weights_start, weights_end); |
2158 | |
2159 | // Threads divide work w.r.t. min-batch and groups, therefore |
2160 | // - weights_reduce_base format: spatial-input_channels-output_channels |
2161 | // - diff_weights format: spatial-input_channels-groups-output_channels |
2162 | for (auto tidx = 0; tidx < nthr; ++tidx) { |
2163 | const float *ws_base |
2164 | = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic; |
2165 | for_(auto w = weights_start; w < weights_end; ++w) |
2166 | for (auto g = g_start; g < g_end; ++g) { |
2167 | float *__restrict dwei_ptr |
2168 | = diff_weights + (w * jcp.ngroups + g) * jcp.oc; |
2169 | const float *__restrict ws_ptr = ws_base + w * jcp.oc; |
2170 | if (tidx == 0) { |
2171 | PRAGMA_OMP_SIMD() |
2172 | for (auto oc = 0; oc < jcp.oc; ++oc) { |
2173 | dwei_ptr[oc] = ws_ptr[oc]; |
2174 | } |
2175 | } else { |
2176 | PRAGMA_OMP_SIMD() |
2177 | for (auto oc = 0; oc < jcp.oc; ++oc) { |
2178 | dwei_ptr[oc] += ws_ptr[oc]; |
2179 | } |
2180 | } |
2181 | } |
2182 | } |
2183 | } |
2184 | |
2185 | bool padding_exists(const conv_gemm_conf_t &jcp) noexcept { |
2186 | return jcp.l_pad || jcp.t_pad || jcp.f_pad || jcp.e_pad || jcp.b_pad |
2187 | || jcp.r_pad; |
2188 | } |
2189 | |
2190 | } // namespace jit_gemm_convolution_utils |
2191 | } // namespace cpu |
2192 | } // namespace impl |
2193 | } // namespace dnnl |
2194 | |