1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_BRGEMM_INNER_PRODUCT_UTILS_HPP |
18 | #define CPU_X64_BRGEMM_INNER_PRODUCT_UTILS_HPP |
19 | |
20 | #include "dnnl_types.h" |
21 | |
22 | #include "common/bfloat16.hpp" |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/memory_tracking.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_engine.hpp" |
30 | #include "cpu/cpu_inner_product_pd.hpp" |
31 | #include "cpu/platform.hpp" |
32 | |
33 | #include "cpu/x64/cpu_barrier.hpp" |
34 | #include "cpu/x64/cpu_isa_traits.hpp" |
35 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
36 | #include "cpu/x64/jit_brgemm_primitive_conf.hpp" |
37 | #include "cpu/x64/jit_generator.hpp" |
38 | |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace cpu { |
42 | namespace x64 { |
43 | |
44 | namespace brgemm_inner_product_utils { |
45 | |
46 | status_t init_ip_conf(cpu_isa_t isa, jit_brgemm_primitive_conf_t &jbgp, |
47 | const inner_product_desc_t &ipd, memory_desc_t &src_md, |
48 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
49 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads); |
50 | |
51 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
52 | const jit_brgemm_primitive_conf_t &jbgp); |
53 | |
54 | static const int max_num_brg_kernels_ip = 2 * 2 * 2 * 2 * 2; |
55 | |
56 | int get_brg_kernel_index(const jit_brgemm_primitive_conf_t &jbgp, |
57 | bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail, |
58 | bool is_K_tail); |
59 | |
60 | int get_os_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust, |
61 | bool is_adjustment); |
62 | int get_oc_block( |
63 | const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust = false); |
64 | |
65 | int ip_fwd_get_oc_block(const jit_brgemm_primitive_conf_t &jbgp); |
66 | int ip_fwd_get_nb_oc_blocking( |
67 | const jit_brgemm_primitive_conf_t &jbgp, bool is_adjustment = false); |
68 | bool ip_fwd_adjust_thread_balance(const jit_brgemm_primitive_conf_t &jbgp); |
69 | int ip_fwd_get_adjusted_oc_block(const jit_brgemm_primitive_conf_t &jbgp); |
70 | |
71 | format_tag_t get_brgemm_ip_weights_tag( |
72 | cpu_isa_t isa, const jit_brgemm_primitive_conf_t &jbgp); |
73 | bool post_ops_ok(jit_brgemm_primitive_conf_t &jbgp, |
74 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); |
75 | void thread_balance(const jit_brgemm_primitive_conf_t &j, int &nb_os_blocking_, |
76 | int &nthr_, int &nthr_mb_, int &nthr_oc_b_, int &nthr_ic_b_); |
77 | status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp, |
78 | const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); |
79 | status_t init_ip_conf_bwd_d(jit_brgemm_primitive_conf_t &jbgp); |
80 | status_t init_ip_conf_bwd_w(jit_brgemm_primitive_conf_t &jbgp); |
81 | size_t buf_dt_size(data_type_t dt, cpu_isa_t isa); |
82 | |
83 | } // namespace brgemm_inner_product_utils |
84 | |
85 | } // namespace x64 |
86 | } // namespace cpu |
87 | } // namespace impl |
88 | } // namespace dnnl |
89 | |
90 | #endif |
91 | |