1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_BRGEMM_INNER_PRODUCT_UTILS_HPP
18#define CPU_X64_BRGEMM_INNER_PRODUCT_UTILS_HPP
19
20#include "dnnl_types.h"
21
22#include "common/bfloat16.hpp"
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/memory_tracking.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_engine.hpp"
30#include "cpu/cpu_inner_product_pd.hpp"
31#include "cpu/platform.hpp"
32
33#include "cpu/x64/cpu_barrier.hpp"
34#include "cpu/x64/cpu_isa_traits.hpp"
35#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
36#include "cpu/x64/jit_brgemm_primitive_conf.hpp"
37#include "cpu/x64/jit_generator.hpp"
38
39namespace dnnl {
40namespace impl {
41namespace cpu {
42namespace x64 {
43
44namespace brgemm_inner_product_utils {
45
46status_t init_ip_conf(cpu_isa_t isa, jit_brgemm_primitive_conf_t &jbgp,
47 const inner_product_desc_t &ipd, memory_desc_t &src_md,
48 memory_desc_t &weights_md, memory_desc_t &dst_md,
49 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads);
50
51void init_scratchpad(memory_tracking::registrar_t &scratchpad,
52 const jit_brgemm_primitive_conf_t &jbgp);
53
54static const int max_num_brg_kernels_ip = 2 * 2 * 2 * 2 * 2;
55
56int get_brg_kernel_index(const jit_brgemm_primitive_conf_t &jbgp,
57 bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail,
58 bool is_K_tail);
59
60int get_os_block(const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust,
61 bool is_adjustment);
62int get_oc_block(
63 const jit_brgemm_primitive_conf_t &jbgp, bool try_to_adjust = false);
64
65int ip_fwd_get_oc_block(const jit_brgemm_primitive_conf_t &jbgp);
66int ip_fwd_get_nb_oc_blocking(
67 const jit_brgemm_primitive_conf_t &jbgp, bool is_adjustment = false);
68bool ip_fwd_adjust_thread_balance(const jit_brgemm_primitive_conf_t &jbgp);
69int ip_fwd_get_adjusted_oc_block(const jit_brgemm_primitive_conf_t &jbgp);
70
71format_tag_t get_brgemm_ip_weights_tag(
72 cpu_isa_t isa, const jit_brgemm_primitive_conf_t &jbgp);
73bool post_ops_ok(jit_brgemm_primitive_conf_t &jbgp,
74 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
75void thread_balance(const jit_brgemm_primitive_conf_t &j, int &nb_os_blocking_,
76 int &nthr_, int &nthr_mb_, int &nthr_oc_b_, int &nthr_ic_b_);
77status_t init_ip_conf_fwd(jit_brgemm_primitive_conf_t &jbgp,
78 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
79status_t init_ip_conf_bwd_d(jit_brgemm_primitive_conf_t &jbgp);
80status_t init_ip_conf_bwd_w(jit_brgemm_primitive_conf_t &jbgp);
81size_t buf_dt_size(data_type_t dt, cpu_isa_t isa);
82
83} // namespace brgemm_inner_product_utils
84
85} // namespace x64
86} // namespace cpu
87} // namespace impl
88} // namespace dnnl
89
90#endif
91