1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_BF16_CONVOLUTION_HPP |
18 | #define CPU_X64_GEMM_BF16_CONVOLUTION_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | #include "common/primitive.hpp" |
23 | |
24 | #include "cpu/cpu_convolution_pd.hpp" |
25 | #include "cpu/cpu_engine.hpp" |
26 | #include "cpu/gemm/gemm.hpp" |
27 | #include "cpu/gemm_convolution_utils.hpp" |
28 | #include "cpu/x64/cpu_reducer.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
30 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | template <data_type_t dst_data_type> |
38 | struct gemm_bf16_convolution_fwd_t : public primitive_t { |
39 | struct pd_t : public cpu_convolution_fwd_pd_t { |
40 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
41 | const typename pd_t::base_class *hint_fwd_pd) |
42 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
43 | |
44 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_fwd_t, |
45 | USE_GLOBAL_SCRATCHPAD); |
46 | |
47 | status_t init(engine_t *engine) { |
48 | bool ok = true && is_fwd() && mayiuse(avx512_core) |
49 | && set_default_alg_kind(alg_kind::convolution_direct) |
50 | && expect_data_types(data_type::bf16, data_type::bf16, |
51 | data_type::undef, dst_data_type, data_type::f32) |
52 | && IMPLICATION(with_bias(), |
53 | utils::one_of(desc()->bias_desc.data_type, |
54 | data_type::bf16, data_type::f32)) |
55 | && !has_zero_dim_memory() |
56 | && attr()->has_default_values( |
57 | primitive_attr_t::skip_mask_t::post_ops, |
58 | dst_data_type); |
59 | { |
60 | using namespace x64::injector; |
61 | static constexpr bool sum_at_pos_0_only = true; |
62 | static constexpr bool sum_requires_scale_one = true; |
63 | static constexpr bool sum_requires_zp_zero = true; |
64 | const auto dst_md = memory_desc_wrapper(dst_md_); |
65 | ok &= post_ops_ok({avx512_core, {binary, eltwise, sum}, |
66 | attr()->post_ops_, &dst_md, sum_at_pos_0_only, |
67 | sum_requires_scale_one, sum_requires_zp_zero}); |
68 | } |
69 | if (!ok) return status::unimplemented; |
70 | |
71 | auto scratchpad = scratchpad_registry().registrar(); |
72 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
73 | *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, |
74 | dnnl_get_max_threads()); |
75 | } |
76 | |
77 | bool is_postprocess_required() const { |
78 | bool post_ops_sum_only_for_dst_f32 = true |
79 | && dst_data_type == data_type::f32 |
80 | && attr()->post_ops_.len() == 1 |
81 | && attr()->post_ops_.contain(primitive_kind::sum, 0); |
82 | bool is_pp_for_post_ops_required = true |
83 | && attr()->post_ops_.len() > 0 |
84 | && !post_ops_sum_only_for_dst_f32; |
85 | return dst_data_type == data_type::bf16 || with_bias() |
86 | || is_pp_for_post_ops_required; |
87 | } |
88 | |
89 | conv_gemm_conf_t jcp_; |
90 | }; |
91 | |
92 | gemm_bf16_convolution_fwd_t(const pd_t *apd) |
93 | : primitive_t(apd), pp_ker_(nullptr) {} |
94 | |
95 | typedef typename prec_traits<dst_data_type>::type dst_data_t; |
96 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
97 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
98 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
99 | |
100 | status_t init(engine_t *engine) override { |
101 | const auto &post_ops = pd()->attr()->post_ops_; |
102 | const acc_data_t one = 1.0, zero = 0.0; |
103 | beta_ = dst_data_type == data_type::f32 |
104 | && post_ops.find(primitive_kind::sum) >= 0 |
105 | ? one |
106 | : zero; |
107 | |
108 | if (this->pd()->is_postprocess_required()) { |
109 | CHECK(safe_ptr_assign(pp_ker_, new pp_ker_t(this->pd()))); |
110 | return pp_ker_->create_kernel(); |
111 | } |
112 | return status::success; |
113 | } |
114 | |
115 | status_t execute(const exec_ctx_t &ctx) const override { |
116 | const bool is_nspc = pd()->jcp_.is_nspc; |
117 | return is_nspc ? execute_forward_nspc(ctx) : execute_forward_ncsp(ctx); |
118 | } |
119 | |
120 | private: |
121 | status_t execute_forward_ncsp(const exec_ctx_t &ctx) const; |
122 | status_t execute_forward_nspc(const exec_ctx_t &ctx) const; |
123 | status_t execute_forward_thr_nspc(const int ithr, const int nthr, |
124 | const src_data_t *src_base, const wei_data_t *wei_base, |
125 | const float *bia_base, dst_data_t *dst_base, |
126 | const memory_tracking::grantor_t &scratchpad, |
127 | const void *post_ops_binary_rhs_arg_vec) const; |
128 | |
129 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
130 | |
131 | class pp_ker_t : public jit_generator { |
132 | public: |
133 | DECLARE_CPU_JIT_AUX_FUNCTIONS(gemm_bf16_convolution_fwd_t::pp_kernel); |
134 | pp_ker_t(const pd_t *pd); |
135 | |
136 | void operator()(dst_data_t *dst, const acc_data_t *acc, |
137 | const acc_data_t *bias, float sum_scale, size_t oc_work, |
138 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
139 | const size_t g_oc_offset); |
140 | void operator()(dst_data_t *dst, const acc_data_t *acc, |
141 | const acc_data_t *bias, float sum_scale, size_t dst_str, |
142 | size_t acc_str, size_t sp_len, size_t oc, |
143 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
144 | const size_t g_oc_offset); |
145 | |
146 | private: |
147 | struct ker_args { |
148 | dst_data_t *dst; |
149 | const acc_data_t *acc; |
150 | const acc_data_t *bias; |
151 | float sum_scale; |
152 | size_t dst_stride_in_bytes; |
153 | size_t acc_stride_in_bytes; |
154 | size_t spatial_length; |
155 | size_t oc_work; |
156 | |
157 | size_t g_oc_offset; |
158 | const void *post_ops_binary_rhs_arg_vec; |
159 | const void *dst_orig; |
160 | }; |
161 | |
162 | enum { default_unroll_2_pow_ = 2 }; |
163 | |
164 | Xbyak::Reg64 reg_param = rdi; |
165 | Xbyak::Reg64 reg_dst_base = rdx; |
166 | Xbyak::Reg64 reg_acc_base = rax; |
167 | Xbyak::Reg64 reg_dst = rsi; |
168 | Xbyak::Reg64 reg_acc = rbp; |
169 | Xbyak::Reg64 reg_bias = rbx; |
170 | |
171 | Xbyak::Reg64 reg_len = r8; |
172 | Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes |
173 | Xbyak::Reg64 reg_rem_mask = r9; |
174 | Xbyak::Opmask kreg_rem_mask = k1; |
175 | Xbyak::Reg64 reg_oc_iter = r11; |
176 | Xbyak::Reg64 reg_len_iter = r12; |
177 | Xbyak::Reg64 reg_dst_str = r13; |
178 | Xbyak::Reg64 reg_acc_str = r14; |
179 | |
180 | Xbyak::Reg64 reserved_eltwise_gpr = r10; |
181 | Xbyak::Opmask reserved_eltwise_maskr = k2; |
182 | |
183 | Xbyak::Zmm vreg_sum_scale, vreg_bias; |
184 | |
185 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27); |
186 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28); |
187 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(29); |
188 | Xbyak::Reg64 bf16_emu_reserv_4 = r15; |
189 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30); |
190 | Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(31); |
191 | |
192 | constexpr static int reg64_size = sizeof(int64_t); |
193 | constexpr static int reg_binary_post_op_acc_off = 0; |
194 | constexpr static int stack_space_needed = reg64_size; |
195 | |
196 | const conv_gemm_conf_t &jcp_; |
197 | const bool do_sum_; |
198 | int max_data_reg_idx_, max_unroll_, compute_reg_step_; |
199 | int data_reg_base_idx_; |
200 | size_t vlen_; |
201 | cpu_isa_t isa_; |
202 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
203 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
204 | postops_injector_; |
205 | |
206 | void apply_postops(const bool apply_mask, const size_t out_offset, |
207 | const int vmm_idx); |
208 | void generate() override; |
209 | int vreg_dst_idx(int iter) { |
210 | int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 0; |
211 | assert(idx <= max_data_reg_idx_); |
212 | return idx; |
213 | } |
214 | int vreg_prev_dst_idx(int iter) { |
215 | int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 1; |
216 | assert(idx <= max_data_reg_idx_); |
217 | return idx; |
218 | } |
219 | |
220 | Xbyak::Zmm vreg_dst(int iter) { |
221 | return Xbyak::Zmm(vreg_dst_idx(iter)); |
222 | }; |
223 | |
224 | Xbyak::Ymm vreg_dst_ymm(int iter) { |
225 | return Xbyak::Ymm(vreg_dst_idx(iter)); |
226 | }; |
227 | |
228 | Xbyak::Zmm vreg_prev_dst(int iter) { |
229 | return Xbyak::Zmm(vreg_prev_dst_idx(iter)); |
230 | }; |
231 | |
232 | Xbyak::Ymm vreg_prev_dst_ymm(int iter) { |
233 | return Xbyak::Ymm(vreg_prev_dst_idx(iter)); |
234 | }; |
235 | }; |
236 | |
237 | acc_data_t beta_; |
238 | std::unique_ptr<pp_ker_t> pp_ker_; |
239 | }; |
240 | |
241 | template <data_type_t diff_src_data_type> |
242 | struct gemm_bf16_convolution_bwd_data_t : public primitive_t { |
243 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
244 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
245 | const convolution_fwd_pd_t *hint_fwd_pd) |
246 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
247 | |
248 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_data_t, |
249 | USE_GLOBAL_SCRATCHPAD); |
250 | |
251 | status_t init(engine_t *engine) { |
252 | bool ok = true && mayiuse(avx512_core) |
253 | && desc()->prop_kind == prop_kind::backward_data |
254 | && set_default_alg_kind(alg_kind::convolution_direct) |
255 | && expect_data_types(diff_src_data_type, data_type::bf16, |
256 | data_type::undef, data_type::bf16, data_type::f32) |
257 | && !has_zero_dim_memory() && attr()->has_default_values(); |
258 | if (!ok) return status::unimplemented; |
259 | |
260 | auto scratchpad = scratchpad_registry().registrar(); |
261 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
262 | *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_, |
263 | attr_, dnnl_get_max_threads()); |
264 | } |
265 | |
266 | conv_gemm_conf_t jcp_; |
267 | }; |
268 | |
269 | gemm_bf16_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
270 | |
271 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
272 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
273 | typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t; |
274 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
275 | |
276 | status_t execute(const exec_ctx_t &ctx) const override { |
277 | const bool is_nspc = pd()->jcp_.is_nspc; |
278 | return is_nspc ? execute_backward_data_nspc(ctx) |
279 | : execute_backward_data_ncsp(ctx); |
280 | } |
281 | |
282 | private: |
283 | status_t execute_backward_data_ncsp(const exec_ctx_t &ctx) const; |
284 | status_t execute_backward_data_nspc(const exec_ctx_t &ctx) const; |
285 | status_t execute_backward_data_thr_nspc(const int ithr, const int nthr, |
286 | diff_src_data_t *diff_src_base, const wei_data_t *wei_base, |
287 | const diff_dst_data_t *diff_dst_base, |
288 | const memory_tracking::grantor_t &scratchpad) const; |
289 | |
290 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
291 | }; |
292 | |
293 | template <data_type_t diff_wei_data_type> |
294 | struct gemm_bf16_convolution_bwd_weights_t : public primitive_t { |
295 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
296 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
297 | const convolution_fwd_pd_t *hint_fwd_pd) |
298 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
299 | , jcp_() {} |
300 | |
301 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_weights_t, |
302 | USE_GLOBAL_SCRATCHPAD); |
303 | |
304 | status_t init(engine_t *engine) { |
305 | bool ok = true && mayiuse(avx512_core) |
306 | && desc()->prop_kind == prop_kind::backward_weights |
307 | && set_default_alg_kind(alg_kind::convolution_direct) |
308 | && expect_data_types(data_type::bf16, diff_wei_data_type, |
309 | data_type::undef, data_type::bf16, data_type::f32) |
310 | && IMPLICATION(with_bias(), |
311 | utils::one_of(desc()->diff_bias_desc.data_type, |
312 | data_type::bf16, data_type::f32)) |
313 | && !has_zero_dim_memory() && attr()->has_default_values(); |
314 | if (!ok) return status::unimplemented; |
315 | |
316 | auto scratchpad = scratchpad_registry().registrar(); |
317 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
318 | *desc(), src_md_, diff_weights_md_, diff_dst_md_, |
319 | diff_bias_md_, attr_, dnnl_get_max_threads()); |
320 | } |
321 | |
322 | conv_gemm_conf_t jcp_; |
323 | }; |
324 | |
325 | gemm_bf16_convolution_bwd_weights_t(const pd_t *apd) |
326 | : primitive_t(apd), acc_ker_(nullptr) {} |
327 | |
328 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
329 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
330 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
331 | typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t; |
332 | |
333 | status_t init(engine_t *engine) override { |
334 | CHECK(safe_ptr_assign( |
335 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
336 | return acc_ker_->create_kernel(); |
337 | } |
338 | |
339 | status_t execute(const exec_ctx_t &ctx) const override { |
340 | const bool is_nspc = pd()->jcp_.is_nspc; |
341 | return is_nspc ? execute_backward_weights_nspc(ctx) |
342 | : execute_backward_weights_ncsp(ctx); |
343 | } |
344 | |
345 | private: |
346 | void bf16_bwd_weights_reduction_par_ncsp(int ithr_mb, int nthr_mb, |
347 | const conv_gemm_conf_t &jcp, const acc_data_t *weights_reduce_base, |
348 | diff_wei_data_t *weights_base) const; |
349 | void bf16_bwd_weights_reduction_par_nspc(int ithr_mb, int nthr_mb, |
350 | size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp, |
351 | const acc_data_t *weights_reduce_base, |
352 | diff_wei_data_t *weights_base) const; |
353 | |
354 | status_t execute_backward_weights_ncsp(const exec_ctx_t &ctx) const; |
355 | status_t execute_backward_weights_nspc(const exec_ctx_t &ctx) const; |
356 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
357 | |
358 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
359 | }; |
360 | |
361 | } // namespace x64 |
362 | } // namespace cpu |
363 | } // namespace impl |
364 | } // namespace dnnl |
365 | |
366 | #endif |
367 | |