1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_AMX_CONVOLUTION_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_AMX_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_convolution_pd.hpp" |
27 | |
28 | #include "cpu/x64/amx_tile_configure.hpp" |
29 | #include "cpu/x64/cpu_barrier.hpp" |
30 | #include "cpu/x64/cpu_reducer.hpp" |
31 | #include "cpu/x64/jit_avx512_core_amx_conv_kernel.hpp" |
32 | #include "cpu/x64/jit_transpose_utils.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | struct jit_avx512_core_amx_convolution_fwd_t : public primitive_t { |
40 | struct pd_t : public cpu_convolution_fwd_pd_t { |
41 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
42 | const typename pd_t::base_class *hint_fwd_pd) |
43 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
44 | |
45 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , jcp_.isa, "" ), |
46 | jit_avx512_core_amx_convolution_fwd_t); |
47 | |
48 | status_t init(engine_t *engine) { |
49 | using namespace data_type; |
50 | using smask_t = primitive_attr_t::skip_mask_t; |
51 | bool is_bf16_convolution |
52 | = (src_md(0)->data_type == bf16 |
53 | && weights_md(0)->data_type == bf16 |
54 | && utils::one_of(dst_md(0)->data_type, f32, bf16)) |
55 | && IMPLICATION(with_bias(), |
56 | utils::one_of(weights_md(1)->data_type, f32, bf16)) |
57 | && attr()->has_default_values(smask_t::post_ops); |
58 | bool is_int8_convolution |
59 | = utils::one_of(src_md(0)->data_type, s8, u8) |
60 | && weights_md(0)->data_type == s8 |
61 | && utils::one_of( |
62 | dst_md(0)->data_type, s8, u8, s32, f32, bf16) |
63 | && IMPLICATION(with_bias(), |
64 | utils::one_of( |
65 | weights_md(1)->data_type, f32, s32, s8, u8)) |
66 | && attr()->has_default_values(smask_t::scales_runtime |
67 | | smask_t::post_ops |
68 | | smask_t::zero_points_runtime |
69 | | smask_t::sum_dt, |
70 | dst_md(0)->data_type) |
71 | && attr()->post_ops_.check_sum_consistent_dt( |
72 | dst_md(0)->data_type); |
73 | |
74 | bool ok = is_fwd() |
75 | && set_default_alg_kind(alg_kind::convolution_direct) |
76 | && (is_bf16_convolution || is_int8_convolution) |
77 | && !has_zero_dim_memory() && zero_points_ok(); |
78 | if (!ok) return status::unimplemented; |
79 | |
80 | CHECK(jit_avx512_core_amx_fwd_kernel_t::init_conf(jcp_, *desc(), |
81 | src_md_, weights_md_, dst_md_, bias_md_, attr_, |
82 | dnnl_get_max_threads())); |
83 | |
84 | auto scratchpad = scratchpad_registry().registrar(); |
85 | CHECK(jit_avx512_core_amx_fwd_kernel_t::init_scratchpad( |
86 | scratchpad, jcp_, *attr())); |
87 | |
88 | return status::success; |
89 | } |
90 | |
91 | jit_conv_conf_t jcp_; |
92 | |
93 | protected: |
94 | bool zero_points_ok() const { |
95 | // Only common zero points are supported -> mask should only be 0 |
96 | int mask_src = 0, mask_dst = 0; |
97 | attr()->zero_points_.get(DNNL_ARG_SRC, &mask_src); |
98 | attr()->zero_points_.get(DNNL_ARG_DST, &mask_dst); |
99 | return attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS) |
100 | && mask_src == 0 && mask_dst == 0; |
101 | } |
102 | }; |
103 | |
104 | jit_avx512_core_amx_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
105 | |
106 | status_t init(engine_t *engine) override { |
107 | CHECK(safe_ptr_assign(kernel_, |
108 | new jit_avx512_core_amx_fwd_kernel_t( |
109 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
110 | return kernel_->create_kernel(); |
111 | } |
112 | |
113 | status_t execute(const exec_ctx_t &ctx) const override { |
114 | const auto &_pd = pd(); |
115 | if (_pd->jcp_.is_depthwise) |
116 | return status::unimplemented; |
117 | else if (_pd->jcp_.is_relo) |
118 | return execute_forward_reduced_lowering(ctx); |
119 | return execute_forward(ctx); |
120 | } |
121 | |
122 | private: |
123 | status_t execute_forward_reduced_lowering(const exec_ctx_t &ctx) const; |
124 | status_t execute_forward(const exec_ctx_t &ctx) const; |
125 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
126 | void prepare_padded_bias(const char *&bias, |
127 | const memory_tracking::grantor_t &scratchpad) const; |
128 | |
129 | std::unique_ptr<jit_avx512_core_amx_fwd_kernel_t> kernel_; |
130 | }; |
131 | |
132 | template <impl::data_type_t diff_src_type, impl::data_type_t wei_type, |
133 | impl::data_type_t diff_dst_type> |
134 | struct jit_avx512_core_amx_convolution_bwd_data_t : public primitive_t { |
135 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
136 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
137 | const convolution_fwd_pd_t *hint_fwd_pd) |
138 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
139 | |
140 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , jcp_.isa, "" ), |
141 | jit_avx512_core_amx_convolution_bwd_data_t); |
142 | |
143 | status_t init(engine_t *engine) { |
144 | bool is_bf16_convolution = true |
145 | && (diff_dst_md_.data_type == data_type::bf16 |
146 | && weights_md_.data_type == data_type::bf16 |
147 | && utils::one_of(diff_src_md_.data_type, |
148 | data_type::f32, data_type::bf16)) |
149 | && attr()->has_default_values(); |
150 | |
151 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
152 | && set_default_alg_kind(alg_kind::convolution_direct) |
153 | && is_bf16_convolution && !has_zero_dim_memory(); |
154 | if (!ok) return status::unimplemented; |
155 | |
156 | status_t status = jit_avx512_core_amx_bwd_data_kernel_t::init_conf( |
157 | jcp_, *desc(), diff_src_md_, weights_md_, diff_dst_md_, |
158 | nullptr /* no bias */, attr_, dnnl_get_max_threads()); |
159 | if (status != status::success) return status; |
160 | |
161 | auto scratchpad = scratchpad_registry().registrar(); |
162 | jit_avx512_core_amx_bwd_data_kernel_t::init_scratchpad( |
163 | scratchpad, jcp_, *attr()); |
164 | |
165 | return status; |
166 | } |
167 | |
168 | jit_conv_conf_t jcp_; |
169 | }; |
170 | |
171 | jit_avx512_core_amx_convolution_bwd_data_t(const pd_t *apd) |
172 | : primitive_t(apd) {} |
173 | |
174 | typedef typename prec_traits<diff_src_type>::type diff_src_data_t; |
175 | typedef typename prec_traits<wei_type>::type wei_data_t; |
176 | typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; |
177 | |
178 | status_t init(engine_t *engine) override { |
179 | CHECK(safe_ptr_assign(kernel_, |
180 | new jit_avx512_core_amx_bwd_data_kernel_t( |
181 | pd()->jcp_, *pd()->attr()))); |
182 | return kernel_->create_kernel(); |
183 | } |
184 | |
185 | status_t execute(const exec_ctx_t &ctx) const override { |
186 | const auto &_pd = pd(); |
187 | if (_pd->jcp_.is_depthwise) { |
188 | assert(!"_pd->jcp_.is_depthwise not implemented" ); |
189 | return status::unimplemented; |
190 | } else |
191 | return execute_backward(ctx); |
192 | } |
193 | |
194 | private: |
195 | status_t execute_backward(const exec_ctx_t &ctx) const; |
196 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
197 | |
198 | std::unique_ptr<jit_avx512_core_amx_bwd_data_kernel_t> kernel_; |
199 | }; |
200 | |
201 | struct jit_avx512_core_amx_convolution_bwd_weights_t : public primitive_t { |
202 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
203 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
204 | const convolution_fwd_pd_t *hint_fwd_pd) |
205 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
206 | , jcp_() {} |
207 | |
208 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , jcp_.isa, "" ), |
209 | jit_avx512_core_amx_convolution_bwd_weights_t); |
210 | |
211 | status_t init(engine_t *engine) { |
212 | bool ok = true && is_bwd_w() |
213 | && set_default_alg_kind(alg_kind::convolution_direct) |
214 | && (expect_data_types(data_type::bf16, data_type::bf16, |
215 | data_type::undef, data_type::bf16, |
216 | data_type::undef) |
217 | || expect_data_types(data_type::bf16, |
218 | data_type::f32, data_type::undef, |
219 | data_type::bf16, data_type::undef)) |
220 | && IMPLICATION(with_bias(), |
221 | utils::one_of(diff_bias_md_.data_type, |
222 | data_type::f32, data_type::bf16)) |
223 | && attr()->has_default_values() && !has_zero_dim_memory(); |
224 | if (!ok) return status::unimplemented; |
225 | |
226 | status_t status |
227 | = jit_avx512_core_amx_bwd_weights_kernel_t::init_conf(jcp_, |
228 | *desc(), src_md_, diff_weights_md_, diff_bias_md_, |
229 | diff_dst_md_, dnnl_get_max_threads()); |
230 | if (status != status::success) return status; |
231 | |
232 | auto scratchpad = scratchpad_registry().registrar(); |
233 | status = jit_avx512_core_amx_bwd_weights_kernel_t::init_scratchpad( |
234 | scratchpad, jcp_, src_md_, diff_weights_md_, diff_dst_md_); |
235 | if (status != status::success) return status; |
236 | |
237 | return status; |
238 | } |
239 | |
240 | jit_conv_conf_t jcp_; |
241 | }; |
242 | |
243 | jit_avx512_core_amx_convolution_bwd_weights_t(const pd_t *apd) |
244 | : primitive_t(apd) {} |
245 | |
246 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
247 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
248 | |
249 | status_t init(engine_t *engine) override; |
250 | |
251 | status_t execute(const exec_ctx_t &ctx) const override { |
252 | execute_backward_weights(ctx); |
253 | return status::success; |
254 | } |
255 | |
256 | private: |
257 | struct thread_info_t; |
258 | |
259 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
260 | void prepare_scratchpad_data(const exec_ctx_t &ctx) const; |
261 | void compute_diff_weights_2d(const thread_info_t *) const; |
262 | void compute_diff_weights_3d(const thread_info_t *) const; |
263 | void compute_diff_weights(const thread_info_t *) const; |
264 | void reduce_and_convert_diff_weights_and_bias(const thread_info_t *) const; |
265 | void store_in_vnni_format(const thread_info_t *) const; |
266 | |
267 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
268 | |
269 | size_t tr_src_buf_number(const thread_info_t *ti, int g, int ic) const; |
270 | size_t tr_diff_dst_buf_number(const thread_info_t *ti, int g, int oc) const; |
271 | void trans_src_nxc(src_data_t *tr_src, const src_data_t *src_base, |
272 | int spatial_start, dim_t spatial_start_offset, int icb_start, |
273 | dim_t chb_stride, int my_work) const; |
274 | void trans_dst_nxc(diff_dst_data_t *tr_diff_dst, |
275 | const diff_dst_data_t *diff_dst_base, int spatial_start, |
276 | dim_t spatial_start_offset, int ocb_start, dim_t chb_stride, |
277 | int my_work) const; |
278 | |
279 | int nthr_ = 0, nthr_mb_ = 0, nthr_g_ = 0, nthr_oc_b_ = 0, nthr_ic_b_ = 0; |
280 | |
281 | std::unique_ptr<jit_avx512_core_amx_bwd_weights_kernel_t> kernel_; |
282 | |
283 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
284 | |
285 | std::unique_ptr<jit_diff_wei_trans_to_vnni_t> diff_wei_trans_kernel_; |
286 | std::unique_ptr<jit_trans_src_t> trans_kernel_; |
287 | std::unique_ptr<jit_trans_dst_t> trans_dst_kernel_; |
288 | |
289 | inline dim_t wei_offset_int(int g, int oc_b, int ic_b, int kX) const { |
290 | const auto &jcp = kernel_->jcp; |
291 | const dim_t = jcp.kw * jcp.ic_block * jcp.oc_block; |
292 | dim_t = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset |
293 | : kX * const_extra_offset; |
294 | return (dim_t)((g * jcp.nb_oc + oc_b) * jcp.nb_ic + ic_b) * jcp.kd |
295 | * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block |
296 | + extra_offset; |
297 | } |
298 | inline dim_t wei_offset_ext(int g, int oc_b, int ic_b, int kX) const { |
299 | const auto &jcp = kernel_->jcp; |
300 | const int nb_ic = utils::div_up(jcp.ic, 2 * jcp.ic_block); |
301 | const dim_t |
302 | = jcp.kw * jcp.ic_block * jcp.oc_block * 2; |
303 | dim_t = (jcp.ndims == 5) ? kX * jcp.kh * const_extra_offset |
304 | : kX * const_extra_offset; |
305 | return (dim_t)((g * jcp.nb_oc + oc_b) * nb_ic + ic_b) * jcp.kd * jcp.kh |
306 | * jcp.kw * jcp.ic_block * jcp.oc_block * 2 |
307 | + extra_offset; |
308 | } |
309 | }; |
310 | |
311 | } // namespace x64 |
312 | } // namespace cpu |
313 | } // namespace impl |
314 | } // namespace dnnl |
315 | |
316 | #endif |
317 | |
318 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
319 | |