1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_1X1_CONV_HPP |
18 | #define CPU_X64_JIT_BRGEMM_1X1_CONV_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_convolution_pd.hpp" |
27 | #include "cpu/platform.hpp" |
28 | |
29 | #include "cpu/x64/amx_tile_configure.hpp" |
30 | #include "cpu/x64/brgemm/brgemm.hpp" |
31 | #include "cpu/x64/cpu_barrier.hpp" |
32 | #include "cpu/x64/cpu_reducer.hpp" |
33 | #include "cpu/x64/jit_brgemm_conv_trans_kernel.hpp" |
34 | #include "cpu/x64/jit_brgemm_conv_utils.hpp" |
35 | #include "cpu/x64/jit_brgemm_post_ops.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | template <cpu_isa_t isa> |
43 | struct brgemm_1x1_convolution_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_convolution_fwd_pd_t { |
45 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
46 | const typename pd_t::base_class *hint_fwd_pd) |
47 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) |
48 | , with_sum(false) |
49 | , sum_scale(0) {} |
50 | |
51 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv_1x1:" , isa, "" ), |
52 | brgemm_1x1_convolution_fwd_t); |
53 | |
54 | status_t init(engine_t *engine); |
55 | |
56 | brgemm_t brgs_[16]; |
57 | bool with_sum; |
58 | float sum_scale; |
59 | |
60 | bool need_postwork; |
61 | int ic_chunks; |
62 | |
63 | jit_brgemm_conv_conf_t jcp_; |
64 | |
65 | protected: |
66 | bool arg_scales_ok() const { |
67 | std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; |
68 | const int with_g = static_cast<int>(with_groups()); |
69 | bool ok = true; |
70 | ok = ok && attr()->scales_.has_default_values(supported_args); |
71 | for (int arg : supported_args) { |
72 | const auto &mask = attr()->scales_.get(arg).mask_; |
73 | if (arg == DNNL_ARG_WEIGHTS) |
74 | ok = ok && (mask == 0 || mask == (1 << with_g)); |
75 | else |
76 | ok = ok && (mask == 0); |
77 | } |
78 | return ok; |
79 | } |
80 | bool zero_points_ok() const { |
81 | // Only common zero points are supported -> mask should only be 0 |
82 | int mask_src = 0, mask_dst = 0; |
83 | attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
84 | attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
85 | return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
86 | && mask_src == 0 && mask_dst == 0; |
87 | } |
88 | }; |
89 | |
90 | brgemm_1x1_convolution_fwd_t(const pd_t *apd) |
91 | : primitive_t(apd), bias_d(pd()->weights_md(1)) {} |
92 | |
93 | ~brgemm_1x1_convolution_fwd_t() {} |
94 | |
95 | status_t execute(const exec_ctx_t &ctx) const override { |
96 | execute_forward_all(ctx); |
97 | |
98 | if (pd()->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx); |
99 | |
100 | return status::success; |
101 | } |
102 | |
103 | protected: |
104 | status_t init(engine_t *engine) override; |
105 | |
106 | private: |
107 | // brgemm convolution execution context |
108 | struct brgemm_exec_ctx_t { |
109 | brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd) |
110 | : src(CTX_IN_MEM(const char *, DNNL_ARG_SRC)) |
111 | , weights(CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS)) |
112 | , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS)) |
113 | , dst(CTX_OUT_MEM(char *, DNNL_ARG_DST)) |
114 | , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args( |
115 | pd->attr()->post_ops_, ctx)) |
116 | , wsp_tile(ctx.get_scratchpad_grantor().template get<char>( |
117 | memory_tracking::names::key_conv_amx_tile_buffer)) {} |
118 | const char *const __restrict src; |
119 | const char *const __restrict weights; |
120 | const char *const __restrict bias; |
121 | char *const __restrict dst; |
122 | const std::vector<const void *> post_ops_binary_rhs_arg_vec; |
123 | char *const wsp_tile; |
124 | }; |
125 | |
126 | void maybe_rtus(int ithr, const char *__restrict src, |
127 | char *__restrict inp_buffer, uint8_t *__restrict inp_buffer_mask, |
128 | int g, int n, int icc, int od, int oh, int ow) const; |
129 | void exec_ker(const brgemm_exec_ctx_t &brgemm_ctx, int ithr, |
130 | brgemm_batch_element_t *const __restrict brg_batch, |
131 | char *const c_buffer, const char *inp_buffer, int g, int n, int ocb, |
132 | int od, int oh, int ow, int icc, int *last_brg_idx, |
133 | const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp, |
134 | int32_t *dst_zp_vals, int32_t *s8s8_compensation) const; |
135 | status_t execute_forward_all(const exec_ctx_t &ctx) const; |
136 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
137 | |
138 | static int get_brg_idx(bool do_initialization, int is_M_tail, |
139 | bool is_N_tail, bool is_K_tail) { |
140 | return (((int)do_initialization * 2 + (int)is_M_tail) * 2 |
141 | + (int)is_N_tail) |
142 | * 2 |
143 | + (int)is_K_tail; |
144 | } |
145 | |
146 | static int get_ker_po_idx(int is_M_tail, bool is_N_tail) { |
147 | return (int)is_M_tail * 2 + (int)is_N_tail; |
148 | } |
149 | |
150 | std::unique_ptr<brgemm_kernel_t> brg_kernels_[16]; |
151 | struct amx_palette_t { |
152 | char p[AMX_PALETTE_SIZE]; |
153 | }; |
154 | std::vector<amx_palette_t> brg_kernel_palette_; |
155 | int brg_kernel_palette_idx_[16]; |
156 | std::unique_ptr<jit_avx512_core_brgemm_conv_trans_kernel:: |
157 | jit_avx512_core_brgemm_conv_rtus_kernel_t> |
158 | rtus_kernel_; |
159 | |
160 | const memory_desc_wrapper bias_d; |
161 | |
162 | int ID, IH, IW, OD, OH, OW, SD, SH, SW; |
163 | size_t bia_dsz, acc_dsz, src_dsz, wei_dsz; |
164 | // const variables used for address calculations |
165 | dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz, |
166 | wei_ic_sz, wei_ocb_sz; |
167 | }; |
168 | |
169 | } // namespace x64 |
170 | } // namespace cpu |
171 | } // namespace impl |
172 | } // namespace dnnl |
173 | |
174 | #endif |
175 | |
176 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
177 | |