1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16_SUM_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16_SUM_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22
23#include "cpu/cpu_sum_pd.hpp"
24#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31struct jit_sum_conf_t {
32 int num_srcs;
33 cpu_isa_t isa;
34 int is_bf16_dst;
35 int typesize_in;
36 int typesize_out;
37 int loop_unroll;
38 int size_blocking; /* minimum recommended data blocking size as this
39 number of elements computes main unrolled loop
40 in jit kernel per iteration */
41};
42
43struct jit_sum_call_s {
44 const void **srcs;
45 const void *dst;
46 const void *scales;
47 dim_t size;
48};
49
50struct jit_avx512_core_bf16_sum_kernel : public jit_generator {
51 jit_avx512_core_bf16_sum_kernel(jit_sum_conf_t ajsp)
52 : jit_generator(jit_name()), jsp(ajsp), bf16_emu_(nullptr) {
53 if (!mayiuse(avx512_core_bf16))
54 bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserved_1,
55 bf16_emu_reserved_2, bf16_emu_reserved_3, bf16_emu_scratch,
56 bf16_emu_reserved_4, bf16_emu_reserved_5);
57 }
58
59 ~jit_avx512_core_bf16_sum_kernel() { delete bf16_emu_; }
60
61 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_sum_kernel)
62
63 static status_t init_conf(jit_sum_conf_t &jsp, const int num_srcs,
64 const memory_desc_t &dst_d);
65
66 static constexpr int max_num_arrs = 8;
67 jit_sum_conf_t jsp;
68
69private:
70 using reg64_t = const Xbyak::Reg64;
71 using reg32_t = const Xbyak::Reg32;
72 using reg8_t = const Xbyak::Reg8;
73 using zmm_t = const Xbyak::Zmm;
74 using ymm_t = const Xbyak::Ymm;
75 using mask_t = const Xbyak::Opmask;
76
77 enum { f32_simd_w = 16, bf16_simd_w = 32 };
78
79 reg64_t param = abi_param1; /* may be rcx, note that cl is required
80 for mask computation */
81
82 reg64_t reg_srcs = abi_not_param1; /* may be rcx, note that cl is required
83 for mask computation */
84 reg64_t reg_idx_table = abi_not_param1; /* may be rcx, note that cl is
85 required for mask computation */
86 reg64_t reg_mask = rsi;
87 reg32_t reg32_mask = esi;
88
89 reg64_t reg_dst = rax;
90 reg64_t reg_scales = rbx;
91 reg64_t reg_sz = rdx;
92
93 reg64_t reg_src[max_num_arrs] = {r8, r9, r10, r11, r12, r13, r14, r15};
94
95 static int max_vregs_available(bool bf16_isa) {
96 // one vector registers are reserved for vperm index and zero values
97 // additional 5 registers are reserved for bf16 emulation on non-cpx
98 return bf16_isa ? 31 : 26;
99 }
100
101 int acc_vreg_idx(int i_unroll, int i_acc) {
102 // 2 accumulation registers per unroll iteration
103 int idx = 2 * i_unroll + i_acc;
104 assert(idx < max_vregs_available(isa_has_bf16(jsp.isa)));
105 return idx;
106 }
107
108 int scale_vreg_idx(int i_acc_iter) {
109 int scale_idx_start = 2 * jsp.loop_unroll; // reserved for acc registers
110 int idx = scale_idx_start + i_acc_iter;
111 assert(idx < max_vregs_available(isa_has_bf16(jsp.isa)));
112 return idx;
113 }
114
115 int src_vreg_idx(int i_unroll, int i_inp) {
116 // reserved for acc and scale registers
117 int inp_idx_start
118 = 2 * jsp.loop_unroll + utils::div_up(jsp.num_srcs, 2);
119 int idx = inp_idx_start + utils::rnd_up(jsp.num_srcs, 2) * i_unroll
120 + i_inp;
121 assert(idx < max_vregs_available(isa_has_bf16(jsp.isa)));
122 return idx;
123 }
124
125 int tmp_vreg_idx(int i_unroll, int i_acc_iter) {
126 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
127 // reserved for acc, scale and src registers
128 int tmp_idx_start = utils::div_up(jsp.num_srcs, 2)
129 + (2 + utils::rnd_up(jsp.num_srcs, 2)) * jsp.loop_unroll;
130 int idx = tmp_idx_start + num_acc_iters * i_unroll + i_acc_iter;
131 assert(idx < max_vregs_available(isa_has_bf16(jsp.isa)));
132 return idx;
133 }
134
135 static int num_vregs_required(int unroll, int num_srcs) {
136 int num_acc_iters = utils::div_up(num_srcs, 2);
137 // reserved for acc, scale and src registers
138 int num_regs = utils::div_up(num_srcs, 2)
139 + (2 + utils::rnd_up(num_srcs, 2)) * unroll;
140 // tmp registers
141 num_regs += num_acc_iters * unroll;
142 return num_regs;
143 }
144
145 Xbyak::Zmm bf16_emu_reserved_1 = Xbyak::Zmm(26);
146 Xbyak::Zmm bf16_emu_reserved_2 = Xbyak::Zmm(27);
147 Xbyak::Zmm bf16_emu_reserved_3 = Xbyak::Zmm(28);
148 Xbyak::Zmm bf16_emu_reserved_4 = Xbyak::Zmm(29);
149 Xbyak::Zmm bf16_emu_reserved_5 = Xbyak::Zmm(30);
150 Xbyak::Reg64 bf16_emu_scratch = abi_not_param1;
151
152 Xbyak::Zmm zmm_idx = Xbyak::Zmm(31);
153
154 Xbyak::Label idx_table;
155 const Xbyak::Opmask k_mask = k1;
156
157 void generate() override;
158 void loop_iteration(int current_unroll);
159 bf16_emulation_t *bf16_emu_;
160};
161
162template <data_type_t src_data_type, data_type_t dst_data_type>
163struct jit_bf16_sum_t : public primitive_t {
164 struct pd_t : public cpu_sum_pd_t {
165 using cpu_sum_pd_t::cpu_sum_pd_t;
166
167 DECLARE_SUM_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_", jsp_.isa, ""),
168 jit_bf16_sum_t);
169
170 status_t init(engine_t *engine) {
171 bool ok = true && mayiuse(avx512_core)
172 && cpu_sum_pd_t::init(engine) == status::success
173 && src_mds_.size()
174 <= jit_avx512_core_bf16_sum_kernel::max_num_arrs;
175 if (!ok) return status::unimplemented;
176
177 const memory_desc_wrapper o_d(&dst_md_);
178 ok = true && o_d.data_type() == dst_data_type && o_d.is_dense(true);
179 if (!ok) return status::unimplemented;
180
181 const auto n = src_mds_.size();
182
183 if (n > jit_avx512_core_bf16_sum_kernel::max_num_arrs)
184 return status::unimplemented;
185
186 for (size_t i = 0; i < n; ++i) {
187 const memory_desc_wrapper i_d(&src_mds_[i]);
188 ok = true && src_data_type == i_d.data_type()
189 && o_d.similar_to(i_d, true, false, 0)
190 && i_d.is_dense(true)
191 // is scales representable in bfloat16: scales will be down
192 // converted to bf16 in order to use bf16 vnni instruction
193 && scales_[i] == float(bfloat16_t(scales_[i]));
194 if (!ok) return status::unimplemented;
195 }
196
197 return jit_avx512_core_bf16_sum_kernel::init_conf(
198 jsp_, src_mds_.size(), dst_md_);
199 }
200 jit_sum_conf_t jsp_;
201 };
202
203 jit_bf16_sum_t(const pd_t *apd) : primitive_t(apd) {}
204
205 status_t init(engine_t *engine) override {
206 CHECK(safe_ptr_assign(
207 kernel_, new jit_avx512_core_bf16_sum_kernel(pd()->jsp_)));
208 return kernel_->create_kernel();
209 }
210
211 status_t execute(const exec_ctx_t &ctx) const override;
212
213 typedef typename prec_traits<src_data_type>::type src_data_t;
214 typedef typename prec_traits<dst_data_type>::type dst_data_t;
215 typedef typename prec_traits<data_type::f32>::type acc_data_t;
216
217private:
218 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
219 std::unique_ptr<jit_avx512_core_bf16_sum_kernel> kernel_;
220};
221
222} // namespace x64
223} // namespace cpu
224} // namespace impl
225} // namespace dnnl
226
227#endif
228