1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_BF16_CONVOLUTION_HPP
18#define CPU_X64_GEMM_BF16_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/primitive.hpp"
23
24#include "cpu/cpu_convolution_pd.hpp"
25#include "cpu/cpu_engine.hpp"
26#include "cpu/gemm/gemm.hpp"
27#include "cpu/gemm_convolution_utils.hpp"
28#include "cpu/x64/cpu_reducer.hpp"
29#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
30#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37template <data_type_t dst_data_type>
38struct gemm_bf16_convolution_fwd_t : public primitive_t {
39 struct pd_t : public cpu_convolution_fwd_pd_t {
40 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
41 const typename pd_t::base_class *hint_fwd_pd)
42 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
43
44 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_fwd_t,
45 USE_GLOBAL_SCRATCHPAD);
46
47 status_t init(engine_t *engine) {
48 bool ok = true && is_fwd() && mayiuse(avx512_core)
49 && set_default_alg_kind(alg_kind::convolution_direct)
50 && expect_data_types(data_type::bf16, data_type::bf16,
51 data_type::undef, dst_data_type, data_type::f32)
52 && IMPLICATION(with_bias(),
53 utils::one_of(desc()->bias_desc.data_type,
54 data_type::bf16, data_type::f32))
55 && !has_zero_dim_memory()
56 && attr()->has_default_values(
57 primitive_attr_t::skip_mask_t::post_ops,
58 dst_data_type);
59 {
60 using namespace x64::injector;
61 static constexpr bool sum_at_pos_0_only = true;
62 static constexpr bool sum_requires_scale_one = true;
63 static constexpr bool sum_requires_zp_zero = true;
64 const auto dst_md = memory_desc_wrapper(dst_md_);
65 ok &= post_ops_ok({avx512_core, {binary, eltwise, sum},
66 attr()->post_ops_, &dst_md, sum_at_pos_0_only,
67 sum_requires_scale_one, sum_requires_zp_zero});
68 }
69 if (!ok) return status::unimplemented;
70
71 auto scratchpad = scratchpad_registry().registrar();
72 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
73 *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_,
74 dnnl_get_max_threads());
75 }
76
77 bool is_postprocess_required() const {
78 bool post_ops_sum_only_for_dst_f32 = true
79 && dst_data_type == data_type::f32
80 && attr()->post_ops_.len() == 1
81 && attr()->post_ops_.contain(primitive_kind::sum, 0);
82 bool is_pp_for_post_ops_required = true
83 && attr()->post_ops_.len() > 0
84 && !post_ops_sum_only_for_dst_f32;
85 return dst_data_type == data_type::bf16 || with_bias()
86 || is_pp_for_post_ops_required;
87 }
88
89 conv_gemm_conf_t jcp_;
90 };
91
92 gemm_bf16_convolution_fwd_t(const pd_t *apd)
93 : primitive_t(apd), pp_ker_(nullptr) {}
94
95 typedef typename prec_traits<dst_data_type>::type dst_data_t;
96 typedef typename prec_traits<data_type::f32>::type acc_data_t;
97 typedef typename prec_traits<data_type::bf16>::type src_data_t;
98 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
99
100 status_t init(engine_t *engine) override {
101 const auto &post_ops = pd()->attr()->post_ops_;
102 const acc_data_t one = 1.0, zero = 0.0;
103 beta_ = dst_data_type == data_type::f32
104 && post_ops.find(primitive_kind::sum) >= 0
105 ? one
106 : zero;
107
108 if (this->pd()->is_postprocess_required()) {
109 CHECK(safe_ptr_assign(pp_ker_, new pp_ker_t(this->pd())));
110 return pp_ker_->create_kernel();
111 }
112 return status::success;
113 }
114
115 status_t execute(const exec_ctx_t &ctx) const override {
116 const bool is_nspc = pd()->jcp_.is_nspc;
117 return is_nspc ? execute_forward_nspc(ctx) : execute_forward_ncsp(ctx);
118 }
119
120private:
121 status_t execute_forward_ncsp(const exec_ctx_t &ctx) const;
122 status_t execute_forward_nspc(const exec_ctx_t &ctx) const;
123 status_t execute_forward_thr_nspc(const int ithr, const int nthr,
124 const src_data_t *src_base, const wei_data_t *wei_base,
125 const float *bia_base, dst_data_t *dst_base,
126 const memory_tracking::grantor_t &scratchpad,
127 const void *post_ops_binary_rhs_arg_vec) const;
128
129 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
130
131 class pp_ker_t : public jit_generator {
132 public:
133 DECLARE_CPU_JIT_AUX_FUNCTIONS(gemm_bf16_convolution_fwd_t::pp_kernel);
134 pp_ker_t(const pd_t *pd);
135
136 void operator()(dst_data_t *dst, const acc_data_t *acc,
137 const acc_data_t *bias, float sum_scale, size_t oc_work,
138 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
139 const size_t g_oc_offset);
140 void operator()(dst_data_t *dst, const acc_data_t *acc,
141 const acc_data_t *bias, float sum_scale, size_t dst_str,
142 size_t acc_str, size_t sp_len, size_t oc,
143 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
144 const size_t g_oc_offset);
145
146 private:
147 struct ker_args {
148 dst_data_t *dst;
149 const acc_data_t *acc;
150 const acc_data_t *bias;
151 float sum_scale;
152 size_t dst_stride_in_bytes;
153 size_t acc_stride_in_bytes;
154 size_t spatial_length;
155 size_t oc_work;
156
157 size_t g_oc_offset;
158 const void *post_ops_binary_rhs_arg_vec;
159 const void *dst_orig;
160 };
161
162 enum { default_unroll_2_pow_ = 2 };
163
164 Xbyak::Reg64 reg_param = rdi;
165 Xbyak::Reg64 reg_dst_base = rdx;
166 Xbyak::Reg64 reg_acc_base = rax;
167 Xbyak::Reg64 reg_dst = rsi;
168 Xbyak::Reg64 reg_acc = rbp;
169 Xbyak::Reg64 reg_bias = rbx;
170
171 Xbyak::Reg64 reg_len = r8;
172 Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes
173 Xbyak::Reg64 reg_rem_mask = r9;
174 Xbyak::Opmask kreg_rem_mask = k1;
175 Xbyak::Reg64 reg_oc_iter = r11;
176 Xbyak::Reg64 reg_len_iter = r12;
177 Xbyak::Reg64 reg_dst_str = r13;
178 Xbyak::Reg64 reg_acc_str = r14;
179
180 Xbyak::Reg64 reserved_eltwise_gpr = r10;
181 Xbyak::Opmask reserved_eltwise_maskr = k2;
182
183 Xbyak::Zmm vreg_sum_scale, vreg_bias;
184
185 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(27);
186 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(28);
187 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(29);
188 Xbyak::Reg64 bf16_emu_reserv_4 = r15;
189 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(30);
190 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(31);
191
192 constexpr static int reg64_size = sizeof(int64_t);
193 constexpr static int reg_binary_post_op_acc_off = 0;
194 constexpr static int stack_space_needed = reg64_size;
195
196 const conv_gemm_conf_t &jcp_;
197 const bool do_sum_;
198 int max_data_reg_idx_, max_unroll_, compute_reg_step_;
199 int data_reg_base_idx_;
200 size_t vlen_;
201 cpu_isa_t isa_;
202 std::unique_ptr<bf16_emulation_t> bf16_emu_;
203 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
204 postops_injector_;
205
206 void apply_postops(const bool apply_mask, const size_t out_offset,
207 const int vmm_idx);
208 void generate() override;
209 int vreg_dst_idx(int iter) {
210 int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 0;
211 assert(idx <= max_data_reg_idx_);
212 return idx;
213 }
214 int vreg_prev_dst_idx(int iter) {
215 int idx = data_reg_base_idx_ + iter * compute_reg_step_ + 1;
216 assert(idx <= max_data_reg_idx_);
217 return idx;
218 }
219
220 Xbyak::Zmm vreg_dst(int iter) {
221 return Xbyak::Zmm(vreg_dst_idx(iter));
222 };
223
224 Xbyak::Ymm vreg_dst_ymm(int iter) {
225 return Xbyak::Ymm(vreg_dst_idx(iter));
226 };
227
228 Xbyak::Zmm vreg_prev_dst(int iter) {
229 return Xbyak::Zmm(vreg_prev_dst_idx(iter));
230 };
231
232 Xbyak::Ymm vreg_prev_dst_ymm(int iter) {
233 return Xbyak::Ymm(vreg_prev_dst_idx(iter));
234 };
235 };
236
237 acc_data_t beta_;
238 std::unique_ptr<pp_ker_t> pp_ker_;
239};
240
241template <data_type_t diff_src_data_type>
242struct gemm_bf16_convolution_bwd_data_t : public primitive_t {
243 struct pd_t : public cpu_convolution_bwd_data_pd_t {
244 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
245 const convolution_fwd_pd_t *hint_fwd_pd)
246 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
247
248 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_data_t,
249 USE_GLOBAL_SCRATCHPAD);
250
251 status_t init(engine_t *engine) {
252 bool ok = true && mayiuse(avx512_core)
253 && desc()->prop_kind == prop_kind::backward_data
254 && set_default_alg_kind(alg_kind::convolution_direct)
255 && expect_data_types(diff_src_data_type, data_type::bf16,
256 data_type::undef, data_type::bf16, data_type::f32)
257 && !has_zero_dim_memory() && attr()->has_default_values();
258 if (!ok) return status::unimplemented;
259
260 auto scratchpad = scratchpad_registry().registrar();
261 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
262 *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_,
263 attr_, dnnl_get_max_threads());
264 }
265
266 conv_gemm_conf_t jcp_;
267 };
268
269 gemm_bf16_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
270
271 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
272 typedef typename prec_traits<data_type::f32>::type acc_data_t;
273 typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t;
274 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
275
276 status_t execute(const exec_ctx_t &ctx) const override {
277 const bool is_nspc = pd()->jcp_.is_nspc;
278 return is_nspc ? execute_backward_data_nspc(ctx)
279 : execute_backward_data_ncsp(ctx);
280 }
281
282private:
283 status_t execute_backward_data_ncsp(const exec_ctx_t &ctx) const;
284 status_t execute_backward_data_nspc(const exec_ctx_t &ctx) const;
285 status_t execute_backward_data_thr_nspc(const int ithr, const int nthr,
286 diff_src_data_t *diff_src_base, const wei_data_t *wei_base,
287 const diff_dst_data_t *diff_dst_base,
288 const memory_tracking::grantor_t &scratchpad) const;
289
290 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
291};
292
293template <data_type_t diff_wei_data_type>
294struct gemm_bf16_convolution_bwd_weights_t : public primitive_t {
295 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
296 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
297 const convolution_fwd_pd_t *hint_fwd_pd)
298 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
299 , jcp_() {}
300
301 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_convolution_bwd_weights_t,
302 USE_GLOBAL_SCRATCHPAD);
303
304 status_t init(engine_t *engine) {
305 bool ok = true && mayiuse(avx512_core)
306 && desc()->prop_kind == prop_kind::backward_weights
307 && set_default_alg_kind(alg_kind::convolution_direct)
308 && expect_data_types(data_type::bf16, diff_wei_data_type,
309 data_type::undef, data_type::bf16, data_type::f32)
310 && IMPLICATION(with_bias(),
311 utils::one_of(desc()->diff_bias_desc.data_type,
312 data_type::bf16, data_type::f32))
313 && !has_zero_dim_memory() && attr()->has_default_values();
314 if (!ok) return status::unimplemented;
315
316 auto scratchpad = scratchpad_registry().registrar();
317 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
318 *desc(), src_md_, diff_weights_md_, diff_dst_md_,
319 diff_bias_md_, attr_, dnnl_get_max_threads());
320 }
321
322 conv_gemm_conf_t jcp_;
323 };
324
325 gemm_bf16_convolution_bwd_weights_t(const pd_t *apd)
326 : primitive_t(apd), acc_ker_(nullptr) {}
327
328 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
329 typedef typename prec_traits<data_type::f32>::type acc_data_t;
330 typedef typename prec_traits<data_type::bf16>::type src_data_t;
331 typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t;
332
333 status_t init(engine_t *engine) override {
334 CHECK(safe_ptr_assign(
335 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
336 return acc_ker_->create_kernel();
337 }
338
339 status_t execute(const exec_ctx_t &ctx) const override {
340 const bool is_nspc = pd()->jcp_.is_nspc;
341 return is_nspc ? execute_backward_weights_nspc(ctx)
342 : execute_backward_weights_ncsp(ctx);
343 }
344
345private:
346 void bf16_bwd_weights_reduction_par_ncsp(int ithr_mb, int nthr_mb,
347 const conv_gemm_conf_t &jcp, const acc_data_t *weights_reduce_base,
348 diff_wei_data_t *weights_base) const;
349 void bf16_bwd_weights_reduction_par_nspc(int ithr_mb, int nthr_mb,
350 size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp,
351 const acc_data_t *weights_reduce_base,
352 diff_wei_data_t *weights_base) const;
353
354 status_t execute_backward_weights_ncsp(const exec_ctx_t &ctx) const;
355 status_t execute_backward_weights_nspc(const exec_ctx_t &ctx) const;
356 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
357
358 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
359};
360
361} // namespace x64
362} // namespace cpu
363} // namespace impl
364} // namespace dnnl
365
366#endif
367