1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <float.h> |
17 | |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/utils.hpp" |
20 | |
21 | #include "cpu/x64/jit_avx512_core_bf16_sum.hpp" |
22 | |
23 | #define GET_OFF(field) offsetof(jit_sum_call_s, field) |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::prop_kind; |
31 | using namespace dnnl::impl::utils; |
32 | |
33 | using namespace Xbyak; |
34 | void jit_avx512_core_bf16_sum_kernel::loop_iteration(int current_unroll) { |
35 | Label loop_label, exit_label; |
36 | const int num_compute_elements = 2 * f32_simd_w * current_unroll; |
37 | dim_t src_shift = 2 * f32_simd_w * jsp.typesize_in; |
38 | dim_t dst_shift = f32_simd_w * jsp.typesize_out; |
39 | |
40 | L(loop_label); |
41 | cmp(reg_sz, num_compute_elements); |
42 | jl(exit_label, T_NEAR); |
43 | for (int u_idx = 0; u_idx < current_unroll; u_idx++) { |
44 | zmm_t vacc0 = Zmm(acc_vreg_idx(u_idx, 0)); |
45 | zmm_t vacc1 = Zmm(acc_vreg_idx(u_idx, 1)); |
46 | vpxord(vacc0, vacc0, vacc0); |
47 | vpxord(vacc1, vacc1, vacc1); |
48 | |
49 | int num_acc_iters = utils::div_up(jsp.num_srcs, 2); |
50 | for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) { |
51 | int isrc0 = 2 * acc_iter; |
52 | int isrc1 = 2 * acc_iter + 1; |
53 | zmm_t vscale = Zmm(scale_vreg_idx(acc_iter)); |
54 | zmm_t vsrc0 = Zmm(src_vreg_idx(u_idx, isrc0)); |
55 | zmm_t vsrc1 = Zmm(src_vreg_idx(u_idx, isrc1)); |
56 | zmm_t vtmp = Zmm(tmp_vreg_idx(u_idx, acc_iter)); |
57 | vmovups(vsrc0, zword[reg_src[isrc0] + u_idx * src_shift]); |
58 | if (num_acc_iters * 2 > jsp.num_srcs |
59 | && acc_iter == num_acc_iters - 1) |
60 | vpxord(vtmp, vtmp, vtmp); /* imitate additional zero input |
61 | if number of srcs is odd */ |
62 | else |
63 | vmovups(vtmp, zword[reg_src[isrc1] + u_idx * src_shift]); |
64 | vshuff64x2(vsrc1, vsrc0, vtmp, 0xEE); |
65 | vpermw(vsrc1, zmm_idx, vsrc1); |
66 | vshuff64x2(vsrc0, vsrc0, vtmp, 0x44); |
67 | vpermw(vsrc0, zmm_idx, vsrc0); |
68 | |
69 | if (!isa_has_bf16(jsp.isa)) { |
70 | bf16_emu_->vdpbf16ps(vacc0, vsrc0, vscale); |
71 | bf16_emu_->vdpbf16ps(vacc1, vsrc1, vscale); |
72 | } else { |
73 | vdpbf16ps(vacc0, vsrc0, vscale); |
74 | vdpbf16ps(vacc1, vsrc1, vscale); |
75 | } |
76 | } |
77 | |
78 | if (!jsp.is_bf16_dst) { |
79 | vmovups(zword[reg_dst + 2 * u_idx * dst_shift], vacc0); |
80 | vmovups(zword[reg_dst + (2 * u_idx + 1) * dst_shift], vacc1); |
81 | } else { |
82 | if (isa_has_bf16(jsp.isa)) { |
83 | zmm_t zmm_str = Zmm(tmp_vreg_idx(u_idx, 0)); |
84 | vcvtne2ps2bf16(zmm_str, vacc1, vacc0); |
85 | vmovups(zword[reg_dst + 2 * u_idx * dst_shift], zmm_str); |
86 | } else { |
87 | auto ymm_str = Ymm(tmp_vreg_idx(u_idx, 0)); |
88 | bf16_emu_->vcvtneps2bf16(ymm_str, vacc0); |
89 | vmovups(yword[reg_dst + 2 * u_idx * dst_shift], ymm_str); |
90 | bf16_emu_->vcvtneps2bf16(ymm_str, vacc1); |
91 | vmovups(yword[reg_dst + (2 * u_idx + 1) * dst_shift], ymm_str); |
92 | } |
93 | } |
94 | } |
95 | sub(reg_sz, num_compute_elements); |
96 | for (int s = 0; s < jsp.num_srcs; s++) |
97 | add(reg_src[s], current_unroll * src_shift); |
98 | add(reg_dst, 2 * current_unroll * dst_shift); |
99 | jge(loop_label, T_NEAR); |
100 | |
101 | L(exit_label); |
102 | } |
103 | |
104 | void jit_avx512_core_bf16_sum_kernel::generate() { |
105 | preamble(); |
106 | |
107 | mov(reg_dst, ptr[param + GET_OFF(dst)]); |
108 | mov(reg_srcs, ptr[param + GET_OFF(srcs)]); |
109 | |
110 | for (int s = 0; s < jsp.num_srcs; s++) |
111 | mov(reg_src[s], ptr[reg_srcs + sizeof(void *) * s]); |
112 | |
113 | mov(reg_scales, ptr[param + GET_OFF(scales)]); |
114 | mov(reg_sz, ptr[param + GET_OFF(size)]); |
115 | |
116 | Label tail_label, exit_label, mask_label; |
117 | |
118 | mov(reg_idx_table, idx_table); |
119 | vmovups(zmm_idx, ptr[reg_idx_table]); |
120 | |
121 | int num_acc_iters = utils::div_up(jsp.num_srcs, 2); |
122 | for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) { |
123 | zmm_t vscale = Zmm(scale_vreg_idx(acc_iter)); |
124 | vpbroadcastd(vscale, ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]); |
125 | } |
126 | |
127 | if (!isa_has_bf16(jsp.isa)) bf16_emu_->init_vcvtneps2bf16(); |
128 | if (jsp.loop_unroll > 1) loop_iteration(jsp.loop_unroll); |
129 | |
130 | loop_iteration(1); |
131 | |
132 | // tail processing |
133 | L(tail_label); |
134 | cmp(reg_sz, 0); |
135 | jle(exit_label, T_NEAR); |
136 | |
137 | const int bf16_half_reg = f32_simd_w; |
138 | mov(reg32_mask, 0xffff); |
139 | cmp(reg_sz, bf16_half_reg); |
140 | jge(mask_label, T_NEAR); |
141 | |
142 | mov(reg32_mask, 1); |
143 | mov(rcx, reg_sz); |
144 | shl(reg32_mask, cl); |
145 | sub(reg32_mask, 1); |
146 | |
147 | L(mask_label); |
148 | kmovd(k_mask, reg32_mask); |
149 | zmm_t vacc = Zmm(acc_vreg_idx(0, 0)); |
150 | vpxord(vacc, vacc, vacc); |
151 | |
152 | for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) { |
153 | int isrc0 = 2 * acc_iter; |
154 | int isrc1 = 2 * acc_iter + 1; |
155 | zmm_t vscale = Zmm(scale_vreg_idx(acc_iter)); |
156 | zmm_t vsrc = Zmm(src_vreg_idx(0, isrc0)); |
157 | ymm_t vysrc0 = Ymm(src_vreg_idx(0, isrc0)); |
158 | ymm_t vysrc1 = Ymm(src_vreg_idx(0, isrc1)); |
159 | vpxord(vysrc0, vysrc0, vysrc0); |
160 | vpxord(vysrc1, vysrc1, vysrc1); |
161 | |
162 | vmovdqu16(vysrc0 | k_mask | T_z, yword[reg_src[isrc0]]); |
163 | if (!(num_acc_iters * 2 > jsp.num_srcs |
164 | && acc_iter == num_acc_iters - 1)) |
165 | vmovdqu16(vysrc1 | k_mask | T_z, yword[reg_src[isrc1]]); |
166 | vinserti64x4(vsrc, vsrc, vysrc1, 0x1); |
167 | vpermw(vsrc, zmm_idx, vsrc); |
168 | |
169 | if (!isa_has_bf16(jsp.isa)) { |
170 | bf16_emu_->vdpbf16ps(vacc, vsrc, vscale); |
171 | } else { |
172 | vdpbf16ps(vacc, vsrc, vscale); |
173 | } |
174 | } |
175 | if (!jsp.is_bf16_dst) { |
176 | vmovups(zword[reg_dst] | k_mask, vacc); |
177 | } else { |
178 | if (isa_has_bf16(jsp.isa)) { |
179 | auto ymm_str = Ymm(tmp_vreg_idx(0, 0)); |
180 | vcvtneps2bf16(ymm_str, vacc); |
181 | vmovdqu16(yword[reg_dst] | k_mask, ymm_str); |
182 | } else { |
183 | auto ymm_str = Ymm(tmp_vreg_idx(0, 0)); |
184 | bf16_emu_->vcvtneps2bf16(ymm_str, vacc); |
185 | vmovdqu16(yword[reg_dst] | k_mask, ymm_str); |
186 | } |
187 | } |
188 | |
189 | sub(reg_sz, bf16_half_reg); |
190 | cmp(reg_sz, 0); |
191 | jle(exit_label, T_NEAR); |
192 | |
193 | for (int s = 0; s < jsp.num_srcs; s++) |
194 | add(reg_src[s], bf16_half_reg * jsp.typesize_in); |
195 | add(reg_dst, f32_simd_w * jsp.typesize_out); |
196 | |
197 | jmp(tail_label, T_NEAR); |
198 | |
199 | L(exit_label); |
200 | postamble(); |
201 | |
202 | align(64); |
203 | L(idx_table); |
204 | const uint16_t _idx[] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, |
205 | 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
206 | const dim_t _idx_size = sizeof(_idx) / sizeof(_idx[0]); |
207 | for (dim_t i = 0; i < _idx_size; ++i) |
208 | dw(_idx[i]); |
209 | } |
210 | |
211 | status_t jit_avx512_core_bf16_sum_kernel::init_conf( |
212 | jit_sum_conf_t &jsp, const int num_srcs, const memory_desc_t &dst_d) { |
213 | jsp.num_srcs = num_srcs; |
214 | jsp.loop_unroll = 0; |
215 | jsp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
216 | : bf16_emulation_t::get_isa(); |
217 | |
218 | const int max_unroll = 6; // maximum possible value of unroll is 6 |
219 | for (/*continue*/; jsp.loop_unroll < max_unroll; jsp.loop_unroll++) { |
220 | int num_regs = num_vregs_required(jsp.loop_unroll + 1, jsp.num_srcs); |
221 | if (num_regs > max_vregs_available(isa_has_bf16(jsp.isa))) break; |
222 | } |
223 | if (jsp.loop_unroll == 0) return status::unimplemented; |
224 | jsp.size_blocking = bf16_simd_w * jsp.loop_unroll; |
225 | |
226 | const memory_desc_wrapper o_d(&dst_d); |
227 | jsp.is_bf16_dst = data_type::bf16 == o_d.data_type(); |
228 | |
229 | jsp.typesize_in = sizeof(bfloat16_t); |
230 | jsp.typesize_out = types::data_type_size(o_d.data_type()); |
231 | |
232 | return status::success; |
233 | } |
234 | |
235 | template <data_type_t src_data_type, data_type_t dst_data_type> |
236 | status_t jit_bf16_sum_t<src_data_type, dst_data_type>::execute( |
237 | const exec_ctx_t &ctx) const { |
238 | auto output = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
239 | const memory_desc_wrapper o_d(pd()->dst_md()); |
240 | output += o_d.blk_off(0); |
241 | const int num_arrs = pd()->n_inputs(); |
242 | const dim_t nelems = o_d.nelems(true); |
243 | const src_data_t *input_ptrs[jit_avx512_core_bf16_sum_kernel::max_num_arrs]; |
244 | /* Number of scales needs to be multiple of 2 in order |
245 | to use VNNI instructions */ |
246 | src_data_t scales[jit_avx512_core_bf16_sum_kernel::max_num_arrs]; |
247 | for (int a = 0; a < num_arrs; ++a) { |
248 | const memory_desc_wrapper i_d(pd()->src_md(a)); |
249 | |
250 | input_ptrs[a] |
251 | = CTX_IN_MEM(const src_data_t *, DNNL_ARG_MULTIPLE_SRC + a) |
252 | + i_d.blk_off(0); |
253 | } |
254 | cvt_float_to_bfloat16(scales, &pd()->scales()[0], num_arrs); |
255 | if (num_arrs % 2 != 0) scales[num_arrs] = 0.0f; |
256 | |
257 | const dim_t half_L1 = 16 * 1024; // bytes |
258 | const dim_t num_elems_in_block = utils::rnd_up( |
259 | utils::div_up(half_L1, |
260 | num_arrs * sizeof(src_data_t) + sizeof(dst_data_t)), |
261 | pd()->jsp_.size_blocking); |
262 | const dim_t num_blocks = nelems / num_elems_in_block; |
263 | const dim_t tail = nelems % num_elems_in_block; |
264 | |
265 | #if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8 \ |
266 | && __GNUC_PATCHLEVEL__ == 3 |
267 | // GCC issues a false positive warning 'array subscript is above array bounds' |
268 | // with gcc 4.8.3 + -march=native option, so disable it for now |
269 | #pragma GCC diagnostic push |
270 | #pragma GCC diagnostic ignored "-Warray-bounds" |
271 | #endif |
272 | parallel(0, [&](const int ithr, const int nthr) { |
273 | dim_t start {0}, end {0}; |
274 | balance211(num_blocks, nthr, ithr, start, end); |
275 | auto arg = jit_sum_call_s(); |
276 | const src_data_t * |
277 | local_input_ptrs[jit_avx512_core_bf16_sum_kernel::max_num_arrs]; |
278 | dst_data_t *local_output; |
279 | |
280 | for (dim_t nb = start; nb < end; ++nb) { |
281 | dim_t start_e = nb * num_elems_in_block; |
282 | for (int a = 0; a < num_arrs; ++a) { |
283 | local_input_ptrs[a] = &input_ptrs[a][start_e]; |
284 | } |
285 | local_output = &output[start_e]; |
286 | arg.srcs = (const void **)local_input_ptrs; |
287 | arg.dst = (const void *)local_output; |
288 | arg.scales = (const void *)scales; |
289 | arg.size = num_elems_in_block; |
290 | (*kernel_)(&arg); |
291 | } |
292 | |
293 | if (tail != 0 && ithr == nthr - 1) { |
294 | dim_t start_e = nelems - tail; |
295 | for (int a = 0; a < num_arrs; ++a) { |
296 | local_input_ptrs[a] = &input_ptrs[a][start_e]; |
297 | } |
298 | local_output = &output[start_e]; |
299 | arg.srcs = (const void **)local_input_ptrs; |
300 | arg.dst = (const void *)local_output; |
301 | arg.scales = (const void *)scales; |
302 | arg.size = tail; |
303 | (*kernel_)(&arg); |
304 | } |
305 | }); |
306 | #if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8 \ |
307 | && __GNUC_PATCHLEVEL__ == 3 |
308 | #pragma GCC diagnostic pop |
309 | #endif |
310 | return status::success; |
311 | } |
312 | |
313 | template struct jit_bf16_sum_t<data_type::bf16, data_type::f32>; |
314 | template struct jit_bf16_sum_t<data_type::bf16, data_type::bf16>; |
315 | |
316 | } // namespace x64 |
317 | } // namespace cpu |
318 | } // namespace impl |
319 | } // namespace dnnl |
320 | |