1/*******************************************************************************
2* Copyright 2019-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <float.h>
17
18#include "common/dnnl_thread.hpp"
19#include "common/utils.hpp"
20
21#include "cpu/x64/jit_avx512_core_bf16_sum.hpp"
22
23#define GET_OFF(field) offsetof(jit_sum_call_s, field)
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::prop_kind;
31using namespace dnnl::impl::utils;
32
33using namespace Xbyak;
34void jit_avx512_core_bf16_sum_kernel::loop_iteration(int current_unroll) {
35 Label loop_label, exit_label;
36 const int num_compute_elements = 2 * f32_simd_w * current_unroll;
37 dim_t src_shift = 2 * f32_simd_w * jsp.typesize_in;
38 dim_t dst_shift = f32_simd_w * jsp.typesize_out;
39
40 L(loop_label);
41 cmp(reg_sz, num_compute_elements);
42 jl(exit_label, T_NEAR);
43 for (int u_idx = 0; u_idx < current_unroll; u_idx++) {
44 zmm_t vacc0 = Zmm(acc_vreg_idx(u_idx, 0));
45 zmm_t vacc1 = Zmm(acc_vreg_idx(u_idx, 1));
46 vpxord(vacc0, vacc0, vacc0);
47 vpxord(vacc1, vacc1, vacc1);
48
49 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
50 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
51 int isrc0 = 2 * acc_iter;
52 int isrc1 = 2 * acc_iter + 1;
53 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
54 zmm_t vsrc0 = Zmm(src_vreg_idx(u_idx, isrc0));
55 zmm_t vsrc1 = Zmm(src_vreg_idx(u_idx, isrc1));
56 zmm_t vtmp = Zmm(tmp_vreg_idx(u_idx, acc_iter));
57 vmovups(vsrc0, zword[reg_src[isrc0] + u_idx * src_shift]);
58 if (num_acc_iters * 2 > jsp.num_srcs
59 && acc_iter == num_acc_iters - 1)
60 vpxord(vtmp, vtmp, vtmp); /* imitate additional zero input
61 if number of srcs is odd */
62 else
63 vmovups(vtmp, zword[reg_src[isrc1] + u_idx * src_shift]);
64 vshuff64x2(vsrc1, vsrc0, vtmp, 0xEE);
65 vpermw(vsrc1, zmm_idx, vsrc1);
66 vshuff64x2(vsrc0, vsrc0, vtmp, 0x44);
67 vpermw(vsrc0, zmm_idx, vsrc0);
68
69 if (!isa_has_bf16(jsp.isa)) {
70 bf16_emu_->vdpbf16ps(vacc0, vsrc0, vscale);
71 bf16_emu_->vdpbf16ps(vacc1, vsrc1, vscale);
72 } else {
73 vdpbf16ps(vacc0, vsrc0, vscale);
74 vdpbf16ps(vacc1, vsrc1, vscale);
75 }
76 }
77
78 if (!jsp.is_bf16_dst) {
79 vmovups(zword[reg_dst + 2 * u_idx * dst_shift], vacc0);
80 vmovups(zword[reg_dst + (2 * u_idx + 1) * dst_shift], vacc1);
81 } else {
82 if (isa_has_bf16(jsp.isa)) {
83 zmm_t zmm_str = Zmm(tmp_vreg_idx(u_idx, 0));
84 vcvtne2ps2bf16(zmm_str, vacc1, vacc0);
85 vmovups(zword[reg_dst + 2 * u_idx * dst_shift], zmm_str);
86 } else {
87 auto ymm_str = Ymm(tmp_vreg_idx(u_idx, 0));
88 bf16_emu_->vcvtneps2bf16(ymm_str, vacc0);
89 vmovups(yword[reg_dst + 2 * u_idx * dst_shift], ymm_str);
90 bf16_emu_->vcvtneps2bf16(ymm_str, vacc1);
91 vmovups(yword[reg_dst + (2 * u_idx + 1) * dst_shift], ymm_str);
92 }
93 }
94 }
95 sub(reg_sz, num_compute_elements);
96 for (int s = 0; s < jsp.num_srcs; s++)
97 add(reg_src[s], current_unroll * src_shift);
98 add(reg_dst, 2 * current_unroll * dst_shift);
99 jge(loop_label, T_NEAR);
100
101 L(exit_label);
102}
103
104void jit_avx512_core_bf16_sum_kernel::generate() {
105 preamble();
106
107 mov(reg_dst, ptr[param + GET_OFF(dst)]);
108 mov(reg_srcs, ptr[param + GET_OFF(srcs)]);
109
110 for (int s = 0; s < jsp.num_srcs; s++)
111 mov(reg_src[s], ptr[reg_srcs + sizeof(void *) * s]);
112
113 mov(reg_scales, ptr[param + GET_OFF(scales)]);
114 mov(reg_sz, ptr[param + GET_OFF(size)]);
115
116 Label tail_label, exit_label, mask_label;
117
118 mov(reg_idx_table, idx_table);
119 vmovups(zmm_idx, ptr[reg_idx_table]);
120
121 int num_acc_iters = utils::div_up(jsp.num_srcs, 2);
122 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
123 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
124 vpbroadcastd(vscale, ptr[reg_scales + 2 * acc_iter * jsp.typesize_in]);
125 }
126
127 if (!isa_has_bf16(jsp.isa)) bf16_emu_->init_vcvtneps2bf16();
128 if (jsp.loop_unroll > 1) loop_iteration(jsp.loop_unroll);
129
130 loop_iteration(1);
131
132 // tail processing
133 L(tail_label);
134 cmp(reg_sz, 0);
135 jle(exit_label, T_NEAR);
136
137 const int bf16_half_reg = f32_simd_w;
138 mov(reg32_mask, 0xffff);
139 cmp(reg_sz, bf16_half_reg);
140 jge(mask_label, T_NEAR);
141
142 mov(reg32_mask, 1);
143 mov(rcx, reg_sz);
144 shl(reg32_mask, cl);
145 sub(reg32_mask, 1);
146
147 L(mask_label);
148 kmovd(k_mask, reg32_mask);
149 zmm_t vacc = Zmm(acc_vreg_idx(0, 0));
150 vpxord(vacc, vacc, vacc);
151
152 for (int acc_iter = 0; acc_iter < num_acc_iters; acc_iter++) {
153 int isrc0 = 2 * acc_iter;
154 int isrc1 = 2 * acc_iter + 1;
155 zmm_t vscale = Zmm(scale_vreg_idx(acc_iter));
156 zmm_t vsrc = Zmm(src_vreg_idx(0, isrc0));
157 ymm_t vysrc0 = Ymm(src_vreg_idx(0, isrc0));
158 ymm_t vysrc1 = Ymm(src_vreg_idx(0, isrc1));
159 vpxord(vysrc0, vysrc0, vysrc0);
160 vpxord(vysrc1, vysrc1, vysrc1);
161
162 vmovdqu16(vysrc0 | k_mask | T_z, yword[reg_src[isrc0]]);
163 if (!(num_acc_iters * 2 > jsp.num_srcs
164 && acc_iter == num_acc_iters - 1))
165 vmovdqu16(vysrc1 | k_mask | T_z, yword[reg_src[isrc1]]);
166 vinserti64x4(vsrc, vsrc, vysrc1, 0x1);
167 vpermw(vsrc, zmm_idx, vsrc);
168
169 if (!isa_has_bf16(jsp.isa)) {
170 bf16_emu_->vdpbf16ps(vacc, vsrc, vscale);
171 } else {
172 vdpbf16ps(vacc, vsrc, vscale);
173 }
174 }
175 if (!jsp.is_bf16_dst) {
176 vmovups(zword[reg_dst] | k_mask, vacc);
177 } else {
178 if (isa_has_bf16(jsp.isa)) {
179 auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
180 vcvtneps2bf16(ymm_str, vacc);
181 vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
182 } else {
183 auto ymm_str = Ymm(tmp_vreg_idx(0, 0));
184 bf16_emu_->vcvtneps2bf16(ymm_str, vacc);
185 vmovdqu16(yword[reg_dst] | k_mask, ymm_str);
186 }
187 }
188
189 sub(reg_sz, bf16_half_reg);
190 cmp(reg_sz, 0);
191 jle(exit_label, T_NEAR);
192
193 for (int s = 0; s < jsp.num_srcs; s++)
194 add(reg_src[s], bf16_half_reg * jsp.typesize_in);
195 add(reg_dst, f32_simd_w * jsp.typesize_out);
196
197 jmp(tail_label, T_NEAR);
198
199 L(exit_label);
200 postamble();
201
202 align(64);
203 L(idx_table);
204 const uint16_t _idx[] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7,
205 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
206 const dim_t _idx_size = sizeof(_idx) / sizeof(_idx[0]);
207 for (dim_t i = 0; i < _idx_size; ++i)
208 dw(_idx[i]);
209}
210
211status_t jit_avx512_core_bf16_sum_kernel::init_conf(
212 jit_sum_conf_t &jsp, const int num_srcs, const memory_desc_t &dst_d) {
213 jsp.num_srcs = num_srcs;
214 jsp.loop_unroll = 0;
215 jsp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
216 : bf16_emulation_t::get_isa();
217
218 const int max_unroll = 6; // maximum possible value of unroll is 6
219 for (/*continue*/; jsp.loop_unroll < max_unroll; jsp.loop_unroll++) {
220 int num_regs = num_vregs_required(jsp.loop_unroll + 1, jsp.num_srcs);
221 if (num_regs > max_vregs_available(isa_has_bf16(jsp.isa))) break;
222 }
223 if (jsp.loop_unroll == 0) return status::unimplemented;
224 jsp.size_blocking = bf16_simd_w * jsp.loop_unroll;
225
226 const memory_desc_wrapper o_d(&dst_d);
227 jsp.is_bf16_dst = data_type::bf16 == o_d.data_type();
228
229 jsp.typesize_in = sizeof(bfloat16_t);
230 jsp.typesize_out = types::data_type_size(o_d.data_type());
231
232 return status::success;
233}
234
235template <data_type_t src_data_type, data_type_t dst_data_type>
236status_t jit_bf16_sum_t<src_data_type, dst_data_type>::execute(
237 const exec_ctx_t &ctx) const {
238 auto output = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
239 const memory_desc_wrapper o_d(pd()->dst_md());
240 output += o_d.blk_off(0);
241 const int num_arrs = pd()->n_inputs();
242 const dim_t nelems = o_d.nelems(true);
243 const src_data_t *input_ptrs[jit_avx512_core_bf16_sum_kernel::max_num_arrs];
244 /* Number of scales needs to be multiple of 2 in order
245 to use VNNI instructions */
246 src_data_t scales[jit_avx512_core_bf16_sum_kernel::max_num_arrs];
247 for (int a = 0; a < num_arrs; ++a) {
248 const memory_desc_wrapper i_d(pd()->src_md(a));
249
250 input_ptrs[a]
251 = CTX_IN_MEM(const src_data_t *, DNNL_ARG_MULTIPLE_SRC + a)
252 + i_d.blk_off(0);
253 }
254 cvt_float_to_bfloat16(scales, &pd()->scales()[0], num_arrs);
255 if (num_arrs % 2 != 0) scales[num_arrs] = 0.0f;
256
257 const dim_t half_L1 = 16 * 1024; // bytes
258 const dim_t num_elems_in_block = utils::rnd_up(
259 utils::div_up(half_L1,
260 num_arrs * sizeof(src_data_t) + sizeof(dst_data_t)),
261 pd()->jsp_.size_blocking);
262 const dim_t num_blocks = nelems / num_elems_in_block;
263 const dim_t tail = nelems % num_elems_in_block;
264
265#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8 \
266 && __GNUC_PATCHLEVEL__ == 3
267// GCC issues a false positive warning 'array subscript is above array bounds'
268// with gcc 4.8.3 + -march=native option, so disable it for now
269#pragma GCC diagnostic push
270#pragma GCC diagnostic ignored "-Warray-bounds"
271#endif
272 parallel(0, [&](const int ithr, const int nthr) {
273 dim_t start {0}, end {0};
274 balance211(num_blocks, nthr, ithr, start, end);
275 auto arg = jit_sum_call_s();
276 const src_data_t *
277 local_input_ptrs[jit_avx512_core_bf16_sum_kernel::max_num_arrs];
278 dst_data_t *local_output;
279
280 for (dim_t nb = start; nb < end; ++nb) {
281 dim_t start_e = nb * num_elems_in_block;
282 for (int a = 0; a < num_arrs; ++a) {
283 local_input_ptrs[a] = &input_ptrs[a][start_e];
284 }
285 local_output = &output[start_e];
286 arg.srcs = (const void **)local_input_ptrs;
287 arg.dst = (const void *)local_output;
288 arg.scales = (const void *)scales;
289 arg.size = num_elems_in_block;
290 (*kernel_)(&arg);
291 }
292
293 if (tail != 0 && ithr == nthr - 1) {
294 dim_t start_e = nelems - tail;
295 for (int a = 0; a < num_arrs; ++a) {
296 local_input_ptrs[a] = &input_ptrs[a][start_e];
297 }
298 local_output = &output[start_e];
299 arg.srcs = (const void **)local_input_ptrs;
300 arg.dst = (const void *)local_output;
301 arg.scales = (const void *)scales;
302 arg.size = tail;
303 (*kernel_)(&arg);
304 }
305 });
306#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8 \
307 && __GNUC_PATCHLEVEL__ == 3
308#pragma GCC diagnostic pop
309#endif
310 return status::success;
311}
312
313template struct jit_bf16_sum_t<data_type::bf16, data_type::f32>;
314template struct jit_bf16_sum_t<data_type::bf16, data_type::bf16>;
315
316} // namespace x64
317} // namespace cpu
318} // namespace impl
319} // namespace dnnl
320