1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <float.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory.hpp" |
23 | #include "common/memory_tracking.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/platform.hpp" |
29 | #include "cpu/x64/cpu_barrier.hpp" |
30 | |
31 | #include "cpu/x64/injectors/injector_utils.hpp" |
32 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
33 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
34 | #include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp" |
35 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
36 | |
37 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
38 | |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace cpu { |
42 | namespace x64 { |
43 | |
44 | using namespace dnnl::impl::format_tag; |
45 | using namespace dnnl::impl::prop_kind; |
46 | using namespace dnnl::impl::utils; |
47 | |
48 | using namespace Xbyak; |
49 | |
50 | jit_avx512_common_1x1_conv_kernel::jit_avx512_common_1x1_conv_kernel( |
51 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
52 | const memory_desc_t &dst_md) |
53 | : jit_generator(jit_name()), jcp(ajcp), attr_(attr) { |
54 | if (jcp.with_eltwise || jcp.with_binary) { |
55 | using namespace binary_injector; |
56 | static constexpr bool preserve_gpr = true; |
57 | static constexpr bool preserve_vmm = false; |
58 | static constexpr size_t helper_vmm_idx = 31; |
59 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
60 | static constexpr bool use_exact_tail_scalar_bcast = true; |
61 | |
62 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
63 | r14, r15, r12, preserve_gpr, preserve_vmm, |
64 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
65 | memory_desc_wrapper(dst_md), tail_size, k_load_dim_mask, |
66 | use_exact_tail_scalar_bcast}; |
67 | const static_params_t static_params { |
68 | this->param1, rhs_arg_static_params}; |
69 | |
70 | postops_injector_ = utils::make_unique< |
71 | injector::jit_uni_postops_injector_t<avx512_core>>( |
72 | this, jcp.post_ops, static_params); |
73 | } |
74 | } |
75 | |
76 | void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk) { |
77 | mov(aux1_reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); |
78 | mov(aux_reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); |
79 | |
80 | mov(aux_reg_output_data, reg_output_data); |
81 | mov(reg_bcast_loop_iter, EVEX_compress_addr(rsp, reg_bcast_loop_work_offt)); |
82 | |
83 | Label bcast_loop; |
84 | Label bcast_loop_tail; |
85 | Label large_tail; |
86 | |
87 | cmp(reg_bcast_loop_iter, jcp.bcast_block); |
88 | jl(bcast_loop_tail, T_NEAR); |
89 | |
90 | L(bcast_loop); |
91 | { |
92 | assert(jcp.bcast_block % jcp.ur == 0); |
93 | int num_substeps = jcp.bcast_block / jcp.ur; |
94 | assert(num_substeps > 0 && num_substeps < 10); |
95 | for (int i = 0; i < num_substeps; i++) { |
96 | if (i + 1 == num_substeps) L(large_tail); |
97 | reduce_loop(load_loop_blk, jcp.ur, i, false); |
98 | if (i < num_substeps - 1) { |
99 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); |
100 | add(aux_reg_output_data, jcp.bcast_loop_output_substep); |
101 | } else { |
102 | add(aux1_reg_bcast_data, |
103 | jcp.bcast_loop_bcast_step |
104 | - (num_substeps - 1) |
105 | * jcp.bcast_loop_bcast_substep); |
106 | add(aux_reg_output_data, |
107 | jcp.bcast_loop_output_step |
108 | - (num_substeps - 1) |
109 | * jcp.bcast_loop_output_substep); |
110 | } |
111 | sub(reg_bcast_loop_iter, jcp.ur); |
112 | } |
113 | cmp(reg_bcast_loop_iter, jcp.bcast_block); |
114 | jge(bcast_loop, T_NEAR); |
115 | } |
116 | |
117 | L(bcast_loop_tail); |
118 | if (jcp.ur_tail) { |
119 | Label bcast_loop_tail_out; |
120 | if (jcp.ur_tail >= jcp.ur) { |
121 | cmp(reg_bcast_loop_iter, jcp.ur); |
122 | jge(large_tail, T_NEAR); |
123 | } |
124 | if (jcp.ur_tail % jcp.ur) { |
125 | cmp(reg_bcast_loop_iter, 0); |
126 | jle(bcast_loop_tail_out, T_NEAR); |
127 | reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true); |
128 | L(bcast_loop_tail_out); |
129 | } |
130 | } |
131 | } |
132 | |
133 | Address jit_avx512_common_1x1_conv_kernel::output_ptr( |
134 | const bool is_out_layout_nxc, const int i_load, const int i_ur) { |
135 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
136 | backward_data)) { |
137 | auto i_load_shift = is_out_layout_nxc |
138 | ? jcp.load_block |
139 | : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; |
140 | int i_ur_shift = is_out_layout_nxc ? jcp.load_dim : jcp.load_block; |
141 | auto offset = (i_load * i_load_shift + i_ur * i_ur_shift) |
142 | * jcp.typesize_out; |
143 | return EVEX_compress_addr(aux_reg_output_data, offset); |
144 | } else |
145 | return ptr[aux_reg_output_data |
146 | + (i_load ? reg_output_stride * i_load |
147 | : 0) // TODO: Xbyak should allow 0 scale |
148 | + jcp.typesize_out * jcp.load_block * i_ur]; |
149 | } |
150 | |
151 | static int vreg_accum_idx( |
152 | const int load_loop_blk, const int i_load, const int i_ur) { |
153 | return (i_ur * load_loop_blk + i_load); |
154 | } |
155 | |
156 | template <typename F> |
157 | static void iterate(const int load_loop_blk, const int ur, const bool mask_tail, |
158 | const F &fun) { |
159 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
160 | const bool mask_flag = mask_tail && i_load + 1 == load_loop_blk; |
161 | for (int i_ur = 0; i_ur < ur; ++i_ur) |
162 | fun(mask_flag, i_load, i_ur); |
163 | } |
164 | } |
165 | template <typename F> |
166 | static void iterate(const int load_loop_blk, const int ur, const F &fun) { |
167 | iterate(load_loop_blk, ur, false, fun); |
168 | } |
169 | |
170 | void jit_avx512_common_1x1_conv_kernel::apply_postops( |
171 | const bool is_out_layout_nxc, const int load_loop_blk, const int ur) { |
172 | injector_utils::vmm_index_set_t vmm_idxs; |
173 | if (jcp.with_binary) { |
174 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
175 | const auto mask_tail = jcp.oc_without_padding % jcp.load_block; |
176 | iterate(load_loop_blk, ur, mask_tail, |
177 | [&](const bool mask_flag, const int i_load, const int i_ur) { |
178 | const auto vmm_idx |
179 | = vreg_accum_idx(load_loop_blk, i_load, i_ur); |
180 | vmm_idxs.emplace(vmm_idx); |
181 | |
182 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
183 | vmm_idx, aux_reg_output_data); |
184 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, |
185 | get_output_offset(is_out_layout_nxc, i_load, i_ur)); |
186 | if (mask_flag) |
187 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
188 | }); |
189 | |
190 | mov(abi_param1, ptr[rsp + reg_abi_param1_backup]); |
191 | |
192 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
193 | } else { |
194 | iterate(load_loop_blk, ur, |
195 | [&](const bool, const int i_load, const int i_ur) { |
196 | vmm_idxs.emplace( |
197 | vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
198 | }); |
199 | postops_injector_->compute_vector_range(vmm_idxs); |
200 | } |
201 | } |
202 | |
203 | void jit_avx512_common_1x1_conv_kernel::reduce_loop( |
204 | int load_loop_blk, int ur, int substep, bool wraparound) { |
205 | const bool out_layout_nxc = is_out_layout_nxc(jcp); |
206 | const bool load_layout_nxc = is_load_layout_nxc(jcp); |
207 | const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp); |
208 | const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block; |
209 | const int load_dim_tail = jcp.load_dim % jcp.load_block; |
210 | |
211 | auto vreg_load |
212 | = [=](int i_load) { return Zmm(ur * load_loop_blk + i_load); }; |
213 | |
214 | auto vreg_accum = [=](int i_load, int i_ur) { |
215 | return Zmm(vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
216 | }; |
217 | |
218 | auto bias_ptr = [=](int i_load) { |
219 | return EVEX_compress_addr( |
220 | reg_bias_data, jcp.typesize_out * jcp.oc_block * i_load); |
221 | }; |
222 | |
223 | auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { |
224 | assert(i_ur < jcp.ur); |
225 | assert(i_reduce <= jcp.reduce_loop_unroll); |
226 | dim_t offt; |
227 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
228 | backward_data)) { |
229 | assert(jcp.reduce_loop_unroll == jcp.reduce_block); |
230 | const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim |
231 | : jcp.reduce_loop_unroll; |
232 | offt = (i_reduce == jcp.reduce_loop_unroll) |
233 | ? (jcp.bcast_dim + i_ur) * reduce_mul |
234 | : i_ur * reduce_mul + i_reduce; |
235 | } else { |
236 | int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block; |
237 | offt = i_reduce * rmul + i_ur; |
238 | } |
239 | return EVEX_compress_addr( |
240 | aux_reg_bcast_data, jcp.typesize_in * offt, bcast); |
241 | }; |
242 | |
243 | auto load_ptr = [=](int i_reduce, int i_load) { |
244 | int offt; |
245 | int u0 = i_reduce % jcp.reduce_loop_unroll; |
246 | int u1 = i_reduce / jcp.reduce_loop_unroll; |
247 | int lmul = jcp.load_block |
248 | * (load_layout_nxc ? 1 |
249 | : utils::rnd_up( |
250 | jcp.reduce_dim, jcp.reduce_block)); |
251 | int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block; |
252 | offt = i_load * lmul + u0 * rmul; |
253 | return EVEX_compress_addr(aux_reg_load_data, |
254 | u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt); |
255 | }; |
256 | |
257 | auto init = [=]() { |
258 | Label init_done; |
259 | Label init_zero; |
260 | |
261 | if (jcp.with_bias |
262 | && one_of(jcp.prop_kind, forward_training, forward_inference)) { |
263 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
264 | jz(init_zero, T_NEAR); |
265 | |
266 | for (int i_load = 0; i_load < load_loop_blk; i_load++) |
267 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
268 | auto vreg_acc = vreg_accum(i_load, i_ur); |
269 | if (i_load + 1 == load_loop_blk && load_dim_tail) |
270 | vreg_acc = vreg_acc | k_load_dim_mask | T_z; |
271 | |
272 | vmovups(vreg_acc, bias_ptr(i_load)); |
273 | } |
274 | jmp(init_done, T_NEAR); |
275 | } |
276 | |
277 | L(init_zero); |
278 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
279 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
280 | auto r = vreg_accum(i_load, i_ur); |
281 | vpxord(r, r, r); |
282 | } |
283 | L(init_done); |
284 | }; |
285 | |
286 | auto store = [=]() { |
287 | Label store_noadd; |
288 | if (!jcp.with_sum) { |
289 | test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST); |
290 | jnz(store_noadd, T_NEAR); |
291 | } |
292 | |
293 | for (int i_ur = 0; i_ur < ur; ++i_ur) |
294 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
295 | auto r = vreg_accum(i_load, i_ur); |
296 | if (i_load + 1 == load_loop_blk && load_dim_tail) |
297 | r = r | k_load_dim_mask | T_z; |
298 | vaddps(r, r, output_ptr(out_layout_nxc, i_load, i_ur)); |
299 | } |
300 | |
301 | L(store_noadd); |
302 | |
303 | if (jcp.with_eltwise || jcp.with_binary) { |
304 | Label store_nopostops; |
305 | test(reg_reduce_pos_flag, FLAG_REDUCE_LAST); |
306 | jz(store_nopostops, T_NEAR); |
307 | |
308 | apply_postops(out_layout_nxc, load_loop_blk, ur); |
309 | |
310 | L(store_nopostops); |
311 | } |
312 | |
313 | auto store_output = [=](bool output_is_aligned) { |
314 | const auto mask_flag = load_dim_tail; |
315 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
316 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
317 | auto vreg_acc = vreg_accum(i_load, i_ur); |
318 | // for nxc_layout-bwd_w, weights are still padded and the |
319 | // output_ptr here can be uninitialized scratchpad. |
320 | // To ensure final output (after reduction) is zero-padded, |
321 | // here we zero-pad output by omitting the mask. |
322 | if (jcp.prop_kind != backward_weights |
323 | && (i_load + 1 == load_loop_blk && mask_flag)) { |
324 | vreg_acc = vreg_acc | k_load_dim_mask; |
325 | } |
326 | vmovups(output_ptr(out_layout_nxc, i_load, i_ur), vreg_acc); |
327 | } |
328 | } |
329 | }; |
330 | |
331 | Label unaligned_store, end_store; |
332 | test(aux_reg_output_data, cpu_isa_traits<avx512_core>::vlen - 1); |
333 | jnz(unaligned_store, T_NEAR); |
334 | store_output(true); |
335 | jmp(end_store, T_NEAR); |
336 | L(unaligned_store); |
337 | { store_output(false); } |
338 | L(end_store); |
339 | }; |
340 | |
341 | auto fma_block = [=](bool last_block) { |
342 | const int i_reduce_end = reduce_dim_tail && last_block |
343 | ? reduce_dim_tail |
344 | : jcp.reduce_loop_unroll; |
345 | |
346 | for (int i_reduce = 0; i_reduce < i_reduce_end; i_reduce++) { |
347 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
348 | auto vreg = vreg_load(i_load); |
349 | if (i_load + 1 == load_loop_blk && load_dim_tail) |
350 | vreg = vreg | k_load_dim_mask | T_z; |
351 | |
352 | vmovups(vreg, load_ptr(i_reduce, i_load)); |
353 | } |
354 | |
355 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
356 | if (jcp.expl_bcast && load_loop_blk > 1) |
357 | vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false)); |
358 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
359 | auto vreg_acc = vreg_accum(i_load, i_ur); |
360 | if (i_load + 1 == load_loop_blk && load_dim_tail) |
361 | vreg_acc = vreg_acc | k_load_dim_mask | T_z; |
362 | if (jcp.expl_bcast && load_loop_blk > 1) |
363 | vfmadd231ps(vreg_acc, vreg_load(i_load), vreg_bcast); |
364 | else |
365 | vfmadd231ps(vreg_acc, vreg_load(i_load), |
366 | bcast_ptr(i_reduce, i_ur, true)); |
367 | } |
368 | } |
369 | } |
370 | }; |
371 | |
372 | Label reduce_loop; |
373 | Label reduce_loop_tail; |
374 | |
375 | mov(aux_reg_load_data, reg_load_data); |
376 | |
377 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
378 | init(); |
379 | |
380 | mov(reduce_loop_iter, reg_reduce_loop_work); |
381 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
382 | jle(reduce_loop_tail, T_NEAR); |
383 | |
384 | L(reduce_loop); |
385 | { |
386 | fma_block(false); |
387 | safe_add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step, reg_long_offt); |
388 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
389 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
390 | jg(reduce_loop, T_NEAR); |
391 | } |
392 | |
393 | L(reduce_loop_tail); |
394 | fma_block(true); |
395 | |
396 | store(); |
397 | } |
398 | |
399 | void jit_avx512_common_1x1_conv_kernel::generate() { |
400 | preamble(); |
401 | |
402 | sub(rsp, stack_space_needed); |
403 | if (jcp.with_binary) { |
404 | const auto zeroed_reg = r15; |
405 | xor_(zeroed_reg, zeroed_reg); |
406 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg); |
407 | mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), param1); |
408 | } |
409 | |
410 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
411 | mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); |
412 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
413 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
414 | |
415 | if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
416 | |
417 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
418 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
419 | mov(EVEX_compress_addr(rsp, reg_bcast_loop_work_offt), reg_bcast_loop_work); |
420 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
421 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
422 | if (jcp.prop_kind == backward_weights) |
423 | mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]); |
424 | |
425 | const int load_dim_tail |
426 | = (one_of(jcp.prop_kind, forward_training, forward_inference) |
427 | ? jcp.oc_without_padding |
428 | : jcp.load_dim) |
429 | % jcp.load_block; |
430 | if (load_dim_tail) { |
431 | Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32(); |
432 | mov(reg_tail_32, (1 << load_dim_tail) - 1); |
433 | kmovw(k_load_dim_tail_mask, reg_tail_32); |
434 | } |
435 | |
436 | auto load_loop_body = [=](int load_loop_blk) { |
437 | if (load_dim_tail) |
438 | kxnorw(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask); |
439 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
440 | if (load_dim_tail) { |
441 | Label no_update_mask; |
442 | jge(no_update_mask, T_NEAR); |
443 | kmovw(k_load_dim_mask, k_load_dim_tail_mask); |
444 | L(no_update_mask); |
445 | } |
446 | bcast_loop(load_loop_blk); |
447 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
448 | switch (jcp.prop_kind) { |
449 | case forward_training: |
450 | case forward_inference: |
451 | add(reg_bias_data, |
452 | load_loop_blk * jcp.load_block * jcp.typesize_out); |
453 | safe_add(reg_output_data, |
454 | load_loop_blk * jcp.load_block * jcp.typesize_out |
455 | * (is_out_layout_nxc(jcp) |
456 | ? 1 |
457 | : (jcp.with_dw_conv |
458 | ? jcp.ow |
459 | : jcp.bcast_dim)), |
460 | reg_long_offt); |
461 | if (jcp.with_binary) { |
462 | const auto oc_off_oprnd = aux_reg_load_data; |
463 | mov(oc_off_oprnd, |
464 | EVEX_compress_addr( |
465 | rsp, reg_binary_post_op_acc_off)); |
466 | add(oc_off_oprnd, jcp.load_block * load_loop_blk); |
467 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), |
468 | oc_off_oprnd); |
469 | } |
470 | break; |
471 | case backward_data: |
472 | safe_add(reg_output_data, |
473 | load_loop_blk * jcp.load_block * jcp.typesize_out |
474 | * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim), |
475 | reg_long_offt); |
476 | break; |
477 | case backward_weights: |
478 | for (int i_load = 0; i_load < load_loop_blk; i_load++) |
479 | add(reg_output_data, reg_output_stride); |
480 | break; |
481 | default: assert(!"invalid prop_kind" ); |
482 | } |
483 | }; |
484 | |
485 | const int simd_w = 16; |
486 | |
487 | Label load_loop_blk[7]; |
488 | |
489 | // with an implicit load_loop_block {6, 5, 4, 3, 2, 1} |
490 | static const int ur_cases_fma_embd_bcast[] = {2, 4, 5, 8, 14, 32}; |
491 | static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32}; |
492 | |
493 | const int size_ur_cases_fma = jcp.expl_bcast |
494 | ? sizeof(ur_cases_fma_expl_bcast) |
495 | : sizeof(ur_cases_fma_embd_bcast); |
496 | |
497 | const int *ur_cases_fma = jcp.expl_bcast ? ur_cases_fma_expl_bcast |
498 | : ur_cases_fma_embd_bcast; |
499 | const int *ur_cases = ur_cases_fma; |
500 | const int num_ur_cases = size_ur_cases_fma / sizeof(*ur_cases); |
501 | |
502 | for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { |
503 | int label_idx = num_ur_cases - ur_idx - 1; |
504 | if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { |
505 | cmp(reg_load_loop_work, simd_w * (label_idx + 1)); |
506 | jle(load_loop_blk[label_idx], T_NEAR); |
507 | } |
508 | } |
509 | |
510 | for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { |
511 | int label_idx = num_ur_cases - ur_idx - 1; |
512 | if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) { |
513 | L(load_loop_blk[label_idx]); |
514 | { |
515 | if (label_idx == 0) { |
516 | cmp(reg_load_loop_work, 0); |
517 | jle(load_loop_blk[num_ur_cases], T_NEAR); |
518 | } |
519 | load_loop_body(label_idx + 1); |
520 | if (label_idx - 1 > 0) { |
521 | cmp(reg_load_loop_work, 2 * label_idx * simd_w); |
522 | je(load_loop_blk[label_idx - 1], T_NEAR); |
523 | } |
524 | cmp(reg_load_loop_work, label_idx * simd_w); |
525 | jg(load_loop_blk[label_idx]); |
526 | } |
527 | for (int idx = label_idx - 1; idx >= 0; --idx) { |
528 | cmp(reg_load_loop_work, simd_w * (idx + 1)); |
529 | jge(load_loop_blk[idx], T_NEAR); |
530 | } |
531 | if (ur_idx < num_ur_cases - 2) { |
532 | cmp(reg_load_loop_work, simd_w); |
533 | jle(load_loop_blk[0], T_NEAR); |
534 | } |
535 | } |
536 | } |
537 | L(load_loop_blk[num_ur_cases]); |
538 | |
539 | add(rsp, stack_space_needed); |
540 | |
541 | postamble(); |
542 | |
543 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
544 | } |
545 | |
546 | status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp, |
547 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
548 | const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d, |
549 | const primitive_attr_t &attr, int nthreads, bool reduce_src) { |
550 | if (!mayiuse(avx512_core)) return status::unimplemented; |
551 | |
552 | if (!everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(), |
553 | dst_d.data_type())) |
554 | return status::unimplemented; |
555 | |
556 | jcp.nthr = nthreads; |
557 | |
558 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
559 | const int simd_w = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
560 | const int ndims = src_d.ndims(); |
561 | |
562 | jcp.prop_kind = cd.prop_kind; |
563 | |
564 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
565 | jcp.mb = src_d.dims()[0]; |
566 | |
567 | jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups; |
568 | jcp.oc = jcp.oc_without_padding; |
569 | jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups; |
570 | jcp.ic = jcp.ic_without_padding; |
571 | |
572 | jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; |
573 | jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; |
574 | jcp.iw = src_d.dims()[ndims - 1]; |
575 | jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; |
576 | jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; |
577 | jcp.ow = dst_d.dims()[ndims - 1]; |
578 | |
579 | jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; |
580 | jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
581 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
582 | |
583 | jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; |
584 | jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; |
585 | jcp.l_pad = cd.padding[0][ndims - 3]; |
586 | |
587 | jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; |
588 | jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; |
589 | jcp.stride_w = cd.strides[ndims - 3]; |
590 | |
591 | jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind, |
592 | format_kind::undef, cd.diff_bias_desc.format_kind) |
593 | != format_kind::undef; |
594 | |
595 | jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow; |
596 | jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw; |
597 | |
598 | const auto &post_ops = attr.post_ops_; |
599 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
600 | jcp.with_dw_conv = dw_conv_ind != -1; |
601 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
602 | // handled in depthwise convolution. |
603 | const int eltwise_ind |
604 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
605 | jcp.with_eltwise = eltwise_ind != -1; |
606 | if (jcp.with_eltwise) { |
607 | if (dst_d.data_type() == data_type::s32) return status::unimplemented; |
608 | } |
609 | |
610 | const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); |
611 | jcp.with_sum = sum_ind != -1; |
612 | |
613 | const int binary_ind |
614 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
615 | jcp.with_binary = binary_ind != -1; |
616 | |
617 | if (dw_conv_ind >= 0) { |
618 | // dw_conv and post_ops after it are handled externally, so skip them |
619 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
620 | post_ops.entry_.cbegin() + dw_conv_ind); |
621 | } else { |
622 | jcp.post_ops = post_ops; |
623 | } |
624 | |
625 | const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); |
626 | const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); |
627 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
628 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); |
629 | bool is_data_layout_nxc |
630 | = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag); |
631 | auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; |
632 | |
633 | bool ok_to_pad_channels = true && !is_data_layout_nxc && jcp.ngroups == 1 |
634 | && src_d.data_type() == data_type::f32; |
635 | |
636 | if (ok_to_pad_channels) { |
637 | jcp.oc = rnd_up(jcp.oc, simd_w); |
638 | jcp.ic = rnd_up(jcp.ic, simd_w); |
639 | } |
640 | |
641 | using namespace injector; |
642 | static constexpr bool sum_at_pos_0_only = true; |
643 | static constexpr bool sum_requires_scale_one = true; |
644 | static constexpr bool sum_requires_zp_zero = true; |
645 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
646 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
647 | sum_requires_zp_zero}); |
648 | if (!post_ops_ok_) return status::unimplemented; |
649 | |
650 | bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag |
651 | && jcp.dst_tag == required_dat_tag |
652 | && IMPLICATION(!is_data_layout_nxc, |
653 | jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) |
654 | && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 |
655 | && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1 |
656 | && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw |
657 | && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0 |
658 | if (!args_ok) return status::unimplemented; |
659 | |
660 | jcp.ic_block = jcp.oc_block = simd_w; |
661 | |
662 | const int is_bwd_d = jcp.prop_kind == backward_data; |
663 | format_tag_t wei_tag = with_groups |
664 | ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i, |
665 | gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, gIOdhw16o16i) |
666 | : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i, OIhw16i16o, |
667 | IOhw16o16i, OIdhw16i16o, IOdhw16o16i); |
668 | |
669 | jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag); |
670 | if (jcp.wei_tag != wei_tag) return status::unimplemented; |
671 | |
672 | jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type); |
673 | jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type); |
674 | |
675 | /* once all the formats are set, check the padding consistency */ |
676 | if (!is_data_layout_nxc) { |
677 | args_ok = true && jcp.ic <= src_d.padded_dims()[1] |
678 | && jcp.oc <= dst_d.padded_dims()[1] |
679 | && jcp.ic <= weights_d.padded_dims()[with_groups + 1] |
680 | && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; |
681 | if (!args_ok) return status::unimplemented; |
682 | } |
683 | |
684 | const int SMALL_SPATIAL = 10; |
685 | const int BIG_SPATIAL = 28; |
686 | const int BIG_REDUCE_DIM = 1024; |
687 | const int BIG_LOAD_DIM = 256; |
688 | |
689 | int load_blocking {0}; |
690 | int load_blocking_max {0}; |
691 | int bcast_blocking {0}; |
692 | int bcast_blocking_max {0}; |
693 | int reduce_blocking {0}; |
694 | int reduce_blocking_max {0}; |
695 | |
696 | jcp.load_grp_count = 1; |
697 | |
698 | const int L1_capacity |
699 | = platform::get_per_core_cache_size(1) / sizeof(float); |
700 | const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float); |
701 | const int L2_capacity = (L2_size * 3) / 4; |
702 | |
703 | if (one_of(jcp.prop_kind, forward_training, forward_inference, |
704 | backward_data)) { |
705 | if (one_of(jcp.prop_kind, forward_training, forward_inference)) { |
706 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
707 | jcp.reduce_dim = jcp.ic; |
708 | jcp.reduce_block = jcp.ic_block; |
709 | |
710 | jcp.load_dim = jcp.oc; |
711 | jcp.load_block = jcp.oc_block; |
712 | |
713 | jcp.bcast_dim = jcp.is; |
714 | } else { |
715 | jcp.reduce_dim = jcp.oc; |
716 | jcp.reduce_block = jcp.oc_block; |
717 | |
718 | jcp.load_dim = jcp.ic; |
719 | jcp.load_block = jcp.ic_block; |
720 | |
721 | jcp.bcast_dim = jcp.os; |
722 | } |
723 | jcp.reduce_loop_unroll = jcp.reduce_block; |
724 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll |
725 | * (is_data_layout_nxc ? 1 : jcp.bcast_dim) * jcp.typesize_in; |
726 | |
727 | jcp.reduce_loop_load_step |
728 | = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; |
729 | jcp.load_loop_load_step |
730 | = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)) |
731 | * jcp.load_block * jcp.typesize_in; |
732 | |
733 | // adjusting registry blocking |
734 | int max_regs, min_regs, size_treshold; |
735 | const int spatial |
736 | = (one_of(jcp.prop_kind, forward_training, forward_inference)) |
737 | ? jcp.od * jcp.oh |
738 | : jcp.id * jcp.ih; |
739 | if ((8 * jcp.mb) / jcp.nthr >= 1 |
740 | // NHWC perf: RN50 mb=1 |
741 | || (is_data_layout_nxc && jcp.mb == 1)) { |
742 | max_regs = 9; |
743 | min_regs = 6; |
744 | size_treshold = 14; |
745 | jcp.expl_bcast = true; |
746 | |
747 | if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM |
748 | && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) { |
749 | max_regs = 6; |
750 | min_regs = 5; |
751 | } |
752 | } else { |
753 | max_regs = 30; |
754 | min_regs = 9; |
755 | size_treshold = 14; |
756 | jcp.expl_bcast = false; |
757 | jcp.use_vmovntps = true; |
758 | } |
759 | jcp.ur = 1; |
760 | for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { |
761 | if ((spatial >= size_treshold && spatial % ur_w == 0) |
762 | || (spatial < size_treshold && jcp.os % ur_w == 0)) { |
763 | jcp.ur = ur_w; |
764 | break; |
765 | } |
766 | } |
767 | if (jcp.ur == 1) { |
768 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
769 | int os_tail = jcp.os % max_regs; |
770 | for (int i = max_regs; i >= min_regs; i--) { |
771 | int i_tail = jcp.os % i; |
772 | if (i_tail > os_tail || i_tail == 0) { |
773 | jcp.ur = i; |
774 | os_tail = i_tail; |
775 | if (i_tail == 0) break; |
776 | } |
777 | } |
778 | } |
779 | jcp.bcast_block = jcp.ur; |
780 | |
781 | jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out |
782 | * (is_data_layout_nxc ? jcp.load_dim : jcp.load_block); |
783 | jcp.bcast_loop_output_substep = -1; // unused |
784 | jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in |
785 | * (is_data_layout_nxc ? jcp.reduce_dim : jcp.reduce_block); |
786 | jcp.bcast_loop_bcast_substep = -1; // unused |
787 | |
788 | jcp.load_loop_iter_step = jcp.load_block; |
789 | |
790 | if (jcp.prop_kind == backward_data) |
791 | jcp.loop_order = loop_lbr; |
792 | else |
793 | jcp.loop_order = reduce_src ? loop_blr : loop_lbr; |
794 | |
795 | int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
796 | int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
797 | int nb_load = div_up(jcp.load_dim, jcp.load_block); |
798 | if (is_data_layout_nxc) { |
799 | reduce_blocking = jcp.reduce_dim; |
800 | } else if (jcp.expl_bcast) { |
801 | if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL |
802 | && spatial < BIG_SPATIAL) |
803 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 80); |
804 | else if (spatial > SMALL_SPATIAL) |
805 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 512); |
806 | else |
807 | reduce_blocking = nstl::min<dim_t>(jcp.reduce_dim, 256); |
808 | } else { |
809 | reduce_blocking = nb_reduce; |
810 | if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) |
811 | reduce_blocking = 16; |
812 | else if (spatial > SMALL_SPATIAL |
813 | && jcp.reduce_dim >= BIG_REDUCE_DIM) |
814 | reduce_blocking = 8; |
815 | reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); |
816 | reduce_blocking *= jcp.reduce_block; |
817 | } |
818 | |
819 | // Check input data cache aliasing. |
820 | // For other ISA constants may be updated. |
821 | // 64 * 1024 is chosen due to 1MB L2 16-way cache. |
822 | // 7 is empirical value. It is about half of 16. |
823 | // So we leave about half of the set for other data - weights, dst |
824 | int way_size = (64 * 1024) / jcp.typesize_in; |
825 | int max_hits = 7; |
826 | if (!is_data_layout_nxc |
827 | && jcp.bcast_dim * reduce_blocking |
828 | > static_cast<dim_t>(way_size) * max_hits) { |
829 | int nrb = reduce_blocking / simd_w; |
830 | auto sp = jcp.bcast_dim; |
831 | int wl = way_size / simd_w; |
832 | for (int start_off = 0; start_off < jcp.ur; start_off++) { |
833 | for (dim_t off = start_off, hits = 0; off < sp * nrb; |
834 | off += wl) { |
835 | if (off % sp >= jcp.ur || ++hits < max_hits) continue; |
836 | int max_r_blocking |
837 | = simd_w * nstl::max<dim_t>(1, (off + wl) / sp); |
838 | reduce_blocking |
839 | = nstl::min(reduce_blocking, max_r_blocking); |
840 | break; |
841 | } |
842 | } |
843 | } |
844 | |
845 | if (reduce_blocking < jcp.reduce_dim) { |
846 | if (jcp.prop_kind == backward_data) |
847 | jcp.loop_order = reduce_src ? loop_lbr : loop_rlb; |
848 | else |
849 | jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; |
850 | } |
851 | load_blocking = jcp.load_dim; |
852 | |
853 | int load_size = jcp.load_dim * jcp.reduce_dim; |
854 | auto bcast_size |
855 | = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; |
856 | |
857 | if (jcp.nthr <= 28 && jcp.mb < jcp.nthr |
858 | && nb_load * nb_bcast > jcp.nthr) { |
859 | // Some heuristic here |
860 | float calc_koef = 0.01, best_cost = FLT_MAX; |
861 | int n_lgc = jcp.nthr; |
862 | float ratio = (float)load_size / (float)bcast_size; |
863 | int best_lgc = ratio > 1 ? n_lgc : 1; |
864 | auto calc_job_cost = [&](int lb, int tg, float mem_k) { |
865 | int bb_size = jcp.mb * div_up(nb_bcast, tg); |
866 | float calc_size = (float)(bb_size * jcp.ur) |
867 | * (lb * jcp.load_block) * jcp.reduce_dim; |
868 | float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block) |
869 | * jcp.reduce_dim; |
870 | return calc_koef * calc_size + mem_k * mem_size; |
871 | }; |
872 | for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) { |
873 | lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1; |
874 | int min_lb = nb_load / lgc; |
875 | int max_lb = div_up(nb_load, lgc); |
876 | int min_tg = jcp.nthr / lgc; |
877 | int max_tg = div_up(jcp.nthr, lgc); |
878 | // Some heuristic here |
879 | float mem_koef = (max_tg == 1) ? 1.f : 1.3f; |
880 | float job_cost = 0.; |
881 | if (jcp.nthr % lgc < nb_load % lgc) { |
882 | job_cost = calc_job_cost(max_lb, min_tg, mem_koef); |
883 | } else { |
884 | auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef); |
885 | auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef); |
886 | job_cost = nstl::max(job_cost1, job_cost2); |
887 | } |
888 | |
889 | if (job_cost < best_cost) { |
890 | best_lgc = lgc; |
891 | best_cost = job_cost; |
892 | } |
893 | } |
894 | jcp.load_grp_count = best_lgc; |
895 | load_blocking |
896 | = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; |
897 | } else { |
898 | jcp.load_grp_count |
899 | = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); |
900 | jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count, |
901 | 2 * jcp.load_grp_count, false); |
902 | } |
903 | |
904 | if (jcp.expl_bcast && jcp.bcast_dim <= 64 && load_size >= L2_size) { |
905 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); |
906 | } else if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr |
907 | && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { |
908 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); |
909 | load_blocking = jcp.load_block; |
910 | } |
911 | |
912 | auto get_thr_eff = [=](int load_chunk, int nthr) { |
913 | int lgc = div_up(nb_load, load_chunk); |
914 | int thr_per_grp = div_up(nthr, lgc); |
915 | int bcast_per_thr |
916 | = div_up(jcp.mb * nb_bcast, thr_per_grp) * jcp.bcast_block; |
917 | int load_per_thr = load_chunk * simd_w; |
918 | float data_norm = (bcast_per_thr + load_per_thr) / 2.f; |
919 | float data_eff |
920 | = (bcast_per_thr * load_per_thr) / (data_norm * data_norm); |
921 | float thr_eff_over_grp |
922 | = (float)nstl::max(1, nthr / lgc) / div_up(nthr, lgc); |
923 | float thr_eff_in_grp = ((float)jcp.mb * nb_bcast) |
924 | / rnd_up(jcp.mb * nb_bcast, thr_per_grp); |
925 | float thr_eff = thr_eff_over_grp * thr_eff_in_grp; |
926 | float load_eff = (float)nb_load / rnd_up(nb_load, lgc); |
927 | float overall_eff = data_eff + thr_eff + load_eff; |
928 | return overall_eff; |
929 | }; |
930 | |
931 | auto get_load_chunk = [=](int nthr) { |
932 | float best_eff = -1.0f; |
933 | int best_lgc = 1; |
934 | float eff; |
935 | |
936 | for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) { |
937 | int lgc = div_up(nb_load, load_chunk); |
938 | if (lgc > nthr) continue; |
939 | eff = get_thr_eff(load_chunk, nthr); |
940 | if (eff > best_eff) { |
941 | best_eff = eff; |
942 | best_lgc = lgc; |
943 | } |
944 | } |
945 | return best_lgc; |
946 | }; |
947 | |
948 | /* adjust the thread decomposition |
949 | * to improve the thr_eff for small problem size |
950 | * the threshold 8192 is empirical |
951 | * TODO: Threshold can be increase for init stride > 1*/ |
952 | if (sizeof(float) * bcast_size < 8192 && jcp.mb < jcp.nthr |
953 | && nb_load * nb_bcast < jcp.nthr) { |
954 | float best_thr_eff = -1.0f; |
955 | float thr_eff = -1.0f; |
956 | int overall_lgc = jcp.load_grp_count; |
957 | int lgc = 1; |
958 | int best_nthr = jcp.nthr; |
959 | int end_nthr = with_groups ? jcp.ngroups : 1; |
960 | for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) { |
961 | lgc = get_load_chunk(nthr); |
962 | thr_eff = get_thr_eff(lgc, nthr); |
963 | if (best_thr_eff < thr_eff) { |
964 | best_thr_eff = thr_eff; |
965 | overall_lgc = lgc; |
966 | best_nthr = nthr; |
967 | } |
968 | } |
969 | jcp.nthr = best_nthr; |
970 | jcp.load_grp_count = overall_lgc; |
971 | load_blocking |
972 | = div_up(nb_load, jcp.load_grp_count) * jcp.load_block; |
973 | } |
974 | |
975 | bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, |
976 | div_up(jcp.nthr, jcp.load_grp_count)) |
977 | * jcp.bcast_block; |
978 | bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking); |
979 | bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); |
980 | |
981 | int space_for_bcast = (L2_capacity - /* kernel_size - */ |
982 | 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking |
983 | - 3 * 1024); |
984 | if (jcp.reduce_dim * jcp.bcast_dim > static_cast<dim_t>(L2_capacity)) |
985 | space_for_bcast /= 2; |
986 | |
987 | int bcast_in_cache |
988 | = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); |
989 | bcast_blocking = nstl::min( |
990 | bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); |
991 | // NHWC perf |
992 | if (is_data_layout_nxc) bcast_blocking = jcp.bcast_block; |
993 | |
994 | load_blocking_max = load_blocking; |
995 | bcast_blocking_max = bcast_blocking * 3 / 2; |
996 | reduce_blocking_max = reduce_blocking; |
997 | |
998 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; |
999 | |
1000 | } else if (jcp.prop_kind == backward_weights) { |
1001 | jcp.reduce_dim = jcp.is; |
1002 | |
1003 | jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true); |
1004 | if (jcp.reduce_dim % jcp.reduce_block != 0) |
1005 | jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false); |
1006 | if (jcp.reduce_block > 256) { jcp.reduce_block = 1; } |
1007 | |
1008 | jcp.load_dim = jcp.oc; |
1009 | jcp.load_block = jcp.oc_block; |
1010 | |
1011 | jcp.bcast_dim = jcp.ic; |
1012 | jcp.bcast_block = jcp.ic_block; |
1013 | |
1014 | if (jcp.reduce_block <= 19 && |
1015 | // maskrcnn optimization for nxc; don't reduce ur when ocb<=1 |
1016 | !(is_data_layout_nxc && jcp.load_dim <= jcp.load_block)) { |
1017 | // if reduce_block is big then generated JIT code may be big |
1018 | // for small values of ur because reduce_loop_unroll = reduce_block |
1019 | jcp.ur = jcp.bcast_block / 2; |
1020 | jcp.expl_bcast = true; |
1021 | } else { |
1022 | jcp.ur = jcp.bcast_block; |
1023 | jcp.expl_bcast = false; |
1024 | } |
1025 | |
1026 | jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block; |
1027 | jcp.reduce_loop_unroll = jcp.reduce_block; |
1028 | jcp.reduce_loop_bcast_step = static_cast<dim_t>(jcp.typesize_in) |
1029 | * jcp.reduce_loop_unroll |
1030 | * (is_data_layout_nxc ? jcp.ic : jcp.ic_block); |
1031 | jcp.reduce_loop_load_step = jcp.typesize_in * jcp.reduce_loop_unroll |
1032 | * (is_data_layout_nxc ? jcp.oc : jcp.oc_block); |
1033 | |
1034 | jcp.bcast_loop_output_step |
1035 | = jcp.oc_block * jcp.ic_block * jcp.typesize_out; |
1036 | jcp.bcast_loop_output_substep |
1037 | = jcp.oc_block * jcp.ur * jcp.typesize_out; |
1038 | jcp.bcast_loop_bcast_step = jcp.ic_block |
1039 | * (is_data_layout_nxc ? 1 |
1040 | : utils::rnd_up( |
1041 | jcp.reduce_dim, jcp.reduce_block)) |
1042 | * jcp.typesize_in; |
1043 | jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in; |
1044 | |
1045 | jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block |
1046 | * (is_data_layout_nxc ? 1 : jcp.os); |
1047 | jcp.load_loop_iter_step = jcp.oc_block; |
1048 | |
1049 | /* --- */ |
1050 | balance(jcp); |
1051 | |
1052 | load_blocking = div_up(jcp.load_dim, jcp.load_block); |
1053 | load_blocking = best_divider(load_blocking, 16, load_blocking, false); |
1054 | load_blocking *= jcp.load_block; |
1055 | |
1056 | load_blocking_max = load_blocking; |
1057 | assert(IMPLICATION( |
1058 | !is_data_layout_nxc, jcp.load_dim % load_blocking == 0)); |
1059 | |
1060 | int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
1061 | int min_bcast_blocking = 5; |
1062 | |
1063 | bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block); |
1064 | bcast_blocking = best_divider( |
1065 | bcast_blocking, min_bcast_blocking, max_bcast_blocking, false); |
1066 | bcast_blocking *= jcp.bcast_block; |
1067 | bcast_blocking_max = bcast_blocking; |
1068 | assert(IMPLICATION( |
1069 | !is_data_layout_nxc, jcp.bcast_dim % bcast_blocking == 0)); |
1070 | |
1071 | // for reduction balance |
1072 | if (is_data_layout_nxc && jcp.reduce_dim >= BIG_SPATIAL * BIG_SPATIAL |
1073 | && jcp.load_dim >= BIG_LOAD_DIM / 2) { |
1074 | reduce_blocking = rnd_up(nstl::min(jcp.ow, 256), jcp.reduce_block); |
1075 | } else { |
1076 | int max_reduce_blocking |
1077 | = nstl::min<dim_t>(L1_capacity / jcp.ur, jcp.reduce_dim); |
1078 | int min_reduce_blocking = nstl::min( |
1079 | L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih)); |
1080 | reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking, |
1081 | max_reduce_blocking, true); |
1082 | reduce_blocking |
1083 | = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block), |
1084 | jcp.reduce_block); |
1085 | } |
1086 | |
1087 | reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block); |
1088 | } else |
1089 | return status::unimplemented; |
1090 | |
1091 | assert(load_blocking); |
1092 | assert(load_blocking_max); |
1093 | assert(bcast_blocking); |
1094 | assert(bcast_blocking_max); |
1095 | assert(reduce_blocking); |
1096 | assert(reduce_blocking_max); |
1097 | |
1098 | if (!is_data_layout_nxc) { |
1099 | assert(load_blocking % jcp.load_block == 0); |
1100 | assert(reduce_blocking % jcp.reduce_block == 0); |
1101 | assert(load_blocking_max % jcp.load_block == 0); |
1102 | assert(reduce_blocking_max % jcp.reduce_block == 0); |
1103 | assert(jcp.reduce_dim % jcp.reduce_block == 0); |
1104 | } |
1105 | |
1106 | assert(jcp.bcast_block % jcp.ur == 0); |
1107 | |
1108 | jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; |
1109 | jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; |
1110 | jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block); |
1111 | jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block); |
1112 | jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block); |
1113 | jcp.nb_reduce_blocking_max |
1114 | = utils::div_up(reduce_blocking_max, jcp.reduce_block); |
1115 | |
1116 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1117 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
1118 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1119 | |
1120 | return status::success; |
1121 | } |
1122 | |
1123 | void jit_avx512_common_1x1_conv_kernel::init_scratchpad( |
1124 | memory_tracking::registrar_t &scratchpad, |
1125 | const jit_1x1_conv_conf_t &jcp) { |
1126 | using namespace dnnl::impl::memory_tracking::names; |
1127 | |
1128 | // Fox nxc layout bias is padded only for bwd_wb direction, as bias |
1129 | // reduction kernels can't handle tails yet. |
1130 | if (jcp.with_bias && jcp.prop_kind != backward_data |
1131 | && (jcp.oc != jcp.oc_without_padding // blocked layout |
1132 | || (jcp.prop_kind == backward_weights // nxc layout |
1133 | && jcp.oc % jcp.oc_block != 0))) { |
1134 | |
1135 | const size_t nelems_padded_bias |
1136 | = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block); |
1137 | scratchpad.book( |
1138 | key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out); |
1139 | } |
1140 | |
1141 | if (jcp.prop_kind == backward_weights) { |
1142 | const size_t wei_size = (size_t)jcp.ngroups |
1143 | * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block); |
1144 | scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1), |
1145 | jcp.typesize_out); |
1146 | } |
1147 | } |
1148 | |
1149 | void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) { |
1150 | int nthreads = jcp.nthr; |
1151 | // initialize jcp reduction threading properties |
1152 | jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1; |
1153 | if (nthreads < jcp.ngroups) { |
1154 | /* simplification... fortunately it doesn't hurt much */ |
1155 | return; |
1156 | } |
1157 | const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1158 | const int nb_load = div_up(jcp.load_dim, jcp.load_block); |
1159 | const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1160 | |
1161 | jcp.nthr_g = jcp.ngroups; |
1162 | const int nthr = nthreads / jcp.nthr_g; |
1163 | |
1164 | auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) { |
1165 | /* calculate per thread memory cost (read/write). high level |
1166 | * optimizer tries to minimize memory consumption. few notes: (n1) |
1167 | * unclear why, but that essentially helps first convolution... |
1168 | * (n2) assuming the reduction over minibatch is always there: |
1169 | * - instead of 8 it should be 5 here (write ~= 2 read): |
1170 | * kernel: temporal workspace 1 write |
1171 | * reduction: 1 read from workspace and 1 write to the diff_wei |
1172 | * - but experiments showed 8 works better than 5 or 6... */ |
1173 | int bcast_koeff = 1; |
1174 | int load_koeff = 1; |
1175 | int output_koeff = 12; |
1176 | return 0 |
1177 | + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) |
1178 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b) |
1179 | * jcp.ic_block * jcp.reduce_block / jcp.stride_h |
1180 | / jcp.stride_w /* (n1) */ |
1181 | + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb) |
1182 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) |
1183 | * jcp.oc_block * jcp.reduce_block |
1184 | + (size_t)output_koeff /* (n2) */ |
1185 | * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b) |
1186 | * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block; |
1187 | }; |
1188 | |
1189 | int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1; |
1190 | auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
1191 | |
1192 | /* step 1: find the best thread distribution with lowest memory cost */ |
1193 | const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce); |
1194 | for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) { |
1195 | const int nthr_par = nthr / nthr_mb; |
1196 | const int nthr_oc_b_max = nstl::min(nthr_par, nb_load); |
1197 | for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) { |
1198 | nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast); |
1199 | auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b); |
1200 | if (mem_cost <= best_mem_cost) { |
1201 | best_mem_cost = mem_cost; |
1202 | jcp.nthr_mb = nthr_mb; |
1203 | jcp.nthr_oc_b = nthr_oc_b; |
1204 | jcp.nthr_ic_b = nthr_ic_b; |
1205 | } |
1206 | } |
1207 | } |
1208 | if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads) |
1209 | jcp.nthr_mb = nstl::min(jcp.mb, nthreads); |
1210 | |
1211 | jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b; |
1212 | assert(jcp.nthr <= nthreads); |
1213 | } |
1214 | |
1215 | } // namespace x64 |
1216 | } // namespace cpu |
1217 | } // namespace impl |
1218 | } // namespace dnnl |
1219 | |