1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | #include "cpu/scale_utils.hpp" |
24 | |
25 | #include "cpu/x64/amx_tile_configure.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | #include "cpu/x64/jit_brgemm_inner_product.hpp" |
29 | #include "cpu/x64/jit_transpose_utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | using namespace dnnl::impl::cpu::x64::brgemm_inner_product_utils; |
37 | using namespace dnnl::impl::data_type; |
38 | using namespace dnnl::impl::format_tag; |
39 | using namespace dnnl::impl::memory_tracking::names; |
40 | using namespace dnnl::impl::status; |
41 | using namespace dnnl::impl::utils; |
42 | |
43 | using namespace nstl; |
44 | |
45 | #define get_blk_off(d, dt, ...) \ |
46 | (types::data_type_size((dt)) * (d).blk_off(__VA_ARGS__)) |
47 | |
48 | namespace { |
49 | template <typename ker_type> |
50 | void copy_data_chunk(ker_type &ker, char *tr_data, const char *data, |
51 | int os_work, bool is_last_blk) { |
52 | auto ctx = jit_brgemm_copy_to_coarse_t::ctx_t(); |
53 | ctx.data = (void *)data; |
54 | ctx.tr_data = (void *)tr_data; |
55 | ctx.os_work = os_work; |
56 | ctx.last_row_blk = is_last_blk ? 1 : 0; |
57 | (*ker)(&ctx); |
58 | } |
59 | } // namespace |
60 | |
61 | template <cpu_isa_t isa> |
62 | status_t brgemm_inner_product_fwd_t<isa>::execute_forward( |
63 | const exec_ctx_t &ctx) const { |
64 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
65 | auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
66 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
67 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
68 | const auto post_ops_binary_rhs_arg_vec |
69 | = binary_injector::prepare_binary_args( |
70 | pd()->attr()->post_ops_, ctx); |
71 | |
72 | memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
73 | const memory_desc_wrapper src_d(pd()->src_md()); |
74 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
75 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
76 | |
77 | const auto &jbgp = pd()->jbgp_; |
78 | |
79 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
80 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
81 | |
82 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
83 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
84 | |
85 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
86 | |
87 | const size_t src_dt_size = types::data_type_size(jbgp.src_dt); |
88 | const size_t bia_dt_size |
89 | = jbgp.with_bias ? types::data_type_size(jbgp.bia_dt) : 0; |
90 | const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt); |
91 | const size_t dst_dt_size = types::data_type_size(jbgp.dst_dt); |
92 | |
93 | auto addr_batch_global = scratchpad.template get<brgemm_batch_element_t>( |
94 | key_brgemm_primitive_batch); |
95 | auto a_buffer_global = (jbgp.use_buffer_a) |
96 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a) |
97 | : nullptr; |
98 | auto c_buffer_global = (jbgp.use_buffer) |
99 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
100 | : nullptr; |
101 | const bool is_amx = jbgp.is_amx; |
102 | auto wsp_tile_base = is_amx |
103 | ? ctx.get_scratchpad_grantor().template get<char>( |
104 | key_conv_amx_tile_buffer) |
105 | : nullptr; |
106 | |
107 | const int ic_chunks = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking); |
108 | |
109 | const bool are_post_ops_applicable = one_of(true, jbgp.with_sum, |
110 | jbgp.with_bias, jbgp.with_scales, jbgp.with_eltwise, |
111 | jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input); |
112 | |
113 | size_t offset = types::data_type_size(jbgp.wei_dt) |
114 | * (weights_d.size() - weights_d.additional_buffer_size()); |
115 | auto compensation = (jbgp.signed_input) |
116 | ? reinterpret_cast<const int32_t *>(&weights[offset]) |
117 | : nullptr; |
118 | |
119 | bool is_os_tail = (jbgp.mb < jbgp.os_block); |
120 | bool is_oc_tail = (jbgp.oc < jbgp.oc_block); |
121 | int base_brg_ker_idx = brgemm_inner_product_utils:: |
122 | get_brg_kernel_index( // TODO: Can be calculated on initialization stage |
123 | jbgp, false, false, is_os_tail, is_oc_tail, false); |
124 | |
125 | const auto ker = [&](int ithr_oc_mb, int nthr_oc_mb, int ithr_ic, int n, |
126 | int ocb, int icc, bool do_init, int buffer_a_osb, |
127 | bool copy_buffer_a) { |
128 | const int ithr = nthr_oc_mb * ithr_ic + ithr_oc_mb; |
129 | auto addr_batch = addr_batch_global + ithr * jbgp.adjusted_batch_size; |
130 | |
131 | const size_t a_buffer_osb_stride |
132 | = src_dt_size * jbgp.LDA * jbgp.os_block; |
133 | const size_t a_buffer_per_thr |
134 | = a_buffer_osb_stride * jbgp.nb_os_blocking; |
135 | auto a_buffer = (jbgp.use_buffer_a) |
136 | ? a_buffer_global + ithr * a_buffer_per_thr |
137 | + buffer_a_osb * a_buffer_osb_stride |
138 | : nullptr; |
139 | |
140 | const int oc = ocb * jbgp.oc_block; |
141 | const size_t dst_off = get_blk_off(dst_d, jbgp.dst_dt, n, oc); |
142 | |
143 | const bool use_c_buffer = (jbgp.with_sum) |
144 | || (jbgp.use_buffer && (jbgp.nthr_ic_b == 1 || ithr_ic > 0)); |
145 | |
146 | char *c_buffer = nullptr; |
147 | if (use_c_buffer) { |
148 | const size_t c_buf_thr_idx = jbgp.nthr_ic_b <= 1 |
149 | ? ithr |
150 | : (jbgp.acc_dt != jbgp.dst_dt || jbgp.with_sum |
151 | ? ithr_ic |
152 | : ithr_ic - 1); |
153 | const size_t c_buf_num_rows = jbgp.nthr_ic_b > 1 ? jbgp.mb : jbgp.M; |
154 | const size_t c_buffer_shift |
155 | = c_buf_thr_idx * c_buf_num_rows * jbgp.LDC; |
156 | const size_t c_buffer_off = acc_dt_size * c_buffer_shift |
157 | + (jbgp.nthr_ic_b > 1 ? acc_dt_size * dst_off / dst_dt_size |
158 | : 0); |
159 | c_buffer = c_buffer_global + c_buffer_off; |
160 | } |
161 | |
162 | char *wsp_tile = is_amx |
163 | ? wsp_tile_base + ithr * jbgp.amx_buf_size_per_thread |
164 | : nullptr; |
165 | int icb = icc * jbgp.nb_ic_blocking; |
166 | int ic = icb * jbgp.ic_block; |
167 | |
168 | bool kernel_init = do_init; |
169 | |
170 | bool is_os_tail = (jbgp.mb - n < jbgp.os_block); |
171 | bool is_oc_tail = (jbgp.oc - oc < jbgp.oc_block); |
172 | bool is_last_ic_chunk = icc == ic_chunks - 1; |
173 | bool is_ic_tail = is_last_ic_chunk && jbgp.K_tail > 0; |
174 | const int remaining_ic_blks |
175 | = (jbgp.use_buffer_a ? utils::rnd_up(jbgp.ic, jbgp.ic_block) |
176 | : jbgp.ic) |
177 | - ic; |
178 | const int gemm_batch |
179 | = nstl::min(jbgp.gemm_batch_size, remaining_ic_blks / jbgp.K); |
180 | |
181 | auto is_bs_tail = (gemm_batch != jbgp.gemm_batch_size); |
182 | int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index( |
183 | jbgp, is_bs_tail, kernel_init, is_os_tail, is_oc_tail, false); |
184 | auto brg_kernel = brg_kernels_[brg_ker_idx].get(); |
185 | |
186 | if (copy_buffer_a) { |
187 | assert(!jbgp.is_bf32); |
188 | auto src_ptr = src + get_blk_off(src_d, jbgp.src_dt, n, ic); |
189 | copy_data_chunk(copy_src_kernel_, a_buffer, src_ptr, |
190 | is_os_tail ? jbgp.mb - n : jbgp.os_block, is_last_ic_chunk); |
191 | } |
192 | if (gemm_batch > 0 && brg_kernel != nullptr) { |
193 | if (is_amx && (is_os_tail || is_oc_tail)) |
194 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
195 | const int ic_blocks_per_batch = jbgp.K / jbgp.ic_block; |
196 | for (int b = 0; b < gemm_batch; b++) { |
197 | auto A_ptr = jbgp.use_buffer_a |
198 | ? (a_buffer + src_dt_size * b * jbgp.K) |
199 | : (src |
200 | + get_blk_off(src_d, jbgp.src_dt, n, |
201 | ic + b * jbgp.K)); |
202 | addr_batch[b].ptr.A = A_ptr; |
203 | addr_batch[b].ptr.B = weights |
204 | + get_blk_off(weights_d, jbgp.wei_dt, ocb, |
205 | icb + b * ic_blocks_per_batch); |
206 | } |
207 | |
208 | auto ptr_D = dst + dst_off; |
209 | auto ptr_C = use_c_buffer ? c_buffer : ptr_D; |
210 | |
211 | if (jbgp.nthr_ic_b == 1 && are_post_ops_applicable |
212 | && is_last_ic_chunk && !is_ic_tail) { |
213 | void *scratch = is_amx |
214 | ? static_cast<void *>(wsp_tile) |
215 | : (jbgp.signed_input ? static_cast<void *>( |
216 | const_cast<int *>(&compensation[oc])) |
217 | : nullptr); |
218 | auto ptr_bias |
219 | = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr; |
220 | const brgemm_post_ops_data_t post_ops_data { |
221 | static_cast<const void *>(ptr_bias), |
222 | &oscales[jbgp.is_oc_scale * oc], |
223 | post_ops_binary_rhs_arg_vec.data(), |
224 | static_cast<size_t>(oc), 0, dst}; |
225 | |
226 | brgemm_kernel_execute_postops(brg_kernel, gemm_batch, |
227 | addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data, |
228 | scratch); |
229 | } else { |
230 | brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch, |
231 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
232 | } |
233 | |
234 | if (is_amx && (is_os_tail || is_oc_tail)) |
235 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
236 | } |
237 | |
238 | if (is_ic_tail) { |
239 | assert(!jbgp.use_buffer_a); |
240 | int ic_block = gemm_batch * jbgp.K / jbgp.ic_block; |
241 | addr_batch[0].ptr.A = src |
242 | + get_blk_off(src_d, jbgp.src_dt, n, |
243 | ic + ic_block * jbgp.ic_block); |
244 | addr_batch[0].ptr.B = weights |
245 | + get_blk_off(weights_d, jbgp.wei_dt, ocb, icb + ic_block); |
246 | |
247 | auto use_init_ker = (kernel_init && gemm_batch == 0); |
248 | int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index( |
249 | jbgp, false, use_init_ker, is_os_tail, is_oc_tail, true); |
250 | auto brg_kernel_ic_tail = brg_kernels_[brg_ker_idx].get(); |
251 | if (is_amx) |
252 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
253 | auto ptr_D = dst + dst_off; |
254 | auto ptr_C = use_c_buffer ? c_buffer : ptr_D; |
255 | if (jbgp.nthr_ic_b == 1 && are_post_ops_applicable) { |
256 | void *scratch = is_amx |
257 | ? static_cast<void *>(wsp_tile) |
258 | : (jbgp.signed_input ? static_cast<void *>( |
259 | const_cast<int *>(&compensation[oc])) |
260 | : nullptr); |
261 | auto ptr_bias |
262 | = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr; |
263 | const brgemm_post_ops_data_t post_ops_data { |
264 | static_cast<const void *>(ptr_bias), |
265 | &oscales[jbgp.is_oc_scale * oc], |
266 | post_ops_binary_rhs_arg_vec.data(), |
267 | static_cast<size_t>(oc), 0, dst}; |
268 | |
269 | brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch, |
270 | (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch); |
271 | } else { |
272 | brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch, |
273 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
274 | } |
275 | if (is_amx) |
276 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
277 | } |
278 | }; |
279 | |
280 | const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
281 | const int oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking); |
282 | const int work_amount = oc_chunks * os_chunks; |
283 | |
284 | const auto init_thr_groups |
285 | = [&](const int ithr, const int nthr, int &nthr_ic, int &nthr_oc_mb, |
286 | int &ithr_ic, int &ithr_oc_mb) { |
287 | nthr_ic = jbgp.nthr_ic_b <= nthr ? jbgp.nthr_ic_b : 1; |
288 | nthr_oc_mb = nthr / nthr_ic; |
289 | ithr_ic = ithr / nthr_oc_mb; |
290 | ithr_oc_mb = ithr % nthr_oc_mb; |
291 | if (ithr_oc_mb >= work_amount || ithr_ic >= ic_chunks |
292 | || ithr >= rnd_dn(nthr, nthr_ic)) |
293 | return false; |
294 | return true; |
295 | }; |
296 | |
297 | // If work_amount == 1 we limit num_threads to 1 as parallel(1, ...) does |
298 | // not create parallel section at all. We do not limit num_threads |
299 | // for 1 < work_amount < dnnl_get_max_threads() case to avoid potential |
300 | // overhead on spawning different number of OMP threads from layer to layer. |
301 | const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr); |
302 | parallel(num_threads, [&](const int ithr, const int nthr) { |
303 | int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0}; |
304 | bool ok = init_thr_groups( |
305 | ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb); |
306 | if (!ok) return; |
307 | |
308 | int start {0}, end {0}; |
309 | balance211(work_amount, nthr_oc_mb, ithr_oc_mb, start, end); |
310 | |
311 | int icc_start {0}, icc_end {ic_chunks}; |
312 | if (nthr_ic > 1) |
313 | balance211(ic_chunks, nthr_ic, ithr_ic, icc_start, icc_end); |
314 | |
315 | const int icc_work = icc_end - icc_start; |
316 | |
317 | if (is_amx) |
318 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
319 | |
320 | int occ {0}, osc {0}; |
321 | nd_iterator_init(start, osc, os_chunks, occ, oc_chunks); |
322 | while (start < end) { |
323 | int ocb_s = occ * jbgp.nb_oc_blocking; |
324 | int ocb_e = nstl::min(ocb_s + jbgp.nb_oc_blocking, jbgp.nb_oc); |
325 | int ocb_work = ocb_e - ocb_s; |
326 | |
327 | int osb_s = osc * jbgp.nb_os_blocking; |
328 | int osb_e = nstl::min(osb_s + jbgp.nb_os_blocking, jbgp.nb_os); |
329 | int osb_work = osb_e - osb_s; |
330 | |
331 | // Each thread runs the below loops: |
332 | int loop_start = 0, loop_end = icc_work * osb_work * ocb_work; |
333 | int icc = 0, osb = 0, ocb = 0; |
334 | |
335 | // If buffer is required, then inner-most loop will be over icc_work |
336 | const bool ocb_inner_most |
337 | = is_f32 && !(jbgp.is_bf32 || jbgp.use_buffer); |
338 | if (ocb_inner_most) |
339 | nd_iterator_init( |
340 | 0, icc, icc_work, osb, osb_work, ocb, ocb_work); |
341 | else |
342 | nd_iterator_init( |
343 | 0, osb, osb_work, ocb, ocb_work, icc, icc_work); |
344 | |
345 | while (loop_start < loop_end) { |
346 | const int n = (osb + osb_s) * jbgp.os_block; |
347 | const int cur_icc = icc + icc_start; |
348 | const bool copy_buffer_a = jbgp.use_buffer_a |
349 | && IMPLICATION(ocb_inner_most, ocb == 0); |
350 | ker(ithr_oc_mb, nthr_oc_mb, ithr_ic, n, ocb + ocb_s, cur_icc, |
351 | cur_icc == icc_start, osb, copy_buffer_a); |
352 | |
353 | ++loop_start; |
354 | if (ocb_inner_most) |
355 | nd_iterator_step( |
356 | icc, icc_work, osb, osb_work, ocb, ocb_work); |
357 | else |
358 | nd_iterator_step( |
359 | osb, osb_work, ocb, ocb_work, icc, icc_work); |
360 | } |
361 | |
362 | ++start; |
363 | nd_iterator_step(osc, os_chunks, occ, oc_chunks); |
364 | } |
365 | if (is_amx) amx_tile_release(); |
366 | }); |
367 | |
368 | if (jbgp.nthr_ic_b > 1) { |
369 | assert(jbgp.use_buffer && is_f32); |
370 | |
371 | const auto get_dst_reduced_off = [&](int ithr_ic, int osb, int ocb) { |
372 | assert(jbgp.nthr_ic_b > 1); |
373 | int os = osb * jbgp.os_block; |
374 | int oc = ocb * jbgp.oc_block; |
375 | const size_t dst_off = get_blk_off(dst_d, jbgp.dst_dt, os, oc); |
376 | if (ithr_ic == 0) return dst_off; |
377 | assert(ithr_ic > 0); |
378 | const size_t ic_buf_idx = jbgp.with_sum ? ithr_ic : ithr_ic - 1; |
379 | return dst_off + (ic_buf_idx * jbgp.mb * jbgp.LDC * acc_dt_size); |
380 | }; |
381 | |
382 | parallel(num_threads, [&](const int ithr, const int nthr) { |
383 | int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0}; |
384 | bool ok = init_thr_groups( |
385 | ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb); |
386 | if (!ok) return; |
387 | |
388 | int ocmb_start {0}, ocmb_end {0}; |
389 | int start {0}, end {0}; |
390 | balance211( |
391 | work_amount, nthr_oc_mb, ithr_oc_mb, ocmb_start, ocmb_end); |
392 | balance211(ocmb_end - ocmb_start, nthr_ic, ithr_ic, start, end); |
393 | |
394 | int occ {0}, osc {0}; |
395 | nd_iterator_init( |
396 | ocmb_start + start, osc, os_chunks, occ, oc_chunks); |
397 | while (start < end) { |
398 | int ocb_s = occ * jbgp.nb_oc_blocking; |
399 | int ocb_e = nstl::min(ocb_s + jbgp.nb_oc_blocking, jbgp.nb_oc); |
400 | |
401 | int osb_s = osc * jbgp.nb_os_blocking; |
402 | int osb_e = nstl::min(osb_s + jbgp.nb_os_blocking, jbgp.nb_os); |
403 | |
404 | for (int osb = osb_s; osb < osb_e; ++osb) { |
405 | int cur_os_block = nstl::min( |
406 | jbgp.os - osb * jbgp.os_block, jbgp.os_block); |
407 | const bool is_os_tail = cur_os_block < jbgp.os_block; |
408 | const int cur_oc_chunk_size |
409 | = nstl::min(jbgp.LDC, ocb_e * jbgp.oc_block) |
410 | - ocb_s * jbgp.oc_block; |
411 | char *dst_reduced = (jbgp.with_sum ? c_buffer_global : dst) |
412 | + get_dst_reduced_off(0, osb, ocb_s); |
413 | const size_t os_offset = jbgp.LDC * acc_dt_size; |
414 | for (int ic_buf = 0; ic_buf < nthr_ic - 1; ++ic_buf) { |
415 | const char *c_buffer = c_buffer_global |
416 | + get_dst_reduced_off(ic_buf + 1, osb, ocb_s); |
417 | for (int os = 0; os < cur_os_block; ++os) { |
418 | acc_ker_->accumulate( |
419 | (float *)(dst_reduced + os * os_offset), |
420 | (float *)(c_buffer + os * os_offset), |
421 | cur_oc_chunk_size); |
422 | } |
423 | } |
424 | if (are_post_ops_applicable) { |
425 | for (int ocb = ocb_s; ocb < ocb_e; ++ocb) { |
426 | const bool is_oc_tail |
427 | = (jbgp.oc - ocb * jbgp.oc_block |
428 | < jbgp.oc_block); |
429 | const int brg_ker_idx = brgemm_inner_product_utils:: |
430 | get_brg_kernel_index(jbgp, false, false, |
431 | is_os_tail, is_oc_tail, false); |
432 | const auto brg_kernel |
433 | = brg_kernels_[brg_ker_idx].get(); |
434 | const int os = osb * jbgp.os_block; |
435 | const int oc = ocb * jbgp.oc_block; |
436 | const auto ptr_bias = jbgp.with_bias |
437 | ? bias + bia_dt_size * oc |
438 | : nullptr; |
439 | auto ptr_D = dst |
440 | + get_blk_off(dst_d, jbgp.dst_dt, os, oc); |
441 | auto ptr_C = (jbgp.with_sum ? c_buffer_global : dst) |
442 | + get_dst_reduced_off(0, osb, ocb); |
443 | |
444 | char *wsp_tile = is_amx ? wsp_tile_base |
445 | + ithr * jbgp.amx_buf_size_per_thread |
446 | : nullptr; |
447 | |
448 | void *scratch = is_amx |
449 | ? static_cast<void *>(wsp_tile) |
450 | : (jbgp.signed_input ? static_cast<void *>( |
451 | const_cast<int *>( |
452 | &compensation[oc])) |
453 | : nullptr); |
454 | |
455 | const brgemm_post_ops_data_t post_ops_data { |
456 | static_cast<const void *>(ptr_bias), |
457 | &oscales[jbgp.is_oc_scale * oc], |
458 | post_ops_binary_rhs_arg_vec.data(), |
459 | static_cast<size_t>(oc), 0, dst, 0, nullptr, |
460 | nullptr, nullptr, true /* skip_accm */}; |
461 | brgemm_kernel_execute_postops(brg_kernel, 0, |
462 | nullptr, (void *)ptr_C, (void *)ptr_D, |
463 | post_ops_data, scratch); |
464 | } |
465 | } |
466 | } |
467 | ++start; |
468 | nd_iterator_step(osc, os_chunks, occ, oc_chunks); |
469 | } |
470 | }); |
471 | } |
472 | |
473 | return status::success; |
474 | } |
475 | |
476 | template struct brgemm_inner_product_fwd_t<avx2_vnni_2>; |
477 | template struct brgemm_inner_product_fwd_t<avx512_core>; |
478 | template struct brgemm_inner_product_fwd_t<avx512_core_bf16>; |
479 | template struct brgemm_inner_product_fwd_t<avx512_core_vnni>; |
480 | template struct brgemm_inner_product_fwd_t<avx512_core_amx>; |
481 | template struct brgemm_inner_product_fwd_t<avx512_core_fp16>; |
482 | template struct brgemm_inner_product_fwd_t<avx512_core_amx_fp16>; |
483 | |
484 | template <cpu_isa_t isa> |
485 | void brgemm_inner_product_bwd_data_t<isa>::execute_backward_data( |
486 | const exec_ctx_t &ctx) const { |
487 | |
488 | auto diff_dst_ = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST); |
489 | auto weights_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
490 | auto diff_src_ = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); |
491 | |
492 | auto diff_src = const_cast<char *>(diff_src_); |
493 | auto weights = const_cast<char *>(weights_); |
494 | auto diff_dst = const_cast<char *>(diff_dst_); |
495 | |
496 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
497 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
498 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
499 | |
500 | const auto &jbgp = pd()->jbgp_; |
501 | |
502 | const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt); |
503 | const bool is_bf16 = everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt); |
504 | const bool is_f16 = everyone_is(f16, jbgp.wei_dt, jbgp.dst_dt); |
505 | const bool is_f32_out = jbgp.src_dt == f32; |
506 | const bool is_amx = jbgp.is_amx; |
507 | const size_t buf_dt_size = types::data_type_size( |
508 | isa == avx512_core_fp16 ? f32 : jbgp.wei_dt); |
509 | |
510 | const dim_t wei_dt_size = types::data_type_size(jbgp.wei_dt); |
511 | |
512 | memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
513 | brgemm_batch_element_t *addr_batch_global |
514 | = scratchpad.template get<brgemm_batch_element_t>( |
515 | key_brgemm_primitive_batch); |
516 | char *c_buffer_global = (jbgp.use_buffer) |
517 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
518 | : nullptr; |
519 | char *b_buffer_global = jbgp.use_buffer_b |
520 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b) |
521 | : nullptr; |
522 | char *a_buffer_global = jbgp.use_buffer_a |
523 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a) |
524 | : nullptr; |
525 | auto wsp_tile_base = is_amx |
526 | ? ctx.get_scratchpad_grantor().template get<char>( |
527 | key_conv_amx_tile_buffer) |
528 | : nullptr; |
529 | |
530 | const int oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking); |
531 | bool is_os_tail = (jbgp.mb < jbgp.os_block); |
532 | bool is_ic_tail = (jbgp.ic < jbgp.ic_block); |
533 | bool is_oc_tail = (jbgp.oc < jbgp.oc_block) && !jbgp.use_buffer_a; |
534 | |
535 | const dim_t acc_dt_sz = types::data_type_size(jbgp.acc_dt); |
536 | const dim_t src_dt_sz = types::data_type_size(jbgp.src_dt); |
537 | |
538 | const int base_brg_ker_idx = brgemm_inner_product_utils:: |
539 | get_brg_kernel_index( // TODO: Can be calculated on initialization stage |
540 | jbgp, false, false, is_os_tail, is_ic_tail, is_oc_tail); |
541 | |
542 | const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
543 | const int work_amount = jbgp.nb_ic * os_chunks; |
544 | const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr); |
545 | |
546 | const auto get_weights_ptr = [&](int icb, int ocb) { |
547 | int fwd_ic_block |
548 | = (is_amx && !jbgp.is_bf32) ? 2 * jbgp.simd_w : jbgp.simd_w; |
549 | int fwd_oc_block = 0; |
550 | switch (jbgp.wei_tag) { |
551 | case OI16i64o: |
552 | case OIw16i64o: |
553 | case OIhw16i64o: |
554 | case OIdhw16i64o: |
555 | case OI8i64o2i: |
556 | case OIw8i64o2i: |
557 | case OIhw8i64o2i: |
558 | case OIdhw8i64o2i: |
559 | case OI16i64o2i: |
560 | case OIw16i64o2i: |
561 | case OIhw16i64o2i: |
562 | case OIdhw16i64o2i: fwd_oc_block = 4 * jbgp.simd_w; break; |
563 | case OI16i32o: |
564 | case OIw16i32o: |
565 | case OIhw16i32o: |
566 | case OIdhw16i32o: |
567 | case OI8i32o2i: |
568 | case OIw8i32o2i: |
569 | case OIhw8i32o2i: |
570 | case OIdhw8i32o2i: |
571 | case OI16i32o2i: |
572 | case OIw16i32o2i: |
573 | case OIhw16i32o2i: |
574 | case OIdhw16i32o2i: fwd_oc_block = 2 * jbgp.simd_w; break; |
575 | default: fwd_oc_block = jbgp.simd_w; |
576 | }; |
577 | int fwd_icb = icb * jbgp.ic_block / fwd_ic_block; |
578 | int fwd_ocb = ocb * jbgp.oc_block / fwd_oc_block; |
579 | char *ptr_wei_local = weights |
580 | + get_blk_off(weights_d, jbgp.wei_dt, fwd_ocb, fwd_icb); |
581 | |
582 | int fwd_ocb_simd = (ocb * jbgp.oc_block) % fwd_oc_block; |
583 | int fwd_icb_simd = (icb * jbgp.ic_block) % fwd_ic_block; |
584 | int blk_sz = is_bf16 || (is_f16 && isa != avx512_core_fp16) ? 2 : 1; |
585 | |
586 | return ptr_wei_local |
587 | + wei_dt_size |
588 | * (fwd_icb_simd / blk_sz * blk_sz * fwd_oc_block |
589 | + blk_sz * fwd_ocb_simd); |
590 | }; |
591 | |
592 | const auto transform_b_chunk |
593 | = [&](char *tr_wei, const char *wei, int trans_batch, int current_N, |
594 | int current_K) { |
595 | auto ctx = jit_brgemm_trans_wei_t::ctx_t(); |
596 | ctx.src = (void *)wei; |
597 | ctx.tr_src = (void *)tr_wei; |
598 | ctx.current_gemm_batch = trans_batch; |
599 | ctx.current_N = current_N; |
600 | ctx.current_K = current_K; |
601 | (*trans_B_kernel_)(&ctx); |
602 | }; |
603 | |
604 | const auto ker = [&](int ithr_ic_mb, int nthr_ic_mb, int ithr_oc, |
605 | int nthr_oc, int n, int icb, int occ, bool do_init, |
606 | bool do_b_transpose) { |
607 | const int ithr = nthr_ic_mb * ithr_oc + ithr_ic_mb; |
608 | brgemm_batch_element_t *addr_batch |
609 | = addr_batch_global + ithr * jbgp.adjusted_batch_size; |
610 | |
611 | const int ic = icb * jbgp.ic_block; |
612 | const int ocb = occ * jbgp.nb_oc_blocking; |
613 | const int oc = ocb * jbgp.oc_block; |
614 | const size_t dsrc_off = get_blk_off(diff_src_d, jbgp.src_dt, n, ic); |
615 | const int adj_buffers = (jbgp.src_dt == f32) ? 1 : 0; |
616 | const size_t c_buf_shift = jbgp.nthr_oc_b > 1 |
617 | ? (ithr_oc - adj_buffers) |
618 | * static_cast<size_t>(jbgp.mb * jbgp.LDC) |
619 | : ithr * static_cast<size_t>(jbgp.LDC * jbgp.M); |
620 | const size_t c_buf_off |
621 | = types::data_type_size(jbgp.acc_dt) * c_buf_shift |
622 | + (jbgp.nthr_oc_b > 1 ? acc_dt_sz * dsrc_off / src_dt_sz : 0); |
623 | bool use_c_buf = false; |
624 | if (is_f32_out && jbgp.use_buffer) { |
625 | use_c_buf = (jbgp.nthr_oc_b == 1 || ithr_oc > 0); |
626 | } else if (!is_f32_out && jbgp.use_buffer) { |
627 | if (jbgp.nthr_oc_b > 1) |
628 | use_c_buf = true; |
629 | else |
630 | use_c_buf = (jbgp.nthr_oc_b == 1 || ithr_oc > 0); |
631 | } |
632 | |
633 | const size_t a_buffer_size_per_thr |
634 | = jbgp.os_block * jbgp.LDA * types::data_type_size(jbgp.dst_dt); |
635 | char *c_buffer = use_c_buf ? c_buffer_global + c_buf_off : nullptr; |
636 | char *a_buffer = jbgp.use_buffer_a |
637 | ? a_buffer_global + ithr * a_buffer_size_per_thr |
638 | : diff_dst; |
639 | char *wsp_tile = is_amx |
640 | ? wsp_tile_base + ithr * jbgp.amx_buf_size_per_thread |
641 | : nullptr; |
642 | |
643 | bool kernel_init = do_init; |
644 | |
645 | const bool is_os_tail = (jbgp.mb - n < jbgp.os_block); |
646 | const bool is_ic_tail = (jbgp.ic - ic < jbgp.ic_block); |
647 | const bool is_last_oc_chunk = occ == oc_chunks - 1; |
648 | const bool is_oc_tail = is_last_oc_chunk && jbgp.K_tail > 0; |
649 | |
650 | const int rnd_oc |
651 | = rnd_up(jbgp.oc, jbgp.use_buffer_a ? jbgp.oc_block : 1); |
652 | const int nb_oc_b |
653 | = nstl::min((rnd_oc - oc) / jbgp.oc_block, jbgp.nb_oc_blocking); |
654 | |
655 | auto is_bs_tail = (nb_oc_b != jbgp.nb_oc_blocking); |
656 | const int brg_ker_idx |
657 | = brgemm_inner_product_utils::get_brg_kernel_index(jbgp, |
658 | is_bs_tail, kernel_init, is_os_tail, is_ic_tail, false); |
659 | auto brg_kernel = brg_kernels_[brg_ker_idx].get(); |
660 | |
661 | const int size_B = jbgp.LDB * rnd_up(jbgp.K, 2); |
662 | |
663 | const size_t b_buf_shift = jbgp.ip_bwd_d_global_b_transpose |
664 | ? icb * jbgp.nb_oc + ocb |
665 | : ithr * jbgp.gemm_batch_size; |
666 | const size_t b_buf_off = buf_dt_size * b_buf_shift * size_B; |
667 | char *b_buffer = b_buffer_global + b_buf_off; |
668 | |
669 | char *ptr_D = diff_src + dsrc_off; |
670 | char *ptr_C = use_c_buf ? c_buffer : ptr_D; |
671 | |
672 | if (jbgp.use_buffer_a) |
673 | copy_data_chunk(copy_diff_dst_kernel_, a_buffer, |
674 | diff_dst + get_blk_off(diff_dst_d, jbgp.dst_dt, n, oc), |
675 | is_os_tail ? jbgp.os - n : jbgp.os_block, is_last_oc_chunk); |
676 | |
677 | if (nb_oc_b > 0 && brg_kernel != nullptr) { |
678 | if (is_amx && (is_os_tail || is_ic_tail)) |
679 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
680 | |
681 | for (int oc_block = 0; oc_block < nb_oc_b; oc_block++) { |
682 | addr_batch[oc_block].ptr.A = jbgp.use_buffer_a ? a_buffer |
683 | + oc_block * jbgp.oc_block |
684 | * types::data_type_size(jbgp.dst_dt) |
685 | : diff_dst |
686 | + get_blk_off(diff_dst_d, jbgp.dst_dt, n, |
687 | oc + oc_block * jbgp.oc_block); |
688 | addr_batch[oc_block].ptr.B |
689 | = b_buffer + buf_dt_size * (oc_block * size_B); |
690 | if (!jbgp.ip_bwd_d_global_b_transpose && do_b_transpose) |
691 | transform_b_chunk((char *)addr_batch[oc_block].ptr.B, |
692 | get_weights_ptr(icb, ocb + oc_block), 1, |
693 | is_ic_tail ? jbgp.ic % jbgp.ic_block |
694 | : jbgp.ic_block, |
695 | jbgp.oc_block); |
696 | } |
697 | |
698 | if (jbgp.use_buffer && (jbgp.nthr_oc_b <= 1 || num_threads == 1) |
699 | && is_last_oc_chunk && !is_oc_tail) { |
700 | void *scratch |
701 | = is_amx ? static_cast<void *>(wsp_tile) : nullptr; |
702 | const brgemm_post_ops_data_t empty_po_data {}; |
703 | brgemm_kernel_execute_postops(brg_kernel, nb_oc_b, addr_batch, |
704 | (void *)c_buffer, (void *)ptr_D, empty_po_data, |
705 | scratch); |
706 | |
707 | } else { |
708 | brgemm_kernel_execute(brg_kernel, nb_oc_b, addr_batch, |
709 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
710 | } |
711 | if (is_amx && (is_os_tail || is_ic_tail)) |
712 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
713 | } |
714 | if (is_oc_tail) { |
715 | assert(!jbgp.use_buffer_a); |
716 | |
717 | const int oc_block = nb_oc_b; |
718 | addr_batch[0].ptr.A = diff_dst |
719 | + get_blk_off(diff_dst_d, jbgp.dst_dt, n, |
720 | oc + oc_block * jbgp.oc_block); |
721 | addr_batch[0].ptr.B = b_buffer + buf_dt_size * (oc_block * size_B); |
722 | if (!jbgp.ip_bwd_d_global_b_transpose && do_b_transpose) { |
723 | transform_b_chunk((char *)addr_batch[0].ptr.B, |
724 | get_weights_ptr(icb, ocb + oc_block), 1, |
725 | is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block, |
726 | jbgp.K_tail); |
727 | } |
728 | |
729 | auto use_init_ker = (kernel_init && nb_oc_b == 0); |
730 | const int brg_kernel_oc_tail_idx |
731 | = brgemm_inner_product_utils::get_brg_kernel_index(jbgp, |
732 | false, use_init_ker, is_os_tail, is_ic_tail, true); |
733 | auto brg_kernel_oc_tail |
734 | = brg_kernels_[brg_kernel_oc_tail_idx].get(); |
735 | if (is_amx) |
736 | amx_tile_configure( |
737 | &brg_kernel_palettes_[brg_kernel_oc_tail_idx][0]); |
738 | if (jbgp.use_buffer && jbgp.nthr_oc_b <= 1) { |
739 | void *scratch |
740 | = is_amx ? static_cast<void *>(wsp_tile) : nullptr; |
741 | const brgemm_post_ops_data_t empty_po_data {}; |
742 | brgemm_kernel_execute_postops(brg_kernel_oc_tail, 1, addr_batch, |
743 | (void *)c_buffer, (void *)ptr_D, empty_po_data, |
744 | scratch); |
745 | |
746 | } else { |
747 | brgemm_kernel_execute(brg_kernel_oc_tail, 1, addr_batch, |
748 | (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr); |
749 | } |
750 | if (is_amx) |
751 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
752 | } |
753 | }; |
754 | |
755 | if (jbgp.ip_bwd_d_global_b_transpose && jbgp.use_buffer_b) { |
756 | assert(IMPLICATION( |
757 | jbgp.ip_bwd_d_global_b_transpose, jbgp.nthr_oc_b == 1)); |
758 | parallel(num_threads, [&](const int ithr, const int nthr) { |
759 | int start {0}, end {0}; |
760 | int max_ch_block = nstl::max(jbgp.ic_block, jbgp.oc_block); |
761 | int ic_chunk_sz = max_ch_block / jbgp.ic_block; |
762 | int oc_chunk_sz = max_ch_block / jbgp.oc_block; |
763 | int nc_ic = utils::div_up(jbgp.nb_ic, ic_chunk_sz); |
764 | int nc_oc = utils::div_up(jbgp.nb_oc, oc_chunk_sz); |
765 | int transp_work_amount = nc_ic * nc_oc; |
766 | balance211(transp_work_amount, nthr, ithr, start, end); |
767 | int icc, occ; |
768 | nd_iterator_init(start, icc, nc_ic, occ, nc_oc); |
769 | while (start < end) { |
770 | int icb_start = icc * ic_chunk_sz; |
771 | int icb_end = nstl::min((icc + 1) * ic_chunk_sz, jbgp.nb_ic); |
772 | int ocb_start = occ * oc_chunk_sz; |
773 | int ocb_end = nstl::min((occ + 1) * oc_chunk_sz, jbgp.nb_oc); |
774 | for_(int icb = icb_start; icb < icb_end; icb++) |
775 | for (int ocb = ocb_start; ocb < ocb_end; ocb++) { |
776 | int ic = icb * jbgp.ic_block; |
777 | int oc = ocb * jbgp.oc_block; |
778 | bool is_ic_tail = (jbgp.ic - ic < jbgp.ic_block); |
779 | bool is_oc_tail = (jbgp.oc - oc < jbgp.oc_block); |
780 | const int size_B = jbgp.LDB * rnd_up(jbgp.K, 2); |
781 | char *b_buffer = b_buffer_global |
782 | + buf_dt_size |
783 | * ((dim_t)icb * jbgp.nb_oc * size_B |
784 | + (dim_t)ocb * size_B); |
785 | |
786 | transform_b_chunk(b_buffer, get_weights_ptr(icb, ocb), 1, |
787 | is_ic_tail ? jbgp.ic % jbgp.ic_block |
788 | : jbgp.ic_block, |
789 | is_oc_tail ? jbgp.oc % jbgp.oc_block |
790 | : jbgp.oc_block); |
791 | } |
792 | ++start; |
793 | nd_iterator_step(icc, nc_ic, occ, nc_oc); |
794 | } |
795 | }); |
796 | } |
797 | |
798 | parallel(num_threads, [&](const int ithr, const int nthr) { |
799 | const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1; |
800 | const int nthr_ic_mb = nthr / nthr_oc; |
801 | const int ithr_ic_mb = ithr % nthr_ic_mb; |
802 | const int ithr_oc = ithr / nthr_ic_mb; |
803 | if (ithr_ic_mb >= work_amount || ithr_oc >= oc_chunks |
804 | || ithr >= rnd_dn(nthr, nthr_oc)) |
805 | return; |
806 | |
807 | int start {0}, end {0}; |
808 | balance211(work_amount, nthr_ic_mb, ithr_ic_mb, start, end); |
809 | int occ_start {0}, occ_end {oc_chunks}; |
810 | if (nthr_oc > 1) |
811 | balance211(oc_chunks, nthr_oc, ithr_oc, occ_start, occ_end); |
812 | |
813 | if (is_amx) |
814 | amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]); |
815 | |
816 | int icb {0}, oss {0}; |
817 | nd_iterator_init(start, oss, os_chunks, icb, jbgp.nb_ic); |
818 | while (start < end) { |
819 | const int nb_os_blocking |
820 | = nstl::min(jbgp.nb_os - oss * jbgp.nb_os_blocking, |
821 | jbgp.nb_os_blocking); |
822 | const int occ_work = occ_end - occ_start; |
823 | const int loop_iteration = nb_os_blocking * occ_work; |
824 | |
825 | for (int iter = 0; iter < loop_iteration; ++iter) { |
826 | int osb = 0, occ = occ_start; |
827 | if (jbgp.use_buffer || !is_f32) { |
828 | osb += iter / occ_work; |
829 | occ += iter % occ_work; |
830 | } else { |
831 | occ += iter / nb_os_blocking; |
832 | osb += iter % nb_os_blocking; |
833 | } |
834 | int n = (oss * jbgp.nb_os_blocking + osb) * jbgp.os_block; |
835 | ker(ithr_ic_mb, nthr_ic_mb, ithr_oc, nthr_oc, n, icb, occ, |
836 | occ == occ_start, osb == 0 || occ_work > 1); |
837 | } |
838 | ++start; |
839 | nd_iterator_step(oss, os_chunks, icb, jbgp.nb_ic); |
840 | } |
841 | if (is_amx) amx_tile_release(); |
842 | }); |
843 | |
844 | if (jbgp.nthr_oc_b > 1) { |
845 | parallel(num_threads, [&](const int ithr, const int nthr) { |
846 | const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1; |
847 | if (nthr_oc <= 1) return; |
848 | |
849 | const int ddst_elems = jbgp.LDC * jbgp.os; |
850 | const int reduce_chunk_size = 64; |
851 | int start {0}, end {0}; |
852 | balance211(div_up(ddst_elems, reduce_chunk_size), nthr, ithr, start, |
853 | end); |
854 | const dim_t reduce_start = start * reduce_chunk_size; |
855 | const dim_t reduce_finish |
856 | = nstl::min(end * reduce_chunk_size, ddst_elems); |
857 | if (reduce_finish <= reduce_start) return; |
858 | const dim_t elems_to_reduce = reduce_finish - reduce_start; |
859 | const dim_t acc_dt_sz = types::data_type_size(jbgp.acc_dt); |
860 | |
861 | char *dsrc_reduced = diff_src + src_dt_sz * reduce_start; |
862 | char *c_buffer_start = c_buffer_global + acc_dt_sz * reduce_start; |
863 | |
864 | float *out_buffer = is_f32_out |
865 | ? reinterpret_cast<float *>(dsrc_reduced) |
866 | : reinterpret_cast<float *>(c_buffer_start); |
867 | int oc_buf_idx = !is_f32_out; |
868 | int oc_buf_end = is_f32_out; |
869 | for (int oc_buf = oc_buf_idx; oc_buf < nthr_oc - oc_buf_end; |
870 | oc_buf++) { |
871 | const dim_t c_buf_offt = acc_dt_sz |
872 | * (oc_buf * jbgp.os * jbgp.LDC + reduce_start); |
873 | char *c_buffer = c_buffer_global + c_buf_offt; |
874 | |
875 | acc_ker_->accumulate((float *)out_buffer, (float *)c_buffer, |
876 | elems_to_reduce); |
877 | if (!is_f32_out && oc_buf == (nthr_oc - oc_buf_end) - 1) { |
878 | if (is_bf16) { |
879 | cvt_float_to_bfloat16((bfloat16_t *)dsrc_reduced, |
880 | (const float *)out_buffer, elems_to_reduce); |
881 | } else if (is_f16) { |
882 | cvt_float_to_float16((float16_t *)dsrc_reduced, |
883 | (const float *)out_buffer, elems_to_reduce); |
884 | } |
885 | } |
886 | } |
887 | }); |
888 | } |
889 | } |
890 | |
891 | template struct brgemm_inner_product_bwd_data_t<avx512_core>; |
892 | template struct brgemm_inner_product_bwd_data_t<avx512_core_amx>; |
893 | template struct brgemm_inner_product_bwd_data_t<avx512_core_bf16>; |
894 | template struct brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>; |
895 | template struct brgemm_inner_product_bwd_data_t<avx512_core_fp16>; |
896 | |
897 | template <cpu_isa_t isa> |
898 | struct brgemm_inner_product_bwd_weights_t<isa>::thread_info_t { |
899 | const char *src; |
900 | const char *diff_dst; |
901 | char *diff_weights; |
902 | char *diff_bias; |
903 | |
904 | const memory_tracking::grantor_t scratchpad; |
905 | |
906 | char *buffer_c = nullptr; |
907 | char *buffer_bias = nullptr; |
908 | char *wsp_tile_base = nullptr; |
909 | |
910 | int ithr; |
911 | int ithr_ic_c, ithr_oc_c, ithr_os_c; |
912 | int nthr; |
913 | int nthr_ic_c, nthr_oc_c, nthr_os_c; |
914 | |
915 | int os_c_start = 0, os_c_end = 0, os_c_work; |
916 | int oc_c_start = 0, oc_c_end = 0, oc_c_work; |
917 | int ic_c_start = 0, ic_c_end = 0, ic_c_work; |
918 | simple_barrier::ctx_t *barrier_ctx; |
919 | |
920 | thread_info_t(const brgemm_inner_product_bwd_weights_t *self, |
921 | const exec_ctx_t &ctx, int ithr) |
922 | : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) { |
923 | |
924 | src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
925 | diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST); |
926 | diff_weights = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_WEIGHTS); |
927 | diff_bias = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_BIAS); |
928 | const auto &jbgp = self->pd()->jbgp_; |
929 | |
930 | const bool is_amx = jbgp.is_amx; |
931 | |
932 | buffer_c = (jbgp.use_buffer) |
933 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
934 | : nullptr; |
935 | |
936 | buffer_bias = (jbgp.with_bias |
937 | && (jbgp.bia_dt != data_type::f32 |
938 | || jbgp.nthr_mb > 1)) |
939 | ? scratchpad.template get<char>(key_iprod_bias_bf16_convert_wsp) |
940 | : nullptr; |
941 | |
942 | buffer_a_ |
943 | = scratchpad.template get<char>(key_brgemm_primitive_buffer_a); |
944 | buffer_b_ = jbgp.use_buffer_b |
945 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b) |
946 | : nullptr; |
947 | |
948 | thread_local_input_buffers_ |
949 | = jbgp.ip_bwd_w_local_buffers_for_input_tensors; |
950 | int ic_chunks = utils::div_up(jbgp.nb_ic, jbgp.nb_ic_blocking); |
951 | int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
952 | nb_ic_blocking_ = jbgp.nb_ic_blocking; |
953 | nb_oc_blocking_ = jbgp.nb_oc_blocking; |
954 | const size_t os_chunks_per_thr = utils::div_up(os_chunks, jbgp.nthr_mb); |
955 | const size_t num_os_chunks_per_thread |
956 | = thread_local_input_buffers_ ? 1 : os_chunks_per_thr; |
957 | |
958 | if (jbgp.use_buffer_a) { |
959 | const size_t dt_sz = buf_dt_size(jbgp.src_dt, jbgp.isa); |
960 | const size_t ic_chunks_per_thr |
961 | = utils::div_up(ic_chunks, jbgp.nthr_ic_b); |
962 | const size_t num_ic_chunks_per_thread |
963 | = thread_local_input_buffers_ ? 1 : ic_chunks_per_thr; |
964 | const size_t block_A_size = dt_sz * jbgp.LDA * jbgp.M; |
965 | const size_t os_chunk_A_buffer |
966 | = jbgp.gemm_batch_size * block_A_size; |
967 | const size_t ic_os_chunk_A_buffer |
968 | = jbgp.nb_ic_blocking * os_chunk_A_buffer; |
969 | |
970 | buffer_a_icb_shift_ = os_chunk_A_buffer; |
971 | buffer_a_osb_shift_ = block_A_size; |
972 | buffer_a_osc_shift_ = thread_local_input_buffers_ |
973 | ? 0 |
974 | : ic_chunks_per_thr * ic_os_chunk_A_buffer; |
975 | const size_t buffer_a_thread_shift = num_ic_chunks_per_thread |
976 | * num_os_chunks_per_thread * ic_os_chunk_A_buffer; |
977 | |
978 | buffer_a_ = buffer_a_ + ithr * buffer_a_thread_shift; |
979 | } |
980 | |
981 | if (jbgp.use_buffer_b) { |
982 | const auto buf_dt = jbgp.dst_dt == f16 && isa == avx512_core_fp16 |
983 | ? data_type::f32 |
984 | : jbgp.dst_dt; |
985 | const size_t dt_sz = buf_dt_size(jbgp.dst_dt, jbgp.isa); |
986 | assert(types::data_type_size(buf_dt) == dt_sz); |
987 | const size_t block_B_size = dt_sz * jbgp.LDB * jbgp.K; |
988 | const size_t os_chunk_B_buffer |
989 | = jbgp.gemm_batch_size * block_B_size; |
990 | buffer_b_ocb_shift_ = dt_sz * jbgp.oc_block |
991 | * data_type_vnni_granularity(buf_dt); |
992 | buffer_b_osb_shift_ = block_B_size; |
993 | buffer_b_osc_shift_ |
994 | = thread_local_input_buffers_ ? 0 : os_chunk_B_buffer; |
995 | |
996 | const size_t buffer_b_thread_shift |
997 | = num_os_chunks_per_thread * os_chunk_B_buffer; |
998 | |
999 | buffer_b_ = buffer_b_ + ithr * buffer_b_thread_shift; |
1000 | } |
1001 | |
1002 | wsp_tile_base = is_amx |
1003 | ? ctx.get_scratchpad_grantor().template get<char>( |
1004 | key_conv_amx_tile_buffer) |
1005 | : nullptr; |
1006 | |
1007 | nthr = jbgp.nthr; |
1008 | nthr_ic_c = jbgp.nthr_ic_b; |
1009 | nthr_oc_c = jbgp.nthr_oc_b; |
1010 | nthr_os_c = jbgp.nthr_mb; |
1011 | |
1012 | ithr_ic_c = ithr % nthr_ic_c; |
1013 | ithr_oc_c = ithr / nthr_ic_c % nthr_oc_c; |
1014 | ithr_os_c = ithr / nthr_ic_c / nthr_oc_c; |
1015 | |
1016 | int oc_chunks = utils::div_up(jbgp.nb_oc, jbgp.nb_oc_blocking); |
1017 | |
1018 | /* reduction dimension */ |
1019 | balance211(os_chunks, nthr_os_c, ithr_os_c, os_c_start, os_c_end); |
1020 | os_c_work = os_c_end - os_c_start; |
1021 | |
1022 | balance211(oc_chunks, nthr_oc_c, ithr_oc_c, oc_c_start, oc_c_end); |
1023 | oc_c_work = oc_c_end - oc_c_start; |
1024 | |
1025 | balance211(ic_chunks, nthr_ic_c, ithr_ic_c, ic_c_start, ic_c_end); |
1026 | ic_c_work = ic_c_end - ic_c_start; |
1027 | |
1028 | if (dnnl_thr_syncable()) |
1029 | barrier_ctx = scratchpad.template get<simple_barrier::ctx_t>( |
1030 | key_conv_wei_bia_reduction_bctx); |
1031 | } |
1032 | |
1033 | char *get_buffer_a_ptr(int icb, int osc) const { |
1034 | if (!buffer_a_) return (char *)nullptr; |
1035 | |
1036 | const int icb_idx = thread_local_input_buffers_ |
1037 | ? icb % nb_ic_blocking_ |
1038 | : icb - ic_c_start * nb_ic_blocking_; |
1039 | const int osc_idx = thread_local_input_buffers_ ? 0 : osc - os_c_start; |
1040 | |
1041 | return buffer_a_ + osc_idx * buffer_a_osc_shift_ |
1042 | + icb_idx * buffer_a_icb_shift_; |
1043 | } |
1044 | |
1045 | char *get_buffer_b_ptr(int ocb, int osc) const { |
1046 | if (!buffer_b_) return (char *)nullptr; |
1047 | |
1048 | const int ocb_idx = ocb % nb_oc_blocking_; |
1049 | const int osc_idx = thread_local_input_buffers_ ? 0 : osc - os_c_start; |
1050 | |
1051 | return buffer_b_ + osc_idx * buffer_b_osc_shift_ |
1052 | + ocb_idx * buffer_b_ocb_shift_; |
1053 | } |
1054 | |
1055 | size_t get_buffer_a_osb_shift() const { return buffer_a_osb_shift_; } |
1056 | size_t get_buffer_b_osb_shift() const { return buffer_b_osb_shift_; } |
1057 | |
1058 | private: |
1059 | char *buffer_a_ = nullptr; |
1060 | char *buffer_b_ = nullptr; |
1061 | |
1062 | bool thread_local_input_buffers_ = false; |
1063 | int nb_ic_blocking_ = 1; |
1064 | int nb_oc_blocking_ = 1; |
1065 | size_t buffer_a_icb_shift_ = 0; |
1066 | size_t buffer_a_osc_shift_ = 0; |
1067 | size_t buffer_a_osb_shift_ = 0; |
1068 | size_t buffer_b_ocb_shift_ = 0; |
1069 | size_t buffer_b_osc_shift_ = 0; |
1070 | size_t buffer_b_osb_shift_ = 0; |
1071 | }; |
1072 | |
1073 | template <cpu_isa_t isa> |
1074 | void brgemm_inner_product_bwd_weights_t<isa>::transform_matrix_a_chunk( |
1075 | char *tr_src, const char *src, int trans_batch, int current_m, |
1076 | int current_k) const { |
1077 | auto ctx = jit_brgemm_trans_src_t::ctx_t(); |
1078 | ctx.src = (void *)src; |
1079 | ctx.tr_src = (void *)tr_src; |
1080 | ctx.current_gemm_batch = trans_batch; |
1081 | ctx.current_M = current_m; |
1082 | ctx.current_K = current_k; |
1083 | (*trans_A_kernel_)(&ctx); |
1084 | } |
1085 | |
1086 | template <cpu_isa_t isa> |
1087 | void brgemm_inner_product_bwd_weights_t<isa>::transform_matrix_b_chunk( |
1088 | char *tr_diff_dst, const char *diff_dst, int trans_batch, |
1089 | int current_col_size, int current_row_size) const { |
1090 | auto ctx = jit_brgemm_trans_to_vnni_t::ctx_t(); |
1091 | ctx.src = (void *)diff_dst; |
1092 | ctx.tr_src = (void *)tr_diff_dst; |
1093 | ctx.current_gemm_batch = trans_batch; |
1094 | ctx.current_col_size = current_col_size; |
1095 | ctx.current_row_size = current_row_size; |
1096 | (*trans_B_kernel_)(&ctx); |
1097 | } |
1098 | |
1099 | template <cpu_isa_t isa> |
1100 | void brgemm_inner_product_bwd_weights_t<isa>::transpose_matrix_c_chunk( |
1101 | const thread_info_t *ti, const int ocb, const int icb, int oc_size, |
1102 | int ic_size, bool is_reduction) const { |
1103 | const auto &jbgp = pd()->jbgp_; |
1104 | |
1105 | if (jbgp.is_amx) { |
1106 | auto p = jit_amx_ip_trans_diff_wei::ctx_t(); |
1107 | |
1108 | const dim_t ext_nb_ic = div_up(jbgp.ic, ext_ic_block_); |
1109 | dim_t icb_shift = (icb * (jbgp.ic_block / ext_ic_block_)) |
1110 | * ext_ic_block_ * ext_oc_block_; |
1111 | |
1112 | dim_t ocb_shift = (ocb * (jbgp.oc_block / ext_oc_block_)) * ext_nb_ic |
1113 | * ext_ic_block_ * ext_oc_block_; |
1114 | dim_t out_offset = ocb_shift + icb_shift; |
1115 | |
1116 | p.src = get_wei_acc_ptr(ti, ocb, icb, 0); |
1117 | p.dst = (void *)(ti->diff_weights |
1118 | + types::data_type_size(jbgp.wei_dt) * out_offset); |
1119 | |
1120 | p.last_ic_block = (jbgp.ic <= ext_ic_block_ |
1121 | || (jbgp.nb_ic > 1 && icb == jbgp.nb_ic - 1)) |
1122 | ? 1 |
1123 | : 0; |
1124 | p.last_oc_block = (jbgp.oc <= ext_oc_block_ |
1125 | || (jbgp.nb_oc > 1 && ocb == jbgp.nb_oc - 1)) |
1126 | ? 1 |
1127 | : 0; |
1128 | (*diff_wei_trans_kernel_)(&p); |
1129 | } else { |
1130 | auto ctx = jit_brgemm_trans_to_vnni_t::ctx_t(); |
1131 | ctx.src = (void *)(get_wei_acc_ptr(ti, ocb, icb, 0)); |
1132 | |
1133 | ctx.tr_src = (void *)(ti->diff_weights |
1134 | + types::data_type_size(jbgp.wei_dt) |
1135 | * get_wei_offset(ocb, icb)); |
1136 | |
1137 | ctx.current_gemm_batch = 1; |
1138 | ctx.current_col_size = oc_size; |
1139 | ctx.current_row_size = ic_size; |
1140 | (*trans_C_kernel_)(&ctx); |
1141 | } |
1142 | } |
1143 | |
1144 | template <cpu_isa_t isa> |
1145 | dim_t brgemm_inner_product_bwd_weights_t<isa>::get_wei_offset( |
1146 | int ocb, int icb) const { |
1147 | const auto &jbgp = pd()->jbgp_; |
1148 | if (jbgp.is_amx) { |
1149 | const dim_t offset |
1150 | = jbgp.kd * jbgp.kh * jbgp.kw * jbgp.ic_block * jbgp.oc_block; |
1151 | return (ocb * jbgp.nb_ic + icb) * offset; |
1152 | } else { |
1153 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1154 | return diff_weights_d.blk_off(ocb, icb); |
1155 | } |
1156 | } |
1157 | |
1158 | template <cpu_isa_t isa> |
1159 | char *brgemm_inner_product_bwd_weights_t<isa>::get_wei_acc_ptr( |
1160 | const thread_info_t *ti, int ocb, int icb, |
1161 | int reduction_buf_idx) const { |
1162 | const auto &jbgp = pd()->jbgp_; |
1163 | |
1164 | const int reduction_buf_start_idx = jbgp.wei_dt == f32; |
1165 | // reduction_buf_idx argument allows manually set up required reduction |
1166 | // buffer index, required for reduction and transform diff_weights parts. |
1167 | // It has value -1 by default. If reduction_buf_idx < 0 then ti->ithr_os_c |
1168 | // is used for calculation of the current reduction index. |
1169 | const int buf_idx = reduction_buf_idx >= 0 |
1170 | ? reduction_buf_idx |
1171 | : (ti->ithr_os_c - reduction_buf_start_idx); |
1172 | const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt); |
1173 | |
1174 | if ((jbgp.nthr_mb > 1 && buf_idx < 0) |
1175 | || (jbgp.wei_dt == jbgp.acc_dt && reduction_buf_idx < 0 |
1176 | && ti->ithr_os_c == 0)) { |
1177 | MAYBE_UNUSED(reduction_buf_idx); |
1178 | const int icb_scale = (!jbgp.is_amx || jbgp.wei_dt == jbgp.acc_dt) |
1179 | ? jbgp.ic_block / jbgp.simd_w |
1180 | : 1; |
1181 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1182 | return (char *)ti->diff_weights |
1183 | + get_blk_off( |
1184 | diff_weights_d, jbgp.wei_dt, ocb, icb * icb_scale); |
1185 | } |
1186 | |
1187 | if (!jbgp.use_buffer) return nullptr; |
1188 | |
1189 | const int ocb_l = ocb % jbgp.nb_oc_blocking; |
1190 | const int icb_l = icb % jbgp.nb_ic_blocking; |
1191 | |
1192 | if (jbgp.nthr_mb > 1 || jbgp.harness == harness_mb_reduction) { |
1193 | const size_t icc = icb / jbgp.nb_ic_blocking; |
1194 | const size_t occ = ocb / jbgp.nb_oc_blocking; |
1195 | const size_t num_ic_chunks = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking); |
1196 | const size_t num_oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking); |
1197 | const size_t block_size = acc_dt_size * jbgp.ic_block * jbgp.oc_block; |
1198 | const size_t chunk_size |
1199 | = block_size * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking; |
1200 | const size_t reduction_buf_shift |
1201 | = num_ic_chunks * num_oc_chunks * chunk_size * buf_idx; |
1202 | return ti->buffer_c + reduction_buf_shift |
1203 | + (occ * num_ic_chunks + icc) * chunk_size |
1204 | + (ocb_l * jbgp.nb_ic_blocking + icb_l) * block_size; |
1205 | } else if (jbgp.nthr_mb == 1) { |
1206 | MAYBE_UNUSED(reduction_buf_idx); |
1207 | const size_t blk_size = acc_dt_size * jbgp.ic_block * jbgp.oc_block; |
1208 | const size_t buf_size_per_thread |
1209 | = blk_size * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking; |
1210 | const size_t offset_within_thread_buf |
1211 | = blk_size * (jbgp.nb_ic_blocking * ocb_l + icb_l); |
1212 | const size_t offset |
1213 | = ti->ithr * buf_size_per_thread + offset_within_thread_buf; |
1214 | return ti->buffer_c + offset; |
1215 | } |
1216 | |
1217 | assert(!"unsupported case" ); |
1218 | return nullptr; |
1219 | }; |
1220 | |
1221 | template <cpu_isa_t isa> |
1222 | void brgemm_inner_product_bwd_weights_t<isa>::compute_diff_weights_and_bias( |
1223 | const thread_info_t *ti) const { |
1224 | auto diff_dst = const_cast<char *>(ti->diff_dst); |
1225 | auto diff_bias = ti->diff_bias; |
1226 | |
1227 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1228 | |
1229 | const auto &jbgp = pd()->jbgp_; |
1230 | |
1231 | const size_t bia_dt_size |
1232 | = jbgp.with_bias ? types::data_type_size(jbgp.bia_dt) : 0; |
1233 | const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt); |
1234 | |
1235 | const int oc_chunk_sz = jbgp.oc_block * jbgp.nb_oc_blocking; |
1236 | |
1237 | brgemm_batch_element_t *addr_batch_global |
1238 | = ti->scratchpad.template get<brgemm_batch_element_t>( |
1239 | key_brgemm_primitive_batch); |
1240 | |
1241 | const bool is_bf16 = jbgp.wei_dt == bf16; |
1242 | const bool is_amx_bf16 = is_bf16 && isa == avx512_core_amx; |
1243 | char *wsp_tile_global = (is_amx_bf16) ? ti->wsp_tile_base : nullptr; |
1244 | int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
1245 | |
1246 | const auto get_bia_acc_ptr = [&](int oc) { |
1247 | const int reduction_buf_start_idx = jbgp.bia_dt == f32; |
1248 | if (jbgp.bia_dt != data_type::f32 |
1249 | || (jbgp.nthr_mb > 1 |
1250 | && ti->ithr_os_c >= reduction_buf_start_idx)) { |
1251 | return ti->buffer_bias |
1252 | + acc_dt_size * (ti->ithr_os_c - reduction_buf_start_idx) |
1253 | * jbgp.oc |
1254 | + acc_dt_size * oc; |
1255 | } else { |
1256 | return ti->diff_bias + bia_dt_size * oc; |
1257 | } |
1258 | }; |
1259 | |
1260 | const auto a_buf_osb_shift = ti->get_buffer_a_osb_shift(); |
1261 | const auto b_buf_osb_shift = ti->get_buffer_b_osb_shift(); |
1262 | |
1263 | const auto ker = [&](const int osc, const int icb, const int ocb, |
1264 | const int osc_prev, const int icc_prev, |
1265 | const int occ_prev) { |
1266 | brgemm_batch_element_t *addr_batch |
1267 | = addr_batch_global + ti->ithr * jbgp.adjusted_batch_size; |
1268 | char *wsp_tile = is_amx_bf16 |
1269 | ? wsp_tile_global + ti->ithr * jbgp.amx_buf_size_per_thread |
1270 | : nullptr; |
1271 | int ic = icb * jbgp.ic_block; |
1272 | int oc = ocb * jbgp.oc_block; |
1273 | int osb = osc * jbgp.nb_os_blocking; |
1274 | int n = osb * jbgp.os_block; |
1275 | |
1276 | // Base buffer pointers for the current kernel iteration, |
1277 | // x_buf_osb_shift values are used for shifting wrt osb iter variable |
1278 | char *a_buffer = ti->get_buffer_a_ptr(icb, osc); |
1279 | char *b_buffer = ti->get_buffer_b_ptr(ocb, osc); |
1280 | |
1281 | bool kernel_init = (osc == ti->os_c_start); |
1282 | |
1283 | bool is_os_tail = jbgp.mb - n < jbgp.os_block * jbgp.nb_os_blocking; |
1284 | bool is_ic_tail = jbgp.ic - ic < jbgp.ic_block; |
1285 | bool is_oc_tail = jbgp.oc - oc < jbgp.oc_block; |
1286 | const int oc_chunk_tail = jbgp.oc % oc_chunk_sz; |
1287 | const bool is_last_oc_chunk = jbgp.oc - oc < oc_chunk_sz; |
1288 | const int curr_oc_chunk_sz = oc_chunk_tail > 0 && is_last_oc_chunk |
1289 | ? oc_chunk_tail |
1290 | : oc_chunk_sz; |
1291 | |
1292 | const bool transform_weights = jbgp.wei_dt != jbgp.acc_dt |
1293 | && (jbgp.nthr_mb == 1 || os_chunks == 1) |
1294 | && osc == (os_chunks - 1); |
1295 | const bool transform_b = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1296 | ? jbgp.use_buffer_b && icb % jbgp.nb_ic_blocking == 0 |
1297 | && ocb % jbgp.nb_oc_blocking == 0 |
1298 | && IMPLICATION(osc_prev == osc, |
1299 | occ_prev != ocb / jbgp.nb_oc_blocking) |
1300 | : jbgp.use_buffer_b |
1301 | && icb == ti->ic_c_start * jbgp.nb_ic_blocking |
1302 | && ocb % jbgp.nb_oc_blocking == 0; |
1303 | const bool transform_a = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1304 | ? jbgp.use_buffer_a && ocb % jbgp.nb_oc_blocking == 0 |
1305 | && IMPLICATION(osc_prev == osc, |
1306 | icc_prev != icb / jbgp.nb_ic_blocking) |
1307 | : jbgp.use_buffer_a |
1308 | && ocb == ti->oc_c_start * jbgp.nb_oc_blocking; |
1309 | |
1310 | auto nb_os_b = is_os_tail ? (jbgp.mb - n) / jbgp.os_block |
1311 | : jbgp.nb_os_blocking; |
1312 | |
1313 | auto is_bs_tail = (nb_os_b != jbgp.nb_os_blocking); |
1314 | const int brg_ker_idx |
1315 | = brgemm_inner_product_utils::get_brg_kernel_index(jbgp, |
1316 | is_bs_tail, kernel_init, is_ic_tail, is_oc_tail, false); |
1317 | auto brg_kernel = brg_kernels_[brg_ker_idx].get(); |
1318 | |
1319 | if (kernel_init && (is_ic_tail || is_oc_tail)) |
1320 | utils::array_set(get_wei_acc_ptr(ti, ocb, icb), 0, |
1321 | types::data_type_size(jbgp.acc_dt) * jbgp.ic_block |
1322 | * jbgp.oc_block); |
1323 | if (nb_os_b > 0 && brg_kernel != nullptr) { |
1324 | if (jbgp.is_amx) |
1325 | amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]); |
1326 | if (transform_a) { |
1327 | const memory_desc_wrapper src_d(pd()->src_md()); |
1328 | auto src_ptr = ti->src |
1329 | + types::data_type_size(jbgp.src_dt) |
1330 | * src_d.blk_off(n, ic); |
1331 | |
1332 | transform_matrix_a_chunk(a_buffer, src_ptr, nb_os_b, |
1333 | is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block, |
1334 | jbgp.os_block); |
1335 | } |
1336 | |
1337 | if (transform_b) { |
1338 | auto diff_dst_ptr = diff_dst |
1339 | + types::data_type_size(jbgp.dst_dt) |
1340 | * diff_dst_d.blk_off(n, oc); |
1341 | transform_matrix_b_chunk(b_buffer, diff_dst_ptr, nb_os_b, |
1342 | curr_oc_chunk_sz, jbgp.os_block); |
1343 | } |
1344 | |
1345 | for (int os_block = 0; os_block < nb_os_b; os_block++) { |
1346 | auto a_ptr = a_buffer + os_block * a_buf_osb_shift; |
1347 | addr_batch[os_block].ptr.A = a_ptr; |
1348 | auto diff_dst_ptr = diff_dst |
1349 | + types::data_type_size(jbgp.dst_dt) |
1350 | * diff_dst_d.blk_off( |
1351 | n + os_block * jbgp.os_block, oc); |
1352 | if (jbgp.use_buffer_b) { |
1353 | auto b_ptr = b_buffer + os_block * b_buf_osb_shift; |
1354 | addr_batch[os_block].ptr.B = b_ptr; |
1355 | } else { |
1356 | addr_batch[os_block].ptr.B = diff_dst_ptr; |
1357 | } |
1358 | if (jbgp.with_bias && icb == 0) { |
1359 | brgemm_kernel_diff_bias_t p; |
1360 | auto bias_ptr = diff_bias + bia_dt_size * oc; |
1361 | p.ptr_diff_dst = (void *)addr_batch[os_block].ptr.B; |
1362 | p.ptr_diff_bias_acc = (void *)get_bia_acc_ptr(oc); |
1363 | p.ptr_diff_bias = (void *)bias_ptr; |
1364 | bool is_first = kernel_init && os_block == 0; |
1365 | bool is_last = (jbgp.nthr_mb == 1 || os_chunks == 1) |
1366 | && osc == os_chunks - 1 && os_block == nb_os_b - 1 |
1367 | && !is_os_tail; |
1368 | p.flags = 0 | (is_first ? FLAG_REDUCE_FIRST : 0) |
1369 | | (is_last ? FLAG_REDUCE_LAST : 0); |
1370 | |
1371 | (*kernels_db_[false][is_oc_tail])(&p); |
1372 | } |
1373 | } |
1374 | brgemm_kernel_execute(brg_kernel, nb_os_b, addr_batch, |
1375 | (void *)get_wei_acc_ptr(ti, ocb, icb), wsp_tile); |
1376 | } |
1377 | |
1378 | if (is_os_tail) { |
1379 | int os_block = nb_os_b; |
1380 | auto a_ptr = a_buffer + os_block * a_buf_osb_shift; |
1381 | |
1382 | if (transform_a) { |
1383 | const memory_desc_wrapper src_d(pd()->src_md()); |
1384 | auto src_ptr = ti->src |
1385 | + types::data_type_size(jbgp.src_dt) |
1386 | * src_d.blk_off( |
1387 | n + os_block * jbgp.os_block, ic); |
1388 | transform_matrix_a_chunk(a_ptr, src_ptr, 1, |
1389 | is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block, |
1390 | jbgp.mb % jbgp.os_block); |
1391 | } |
1392 | |
1393 | addr_batch[0].ptr.A = a_ptr; |
1394 | auto diff_dst_ptr = diff_dst |
1395 | + types::data_type_size(jbgp.dst_dt) |
1396 | * diff_dst_d.blk_off( |
1397 | n + os_block * jbgp.os_block, oc); |
1398 | if (jbgp.use_buffer_b) { |
1399 | auto b_ptr = b_buffer + os_block * b_buf_osb_shift; |
1400 | |
1401 | if (transform_b) |
1402 | transform_matrix_b_chunk(b_ptr, diff_dst_ptr, 1, |
1403 | curr_oc_chunk_sz, jbgp.mb % jbgp.os_block); |
1404 | addr_batch[0].ptr.B = b_ptr; |
1405 | } else { |
1406 | addr_batch[0].ptr.B = diff_dst_ptr; |
1407 | } |
1408 | |
1409 | if (jbgp.with_bias && icb == 0) { |
1410 | brgemm_kernel_diff_bias_t p; |
1411 | auto bias_ptr = diff_bias + bia_dt_size * oc; |
1412 | p.ptr_diff_dst = (void *)addr_batch[0].ptr.B; |
1413 | p.ptr_diff_bias_acc = (void *)get_bia_acc_ptr(oc); |
1414 | p.ptr_diff_bias = (void *)bias_ptr; |
1415 | bool is_first = kernel_init && os_block == 0; |
1416 | bool is_last = (jbgp.nthr_mb == 1 || os_chunks == 1) |
1417 | && osc == os_chunks - 1; |
1418 | p.flags = 0 | (is_first ? FLAG_REDUCE_FIRST : 0) |
1419 | | (is_last ? FLAG_REDUCE_LAST : 0); |
1420 | |
1421 | (*kernels_db_[true][is_oc_tail])(&p); |
1422 | } |
1423 | |
1424 | auto use_init_ker = (kernel_init && nb_os_b == 0); |
1425 | const int brg_ker_idx_os_tail |
1426 | = brgemm_inner_product_utils::get_brg_kernel_index(jbgp, |
1427 | false, use_init_ker, is_ic_tail, is_oc_tail, true); |
1428 | auto brg_kernel_os_tail = brg_kernels_[brg_ker_idx_os_tail].get(); |
1429 | if (brg_kernel_os_tail != nullptr) { |
1430 | if (jbgp.is_amx) |
1431 | amx_tile_configure( |
1432 | &brg_kernel_palettes_[brg_ker_idx_os_tail][0]); |
1433 | brgemm_kernel_execute(brg_kernel_os_tail, 1, addr_batch, |
1434 | (void *)get_wei_acc_ptr(ti, ocb, icb), wsp_tile); |
1435 | } |
1436 | } |
1437 | |
1438 | if (transform_weights) { |
1439 | transpose_matrix_c_chunk(ti, ocb, icb, |
1440 | is_oc_tail ? jbgp.oc % jbgp.oc_block : jbgp.oc_block, |
1441 | is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block); |
1442 | } |
1443 | }; |
1444 | |
1445 | const auto occ_work = (ti->oc_c_end - ti->oc_c_start); |
1446 | const auto icc_work = (ti->ic_c_end - ti->ic_c_start); |
1447 | const auto osc_work = (ti->os_c_end - ti->os_c_start); |
1448 | |
1449 | auto loop_idx = 0; |
1450 | const auto loop_end = occ_work * icc_work * osc_work; |
1451 | |
1452 | int occ_idx = 0, icc_idx = 0, osc_idx = 0; |
1453 | loop_order_t loop_order = jbgp.ip_bwd_w_local_buffers_for_input_tensors |
1454 | ? loop_order_t::osc_icc_occ |
1455 | : jbgp.harness == harness_mb_reduction ? loop_order_t::osc_occ_icc |
1456 | : loop_order_t::occ_icc_osc; |
1457 | |
1458 | switch (loop_order) { |
1459 | case loop_order_t::osc_icc_occ: |
1460 | nd_iterator_init(loop_idx, osc_idx, osc_work, icc_idx, icc_work, |
1461 | occ_idx, occ_work); |
1462 | break; |
1463 | case loop_order_t::osc_occ_icc: |
1464 | nd_iterator_init(loop_idx, osc_idx, osc_work, occ_idx, occ_work, |
1465 | icc_idx, icc_work); |
1466 | break; |
1467 | case loop_order_t::occ_icc_osc: |
1468 | nd_iterator_init(loop_idx, occ_idx, occ_work, icc_idx, icc_work, |
1469 | osc_idx, osc_work); |
1470 | }; |
1471 | |
1472 | int osc_prev = -1, icc_prev = -1, occ_prev = -1; |
1473 | while (loop_idx < loop_end) { |
1474 | const int occ = ti->oc_c_start + occ_idx; |
1475 | const int icc = ti->ic_c_start + icc_idx; |
1476 | const int osc = ti->os_c_start + osc_idx; |
1477 | |
1478 | const int ocb_work = nstl::min( |
1479 | jbgp.nb_oc_blocking, jbgp.nb_oc - occ * jbgp.nb_oc_blocking); |
1480 | const int icb_work = nstl::min( |
1481 | jbgp.nb_ic_blocking, jbgp.nb_ic - icc * jbgp.nb_ic_blocking); |
1482 | |
1483 | for_(int ocb = 0; ocb < ocb_work; ocb++) |
1484 | for (int icb = 0; icb < icb_work; icb++) { |
1485 | ker(osc, icc * jbgp.nb_ic_blocking + icb, |
1486 | occ * jbgp.nb_oc_blocking + ocb, osc_prev, icc_prev, |
1487 | occ_prev); |
1488 | } |
1489 | osc_prev = osc; |
1490 | icc_prev = icc; |
1491 | occ_prev = occ; |
1492 | |
1493 | ++loop_idx; |
1494 | |
1495 | switch (loop_order) { |
1496 | case loop_order_t::osc_icc_occ: |
1497 | nd_iterator_step(osc_idx, osc_work, icc_idx, icc_work, occ_idx, |
1498 | occ_work); |
1499 | break; |
1500 | case loop_order_t::osc_occ_icc: |
1501 | nd_iterator_step(osc_idx, osc_work, occ_idx, occ_work, icc_idx, |
1502 | icc_work); |
1503 | break; |
1504 | case loop_order_t::occ_icc_osc: |
1505 | nd_iterator_step(occ_idx, occ_work, icc_idx, icc_work, osc_idx, |
1506 | osc_work); |
1507 | }; |
1508 | } |
1509 | if (jbgp.is_amx) amx_tile_release(); |
1510 | } |
1511 | |
1512 | template <cpu_isa_t isa> |
1513 | void brgemm_inner_product_bwd_weights_t< |
1514 | isa>::reduce_and_convert_diff_weights_and_bias(const thread_info_t *ti) |
1515 | const { |
1516 | const auto &jbgp = pd()->jbgp_; |
1517 | |
1518 | if (dnnl_thr_syncable() && jbgp.nthr > 1) |
1519 | simple_barrier::barrier(ti->barrier_ctx, jbgp.nthr); |
1520 | if (ti->nthr_os_c == 1) return; |
1521 | |
1522 | const bool is_f32_out = jbgp.wei_dt == data_type::f32; |
1523 | const int icb_scale = is_f32_out ? jbgp.ic_block / jbgp.simd_w : 1; |
1524 | |
1525 | const int icb_work = ti->ic_c_work * jbgp.nb_ic_blocking; |
1526 | const int ocb_work = ti->oc_c_work * jbgp.nb_oc_blocking; |
1527 | const int work = ocb_work * icb_work; |
1528 | |
1529 | int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking); |
1530 | int reduce_buffers = nstl::min(ti->nthr_os_c, os_chunks); |
1531 | int reduce_buf_idx_start = !is_f32_out; |
1532 | int reduce_buf_idx_end = reduce_buffers - is_f32_out; |
1533 | |
1534 | int start = 0, end = 0; |
1535 | balance211(work, ti->nthr_os_c, ti->ithr_os_c, start, end); |
1536 | if (start == end) return; |
1537 | |
1538 | int icb_l = 0, ocb_l = 0; |
1539 | const int acc_size = jbgp.ic_block * jbgp.oc_block; |
1540 | |
1541 | for (int ir = reduce_buf_idx_start; ir < reduce_buf_idx_end; ++ir) { |
1542 | int counter = start; |
1543 | nd_iterator_init(start, ocb_l, ocb_work, icb_l, icb_work); |
1544 | while (counter < end) { |
1545 | const int ocb = ti->oc_c_start * jbgp.nb_oc_blocking + ocb_l; |
1546 | const int icb = ti->ic_c_start * jbgp.nb_ic_blocking + icb_l; |
1547 | char *wei_to_reduce = get_wei_acc_ptr(ti, ocb, icb, ir); |
1548 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1549 | char *wei_reduced = !is_f32_out ? get_wei_acc_ptr(ti, ocb, icb, 0) |
1550 | : ti->diff_weights |
1551 | + get_blk_off(diff_weights_d, jbgp.wei_dt, ocb, |
1552 | icb * icb_scale); |
1553 | acc_ker_->accumulate( |
1554 | (float *)(wei_reduced), (float *)(wei_to_reduce), acc_size); |
1555 | if (!is_f32_out && ir + 1 == reduce_buf_idx_end) { |
1556 | transpose_matrix_c_chunk(ti, ocb, icb * icb_scale, |
1557 | jbgp.oc_block, jbgp.ic_block, true); |
1558 | } |
1559 | ++counter; |
1560 | nd_iterator_step(ocb_l, ocb_work, icb_l, icb_work); |
1561 | } |
1562 | } |
1563 | |
1564 | if (jbgp.with_bias && ti->ithr_ic_c == 0 && ti->ic_c_work > 0 |
1565 | && ti->ithr_os_c == 0 && ti->os_c_work > 0 && ti->oc_c_work > 0) { |
1566 | const bool is_f32_bias = jbgp.bia_dt == data_type::f32; |
1567 | float *bias_reduced = is_f32_bias ? (float *)ti->diff_bias |
1568 | : (float *)ti->buffer_bias; |
1569 | int reduce_buf_idx_start = !is_f32_bias; |
1570 | int reduce_buf_idx_end = reduce_buffers - 1; |
1571 | int oc_chunk_size = jbgp.nb_oc_blocking * jbgp.oc_block; |
1572 | int oc = ti->oc_c_start * oc_chunk_size; |
1573 | int acc_size = nstl::min(ti->oc_c_work * oc_chunk_size, jbgp.oc - oc); |
1574 | |
1575 | int ir = reduce_buf_idx_start; |
1576 | for (; ir < reduce_buf_idx_end; ++ir) { |
1577 | float *bias_to_reduce = (float *)ti->buffer_bias + ir * jbgp.oc; |
1578 | acc_ker_->accumulate( |
1579 | &bias_reduced[oc], &bias_to_reduce[oc], acc_size); |
1580 | } |
1581 | |
1582 | if (!is_f32_bias) { |
1583 | float *bias_to_reduce = (float *)ti->buffer_bias + ir * jbgp.oc; |
1584 | switch (jbgp.bia_dt) { |
1585 | case data_type::bf16: |
1586 | add_floats_and_cvt_to_bfloat16( |
1587 | (bfloat16_t *)(ti->diff_bias) + oc, |
1588 | &bias_reduced[oc], &bias_to_reduce[oc], acc_size); |
1589 | break; |
1590 | case data_type::f16: |
1591 | add_floats_and_cvt_to_float16( |
1592 | (float16_t *)(ti->diff_bias) + oc, |
1593 | &bias_reduced[oc], &bias_to_reduce[oc], acc_size); |
1594 | break; |
1595 | default: assert(!"invalid data type" ); |
1596 | } |
1597 | } |
1598 | } |
1599 | } |
1600 | |
1601 | template <cpu_isa_t isa> |
1602 | void brgemm_inner_product_bwd_weights_t<isa>::execute_backward_weights( |
1603 | const exec_ctx_t &ctx) const { |
1604 | const auto &jbgp = pd()->jbgp_; |
1605 | |
1606 | if (dnnl_thr_syncable() && jbgp.nthr > 1) { |
1607 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1608 | simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( |
1609 | key_conv_wei_bia_reduction_bctx)); |
1610 | } |
1611 | |
1612 | parallel(jbgp.nthr, [&](const int ithr, const int nthr) { |
1613 | thread_info_t thread_info(this, ctx, ithr); |
1614 | compute_diff_weights_and_bias(&thread_info); |
1615 | |
1616 | if (dnnl_thr_syncable()) { |
1617 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1618 | } |
1619 | }); |
1620 | |
1621 | if (!dnnl_thr_syncable()) { |
1622 | parallel(jbgp.nthr, [&](const int ithr, const int nthr) { |
1623 | thread_info_t thread_info(this, ctx, ithr); |
1624 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1625 | }); |
1626 | } |
1627 | } |
1628 | |
1629 | template struct brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>; |
1630 | template struct brgemm_inner_product_bwd_weights_t<avx512_core_fp16>; |
1631 | template struct brgemm_inner_product_bwd_weights_t<avx512_core_amx>; |
1632 | template struct brgemm_inner_product_bwd_weights_t<avx512_core_bf16>; |
1633 | template struct brgemm_inner_product_bwd_weights_t<avx512_core>; |
1634 | |
1635 | } // namespace x64 |
1636 | } // namespace cpu |
1637 | } // namespace impl |
1638 | } // namespace dnnl |
1639 | |
1640 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1641 | |