1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/cpu_primitive.hpp" |
22 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
23 | |
24 | #include "cpu/x64/jit_brgemm_conv_bwd_w.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::status; |
32 | using namespace dnnl::impl::memory_tracking::names; |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | using namespace nstl; |
36 | using namespace data_type; |
37 | |
38 | #define wht_blk_off(d, g, ...) \ |
39 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
40 | : (d).blk_off(__VA_ARGS__)) |
41 | |
42 | status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) { |
43 | const auto src_type = src_md(0)->data_type; |
44 | const auto diff_wei_type = diff_weights_md(0)->data_type; |
45 | const auto diff_bia_type = diff_weights_md(1)->data_type; |
46 | const auto diff_dst_type = diff_dst_md(0)->data_type; |
47 | bool ok = true && is_bwd_w() |
48 | && set_default_alg_kind(alg_kind::convolution_direct) |
49 | && utils::one_of(src_type, bf16, f16) && diff_dst_type == src_type |
50 | && utils::one_of(diff_wei_type, f32, src_type) |
51 | && utils::one_of(diff_bia_type, data_type::undef, f32, src_type) |
52 | && attr()->has_default_values() && !has_zero_dim_memory(); |
53 | if (!ok) return status::unimplemented; |
54 | |
55 | auto scratchpad = scratchpad_registry().registrar(); |
56 | |
57 | status_t status = brgemm_convolution_utils::init_conf_bwd_w(jcp_, *desc(), |
58 | src_md_, diff_weights_md_, diff_bias_md_, diff_dst_md_, attr_, |
59 | dnnl_get_max_threads()); |
60 | if (status != status::success) return status; |
61 | |
62 | status = brgemm_convolution_utils::init_scratchpad_bwd_w( |
63 | scratchpad, jcp_, src_md_, diff_weights_md_, diff_dst_md_); |
64 | |
65 | if (status != status::success) return status; |
66 | copy2jit_jcp(); |
67 | |
68 | bs_c = jcp_.var_bs ? 1 : (jcp_.max_batch + 1); |
69 | batchsizes.resize(bs_c + 1); |
70 | for (int i = 0; i <= bs_c; i++) |
71 | batchsizes[i] = -1; |
72 | |
73 | batchsizes[1] = 0; |
74 | |
75 | const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail); |
76 | brgs_sz_ = bs_c * (adj_M + 1) * 2 * 2 * 2; |
77 | brgs_.resize(brgs_sz_); |
78 | bd_masks.resize(brgs_sz_); |
79 | |
80 | const float alpha = 1.0; |
81 | const float beta = 1.0; |
82 | |
83 | int M_begin = 0; |
84 | int M_end = (jcp_.M_tail == jcp_.M || jcp_.M_tail == 0) ? 1 : 2; |
85 | int N_begin = 0; |
86 | int N_end = (jcp_.N_tail == jcp_.N || jcp_.N_tail == 0) ? 1 : 2; |
87 | int K_begin = 0; |
88 | int K_end = (jcp_.K_tail == jcp_.K || jcp_.K_tail == 0) ? 1 : 2; |
89 | int init_begin = 0; |
90 | int init_end = 2; |
91 | |
92 | const auto wei_type = src_type; |
93 | |
94 | for (int i = M_begin; i < M_end; i++) { |
95 | auto M = (i) ? jcp_.M_tail : jcp_.M; |
96 | if (M <= 0) continue; |
97 | // init only needed brgemm descriptors |
98 | for (int bs = 0; bs <= jcp_.max_batch; bs++) { |
99 | if (batchsizes[bs] == -1) continue; |
100 | for_(int i_init = init_begin; i_init < init_end; i_init++) |
101 | for_(int i_N = N_begin; i_N < N_end; i_N++) |
102 | for (int i_K = K_begin; i_K < K_end; i_K++) { |
103 | auto vbeta = (i_init) ? 0 : beta; |
104 | auto vN = (i_N) ? jcp_.N_tail : jcp_.N; |
105 | auto vK = (i_K) ? jcp_.K_tail : jcp_.K; |
106 | if (vN == 0 || vK == 0) continue; |
107 | auto brg_idx = get_brg_idx(bs, M, i_init, i_N, i_K); |
108 | // if brgemm_t already created then skip this iteration |
109 | if (brgs_[brg_idx] != nullptr) continue; |
110 | brgs_[brg_idx] = std::make_shared<brgemm_t>(); |
111 | brgemm_t *brg = brgs_[brg_idx].get(); |
112 | CHECK(brgemm_desc_init(brg, jcp_.isa, jcp_.brg_type, src_type, |
113 | wei_type, false, false, brgemm_row_major, alpha, vbeta, |
114 | jcp_.LDA, jcp_.LDB, jcp_.LDC, M, vN, vK, nullptr)); |
115 | |
116 | brgemm_attr_t brgattr; |
117 | brgattr.use_uker = jcp_.use_uker; |
118 | brgattr.use_interleave_stores = jcp_.use_interleave_stores; |
119 | brgattr.hint_prefetching = jcp_.hint_prefetching; |
120 | brgattr.var_bs = jcp_.var_bs; |
121 | brgattr.max_bs = jcp_.max_batch; |
122 | brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost |
123 | ? brgemm_bd_loop_innermost |
124 | : brgemm_ld_loop_innermost; |
125 | |
126 | brgattr.hint_expected_A_size = 0; |
127 | brgattr.hint_expected_B_size = 0; |
128 | brgattr.hint_expected_C_size = 0; |
129 | |
130 | brgattr.wary_tail_read = false; |
131 | brgattr.bd_mask_level = jcp_.use_M_mask; |
132 | |
133 | brgattr.max_top_vpad = 0; |
134 | brgattr.max_bottom_vpad = 0; |
135 | |
136 | brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id; |
137 | brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od; |
138 | brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw; |
139 | brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block |
140 | * jcp_.kd * jcp_.kh * jcp_.kw; |
141 | |
142 | CHECK(brgemm_desc_set_attr(brg, brgattr)); |
143 | } |
144 | } |
145 | } |
146 | return status; |
147 | } |
148 | |
149 | // jit_jcp used to initialize transpose kernels shared with jit implementation |
150 | void brgemm_convolution_bwd_weights_t::pd_t::copy2jit_jcp() { |
151 | jit_jcp_ = zero<decltype(jit_jcp_)>(); |
152 | jit_jcp_.prop_kind = jcp_.prop_kind; |
153 | jit_jcp_.has_vnni = jcp_.has_vnni; |
154 | jit_jcp_.harness = jcp_.harness; |
155 | jit_jcp_.simd_w = jcp_.simd_w; |
156 | jit_jcp_.ndims = jcp_.ndims; |
157 | jit_jcp_.mb = jcp_.mb; |
158 | jit_jcp_.ngroups = jcp_.ngroups; |
159 | jit_jcp_.ic = jcp_.ic; |
160 | jit_jcp_.oc = jcp_.oc; |
161 | jit_jcp_.oc_without_padding = jcp_.oc; |
162 | jit_jcp_.ic_without_padding = jcp_.ic_without_padding; |
163 | jit_jcp_.id = jcp_.id; |
164 | jit_jcp_.ih = jcp_.ih; |
165 | jit_jcp_.iw = jcp_.iw; |
166 | jit_jcp_.od = jcp_.od; |
167 | jit_jcp_.oh = jcp_.oh; |
168 | jit_jcp_.ow = jcp_.ow; |
169 | jit_jcp_.f_pad = jcp_.f_pad; |
170 | jit_jcp_.l_pad = jcp_.l_pad; |
171 | jit_jcp_.t_pad = jcp_.t_pad; |
172 | jit_jcp_.back_pad = jcp_.back_pad; |
173 | jit_jcp_.r_pad = jcp_.r_pad; |
174 | jit_jcp_.b_pad = jcp_.b_pad; |
175 | jit_jcp_.kd = jcp_.kd; |
176 | jit_jcp_.kh = jcp_.kh; |
177 | jit_jcp_.kw = jcp_.kw; |
178 | jit_jcp_.stride_d = jcp_.stride_d; |
179 | jit_jcp_.stride_h = jcp_.stride_h; |
180 | jit_jcp_.stride_w = jcp_.stride_w; |
181 | jit_jcp_.dilate_d = jcp_.dilate_d; |
182 | jit_jcp_.dilate_h = jcp_.dilate_h; |
183 | jit_jcp_.dilate_w = jcp_.dilate_w; |
184 | jit_jcp_.src_tag = jcp_.src_tag; |
185 | jit_jcp_.wei_tag = jcp_.wei_tag; |
186 | jit_jcp_.dst_tag = jcp_.dst_tag; |
187 | jit_jcp_.with_bias = jcp_.with_bias; |
188 | jit_jcp_.with_sum = jcp_.with_sum; |
189 | jit_jcp_.with_eltwise = jcp_.with_eltwise; |
190 | jit_jcp_.with_binary = jcp_.with_binary; |
191 | jit_jcp_.is_fused_conv = jcp_.is_fused_conv; |
192 | jit_jcp_.nb_ic = jcp_.nb_ic; |
193 | jit_jcp_.ic_block = jcp_.ic_block; |
194 | jit_jcp_.nb_oc = jcp_.nb_oc; |
195 | jit_jcp_.oc_block = jcp_.oc_block; |
196 | jit_jcp_.nb_oc_blocking = jcp_.nb_oc_blocking; |
197 | |
198 | jit_jcp_.ic_tail = jcp_.ic_tail; |
199 | jit_jcp_.oc_tail = jcp_.oc_tail; |
200 | |
201 | jit_jcp_.tr_iw = jcp_.tr_iw; |
202 | jit_jcp_.tr_ow = jcp_.tr_ow; |
203 | jit_jcp_.tr_diff_dst_buf_size = jcp_.tr_diff_dst_buf_size; |
204 | jit_jcp_.typesize_in = jcp_.typesize_in; |
205 | jit_jcp_.typesize_out = jcp_.typesize_out; |
206 | jit_jcp_.ddst_dt = jcp_.dst_dt; |
207 | } |
208 | |
209 | status_t brgemm_convolution_bwd_weights_t::add_brg_kernel( |
210 | int bs, int M, int i_N, int i_K, int i_init) { |
211 | if (M <= 0) return status::success; |
212 | const auto _pd = pd(); |
213 | const auto &jcp = _pd->jcp_; |
214 | const auto &brgs = _pd->brgs_; |
215 | |
216 | auto N = (i_N) ? jcp.N_tail : jcp.N; |
217 | auto K = (i_K) ? jcp.K_tail : jcp.K; |
218 | if (N <= 0 || K <= 0) return status::success; |
219 | auto brg_idx = _pd->get_brg_idx(bs, M, i_init, i_N, i_K); |
220 | auto brg = brgs[brg_idx]; |
221 | if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0 |
222 | && brg->reduce_dim > 0) { |
223 | brgemm_kernel_t *brg_kernel = nullptr; |
224 | CHECK(brgemm_kernel_create(&brg_kernel, *brg)); |
225 | CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel)); |
226 | CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0])); |
227 | } |
228 | return status::success; |
229 | } |
230 | |
231 | status_t brgemm_convolution_bwd_weights_t::init(engine_t *engine) { |
232 | const auto _pd = pd(); |
233 | const auto &jcp = _pd->jcp_; |
234 | const auto &jit_jcp = pd()->jit_jcp_; |
235 | |
236 | CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&jit_jcp))); |
237 | CHECK(trans_kernel_->create_kernel()); |
238 | CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&jit_jcp))); |
239 | CHECK(trans_dst_kernel_->create_kernel()); |
240 | |
241 | if (jcp.with_bias) { |
242 | CHECK(safe_ptr_assign(diff_bias_kernel_, |
243 | new jit_avx512_core_amx_bwd_bias_kernel_t(jit_jcp))); |
244 | CHECK(diff_bias_kernel_->create_kernel()); |
245 | } |
246 | |
247 | if (jcp.nthr_mb > 1) { |
248 | CHECK(safe_ptr_assign( |
249 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
250 | CHECK(acc_ker_->create_kernel()); |
251 | } |
252 | if (jcp.transform_to_vnni) { |
253 | CHECK(safe_ptr_assign(diff_wei_trans_kernel_, |
254 | new jit_diff_wei_trans_to_vnni_t(jcp.wei_dt, jcp.kd, jcp.kh, |
255 | jcp.kw, jcp.ic_block, jcp.oc_block))); |
256 | CHECK(diff_wei_trans_kernel_->create_kernel()); |
257 | } |
258 | |
259 | brg_kernels_.resize(_pd->brgs_sz_); |
260 | brg_kernel_palettes_.resize(_pd->brgs_sz_); |
261 | |
262 | for (int i = 0; i < _pd->brgs_sz_; i++) |
263 | brg_kernels_[i] = nullptr; |
264 | |
265 | int M_begin = 0; |
266 | int M_end = (jcp.M_tail == jcp.M || jcp.M_tail == 0) ? 1 : 2; |
267 | int N_begin = 0; |
268 | int N_end = (jcp.N_tail == jcp.N || jcp.N_tail == 0) ? 1 : 2; |
269 | int K_begin = 0; |
270 | int K_end = (jcp.K_tail == jcp.K || jcp.K_tail == 0) ? 1 : 2; |
271 | int init_begin = 0; |
272 | int init_end = 2; |
273 | |
274 | for (int bs = 0; bs <= jcp.max_batch; bs++) { |
275 | if (_pd->batchsizes[bs] == -1) continue; |
276 | |
277 | for_(int i_N = N_begin; i_N < N_end; i_N++) |
278 | for_(int i_M = M_begin; i_M < M_end; i_M++) |
279 | for_(int i_init = init_begin; i_init < init_end; i_init++) |
280 | for (int i_K = K_begin; i_K < K_end; i_K++) { |
281 | auto M = (i_M) ? jcp.M_tail : jcp.M; |
282 | if (M <= 0) continue; |
283 | add_brg_kernel(bs, M, i_N, i_K, i_init); |
284 | } |
285 | } |
286 | |
287 | return status::success; |
288 | } |
289 | |
290 | struct brgemm_convolution_bwd_weights_t::thread_info_t { |
291 | const src_data_t *src = nullptr; |
292 | const diff_dst_data_t *diff_dst = nullptr; |
293 | const void *diff_weights = nullptr; |
294 | const void *diff_bias = nullptr; |
295 | |
296 | const brgemm_convolution_bwd_weights_t *self; |
297 | const memory_tracking::grantor_t scratchpad; |
298 | |
299 | src_data_t *tr_src = nullptr; |
300 | diff_dst_data_t *tr_diff_dst = nullptr; |
301 | simple_barrier::ctx_t *tr_src_bctx = nullptr; |
302 | simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr; |
303 | |
304 | float *wei_bia_reduction = nullptr; |
305 | float *bia_reduction = nullptr; |
306 | simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr; |
307 | |
308 | // All nthreads are mapped to a multidimensional "cube" with sizes: |
309 | // (nthr_mb, nthr_g, nthr_oc, nthr_ic). |
310 | // Variables ithr_* define the coordinates and "layers" of the current |
311 | // thread in this "cube" |
312 | int ithr = 0; |
313 | int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0; |
314 | int ithr_but_oc = 0; |
315 | int ithr_but_ic = 0; |
316 | |
317 | int img_start = 0, img_end = 0, img_work = 0; |
318 | int g_start = 0, g_end = 0, g_work = 0; |
319 | int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0; |
320 | int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0; |
321 | |
322 | S_t cur_palette; |
323 | brgemm_batch_element_t *__restrict brg_batch; |
324 | char *wsp_tile; |
325 | const exec_ctx_t &exec_ctx; |
326 | const jit_brgemm_conv_conf_t &jcp; |
327 | const memory_desc_wrapper src_d; |
328 | const memory_desc_wrapper diff_dst_d; |
329 | const memory_desc_wrapper diff_weights_d; |
330 | |
331 | thread_info_t(const brgemm_convolution_bwd_weights_t *pcnv, |
332 | const exec_ctx_t &ctx, int ithr) |
333 | : self(pcnv) |
334 | , scratchpad(ctx.get_scratchpad_grantor()) |
335 | , ithr(ithr) |
336 | , exec_ctx(ctx) |
337 | , jcp(self->pd()->jcp_) |
338 | , src_d(self->pd()->src_md()) |
339 | , diff_dst_d(self->pd()->diff_dst_md()) |
340 | , diff_weights_d(self->pd()->diff_weights_md(0)) { |
341 | diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
342 | src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
343 | diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS); |
344 | |
345 | diff_bias = self->pd()->with_bias() && (jcp.oc % jcp.oc_block != 0) |
346 | && self->pd()->jcp_.bia_dt == data_type::f32 |
347 | ? (void *)scratchpad.template get<float>(key_conv_padded_bias) |
348 | : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS); |
349 | |
350 | tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
351 | if (jcp.global_transpose) |
352 | tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
353 | key_conv_tr_src_bctx); |
354 | |
355 | tr_diff_dst = scratchpad.template get<diff_dst_data_t>( |
356 | key_conv_tr_diff_dst); |
357 | if (jcp.global_transpose) |
358 | tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
359 | key_conv_tr_diff_dst_bctx); |
360 | wei_bia_reduction |
361 | = scratchpad.template get<float>(key_conv_wei_bia_reduction); |
362 | bia_reduction = nullptr; |
363 | if (jcp.with_bias) { |
364 | const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block |
365 | * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
366 | const int num_wei_buffers = jcp.wei_dt != data_type::f32 |
367 | ? jcp.nthr_mb |
368 | : jcp.nthr_mb - 1; |
369 | bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers; |
370 | } |
371 | |
372 | wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
373 | key_conv_wei_bia_reduction_bctx); |
374 | |
375 | ithr_ic_b = ithr % jcp.nthr_ic_b; |
376 | ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; |
377 | ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; |
378 | ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; |
379 | |
380 | ithr_but_oc |
381 | = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b; |
382 | |
383 | ithr_but_ic |
384 | = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_oc_b + ithr_oc_b; |
385 | |
386 | int work_amount = jcp.nthr_mb_work; |
387 | /* reduction dimension */ |
388 | balance211(work_amount, jcp.nthr_mb, ithr_mb, img_start, img_end); |
389 | img_work = img_end - img_start; |
390 | |
391 | /* independent dimensions */ |
392 | balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); |
393 | g_work = g_end - g_start; |
394 | |
395 | balance211(jcp.nb_oc, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); |
396 | oc_b_work = oc_b_end - oc_b_start; |
397 | |
398 | balance211(jcp.nb_ic, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); |
399 | if (jcp.transform_to_vnni) { |
400 | if (ic_b_start % 2 != 0) ic_b_start++; |
401 | if (ic_b_end != jcp.nb_ic && ic_b_end % 2 != 0) ic_b_end++; |
402 | } |
403 | ic_b_work = ic_b_end - ic_b_start; |
404 | |
405 | std::memset(cur_palette.a, 0, AMX_PALETTE_SIZE); |
406 | brgemm_batch_element_t *const __restrict brg_batch_global |
407 | = (jcp.brg_type == brgemm_strd) |
408 | ? nullptr |
409 | : scratchpad.template get<brgemm_batch_element_t>( |
410 | key_brgemm_primitive_batch); |
411 | brg_batch = brg_batch_global |
412 | + static_cast<size_t>(ithr) * jcp.adjusted_batch_size; |
413 | |
414 | auto wsp_tile_global |
415 | = scratchpad.template get<char>(key_conv_amx_tile_buffer); |
416 | wsp_tile = wsp_tile_global + ithr * 2 * brgemm_convolution_utils::P4K; |
417 | } |
418 | |
419 | const pd_t *pd() const { return self->pd(); } |
420 | |
421 | inline int get_inp_start(int out_s, int pad, int str) const { |
422 | return nstl::max(0, -pad + out_s * str); |
423 | } |
424 | |
425 | inline int get_inp_end(int out_e, int is, int pad, int str, int ek) const { |
426 | return nstl::min(is, -pad + (out_e - 1) * str + ek); |
427 | } |
428 | |
429 | inline int get_id_start(int od_s) const { |
430 | return get_inp_start(od_s, jcp.f_pad, jcp.stride_d); |
431 | } |
432 | inline int get_ih_start(int oh_s) const { |
433 | return get_inp_start(oh_s, jcp.t_pad, jcp.stride_h); |
434 | } |
435 | |
436 | inline int get_id_end(int od_e) const { |
437 | return get_inp_end(od_e, jcp.id, jcp.f_pad, jcp.stride_d, jcp.ext_kd); |
438 | } |
439 | inline int get_ih_end(int oh_e) const { |
440 | return get_inp_end(oh_e, jcp.ih, jcp.t_pad, jcp.stride_h, jcp.ext_kh); |
441 | } |
442 | |
443 | size_t tr_src_buf_number(int g, int icb) const { |
444 | return jcp.global_transpose |
445 | ? ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + icb |
446 | : ithr; |
447 | } |
448 | |
449 | size_t tr_diff_dst_buf_number(int g, int ocb) const { |
450 | // for current loop order (xoi) if jcp.tr_ocb_chunk then we can reuse |
451 | // same area in tr_diff_dst buffer |
452 | if (jcp.tr_ocb_chunk) |
453 | return jcp.global_transpose |
454 | ? ((ithr_mb * jcp.ngroups + g) * jcp.nthr_oc_b + ithr_oc_b) |
455 | * jcp.nb_oc_blocking |
456 | + (ocb - oc_b_start) % jcp.nb_oc_blocking |
457 | : ithr; |
458 | else |
459 | return jcp.global_transpose |
460 | ? ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + ocb |
461 | : ithr; |
462 | } |
463 | |
464 | size_t tr_src_off(int g, int icb, int id, int ih) const { |
465 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
466 | const size_t tr_3d_size = tr_row_size * jcp.ih; |
467 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; |
468 | // Aligned to buffer end to use guard elements |
469 | return tr_src_buf_number(g, icb) * adj * jcp.tr_src_buf_size |
470 | + id * tr_3d_size + ih * tr_row_size; |
471 | } |
472 | |
473 | size_t tr_diff_dst_off(int g, int ocb, int od, int oh) const { |
474 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
475 | const size_t tr_3d_size = tr_row_size * jcp.oh; |
476 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; |
477 | return tr_diff_dst_buf_number(g, ocb) * adj * jcp.tr_diff_dst_buf_size |
478 | + od * tr_3d_size + oh * tr_row_size; |
479 | } |
480 | |
481 | void trans_src_nxc(src_data_t *tr_src, const src_data_t *src_base, |
482 | int spatial_start, dim_t spatial_start_offset, int icb_start, |
483 | dim_t chb_stride, int row_count) const { |
484 | const int src_stride = jcp.iw * jcp.ngroups * jcp.ic; |
485 | const int tr_src_stride = jcp.tr_iw * jcp.ic_block; |
486 | |
487 | int work_rest = row_count; |
488 | int max_spatial_work = jcp.id * jcp.ih; |
489 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
490 | const src_data_t *src = src_base + spatial_start_offset; |
491 | int icb = 0; |
492 | const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block; |
493 | while (work_rest > 0) { |
494 | for (int iwork = 0; iwork < sp_work; iwork++) { |
495 | // For 1x1 convolutions with strides we transpose only |
496 | // needed lines |
497 | if (IMPLICATION(jcp.kh == 1, iwork % jcp.stride_h == 0)) { |
498 | auto ctx = jit_trans_src_t::ctx_t(); |
499 | ctx.src = src; |
500 | ctx.tr_src = tr_src; |
501 | assert(icb_start + icb < jcp.nb_ic); |
502 | ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic |
503 | ? ic_tail_work |
504 | : jcp.ic_block; |
505 | ctx.src_prf = nullptr; |
506 | ctx.tr_src_prf = nullptr; |
507 | (*self->trans_kernel_)(&ctx); |
508 | } |
509 | src += src_stride; |
510 | tr_src += tr_src_stride; |
511 | } |
512 | work_rest -= sp_work; |
513 | sp_work = nstl::min(work_rest, max_spatial_work); |
514 | icb++; |
515 | src = src_base + icb * chb_stride; |
516 | } |
517 | } |
518 | |
519 | void trans_dst_nxc(diff_dst_data_t *tr_diff_dst, |
520 | const diff_dst_data_t *diff_dst_base, int spatial_start, |
521 | dim_t spatial_start_offset, int ocb_start, dim_t chb_stride, |
522 | int row_count) const { |
523 | const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc; |
524 | const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block; |
525 | int work_rest = row_count; |
526 | int max_spatial_work = jcp.od * jcp.oh; |
527 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
528 | const src_data_t *diff_dst = diff_dst_base + spatial_start_offset; |
529 | int ocb = 0; |
530 | const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block; |
531 | while (work_rest > 0) { |
532 | for (int iwork = 0; iwork < sp_work; iwork++) { |
533 | auto ctx = jit_trans_dst_t::ctx_t(); |
534 | ctx.src = diff_dst; |
535 | ctx.tr_src = tr_diff_dst; |
536 | assert(ocb_start + ocb < jcp.nb_oc); |
537 | ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work |
538 | : jcp.oc_block; |
539 | ctx.src_prf = nullptr; |
540 | ctx.tr_src_prf = nullptr; |
541 | (*self->trans_dst_kernel_)(&ctx); |
542 | diff_dst += diff_dst_stride; |
543 | tr_diff_dst += tr_diff_dst_stride; |
544 | } |
545 | work_rest -= sp_work; |
546 | sp_work = nstl::min(work_rest, max_spatial_work); |
547 | ocb++; |
548 | diff_dst = diff_dst_base + ocb * chb_stride; |
549 | } |
550 | } |
551 | |
552 | void maybe_global_transpose(int img, int ocb_s, int ocb_e, int icb_s, |
553 | int icb_e, int od_s, int odb_s, int odb_e, int oh_s, int ohb_s, |
554 | int ohb_e) const { |
555 | if (!jcp.global_transpose) return; |
556 | |
557 | using simple_barrier::barrier; |
558 | const int icb_work = icb_e - icb_s; |
559 | const int ocb_work = ocb_e - ocb_s; |
560 | |
561 | // The barrier should stay outside of work condition to avoid |
562 | // possible hang |
563 | if (jcp.nthr_oc_b > 1) |
564 | barrier(&tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); |
565 | |
566 | if (icb_work > 0) { |
567 | const auto id_s = get_id_start(od_s); |
568 | const auto ih_s = get_ih_start(oh_s); |
569 | |
570 | const auto idb_s = get_id_start(odb_s); |
571 | const auto idb_e = get_id_end(odb_e); |
572 | |
573 | const auto ihb_s = get_ih_start(ohb_s); |
574 | const auto ihb_e = get_ih_end(ohb_e); |
575 | |
576 | int work_amount |
577 | = g_work * icb_work * (idb_e - idb_s) * (ihb_e - ihb_s); |
578 | int tr_start {0}, tr_end {0}; |
579 | balance211(work_amount, jcp.nthr_oc_b, ithr_oc_b, tr_start, tr_end); |
580 | |
581 | int g {0}, ic_b {0}, jd {0}, jh {0}; |
582 | nd_iterator_init(tr_start, g, g_work, ic_b, icb_work, jd, |
583 | idb_e - idb_s, jh, ihb_e - ihb_s); |
584 | |
585 | while (tr_start < tr_end) { |
586 | int g_ = g + g_start; |
587 | int ic_b_ = ic_b + icb_s; |
588 | |
589 | int jd_s = jd + idb_s; |
590 | |
591 | int jh_s = jh + ihb_s; |
592 | int jh_e = jh_s + nstl::min(tr_end - tr_start, ihb_e - jh_s); |
593 | |
594 | const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block; |
595 | |
596 | const src_data_t *p_src {nullptr}; |
597 | if (jcp.harness == harness_2d_reduction) { |
598 | p_src = &src[src_d.blk_off(img, ic_off_idx, jh_s)]; |
599 | } else if (jcp.harness == harness_3d_reduction) { |
600 | p_src = &src[src_d.blk_off(img, ic_off_idx, jd_s, jh_s)]; |
601 | } else |
602 | assert(!"Invalid harness type" ); |
603 | |
604 | src_data_t *p_tr_src = &tr_src[tr_src_off( |
605 | g_, ic_b_, jd_s - id_s, jh_s - ih_s)]; |
606 | trans_src_nxc(p_tr_src, p_src, 0, 0, ic_b_, 0, jh_e - jh_s); |
607 | |
608 | nd_iterator_jump(tr_start, tr_end, g, g_work, ic_b, icb_work, |
609 | jd, idb_e - idb_s, jh, ihb_e - ihb_s); |
610 | } |
611 | } |
612 | if (jcp.nthr_oc_b > 1) |
613 | barrier(&tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b); |
614 | |
615 | // The barrier should stay outside of work condition to avoid |
616 | // possible hang |
617 | if (jcp.nthr_ic_b > 1) |
618 | barrier(&tr_diff_dst_bctx[ithr_but_ic], jcp.nthr_ic_b); |
619 | |
620 | if (ocb_work > 0) { |
621 | int jd = 0; |
622 | int jh = 0; |
623 | int work_amount |
624 | = g_work * ocb_work * (odb_e - odb_s) * (ohb_e - ohb_s); |
625 | int tr_start = 0; |
626 | int tr_end = 0; |
627 | balance211(work_amount, jcp.nthr_ic_b, ithr_ic_b, tr_start, tr_end); |
628 | |
629 | int g = 0; |
630 | int oc_b = 0; |
631 | nd_iterator_init(tr_start, g, g_work, oc_b, ocb_work, jd, |
632 | odb_e - odb_s, jh, ohb_e - ohb_s); |
633 | |
634 | while (tr_start < tr_end) { |
635 | int g_ = g + g_start; |
636 | int oc_b_ = oc_b + ocb_s; |
637 | int jd_s = jd + odb_s; |
638 | int jh_s = jh + ohb_s; |
639 | int jh_e = jh_s + nstl::min(tr_end - tr_start, ohb_e - jh_s); |
640 | const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block; |
641 | |
642 | const diff_dst_data_t *p_diff_dst {nullptr}; |
643 | if (jcp.harness == harness_2d_reduction) { |
644 | p_diff_dst = &diff_dst[diff_dst_d.blk_off( |
645 | img, oc_off_idx, jh_s)]; |
646 | } else if (jcp.harness == harness_3d_reduction) { |
647 | p_diff_dst = &diff_dst[diff_dst_d.blk_off( |
648 | img, oc_off_idx, jd_s, jh_s)]; |
649 | } else |
650 | assert(!"Invalid harness type" ); |
651 | |
652 | diff_dst_data_t *p_tr_diff_dst = &tr_diff_dst[tr_diff_dst_off( |
653 | g_, oc_b_, jd_s - od_s, jh_s - oh_s)]; |
654 | trans_dst_nxc( |
655 | p_tr_diff_dst, p_diff_dst, 0, 0, oc_b_, 0, jh_e - jh_s); |
656 | |
657 | nd_iterator_jump(tr_start, tr_end, g, g_work, oc_b, ocb_work, |
658 | jd, odb_e - odb_s, jh, ohb_e - ohb_s); |
659 | } |
660 | } |
661 | if (jcp.nthr_ic_b > 1) |
662 | barrier(&tr_diff_dst_bctx[ithr_but_ic], jcp.nthr_ic_b); |
663 | } |
664 | |
665 | void maybe_local_traspose(void *&p_src, void *&p_dst, int img, int g, |
666 | int ic_b, int oc_b, int od_s, int odb_s, int odb_e, int oh_s, |
667 | int ohb_s, int ohb_e) const { |
668 | |
669 | const int idb_s = get_id_start(odb_s); |
670 | const int ihb_s = get_ih_start(ohb_s); |
671 | |
672 | const int idb_e = get_id_end(odb_e); |
673 | const int ihb_e = get_ih_end(ohb_e); |
674 | |
675 | const int id_s = get_id_start(od_s); |
676 | const int ih_s = get_ih_start(oh_s); |
677 | |
678 | if (jcp.global_transpose) { |
679 | p_src = &tr_src[tr_src_off(g, ic_b, 0, 0)]; |
680 | p_dst = &tr_diff_dst[tr_diff_dst_off(g, oc_b, 0, 0)]; |
681 | return; |
682 | } |
683 | |
684 | const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ic_b_end) |
685 | ? 1 |
686 | : jcp.nb_ic_blocking; |
687 | |
688 | const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > oc_b_end) |
689 | ? 1 |
690 | : jcp.nb_oc_blocking; |
691 | |
692 | for_(int idb = idb_s; idb < idb_e; idb++) |
693 | for (int icb = 0; icb < nb_ic_blocks; icb++) { |
694 | const int ic_off_idx = g * jcp.ic + (ic_b + icb) * jcp.ic_block; |
695 | src_data_t *p_tr_src |
696 | = &tr_src[tr_src_off(0, 0, idb - id_s, ihb_s - ih_s)]; |
697 | src_data_t *tr_src_local = p_tr_src + icb * jcp.tr_src_buf_size; |
698 | const src_data_t *p_raw_src {nullptr}; |
699 | if (jcp.harness == harness_2d_reduction) { |
700 | p_raw_src = (src_data_t |
701 | *)&src[src_d.blk_off(img, ic_off_idx, ihb_s)]; |
702 | } else if (jcp.harness == harness_3d_reduction) { |
703 | p_raw_src = (src_data_t *)&src[src_d.blk_off( |
704 | img, ic_off_idx, idb, ihb_s)]; |
705 | } else |
706 | assert(!"Invalid harness type" ); |
707 | trans_src_nxc(tr_src_local, p_raw_src, 0, 0, (ic_b + icb), 0, |
708 | (ihb_e - ihb_s)); |
709 | } |
710 | |
711 | p_src = &tr_src[tr_src_off(0, 0, 0, 0)]; // p_tr_src; |
712 | |
713 | for_(int odb = odb_s; odb < odb_e; odb++) |
714 | for (int ocb = 0; ocb < nb_oc_blocks; ocb++) { |
715 | const int oc_off_idx = g * jcp.oc + (oc_b + ocb) * jcp.oc_block; |
716 | const diff_dst_data_t *p_raw_diff_dst {nullptr}; |
717 | if (jcp.harness == harness_2d_reduction) { |
718 | p_raw_diff_dst |
719 | = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx, ohb_s)]; |
720 | } else if (jcp.harness == harness_3d_reduction) { |
721 | p_raw_diff_dst = &diff_dst[diff_dst_d.blk_off( |
722 | img, oc_off_idx, odb, ohb_s)]; |
723 | } else |
724 | assert(!"Invalid harness type" ); |
725 | diff_dst_data_t *p_tr_diff_dst = &tr_diff_dst[tr_diff_dst_off( |
726 | 0, 0, odb - od_s, ohb_s - oh_s)]; |
727 | diff_dst_data_t *tr_diff_dst_local |
728 | = p_tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size; |
729 | trans_dst_nxc(tr_diff_dst_local, p_raw_diff_dst, 0, 0, (oc_b + ocb), |
730 | 0, (ohb_e - ohb_s)); |
731 | } |
732 | p_dst = &tr_diff_dst[tr_diff_dst_off(0, 0, 0, 0)]; // p_tr_diff_dst; |
733 | } |
734 | |
735 | bool just_init_output( |
736 | int start, int end, float *diff_wei, float *diff_bias) { |
737 | if (start < end || g_start >= g_end || oc_b_start >= oc_b_end |
738 | || ic_b_start >= ic_b_end) |
739 | return false; |
740 | // for rare case if thread has no work by spatial dimension then we |
741 | // need to initialize the output at least |
742 | if (jcp.with_bias) { |
743 | for_(int g = g_start; g < g_end; ++g) |
744 | { |
745 | void *p_bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
746 | + oc_b_start * jcp.oc_block; |
747 | auto bias_amount = (oc_b_end - oc_b_start) * jcp.oc_block; |
748 | std::memset(p_bias, 0, bias_amount * jcp.acc_dsz); |
749 | } |
750 | } |
751 | |
752 | for_(int g = g_start; g < g_end; ++g) |
753 | for (int oc_b = oc_b_start; oc_b < oc_b_end; oc_b++) { |
754 | auto wei_offs_ext = pd()->ndims() == 3 |
755 | ? wht_blk_off(diff_weights_d, g, oc_b, ic_b_start, 0) |
756 | : (pd()->ndims() == 4 ? wht_blk_off( |
757 | diff_weights_d, g, oc_b, ic_b_start, 0, 0) |
758 | : wht_blk_off(diff_weights_d, g, oc_b, |
759 | ic_b_start, 0, 0, 0)); |
760 | void *ptr_C = (jcp.transform_to_vnni) ? diff_wei |
761 | + self->wei_offset_int(g, oc_b, ic_b_start, 0, 0, 0) |
762 | : diff_wei + wei_offs_ext; |
763 | |
764 | auto C_amount = jcp.kd * jcp.kh * jcp.kw * (ic_b_end - ic_b_start) |
765 | * jcp.ic_block * jcp.oc_block; |
766 | |
767 | std::memset(ptr_C, 0, C_amount * jcp.acc_dsz); |
768 | } |
769 | return true; |
770 | } |
771 | }; |
772 | |
773 | void brgemm_convolution_bwd_weights_t::call_brgemm_kernel( |
774 | thread_info_t &btc, int brg_idx, int batch_size, void *ptr_C) const { |
775 | |
776 | const auto brg_ker = brg_kernels_[brg_idx].get(); |
777 | assert(brg_ker != nullptr); |
778 | |
779 | // TODO: avoid costly tile reconfigurations |
780 | if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
781 | AMX_PALETTE_SIZE) |
782 | != 0) { |
783 | amx_tile_configure(brg_kernel_palettes_[brg_idx].a); |
784 | std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
785 | AMX_PALETTE_SIZE); |
786 | } |
787 | |
788 | brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C, |
789 | static_cast<void *>(btc.wsp_tile)); |
790 | } |
791 | |
792 | void brgemm_convolution_bwd_weights_t::compute_diff_weights_2d( |
793 | thread_info_t *ti) const { |
794 | |
795 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
796 | const auto _pd = pd(); |
797 | const auto &jcp = _pd->jcp_; |
798 | |
799 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
800 | * jcp.ic_block * jcp.kd * jcp.kh * jcp.kw; |
801 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
802 | const int optimal_spblock = jcp.spatial_blk_size; |
803 | |
804 | float *diff_wei; |
805 | if (diff_weights_d.data_type() != data_type::f32) |
806 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
807 | else |
808 | diff_wei = ti->ithr_mb == 0 |
809 | ? (float *)ti->diff_weights |
810 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
811 | |
812 | float *diff_bias = nullptr; |
813 | if (jcp.with_bias) { |
814 | if (jcp.bia_dt != data_type::f32) |
815 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
816 | else |
817 | diff_bias = ti->ithr_mb == 0 |
818 | ? (float *)ti->diff_bias |
819 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
820 | } |
821 | |
822 | int img {0}, oh_s {0}; |
823 | int start = ti->img_start; |
824 | int end = ti->img_end; |
825 | |
826 | int img_s {0}; |
827 | nd_iterator_init(start, img_s, jcp.mb, oh_s, jcp.oh); |
828 | img = img_s; |
829 | |
830 | auto do_brgemm_call = [&](int g, int bs, int ic_b, int oc_b, int ohb_s, |
831 | int bs_ih_s, const void *p_src, |
832 | const void *p_dst, int kh, int kw, |
833 | bool do_init) { |
834 | const int ihb_s = ti->get_ih_start(ohb_s); |
835 | |
836 | const int bs_oh_s = utils::saturate(0, jcp.oh, |
837 | (bs_ih_s + jcp.t_pad - kh * (jcp.dilate_h + 1)) / jcp.stride_h); |
838 | |
839 | auto ocb_end = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end); |
840 | auto icb_end = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end); |
841 | const int src_stride_w_shift = jcp.tr_iw / jcp.stride_w; |
842 | const void *ptr_A = ((src_data_t *)p_src) |
843 | + _pd->filter_w_to_src(kw) / jcp.stride_w |
844 | + (kw % jcp.stride_w) * src_stride_w_shift |
845 | + (bs_ih_s - ihb_s) * jcp.tr_iw * jcp.ic_block; |
846 | const void *ptr_B = ((diff_dst_data_t *)p_dst) |
847 | + (bs_oh_s - ohb_s) * jcp.tr_ow * jcp.oc_block; |
848 | |
849 | void *ptr_C = (jcp.transform_to_vnni) |
850 | ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0, kh, kw) |
851 | : diff_wei |
852 | + (pd()->ndims() == 3 ? wht_blk_off( |
853 | diff_weights_d, g, oc_b, ic_b, kw) |
854 | : wht_blk_off(diff_weights_d, g, |
855 | oc_b, ic_b, kh, kw)); |
856 | bool M_tail = (icb_end < ic_b + jcp.nb_ic_blocking); |
857 | bool N_tail = (ocb_end < oc_b + jcp.nb_oc_blocking); |
858 | |
859 | auto brg_idx = _pd->get_brg_idx( |
860 | bs, M_tail ? jcp.M_tail : jcp.M, do_init, N_tail, false); |
861 | |
862 | for (int ohb = 0; ohb < bs; ohb++) { |
863 | ti->brg_batch[ohb].ptr.A = (char *)ptr_A |
864 | + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block |
865 | * jcp.stride_h; |
866 | ti->brg_batch[ohb].ptr.B = (char *)ptr_B |
867 | + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block; |
868 | } |
869 | |
870 | call_brgemm_kernel(*ti, brg_idx, bs, ptr_C); |
871 | }; |
872 | |
873 | if (ti->just_init_output(start, end, diff_wei, diff_bias)) return; |
874 | |
875 | while (start < end) { |
876 | const int oh_e = _pd->get_finish_oh( |
877 | oh_s, start, get_end(start, jcp.oh_block, end)); |
878 | int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_spblock; |
879 | |
880 | // loop by ohb_s have only one iteration for global_transpose case |
881 | // because height_block = oh_e - oh_s |
882 | for (int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block) { |
883 | const int ohb_e = get_end(ohb_s, height_block, oh_e); |
884 | assert(ohb_e <= jcp.oh); |
885 | |
886 | ti->maybe_global_transpose(img, |
887 | jcp.tr_ocb_chunk ? 0 : ti->oc_b_start, |
888 | jcp.tr_ocb_chunk ? 0 : ti->oc_b_end, |
889 | jcp.tr_icb_chunk ? 0 : ti->ic_b_start, |
890 | jcp.tr_icb_chunk ? 0 : ti->ic_b_end, 0, 0, 1, oh_s, ohb_s, |
891 | ohb_e); |
892 | |
893 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
894 | for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; |
895 | oc_b += jcp.nb_oc_blocking) { |
896 | const int oc_b_e |
897 | = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end); |
898 | |
899 | if (jcp.tr_ocb_chunk) |
900 | ti->maybe_global_transpose(img, oc_b, oc_b_e, 0, 0, 0, 0, 1, |
901 | oh_s, ohb_s, ohb_e); |
902 | |
903 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
904 | ic_b += jcp.nb_ic_blocking) { |
905 | |
906 | const int ic_b_e |
907 | = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end); |
908 | |
909 | if (oc_b == ti->oc_b_start && jcp.tr_icb_chunk) |
910 | ti->maybe_global_transpose(img, 0, 0, ic_b, ic_b_e, 0, |
911 | 0, 1, oh_s, ohb_s, ohb_e); |
912 | |
913 | void *p_src {nullptr}; |
914 | void *p_dst {nullptr}; |
915 | ti->maybe_local_traspose(p_src, p_dst, img, g, ic_b, oc_b, |
916 | 0, 0, 1, oh_s, ohb_s, ohb_e); |
917 | |
918 | if (jcp.with_bias && ic_b == 0) { |
919 | auto bp = jit_conv_call_s(); |
920 | |
921 | bp.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
922 | + oc_b * jcp.oc_block; |
923 | bp.channel |
924 | = (start == ti->img_start) && (ohb_s == oh_s); |
925 | |
926 | bp.os_index_begin = ohb_s; |
927 | bp.os_index_end = ohb_e; |
928 | |
929 | bp.last_oc_block |
930 | = ((oc_b_e - oc_b) == jcp.nb_oc_blocking) ? 0 |
931 | : 1; |
932 | |
933 | bp.dst = p_dst; |
934 | |
935 | (*diff_bias_kernel_)(&bp); |
936 | } |
937 | |
938 | if (ti->g_start == ti->g_end |
939 | || ti->oc_b_start == ti->oc_b_end |
940 | || ti->ic_b_start == ti->ic_b_end) |
941 | continue; |
942 | |
943 | const auto do_init = (start == ti->img_start); |
944 | |
945 | for (int kh = 0; kh < jcp.kh; kh++) { |
946 | const int bs_ih_s = _pd->get_start_ih(kh, ohb_s); |
947 | const int bs_ih_e = _pd->get_finish_ih(kh, ohb_e); |
948 | const auto bs = div_up(bs_ih_e - bs_ih_s, jcp.stride_h); |
949 | if (bs == 0 && !do_init) continue; |
950 | |
951 | for_(int s = 0; s < jcp.stride_w; s++) |
952 | for (int kw = s; kw < jcp.kw; kw += jcp.stride_w) |
953 | do_brgemm_call(g, bs, ic_b, oc_b, ohb_s, bs_ih_s, |
954 | p_src, p_dst, kh, kw, do_init); |
955 | } |
956 | } |
957 | } |
958 | } |
959 | |
960 | nd_iterator_jump(start, get_end(start, jcp.oh_block, end), img, jcp.mb, |
961 | oh_s, jcp.oh); |
962 | } |
963 | } |
964 | |
965 | void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( |
966 | thread_info_t *ti) const { |
967 | |
968 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
969 | const auto _pd = pd(); |
970 | const auto &jcp = _pd->jcp_; |
971 | |
972 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
973 | * jcp.ic_block * jcp.kd * jcp.kh * jcp.kw; |
974 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
975 | const int optimal_spblock = jcp.spatial_blk_size; |
976 | |
977 | float *diff_wei; |
978 | if (diff_weights_d.data_type() != data_type::f32) |
979 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
980 | else |
981 | diff_wei = ti->ithr_mb == 0 |
982 | ? (float *)ti->diff_weights |
983 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
984 | |
985 | float *diff_bias = nullptr; |
986 | if (jcp.with_bias) { |
987 | if (jcp.bia_dt != data_type::f32) |
988 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
989 | else |
990 | diff_bias = ti->ithr_mb == 0 |
991 | ? (float *)ti->diff_bias |
992 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
993 | } |
994 | |
995 | int img {0}, od_s {0}; |
996 | int start = ti->img_start; |
997 | int end = ti->img_end; |
998 | |
999 | int img_s {0}; |
1000 | nd_iterator_init(start, img_s, jcp.mb, od_s, jcp.od); |
1001 | img = img_s; |
1002 | |
1003 | auto do_brgemm_call = [&](int g, int bs_d, int bs_h, int ic_b, int oc_b, |
1004 | int od_s, int oh_s, int bs_id_s, int bs_ih_s, |
1005 | const void *p_src, const void *p_dst, int kd, |
1006 | int kh, int kw, bool do_init) { |
1007 | const int id_s = ti->get_id_start(od_s); |
1008 | const int ih_s = ti->get_ih_start(oh_s); |
1009 | |
1010 | const int bs_od_s = utils::saturate(0, jcp.od, |
1011 | (bs_id_s + jcp.f_pad - kd * (jcp.dilate_d + 1)) / jcp.stride_d); |
1012 | |
1013 | const int bs_oh_s = utils::saturate(0, jcp.oh, |
1014 | (bs_ih_s + jcp.t_pad - kh * (jcp.dilate_h + 1)) / jcp.stride_h); |
1015 | |
1016 | auto ocb_end = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end); |
1017 | auto icb_end = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end); |
1018 | const int src_stride_w_shift = jcp.tr_iw / jcp.stride_w; |
1019 | const void *ptr_A = ((src_data_t *)p_src) |
1020 | + _pd->filter_w_to_src(kw) / jcp.stride_w |
1021 | + (kw % jcp.stride_w) * src_stride_w_shift |
1022 | + (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block |
1023 | + (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block; |
1024 | const void *ptr_B = ((diff_dst_data_t *)p_dst) |
1025 | + (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block |
1026 | + (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block; |
1027 | void *ptr_C = (jcp.transform_to_vnni) |
1028 | ? diff_wei + wei_offset_int(g, oc_b, ic_b, kd, kh, kw) |
1029 | : diff_wei |
1030 | + wht_blk_off( |
1031 | diff_weights_d, g, oc_b, ic_b, kd, kh, kw); |
1032 | bool M_tail = (icb_end < ic_b + jcp.nb_ic_blocking); |
1033 | bool N_tail = (ocb_end < oc_b + jcp.nb_oc_blocking); |
1034 | |
1035 | const auto bs = bs_d * bs_h; |
1036 | auto brg_idx = _pd->get_brg_idx( |
1037 | bs, M_tail ? jcp.M_tail : jcp.M, do_init, N_tail, false); |
1038 | |
1039 | for (int odb = 0; odb < bs_d; odb++) { |
1040 | for (int ohb = 0; ohb < bs_h; ohb++) { |
1041 | ti->brg_batch[odb * bs_h + ohb].ptr.A = (char *)ptr_A |
1042 | + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block |
1043 | * jcp.stride_h |
1044 | + odb * jcp.typesize_in * jcp.ih * jcp.tr_iw |
1045 | * jcp.ic_block * jcp.stride_d; |
1046 | ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B |
1047 | + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block |
1048 | + odb * jcp.typesize_in * jcp.oh * jcp.tr_ow |
1049 | * jcp.oc_block; |
1050 | } |
1051 | } |
1052 | |
1053 | call_brgemm_kernel(*ti, brg_idx, bs, ptr_C); |
1054 | }; |
1055 | |
1056 | if (ti->just_init_output(start, end, diff_wei, diff_bias)) return; |
1057 | |
1058 | const auto oh_s = 0; |
1059 | const auto oh_e = jcp.oh; |
1060 | |
1061 | while (start < end) { |
1062 | const int od_e = _pd->get_finish_od( |
1063 | od_s, start, get_end(start, jcp.od_block, end)); |
1064 | int sp_block = jcp.global_transpose ? od_e - od_s : optimal_spblock; |
1065 | |
1066 | // loop by odb_s have only one iteration for global_transpose case |
1067 | // because sp_block = od_e - od_s |
1068 | for (int odb_s = od_s; odb_s < od_e; odb_s += sp_block) { |
1069 | const int odb_e = get_end(odb_s, sp_block, od_e); |
1070 | assert(odb_e <= jcp.od); |
1071 | |
1072 | for (int ohb_s = oh_s; ohb_s < oh_e; ohb_s += jcp.oh_block) { |
1073 | const auto ohb_e = get_end(ohb_s, jcp.oh_block, jcp.oh); |
1074 | |
1075 | ti->maybe_global_transpose(img, |
1076 | jcp.tr_ocb_chunk ? 0 : ti->oc_b_start, |
1077 | jcp.tr_ocb_chunk ? 0 : ti->oc_b_end, |
1078 | jcp.tr_icb_chunk ? 0 : ti->ic_b_start, |
1079 | jcp.tr_icb_chunk ? 0 : ti->ic_b_end, od_s, odb_s, odb_e, |
1080 | oh_s, ohb_s, ohb_e); |
1081 | |
1082 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1083 | for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; |
1084 | oc_b += jcp.nb_oc_blocking) { |
1085 | const int oc_b_e |
1086 | = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end); |
1087 | if (jcp.tr_ocb_chunk) |
1088 | ti->maybe_global_transpose(img, oc_b, oc_b_e, 0, 0, |
1089 | od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e); |
1090 | |
1091 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1092 | ic_b += jcp.nb_ic_blocking) { |
1093 | |
1094 | const int ic_b_e = get_end( |
1095 | ic_b, jcp.nb_ic_blocking, ti->ic_b_end); |
1096 | |
1097 | if (oc_b == ti->oc_b_start && jcp.tr_icb_chunk) |
1098 | ti->maybe_global_transpose(img, 0, 0, ic_b, ic_b_e, |
1099 | od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e); |
1100 | |
1101 | void *p_src {nullptr}; |
1102 | void *p_dst {nullptr}; |
1103 | ti->maybe_local_traspose(p_src, p_dst, img, g, ic_b, |
1104 | oc_b, od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e); |
1105 | |
1106 | if (jcp.with_bias && ic_b == 0) { |
1107 | for (int iodb = odb_s; iodb < odb_e; iodb++) { |
1108 | auto bp = jit_conv_call_s(); |
1109 | |
1110 | bp.bias = diff_bias |
1111 | + g * rnd_up(jcp.oc, jcp.oc_block) |
1112 | + oc_b * jcp.oc_block; |
1113 | bp.os_index_begin = ohb_s; |
1114 | bp.os_index_end = ohb_e; |
1115 | |
1116 | bp.last_oc_block |
1117 | = ((oc_b_e - oc_b) |
1118 | == jcp.nb_oc_blocking) |
1119 | ? 0 |
1120 | : 1; |
1121 | |
1122 | bp.channel = (start == ti->img_start) |
1123 | && (odb_s == od_s) && (iodb == odb_s) |
1124 | && (ohb_s == oh_s); |
1125 | bp.dst = ((diff_dst_data_t *)p_dst) |
1126 | + (iodb - od_s) * jcp.oh * jcp.tr_ow |
1127 | * jcp.oc_block |
1128 | + (ohb_s - oh_s) * jcp.tr_ow |
1129 | * jcp.oc_block; |
1130 | (*diff_bias_kernel_)(&bp); |
1131 | } |
1132 | } |
1133 | |
1134 | if (ti->g_start == ti->g_end |
1135 | || ti->oc_b_start == ti->oc_b_end |
1136 | || ti->ic_b_start == ti->ic_b_end) |
1137 | continue; |
1138 | |
1139 | const auto do_init |
1140 | = (start == ti->img_start && ohb_s == oh_s); |
1141 | |
1142 | for (int kd = 0; kd < jcp.kd; kd++) { |
1143 | const int bs_id_s = _pd->get_start_id(kd, odb_s); |
1144 | const int bs_id_e = _pd->get_finish_id(kd, odb_e); |
1145 | const auto bs_d |
1146 | = div_up(bs_id_e - bs_id_s, jcp.stride_d); |
1147 | // bs_d may be 0 but we may still need to call brgemm to |
1148 | // initialize output |
1149 | if (bs_d == 0 && !do_init) continue; |
1150 | |
1151 | for (int kh = 0; kh < jcp.kh; kh++) { |
1152 | const int bs_ih_s |
1153 | = _pd->get_start_ih(kh, ohb_s); |
1154 | const int bs_ih_e |
1155 | = _pd->get_finish_ih(kh, ohb_e); |
1156 | const auto bs_h = div_up( |
1157 | bs_ih_e - bs_ih_s, jcp.stride_h); |
1158 | if (bs_h == 0 && !do_init) continue; |
1159 | |
1160 | for_(int s = 0; s < jcp.stride_w; s++) |
1161 | for (int kw = s; kw < jcp.kw; |
1162 | kw += jcp.stride_w) |
1163 | do_brgemm_call(g, bs_d, bs_h, ic_b, oc_b, |
1164 | od_s, oh_s, bs_id_s, bs_ih_s, p_src, |
1165 | p_dst, kd, kh, kw, do_init); |
1166 | } |
1167 | } |
1168 | } |
1169 | } |
1170 | } |
1171 | } |
1172 | |
1173 | nd_iterator_jump(start, get_end(start, jcp.od_block, end), img, jcp.mb, |
1174 | od_s, jcp.od); |
1175 | } |
1176 | } |
1177 | |
1178 | void brgemm_convolution_bwd_weights_t::store_in_vnni_format( |
1179 | thread_info_t *ti) const { |
1180 | const auto &jcp = pd()->jcp_; |
1181 | |
1182 | for_(int g = ti->g_start; g < ti->g_end; g++) |
1183 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) |
1184 | for_(int ic_b = ti->ic_b_start; ic_b < ti->ic_b_start + ti->ic_b_work; |
1185 | ic_b += 2) |
1186 | { |
1187 | jit_conv_call_s p = jit_conv_call_s(); |
1188 | |
1189 | bfloat16_t *output = (bfloat16_t *)ti->diff_weights |
1190 | + wei_offset_ext(g, oc_b, (ic_b / 2), 0); |
1191 | float *input = ti->wei_bia_reduction + wei_offset_int(g, oc_b, ic_b, 0); |
1192 | |
1193 | p.src = (void *)input; |
1194 | p.dst = (void *)output; |
1195 | p.last_ic_block = ((ic_b + 1) >= jcp.nb_ic) ? 1 : 0; |
1196 | (*diff_wei_trans_kernel_)(&p); |
1197 | } |
1198 | } |
1199 | |
1200 | void brgemm_convolution_bwd_weights_t::reduce_and_convert_diff_weights_and_bias( |
1201 | thread_info_t *ti) const { |
1202 | |
1203 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1204 | |
1205 | const auto &jcp = pd()->jcp_; |
1206 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1207 | * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1); |
1208 | |
1209 | const auto wei_dt = diff_weights_d.data_type(); |
1210 | const auto bia_dt = jcp.bia_dt; |
1211 | const bool is_f32_out = wei_dt == data_type::f32; |
1212 | const bool is_f32_bias = bia_dt == data_type::f32; |
1213 | |
1214 | if (jcp.nthr_mb == 1) { |
1215 | if (!is_f32_out) { |
1216 | // reduction is not required, only conversion |
1217 | if (jcp.transform_to_vnni) { |
1218 | store_in_vnni_format(ti); |
1219 | } else { |
1220 | for_(int g = ti->g_start; g < ti->g_end; g++) |
1221 | for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) { |
1222 | const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh |
1223 | * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1) |
1224 | * jcp.ic_block * jcp.oc_block; |
1225 | const size_t off = wht_blk_off( |
1226 | diff_weights_d, g, oc_b, ti->ic_b_start); |
1227 | types::cvt_from_float(wei_dt, |
1228 | (void *)((char *)ti->diff_weights |
1229 | + off * types::data_type_size(wei_dt)), |
1230 | (ti->wei_bia_reduction + off), acc_size); |
1231 | } |
1232 | } |
1233 | } |
1234 | if (pd()->with_bias() && !is_f32_bias && ti->ithr_ic_b == 0 |
1235 | && ti->ic_b_work > 0) { |
1236 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1237 | int result_start_idx |
1238 | = g * jcp.oc + ti->oc_b_start * jcp.oc_block; |
1239 | int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1240 | + ti->oc_b_start * jcp.oc_block; |
1241 | const size_t acc_size |
1242 | = nstl::min(jcp.oc, ti->oc_b_end * jcp.oc_block) |
1243 | - ti->oc_b_start * jcp.oc_block; |
1244 | void *diff_bias = (char *)ti->diff_bias |
1245 | + result_start_idx * types::data_type_size(bia_dt); |
1246 | float *buffer = ti->bia_reduction + buffer_start_idx; |
1247 | types::cvt_from_float( |
1248 | bia_dt, diff_bias, (const float *)buffer, acc_size); |
1249 | } |
1250 | } |
1251 | return; |
1252 | } |
1253 | |
1254 | /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ |
1255 | if (jcp.global_transpose) |
1256 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr); |
1257 | |
1258 | const int ic_b_kh_work |
1259 | = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1260 | if (ic_b_kh_work <= 0 || ti->oc_b_work == 0 || ti->g_work == 0) { |
1261 | // TODO: double check if a barrier is needed here |
1262 | // and at the end of function |
1263 | if (jcp.transform_to_vnni && jcp.global_transpose) |
1264 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr); |
1265 | return; |
1266 | } |
1267 | |
1268 | const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; |
1269 | |
1270 | int start {0}, end {0}; |
1271 | balance211(work, jcp.nthr_mb, ti->ithr_mb, start, end); |
1272 | if (!jcp.transform_to_vnni && start == end) return; |
1273 | |
1274 | const int _start_nthr_mb = 1; |
1275 | for (int thr_mb = _start_nthr_mb; thr_mb < jcp.nthr_mb; ++thr_mb) { |
1276 | int w = start; |
1277 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0}; |
1278 | nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, |
1279 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1280 | while (w < end) { |
1281 | const int g = ti->g_start + sub_g_start; |
1282 | const int oc_b = ti->oc_b_start + sub_oc_b_start; |
1283 | const int ic_b = ti->ic_b_start |
1284 | + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1285 | const int kX |
1286 | = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1287 | |
1288 | const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block |
1289 | * ((jcp.ndims == 5) ? jcp.kh : 1) |
1290 | * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start); |
1291 | |
1292 | const size_t off_ext |
1293 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX); |
1294 | const size_t off_int = (jcp.transform_to_vnni) |
1295 | ? wei_offset_int(g, oc_b, ic_b, kX) |
1296 | : off_ext; |
1297 | |
1298 | float *wei_reduced = is_f32_out |
1299 | ? (float *)(ti->diff_weights) + off_ext |
1300 | : ti->wei_bia_reduction + off_int; |
1301 | |
1302 | int thr_mb_buffer_idx = is_f32_out ? thr_mb - 1 : thr_mb; |
1303 | float *wei_to_reduce = ti->wei_bia_reduction |
1304 | + thr_mb_buffer_idx * wei_size + off_int; |
1305 | |
1306 | if (!jcp.transform_to_vnni && !is_f32_out |
1307 | && thr_mb == jcp.nthr_mb - 1) { |
1308 | // the last iteration for bfloat16 requires conversion and |
1309 | // store to diff_weights array |
1310 | if (wei_dt == bf16) |
1311 | add_floats_and_cvt_to_bfloat16( |
1312 | (bfloat16_t *)(ti->diff_weights) + off_ext, |
1313 | wei_reduced, wei_to_reduce, acc_size); |
1314 | else if (wei_dt == f16) |
1315 | add_floats_and_cvt_to_float16( |
1316 | (float16_t *)(ti->diff_weights) + off_ext, |
1317 | wei_reduced, wei_to_reduce, acc_size); |
1318 | } else |
1319 | acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size); |
1320 | |
1321 | nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, |
1322 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1323 | } |
1324 | if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0 |
1325 | && ti->ithr_mb == 0 && ti->img_work > 0) { |
1326 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1327 | float *bias_reduced = is_f32_bias ? (float *)(ti->diff_bias) |
1328 | : ti->bia_reduction; |
1329 | int thr_mb_buffer_idx = is_f32_bias ? thr_mb - 1 : thr_mb; |
1330 | int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1331 | float *bias_to_reduce |
1332 | = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size; |
1333 | const size_t acc_size |
1334 | = nstl::min(jcp.oc, ti->oc_b_end * jcp.oc_block) |
1335 | - ti->oc_b_start * jcp.oc_block; |
1336 | int idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1337 | + ti->oc_b_start * jcp.oc_block; |
1338 | if (!is_f32_bias && thr_mb == jcp.nthr_mb - 1) { |
1339 | // the last iteration for bfloat16 requires conversion and |
1340 | // store to diff_weights array |
1341 | int diff_bias_idx |
1342 | = g * jcp.oc + ti->oc_b_start * jcp.oc_block; |
1343 | if (bia_dt == bf16) |
1344 | add_floats_and_cvt_to_bfloat16( |
1345 | (bfloat16_t *)(ti->diff_bias) + diff_bias_idx, |
1346 | &bias_reduced[idx], &bias_to_reduce[idx], |
1347 | acc_size); |
1348 | else if (bia_dt == f16) |
1349 | add_floats_and_cvt_to_float16( |
1350 | (float16_t *)(ti->diff_bias) + diff_bias_idx, |
1351 | &bias_reduced[idx], &bias_to_reduce[idx], |
1352 | acc_size); |
1353 | } else { |
1354 | acc_ker_->accumulate( |
1355 | &bias_reduced[idx], &bias_to_reduce[idx], acc_size); |
1356 | } |
1357 | } |
1358 | } |
1359 | } |
1360 | |
1361 | if (jcp.transform_to_vnni && jcp.global_transpose) { |
1362 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr); |
1363 | store_in_vnni_format(ti); |
1364 | } |
1365 | } |
1366 | |
1367 | void brgemm_convolution_bwd_weights_t::prepare_scratchpad_data( |
1368 | const exec_ctx_t &ctx) const { |
1369 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1370 | |
1371 | const auto &jcp = pd()->jcp_; |
1372 | |
1373 | auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
1374 | // Zero out guard elements that cross a buffer boundary to prevent a |
1375 | // race condition due to buffer overflows from memory optimization where |
1376 | // buffers sharing padding |
1377 | // TODO: optimize it |
1378 | for (size_t isb = 1; isb <= jcp.tr_src_buf_count; ++isb) { |
1379 | src_data_t *ts |
1380 | = &tr_src[isb * jcp.tr_src_buf_size * jcp.nb_ic_blocking]; |
1381 | for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i) |
1382 | ts[i] = 0; |
1383 | } |
1384 | |
1385 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
1386 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
1387 | auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1388 | key_conv_tr_src_bctx); |
1389 | for (int i = 0; i < tr_src_bctx_size; ++i) |
1390 | simple_barrier::ctx_init(&tr_src_bctx[i]); |
1391 | } |
1392 | if (jcp.global_transpose) { |
1393 | if (jcp.nthr_ic_b > 1) { |
1394 | const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
1395 | auto tr_diff_dst_bctx |
1396 | = scratchpad.template get<simple_barrier::ctx_t>( |
1397 | key_conv_tr_diff_dst_bctx); |
1398 | for (int i = 0; i < tr_diff_dst_bctx_size; ++i) |
1399 | simple_barrier::ctx_init(&tr_diff_dst_bctx[i]); |
1400 | } |
1401 | } |
1402 | |
1403 | if (jcp.nthr_mb > 1 |
1404 | || pd()->diff_weights_md(0)->data_type != data_type::f32) { |
1405 | // TODO: don't use barrier for case |
1406 | // diff_weights_type != data_type::f32 && nthr_mb_ == 1 |
1407 | simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( |
1408 | key_conv_wei_bia_reduction_bctx)); |
1409 | } |
1410 | } |
1411 | |
1412 | void brgemm_convolution_bwd_weights_t::execute_backward_weights( |
1413 | const exec_ctx_t &ctx) const { |
1414 | prepare_scratchpad_data(ctx); |
1415 | |
1416 | const auto &jcp = pd()->jcp_; |
1417 | |
1418 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1419 | assert(jcp.nthr == nthr); |
1420 | assert(utils::one_of(pd()->ndims(), 3, 4, 5)); |
1421 | |
1422 | thread_info_t thread_info(this, ctx, ithr); |
1423 | switch (jcp.harness) { |
1424 | case harness_2d_reduction: |
1425 | compute_diff_weights_2d(&thread_info); |
1426 | if (jcp.global_transpose) |
1427 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1428 | break; |
1429 | case harness_3d_reduction: |
1430 | compute_diff_weights_3d(&thread_info); |
1431 | if (jcp.global_transpose) |
1432 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1433 | break; |
1434 | default: assert(!"Invalid harness type" ); |
1435 | } |
1436 | |
1437 | amx_tile_release(); |
1438 | }); |
1439 | |
1440 | if (!jcp.global_transpose) { |
1441 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1442 | assert(jcp.nthr == nthr); |
1443 | thread_info_t thread_info(this, ctx, ithr); |
1444 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1445 | }); |
1446 | } |
1447 | |
1448 | if (jcp.transform_to_vnni && !jcp.global_transpose) { |
1449 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1450 | assert(jcp.nthr == nthr); |
1451 | thread_info_t thread_info(this, ctx, ithr); |
1452 | store_in_vnni_format(&thread_info); |
1453 | }); |
1454 | } |
1455 | |
1456 | if (pd()->with_bias() && (jcp.oc % jcp.oc_block != 0) |
1457 | && jcp.bia_dt == data_type::f32) { |
1458 | auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>( |
1459 | key_conv_padded_bias); |
1460 | auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
1461 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
1462 | const int stride = jcp.oc; |
1463 | for (int g = 0; g < jcp.ngroups; ++g) { |
1464 | utils::array_copy(diff_bias_in + g * stride, |
1465 | diff_bias + g * padded_stride, stride); |
1466 | } |
1467 | } |
1468 | } |
1469 | |
1470 | } // namespace x64 |
1471 | |
1472 | } // namespace cpu |
1473 | } // namespace impl |
1474 | } // namespace dnnl |
1475 | |
1476 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1477 | |