1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16_SUM_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16_SUM_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | |
23 | #include "cpu/cpu_sum_pd.hpp" |
24 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | struct jit_sum_conf_t { |
32 | int num_srcs; |
33 | cpu_isa_t isa; |
34 | int is_bf16_dst; |
35 | int typesize_in; |
36 | int typesize_out; |
37 | int loop_unroll; |
38 | int size_blocking; /* minimum recommended data blocking size as this |
39 | number of elements computes main unrolled loop |
40 | in jit kernel per iteration */ |
41 | }; |
42 | |
43 | struct jit_sum_call_s { |
44 | const void **srcs; |
45 | const void *dst; |
46 | const void *scales; |
47 | dim_t size; |
48 | }; |
49 | |
50 | struct jit_avx512_core_bf16_sum_kernel : public jit_generator { |
51 | jit_avx512_core_bf16_sum_kernel(jit_sum_conf_t ajsp) |
52 | : jit_generator(jit_name()), jsp(ajsp), bf16_emu_(nullptr) { |
53 | if (!mayiuse(avx512_core_bf16)) |
54 | bf16_emu_ = new bf16_emulation_t(this, bf16_emu_reserved_1, |
55 | bf16_emu_reserved_2, bf16_emu_reserved_3, bf16_emu_scratch, |
56 | bf16_emu_reserved_4, bf16_emu_reserved_5); |
57 | } |
58 | |
59 | ~jit_avx512_core_bf16_sum_kernel() { delete bf16_emu_; } |
60 | |
61 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_sum_kernel) |
62 | |
63 | static status_t init_conf(jit_sum_conf_t &jsp, const int num_srcs, |
64 | const memory_desc_t &dst_d); |
65 | |
66 | static constexpr int max_num_arrs = 8; |
67 | jit_sum_conf_t jsp; |
68 | |
69 | private: |
70 | using reg64_t = const Xbyak::Reg64; |
71 | using reg32_t = const Xbyak::Reg32; |
72 | using reg8_t = const Xbyak::Reg8; |
73 | using zmm_t = const Xbyak::Zmm; |
74 | using ymm_t = const Xbyak::Ymm; |
75 | using mask_t = const Xbyak::Opmask; |
76 | |
77 | enum { f32_simd_w = 16, bf16_simd_w = 32 }; |
78 | |
79 | reg64_t param = abi_param1; /* may be rcx, note that cl is required |
80 | for mask computation */ |
81 | |
82 | reg64_t reg_srcs = abi_not_param1; /* may be rcx, note that cl is required |
83 | for mask computation */ |
84 | reg64_t reg_idx_table = abi_not_param1; /* may be rcx, note that cl is |
85 | required for mask computation */ |
86 | reg64_t reg_mask = rsi; |
87 | reg32_t reg32_mask = esi; |
88 | |
89 | reg64_t reg_dst = rax; |
90 | reg64_t reg_scales = rbx; |
91 | reg64_t reg_sz = rdx; |
92 | |
93 | reg64_t reg_src[max_num_arrs] = {r8, r9, r10, r11, r12, r13, r14, r15}; |
94 | |
95 | static int max_vregs_available(bool bf16_isa) { |
96 | // one vector registers are reserved for vperm index and zero values |
97 | // additional 5 registers are reserved for bf16 emulation on non-cpx |
98 | return bf16_isa ? 31 : 26; |
99 | } |
100 | |
101 | int acc_vreg_idx(int i_unroll, int i_acc) { |
102 | // 2 accumulation registers per unroll iteration |
103 | int idx = 2 * i_unroll + i_acc; |
104 | assert(idx < max_vregs_available(isa_has_bf16(jsp.isa))); |
105 | return idx; |
106 | } |
107 | |
108 | int scale_vreg_idx(int i_acc_iter) { |
109 | int scale_idx_start = 2 * jsp.loop_unroll; // reserved for acc registers |
110 | int idx = scale_idx_start + i_acc_iter; |
111 | assert(idx < max_vregs_available(isa_has_bf16(jsp.isa))); |
112 | return idx; |
113 | } |
114 | |
115 | int src_vreg_idx(int i_unroll, int i_inp) { |
116 | // reserved for acc and scale registers |
117 | int inp_idx_start |
118 | = 2 * jsp.loop_unroll + utils::div_up(jsp.num_srcs, 2); |
119 | int idx = inp_idx_start + utils::rnd_up(jsp.num_srcs, 2) * i_unroll |
120 | + i_inp; |
121 | assert(idx < max_vregs_available(isa_has_bf16(jsp.isa))); |
122 | return idx; |
123 | } |
124 | |
125 | int tmp_vreg_idx(int i_unroll, int i_acc_iter) { |
126 | int num_acc_iters = utils::div_up(jsp.num_srcs, 2); |
127 | // reserved for acc, scale and src registers |
128 | int tmp_idx_start = utils::div_up(jsp.num_srcs, 2) |
129 | + (2 + utils::rnd_up(jsp.num_srcs, 2)) * jsp.loop_unroll; |
130 | int idx = tmp_idx_start + num_acc_iters * i_unroll + i_acc_iter; |
131 | assert(idx < max_vregs_available(isa_has_bf16(jsp.isa))); |
132 | return idx; |
133 | } |
134 | |
135 | static int num_vregs_required(int unroll, int num_srcs) { |
136 | int num_acc_iters = utils::div_up(num_srcs, 2); |
137 | // reserved for acc, scale and src registers |
138 | int num_regs = utils::div_up(num_srcs, 2) |
139 | + (2 + utils::rnd_up(num_srcs, 2)) * unroll; |
140 | // tmp registers |
141 | num_regs += num_acc_iters * unroll; |
142 | return num_regs; |
143 | } |
144 | |
145 | Xbyak::Zmm bf16_emu_reserved_1 = Xbyak::Zmm(26); |
146 | Xbyak::Zmm bf16_emu_reserved_2 = Xbyak::Zmm(27); |
147 | Xbyak::Zmm bf16_emu_reserved_3 = Xbyak::Zmm(28); |
148 | Xbyak::Zmm bf16_emu_reserved_4 = Xbyak::Zmm(29); |
149 | Xbyak::Zmm bf16_emu_reserved_5 = Xbyak::Zmm(30); |
150 | Xbyak::Reg64 bf16_emu_scratch = abi_not_param1; |
151 | |
152 | Xbyak::Zmm zmm_idx = Xbyak::Zmm(31); |
153 | |
154 | Xbyak::Label idx_table; |
155 | const Xbyak::Opmask k_mask = k1; |
156 | |
157 | void generate() override; |
158 | void loop_iteration(int current_unroll); |
159 | bf16_emulation_t *bf16_emu_; |
160 | }; |
161 | |
162 | template <data_type_t src_data_type, data_type_t dst_data_type> |
163 | struct jit_bf16_sum_t : public primitive_t { |
164 | struct pd_t : public cpu_sum_pd_t { |
165 | using cpu_sum_pd_t::cpu_sum_pd_t; |
166 | |
167 | DECLARE_SUM_PD_T(JIT_IMPL_NAME_HELPER("jit_bf16_" , jsp_.isa, "" ), |
168 | jit_bf16_sum_t); |
169 | |
170 | status_t init(engine_t *engine) { |
171 | bool ok = true && mayiuse(avx512_core) |
172 | && cpu_sum_pd_t::init(engine) == status::success |
173 | && src_mds_.size() |
174 | <= jit_avx512_core_bf16_sum_kernel::max_num_arrs; |
175 | if (!ok) return status::unimplemented; |
176 | |
177 | const memory_desc_wrapper o_d(&dst_md_); |
178 | ok = true && o_d.data_type() == dst_data_type && o_d.is_dense(true); |
179 | if (!ok) return status::unimplemented; |
180 | |
181 | const auto n = src_mds_.size(); |
182 | |
183 | if (n > jit_avx512_core_bf16_sum_kernel::max_num_arrs) |
184 | return status::unimplemented; |
185 | |
186 | for (size_t i = 0; i < n; ++i) { |
187 | const memory_desc_wrapper i_d(&src_mds_[i]); |
188 | ok = true && src_data_type == i_d.data_type() |
189 | && o_d.similar_to(i_d, true, false, 0) |
190 | && i_d.is_dense(true) |
191 | // is scales representable in bfloat16: scales will be down |
192 | // converted to bf16 in order to use bf16 vnni instruction |
193 | && scales_[i] == float(bfloat16_t(scales_[i])); |
194 | if (!ok) return status::unimplemented; |
195 | } |
196 | |
197 | return jit_avx512_core_bf16_sum_kernel::init_conf( |
198 | jsp_, src_mds_.size(), dst_md_); |
199 | } |
200 | jit_sum_conf_t jsp_; |
201 | }; |
202 | |
203 | jit_bf16_sum_t(const pd_t *apd) : primitive_t(apd) {} |
204 | |
205 | status_t init(engine_t *engine) override { |
206 | CHECK(safe_ptr_assign( |
207 | kernel_, new jit_avx512_core_bf16_sum_kernel(pd()->jsp_))); |
208 | return kernel_->create_kernel(); |
209 | } |
210 | |
211 | status_t execute(const exec_ctx_t &ctx) const override; |
212 | |
213 | typedef typename prec_traits<src_data_type>::type src_data_t; |
214 | typedef typename prec_traits<dst_data_type>::type dst_data_t; |
215 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
216 | |
217 | private: |
218 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
219 | std::unique_ptr<jit_avx512_core_bf16_sum_kernel> kernel_; |
220 | }; |
221 | |
222 | } // namespace x64 |
223 | } // namespace cpu |
224 | } // namespace impl |
225 | } // namespace dnnl |
226 | |
227 | #endif |
228 | |