1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_UTILS_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23
24#include "cpu/cpu_convolution_pd.hpp"
25#include "cpu/cpu_engine.hpp"
26#include "cpu/platform.hpp"
27
28#include "cpu/x64/brgemm/brgemm.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36namespace brgemm_convolution_utils {
37
38constexpr size_t P4K = 4096;
39
40bool is_amx(cpu_isa_t isa);
41
42status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
43 const convolution_desc_t &cd, memory_desc_t &src_md,
44 memory_desc_t &weights_md, memory_desc_t &dst_md,
45 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads);
46
47status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
48 const convolution_desc_t &cd, memory_desc_t &src_md,
49 memory_desc_t &weights_md, memory_desc_t &dst_md,
50 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads);
51
52void set_amx_wsp_per_thread(jit_brgemm_conv_conf_t &jcp);
53
54void init_scratchpad(memory_tracking::registrar_t &scratchpad,
55 const jit_brgemm_conv_conf_t &jcp);
56
57status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp,
58 const convolution_desc_t &cd, memory_desc_t &src_md,
59 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
60 memory_desc_t &diff_dst_md, primitive_attr_t &attr, int nthreads);
61
62status_t init_scratchpad_bwd_w(memory_tracking::registrar_t &scratchpad,
63 const jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md,
64 memory_desc_t &diff_weights_md, memory_desc_t &diff_dst_md);
65
66} // namespace brgemm_convolution_utils
67
68} // namespace x64
69} // namespace cpu
70} // namespace impl
71} // namespace dnnl
72
73#endif
74