1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/platform.hpp"
28
29#include "cpu/x64/amx_tile_configure.hpp"
30#include "cpu/x64/brgemm/brgemm.hpp"
31#include "cpu/x64/cpu_barrier.hpp"
32#include "cpu/x64/cpu_reducer.hpp"
33#include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp"
34#include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp"
35#include "cpu/x64/jit_brgemm_conv_utils.hpp"
36#include "cpu/x64/jit_brgemm_post_ops.hpp"
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41namespace x64 {
42
43template <cpu_isa_t isa, bool use_inversion = false>
44struct brgemm_convolution_fwd_t : public primitive_t {
45
46 struct pd_t : public cpu_convolution_fwd_pd_t {
47 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
48 const typename pd_t::hint_class *hint_fwd_pd)
49 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
50 , with_sum(false) {}
51
52 ~pd_t() = default;
53
54 // ------- DECLARE_COMMON_PD_t -----
55 pd_t *clone() const override {
56 auto new_pd = utils::make_unique<pd_t>(*this);
57 if (!new_pd->is_initialized()) return nullptr;
58 new_pd->brgs_.resize(brgs_sz_);
59 for (int i = 0; i < brgs_sz_; i++) {
60 new_pd->brgs_[i] = brgs_[i];
61 new_pd->bd_masks[i] = bd_masks[i];
62 }
63 return new_pd.release();
64 }
65
66 status_t create_primitive(
67 std::pair<std::shared_ptr<primitive_t>, bool> &primitive,
68 engine_t *engine,
69 const cache_blob_t &cache_blob) const override {
70 return primitive_t::create_primitive_common<
71 brgemm_convolution_fwd_t, pd_t>(
72 primitive, this, engine, false, cache_blob);
73 }
74
75 const char *name() const override {
76 return JIT_IMPL_NAME_HELPER("brgconv:", isa, "");
77 }
78 // ---------------------------------
79
80 status_t init(engine_t *engine);
81
82 int brgs_sz_;
83 std::vector<std::shared_ptr<brgemm_t>> brgs_;
84 std::vector<std::shared_ptr<std::vector<char>>> bd_masks;
85 bool with_sum;
86 jit_brgemm_conv_conf_t jcp_;
87
88 int ic_chunks;
89 bool need_postwork;
90
91 // batch sizes info for unrolled kernels
92 int bs_c, first_bs;
93 std::vector<int> batchsizes;
94 int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
95 bool is_K_tail) const {
96 auto bs_idx = jcp_.use_uker ? batchsizes[bs] : 0;
97 assert(bs_idx >= 0);
98 return (((m * bs_c + bs_idx) * 2
99 + static_cast<int>(do_initialization))
100 * 2
101 + static_cast<int>(is_N_tail))
102 * 2
103 + static_cast<int>(is_K_tail);
104 }
105
106 protected:
107 bool arg_scales_ok() const {
108 std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
109 const int with_g = static_cast<int>(with_groups());
110 bool ok = true;
111 ok = ok && attr()->scales_.has_default_values(supported_args);
112 for (int arg : supported_args) {
113 const auto &mask = attr()->scales_.get(arg).mask_;
114 if (arg == DNNL_ARG_WEIGHTS)
115 ok = ok && (mask == 0 || mask == (1 << with_g));
116 else
117 ok = ok && (mask == 0);
118 }
119 return ok;
120 }
121 bool zero_points_ok() const {
122 // Only common zero points are supported -> mask should only be 0
123 int mask_src = 0, mask_dst = 0;
124 attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src);
125 attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst);
126 return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS)
127 && mask_src == 0 && mask_dst == 0;
128 }
129 };
130
131 brgemm_convolution_fwd_t(const pd_t *apd);
132
133 ~brgemm_convolution_fwd_t() = default;
134
135 status_t execute(const exec_ctx_t &ctx) const override;
136
137protected:
138 status_t init(engine_t *engine) override;
139
140private:
141 struct S_t {
142 char a[AMX_PALETTE_SIZE];
143 };
144
145 // brgemm convolution execution context
146 struct brgemm_exec_ctx_t {
147 brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd)
148 : src(CTX_IN_MEM(const char *, DNNL_ARG_SRC))
149 , weights(CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS))
150 , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS))
151 , dst(CTX_OUT_MEM(char *, DNNL_ARG_DST))
152 , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args(
153 pd->attr()->post_ops_, ctx)) {}
154 const char *const __restrict src;
155 const char *const __restrict weights;
156 const char *const __restrict bias;
157 char *const __restrict dst;
158 const std::vector<const void *> post_ops_binary_rhs_arg_vec;
159 };
160
161 struct brgemm_thread_ctx_t;
162
163 static int get_ker_po_idx(int m, bool do_postwork, bool is_N_tail) {
164 return (m * 2 + static_cast<int>(do_postwork)) * 2
165 + static_cast<int>(is_N_tail);
166 }
167
168 static int get_inp_size(
169 int max_src_size, int dst_size, int k, int stride, int dilate) {
170 const auto res = nstl::min(max_src_size,
171 calculate_end_padding(0, dst_size, 0, stride,
172 calculate_extended_filter_size(k, dilate)));
173 return res;
174 }
175
176 int maybe_invert(int k, int K) const {
177 return use_inversion ? K - 1 - k : k;
178 };
179 void get_kw_range(
180 int ow, int &kw_s, int &kw_full_s, int &kw_full_e, int &kw_e) const;
181 void get_ow_range(int ow, int kw, int &ow_s, int &ow_e) const;
182
183 void ker_base(brgemm_thread_ctx_t &btc) const;
184 void ker_trans(brgemm_thread_ctx_t &btc, char *inp_buffer) const;
185 void ker_vpad(brgemm_thread_ctx_t &btc) const;
186
187 void perform_outwork(char *dst_base, char *dst, char *c_buffer,
188 const char *bias_w, int od, int oh, int ow, int g_oc,
189 bool is_oc_tail, int ker_ow_s, int ker_ow_f, int kd_l, int kh_l,
190 const void *post_ops_binary_rhs_arg_vec, const float *oscales,
191 int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
192 int32_t *s8s8_compensation, bool maybe_do_init, bool do_postwork,
193 bool do_post_comp) const;
194
195 void call_brgemm_kernel(brgemm_thread_ctx_t &btc, int brg_idx,
196 int batch_size, char *ptr_C, char *ptr_D, const char *bias_w,
197 int g_oc, bool do_postops, const void *binary_post_ops_rhs,
198 int32_t src_zp_vals, int32_t *src_zp_ptr, int32_t *dst_zp_ptr,
199 int32_t *s8s8_comp, bool do_only_comp) const;
200
201 void maybe_conv_inp(int ithr, const char *__restrict src,
202 char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask,
203 int g, int n, int icc, int odb, int ohb, int owb, int last_g,
204 int last_n, int last_icc, int last_odb, int last_ohb,
205 int last_owb) const;
206
207 status_t add_po_kernel(brgemm_t *bcfg, int ker_idx, bool is_init);
208 void add_po_kernels(int i_N, int init_bcast_dim, int po_bcast_dim);
209 status_t add_brg_kernel(int bs, int M, int i_N, int i_K, int i_init);
210
211 status_t cal_compensation(const char *__restrict weights,
212 int32_t *src_zp_buffer, int32_t *s8s8_comp_buffer) const;
213 int get_comp_ker_idx(const int kd_b, const int kd_e, const int kh_b,
214 const int kh_e, const int kw_b, const int kw_e) const;
215 int get_comp_offset(const int g, const int ocb, const int ow,
216 const int kd_b, const int kd_e, const int kh_b, const int kh_e,
217 const int kw_b, const int kw_e) const;
218 const pd_t *pd() const {
219 return static_cast<const pd_t *>(primitive_t::pd().get());
220 }
221
222 std::vector<std::unique_ptr<brgemm_kernel_t>> brg_kernels_;
223 std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops<isa>>> kernels_po_;
224 std::unique_ptr<jit_avx512_core_brgemm_conv_trans_kernel::
225 jit_avx512_core_brgemm_conv_trans_kernel_t>
226 copy_to_pbuffer_;
227 std::unique_ptr<jit_avx512_core_brgemm_conv_comp_pad_kernel::
228 jit_avx512_core_brgemm_conv_comp_pad_kernel_t>
229 comp_vpad_pbuffer_;
230 std::vector<S_t> brg_kernel_palettes_;
231
232 size_t acc_dsz, bia_dsz, src_dsz, wei_dsz, dst_dsz;
233
234 const memory_desc_wrapper bias_d;
235
236 // pre - calculated values
237 std::vector<dim_t> owb_kw_top_vpads;
238 std::vector<dim_t> owb_kw_bottom_vpads;
239 std::vector<dim_t> kd_bs, kd_es, kh_bs, kh_es, kw_bs, kw_es;
240
241 int KD, KH, KW, EXT_KD, EXT_KH, EXT_KW, KS, KD_BLOCK, KH_BLOCK, KW_BLOCK,
242 KD_BLOCK_PAD, KH_BLOCK_PAD, ID, IH, IW, IDP, IHP, IWP, OD, OH, OW,
243 SD, SH, SW, FP, TP, LP, DD, DH, DW;
244 dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_ic_sz,
245 wei_kw_sz, wei_kh_sz, wei_kd_sz, wei_ocb_sz;
246 dim_t pbuf_w_sz, pbuf_h_sz, pbuf_d_sz;
247 dim_t ker_vpad_sz, comp_ocb_sz, comp_ker_sz, comp_kw_sz;
248
249 bool need_compensation;
250 bool is_amx;
251};
252
253} // namespace x64
254} // namespace cpu
255} // namespace impl
256} // namespace dnnl
257
258#endif
259
260// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
261