1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_1X1_CONV_HPP
18#define CPU_X64_JIT_BRGEMM_1X1_CONV_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/platform.hpp"
28
29#include "cpu/x64/amx_tile_configure.hpp"
30#include "cpu/x64/brgemm/brgemm.hpp"
31#include "cpu/x64/cpu_barrier.hpp"
32#include "cpu/x64/cpu_reducer.hpp"
33#include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp"
34#include "cpu/x64/jit_brgemm_conv_utils.hpp"
35#include "cpu/x64/jit_brgemm_post_ops.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42template <cpu_isa_t isa>
43struct brgemm_1x1_convolution_fwd_t : public primitive_t {
44 struct pd_t : public cpu_convolution_fwd_pd_t {
45 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
46 const typename pd_t::base_class *hint_fwd_pd)
47 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
48 , with_sum(false)
49 , sum_scale(0) {}
50
51 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv_1x1:", isa, ""),
52 brgemm_1x1_convolution_fwd_t);
53
54 status_t init(engine_t *engine);
55
56 brgemm_t brgs_[16];
57 bool with_sum;
58 float sum_scale;
59
60 bool need_postwork;
61 int ic_chunks;
62
63 jit_brgemm_conv_conf_t jcp_;
64
65 protected:
66 bool arg_scales_ok() const {
67 std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
68 const int with_g = static_cast<int>(with_groups());
69 bool ok = true;
70 ok = ok && attr()->scales_.has_default_values(supported_args);
71 for (int arg : supported_args) {
72 const auto &mask = attr()->scales_.get(arg).mask_;
73 if (arg == DNNL_ARG_WEIGHTS)
74 ok = ok && (mask == 0 || mask == (1 << with_g));
75 else
76 ok = ok && (mask == 0);
77 }
78 return ok;
79 }
80 bool zero_points_ok() const {
81 // Only common zero points are supported -> mask should only be 0
82 int mask_src = 0, mask_dst = 0;
83 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
84 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
85 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
86 && mask_src == 0 && mask_dst == 0;
87 }
88 };
89
90 brgemm_1x1_convolution_fwd_t(const pd_t *apd)
91 : primitive_t(apd), bias_d(pd()->weights_md(1)) {}
92
93 ~brgemm_1x1_convolution_fwd_t() {}
94
95 status_t execute(const exec_ctx_t &ctx) const override {
96 execute_forward_all(ctx);
97
98 if (pd()->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx);
99
100 return status::success;
101 }
102
103protected:
104 status_t init(engine_t *engine) override;
105
106private:
107 // brgemm convolution execution context
108 struct brgemm_exec_ctx_t {
109 brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd)
110 : src(CTX_IN_MEM(const char *, DNNL_ARG_SRC))
111 , weights(CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS))
112 , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS))
113 , dst(CTX_OUT_MEM(char *, DNNL_ARG_DST))
114 , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args(
115 pd->attr()->post_ops_, ctx))
116 , wsp_tile(ctx.get_scratchpad_grantor().template get<char>(
117 memory_tracking::names::key_conv_amx_tile_buffer)) {}
118 const char *const __restrict src;
119 const char *const __restrict weights;
120 const char *const __restrict bias;
121 char *const __restrict dst;
122 const std::vector<const void *> post_ops_binary_rhs_arg_vec;
123 char *const wsp_tile;
124 };
125
126 void maybe_rtus(int ithr, const char *__restrict src,
127 char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask,
128 int g, int n, int icc, int od, int oh, int ow) const;
129 void exec_ker(const brgemm_exec_ctx_t &brgemm_ctx, int ithr,
130 brgemm_batch_element_t *const __restrict brg_batch,
131 char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
132 int od, int oh, int ow, int icc, int *last_brg_idx,
133 const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp,
134 int32_t *dst_zp_vals, int32_t *s8s8_compensation) const;
135 status_t execute_forward_all(const exec_ctx_t &ctx) const;
136 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
137
138 static int get_brg_idx(bool do_initialization, int is_M_tail,
139 bool is_N_tail, bool is_K_tail) {
140 return (((int)do_initialization * 2 + (int)is_M_tail) * 2
141 + (int)is_N_tail)
142 * 2
143 + (int)is_K_tail;
144 }
145
146 static int get_ker_po_idx(int is_M_tail, bool is_N_tail) {
147 return (int)is_M_tail * 2 + (int)is_N_tail;
148 }
149
150 std::unique_ptr<brgemm_kernel_t> brg_kernels_[16];
151 struct amx_palette_t {
152 char p[AMX_PALETTE_SIZE];
153 };
154 std::vector<amx_palette_t> brg_kernel_palette_;
155 int brg_kernel_palette_idx_[16];
156 std::unique_ptr<jit_avx512_core_brgemm_conv_trans_kernel::
157 jit_avx512_core_brgemm_conv_rtus_kernel_t>
158 rtus_kernel_;
159
160 const memory_desc_wrapper bias_d;
161
162 int ID, IH, IW, OD, OH, OW, SD, SH, SW;
163 size_t bia_dsz, acc_dsz, src_dsz, wei_dsz;
164 // const variables used for address calculations
165 dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz,
166 wei_ic_sz, wei_ocb_sz;
167};
168
169} // namespace x64
170} // namespace cpu
171} // namespace impl
172} // namespace dnnl
173
174#endif
175
176// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
177