1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | #include "common/memory_tracking.hpp" |
19 | #include "common/utils.hpp" |
20 | |
21 | #include "cpu/cpu_primitive.hpp" |
22 | #include "cpu/scale_utils.hpp" |
23 | |
24 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
25 | #include "cpu/x64/jit_brdgmm_dw_conv.hpp" |
26 | #include <cpu/x64/cpu_isa_traits.hpp> |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::status; |
35 | using namespace dnnl::impl::utils; |
36 | |
37 | using namespace nstl; |
38 | using namespace data_type; |
39 | |
40 | inline status_t init_tag(memory_desc_t &md, const memory_desc_wrapper &mdw, |
41 | const format_tag_t tag_value, bool any_eligible) { |
42 | |
43 | format_tag_t tag; |
44 | if (mdw.format_kind() == format_kind::any) { |
45 | if (any_eligible) { |
46 | CHECK(memory_desc_init_by_tag(md, tag_value)); |
47 | tag = tag_value; |
48 | } else { |
49 | tag = format_tag::undef; |
50 | } |
51 | } else { |
52 | tag = mdw.matches_one_of_tag(tag_value); |
53 | } |
54 | |
55 | if (tag != tag_value) return status::unimplemented; |
56 | |
57 | return status::success; |
58 | } |
59 | |
60 | bool post_ops_ok(jit_brdgmm_conv_conf_t &jcp, const primitive_attr_t &attr, |
61 | const memory_desc_wrapper &dst_d) { |
62 | using namespace injector; |
63 | |
64 | const auto &post_ops = attr.post_ops_; |
65 | |
66 | return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(), |
67 | {sum, eltwise, binary}, post_ops, &dst_d, |
68 | false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/, |
69 | false /*sum_requires_zp_zero*/, |
70 | {broadcasting_strategy_t::per_oc, broadcasting_strategy_t::scalar, |
71 | broadcasting_strategy_t::no_broadcast})); |
72 | } |
73 | |
74 | cpu_isa_t get_supported_isa( |
75 | bool is_f32, bool is_int8, bool is_bf16, bool is_f16) { |
76 | std::vector<cpu_isa_t> isa_list; |
77 | if (is_f32) { |
78 | // Note: Temporarily disabling avx2 support until performance study. |
79 | isa_list = {avx512_core /*, avx2*/}; |
80 | } else if (is_int8) { |
81 | isa_list = {avx512_core_vnni}; |
82 | } else if (is_bf16) { |
83 | isa_list = {avx512_core_bf16, avx2_vnni_2}; |
84 | } else if (is_f16) { |
85 | isa_list = {avx512_core_fp16, avx2_vnni_2}; |
86 | } |
87 | |
88 | for (auto isa : isa_list) { |
89 | if (mayiuse(isa)) return isa; |
90 | } |
91 | return isa_undef; |
92 | } |
93 | |
94 | status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { |
95 | |
96 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
97 | |
98 | const auto &cd = *desc(); |
99 | const auto src_type = cd.src_desc.data_type; |
100 | const auto wei_type = cd.weights_desc.data_type; |
101 | const auto bia_type = cd.bias_desc.data_type; |
102 | const auto dst_type = cd.dst_desc.data_type; |
103 | |
104 | // TODO: support s8s8 conv |
105 | const bool is_f32 = everyone_is(f32, src_type, wei_type, dst_type); |
106 | const bool is_int8 = one_of(src_type, u8) && wei_type == s8 |
107 | && one_of(dst_type, s32, f32, u8, s8, bf16); |
108 | const bool is_bf16 = everyone_is(bf16, src_type, wei_type) |
109 | && one_of(dst_type, bf16, f32); |
110 | const bool is_f16 = everyone_is(f16, src_type, wei_type) |
111 | && one_of(dst_type, f16, f32); |
112 | const cpu_isa_t isa = get_supported_isa(is_f32, is_int8, is_bf16, is_f16); |
113 | |
114 | auto skip_mask = skip_mask_t::post_ops; |
115 | if (is_int8) skip_mask |= skip_mask_t::scales_runtime; |
116 | |
117 | bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) |
118 | && one_of(true, is_f32, is_int8, is_bf16, is_f16) |
119 | && (isa != isa_undef) && mayiuse(isa) |
120 | && IMPLICATION(is_int8, |
121 | one_of(bia_type, data_type::undef, f32, s32, s8, u8)) |
122 | && IMPLICATION(!is_int8, |
123 | one_of(bia_type, data_type::undef, src_type, dst_type)) |
124 | && attr()->has_default_values(skip_mask) && !has_zero_dim_memory(); |
125 | if (!ok) return status::unimplemented; |
126 | |
127 | auto &jcp = jcp_; |
128 | |
129 | const memory_desc_wrapper src_d(&src_md_); |
130 | const memory_desc_wrapper weights_d(&weights_md_); |
131 | const memory_desc_wrapper dst_d(&dst_md_); |
132 | const memory_desc_wrapper bias_d(&bias_md_); |
133 | |
134 | const int ndims = src_d.ndims(); |
135 | // Currently this kernel only supports 2D convolutions. |
136 | if (ndims != 4) return status::unimplemented; |
137 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
138 | if (!with_groups) return status::unimplemented; |
139 | // dilations are not supported |
140 | if (cd.dilates[0] != 0 || cd.dilates[1] != 0) return status::unimplemented; |
141 | |
142 | jcp = zero<decltype(jcp)>(); |
143 | jcp.ngroups = weights_d.dims()[0]; |
144 | jcp.mb = src_d.dims()[0]; |
145 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
146 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
147 | jcp.ih = src_d.dims()[2]; |
148 | jcp.iw = src_d.dims()[3]; |
149 | jcp.oh = dst_d.dims()[2]; |
150 | jcp.ow = dst_d.dims()[3]; |
151 | jcp.kh = weights_d.dims()[3]; |
152 | jcp.kw = weights_d.dims()[4]; |
153 | jcp.t_pad = cd.padding[0][0]; |
154 | jcp.l_pad = cd.padding[0][1]; |
155 | jcp.stride_h = cd.strides[0]; |
156 | jcp.stride_w = cd.strides[1]; |
157 | jcp.b_pad = calculate_end_padding( |
158 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.kh); |
159 | jcp.r_pad = calculate_end_padding( |
160 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.kw); |
161 | jcp.src_dt = cd.src_desc.data_type; |
162 | jcp.dst_dt = cd.dst_desc.data_type; |
163 | jcp.wei_dt = cd.weights_desc.data_type; |
164 | jcp.with_bias = with_bias(); |
165 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
166 | |
167 | if (!(everyone_is(1, jcp.ic, jcp.oc))) return status::unimplemented; |
168 | |
169 | const auto def_data_tag = format_tag::nhwc; |
170 | const bool any_eligible = (cd.prop_kind == prop_kind::forward_inference |
171 | || is_int8 || is_f16 || (isa == avx2_vnni_2 && is_bf16)); |
172 | CHECK(init_tag(src_md_, src_d, def_data_tag, any_eligible)); |
173 | CHECK(init_tag(dst_md_, dst_d, def_data_tag, any_eligible)); |
174 | |
175 | if (jcp.with_bias) { |
176 | if (bias_d.format_kind() == format_kind::any) |
177 | CHECK(memory_desc_init_by_tag(bias_md_, format_tag::x)); |
178 | } |
179 | |
180 | CHECK(attr_.set_default_formats(dst_md())); |
181 | if (!post_ops_ok(jcp, *attr(), dst_d)) return status::unimplemented; |
182 | jcp.with_post_ops = attr()->post_ops_.len() > 0; |
183 | |
184 | jcp.isa = isa; |
185 | jcp.nthr = dnnl_get_max_threads(); |
186 | jcp.src_dsz = types::data_type_size(jcp.src_dt); |
187 | jcp.wei_dsz = types::data_type_size(jcp.wei_dt); |
188 | jcp.bia_dsz |
189 | = jcp.with_bias ? types::data_type_size(cd.bias_desc.data_type) : 0; |
190 | jcp.dst_dsz = types::data_type_size(jcp.dst_dt); |
191 | |
192 | const auto &src_scales = attr_.scales_.get(DNNL_ARG_SRC); |
193 | const auto &wei_scales = attr_.scales_.get(DNNL_ARG_WEIGHTS); |
194 | const auto &dst_scales = attr_.scales_.get(DNNL_ARG_DST); |
195 | jcp.with_scale = !src_scales.has_default_values() |
196 | || !wei_scales.has_default_values(); |
197 | const int wei_mask_per_oc = 1 << (int)with_groups; |
198 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
199 | |
200 | // only common and per-oc-channel scales are supported |
201 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
202 | && src_scales.mask_ == 0 && dst_scales.has_default_values(); |
203 | if (!scales_ok) return status::unimplemented; |
204 | |
205 | // strd is only feasible for 1D (i.e., height dim is one) |
206 | // and if there are no tails (for calculating matrix_B strides). |
207 | // Since, we cannot always predict the blocking is 8 or 16. |
208 | if (jcp.kh == 1 && jcp.ngroups % 16 == 0) { |
209 | jcp.batch_kind = brgemm_strd; |
210 | } else if ((jcp.mb * jcp.oh) % jcp.nthr != 0) { |
211 | jcp.batch_kind = brgemm_offs; |
212 | } else { |
213 | jcp.batch_kind = brgemm_addr; |
214 | } |
215 | |
216 | // to avoid cache concurrent access from different threads |
217 | size_t sc_size = sizeof(brgemm_batch_element_t); |
218 | jcp.adjusted_batch_size |
219 | = div_up(rnd_up(jcp.kh * jcp.kw * sc_size, 4096), sc_size); |
220 | CHECK(init_brdgmm_conf()); |
221 | CHECK(init_scratchpad()); |
222 | if (jcp.with_scale) { |
223 | auto scratchpad = scratchpad_registry().registrar(); |
224 | book_precomputed_scales(scratchpad, attr_.scales_, OC()); |
225 | } |
226 | |
227 | return status::success; |
228 | } |
229 | |
230 | status_t brdgmm_dw_convolution_fwd_t::pd_t::init_brdgmm_conf() { |
231 | |
232 | auto &jcp = jcp_; |
233 | |
234 | auto init_bcp = [&](int &idx, const int M, const int N) { |
235 | const float alpha = 1.f; |
236 | const float beta = 0.f; |
237 | const int LDA = jcp.ngroups * jcp.stride_w; |
238 | const int LDC = jcp.ngroups; |
239 | const int LDD = jcp.ngroups; |
240 | |
241 | brgemm_attr_t brg_attr; |
242 | brg_attr.max_bs = jcp.kw * jcp.kh; |
243 | brg_attr.max_top_vpad = nstl::max(0, jcp.l_pad); |
244 | brg_attr.max_bottom_vpad = nstl::max(0, jcp.r_pad); |
245 | |
246 | // only needed for strd batch_kind |
247 | const brgemm_strides_t strides |
248 | = {static_cast<dim_t>(jcp.src_dsz) * jcp.ngroups, |
249 | static_cast<dim_t>(jcp.wei_dsz) * jcp.ngroups}; |
250 | |
251 | auto &bcp = bcps_[idx]; |
252 | CHECK(brdgmm_desc_init(&bcp, jcp.isa, jcp.batch_kind, jcp.src_dt, |
253 | jcp.wei_dt, false /*transA*/, brgemm_row_major, alpha, beta, |
254 | LDA, LDC, M, N, &strides)); |
255 | CHECK(brgemm_desc_set_attr(&bcp, brg_attr)); |
256 | CHECK(brgemm_desc_set_postops(&bcp, attr(), dst_md(), LDD, jcp.bia_dt)); |
257 | ++idx; |
258 | return status::success; |
259 | }; |
260 | |
261 | bcps_.resize(1); |
262 | jcp.ow_block = jcp.ow; |
263 | jcp.nb_ow = 1; |
264 | jcp.nb_ch_blocking = jcp.ngroups; |
265 | jcp.chb_tail = 0; |
266 | int ker_idx = 0; |
267 | CHECK(init_bcp(ker_idx, jcp.ow, jcp.ngroups)); // default full row kernel. |
268 | |
269 | const auto &bcp_0 = bcps_[0]; |
270 | jcp.ch_block = bcp_0.ld_block; |
271 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
272 | |
273 | const auto wei_tag |
274 | = jcp.ch_block == 16 ? format_tag::hwioG16g : format_tag::hwioG8g; |
275 | const memory_desc_wrapper weights_d(&weights_md_); |
276 | CHECK(init_tag(weights_md_, weights_d, wei_tag, true)); |
277 | |
278 | if ((jcp.mb * jcp.oh) % jcp.nthr != 0) { |
279 | // determine ow_block |
280 | { |
281 | const size_t work_amount = jcp.mb * jcp.oh * jcp.ow; |
282 | if (work_amount % jcp.nthr == 0) { |
283 | const size_t work_per_thr = div_up(work_amount, jcp.nthr); |
284 | const size_t ow_tail_block |
285 | = (work_per_thr / jcp.nb_ch) % jcp.ow; |
286 | if (ow_tail_block && (jcp.ow % ow_tail_block == 0)) |
287 | jcp.ow_block = ow_tail_block; |
288 | else { |
289 | jcp.ow_block = jcp.ow; |
290 | } |
291 | } else { |
292 | const int max_ow_block = is_superset(jcp.isa, avx512_core) |
293 | ? 6 |
294 | : bcp_0.bd_block2 /*TODO: Tune for avx2*/; |
295 | jcp.ow_block = nstl::min(max_ow_block, jcp.ow); |
296 | } |
297 | jcp.ow_tail = jcp.ow % jcp.ow_block; |
298 | } |
299 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
300 | |
301 | // determine nb_ch_block |
302 | { |
303 | const size_t work_amount = jcp.mb * jcp.nb_ch * jcp.oh * jcp.nb_ow; |
304 | if (work_amount % jcp.nthr == 0) { |
305 | const size_t work_per_thr = div_up(work_amount, jcp.nthr); |
306 | const size_t ch_tail_block = work_per_thr % jcp.nb_ch; |
307 | if (ch_tail_block && (jcp.nb_ch % ch_tail_block == 0)) |
308 | jcp.nb_ch_blocking = ch_tail_block * jcp.ch_block; |
309 | else |
310 | jcp.nb_ch_blocking = jcp.ngroups; |
311 | } else { |
312 | const int max_ch_block2 = is_superset(jcp.isa, avx512_core) |
313 | ? 4 |
314 | : bcp_0.ld_block2 /*TODO: Tune for avx2*/; |
315 | jcp.nb_ch_blocking |
316 | = nstl::min(max_ch_block2 * jcp.ch_block, jcp.ngroups); |
317 | } |
318 | jcp.chb_tail = jcp.ngroups % jcp.nb_ch_blocking; |
319 | } |
320 | |
321 | const int n_owb_kernels = std::ceil(log2(jcp.nb_ow)); |
322 | const int num_kernels = 1 /*full ow*/ + n_owb_kernels |
323 | + (jcp.chb_tail != 0) + (jcp.nb_ch_blocking != jcp.ngroups) |
324 | + (jcp.ow_tail != 0); |
325 | bcps_.resize(num_kernels); |
326 | |
327 | for (int i = 0; i < n_owb_kernels; ++i) { |
328 | CHECK(init_bcp(ker_idx, jcp.ow_block * (1 << i), jcp.ngroups)); |
329 | } |
330 | |
331 | if (jcp.chb_tail) { |
332 | jcp.chb_tail_idx = ker_idx; |
333 | CHECK(init_bcp(ker_idx, jcp.ow_block, jcp.chb_tail)); |
334 | } |
335 | |
336 | if (jcp.ow_tail) { |
337 | jcp.ow_tail_idx = ker_idx; |
338 | CHECK(init_bcp(ker_idx, jcp.ow_tail, jcp.ngroups)); |
339 | } |
340 | |
341 | if (jcp.nb_ch_blocking != jcp.ngroups) { |
342 | jcp.nb_ch_blocking_idx = ker_idx; |
343 | CHECK(init_bcp(ker_idx, jcp.ow_block, jcp.nb_ch_blocking)); |
344 | } |
345 | assert(num_kernels == ker_idx); |
346 | } |
347 | |
348 | return status::success; |
349 | } |
350 | |
351 | status_t brdgmm_dw_convolution_fwd_t::pd_t::init_scratchpad() { |
352 | const auto &jcp = jcp_; |
353 | auto scratchpad = scratchpad_registry().registrar(); |
354 | |
355 | scratchpad.book(key_brgemm_primitive_batch, |
356 | static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size, |
357 | sizeof(brgemm_batch_element_t), 64); |
358 | return status::success; |
359 | } |
360 | |
361 | status_t brdgmm_dw_convolution_fwd_t::init(engine_t *engine) { |
362 | const auto &bcps = pd()->bcps_; |
363 | brdgmm_kernels_.resize(bcps.size()); |
364 | |
365 | for (size_t idx = 0; idx < bcps.size(); ++idx) { |
366 | const auto &bcp = bcps[idx]; |
367 | if (bcp.bcast_dim * bcp.load_dim /* M*N */ == 0) continue; |
368 | brgemm_kernel_t *brg_kernel = nullptr; |
369 | CHECK(brgemm_kernel_create(&brg_kernel, pd()->bcps_[idx])); |
370 | CHECK(safe_ptr_assign(brdgmm_kernels_[idx], brg_kernel)); |
371 | } |
372 | |
373 | return status::success; |
374 | } |
375 | |
376 | status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { |
377 | |
378 | const char *const __restrict src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
379 | const char *const __restrict weights |
380 | = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
381 | const char *const __restrict bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
382 | char *const __restrict dst = CTX_OUT_MEM(const char *, DNNL_ARG_DST); |
383 | const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
384 | brgemm_batch_element_t *const __restrict brg_batch_global |
385 | = scratchpad.template get<brgemm_batch_element_t>( |
386 | key_brgemm_primitive_batch); |
387 | const std::vector<const void *> post_ops_binary_rhs_arg_vec |
388 | = binary_injector::prepare_binary_args( |
389 | pd()->attr()->post_ops_, ctx); |
390 | |
391 | const auto &jcp = pd()->jcp_; |
392 | |
393 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
394 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
395 | |
396 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
397 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
398 | |
399 | const int chb_step = jcp.nb_ch_blocking; |
400 | const int chb_work = div_up(jcp.ngroups, chb_step); |
401 | const int ow_step = jcp.ow_block; |
402 | const int work_amount = jcp.mb * jcp.oh * jcp.nb_ow * chb_work; |
403 | |
404 | const size_t src_ch_stride = jcp.src_dsz; |
405 | const size_t src_w_stride = jcp.ngroups * jcp.src_dsz; |
406 | const size_t src_h_stride = jcp.ngroups * jcp.iw * jcp.src_dsz; |
407 | const size_t src_mb_stride = jcp.ngroups * jcp.iw * jcp.ih * jcp.src_dsz; |
408 | const size_t wei_ch_stride = jcp.wei_dsz; |
409 | const size_t wei_w_stride = rnd_up(jcp.ngroups, jcp.ch_block) * jcp.wei_dsz; |
410 | const size_t wei_h_stride = wei_w_stride * jcp.kw; |
411 | const size_t dst_ch_stride = jcp.dst_dsz; |
412 | const size_t dst_w_stride = jcp.ngroups * jcp.dst_dsz; |
413 | const size_t dst_h_stride = jcp.ngroups * jcp.ow * jcp.dst_dsz; |
414 | const size_t dst_mb_stride = jcp.ngroups * jcp.ow * jcp.oh * jcp.dst_dsz; |
415 | |
416 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
417 | int start {0}, end {0}; |
418 | balance211(work_amount, nthr, ithr, start, end); |
419 | int n {0}, chb {0}, oh {0}, owb {0}; |
420 | |
421 | auto iwork = start; |
422 | brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global |
423 | + static_cast<size_t>(ithr) * jcp.adjusted_batch_size; |
424 | const brgemm_kernel_t *kernel = nullptr; |
425 | const brgemm_kernel_t *kernel_chb_tail |
426 | = brdgmm_kernels_[jcp.chb_tail_idx].get(); |
427 | brgemm_post_ops_data_t post_ops_data; |
428 | post_ops_data.binary_post_ops_rhs = post_ops_binary_rhs_arg_vec.data(); |
429 | post_ops_data.data_C_ptr_ = dst; |
430 | |
431 | while (iwork < end) { |
432 | nd_iterator_init(iwork, n, jcp.mb, oh, jcp.oh, owb, jcp.nb_ow, chb, |
433 | chb_work); |
434 | const bool is_m_tail = jcp.ow_tail != 0 && (owb + 1 == jcp.nb_ow); |
435 | const bool is_n_tail = jcp.chb_tail != 0 && (chb + 1 == chb_work); |
436 | if (is_m_tail && chb != 0) { |
437 | // the tail ow_block is not split btw threads to reduce the |
438 | // number of kernels. |
439 | utils::nd_iterator_jump(iwork, end, n, jcp.mb, oh, jcp.oh, owb, |
440 | jcp.nb_ow, chb, chb_work); |
441 | continue; |
442 | } |
443 | |
444 | const auto rem_work = end - iwork; |
445 | const int rem_row_owb |
446 | = saturate(1, jcp.nb_ow - owb, rem_work / chb_work); |
447 | int cur_n_owb = 1; |
448 | int ker_idx = 0; |
449 | if (is_n_tail) { |
450 | ker_idx = jcp.chb_tail_idx; |
451 | } else if (is_m_tail) { |
452 | ker_idx = jcp.ow_tail_idx; |
453 | } else if (chb != 0 || rem_work < chb_work) { |
454 | ker_idx = jcp.nb_ch_blocking_idx; |
455 | } else if (rem_row_owb == jcp.nb_ow) { |
456 | ker_idx = 0; |
457 | cur_n_owb = jcp.nb_ow; |
458 | } else { |
459 | // The ow_tail kernel is processed alone, subtract if it exists. |
460 | const int log_rem_owb = log2(rem_row_owb |
461 | - (owb + rem_row_owb >= jcp.nb_ow) |
462 | * (jcp.ow_tail != 0)); |
463 | cur_n_owb = (1 << log_rem_owb); |
464 | ker_idx = log_rem_owb + 1; // add 1 as 0th is full row. |
465 | } |
466 | |
467 | kernel = brdgmm_kernels_[ker_idx].get(); |
468 | |
469 | int ch = chb * chb_step; |
470 | const int ow = owb * ow_step; |
471 | auto *ptr_A = src; |
472 | auto *ptr_B = weights; |
473 | int bs = 0; |
474 | for (int kh = 0; kh < jcp.kh; ++kh) { |
475 | for (int kw = 0; kw < jcp.kw; ++kw) { |
476 | const int ih = (oh * jcp.stride_h - jcp.t_pad) + kh; |
477 | if (ih < 0 || ih >= jcp.ih) continue; |
478 | const int iw_s = ow * jcp.stride_w - jcp.l_pad + kw; |
479 | const int ow_e |
480 | = nstl::min(jcp.ow, ow + cur_n_owb * jcp.ow_block) |
481 | - 1; |
482 | const int iw_e = ow_e * jcp.stride_w - jcp.l_pad + kw; |
483 | auto &batch = brg_batch[bs]; |
484 | batch.vvpad.top = nstl::max(0, div_up(-iw_s, jcp.stride_w)); |
485 | batch.vvpad.bottom = nstl::max<dim_t>( |
486 | 0, div_up(iw_e - (jcp.iw - 1), jcp.stride_w)); |
487 | const dim_t offs_A = n * src_mb_stride + ih * src_h_stride |
488 | + iw_s * src_w_stride + ch * src_ch_stride; |
489 | const dim_t offs_B = kh * wei_h_stride + kw * wei_w_stride |
490 | + ch * wei_ch_stride; |
491 | if (jcp.batch_kind == brgemm_offs) { |
492 | batch.offset.A = offs_A; |
493 | batch.offset.B = offs_B; |
494 | } else if (jcp.batch_kind == brgemm_addr) { |
495 | batch.ptr.A = src + offs_A; |
496 | batch.ptr.B = weights + offs_B; |
497 | } else { |
498 | assert(jcp.batch_kind == brgemm_strd); |
499 | if (bs == 0) { |
500 | ptr_A = src + offs_A; |
501 | ptr_B = weights + offs_B; |
502 | } |
503 | } |
504 | ++bs; |
505 | } |
506 | } |
507 | auto ptr_C = dst + n * dst_mb_stride + oh * dst_h_stride |
508 | + ow * dst_w_stride + ch * dst_ch_stride; |
509 | const int rem_chb_work = chb_work - chb; |
510 | int chb_loop_work = is_m_tail || (chb == 0 && rem_work >= chb_work) |
511 | ? 1 // Compute entire chb_work in single jit call |
512 | : nstl::min(rem_work, rem_chb_work); |
513 | iwork += cur_n_owb * nstl::min(rem_work, rem_chb_work); |
514 | |
515 | while (chb_loop_work) { |
516 | // brgemm_offs and brgemm_strd mode enables us to run this loop, |
517 | // without changing brg_batch elements. |
518 | assert(IMPLICATION(chb != 0, |
519 | one_of(jcp.batch_kind, brgemm_offs, brgemm_strd))); |
520 | post_ops_data.bias = bias + ch * jcp.bia_dsz; |
521 | post_ops_data.scales = &oscales[jcp.is_oc_scale * ch]; |
522 | post_ops_data.oc_logical_off = ch; |
523 | brgemm_kernel_execute_postops(kernel, bs, ptr_A, ptr_B, |
524 | brg_batch, ptr_C, ptr_C, post_ops_data, |
525 | nullptr /*scratch*/); |
526 | ++chb; |
527 | if (jcp.chb_tail != 0 && chb + 1 == chb_work) |
528 | kernel = kernel_chb_tail; |
529 | ch += chb_step; |
530 | ptr_A += chb_step * src_ch_stride; |
531 | ptr_B += chb_step * wei_ch_stride; |
532 | ptr_C += chb_step * dst_ch_stride; |
533 | --chb_loop_work; |
534 | } |
535 | } |
536 | }); |
537 | return status::success; |
538 | } |
539 | } // namespace x64 |
540 | } // namespace cpu |
541 | } // namespace impl |
542 | } // namespace dnnl |
543 | |