1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <float.h> |
17 | |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/memory_tracking.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | #include "cpu/x64/injectors/injector_utils.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
29 | #include "cpu/x64/jit_avx512_core_bf16_1x1_conv_kernel.hpp" |
30 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
31 | |
32 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | using namespace dnnl::impl::prop_kind; |
40 | using namespace dnnl::impl::utils; |
41 | |
42 | using namespace Xbyak; |
43 | |
44 | jit_avx512_core_bf16_1x1_conv_kernel::jit_avx512_core_bf16_1x1_conv_kernel( |
45 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
46 | const memory_desc_t &dst_md) |
47 | : jit_generator(jit_name(), nullptr, ker_code_size, true, avx512_core_bf16) |
48 | , jcp(ajcp) |
49 | , attr_(attr) { |
50 | if (jcp.with_eltwise || jcp.with_binary) { |
51 | using namespace binary_injector; |
52 | static constexpr bool preserve_gpr = true; |
53 | static constexpr bool preserve_vmm = false; |
54 | static constexpr size_t helper_vmm_idx = 31; |
55 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
56 | static constexpr bool use_exact_tail_scalar_bcast = true; |
57 | |
58 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
59 | r14, r15, r12, preserve_gpr, preserve_vmm, |
60 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
61 | memory_desc_wrapper(dst_md), tail_size, k_load_dim_tail_mask, |
62 | use_exact_tail_scalar_bcast}; |
63 | const static_params_t static_params { |
64 | this->param1, rhs_arg_static_params}; |
65 | |
66 | postops_injector_ = utils::make_unique< |
67 | injector::jit_uni_postops_injector_t<avx512_core>>( |
68 | this, jcp.post_ops, static_params); |
69 | } |
70 | |
71 | if (!isa_has_bf16(jcp.isa)) |
72 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
73 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
74 | bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6); |
75 | } |
76 | |
77 | void jit_avx512_core_bf16_1x1_conv_kernel::bcast_loop(int load_loop_blk) { |
78 | mov(aux1_reg_bcast_data, reg_bcast_data); |
79 | mov(aux_reg_bcast_data, reg_bcast_data); |
80 | |
81 | mov(aux_reg_output_data, reg_output_data); |
82 | mov(aux_reg_store_buf, reg_store_buf); |
83 | |
84 | mov(reg_bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt)); |
85 | |
86 | Label bcast_loop; |
87 | Label bcast_loop_tail; |
88 | Label long_tail; |
89 | |
90 | cmp(reg_bcast_loop_iter, jcp.ur); |
91 | jl(bcast_loop_tail, T_NEAR); |
92 | |
93 | L(bcast_loop); |
94 | { |
95 | assert(jcp.bcast_block % jcp.ur == 0); |
96 | int num_substeps = jcp.bcast_block / jcp.ur; |
97 | assert(num_substeps > 0 && num_substeps < 10); |
98 | for (int i = 0; i < num_substeps; i++) { |
99 | if (i + 1 == num_substeps) L(long_tail); |
100 | reduce_loop(load_loop_blk, jcp.ur, i, false); |
101 | if (i < num_substeps - 1) { |
102 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); |
103 | |
104 | add(aux_reg_output_data, jcp.bcast_loop_output_substep); |
105 | add(aux_reg_store_buf, jcp.bcast_loop_output_substep); |
106 | } else { |
107 | add(aux1_reg_bcast_data, |
108 | jcp.bcast_loop_bcast_step |
109 | - (num_substeps - 1) |
110 | * jcp.bcast_loop_bcast_substep); |
111 | |
112 | add(aux_reg_output_data, |
113 | jcp.bcast_loop_output_step * jcp.typesize_out |
114 | - (num_substeps - 1) |
115 | * jcp.bcast_loop_output_substep); |
116 | add(aux_reg_store_buf, |
117 | jcp.bcast_loop_output_step * jcp.typesize_acc |
118 | - (num_substeps - 1) |
119 | * jcp.bcast_loop_output_substep); |
120 | } |
121 | sub(reg_bcast_loop_iter, jcp.ur); |
122 | } |
123 | cmp(reg_bcast_loop_iter, jcp.bcast_block); |
124 | jge(bcast_loop, T_NEAR); |
125 | } |
126 | |
127 | L(bcast_loop_tail); |
128 | if (jcp.ur_tail) { |
129 | Label bcast_loop_tail_out; |
130 | if (jcp.ur_tail >= jcp.ur) { |
131 | cmp(reg_bcast_loop_iter, jcp.ur); |
132 | jge(long_tail, T_NEAR); |
133 | } |
134 | if (jcp.ur_tail % jcp.ur) { |
135 | cmp(reg_bcast_loop_iter, 0); |
136 | jle(bcast_loop_tail_out, T_NEAR); |
137 | reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true); |
138 | L(bcast_loop_tail_out); |
139 | } |
140 | } |
141 | } |
142 | |
143 | static int vreg_accum_idx( |
144 | const int load_loop_blk, const int i_load, const int i_ur) { |
145 | int idx = i_ur * load_loop_blk + i_load; |
146 | assert(idx < 31); |
147 | return idx; |
148 | } |
149 | |
150 | Address jit_avx512_core_bf16_1x1_conv_kernel::output_ptr( |
151 | const int i_load, const int i_ur) { |
152 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
153 | backward_data)) { |
154 | const bool is_output_layout_nxc = is_out_layout_nxc(); |
155 | int i_load_shift = is_output_layout_nxc |
156 | ? jcp.load_block |
157 | : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; |
158 | int i_ur_shift = is_output_layout_nxc ? jcp.load_dim : jcp.load_block; |
159 | int offset = (i_load * i_load_shift + i_ur * i_ur_shift) |
160 | * jcp.typesize_out; |
161 | return EVEX_compress_addr(aux_reg_output_data, offset); |
162 | } else |
163 | return ptr[aux_reg_output_data |
164 | + (i_load ? reg_output_stride * i_load |
165 | : 0) // TODO: Xbyak should allow 0 scale |
166 | + jcp.typesize_out * jcp.load_block * i_ur]; |
167 | } |
168 | |
169 | template <typename F> |
170 | static void iterate(const int load_loop_blk, const int ur, const bool mask_tail, |
171 | const F &f) { |
172 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
173 | const bool mask_flag = (mask_tail && i_load + 1 == load_loop_blk); |
174 | for (int i_ur = 0; i_ur < ur; i_ur++) |
175 | f(mask_flag, i_load, i_ur); |
176 | } |
177 | } |
178 | template <typename F> |
179 | static void iterate(const int load_loop_blk, const int ur, const F &f) { |
180 | iterate(load_loop_blk, ur, false, f); |
181 | } |
182 | |
183 | void jit_avx512_core_bf16_1x1_conv_kernel::apply_postops( |
184 | const int load_loop_blk, const int ur) { |
185 | if (jcp.with_eltwise || jcp.with_binary) { |
186 | injector_utils::vmm_index_set_t vmm_idxs; |
187 | if (jcp.with_binary) { |
188 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
189 | rhs_arg_params_tail; |
190 | const auto mask_tail = jcp.oc_without_padding % isa_simd_width_; |
191 | iterate(load_loop_blk, ur, mask_tail, |
192 | [&](const bool mask_flag, const int i_load, |
193 | const int i_ur) { |
194 | const int aux_output_l_off |
195 | = get_output_offset(i_load, i_ur); |
196 | const auto vmm_idx |
197 | = vreg_accum_idx(load_loop_blk, i_load, i_ur); |
198 | vmm_idxs.emplace(vmm_idx); |
199 | |
200 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
201 | vmm_idx, aux_reg_output_data); |
202 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
203 | vmm_idx, aux_output_l_off); |
204 | if (mask_flag) |
205 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
206 | }); |
207 | rhs_arg_params = rhs_arg_params_tail; |
208 | rhs_arg_params.vmm_tail_idx_.clear(); |
209 | |
210 | const injector_utils::register_preserve_guard_t register_guard( |
211 | this, {reg_tmp}); |
212 | const size_t reg_guard_stack_occupied |
213 | = register_guard.stack_space_occupied(); |
214 | |
215 | mov(abi_param1, |
216 | EVEX_compress_addr(rsp, |
217 | reg_abi_param1_backup + reg_guard_stack_occupied)); |
218 | |
219 | Label postops_done; |
220 | if (mask_tail) { |
221 | Label postops_no_tail; |
222 | mov(reg_tmp, |
223 | ptr[rsp + reg_load_loop_work_off |
224 | + reg_guard_stack_occupied]); |
225 | cmp(reg_tmp, jcp.oc_block * load_loop_blk); |
226 | jge(postops_no_tail, T_NEAR); |
227 | postops_injector_->compute_vector_range( |
228 | vmm_idxs, rhs_arg_params_tail); |
229 | jmp(postops_done, T_NEAR); |
230 | L(postops_no_tail); |
231 | } |
232 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
233 | L(postops_done); |
234 | |
235 | } else { |
236 | iterate(load_loop_blk, ur, |
237 | [&](const bool, const int i_load, const int i_ur) { |
238 | vmm_idxs.emplace( |
239 | vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
240 | }); |
241 | postops_injector_->compute_vector_range(vmm_idxs); |
242 | } |
243 | } |
244 | } |
245 | |
246 | void jit_avx512_core_bf16_1x1_conv_kernel::reduce_loop( |
247 | int load_loop_blk, int ur, int substep, bool wraparound) { |
248 | const bool load_layout_nxc = is_load_layout_nxc(); |
249 | const bool bcast_layout_nxc = is_bcast_layout_nxc(); |
250 | const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; |
251 | const int load_dim_tail = jcp.load_dim % jcp.load_block; |
252 | |
253 | auto vreg_load = [=](int i_load) { |
254 | int idx = ur * load_loop_blk + i_load; |
255 | assert(idx < 31); |
256 | return Zmm(idx); |
257 | }; |
258 | auto ymm_store = [=]() { return Xbyak::Ymm(31); }; |
259 | auto zmm_store = [=]() { return Xbyak::Zmm(31); }; |
260 | |
261 | const auto vreg_accum = [=](int i_load, int i_ur) { |
262 | return Zmm(vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
263 | }; |
264 | |
265 | auto bias_ptr = [=](int i_load) { |
266 | return EVEX_compress_addr( |
267 | reg_bias_data, jcp.typesize_bia * jcp.oc_block * i_load); |
268 | }; |
269 | |
270 | auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { |
271 | assert(i_ur < jcp.ur); |
272 | assert(i_reduce <= jcp.reduce_loop_unroll); |
273 | int offt; |
274 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
275 | backward_data)) { |
276 | assert(jcp.reduce_loop_unroll == jcp.reduce_block); |
277 | const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim |
278 | : jcp.reduce_loop_unroll; |
279 | offt = (i_reduce == jcp.reduce_loop_unroll) |
280 | ? (jcp.bcast_dim + i_ur) * reduce_mul |
281 | : i_ur * reduce_mul + i_reduce; |
282 | } else { |
283 | if (jcp.uses_permw_transposition) { |
284 | int rmul = bcast_layout_nxc ? jcp.ngroups * jcp.ic |
285 | : jcp.ic_block; |
286 | offt = i_reduce * rmul + i_ur; |
287 | } else { |
288 | offt = (i_reduce / 2) * 2 * jcp.ic_block + 2 * i_ur; |
289 | } |
290 | } |
291 | return EVEX_compress_addr( |
292 | aux_reg_bcast_data, jcp.typesize_in * offt, bcast); |
293 | }; |
294 | |
295 | auto load_ptr = [=](int i_reduce, int i_load) { |
296 | int u0 = i_reduce % jcp.reduce_loop_unroll; |
297 | int u1 = i_reduce / jcp.reduce_loop_unroll; |
298 | int lmul = jcp.load_block |
299 | * (load_layout_nxc ? 1 |
300 | : utils::rnd_up( |
301 | jcp.reduce_dim, jcp.reduce_block)); |
302 | int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block; |
303 | int offt = i_load * lmul + u0 * rmul; |
304 | return EVEX_compress_addr(aux_reg_load_data, |
305 | u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt); |
306 | }; |
307 | |
308 | auto store_buffer_ptr = [=](int i_load, int i_ur) { |
309 | const bool is_output_layout_nxc = is_out_layout_nxc(); |
310 | int i_load_shift |
311 | = jcp.load_block * (is_output_layout_nxc ? 1 : jcp.bcast_dim); |
312 | int i_ur_shift = is_output_layout_nxc ? jcp.load_dim : jcp.load_block; |
313 | int offset = (i_load * i_load_shift + i_ur * i_ur_shift) |
314 | * jcp.typesize_acc; |
315 | return EVEX_compress_addr(aux_reg_store_buf, offset); |
316 | }; |
317 | |
318 | auto init = [=]() { |
319 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
320 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
321 | auto r = vreg_accum(i_load, i_ur); |
322 | vpxord(r, r, r); |
323 | } |
324 | }; |
325 | |
326 | auto store = [=]() { |
327 | if (!isa_has_bf16(jcp.isa)) bf16_emu_->init_vcvtneps2bf16(); |
328 | |
329 | auto preamble = [=]() { |
330 | auto preamble_read = [=](bool from_buf) { |
331 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
332 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
333 | auto r = vreg_accum(i_load, i_ur); |
334 | if (jcp.prop_kind == backward_weights) |
335 | vaddps(r, r, output_ptr(i_load, i_ur)); |
336 | else { |
337 | bool mask_flag = load_dim_tail |
338 | && i_load + 1 == load_loop_blk; |
339 | auto zmm_prev_dst = Xbyak::Zmm(31); |
340 | auto zmm_prev_dst_masked = may_be_mask_zmm( |
341 | zmm_prev_dst, mask_flag, true); |
342 | if (from_buf) { |
343 | vmovups(zmm_prev_dst, |
344 | store_buffer_ptr(i_load, i_ur)); |
345 | } else if (jcp.dst_dt == data_type::bf16) { |
346 | vpmovzxwd(zmm_prev_dst_masked, |
347 | output_ptr(i_load, i_ur)); |
348 | vpslld(zmm_prev_dst, zmm_prev_dst, 16); |
349 | } else { |
350 | vmovups(zmm_prev_dst_masked, |
351 | output_ptr(i_load, i_ur)); |
352 | } |
353 | vaddps(r, zmm_prev_dst); |
354 | } |
355 | } |
356 | } |
357 | }; |
358 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
359 | backward_data)) { |
360 | Label read_from_output; |
361 | Label read_done; |
362 | |
363 | if (!jcp.with_sum) { |
364 | test(reg_reduce_pos_flag, |
365 | FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST |
366 | jnz(read_done, T_NEAR); |
367 | } else { |
368 | test(reg_reduce_pos_flag, |
369 | FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST |
370 | jnz(read_from_output, T_NEAR); |
371 | } |
372 | preamble_read(true); |
373 | jmp(read_done, T_NEAR); |
374 | |
375 | L(read_from_output); |
376 | preamble_read(false); |
377 | |
378 | L(read_done); |
379 | } else if (jcp.prop_kind == backward_weights) { |
380 | Label read_done; |
381 | |
382 | test(reg_reduce_pos_flag, |
383 | FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST |
384 | jnz(read_done, T_NEAR); |
385 | preamble_read(false); |
386 | |
387 | L(read_done); |
388 | } |
389 | }; |
390 | |
391 | const auto apply_bias = [=]() { |
392 | if (jcp.with_bias |
393 | && one_of(jcp.prop_kind, forward_training, |
394 | forward_inference)) { |
395 | auto zmm_bias = Xbyak::Zmm(31); |
396 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
397 | bool mask_flag |
398 | = load_dim_tail && i_load + 1 == load_loop_blk; |
399 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
400 | auto vreg_acc = vreg_accum(i_load, i_ur); |
401 | if (jcp.bia_dt == data_type::bf16) { |
402 | vpmovzxwd( |
403 | may_be_mask_zmm(zmm_bias, mask_flag, true), |
404 | bias_ptr(i_load)); |
405 | vpslld(zmm_bias, zmm_bias, 16); |
406 | vaddps(vreg_acc, zmm_bias); |
407 | } else { |
408 | vaddps(may_be_mask_zmm(vreg_acc, mask_flag, true), |
409 | bias_ptr(i_load)); |
410 | } |
411 | } |
412 | } |
413 | } |
414 | }; |
415 | |
416 | const auto apply_bias_and_postops = [=]() { |
417 | Label store_no_post_ops; |
418 | |
419 | test(reg_reduce_pos_flag, |
420 | FLAG_REDUCE_LAST); // If Not FLAG_REDUCE_LAST |
421 | jz(store_no_post_ops, T_NEAR); |
422 | |
423 | apply_bias(); |
424 | apply_postops(load_loop_blk, ur); |
425 | L(store_no_post_ops); |
426 | }; |
427 | |
428 | auto store_output = [=](bool to_buf) { |
429 | if (to_buf) { |
430 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
431 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
432 | vmovups(store_buffer_ptr(i_load, i_ur), |
433 | vreg_accum(i_load, i_ur)); |
434 | } |
435 | } |
436 | } else if (jcp.prop_kind == backward_weights |
437 | || jcp.dst_dt == data_type::f32) { |
438 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
439 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
440 | auto vreg_acc = vreg_accum(i_load, i_ur); |
441 | // for nxc_layout-bwd_w, weights are still padded and |
442 | // the output_ptr here can be uninitialized scratchpad. |
443 | // To ensure final output (after reduction) is zero |
444 | // padded, here we zero-pad output by omitting the mask. |
445 | bool mask_flag = (jcp.prop_kind != backward_weights) |
446 | && load_dim_tail && i_load + 1 == load_loop_blk; |
447 | vmovups(output_ptr(i_load, i_ur), |
448 | may_be_mask_zmm(vreg_acc, mask_flag, false)); |
449 | } |
450 | } |
451 | } else if (jcp.dst_dt == data_type::bf16) { |
452 | if (isa_has_bf16(jcp.isa) && is_out_layout_nxc()) { |
453 | // Optimization: use single store instruction for pair |
454 | // of the nearest vectors along LOAD dimension |
455 | for (int i_ur = 0; i_ur < ur; i_ur++) { |
456 | int i_load = 0; |
457 | for (; i_load < rnd_dn(load_loop_blk, 2); i_load += 2) { |
458 | auto zmm = vreg_accum(i_load, i_ur); |
459 | auto zmm_next = vreg_accum(i_load + 1, i_ur); |
460 | vcvtne2ps2bf16(zmm, zmm_next, zmm); |
461 | bool mask_flag = load_dim_tail |
462 | && i_load + 2 == load_loop_blk; |
463 | vmovdqu16(output_ptr(i_load, i_ur), |
464 | may_be_mask_zmm( |
465 | zmm, mask_flag, false, true)); |
466 | } |
467 | if (load_loop_blk % 2 != 0) { |
468 | auto zmm = vreg_accum(i_load, i_ur); |
469 | auto ymm = Ymm(zmm.getIdx()); |
470 | vcvtneps2bf16(ymm, zmm); |
471 | vmovdqu16(output_ptr(i_load, i_ur), |
472 | may_be_mask_ymm(ymm, load_dim_tail)); |
473 | } |
474 | } |
475 | } else if (isa_has_bf16(jcp.isa)) { |
476 | // Optimization: use single store instruction for pair |
477 | // of the nearest vectors along BCAST dimension |
478 | assert(!is_out_layout_nxc()); |
479 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
480 | int n_2bf2ps = (ur / 2) * 2, i_ur = 0; |
481 | for (i_ur = 0; i_ur < n_2bf2ps; i_ur += 2) { |
482 | auto zmm = zmm_store(); |
483 | vcvtne2ps2bf16(zmm, vreg_accum(i_load, i_ur + 1), |
484 | vreg_accum(i_load, i_ur)); |
485 | vmovups(output_ptr(i_load, i_ur), zmm); |
486 | } |
487 | if (i_ur < ur) { |
488 | auto ymm = ymm_store(); |
489 | vcvtneps2bf16(ymm, vreg_accum(i_load, i_ur)); |
490 | vmovups(output_ptr(i_load, i_ur), ymm); |
491 | } |
492 | } |
493 | } else { |
494 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
495 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
496 | auto ymm = ymm_store(); |
497 | bf16_emu_->vcvtneps2bf16( |
498 | ymm, vreg_accum(i_load, i_ur)); |
499 | bool mask_flag = load_dim_tail |
500 | && i_load + 1 == load_loop_blk; |
501 | vmovdqu16(output_ptr(i_load, i_ur), |
502 | may_be_mask_ymm(ymm, mask_flag)); |
503 | } |
504 | } |
505 | } |
506 | } else { |
507 | assert(!"unsupported destination type" ); |
508 | } |
509 | }; |
510 | |
511 | preamble(); |
512 | |
513 | apply_bias_and_postops(); |
514 | |
515 | if (jcp.prop_kind == backward_weights) { |
516 | store_output(false); |
517 | } else { |
518 | Label store_to_output; |
519 | Label store_done; |
520 | |
521 | test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); // If FLAG_REDUCE_LAST |
522 | jnz(store_to_output, T_NEAR); |
523 | |
524 | store_output(true); |
525 | jmp(store_done, T_NEAR); |
526 | |
527 | L(store_to_output); |
528 | store_output(false); |
529 | |
530 | L(store_done); |
531 | } |
532 | }; |
533 | |
534 | auto fma_block_bwd_w = [=](bool is_tail) { |
535 | int n_reduce_tail = jcp.reduce_dim % jcp.reduce_loop_unroll; |
536 | int n_reduce |
537 | = is_tail && n_reduce_tail > 0 && !jcp.uses_permw_transposition |
538 | ? n_reduce_tail |
539 | : jcp.reduce_loop_unroll; |
540 | int bcast_count = 0; |
541 | int pipeline_length_max = 1; |
542 | if (isa_has_bf16(jcp.isa)) { |
543 | const int max_regs = 32; |
544 | const int regs_for_accum = ur * load_loop_blk; |
545 | const int regs_for_pipeline_total |
546 | = max_regs - regs_for_accum - jcp.uses_permw_transposition; |
547 | const int regs_for_pipeline_iter |
548 | = load_loop_blk + jcp.uses_permw_transposition; |
549 | assert(regs_for_pipeline_total >= regs_for_pipeline_iter); |
550 | pipeline_length_max = nstl::min( |
551 | regs_for_pipeline_total / regs_for_pipeline_iter, |
552 | n_reduce / 2); |
553 | } |
554 | |
555 | const int pipeline = saturate(1, 4, pipeline_length_max); |
556 | auto zmm_prm = [=]() { return zmm_store(); }; |
557 | auto get_load_start_idx = [=](int bcast_count) { |
558 | return pipeline * jcp.uses_permw_transposition |
559 | + (bcast_count % pipeline) * load_loop_blk; |
560 | }; |
561 | auto pipeline_bcast_ptr = [=](int i_reduce, int i_ur, bool bcast, |
562 | int pipeline_idx) { |
563 | if (jcp.uses_permw_transposition) { |
564 | int offset = 64 * pipeline_idx + jcp.typesize_in * 2 * i_ur; |
565 | auto p = rsp + broadcast_space + offset; |
566 | return bcast ? zword_b[p] : ptr[p]; |
567 | } else { |
568 | return bcast_ptr(i_reduce, i_ur, bcast); |
569 | } |
570 | }; |
571 | |
572 | auto get_load_mask = [=](int i_reduce) { |
573 | bool is_reduce_tail = jcp.reduce_loop_unroll % 2 |
574 | && i_reduce + 2 >= jcp.reduce_loop_unroll; |
575 | return is_load_layout_nxc() || is_reduce_tail ? half_mask |
576 | : full_mask; |
577 | }; |
578 | auto get_bcast_mask = [=](int i_reduce) { |
579 | bool is_reduce_tail = jcp.reduce_loop_unroll % 2 |
580 | && i_reduce + 2 >= jcp.reduce_loop_unroll; |
581 | return is_bcast_layout_nxc() || is_reduce_tail ? half_mask |
582 | : full_mask; |
583 | }; |
584 | |
585 | if (jcp.uses_permw_transposition) { |
586 | mov(EVEX_compress_addr(rsp, perm_reg_offset), reg_reduce_pos_flag); |
587 | mov(reg_trans_tmp, dst_prm_table); |
588 | vmovups(zmm_prm(), ptr[reg_trans_tmp]); |
589 | |
590 | for (; bcast_count < pipeline; bcast_count++) { |
591 | int i_reduce = 2 * bcast_count; |
592 | bool is_reduce_tail = jcp.reduce_loop_unroll % 2 |
593 | && i_reduce + 2 >= jcp.reduce_loop_unroll; |
594 | int load_idx = get_load_start_idx(bcast_count); |
595 | Opmask bcast_mask = get_bcast_mask(i_reduce); |
596 | auto bcast_values = vreg_load(bcast_count); |
597 | |
598 | vmovdqu16(bcast_values | bcast_mask | T_z, |
599 | bcast_ptr(i_reduce, 0, false)); |
600 | if (is_bcast_layout_nxc() && !is_reduce_tail) { |
601 | // Reuse i_ur argument to shift back pointer |
602 | vmovdqu16(bcast_values | half_mask_hi, |
603 | bcast_ptr(i_reduce + 1, -jcp.ic_block, false)); |
604 | } |
605 | vpermw(bcast_values, zmm_prm(), bcast_values); |
606 | vmovups(pipeline_bcast_ptr(i_reduce, 0, false, bcast_count), |
607 | bcast_values); |
608 | |
609 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
610 | vmovdqu16(vreg_load(load_idx + i_load) | bcast_mask | T_z, |
611 | load_ptr(i_reduce, i_load)); |
612 | if (is_load_layout_nxc() && !is_reduce_tail) { |
613 | vmovdqu16(vreg_load(load_idx + i_load) | half_mask_hi, |
614 | load_ptr(i_reduce + 1, i_load - 1)); |
615 | } |
616 | vpermw(vreg_load(load_idx + i_load), zmm_prm(), |
617 | vreg_load(load_idx + i_load)); |
618 | } |
619 | } |
620 | } else { |
621 | for (; bcast_count < pipeline; bcast_count++) { |
622 | int i_reduce = 2 * bcast_count; |
623 | int load_idx = get_load_start_idx(bcast_count); |
624 | |
625 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
626 | auto vreg = vreg_load(load_idx + i_load); |
627 | bool mask_flag |
628 | = load_dim_tail && i_load + 1 == load_loop_blk; |
629 | vmovups(may_be_mask_zmm(vreg, mask_flag, true), |
630 | load_ptr(i_reduce, i_load)); |
631 | } |
632 | } |
633 | } |
634 | |
635 | int use_bcast_count = 0; |
636 | for (int i_reduce = 0; i_reduce < n_reduce; i_reduce += 2) { |
637 | int bcast_pl_idx = use_bcast_count % pipeline; |
638 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
639 | // TODO: try to enable jcp.expl_bcast version |
640 | if (jcp.expl_bcast && load_loop_blk > 1) { |
641 | vpbroadcastd(vreg_bcast, |
642 | pipeline_bcast_ptr( |
643 | i_reduce, i_ur, false, bcast_pl_idx)); |
644 | } |
645 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
646 | int load_idx = get_load_start_idx(use_bcast_count) + i_load; |
647 | if (!isa_has_bf16(jcp.isa)) { |
648 | if (jcp.uses_permw_transposition) { |
649 | bool is_reduce_tail = jcp.reduce_loop_unroll % 2 |
650 | && i_reduce + 2 >= jcp.reduce_loop_unroll; |
651 | Opmask load_mask = get_load_mask(i_reduce); |
652 | vmovdqu16(vreg_load(i_load) | load_mask | T_z, |
653 | load_ptr(i_reduce, i_load)); |
654 | if (is_load_layout_nxc() && !is_reduce_tail) { |
655 | vmovdqu16(vreg_load(i_load) | half_mask_hi, |
656 | load_ptr(i_reduce + 1, i_load - 1)); |
657 | } |
658 | vpermw(vreg_load(i_load), zmm_prm(), |
659 | vreg_load(i_load)); |
660 | } else |
661 | vmovups(vreg_load(i_load), |
662 | load_ptr(i_reduce, i_load)); |
663 | } |
664 | if (jcp.expl_bcast && load_loop_blk > 1) { |
665 | if (!isa_has_bf16(jcp.isa)) { |
666 | auto acc = vreg_accum(i_load, i_ur); |
667 | auto wei = vreg_load(i_load); |
668 | bf16_emu_->vdpbf16ps(acc, wei, vreg_bcast); |
669 | } else |
670 | vdpbf16ps(vreg_accum(i_load, i_ur), |
671 | vreg_load(load_idx), vreg_bcast); |
672 | } else { |
673 | if (!isa_has_bf16(jcp.isa)) { |
674 | vpbroadcastd(zmm_tmp2, |
675 | pipeline_bcast_ptr( |
676 | i_reduce, i_ur, false, 0)); |
677 | auto acc = vreg_accum(i_load, i_ur); |
678 | auto wei = vreg_load(i_load); |
679 | bf16_emu_->vdpbf16ps(acc, wei, zmm_tmp2); |
680 | } else |
681 | vdpbf16ps(vreg_accum(i_load, i_ur), |
682 | vreg_load(load_idx), |
683 | pipeline_bcast_ptr(i_reduce, i_ur, true, |
684 | bcast_pl_idx)); |
685 | } |
686 | } |
687 | } |
688 | use_bcast_count++; |
689 | if (bcast_count < div_up(n_reduce, 2)) { |
690 | int load_idx = get_load_start_idx(bcast_count); |
691 | int i_reduce = bcast_count * 2; |
692 | |
693 | if (jcp.uses_permw_transposition) { |
694 | bool is_reduce_tail = jcp.reduce_loop_unroll % 2 |
695 | && i_reduce + 2 >= jcp.reduce_loop_unroll; |
696 | Opmask bcast_mask = get_bcast_mask(i_reduce); |
697 | int bcast_pl_idx = bcast_count % pipeline; |
698 | auto bcast_values = vreg_load(bcast_pl_idx); |
699 | |
700 | vmovdqu16(bcast_values | bcast_mask | T_z, |
701 | bcast_ptr(i_reduce, 0, false)); |
702 | if (is_bcast_layout_nxc() && !is_reduce_tail) { |
703 | vmovdqu16(bcast_values | half_mask_hi, |
704 | bcast_ptr(i_reduce + 1, -jcp.ic_block, false)); |
705 | } |
706 | vpermw(bcast_values, zmm_prm(), bcast_values); |
707 | vmovups(pipeline_bcast_ptr( |
708 | i_reduce, 0, false, bcast_pl_idx), |
709 | bcast_values); |
710 | |
711 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
712 | vmovdqu16( |
713 | vreg_load(load_idx + i_load) | bcast_mask | T_z, |
714 | load_ptr(i_reduce, i_load)); |
715 | if (is_load_layout_nxc() && !is_reduce_tail) { |
716 | vmovdqu16( |
717 | vreg_load(load_idx + i_load) | half_mask_hi, |
718 | load_ptr(i_reduce + 1, i_load - 1)); |
719 | } |
720 | vpermw(vreg_load(load_idx + i_load), zmm_prm(), |
721 | vreg_load(load_idx + i_load)); |
722 | } |
723 | } else { |
724 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
725 | vmovups(vreg_load(load_idx + i_load), |
726 | load_ptr(i_reduce, i_load)); |
727 | } |
728 | } |
729 | bcast_count++; |
730 | } |
731 | } |
732 | if (jcp.uses_permw_transposition) |
733 | mov(reg_reduce_pos_flag, EVEX_compress_addr(rsp, perm_reg_offset)); |
734 | }; |
735 | |
736 | auto fma_block_fwd_bwd_d = [=](bool is_tail) { |
737 | int n_reduce_tail = jcp.reduce_dim % jcp.reduce_loop_unroll; |
738 | int n_reduce = is_tail && n_reduce_tail > 0 ? n_reduce_tail |
739 | : jcp.reduce_loop_unroll; |
740 | const int reduce_step = 2; |
741 | for (int i_reduce = 0; i_reduce < n_reduce; i_reduce += 2) { |
742 | if (isa_has_bf16(jcp.isa)) { |
743 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
744 | vmovdqu16(vreg_load(i_load), load_ptr(i_reduce, i_load)); |
745 | } |
746 | } |
747 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
748 | const bool need_safe_reduce_dim_load |
749 | = (i_reduce == rnd_dn(reduce_dim_tail, reduce_step)) |
750 | && reduce_dim_tail % reduce_step; |
751 | if (jcp.expl_bcast && load_loop_blk > 1) { |
752 | Label reduce_load_done; |
753 | if (need_safe_reduce_dim_load) { |
754 | Label skip_tail_load; |
755 | cmp(reduce_loop_iter, i_reduce + reduce_step); |
756 | jge(skip_tail_load, T_NEAR); |
757 | vpbroadcastw( |
758 | vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); |
759 | // clear duplciate high word |
760 | vpsrld(vreg_bcast, vreg_bcast, 16); |
761 | jmp(reduce_load_done, T_NEAR); |
762 | L(skip_tail_load); |
763 | } |
764 | vpbroadcastd(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); |
765 | L(reduce_load_done); |
766 | } |
767 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
768 | if (!isa_has_bf16(jcp.isa)) { |
769 | vmovdqu16( |
770 | vreg_load(i_load), load_ptr(i_reduce, i_load)); |
771 | } |
772 | if (jcp.expl_bcast && load_loop_blk > 1) { |
773 | if (!isa_has_bf16(jcp.isa)) { |
774 | auto acc = vreg_accum(i_load, i_ur); |
775 | auto wei = vreg_load(i_load); |
776 | bf16_emu_->vdpbf16ps(acc, wei, vreg_bcast); |
777 | } else |
778 | vdpbf16ps(vreg_accum(i_load, i_ur), |
779 | vreg_load(i_load), vreg_bcast); |
780 | } else { |
781 | if (!isa_has_bf16(jcp.isa)) { |
782 | Label reduce_load_done; |
783 | if (need_safe_reduce_dim_load) { |
784 | Label skip_tail_load; |
785 | cmp(reduce_loop_iter, i_reduce + reduce_step); |
786 | jge(skip_tail_load, T_NEAR); |
787 | vpbroadcastw(zmm_tmp2, |
788 | bcast_ptr(i_reduce, i_ur, false)); |
789 | // clear duplciate high word |
790 | vpsrld(zmm_tmp2, zmm_tmp2, 16); |
791 | jmp(reduce_load_done, T_NEAR); |
792 | L(skip_tail_load); |
793 | } |
794 | vpbroadcastd( |
795 | zmm_tmp2, bcast_ptr(i_reduce, i_ur, false)); |
796 | L(reduce_load_done); |
797 | auto acc = vreg_accum(i_load, i_ur); |
798 | auto wei = vreg_load(i_load); |
799 | bf16_emu_->vdpbf16ps(acc, wei, zmm_tmp2); |
800 | } else { |
801 | auto vreg_acc = vreg_accum(i_load, i_ur); |
802 | bool mask_flag = i_load + 1 == load_loop_blk |
803 | && load_dim_tail; |
804 | vreg_acc = may_be_mask_zmm( |
805 | vreg_acc, mask_flag, true); |
806 | if (need_safe_reduce_dim_load) { |
807 | Label reduce_load_done; |
808 | Label skip_tail_load; |
809 | cmp(reduce_loop_iter, i_reduce + reduce_step); |
810 | jge(skip_tail_load, T_NEAR); |
811 | vpbroadcastw(zmm_tmp2, |
812 | bcast_ptr(i_reduce, i_ur, false)); |
813 | // clear duplicate high word |
814 | vpsrld(zmm_tmp2, zmm_tmp2, 16); |
815 | jmp(reduce_load_done, T_NEAR); |
816 | L(skip_tail_load); |
817 | vpbroadcastd(zmm_tmp2, |
818 | bcast_ptr(i_reduce, i_ur, false)); |
819 | L(reduce_load_done); |
820 | vdpbf16ps( |
821 | vreg_acc, vreg_load(i_load), zmm_tmp2); |
822 | } else { |
823 | vdpbf16ps(vreg_acc, vreg_load(i_load), |
824 | bcast_ptr(i_reduce, i_ur, true)); |
825 | } |
826 | } |
827 | } |
828 | } |
829 | } |
830 | } |
831 | }; |
832 | |
833 | auto fma_block = [=](bool is_tail) { |
834 | return (jcp.prop_kind == backward_weights) |
835 | ? fma_block_bwd_w(is_tail) |
836 | : fma_block_fwd_bwd_d(is_tail); |
837 | }; |
838 | |
839 | Label reduce_loop; |
840 | Label reduce_loop_tail; |
841 | |
842 | mov(aux_reg_load_data, reg_load_data); |
843 | |
844 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
845 | init(); |
846 | |
847 | mov(reduce_loop_iter, reg_reduce_loop_work); |
848 | Label reduce_loop_exit; |
849 | cmp(reduce_loop_iter, jcp.reduce_loop_unroll); |
850 | jl(reduce_loop_tail, T_NEAR); |
851 | |
852 | L(reduce_loop); |
853 | { |
854 | fma_block(false); |
855 | add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); |
856 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
857 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
858 | cmp(reduce_loop_iter, jcp.reduce_loop_unroll); |
859 | jge(reduce_loop, T_NEAR); |
860 | } |
861 | |
862 | L(reduce_loop_tail); |
863 | cmp(reduce_loop_iter, 0); |
864 | jle(reduce_loop_exit, T_NEAR); |
865 | |
866 | fma_block(true); |
867 | L(reduce_loop_exit); |
868 | store(); |
869 | } |
870 | |
871 | void jit_avx512_core_bf16_1x1_conv_kernel::compute_diff_bias( |
872 | int load_loop_blk) { |
873 | if (IMPLICATION(jcp.with_bias, jcp.prop_kind != backward_weights)) return; |
874 | Label skip_diff_bias; |
875 | test(reg_reduce_pos_flag, FLAG_COMPUTE_BIAS); |
876 | jz(skip_diff_bias, T_NEAR); |
877 | |
878 | auto vunit = Zmm(31); |
879 | auto vreg_prm = Zmm(30); |
880 | |
881 | auto get_load_offset = [=](int i_reduce, int i_load) { |
882 | dim_t lmul |
883 | = jcp.load_block * (is_load_layout_nxc() ? 1 : jcp.reduce_dim); |
884 | dim_t rmul = (is_load_layout_nxc() ? jcp.load_dim : jcp.load_block); |
885 | return (i_load * lmul + i_reduce * rmul) * jcp.typesize_in; |
886 | }; |
887 | auto load_ptr = [=](int i_reduce, int i_load, int offset = 0) { |
888 | return EVEX_compress_addr( |
889 | aux_reg_load_data, get_load_offset(i_reduce, i_load) + offset); |
890 | }; |
891 | auto bias_ptr = [=](int i_load) { |
892 | return ptr[reg_bias_data + i_load * jcp.load_block * jcp.typesize_acc]; |
893 | }; |
894 | |
895 | auto vreg_acc = [=](int i_load) { return Zmm(i_load); }; |
896 | auto vreg_load = [=](int i_load) { return Zmm(load_loop_blk + i_load); }; |
897 | |
898 | auto compute_diff_bias_block = [=](bool is_tail) { |
899 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
900 | auto vacc = vreg_acc(i_load); |
901 | auto vload = vreg_load(i_load); |
902 | auto vload_masked = is_load_layout_nxc() || is_tail |
903 | ? vload | half_mask | T_z |
904 | : vload; |
905 | if (jcp.uses_permw_transposition) { |
906 | vmovdqu16(vload_masked, load_ptr(0, i_load)); |
907 | if (is_load_layout_nxc() && !is_tail) { |
908 | const int shift_16_elems = 16 * jcp.typesize_in; |
909 | vmovdqu16(vload | half_mask_hi, |
910 | load_ptr(0, i_load, -shift_16_elems)); |
911 | } |
912 | vpermw(vload, vreg_prm, vload); |
913 | } else { |
914 | vmovups(vload_masked, load_ptr(0, i_load)); |
915 | } |
916 | if (!isa_has_bf16(jcp.isa)) |
917 | bf16_emu_->vdpbf16ps(vacc, vload, vunit); |
918 | else |
919 | vdpbf16ps(vacc, vload, vunit); |
920 | } |
921 | }; |
922 | |
923 | auto reg_unit_val = reg_bcast_loop_iter.cvt16(); |
924 | mov(reg_unit_val, 0x3f80); // bf16 value of 1. |
925 | vpbroadcastw(vunit, reg_unit_val); |
926 | |
927 | if (jcp.uses_permw_transposition) { |
928 | mov(reg_bcast_loop_iter, dst_prm_table); |
929 | vmovups(vreg_prm, ptr[reg_bcast_loop_iter]); |
930 | } |
931 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
932 | auto vacc = vreg_acc(i_load); |
933 | vpxord(vacc, vacc, vacc); |
934 | } |
935 | |
936 | mov(aux_reg_load_data, reg_load_data); |
937 | mov(reduce_loop_iter, reg_reduce_loop_work); |
938 | const int reduce_step = 2; |
939 | Label reduce_loop, reduce_loop_tail, reduce_loop_exit; |
940 | cmp(reduce_loop_iter, reduce_step); |
941 | jl(reduce_loop_tail, T_NEAR); |
942 | |
943 | L(reduce_loop); |
944 | { |
945 | compute_diff_bias_block(false); |
946 | add(aux_reg_load_data, get_load_offset(reduce_step, 0)); |
947 | sub(reduce_loop_iter, reduce_step); |
948 | cmp(reduce_loop_iter, reduce_step); |
949 | jge(reduce_loop, T_NEAR); |
950 | } |
951 | |
952 | L(reduce_loop_tail); |
953 | cmp(reduce_loop_iter, 0); |
954 | jle(reduce_loop_exit, T_NEAR); |
955 | |
956 | compute_diff_bias_block(true); |
957 | L(reduce_loop_exit); |
958 | |
959 | Label skip_reading; |
960 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); // If FLAG_REDUCE_FIRST |
961 | jnz(skip_reading, T_NEAR); |
962 | |
963 | const int load_dim_tail = jcp.load_dim % jcp.load_block; |
964 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
965 | bool mask_flag = load_dim_tail && i_load + 1 == load_loop_blk; |
966 | auto vacc = vreg_acc(i_load); |
967 | vaddps(may_be_mask_zmm(vacc, mask_flag, true), vacc, bias_ptr(i_load)); |
968 | } |
969 | |
970 | L(skip_reading); |
971 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
972 | bool mask_flag = load_dim_tail && i_load + 1 == load_loop_blk; |
973 | auto vacc = vreg_acc(i_load); |
974 | vmovups(bias_ptr(i_load), may_be_mask_zmm(vacc, mask_flag, false)); |
975 | } |
976 | |
977 | L(skip_diff_bias); |
978 | } |
979 | |
980 | void jit_avx512_core_bf16_1x1_conv_kernel::generate() { |
981 | preamble(); |
982 | |
983 | sub(rsp, stack_space_needed); |
984 | if (jcp.with_binary) { |
985 | const auto zeroed_reg = r15; |
986 | xor_(zeroed_reg, zeroed_reg); |
987 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg); |
988 | mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), abi_param1); |
989 | } |
990 | |
991 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
992 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
993 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
994 | |
995 | mov(reg_trans_tmp.cvt32(), 0x0000ffff); |
996 | kmovd(half_mask, reg_trans_tmp.cvt32()); |
997 | mov(reg_trans_tmp.cvt32(), 0xffffffff); |
998 | kmovd(full_mask, reg_trans_tmp.cvt32()); |
999 | mov(reg_trans_tmp.cvt32(), 0xffff0000); |
1000 | kmovd(half_mask_hi, reg_trans_tmp.cvt32()); |
1001 | |
1002 | const int load_dim_tail |
1003 | = (one_of(jcp.prop_kind, forward_training, forward_inference) |
1004 | ? jcp.oc_without_padding |
1005 | : jcp.load_dim) |
1006 | % jcp.load_block; |
1007 | if (load_dim_tail) { |
1008 | mov(reg_trans_tmp.cvt32(), (1 << load_dim_tail) - 1); |
1009 | kmovw(k_load_dim_tail_mask, reg_trans_tmp.cvt32()); |
1010 | |
1011 | if (is_out_layout_nxc()) |
1012 | mov(reg_trans_tmp.cvt32(), |
1013 | (1 << (load_dim_tail + jcp.load_block)) - 1); |
1014 | else { |
1015 | const auto half_mask = (1 << load_dim_tail) - 1; |
1016 | mov(reg_trans_tmp.cvt32(), ((half_mask << 16) + half_mask)); |
1017 | } |
1018 | |
1019 | kmovd(k_load_dim_tail_mask_extended, reg_trans_tmp.cvt32()); |
1020 | } |
1021 | |
1022 | if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
1023 | |
1024 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
1025 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
1026 | mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work); |
1027 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
1028 | |
1029 | if (one_of(jcp.prop_kind, backward_data, forward_training, |
1030 | forward_inference)) { |
1031 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
1032 | mov(reg_store_buf, ptr[param1 + GET_OFF(store_buffer)]); |
1033 | } |
1034 | if (jcp.prop_kind == backward_weights) { |
1035 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
1036 | mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); |
1037 | } |
1038 | |
1039 | auto load_loop_body = [=](int load_loop_blk) { |
1040 | Label no_update_mask, update_mask_done; |
1041 | if (load_dim_tail) { |
1042 | cmp(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
1043 | jge(no_update_mask, T_NEAR); |
1044 | kmovd(k_load_dim_mask, k_load_dim_tail_mask); |
1045 | kmovd(k_load_dim_mask_extended, k_load_dim_tail_mask_extended); |
1046 | jmp(update_mask_done, T_NEAR); |
1047 | L(no_update_mask); |
1048 | } |
1049 | kxnord(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask); |
1050 | kxnord(k_load_dim_mask_extended, k_load_dim_mask_extended, |
1051 | k_load_dim_mask_extended); |
1052 | L(update_mask_done); |
1053 | |
1054 | mov(ptr[rsp + reg_load_loop_work_off], reg_load_loop_work); |
1055 | compute_diff_bias(load_loop_blk); |
1056 | bcast_loop(load_loop_blk); |
1057 | mov(reg_load_loop_work, ptr[rsp + reg_load_loop_work_off]); |
1058 | |
1059 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
1060 | switch (jcp.prop_kind) { |
1061 | case forward_training: |
1062 | case forward_inference: |
1063 | add(reg_bias_data, |
1064 | load_loop_blk * jcp.load_block * jcp.typesize_bia); |
1065 | add(reg_output_data, |
1066 | load_loop_blk * jcp.load_block * jcp.typesize_out |
1067 | * (is_out_layout_nxc() |
1068 | ? 1 |
1069 | : (jcp.with_dw_conv |
1070 | ? jcp.ow |
1071 | : jcp.bcast_dim))); |
1072 | add(reg_store_buf, |
1073 | load_loop_blk * jcp.load_block * jcp.typesize_acc |
1074 | * (is_out_layout_nxc() ? 1 : jcp.bcast_dim)); |
1075 | if (jcp.with_binary) { |
1076 | const auto oc_off_oprnd = rcx; |
1077 | mov(oc_off_oprnd, |
1078 | EVEX_compress_addr( |
1079 | rsp, reg_binary_post_op_acc_off)); |
1080 | add(oc_off_oprnd, jcp.load_block * load_loop_blk); |
1081 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), |
1082 | oc_off_oprnd); |
1083 | } |
1084 | break; |
1085 | case backward_data: |
1086 | add(reg_output_data, |
1087 | load_loop_blk * jcp.load_block * jcp.typesize_out |
1088 | * (is_out_layout_nxc() ? 1 : jcp.bcast_dim)); |
1089 | add(reg_store_buf, |
1090 | load_loop_blk * jcp.load_block * jcp.typesize_acc |
1091 | * (is_out_layout_nxc() ? 1 : jcp.bcast_dim)); |
1092 | break; |
1093 | case backward_weights: |
1094 | for (int i_load = 0; i_load < load_loop_blk; i_load++) |
1095 | add(reg_output_data, reg_output_stride); |
1096 | add(reg_bias_data, |
1097 | load_loop_blk * jcp.load_block * jcp.typesize_acc); |
1098 | break; |
1099 | default: assert(!"invalid prop_kind" ); |
1100 | } |
1101 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
1102 | }; |
1103 | |
1104 | const int simd_w = 16; |
1105 | |
1106 | Label load_loop_blk[7]; |
1107 | |
1108 | int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32}; |
1109 | int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32}; |
1110 | if (jcp.prop_kind == backward_weights) |
1111 | for (int i = 1; i < 6; i++) |
1112 | ur_cases_fma_expl_bcast[i] /= 2; |
1113 | |
1114 | const int size_ur_cases_fma = jcp.expl_bcast |
1115 | ? sizeof(ur_cases_fma_expl_bcast) |
1116 | : sizeof(ur_cases_fma_embd_bcast); |
1117 | const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast |
1118 | : ur_cases_fma_embd_bcast; |
1119 | const int *ur_cases = ur_cases_fma; |
1120 | const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); |
1121 | |
1122 | for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { |
1123 | int label_idx = num_ur_cases - ur_idx - 1; |
1124 | if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { |
1125 | cmp(reg_load_loop_work, simd_w * (label_idx + 1)); |
1126 | jle(load_loop_blk[label_idx], T_NEAR); |
1127 | } |
1128 | } |
1129 | |
1130 | for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { |
1131 | int label_idx = num_ur_cases - ur_idx - 1; |
1132 | if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { |
1133 | L(load_loop_blk[label_idx]); |
1134 | { |
1135 | if (label_idx == 0) { |
1136 | cmp(reg_load_loop_work, 0); |
1137 | jle(load_loop_blk[num_ur_cases], T_NEAR); |
1138 | } |
1139 | load_loop_body(label_idx + 1); |
1140 | if (label_idx - 1 > 0) { |
1141 | cmp(reg_load_loop_work, 2 * label_idx * simd_w); |
1142 | je(load_loop_blk[label_idx - 1], T_NEAR); |
1143 | } |
1144 | cmp(reg_load_loop_work, label_idx * simd_w); |
1145 | jg(load_loop_blk[label_idx]); |
1146 | } |
1147 | for (int idx = label_idx - 1; idx >= 0; --idx) { |
1148 | cmp(reg_load_loop_work, simd_w * (idx + 1)); |
1149 | jge(load_loop_blk[idx], T_NEAR); |
1150 | } |
1151 | if (ur_idx < num_ur_cases - 2) { |
1152 | cmp(reg_load_loop_work, simd_w); |
1153 | jle(load_loop_blk[0], T_NEAR); |
1154 | } |
1155 | } |
1156 | } |
1157 | L(load_loop_blk[num_ur_cases]); |
1158 | |
1159 | add(rsp, stack_space_needed); |
1160 | |
1161 | postamble(); |
1162 | |
1163 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
1164 | |
1165 | if (jcp.prop_kind == backward_weights) { |
1166 | const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, |
1167 | 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, |
1168 | 29, 14, 30, 15, 31}; |
1169 | |
1170 | align(64); |
1171 | L(dst_prm_table); |
1172 | for (int i = 0; i < 32; ++i) |
1173 | dw(dst_prm_array[i]); |
1174 | } |
1175 | } |
1176 | |
1177 | status_t jit_avx512_core_bf16_1x1_conv_kernel::init_conf( |
1178 | jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, |
1179 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, |
1180 | const memory_desc_wrapper &dst_d, primitive_attr_t &attr, int nthreads, |
1181 | bool reduce_src) { |
1182 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1183 | const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
1184 | const int ndims = src_d.ndims(); |
1185 | |
1186 | jcp.nthr = nthreads; |
1187 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
1188 | : bf16_emulation_t::get_isa(); |
1189 | jcp.prop_kind = cd.prop_kind; |
1190 | |
1191 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1192 | jcp.mb = src_d.dims()[0]; |
1193 | |
1194 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1195 | jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; |
1196 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1197 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
1198 | |
1199 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
1200 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
1201 | jcp.iw = src_d.dims()[ndims - 1]; |
1202 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
1203 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
1204 | jcp.ow = dst_d.dims()[ndims - 1]; |
1205 | |
1206 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
1207 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1208 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1209 | |
1210 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
1211 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
1212 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1213 | |
1214 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
1215 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
1216 | jcp.stride_w = cd.strides[ndims - 3]; |
1217 | |
1218 | jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, |
1219 | format_kind::undef, cd.diff_bias_desc.format_kind) |
1220 | != format_kind::undef; |
1221 | |
1222 | jcp.bia_dt = jcp.with_bias |
1223 | ? pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.data_type, |
1224 | data_type::undef, cd.diff_bias_desc.data_type) |
1225 | : data_type::undef; |
1226 | jcp.typesize_bia = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0; |
1227 | |
1228 | jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow; |
1229 | jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw; |
1230 | |
1231 | const auto &post_ops = attr.post_ops_; |
1232 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
1233 | jcp.with_dw_conv = dw_conv_ind != -1; |
1234 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
1235 | // handled in depthwise convolution. |
1236 | const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); |
1237 | jcp.with_sum = sum_ind != -1; |
1238 | |
1239 | const int eltwise_ind |
1240 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
1241 | jcp.with_eltwise = eltwise_ind != -1; |
1242 | if (jcp.with_eltwise) { |
1243 | if (dst_d.data_type() == data_type::s32) return status::unimplemented; |
1244 | } |
1245 | const int binary_ind |
1246 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
1247 | jcp.with_binary = binary_ind != -1; |
1248 | |
1249 | if (dw_conv_ind >= 0) { |
1250 | // dw_conv and post_ops after it are handled externally, so skip them |
1251 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
1252 | post_ops.entry_.cbegin() + dw_conv_ind); |
1253 | } else { |
1254 | jcp.post_ops = post_ops; |
1255 | } |
1256 | |
1257 | using namespace injector; |
1258 | static constexpr bool sum_at_pos_0_only = true; |
1259 | static constexpr bool sum_requires_scale_one = true; |
1260 | static constexpr bool sum_requires_zp_zero = true; |
1261 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
1262 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
1263 | sum_requires_zp_zero}); |
1264 | if (!post_ops_ok_) return status::unimplemented; |
1265 | |
1266 | using namespace format_tag; |
1267 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
1268 | const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
1269 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
1270 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
1271 | bool is_data_layout_nxc |
1272 | = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
1273 | auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
1274 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1 |
1275 | && (src_d.data_type() == data_type::f32 |
1276 | || src_d.data_type() == data_type::bf16); |
1277 | if (ok_to_pad_channels) { |
1278 | jcp.oc = rnd_up(jcp.oc, simd_w); |
1279 | jcp.ic = rnd_up(jcp.ic, simd_w); |
1280 | } |
1281 | |
1282 | const int is_bwd_d = jcp.prop_kind == backward_data; |
1283 | const int is_bwd_w = jcp.prop_kind == backward_weights; |
1284 | |
1285 | auto wei_tag = utils::pick( |
1286 | 2 * ndims - 6 + with_groups + 6 * is_bwd_d + 12 * is_bwd_w, |
1287 | OIw8i16o2i, gOIw8i16o2i, OIhw8i16o2i, gOIhw8i16o2i, OIdhw8i16o2i, |
1288 | gOIdhw8i16o2i, IOw8o16i2o, gIOw8o16i2o, IOhw8o16i2o, gIOhw8o16i2o, |
1289 | IOdhw8o16i2o, gIOdhw8o16i2o, OIw16i16o, gOIw16i16o, OIhw16i16o, |
1290 | gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o); |
1291 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
1292 | |
1293 | bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag |
1294 | && jcp.dst_tag == required_dat_tag && jcp.wei_tag == wei_tag; |
1295 | if (!args_ok) return status::unimplemented; |
1296 | |
1297 | args_ok = true |
1298 | && IMPLICATION(!is_data_layout_nxc, |
1299 | jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) |
1300 | && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 |
1301 | && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 |
1302 | && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw |
1303 | && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0 |
1304 | if (!args_ok) return status::unimplemented; |
1305 | |
1306 | jcp.ic_block = jcp.oc_block = simd_w; |
1307 | |
1308 | jcp.typesize_acc = sizeof(float); |
1309 | if (one_of(jcp.prop_kind, forward_training, forward_inference)) { |
1310 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1311 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
1312 | jcp.dst_dt = dst_d.data_type(); |
1313 | } else if (jcp.prop_kind == backward_data) { |
1314 | jcp.typesize_in = types::data_type_size(dst_d.data_type()); |
1315 | jcp.typesize_out = types::data_type_size(src_d.data_type()); |
1316 | jcp.dst_dt = src_d.data_type(); |
1317 | } else if (jcp.prop_kind == backward_weights) { |
1318 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1319 | jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type); |
1320 | jcp.dst_dt = weights_d.data_type(); |
1321 | } |
1322 | |
1323 | /* once all the formats are set, check the padding consistency */ |
1324 | args_ok = true && jcp.ic <= src_d.padded_dims()[1] |
1325 | && jcp.oc <= dst_d.padded_dims()[1] |
1326 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
1327 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
1328 | if (!args_ok) return status::unimplemented; |
1329 | |
1330 | const int SMALL_SPATIAL = 10; |
1331 | const int BIG_SPATIAL = 28; |
1332 | |
1333 | const int BIG_REDUCE_DIM = 1024; |
1334 | const int BIG_LOAD_DIM = 256; |
1335 | |
1336 | int load_blocking {0}; |
1337 | int load_blocking_max {0}; |
1338 | int bcast_blocking {0}; |
1339 | int bcast_blocking_max {0}; |
1340 | int reduce_blocking {0}; |
1341 | int reduce_blocking_max {0}; |
1342 | |
1343 | jcp.load_grp_count = 1; |
1344 | |
1345 | const int L1_capacity |
1346 | = platform::get_per_core_cache_size(1) / jcp.typesize_in; |
1347 | const int L2_size = platform::get_per_core_cache_size(2) / jcp.typesize_in; |
1348 | const int L2_capacity = (L2_size * 3) / 4; |
1349 | |
1350 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
1351 | backward_data)) { |
1352 | jcp.nthr = nthreads; |
1353 | if (one_of(jcp.prop_kind, forward_inference, forward_training)) { |
1354 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
1355 | jcp.reduce_dim = jcp.ic; |
1356 | jcp.reduce_block = jcp.ic_block; |
1357 | |
1358 | jcp.load_dim = jcp.oc; |
1359 | jcp.load_block = jcp.oc_block; |
1360 | |
1361 | jcp.bcast_dim = jcp.is; |
1362 | } else { |
1363 | jcp.reduce_dim = jcp.oc; |
1364 | jcp.reduce_block = jcp.oc_block; |
1365 | |
1366 | jcp.load_dim = jcp.ic; |
1367 | jcp.load_block = jcp.ic_block; |
1368 | |
1369 | jcp.bcast_dim = jcp.os; |
1370 | } |
1371 | jcp.reduce_loop_unroll = jcp.reduce_block; |
1372 | jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll |
1373 | * (is_data_layout_nxc ? 1 : jcp.bcast_dim); |
1374 | jcp.reduce_loop_load_step |
1375 | = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; |
1376 | jcp.load_loop_load_step |
1377 | = utils::rnd_up(jcp.reduce_dim, jcp.reduce_block) |
1378 | * jcp.load_block * jcp.typesize_in; |
1379 | |
1380 | // adjusting registry blocking |
1381 | int max_regs, min_regs, size_treshold, ur_step; |
1382 | const int spatial |
1383 | = (one_of(jcp.prop_kind, forward_training, forward_inference)) |
1384 | ? jcp.od * jcp.oh |
1385 | : jcp.id * jcp.ih; |
1386 | const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; |
1387 | if (reduce_dim_tail % 2 == 0 // cannot expl_bcast odd tail |
1388 | && (8 * jcp.mb) / jcp.nthr >= 1) { |
1389 | max_regs = 9; |
1390 | min_regs = 6; |
1391 | size_treshold = 14; |
1392 | ur_step = 1; |
1393 | jcp.expl_bcast = true; |
1394 | |
1395 | if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM |
1396 | && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { |
1397 | max_regs = 6; |
1398 | min_regs = isa_has_bf16(jcp.isa) ? 5 : 4; |
1399 | } |
1400 | } else { |
1401 | max_regs = 30; |
1402 | min_regs = 9; |
1403 | size_treshold = 14; |
1404 | ur_step = 1; |
1405 | jcp.expl_bcast = false; |
1406 | } |
1407 | jcp.ur = 1; |
1408 | if (!isa_has_bf16(jcp.isa)) { |
1409 | int adj_max_regs = max_regs / 3; |
1410 | max_regs = (adj_max_regs < min_regs) ? min_regs : adj_max_regs; |
1411 | } |
1412 | for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) { |
1413 | if ((spatial >= size_treshold && spatial % ur_w == 0) |
1414 | || (spatial < size_treshold && jcp.os % ur_w == 0)) { |
1415 | jcp.ur = ur_w; |
1416 | break; |
1417 | } |
1418 | } |
1419 | if (jcp.ur == 1) { |
1420 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
1421 | int os_tail = jcp.os % max_regs; |
1422 | for (int i = max_regs; i >= min_regs; i -= ur_step) { |
1423 | int i_tail = jcp.os % i; |
1424 | if (i_tail > os_tail || i_tail == 0) { |
1425 | jcp.ur = i; |
1426 | os_tail = i_tail; |
1427 | if (i_tail == 0) break; |
1428 | } |
1429 | } |
1430 | } |
1431 | |
1432 | jcp.bcast_block = jcp.ur; |
1433 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; |
1434 | |
1435 | jcp.bcast_loop_output_step |
1436 | = jcp.ur * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block); |
1437 | jcp.bcast_loop_output_substep = -1; // unused |
1438 | jcp.bcast_loop_bcast_step = jcp.typesize_in * jcp.ur |
1439 | * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block); |
1440 | jcp.bcast_loop_bcast_substep = -1; // unused |
1441 | |
1442 | jcp.load_loop_iter_step = jcp.load_block; |
1443 | |
1444 | if (jcp.prop_kind == backward_data) |
1445 | jcp.loop_order = loop_lbr; |
1446 | else |
1447 | jcp.loop_order = reduce_src ? loop_blr : loop_lbr; |
1448 | |
1449 | int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1450 | int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1451 | int nb_load = div_up(jcp.load_dim, jcp.load_block); |
1452 | |
1453 | if (is_data_layout_nxc |
1454 | || (jcp.prop_kind == backward_data && reduce_src)) { |
1455 | reduce_blocking = jcp.reduce_dim; |
1456 | } else { |
1457 | if (jcp.expl_bcast) { |
1458 | if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL |
1459 | && spatial < BIG_SPATIAL) |
1460 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 160); |
1461 | else if (spatial > SMALL_SPATIAL) |
1462 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 1024); |
1463 | else |
1464 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 512); |
1465 | } else { |
1466 | reduce_blocking = nb_reduce; |
1467 | if (spatial <= SMALL_SPATIAL |
1468 | && jcp.reduce_dim >= BIG_REDUCE_DIM) |
1469 | reduce_blocking = 32; |
1470 | else if (spatial > SMALL_SPATIAL |
1471 | && jcp.reduce_dim >= BIG_REDUCE_DIM) |
1472 | reduce_blocking = 16; |
1473 | reduce_blocking |
1474 | = best_divider(nb_reduce, 1, reduce_blocking, true); |
1475 | reduce_blocking *= jcp.reduce_block; |
1476 | } |
1477 | // Check input data cache aliasing. |
1478 | // For other ISA constants may be updated. |
1479 | // 64 * 1024 is chosen due to 1MB L2 16-way cache. |
1480 | // 7 is empirical value. It is about half of 16. |
1481 | // So we leave about half of the set for other data - weights, dst |
1482 | int way_size = (64 * 1024) / jcp.typesize_in; |
1483 | int max_hits = 7; |
1484 | if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) { |
1485 | int nrb = reduce_blocking / simd_w; |
1486 | int sp = jcp.bcast_dim; |
1487 | int wl = way_size / simd_w; |
1488 | for (int start_off = 0; start_off < jcp.ur; start_off++) { |
1489 | for (int off = start_off, hits = 0; off < sp * nrb; |
1490 | off += wl) { |
1491 | if (off % sp >= jcp.ur || ++hits < max_hits) continue; |
1492 | int max_r_blocking |
1493 | = simd_w * nstl::max(1, (off + wl) / sp); |
1494 | reduce_blocking |
1495 | = nstl::min(reduce_blocking, max_r_blocking); |
1496 | break; |
1497 | } |
1498 | } |
1499 | } |
1500 | } |
1501 | load_blocking = jcp.load_dim; |
1502 | |
1503 | int load_size = jcp.load_dim * jcp.reduce_dim; |
1504 | auto bcast_size |
1505 | = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; |
1506 | |
1507 | if (jcp.nthr <= 28 && jcp.mb < jcp.nthr |
1508 | && nb_load * nb_bcast > jcp.nthr) { |
1509 | // Some heuristic here |
1510 | float calc_koef = 0.01, best_cost = FLT_MAX; |
1511 | int n_lgc = jcp.nthr; |
1512 | float ratio = (float)load_size / (float)bcast_size; |
1513 | int best_lgc = ratio > 1 ? n_lgc : 1; |
1514 | auto calc_job_cost = [&](int lb, int tg, float mem_k) { |
1515 | int bb_size = jcp.mb * div_up(nb_bcast, tg); |
1516 | float calc_size = (float)(bb_size * jcp.ur) |
1517 | * (lb * jcp.load_block) * jcp.reduce_dim; |
1518 | float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) |
1519 | * jcp.reduce_dim; |
1520 | return calc_koef * calc_size + mem_k * mem_size; |
1521 | }; |
1522 | for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { |
1523 | lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; |
1524 | int min_lb = nb_load / lgc; |
1525 | int max_lb = div_up(nb_load, lgc); |
1526 | int min_tg = jcp.nthr / lgc; |
1527 | int max_tg = div_up(jcp.nthr, lgc); |
1528 | // Some heuristic here |
1529 | float mem_koef = (max_tg == 1) ? 1.f : 1.3f; |
1530 | float job_cost = 0.; |
1531 | if (jcp.nthr % lgc < nb_load % lgc) { |
1532 | job_cost = calc_job_cost(max_lb, min_tg, mem_koef); |
1533 | } else { |
1534 | auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); |
1535 | auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); |
1536 | job_cost = nstl::max(job_cost1, job_cost2); |
1537 | } |
1538 | |
1539 | if (job_cost < best_cost) { |
1540 | best_lgc = lgc; |
1541 | best_cost = job_cost; |
1542 | } |
1543 | } |
1544 | jcp.load_grp_count = best_lgc; |
1545 | load_blocking |
1546 | = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; |
1547 | } else { |
1548 | jcp.load_grp_count |
1549 | = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); |
1550 | jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count, |
1551 | 2 * jcp.load_grp_count, false); |
1552 | } |
1553 | |
1554 | if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) { |
1555 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); |
1556 | } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr |
1557 | && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { |
1558 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); |
1559 | load_blocking = jcp.load_block; |
1560 | } |
1561 | |
1562 | bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, |
1563 | div_up(jcp.nthr, jcp.load_grp_count)) |
1564 | * jcp.bcast_block; |
1565 | bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking); |
1566 | bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); |
1567 | |
1568 | int space_for_bcast = (L2_capacity - /* kernel_size - */ |
1569 | 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking |
1570 | - 3 * 1024); |
1571 | if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2; |
1572 | |
1573 | int bcast_in_cache |
1574 | = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); |
1575 | bcast_blocking = nstl::min( |
1576 | bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); |
1577 | |
1578 | load_blocking_max = load_blocking; |
1579 | bcast_blocking_max = bcast_blocking * 3 / 2; |
1580 | reduce_blocking_max = reduce_blocking; |
1581 | } else if (jcp.prop_kind == backward_weights) { |
1582 | jcp.use_vmovntps = false; |
1583 | jcp.reduce_dim = jcp.is; |
1584 | |
1585 | // cross-point for rn50 v1.5 blocked layout |
1586 | // in case of is_data_layout_nxc = true |
1587 | // jcp.uses_permw_transposition = false shows better performance for |
1588 | // the most problems according to performance measurements |
1589 | jcp.uses_permw_transposition = !is_data_layout_nxc && jcp.oh > 14 |
1590 | && jcp.ow > 14 |
1591 | // Performance improvement for i3d shapes |
1592 | && IMPLICATION(ndims == 5, jcp.oc <= 64); |
1593 | |
1594 | if (jcp.uses_permw_transposition) { |
1595 | int rdim = nstl::min<dim_t>(256, jcp.reduce_dim); |
1596 | jcp.reduce_block = best_divider(jcp.reduce_dim, 7, rdim, true, 2); |
1597 | } else |
1598 | jcp.reduce_block = best_divider(jcp.reduce_dim, 8, 16, true, 2); |
1599 | |
1600 | if (jcp.reduce_dim % jcp.reduce_block != 0) { |
1601 | jcp.reduce_block = best_divider( |
1602 | jcp.iw, 4, jcp.iw, jcp.uses_permw_transposition, 2); |
1603 | } |
1604 | if (jcp.reduce_block > 256) { jcp.reduce_block = 1; } |
1605 | if (!jcp.uses_permw_transposition) |
1606 | jcp.reduce_block = rnd_up(jcp.reduce_block, 2); |
1607 | |
1608 | jcp.load_dim = jcp.oc; |
1609 | jcp.load_block = jcp.oc_block; |
1610 | |
1611 | jcp.bcast_dim = jcp.ic; |
1612 | jcp.bcast_block = jcp.ic_block; |
1613 | |
1614 | jcp.ur = jcp.bcast_block; |
1615 | jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block; |
1616 | // TODO: try to enable jcp.expl_bcast version |
1617 | jcp.expl_bcast = false; |
1618 | |
1619 | jcp.reduce_loop_unroll = jcp.reduce_block; |
1620 | jcp.reduce_loop_bcast_step = jcp.typesize_in * jcp.reduce_loop_unroll |
1621 | * (is_data_layout_nxc && jcp.uses_permw_transposition |
1622 | ? jcp.ngroups * jcp.ic |
1623 | : jcp.ic_block); |
1624 | jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll |
1625 | * (is_data_layout_nxc && jcp.uses_permw_transposition |
1626 | ? jcp.ngroups * jcp.oc |
1627 | : jcp.oc_block); |
1628 | |
1629 | jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block; |
1630 | jcp.bcast_loop_output_substep |
1631 | = jcp.oc_block * jcp.ur * jcp.typesize_out; |
1632 | jcp.bcast_loop_bcast_step = jcp.typesize_in * jcp.ic_block |
1633 | * (jcp.uses_permw_transposition |
1634 | ? (is_data_layout_nxc ? 1 : jcp.reduce_dim) |
1635 | : rnd_up(jcp.reduce_dim, 2)); |
1636 | jcp.bcast_loop_bcast_substep = jcp.typesize_in * jcp.ur |
1637 | * (jcp.uses_permw_transposition ? 1 : 2); |
1638 | jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block |
1639 | * (jcp.uses_permw_transposition |
1640 | ? (is_data_layout_nxc ? 1 : jcp.os) |
1641 | : rnd_up(jcp.reduce_dim, 2)); |
1642 | jcp.load_loop_iter_step = jcp.oc_block; |
1643 | |
1644 | /* --- */ |
1645 | balance(jcp, jcp.nthr); |
1646 | |
1647 | load_blocking = div_up(jcp.load_dim, jcp.load_block); |
1648 | load_blocking = best_divider(load_blocking, 16, load_blocking, false); |
1649 | load_blocking *= jcp.load_block; |
1650 | |
1651 | load_blocking_max = load_blocking; |
1652 | |
1653 | int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
1654 | int min_bcast_blocking = 5; |
1655 | |
1656 | bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
1657 | bcast_blocking = best_divider( |
1658 | bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); |
1659 | bcast_blocking *= jcp.bcast_block; |
1660 | bcast_blocking_max = bcast_blocking; |
1661 | if (!is_data_layout_nxc) { |
1662 | assert(jcp.bcast_dim % bcast_blocking == 0); |
1663 | assert(jcp.load_dim % load_blocking == 0); |
1664 | } |
1665 | |
1666 | // for reduction balance |
1667 | int max_reduce_blocking |
1668 | = nstl::min<dim_t>(L1_capacity / jcp.ur, jcp.reduce_dim); |
1669 | int min_reduce_blocking |
1670 | = nstl::min(L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); |
1671 | reduce_blocking = best_divider( |
1672 | jcp.reduce_dim, min_reduce_blocking, max_reduce_blocking, true); |
1673 | reduce_blocking = nstl::max( |
1674 | rnd_dn(reduce_blocking, jcp.reduce_block), jcp.reduce_block); |
1675 | |
1676 | reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); |
1677 | } else |
1678 | return status::unimplemented; |
1679 | |
1680 | assert(load_blocking); |
1681 | assert(load_blocking_max); |
1682 | assert(bcast_blocking); |
1683 | assert(bcast_blocking_max); |
1684 | assert(reduce_blocking); |
1685 | assert(reduce_blocking_max); |
1686 | if (!is_data_layout_nxc) { |
1687 | assert(load_blocking % jcp.load_block == 0); |
1688 | assert(load_blocking_max % jcp.load_block == 0); |
1689 | } |
1690 | assert(IMPLICATION(jcp.uses_permw_transposition, |
1691 | reduce_blocking % jcp.reduce_block == 0)); |
1692 | assert(IMPLICATION(jcp.uses_permw_transposition, |
1693 | reduce_blocking_max % jcp.reduce_block == 0)); |
1694 | |
1695 | assert(jcp.bcast_block % jcp.ur == 0); |
1696 | assert(IMPLICATION(jcp.uses_permw_transposition, |
1697 | jcp.reduce_dim % jcp.reduce_block == 0)); |
1698 | |
1699 | jcp.nb_bcast_blocking = utils::div_up(bcast_blocking, jcp.bcast_block); |
1700 | jcp.nb_bcast_blocking_max |
1701 | = utils::div_up(bcast_blocking_max, jcp.bcast_block); |
1702 | jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block); |
1703 | jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block); |
1704 | jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block); |
1705 | jcp.nb_reduce_blocking_max |
1706 | = utils::div_up(reduce_blocking_max, jcp.reduce_block); |
1707 | |
1708 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1709 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
1710 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1711 | |
1712 | /* adjust the thread decomposition |
1713 | * to improve the perf for small size problem |
1714 | * simply set the thread to max of nb_bcast and nb_load now |
1715 | * TODO: add get_thr_eff func to compute optimal thread |
1716 | * TODO: Threshold can be increase when init stride > 1 */ |
1717 | auto bcast_size |
1718 | = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; |
1719 | bool is_adjust_thread = one_of(jcp.prop_kind, forward_training, |
1720 | forward_inference, backward_data) |
1721 | && jcp.typesize_in * bcast_size < 8192 && jcp.ngroups < jcp.nthr |
1722 | && jcp.nb_bcast * jcp.nb_load < jcp.nthr; |
1723 | if (is_adjust_thread) { |
1724 | int nthr = nstl::max(jcp.nb_bcast, jcp.nb_load); |
1725 | jcp.nthr = nstl::min(jcp.nthr, nthr); |
1726 | } |
1727 | |
1728 | return status::success; |
1729 | } |
1730 | |
1731 | status_t jit_avx512_core_bf16_1x1_conv_kernel::init_scratchpad( |
1732 | memory_tracking::registrar_t &scratchpad, |
1733 | const jit_1x1_conv_conf_t &jcp) { |
1734 | using namespace dnnl::impl::memory_tracking::names; |
1735 | using namespace dnnl::impl::format_tag; |
1736 | |
1737 | if (jcp.with_bias && jcp.oc_without_padding % jcp.oc_block |
1738 | && utils::one_of(jcp.prop_kind, forward_inference, forward_training, |
1739 | backward_weights) |
1740 | // nxc layout bias padding is only needed during bwd_wb prop |
1741 | && IMPLICATION(utils::one_of(jcp.dst_tag, nwc, nhwc, ndhwc), |
1742 | jcp.prop_kind == backward_weights)) { |
1743 | scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia); |
1744 | } |
1745 | if (jcp.prop_kind == backward_weights) { |
1746 | const size_t wei_size = (size_t)jcp.ngroups |
1747 | * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block); |
1748 | const int n_wei_buffers |
1749 | = jcp.dst_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1; |
1750 | const size_t bias_size = (size_t)jcp.with_bias * jcp.ngroups |
1751 | * rnd_up(jcp.oc, jcp.oc_block); |
1752 | const int n_bias_buffers = jcp.with_bias |
1753 | ? (jcp.bia_dt == data_type::bf16 ? jcp.nthr_mb |
1754 | : jcp.nthr_mb - 1) |
1755 | : 0; |
1756 | const size_t wei_bia_size |
1757 | = wei_size * n_wei_buffers + bias_size * n_bias_buffers; |
1758 | scratchpad.book(key_conv_wei_reduction, wei_bia_size, jcp.typesize_acc); |
1759 | |
1760 | if (!jcp.uses_permw_transposition) { |
1761 | const size_t dst_diff_tr_size_per_thr |
1762 | = (size_t)rnd_up(jcp.reduce_dim, 2) * jcp.oc_block |
1763 | * jcp.nb_load_blocking_max; |
1764 | scratchpad.book(key_conv_tr_diff_dst, |
1765 | jcp.nthr * dst_diff_tr_size_per_thr, jcp.typesize_in); |
1766 | const size_t src_tr_size_per_thr = (size_t)rnd_up(jcp.reduce_dim, 2) |
1767 | * jcp.ic_block * jcp.nb_bcast_blocking_max; |
1768 | scratchpad.book(key_conv_tr_src, jcp.nthr * src_tr_size_per_thr, |
1769 | jcp.typesize_in); |
1770 | } |
1771 | } |
1772 | |
1773 | // TODO: Check - do we need this buffer for ALL cases? |
1774 | if (jcp.prop_kind != backward_weights) { |
1775 | const size_t grp_count = utils::div_up( |
1776 | jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count)); |
1777 | const bool is_out_layout_nxc |
1778 | = (utils::one_of(jcp.prop_kind, prop_kind::forward_training, |
1779 | prop_kind::forward_inference) |
1780 | && utils::one_of(jcp.dst_tag, format_tag::ndhwc, |
1781 | format_tag::nhwc, format_tag::nwc)) |
1782 | || (jcp.prop_kind == prop_kind::backward_data |
1783 | && utils::one_of(jcp.src_tag, format_tag::ndhwc, |
1784 | format_tag::nhwc, format_tag::nwc)); |
1785 | const size_t max_load_per_thread = is_out_layout_nxc |
1786 | ? utils::rnd_up(jcp.load_dim, jcp.load_block) |
1787 | : rnd_up((utils::div_up(jcp.load_dim, grp_count)), |
1788 | jcp.load_block); |
1789 | const size_t store_buffer_size = (size_t)jcp.nthr |
1790 | * utils::rnd_up(jcp.bcast_dim, jcp.bcast_block) |
1791 | * max_load_per_thread; |
1792 | scratchpad.book( |
1793 | key_conv_store_wsp, store_buffer_size, jcp.typesize_acc); |
1794 | } |
1795 | // Heuristic threshold for requested scratchpad memory to avoid |
1796 | // possible crash on memory allocation. |
1797 | size_t scratchpad_limit_by_absolute_value = (size_t)20 * (1 << 30); // 20Gb |
1798 | |
1799 | // Note: currently ignore this check for depthwise-fusion implementation |
1800 | // because of memory requirements in the latter. |
1801 | if (!jcp.with_dw_conv |
1802 | && scratchpad.size() > scratchpad_limit_by_absolute_value) |
1803 | return status::unimplemented; |
1804 | |
1805 | return status::success; |
1806 | } |
1807 | |
1808 | void jit_avx512_core_bf16_1x1_conv_kernel::balance( |
1809 | jit_1x1_conv_conf_t &jcp, int nthreads) { |
1810 | // initialize jcp reduction threading properties |
1811 | jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; |
1812 | if (nthreads < jcp.ngroups) { |
1813 | /* simplification... fortunately it doesn't hurt much */ |
1814 | return; |
1815 | } |
1816 | const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1817 | const int nb_load = div_up(jcp.load_dim, jcp.load_block); |
1818 | const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1819 | |
1820 | jcp.nthr_g = jcp.ngroups; |
1821 | const int nthr = nthreads / jcp.nthr_g; |
1822 | |
1823 | auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
1824 | /* calculate per thread memory cost (read/write). high level |
1825 | * optimizer tries to minimize memory consumption. few notes: (n1) |
1826 | * unclear why, but that essentially helps first convolution... |
1827 | * (n2) assuming the reduction over minibatch is always there: |
1828 | * - instead of 8 it should be 5 here (write ~= 2 read): |
1829 | * kernel: temporal workspace 1 write |
1830 | * reduction: 1 read from workspace and 1 write to the diff_wei |
1831 | * - but experiments showed 8 works better than 5 or 6... */ |
1832 | int bcast_koeff = 1; |
1833 | int load_koeff = 1; |
1834 | int output_koeff = 12; |
1835 | |
1836 | if (jcp.prop_kind == backward_weights) { |
1837 | int mult = (jcp.stride_h == 1 && jcp.stride_w == 1) |
1838 | ? nstl::max(1, (jcp.oc / jcp.ic)) |
1839 | : 1; |
1840 | output_koeff = 4 * mult; |
1841 | } |
1842 | |
1843 | return 0 |
1844 | + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) |
1845 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b) |
1846 | * jcp.ic_block * jcp.reduce_block / jcp.stride_h |
1847 | / jcp.stride_w /* (n1) */ |
1848 | + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) |
1849 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) |
1850 | * jcp.oc_block * jcp.reduce_block |
1851 | + (size_t)output_koeff /* (n2) */ |
1852 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) |
1853 | * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block; |
1854 | }; |
1855 | |
1856 | int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; |
1857 | auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
1858 | |
1859 | /* step 1: find the best thread distribution with lowest memory cost */ |
1860 | const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); |
1861 | for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
1862 | const int nthr_par = nthr / nthr_mb; |
1863 | const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); |
1864 | for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
1865 | nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); |
1866 | auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
1867 | if (mem_cost <= best_mem_cost) { |
1868 | best_mem_cost = mem_cost; |
1869 | jcp.nthr_mb = nthr_mb; |
1870 | jcp.nthr_oc_b = nthr_oc_b; |
1871 | jcp.nthr_ic_b = nthr_ic_b; |
1872 | } |
1873 | } |
1874 | } |
1875 | if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) |
1876 | jcp.nthr_mb = nstl::min(jcp.mb, nthreads); |
1877 | |
1878 | jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; |
1879 | assert(jcp.nthr <= nthreads); |
1880 | } |
1881 | } // namespace x64 |
1882 | } // namespace cpu |
1883 | } // namespace impl |
1884 | } // namespace dnnl |
1885 | |