1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <memory> |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/math_utils.hpp" |
21 | #include "cpu/simple_q10n.hpp" |
22 | |
23 | #include "cpu/x64/cpu_isa_traits.hpp" |
24 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
25 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
26 | #include "cpu/x64/jit_generator.hpp" |
27 | |
28 | #include "cpu/x64/jit_gemm_inner_product_utils.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | namespace inner_product_utils { |
35 | |
36 | using namespace dnnl::impl::cpu::inner_product_utils; |
37 | using namespace Xbyak; |
38 | using namespace data_type; |
39 | |
40 | template <cpu_isa_t isa> |
41 | struct jit_pp_kernel_t : public pp_kernel_t, public jit_generator { |
42 | DECLARE_CPU_JIT_AUX_FUNCTIONS(inner_product_utils::jit_pp_kernel_t); |
43 | |
44 | jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, |
45 | const primitive_attr_t *attr, data_type_t bias_dt, |
46 | data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum); |
47 | |
48 | void operator()(void *dst, const void *acc, const char *bias, |
49 | const float *scales, float dst_scale, size_t start, |
50 | size_t dst_logical_off, size_t dim1_off, size_t end, |
51 | size_t runtime_oc, dim_t dst_mb_stride, |
52 | const float *dst_zero_points, |
53 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
54 | size_t first_mb_matrix_addr_off, const exec_ctx_t &ctx, |
55 | const memory_desc_t &dst_md) const override; |
56 | |
57 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
58 | |
59 | private: |
60 | using Vmm = typename utils::conditional3<isa == sse41, Xbyak::Xmm, |
61 | isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type; |
62 | |
63 | enum class arg_t { dst, acc, bias, stack, scale, sum }; |
64 | enum class data_op_t { load, store }; |
65 | |
66 | void apply_postops(const bool apply_mask, const int vmm_idx, |
67 | const size_t offset, bool runtime_tail_mask); |
68 | void prepare_mask(const size_t tail); |
69 | void load_no_tail(const Vmm v, Xbyak::Address op, const data_type_t dt); |
70 | void load_tail(const Vmm v, const arg_t arg_num, const size_t off, |
71 | const data_type_t dt, const size_t tail); |
72 | void load_and_cvt(const Vmm v, const arg_t arg_num, const size_t off, |
73 | const size_t tail, bool do_cvt = true); |
74 | // convert and store instances for each case of Vmm |
75 | void cvt_and_store(const Xbyak::Zmm v, const arg_t arg_num, |
76 | const size_t off, const size_t tail); |
77 | void cvt_and_store(const Xbyak::Ymm v, const arg_t arg_num, |
78 | const size_t off, const size_t tail); |
79 | void cvt_and_store(const Xbyak::Xmm v, const arg_t arg_num, |
80 | const size_t off, const size_t tail); |
81 | void runtime_tail_load_cvt(const Vmm v, const arg_t arg_num, |
82 | const size_t off, bool cvt = true); |
83 | void runtime_tail_cvt_store( |
84 | const Vmm v, const arg_t arg_num, const size_t off); |
85 | void data_copy(const Vmm v, const arg_t arg_num, const size_t off, |
86 | data_op_t data_op, const size_t tail, |
87 | const bool is_needed_runtime_tail_process, |
88 | const bool do_cvt = true); |
89 | void generate() override; |
90 | void compute_oc_channel_blk(); |
91 | void compute_mb_blk(); // vectorize across minibatch |
92 | |
93 | void advance_binary_postops_off(const size_t offset); |
94 | |
95 | void advance_binary_postops_off(const Xbyak::Reg64 offset); |
96 | |
97 | template <typename T> |
98 | void advance_binary_postops_per_oc_off(const T &offset); |
99 | |
100 | void update_binary_postops_per_tensor_off(); |
101 | |
102 | template <typename T> |
103 | void advance_binary_postops_channel_bcast_off(const T &offset); |
104 | |
105 | struct ker_args_t { |
106 | char *dst = nullptr; |
107 | const char *acc = nullptr; |
108 | const char *bias = nullptr; |
109 | const float *scales = nullptr; |
110 | float dst_scale = 1.0f; |
111 | const float *dst_zero_points = nullptr; |
112 | float nslope = 0; |
113 | size_t oc = 0; |
114 | size_t len = 0; |
115 | size_t oc_offset = 0; |
116 | size_t dim1_off = 0; |
117 | size_t dst_logical_off = 0; |
118 | size_t first_mb_matrix_addr_off = 0; |
119 | dim_t dst_mb_stride = 0; |
120 | const void *post_ops_binary_rhs_arg_vec = nullptr; |
121 | const void *dst_orig = nullptr; |
122 | }; |
123 | |
124 | const bool is_avx512_ = utils::one_of(isa, avx512_core, avx512_core_bf16); |
125 | static constexpr cpu_isa_t inject_isa_ |
126 | = isa == avx512_core_bf16 ? avx512_core : isa; |
127 | std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa_>> |
128 | postops_injector_; |
129 | |
130 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
131 | |
132 | #ifdef _WIN32 |
133 | const Xbyak::Reg64 reg_binary_inj_param_ = abi_not_param1; |
134 | #else |
135 | const Xbyak::Reg64 reg_binary_inj_param_ = abi_param1; |
136 | #endif |
137 | |
138 | const Xbyak::Reg64 reg_param = abi_param1; |
139 | const Xbyak::Reg64 reg_stack_frame_ = rbp; |
140 | const Xbyak::Reg64 reg_dst = rdx; |
141 | const Xbyak::Reg64 reg_acc = rax; |
142 | const Xbyak::Reg64 reg_bias = rbx; |
143 | const Xbyak::Reg64 reg_scales = rsi; |
144 | |
145 | const Xbyak::Reg64 reg_oc = r13; |
146 | const Xbyak::Reg64 reg_len = r8; |
147 | const Xbyak::Reg64 reg_tmp = rcx; // intentional for shifting purposes |
148 | const Xbyak::Reg64 reg_tail = reg_tmp; |
149 | const Xbyak::Reg64 reg_oc_offset = r9; |
150 | const Xbyak::Reg64 reg_rem_mask = r10; |
151 | const Xbyak::Opmask kreg_rem_mask = k1; |
152 | const Xbyak::Opmask opmask_binary = k3; |
153 | const Vmm vmm_rem_mask = Vmm(0); |
154 | // register used for temp computation, needs not to be preserved |
155 | const Xbyak::Reg64 reg_tmp_comp = r15; |
156 | |
157 | // *mb_stride used only in matmul_pp_kernel && compute_oc_channel_blk() |
158 | const Xbyak::Reg64 reg_dst_mb_stride = r12; |
159 | const Xbyak::Reg64 reg_acc_mb_stride = r14; |
160 | |
161 | // Will be assigned in constructor |
162 | Vmm vreg_zero, vreg_saturation_ubound, vreg_scale, vreg_dst_scale, |
163 | vreg_sum_scale, vreg_sum_zp, vreg_dst_zero_points; |
164 | |
165 | const Xbyak::Reg64 eltwise_reserved_gpr_ = r11; |
166 | const Xbyak::Opmask eltwise_reserved_opmask_ = k2; |
167 | |
168 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(28); |
169 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(29); |
170 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(30); |
171 | Xbyak::Reg64 bf16_emu_reserv_4 = reg_tmp_comp; |
172 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(31); |
173 | |
174 | int max_OC_loop_unroll_ = 13; |
175 | int idx_compute_vreg_start_ = is_avx512_ ? 0 : 1; |
176 | int idx_compute_vreg_max_ = is_avx512_ ? 31 : 15; |
177 | int compute_vregs_per_iter_ = 1; |
178 | int compute_vreg_bias_shift_ = 0; |
179 | int compute_vreg_prev_dst_shift_ = 0; |
180 | |
181 | const size_t vlen = cpu_isa_traits<isa>::vlen / sizeof(float); |
182 | static constexpr int reg64_size_ = sizeof(int64_t); |
183 | static constexpr int reg_binary_post_op_oc_off_ = 0; |
184 | static constexpr int reg_binary_post_op_offset_ = 1 * reg64_size_; |
185 | static constexpr int reg_binary_post_op_sp_off_ = 2 * reg64_size_; |
186 | static constexpr int reg_origin_dst_ptr_ = 3 * reg64_size_; |
187 | static constexpr int stack_space_needed_ = 4 * reg64_size_; |
188 | |
189 | bool any_binary_postop_is_no_bcast_type_ = false; |
190 | bool any_binary_postop_is_per_oc_bcast_type_ = false; |
191 | bool any_binary_postop_is_per_oc_sp_bcast_type_ = false; |
192 | bool any_binary_postop_is_oc_bcast_type_ = false; |
193 | |
194 | int vreg_dst_idx(const int iter) const { |
195 | int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_; |
196 | assert(idx <= idx_compute_vreg_max_); |
197 | return idx; |
198 | } |
199 | |
200 | Vmm vreg_dst(int iter) { return Vmm(vreg_dst_idx(iter)); } |
201 | |
202 | Vmm vreg_prev_dst(int iter) { |
203 | int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_ |
204 | + compute_vreg_prev_dst_shift_; |
205 | assert(idx <= idx_compute_vreg_max_); |
206 | return Vmm(idx); |
207 | } |
208 | |
209 | Vmm vreg_bias(int iter) { |
210 | int idx = idx_compute_vreg_start_ + iter * compute_vregs_per_iter_ |
211 | + compute_vreg_bias_shift_; |
212 | assert(idx <= idx_compute_vreg_max_); |
213 | return Vmm(idx); |
214 | } |
215 | |
216 | Xbyak::Address dst_ptr(const size_t offt) { return ptr[reg_dst + offt]; } |
217 | |
218 | Xbyak::Address acc_ptr(const size_t offt) { return ptr[reg_acc + offt]; } |
219 | |
220 | Xbyak::Address bias_ptr(const size_t offt) { return ptr[reg_bias + offt]; } |
221 | |
222 | Xbyak::Address stack_ptr(const size_t offt) { return ptr[rsp + offt]; } |
223 | |
224 | Xbyak::Address scale_ptr(const size_t offt) { |
225 | return ptr[reg_scales + offt]; |
226 | } |
227 | |
228 | Xbyak::Address get_address(const arg_t arg_num, const size_t off) { |
229 | switch (arg_num) { |
230 | case arg_t::dst: |
231 | case arg_t::sum: return dst_ptr(off); |
232 | case arg_t::acc: return acc_ptr(off); |
233 | case arg_t::bias: return bias_ptr(off); |
234 | case arg_t::stack: return stack_ptr(off); |
235 | case arg_t::scale: return scale_ptr(off); |
236 | default: assert(!"unsupported arg_num" ); break; |
237 | } |
238 | return Xbyak::Address(0); |
239 | } |
240 | |
241 | Xbyak::Reg64 get_reg_address(const arg_t arg_num) { |
242 | switch (arg_num) { |
243 | case arg_t::dst: |
244 | case arg_t::sum: return reg_dst; |
245 | case arg_t::acc: return reg_acc; |
246 | case arg_t::bias: return reg_bias; |
247 | case arg_t::stack: return rsp; |
248 | case arg_t::scale: return reg_scales; |
249 | default: assert(!"unsupported arg_num" ); break; |
250 | } |
251 | return rsp; |
252 | } |
253 | |
254 | data_type_t get_data_type(const arg_t arg_num) { |
255 | switch (arg_num) { |
256 | case arg_t::dst: return this->dst_data_type_; |
257 | case arg_t::sum: return this->sum_data_type_; |
258 | case arg_t::acc: return this->acc_data_type_; |
259 | case arg_t::bias: return this->bias_data_type_; |
260 | // default for stack or scale operation |
261 | default: return f32; |
262 | } |
263 | return data_type::undef; |
264 | } |
265 | }; |
266 | |
267 | template <cpu_isa_t isa> |
268 | jit_pp_kernel_t<isa>::jit_pp_kernel_t(size_t OC, size_t MB, dim_t dst_mb_stride, |
269 | const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt, |
270 | const memory_desc_t *dst_md, bool skip_sum) |
271 | : pp_kernel_t( |
272 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum) |
273 | , jit_generator(jit_name()) { |
274 | assert(IMPLICATION(this->dst_data_type_ == bf16, mayiuse(avx512_core))); |
275 | |
276 | if (this->do_scale_) vreg_scale = Vmm(idx_compute_vreg_start_++); |
277 | |
278 | if (this->dst_data_type_ == u8) vreg_zero = Vmm(idx_compute_vreg_start_++); |
279 | if (utils::one_of(this->dst_data_type_, u8, s8, s32)) |
280 | vreg_saturation_ubound = Vmm(idx_compute_vreg_start_++); |
281 | |
282 | if (this->do_sum_) { |
283 | compute_vreg_prev_dst_shift_ = compute_vregs_per_iter_++; |
284 | if (this->sum_scale_ != 1.f) |
285 | vreg_sum_scale = Vmm(idx_compute_vreg_start_++); |
286 | if (this->sum_zp_ != 0) vreg_sum_zp = Vmm(idx_compute_vreg_start_++); |
287 | } |
288 | |
289 | if (this->do_bias()) compute_vreg_bias_shift_ = compute_vregs_per_iter_++; |
290 | |
291 | if (!attr->scales_.get(DNNL_ARG_DST).has_default_values()) { |
292 | this->do_dst_scale_ = true; |
293 | vreg_dst_scale = Vmm(idx_compute_vreg_start_++); |
294 | } |
295 | |
296 | if (!attr->zero_points_.has_default_values(DNNL_ARG_DST)) { |
297 | this->do_dst_zero_points_ = true; |
298 | vreg_dst_zero_points = Vmm(idx_compute_vreg_start_++); |
299 | } |
300 | |
301 | if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16) { |
302 | idx_compute_vreg_max_ = 27; |
303 | bf16_emu_.reset(new bf16_emulation_t(this, bf16_emu_reserv_1, |
304 | bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_reserv_4, |
305 | bf16_emu_reserv_5)); |
306 | } |
307 | |
308 | int max_unroll = (idx_compute_vreg_max_ - idx_compute_vreg_start_ + 1) |
309 | / compute_vregs_per_iter_; |
310 | max_OC_loop_unroll_ = nstl::min(max_OC_loop_unroll_, max_unroll); |
311 | if (this->do_eltwise_ || this->do_binary_) { |
312 | #define PARAM_OFF(field) offsetof(ker_args_t, field) |
313 | static constexpr bool preserve_gpr = true; |
314 | static constexpr bool preserve_vmm = false; |
315 | static const size_t helper_vmm_idx = is_avx512_ ? 31 : 15; |
316 | static constexpr bool use_exact_tail_scalar_bcast = false; |
317 | const auto dst_md_wrapper = memory_desc_wrapper(*dst_md); |
318 | |
319 | size_t OC_loop, OC_tail; |
320 | if (OC < max_OC_loop_unroll_ * vlen) { |
321 | // Fully unroll small loops |
322 | OC_loop = 0; |
323 | OC_tail = OC; |
324 | } else { |
325 | OC_loop = vlen * max_OC_loop_unroll_; |
326 | OC_tail = OC % OC_loop; |
327 | } |
328 | size_t tail_size = OC_tail % vlen; |
329 | // enable tail processing for runtime load even if there is no tail |
330 | // for the OC |
331 | tail_size = !!tail_size ? tail_size : 1; |
332 | const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { |
333 | helper_vmm_idx, eltwise_reserved_gpr_, r14, r15, preserve_gpr, |
334 | preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), |
335 | PARAM_OFF(dst_orig), dst_md_wrapper, tail_size, opmask_binary, |
336 | reg_tmp, use_exact_tail_scalar_bcast}; |
337 | static const bcast_set_t enabled_bcast_strategy |
338 | = {broadcasting_strategy_t::scalar, |
339 | broadcasting_strategy_t::per_oc, |
340 | broadcasting_strategy_t::per_oc_spatial, |
341 | broadcasting_strategy_t::per_mb_spatial, |
342 | broadcasting_strategy_t::per_mb_w, |
343 | broadcasting_strategy_t::per_w, |
344 | broadcasting_strategy_t::no_broadcast}; |
345 | const binary_injector::static_params_t binary_static_params { |
346 | reg_binary_inj_param_, enabled_bcast_strategy, |
347 | rhs_arg_static_params}; |
348 | static constexpr bool save_state = true; |
349 | const eltwise_injector::static_params_t eltwise_static_params { |
350 | save_state, reg_tmp_comp, eltwise_reserved_opmask_}; |
351 | |
352 | postops_injector_ = utils::make_unique< |
353 | injector::jit_uni_postops_injector_t<inject_isa_>>(this, |
354 | this->post_ops_, binary_static_params, eltwise_static_params); |
355 | |
356 | using namespace dnnl::impl::cpu::binary_injector_utils; |
357 | std::tie(any_binary_postop_is_no_bcast_type_, |
358 | any_binary_postop_is_per_oc_bcast_type_, |
359 | any_binary_postop_is_per_oc_sp_bcast_type_, |
360 | any_binary_postop_is_oc_bcast_type_) |
361 | = bcast_strategies_present_tup(this->post_ops_.entry_, |
362 | dst_md_wrapper, broadcasting_strategy_t::no_broadcast, |
363 | broadcasting_strategy_t::per_oc, |
364 | broadcasting_strategy_t::per_oc_spatial, |
365 | broadcasting_strategy_t::per_mb_spatial); |
366 | } |
367 | #undef PARAM_OFF |
368 | } |
369 | |
370 | template <cpu_isa_t isa> |
371 | template <typename T> |
372 | void jit_pp_kernel_t<isa>::advance_binary_postops_per_oc_off(const T &offset) { |
373 | |
374 | const auto binary_post_op_oc_off_reg = reg_tmp_comp; |
375 | const auto binary_post_op_current_offset_on_stack |
376 | = ptr[rsp + reg_binary_post_op_oc_off_]; |
377 | |
378 | mov(binary_post_op_oc_off_reg, binary_post_op_current_offset_on_stack); |
379 | add(binary_post_op_oc_off_reg, offset); |
380 | |
381 | if (this->ndims_ == 2) { |
382 | Xbyak::Label end; |
383 | cmp(binary_post_op_oc_off_reg, this->OC_); |
384 | jl(end, T_NEAR); |
385 | xor_(binary_post_op_oc_off_reg, binary_post_op_oc_off_reg); |
386 | L(end); |
387 | } |
388 | |
389 | mov(binary_post_op_current_offset_on_stack, binary_post_op_oc_off_reg); |
390 | } |
391 | |
392 | template <cpu_isa_t isa> |
393 | void jit_pp_kernel_t<isa>::update_binary_postops_per_tensor_off() { |
394 | // substract dst_origin from current dst and divide it by dst data type |
395 | // size to get the correct offset |
396 | const auto binary_post_op_offset_reg = reg_tmp_comp; |
397 | const auto binary_post_op_current_offset_on_stack |
398 | = ptr[rsp + reg_binary_post_op_offset_]; |
399 | mov(binary_post_op_offset_reg, reg_dst); |
400 | sub(binary_post_op_offset_reg, ptr[rsp + reg_origin_dst_ptr_]); |
401 | sar(binary_post_op_offset_reg, |
402 | std::log2(types::data_type_size(get_data_type(arg_t::dst)))); |
403 | mov(binary_post_op_current_offset_on_stack, binary_post_op_offset_reg); |
404 | } |
405 | |
406 | template <cpu_isa_t isa> |
407 | template <typename T> |
408 | void jit_pp_kernel_t<isa>::advance_binary_postops_channel_bcast_off( |
409 | const T &offset) { |
410 | |
411 | const auto binary_post_op_offset_reg = reg_tmp_comp; |
412 | const auto binary_post_op_current_offset_on_stack |
413 | = ptr[rsp + reg_binary_post_op_sp_off_]; |
414 | mov(binary_post_op_offset_reg, binary_post_op_current_offset_on_stack); |
415 | add(binary_post_op_offset_reg, offset); |
416 | mov(binary_post_op_current_offset_on_stack, binary_post_op_offset_reg); |
417 | } |
418 | |
419 | /* |
420 | * Advance binary postops offsets with per_tensor_offset passed as plain value |
421 | * type (const offset value). |
422 | */ |
423 | template <cpu_isa_t isa> |
424 | void jit_pp_kernel_t<isa>::advance_binary_postops_off(const size_t offset) { |
425 | if (offset) { |
426 | if (any_binary_postop_is_per_oc_bcast_type_) |
427 | advance_binary_postops_per_oc_off(offset); |
428 | if (any_binary_postop_is_no_bcast_type_) |
429 | update_binary_postops_per_tensor_off(); |
430 | if (any_binary_postop_is_oc_bcast_type_) |
431 | advance_binary_postops_channel_bcast_off(offset); |
432 | } |
433 | } |
434 | |
435 | /* |
436 | * Advance binary postops offsets with per_tensor_offset passed in Reg64. |
437 | */ |
438 | template <cpu_isa_t isa> |
439 | void jit_pp_kernel_t<isa>::advance_binary_postops_off( |
440 | const Xbyak::Reg64 reg_offset) { |
441 | if (any_binary_postop_is_per_oc_bcast_type_) |
442 | advance_binary_postops_per_oc_off(reg_offset); |
443 | if (any_binary_postop_is_no_bcast_type_) |
444 | update_binary_postops_per_tensor_off(); |
445 | if (any_binary_postop_is_oc_bcast_type_) |
446 | advance_binary_postops_channel_bcast_off(reg_offset); |
447 | } |
448 | |
449 | template <cpu_isa_t isa> |
450 | void jit_pp_kernel_t<isa>::apply_postops(const bool apply_mask, |
451 | const int vmm_idx, const size_t offset, const bool runtime_tail_mask) { |
452 | if (this->do_eltwise_ || this->do_binary_) { |
453 | if (this->do_binary_) { |
454 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
455 | if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
456 | rhs_arg_params.tail_load_mode = runtime_tail_mask |
457 | ? binary_injector::tail_lode_mode_t::DYNAMIC |
458 | : binary_injector::tail_lode_mode_t::DEFAULT; |
459 | |
460 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); |
461 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, offset); |
462 | |
463 | postops_injector_->compute_vector(vmm_idx, rhs_arg_params); |
464 | } else |
465 | postops_injector_->compute_vector(vmm_idx); |
466 | } |
467 | } |
468 | |
469 | template <cpu_isa_t isa> |
470 | void jit_pp_kernel_t<isa>::prepare_mask(const size_t tail) { |
471 | assert(tail > 0 && tail <= vlen - 1); |
472 | if (is_avx512_) { |
473 | const size_t tail_mask = (1 << tail) - 1; |
474 | mov(reg_tmp, tail_mask); |
475 | kmovq(kreg_rem_mask, reg_tmp); |
476 | } else if (isa == avx2) { |
477 | static const uint32_t mask_f32[14] |
478 | = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, |
479 | 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0}; |
480 | |
481 | mov(reg_tmp, reinterpret_cast<size_t>(&mask_f32[7 - tail])); |
482 | vmovups(vmm_rem_mask, ptr[reg_tmp]); |
483 | } |
484 | } |
485 | |
486 | template <cpu_isa_t isa> |
487 | void jit_pp_kernel_t<isa>::load_no_tail( |
488 | const Vmm v, Xbyak::Address op, const data_type_t dt) { |
489 | using namespace data_type; |
490 | switch (dt) { |
491 | case s8: uni_vpmovsxbd(v, op); break; |
492 | case u8: uni_vpmovzxbd(v, op); break; |
493 | case s32: |
494 | case f32: uni_vmovups(v, op); break; |
495 | case bf16: |
496 | vpmovzxwd(v, op); |
497 | vpslld(v, v, 0x10); |
498 | break; |
499 | default: assert(!"unimplemented" ); |
500 | } |
501 | } |
502 | |
503 | template <cpu_isa_t isa> |
504 | void jit_pp_kernel_t<isa>::load_tail(const Vmm v, const arg_t arg_num, |
505 | const size_t off, const data_type_t dt, const size_t tail) { |
506 | using namespace data_type; |
507 | if (is_avx512_) { |
508 | auto v_dst = tail ? v | kreg_rem_mask : v; |
509 | load_no_tail(v_dst, get_address(arg_num, off), dt); |
510 | } else { |
511 | if (utils::one_of(dt, s8, u8)) { |
512 | const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx()); |
513 | for (size_t i = 0; i < tail; i++) |
514 | uni_vpinsrb(x, x, get_address(arg_num, i + off), i); |
515 | if (dt == s8) |
516 | uni_vpmovsxbd(v, v); |
517 | else |
518 | uni_vpmovzxbd(v, v); |
519 | } else { |
520 | const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
521 | if (is_ymm) { |
522 | vmaskmovps(v, vmm_rem_mask, get_address(arg_num, off)); |
523 | } else { |
524 | const size_t dt_size = types::data_type_size(f32); |
525 | for (size_t i = 0; i < tail; i++) |
526 | uni_vpinsrd( |
527 | v, v, get_address(arg_num, i * dt_size + off), i); |
528 | } |
529 | } |
530 | } |
531 | } |
532 | |
533 | template <cpu_isa_t isa> |
534 | void jit_pp_kernel_t<isa>::load_and_cvt(const Vmm v, const arg_t arg_num, |
535 | const size_t off, const size_t tail, bool do_cvt) { |
536 | using namespace data_type; |
537 | const data_type_t dt = get_data_type(arg_num); |
538 | if (tail) |
539 | load_tail(v, arg_num, off, dt, tail); |
540 | else |
541 | load_no_tail(v, get_address(arg_num, off), dt); |
542 | |
543 | if (do_cvt && utils::one_of(dt, u8, s8, s32)) uni_vcvtdq2ps(v, v); |
544 | } |
545 | |
546 | template <cpu_isa_t isa> |
547 | void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Zmm v, |
548 | const arg_t arg_num, const size_t off, const size_t tail) { |
549 | using namespace data_type; |
550 | const data_type_t dt = get_data_type(arg_num); |
551 | if (!utils::one_of(dt, f32, bf16)) { |
552 | Vmm vreg = Vmm(v.getIdx()); // in case of use Ymm for bf16 |
553 | saturate_f32(vreg, vreg_zero, vreg_saturation_ubound, dt); |
554 | vcvtps2dq(v, v); |
555 | } else if (dt == bf16) { |
556 | if (isa == avx512_core_bf16) |
557 | vcvtneps2bf16(Ymm(v.getIdx()), v); |
558 | else |
559 | bf16_emu_->vcvtneps2bf16(Ymm(v.getIdx()), v); |
560 | } |
561 | |
562 | auto v_src = tail ? v | kreg_rem_mask : v; |
563 | const Xbyak::Address dst = get_address(arg_num, off); |
564 | switch (dt) { |
565 | case s8: vpmovsdb(dst, v_src); break; |
566 | case u8: vpmovusdb(dst, v_src); break; |
567 | case f32: |
568 | case s32: uni_vmovups(dst, v_src); break; |
569 | case bf16: |
570 | vmovdqu16(dst, |
571 | tail ? Ymm(v.getIdx()) | kreg_rem_mask : Ymm(v.getIdx())); |
572 | break; |
573 | default: assert(!"unimplemented" ); |
574 | } |
575 | } |
576 | |
577 | template <cpu_isa_t isa> |
578 | void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Ymm v, |
579 | const arg_t arg_num, const size_t off, const size_t tail) { |
580 | using namespace data_type; |
581 | const data_type_t dt = get_data_type(arg_num); |
582 | const Xbyak::Address dst = get_address(arg_num, off); |
583 | const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx()); |
584 | if (dt == bf16) { |
585 | // use Zmm implementation for bf16 with Ymm |
586 | cvt_and_store(Xbyak::Zmm(v.getIdx()), arg_num, off, tail); |
587 | return; |
588 | } else if (utils::one_of(dt, s8, u8, s32)) { |
589 | saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt); |
590 | vcvtps2dq(v, v); |
591 | |
592 | if (dt != s32) { |
593 | // v = { 8x32 } |
594 | vpackssdw(v, v, vreg_zero); |
595 | // v = { 4x16, 0, 4x16, 0 } |
596 | vpermq(v, v, 0x58); |
597 | // v = { 8x16, 0 } |
598 | if (dt == s8) |
599 | vpacksswb(v, v, vreg_zero); |
600 | else |
601 | vpackuswb(v, v, vreg_zero); |
602 | } |
603 | } |
604 | |
605 | if (tail) { |
606 | switch (dt) { |
607 | case s8: |
608 | case u8: |
609 | for (size_t i = 0; i < tail; i++) |
610 | vpextrb(get_address(arg_num, off + i), x, i); |
611 | break; |
612 | case f32: |
613 | case s32: vmaskmovps(dst, vmm_rem_mask, v); break; |
614 | default: assert(!"unimplemented" ); |
615 | } |
616 | } else { |
617 | switch (dt) { |
618 | case s8: |
619 | case u8: vmovq(dst, x); break; |
620 | case f32: |
621 | case s32: vmovups(dst, v); break; |
622 | default: assert(!"unimplemented" ); |
623 | } |
624 | } |
625 | } |
626 | |
627 | template <cpu_isa_t isa> |
628 | void jit_pp_kernel_t<isa>::cvt_and_store(const Xbyak::Xmm v, |
629 | const arg_t arg_num, const size_t off, const size_t tail) { |
630 | using namespace data_type; |
631 | const data_type_t dt = get_data_type(arg_num); |
632 | const Xbyak::Address dst = get_address(arg_num, off); |
633 | if (utils::one_of(dt, s8, u8, s32)) { |
634 | saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt); |
635 | uni_vcvtps2dq(v, v); |
636 | |
637 | if (dt != s32) { |
638 | // v = { 8x32 } |
639 | uni_vpackssdw(v, v, vreg_zero); |
640 | // v = { 4x16, 0} |
641 | if (dt == s8) |
642 | uni_vpacksswb(v, v, vreg_zero); |
643 | else |
644 | uni_vpackuswb(v, v, vreg_zero); |
645 | } |
646 | } |
647 | |
648 | if (tail) { |
649 | switch (dt) { |
650 | case s8: |
651 | case u8: |
652 | for (size_t i = 0; i < tail; i++) |
653 | uni_vpextrb(get_address(arg_num, off + i), v, i); |
654 | break; |
655 | case f32: |
656 | case s32: { |
657 | const size_t dt_size = types::data_type_size(f32); |
658 | for (size_t i = 0; i < tail; i++) |
659 | uni_vpextrd(get_address(arg_num, off + i * dt_size), v, i); |
660 | } break; |
661 | default: assert(!"unimplemented" ); |
662 | } |
663 | } else { |
664 | switch (dt) { |
665 | case s8: |
666 | case u8: uni_vmovd(dst, v); break; |
667 | case f32: |
668 | case s32: uni_vmovups(dst, v); break; |
669 | default: assert(!"unimplemented" ); |
670 | } |
671 | } |
672 | } |
673 | |
674 | template <cpu_isa_t isa> |
675 | void jit_pp_kernel_t<isa>::runtime_tail_load_cvt( |
676 | const Vmm v, const arg_t arg_num, const size_t off, bool cvt) { |
677 | assert(!is_avx512_); |
678 | const data_type_t dt = get_data_type(arg_num); |
679 | const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
680 | const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx()); |
681 | const Xbyak::Ymm y = Xbyak::Ymm(v.getIdx()); |
682 | const Xbyak::Reg64 reg_addr = get_reg_address(arg_num); |
683 | |
684 | auto runtime_tail_load = [&](int load_size) { |
685 | if (is_ymm) |
686 | load_data(dt, y, reg_addr, off, load_size); |
687 | else |
688 | load_data(dt, x, reg_addr, off, load_size); |
689 | }; |
690 | |
691 | runtime_tail_process<Vmm>(reg_tail, reg_rem_mask, runtime_tail_load); |
692 | |
693 | if (cvt && utils::one_of(dt, u8, s8, s32)) uni_vcvtdq2ps(v, v); |
694 | } |
695 | |
696 | template <cpu_isa_t isa> |
697 | void jit_pp_kernel_t<isa>::runtime_tail_cvt_store( |
698 | const Vmm v, const arg_t arg_num, const size_t off) { |
699 | assert(!is_avx512_); |
700 | const data_type_t dt = get_data_type(arg_num); |
701 | const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
702 | const Xbyak::Xmm x = Xbyak::Xmm(v.getIdx()); |
703 | const Xbyak::Ymm y = Xbyak::Ymm(v.getIdx()); |
704 | const Xbyak::Reg64 reg_addr = get_reg_address(arg_num); |
705 | |
706 | if (utils::one_of(dt, u8, s8, s32)) { |
707 | saturate_f32(v, vreg_zero, vreg_saturation_ubound, dt); |
708 | uni_vcvtps2dq(v, v); |
709 | } |
710 | |
711 | auto runtime_tail_store = [&](int store_size) { |
712 | if (is_ymm) |
713 | store_data(dt, y, reg_addr, off, store_size); |
714 | else |
715 | store_data(dt, x, reg_addr, off, store_size); |
716 | }; |
717 | |
718 | runtime_tail_process<Vmm>(reg_tail, reg_rem_mask, runtime_tail_store); |
719 | } |
720 | |
721 | template <cpu_isa_t isa> |
722 | void jit_pp_kernel_t<isa>::data_copy(const Vmm v, const arg_t arg_num, |
723 | const size_t off, data_op_t data_op, const size_t tail, |
724 | const bool is_needed_runtime_tail_process, const bool do_cvt) { |
725 | if (data_op == data_op_t::load) { |
726 | if (is_needed_runtime_tail_process) |
727 | runtime_tail_load_cvt(v, arg_num, off, do_cvt); |
728 | else |
729 | load_and_cvt(v, arg_num, off, tail, do_cvt); |
730 | } else { |
731 | if (is_needed_runtime_tail_process) |
732 | runtime_tail_cvt_store(v, arg_num, off); |
733 | else |
734 | cvt_and_store(v, arg_num, off, tail); |
735 | } |
736 | } |
737 | |
738 | template <cpu_isa_t isa> |
739 | void jit_pp_kernel_t<isa>::compute_oc_channel_blk() { |
740 | // Load accumulated value, convert to float, apply bias (if any), scaling, |
741 | // and eltwise (if any); then convert to destination type and store |
742 | |
743 | auto compute = [&](size_t offset, int idx, bool runtime_tail_mask, |
744 | int tail = 0) { |
745 | const bool is_needed_runtime_tail_process |
746 | = runtime_tail_mask && tail && !is_avx512_; |
747 | |
748 | if (this->do_scale_ && this->scale_idx_mult_ == 1) |
749 | data_copy(vreg_scale, arg_t::scale, offset * sizeof(float), |
750 | data_op_t::load, tail, is_needed_runtime_tail_process, |
751 | false); |
752 | |
753 | if (this->do_binary_ && tail && is_avx512_) |
754 | kmovq(opmask_binary, kreg_rem_mask); |
755 | |
756 | const int dst_idx = vreg_dst_idx(idx); |
757 | auto vreg_dst_ = Vmm(dst_idx); |
758 | data_copy(vreg_dst_, arg_t::acc, offset * this->acc_data_type_size_, |
759 | data_op_t::load, tail, is_needed_runtime_tail_process); |
760 | |
761 | if (this->do_scale_) uni_vmulps(vreg_dst_, vreg_dst_, vreg_scale); |
762 | |
763 | if (this->do_bias()) { |
764 | auto vreg_bias_ = vreg_bias(idx); |
765 | data_copy(vreg_bias_, arg_t::bias, |
766 | offset * this->bias_data_type_size_, data_op_t::load, tail, |
767 | is_needed_runtime_tail_process); |
768 | uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_); |
769 | } |
770 | |
771 | if (this->do_sum_) { |
772 | auto vreg_prev_dst_ = vreg_prev_dst(idx); |
773 | data_copy(vreg_prev_dst_, arg_t::sum, |
774 | offset * this->dst_data_type_size_, data_op_t::load, tail, |
775 | is_needed_runtime_tail_process); |
776 | if (this->sum_zp_ != 0) |
777 | uni_vsubps(vreg_prev_dst_, vreg_prev_dst_, vreg_sum_zp); |
778 | if (this->sum_scale_ != 1.f) |
779 | uni_vfmadd231ps(vreg_dst_, vreg_prev_dst_, vreg_sum_scale); |
780 | else |
781 | uni_vaddps(vreg_dst_, vreg_dst_, vreg_prev_dst_); |
782 | } |
783 | |
784 | apply_postops(!!tail, dst_idx, offset * this->dst_data_type_size_, |
785 | is_needed_runtime_tail_process); |
786 | |
787 | if (this->do_dst_scale_) |
788 | uni_vmulps(vreg_dst_, vreg_dst_, vreg_dst_scale); |
789 | |
790 | if (this->do_dst_zero_points_) |
791 | uni_vaddps(vreg_dst_, vreg_dst_, vreg_dst_zero_points); |
792 | |
793 | data_copy(vreg_dst_, arg_t::dst, offset * this->dst_data_type_size_, |
794 | data_op_t::store, tail, is_needed_runtime_tail_process); |
795 | }; |
796 | |
797 | // Advance all pointers by an immediate |
798 | auto advance_ptrs_imm = [&](size_t offset) { |
799 | add(reg_dst, offset * this->dst_data_type_size_); |
800 | add(reg_acc, offset * this->acc_data_type_size_); |
801 | if (this->do_scale_ && this->scale_idx_mult_ == 1) |
802 | add(reg_scales, offset * sizeof(float)); |
803 | if (this->do_bias()) add(reg_bias, offset * this->bias_data_type_size_); |
804 | if (this->do_binary_) { advance_binary_postops_off(offset); } |
805 | }; |
806 | |
807 | // Advance all pointers by a value stored in a register |
808 | auto advance_ptrs_reg = [&](const Reg64 offset) { |
809 | lea(reg_dst, ptr[reg_dst + offset * this->dst_data_type_size_]); |
810 | lea(reg_acc, ptr[reg_acc + offset * this->acc_data_type_size_]); |
811 | if (this->do_scale_ && this->scale_idx_mult_ == 1) |
812 | lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]); |
813 | if (this->do_bias()) |
814 | lea(reg_bias, ptr[reg_bias + offset * this->bias_data_type_size_]); |
815 | if (this->do_binary_) advance_binary_postops_off(offset); |
816 | }; |
817 | |
818 | // incase of non-trivial dst_mb_strides, fixup the reg_dst and reg_acc |
819 | auto maybe_advance_mb_stride = [&]() { |
820 | if (!this->has_trivial_mb_stride()) { |
821 | lea(reg_dst, |
822 | ptr[reg_dst |
823 | + reg_dst_mb_stride * this->dst_data_type_size_]); |
824 | lea(reg_acc, |
825 | ptr[reg_acc |
826 | + reg_acc_mb_stride * this->acc_data_type_size_]); |
827 | } |
828 | if (this->do_binary_ && any_binary_postop_is_no_bcast_type_) |
829 | update_binary_postops_per_tensor_off(); |
830 | }; |
831 | |
832 | // Rewind pointers that point to data that is indexed by output channel |
833 | // (bias or per-oc scaling factors) |
834 | auto rewind_ptrs = [&]() { |
835 | neg(reg_oc); |
836 | if (this->do_bias()) |
837 | lea(reg_bias, ptr[reg_bias + reg_oc * this->bias_data_type_size_]); |
838 | if (this->do_scale_ && this->scale_idx_mult_ == 1) |
839 | lea(reg_scales, ptr[reg_scales + reg_oc * sizeof(float)]); |
840 | |
841 | neg(reg_oc); |
842 | }; |
843 | |
844 | // Process one row of reg_tmp elements |
845 | auto process_runtime_oc = [&]() { |
846 | Label l_loop, l_loop_tail, l_loop_end; |
847 | cmp(reg_tmp, vlen); |
848 | jl(l_loop_tail, T_NEAR); |
849 | |
850 | L(l_loop); |
851 | { |
852 | compute(0, 0, true); |
853 | advance_ptrs_imm(vlen); |
854 | |
855 | sub(reg_tmp, vlen); |
856 | cmp(reg_tmp, vlen); |
857 | jge(l_loop, T_NEAR); |
858 | } |
859 | |
860 | L(l_loop_tail); |
861 | cmp(reg_tmp, 0); |
862 | je(l_loop_end, T_NEAR); |
863 | |
864 | if (is_avx512_) { |
865 | mov(reg_rem_mask, 1); |
866 | shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here |
867 | sub(reg_rem_mask, 1); |
868 | kmovq(kreg_rem_mask, reg_rem_mask); |
869 | } |
870 | // tail size does not matter for runtime load |
871 | compute(0, 0, true, true); |
872 | advance_ptrs_reg(reg_tmp); |
873 | |
874 | L(l_loop_end); |
875 | }; |
876 | |
877 | // <-------------------- OC -------------------------------> |
878 | // |
879 | // ^ +....................+----------------------------------+ |
880 | // | : not accessed | Prologue loop | |
881 | // | +--------------------+----------------------------------+ |
882 | // | | |
883 | // M | Main loop (unrolled) | |
884 | // B | | |
885 | // +--------------------------------+----------------------+ |
886 | // | | Epilogue loop | not accessed : |
887 | // v +--------------------------------+......................+ |
888 | |
889 | if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16) |
890 | bf16_emu_->init_vcvtneps2bf16(); |
891 | |
892 | // Prologue loop |
893 | Label l_prologue_end; |
894 | cmp(reg_oc_offset, 0); |
895 | je(l_prologue_end, T_NEAR); |
896 | { |
897 | mov(reg_tmp, reg_oc); |
898 | sub(reg_tmp, reg_oc_offset); |
899 | cmp(reg_tmp, reg_len); |
900 | cmovg(reg_tmp, reg_len); |
901 | sub(reg_len, reg_tmp); |
902 | process_runtime_oc(); |
903 | rewind_ptrs(); |
904 | maybe_advance_mb_stride(); |
905 | } |
906 | L(l_prologue_end); |
907 | |
908 | // Main loop |
909 | Label l_main_loop_end; |
910 | cmp(reg_len, reg_oc); |
911 | jle(l_main_loop_end, T_NEAR); |
912 | if (this->runtime_oc()) { |
913 | Label l_main_loop; |
914 | L(l_main_loop); |
915 | { |
916 | mov(reg_tmp, reg_oc); |
917 | |
918 | process_runtime_oc(); |
919 | rewind_ptrs(); |
920 | |
921 | sub(reg_len, reg_oc); |
922 | maybe_advance_mb_stride(); |
923 | cmp(reg_len, reg_oc); |
924 | jge(l_main_loop, T_NEAR); |
925 | } |
926 | } else { |
927 | Label l_main_loop; |
928 | L(l_main_loop); |
929 | { |
930 | size_t OC_loop, OC_tail; |
931 | if (this->OC_ < max_OC_loop_unroll_ * vlen) { |
932 | // Fully unroll small loops |
933 | OC_loop = 0; |
934 | OC_tail = this->OC_; |
935 | } else { |
936 | OC_loop = vlen * max_OC_loop_unroll_; |
937 | OC_tail = this->OC_ % OC_loop; |
938 | } |
939 | |
940 | assert(!!OC_loop || !!OC_tail); |
941 | |
942 | const int vlen_tail = OC_tail % vlen; |
943 | if (vlen_tail) prepare_mask(vlen_tail); |
944 | |
945 | if (OC_loop) { |
946 | mov(reg_tmp, utils::rnd_dn(this->OC_, OC_loop)); |
947 | Label l_oc_loop; |
948 | L(l_oc_loop); |
949 | { |
950 | for (size_t offset = 0; offset < OC_loop; offset += vlen) |
951 | compute(offset, offset / vlen, false); |
952 | advance_ptrs_imm(OC_loop); |
953 | sub(reg_tmp, OC_loop); |
954 | jnz(l_oc_loop); |
955 | } |
956 | } |
957 | |
958 | if (OC_tail) { |
959 | for (size_t offset = 0; offset < OC_tail; offset += vlen) { |
960 | const bool use_mask = (offset + vlen) > OC_tail; |
961 | compute(offset, offset / vlen, false, |
962 | use_mask ? vlen_tail : 0); |
963 | } |
964 | advance_ptrs_imm(OC_tail); |
965 | } |
966 | |
967 | if (any_binary_postop_is_per_oc_sp_bcast_type_ |
968 | && this->ndims_ <= 3) { |
969 | static constexpr size_t offset_oc_spatial = 1; |
970 | advance_binary_postops_per_oc_off(offset_oc_spatial); |
971 | } |
972 | |
973 | rewind_ptrs(); |
974 | sub(reg_len, reg_oc); |
975 | maybe_advance_mb_stride(); |
976 | cmp(reg_len, reg_oc); |
977 | jge(l_main_loop, T_NEAR); |
978 | } |
979 | } |
980 | L(l_main_loop_end); |
981 | |
982 | // Epilogue loop |
983 | Label l_epilogue_end; |
984 | cmp(reg_len, 0); |
985 | je(l_epilogue_end, T_NEAR); |
986 | { |
987 | mov(reg_tmp, reg_len); |
988 | process_runtime_oc(); |
989 | } |
990 | L(l_epilogue_end); |
991 | } |
992 | |
993 | template <cpu_isa_t isa> |
994 | void jit_pp_kernel_t<isa>::compute_mb_blk() { |
995 | auto compute = [&](int tail, bool runtime_tail = false) { |
996 | auto vmm_bias = vreg_bias(0); |
997 | auto vmm_dst = vreg_dst(0); |
998 | assert(utils::one_of(this->acc_data_type_, s32, f32)); |
999 | data_copy(vmm_dst, arg_t::acc, 0, data_op_t::load, tail, runtime_tail); |
1000 | uni_vaddps(vmm_dst, vmm_dst, vmm_bias); |
1001 | data_copy(vmm_dst, arg_t::dst, 0, data_op_t::store, tail, runtime_tail); |
1002 | }; |
1003 | |
1004 | Label mb_main_loop, end_main_loop; |
1005 | |
1006 | bool expl_broadcast |
1007 | = this->OC_ == 1 && utils::one_of(this->bias_data_type_, s32, f32); |
1008 | size_t mb_step = vlen / this->OC_; |
1009 | size_t mb_tail = this->MB_ % mb_step; |
1010 | size_t mb_oc_blk = mb_step * this->OC_; |
1011 | size_t tail_size = mb_oc_blk % vlen; |
1012 | auto vmm_bias = vreg_bias(0); |
1013 | |
1014 | if (this->dst_data_type_ == bf16 && isa != avx512_core_bf16) |
1015 | bf16_emu_->init_vcvtneps2bf16(); |
1016 | |
1017 | if (expl_broadcast) { |
1018 | // when OC == 1 bias can be loaded directly into simd |
1019 | switch (this->bias_data_type_) { |
1020 | case s32: uni_vpbroadcastd(vmm_bias, ptr[reg_bias]); break; |
1021 | case f32: uni_vbroadcastss(vmm_bias, ptr[reg_bias]); break; |
1022 | // TODO: enable broadcast for other data types |
1023 | default: assert(!"unimplemented" ); |
1024 | } |
1025 | } else { |
1026 | // prepare bias data for simd computation |
1027 | prepare_mask(this->OC_); // this->OC will never be larger than vlen / 2 |
1028 | load_and_cvt(vmm_bias, arg_t::bias, 0, this->OC_, false); |
1029 | |
1030 | // write repeated MB*OC entries into stack |
1031 | sub(rsp, mb_oc_blk * sizeof(uint32_t)); |
1032 | for (size_t i = 0; i < mb_step; ++i) |
1033 | cvt_and_store(vmm_bias, arg_t::stack, |
1034 | i * this->OC_ * sizeof(uint32_t), this->OC_); |
1035 | |
1036 | // load into simd |
1037 | if (tail_size) prepare_mask(tail_size); |
1038 | load_and_cvt(vmm_bias, arg_t::stack, 0, tail_size, false); |
1039 | } |
1040 | if (utils::one_of(this->bias_data_type_, u8, s8, s32)) |
1041 | uni_vcvtdq2ps(vmm_bias, vmm_bias); |
1042 | L(mb_main_loop); |
1043 | { |
1044 | cmp(reg_len, mb_oc_blk); |
1045 | jl(end_main_loop, T_NEAR); |
1046 | |
1047 | compute(!expl_broadcast ? tail_size : 0); |
1048 | add(reg_dst, mb_oc_blk * this->dst_data_type_size_); |
1049 | add(reg_acc, mb_oc_blk * this->acc_data_type_size_); |
1050 | sub(reg_len, mb_oc_blk); |
1051 | jmp(mb_main_loop, T_NEAR); |
1052 | } |
1053 | L(end_main_loop); |
1054 | |
1055 | if (mb_tail > 0) { |
1056 | Label mb_tail_loop, runtime_tail, end_runtime_tail; |
1057 | tail_size = (mb_tail * this->OC_); |
1058 | if (tail_size) prepare_mask(tail_size); |
1059 | L(mb_tail_loop); |
1060 | { |
1061 | cmp(reg_len, tail_size); |
1062 | jl(runtime_tail, T_NEAR); |
1063 | compute(tail_size); |
1064 | add(reg_dst, tail_size * this->dst_data_type_size_); |
1065 | add(reg_acc, tail_size * this->acc_data_type_size_); |
1066 | sub(reg_len, tail_size); |
1067 | jmp(mb_tail_loop, T_NEAR); |
1068 | } |
1069 | // Load tail in runtime if len < mb_tail * oc |
1070 | L(runtime_tail); |
1071 | { |
1072 | cmp(reg_len, 0); |
1073 | jle(end_runtime_tail, T_NEAR); |
1074 | mov(reg_tail, reg_len); // save tail |
1075 | if (is_avx512_) { |
1076 | mov(reg_rem_mask, 1); |
1077 | shl(reg_rem_mask, cl); // cl == last 8 bits of reg_tail |
1078 | sub(reg_rem_mask, 1); |
1079 | kmovq(kreg_rem_mask, reg_rem_mask); |
1080 | } |
1081 | compute(tail_size, !is_avx512_); |
1082 | } |
1083 | L(end_runtime_tail); |
1084 | } |
1085 | |
1086 | if (!expl_broadcast) add(rsp, mb_oc_blk * sizeof(uint32_t)); |
1087 | } |
1088 | |
1089 | template <cpu_isa_t isa> |
1090 | void jit_pp_kernel_t<isa>::generate() { |
1091 | preamble(); |
1092 | |
1093 | #ifdef _WIN32 |
1094 | // binary postops injector needs params held (in case of WIN32) |
1095 | // in rcx register that is also used as a temp reg, so the pointer to |
1096 | // params needs to be stored in extra reg |
1097 | if (this->do_binary_) mov(reg_binary_inj_param_, param1); |
1098 | #endif |
1099 | |
1100 | #define PARAM_OFF(x) offsetof(ker_args_t, x) |
1101 | mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); |
1102 | mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]); |
1103 | mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); |
1104 | if (this->do_scale_) mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]); |
1105 | if (this->do_dst_scale_) { |
1106 | // don't overwrite reg_param |
1107 | assert(reg_tmp_comp.getIdx() != reg_param.getIdx()); |
1108 | mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dst_scale)]); |
1109 | auto xreg_dst_scale = Xmm(vreg_dst_scale.getIdx()); |
1110 | uni_vmovq(xreg_dst_scale, reg_tmp_comp); |
1111 | uni_vbroadcastss(vreg_dst_scale, xreg_dst_scale); |
1112 | } |
1113 | if (this->do_dst_zero_points_) { |
1114 | // use reg_oc as a temporary one (alas, reg_tmp = reg_param on Windows) |
1115 | mov(reg_oc, ptr[reg_param + PARAM_OFF(dst_zero_points)]); |
1116 | uni_vbroadcastss(vreg_dst_zero_points, ptr[reg_oc]); |
1117 | } |
1118 | if (this->runtime_oc()) |
1119 | mov(reg_oc, ptr[reg_param + PARAM_OFF(oc)]); |
1120 | else |
1121 | mov(reg_oc, this->OC_); |
1122 | mov(reg_len, ptr[reg_param + PARAM_OFF(len)]); |
1123 | mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]); |
1124 | if (this->do_binary_) { |
1125 | mov(reg_stack_frame_, rsp); |
1126 | sub(rsp, stack_space_needed_); |
1127 | if (any_binary_postop_is_per_oc_sp_bcast_type_ |
1128 | || any_binary_postop_is_per_oc_bcast_type_) { |
1129 | mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dim1_off)]); |
1130 | mov(ptr[rsp + reg_binary_post_op_oc_off_], reg_tmp_comp); |
1131 | } |
1132 | if (any_binary_postop_is_no_bcast_type_) { |
1133 | // store origin dst pointer to calculate proper binary src1 offset |
1134 | mov(reg_tmp_comp, ptr[reg_param + PARAM_OFF(dst_orig)]); |
1135 | mov(ptr[rsp + reg_origin_dst_ptr_], reg_tmp_comp); |
1136 | // init offset |
1137 | update_binary_postops_per_tensor_off(); |
1138 | } |
1139 | if (any_binary_postop_is_oc_bcast_type_) { |
1140 | // initialize binary post_ops no bcast offset accumulator |
1141 | mov(reg_tmp_comp, |
1142 | ptr[reg_param + PARAM_OFF(first_mb_matrix_addr_off)]); |
1143 | mov(ptr[rsp + reg_binary_post_op_sp_off_], reg_tmp_comp); |
1144 | } |
1145 | } |
1146 | if (this->do_scale_ && this->scale_idx_mult_ == 0) |
1147 | uni_vbroadcastss(vreg_scale, dword[reg_scales]); |
1148 | if (!this->has_trivial_mb_stride()) { |
1149 | mov(reg_dst_mb_stride, ptr[reg_param + PARAM_OFF(dst_mb_stride)]); |
1150 | sub(reg_dst_mb_stride, reg_oc); |
1151 | // if dst and acc point to same address (in-place), then strides must be |
1152 | // similar, else assume acc buffer is dense. |
1153 | xor_(reg_acc_mb_stride, reg_acc_mb_stride); |
1154 | cmp(reg_dst, reg_acc); |
1155 | cmove(reg_acc_mb_stride, reg_dst_mb_stride); |
1156 | } |
1157 | #undef PARAM_OFF |
1158 | |
1159 | if (this->do_sum_) { |
1160 | if (this->sum_scale_ != 1.f) { |
1161 | mov(reg_tmp, float2int(this->sum_scale_)); |
1162 | auto xreg_sum_scale = Xmm(vreg_sum_scale.getIdx()); |
1163 | uni_vmovq(xreg_sum_scale, reg_tmp); |
1164 | uni_vbroadcastss(vreg_sum_scale, xreg_sum_scale); |
1165 | } |
1166 | if (this->sum_zp_ != 0) { |
1167 | mov(reg_tmp, this->sum_zp_); |
1168 | auto xreg_sum_zp = Xmm(vreg_sum_zp.getIdx()); |
1169 | uni_vmovq(xreg_sum_zp, reg_tmp); |
1170 | uni_vbroadcastss(vreg_sum_zp, xreg_sum_zp); |
1171 | uni_vcvtdq2ps(vreg_sum_zp, vreg_sum_zp); |
1172 | } |
1173 | } |
1174 | |
1175 | init_saturate_f32(vreg_zero, vreg_saturation_ubound, reg_tmp_comp, f32, |
1176 | this->dst_data_type_); |
1177 | |
1178 | // at least 2 blocks of mb within vlen |
1179 | bool dim_restrict = !this->runtime_oc() && !this->runtime_mb() |
1180 | && (this->OC_ <= vlen / 2) && (this->MB_ >= vlen); |
1181 | bool supported_postops = this->do_scale_ || this->do_eltwise_ |
1182 | || this->do_binary_ || this->do_sum_ || this->do_dst_zero_points_ |
1183 | || this->do_dst_scale_; |
1184 | if (this->do_bias() && !supported_postops && dim_restrict |
1185 | && this->has_trivial_mb_stride()) { |
1186 | this->mb_blk_kernel_ = true; |
1187 | compute_mb_blk(); |
1188 | } else { |
1189 | compute_oc_channel_blk(); |
1190 | } |
1191 | |
1192 | if (this->do_binary_) add(rsp, stack_space_needed_); |
1193 | postamble(); |
1194 | |
1195 | if (this->do_eltwise_) postops_injector_->prepare_table(); |
1196 | } |
1197 | |
1198 | template <cpu_isa_t isa> |
1199 | void jit_pp_kernel_t<isa>::operator()(void *dst, const void *acc, |
1200 | const char *bias, const float *scales, float dst_scale, size_t start, |
1201 | size_t dst_logical_off, size_t dim1_off, size_t end, size_t runtime_oc, |
1202 | dim_t dst_mb_stride, const float *dst_zero_points, |
1203 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
1204 | size_t first_mb_matrix_addr_off, const exec_ctx_t & /* ctx */, |
1205 | const memory_desc_t & /* dst_md */) const { |
1206 | |
1207 | if (end <= start) return; |
1208 | const size_t OC = this->runtime_oc() ? runtime_oc : this->OC_; |
1209 | |
1210 | ker_args_t args; |
1211 | size_t oc_offset = start % OC; |
1212 | if (this->has_trivial_mb_stride()) { |
1213 | args.dst = static_cast<char *>(dst) + this->dst_data_type_size_ * start; |
1214 | args.acc = static_cast<const char *>(acc) |
1215 | + this->acc_data_type_size_ * start; |
1216 | } else { |
1217 | const dim_t offt = (start / OC) * dst_mb_stride + oc_offset; |
1218 | args.dst = static_cast<char *>(dst) + this->dst_data_type_size_ * offt; |
1219 | // if dst and acc point to same address (inplace), then strides |
1220 | // must be similar, else assume acc buffer is dense. |
1221 | const auto stride = dst == acc ? offt : start; |
1222 | args.acc = static_cast<const char *>(acc) |
1223 | + this->acc_data_type_size_ * stride; |
1224 | } |
1225 | args.bias = bias + oc_offset * this->bias_data_type_size_; |
1226 | args.scales = scales + this->scale_idx_mult_ * oc_offset; |
1227 | args.dst_scale = dst_scale; |
1228 | args.dst_zero_points = dst_zero_points; |
1229 | args.oc = OC; |
1230 | args.len = end - start; |
1231 | args.oc_offset = oc_offset; |
1232 | args.dst_logical_off = dst_logical_off; |
1233 | args.dim1_off = dim1_off; |
1234 | args.dst_mb_stride = dst_mb_stride; |
1235 | args.first_mb_matrix_addr_off = first_mb_matrix_addr_off; |
1236 | |
1237 | args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
1238 | args.dst_orig = dst_orig; |
1239 | jit_generator::operator()(&args); |
1240 | } |
1241 | |
1242 | pp_kernel_t *jit_pp_kernel_create(size_t OC, size_t MB, dim_t dst_mb_stride, |
1243 | const primitive_attr_t *attr, data_type_t bias_dt, data_type_t acc_dt, |
1244 | const memory_desc_t *dst_md, bool skip_sum) { |
1245 | if (mayiuse(avx512_core_bf16)) { |
1246 | return new jit_pp_kernel_t<avx512_core_bf16>( |
1247 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
1248 | } else if (mayiuse(avx512_core)) { |
1249 | return new jit_pp_kernel_t<avx512_core>( |
1250 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
1251 | } else if (mayiuse(avx2)) { |
1252 | return new jit_pp_kernel_t<avx2>( |
1253 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
1254 | } else if (mayiuse(sse41)) { |
1255 | return new jit_pp_kernel_t<sse41>( |
1256 | OC, MB, dst_mb_stride, attr, bias_dt, acc_dt, dst_md, skip_sum); |
1257 | } else { |
1258 | return nullptr; |
1259 | } |
1260 | } |
1261 | |
1262 | } // namespace inner_product_utils |
1263 | } // namespace x64 |
1264 | } // namespace cpu |
1265 | } // namespace impl |
1266 | } // namespace dnnl |
1267 | |