1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <algorithm>
17#include <cmath>
18
19#include "common/primitive.hpp"
20#include "common/primitive_attr.hpp"
21#include "common/primitive_exec_types.hpp"
22#include "common/utils.hpp"
23#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29namespace binary_injector {
30
31static bcast_set_t get_all_strategies_supported_by_injector() {
32 return bcast_set_t {broadcasting_strategy_t::scalar,
33 broadcasting_strategy_t::per_oc,
34 broadcasting_strategy_t::per_oc_spatial,
35 broadcasting_strategy_t::per_mb_w, broadcasting_strategy_t::per_w,
36 broadcasting_strategy_t::no_broadcast};
37}
38
39bool is_data_supported(cpu_isa_t isa, data_type_t data_type) {
40 switch (data_type) {
41 case data_type::f32:
42 case data_type::s32:
43 case data_type::s8:
44 case data_type::u8: return true;
45 case data_type::bf16:
46 return is_superset(isa, avx512_core)
47 || is_superset(isa, avx2_vnni_2);
48 case data_type::f16:
49 return is_superset(isa, avx512_core_fp16)
50 || is_superset(isa, avx2_vnni_2);
51 default: return true;
52 }
53}
54
55static bool src1_desc_layout_same_as_dst_d(
56 const dnnl::impl::memory_desc_t &src1_desc,
57 const memory_desc_wrapper &dst_d) {
58 if (dst_d.md_ == nullptr) return false;
59 const auto &lhs = src1_desc;
60 const auto &rhs = *(dst_d.md_);
61
62 using namespace dnnl::impl::utils;
63 const bool is_format_any
64 = one_of(format_kind::any, lhs.format_kind, rhs.format_kind);
65
66 return lhs.ndims == rhs.ndims
67 && (is_format_any
68 || (lhs.format_kind == rhs.format_kind
69 && array_cmp(lhs.format_desc.blocking.strides,
70 rhs.format_desc.blocking.strides,
71 lhs.ndims)))
72 && array_cmp(lhs.dims, rhs.dims, lhs.ndims)
73 && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims)
74 && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims)
75 && lhs.offset0 == rhs.offset0;
76}
77
78bool is_bcast_supported(const dnnl::impl::memory_desc_t &src1_desc,
79 const memory_desc_wrapper &dst_d,
80 const bcast_set_t &supported_strategy_set) {
81 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
82 src1_desc, dst_d, supported_strategy_set);
83
84 if (bcast_type == broadcasting_strategy_t::no_broadcast) {
85 // in case of no broadcast data layout of dst and src1 have to be the same
86 if (!src1_desc_layout_same_as_dst_d(src1_desc, dst_d)) return false;
87 }
88
89 return bcast_type != broadcasting_strategy_t::unsupported;
90}
91
92bool is_supported(cpu_isa_t isa, const dnnl::impl::memory_desc_t &src1_desc,
93 const memory_desc_wrapper &dst_d,
94 const bcast_set_t &supported_strategy_set) {
95 return is_data_supported(isa, src1_desc.data_type)
96 && is_bcast_supported(src1_desc, dst_d, supported_strategy_set);
97}
98
99bool binary_args_broadcast_supported(const post_ops_t &post_ops,
100 const memory_desc_wrapper &dst_d,
101 const bcast_set_t &supported_strategy_set) {
102
103 return std::none_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
104 [&](const post_ops_t::entry_t &entry) -> bool {
105 if (entry.is_binary()) {
106 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
107 entry.binary.src1_desc, dst_d,
108 supported_strategy_set);
109 return bcast_type == broadcasting_strategy_t::unsupported;
110 }
111 return false;
112 });
113}
114
115bool binary_args_tail_supported(const post_ops_t &post_ops,
116 const memory_desc_wrapper &dst_d, int vlen,
117 const bcast_set_t &supported_strategy_set) {
118 const auto channels = dst_d.dims()[1];
119 const int vmm_l_len = vlen / 4;
120
121 return std::none_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
122 [&](const post_ops_t::entry_t &entry) -> bool {
123 if (entry.is_binary()) {
124 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
125 entry.binary.src1_desc, dst_d,
126 supported_strategy_set);
127 return utils::one_of(bcast_type,
128 broadcasting_strategy_t::per_oc,
129 broadcasting_strategy_t::per_oc_spatial)
130 && (channels % vmm_l_len != 0);
131 }
132 return false;
133 });
134}
135
136bool binary_args_matches_tag(format_tag_t tag, const post_ops_t &post_ops) {
137 return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
138 [&](const post_ops_t::entry_t &entry) {
139 if (entry.is_binary()) {
140 const memory_desc_wrapper rhs_arg_d(entry.binary.src1_desc);
141 return rhs_arg_d.matches_tag(tag);
142 }
143 return true;
144 });
145}
146
147bool any_binary_postop_rhs_non_scalar_broadcast(
148 const post_ops_t &post_ops, const memory_desc_wrapper &dst_d) {
149 return std::any_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
150 [&](const post_ops_t::entry_t &entry) -> bool {
151 if (entry.is_binary()) {
152 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
153 entry.binary.src1_desc, dst_d,
154 get_all_strategies_supported_by_injector());
155 return !utils::one_of(bcast_type,
156 broadcasting_strategy_t::scalar,
157 broadcasting_strategy_t::unsupported);
158 }
159 return false;
160 });
161}
162
163bool any_binary_postop_rhs_per_oc_broadcast(
164 const post_ops_t &post_ops, const memory_desc_wrapper &dst_d) {
165 return any_binary_postop_rhs_per_oc_broadcast(
166 post_ops, dst_d, get_all_strategies_supported_by_injector());
167}
168
169bool any_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
170 const memory_desc_wrapper &dst_d,
171 const bcast_set_t &supported_strategy_set) {
172 return std::any_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
173 [&](const post_ops_t::entry_t &entry) -> bool {
174 if (entry.is_binary()) {
175 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
176 entry.binary.src1_desc, dst_d,
177 supported_strategy_set);
178 return bcast_type == broadcasting_strategy_t::per_oc
179 || bcast_type
180 == broadcasting_strategy_t::per_oc_spatial;
181 }
182 return false;
183 });
184}
185
186bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
187 const memory_desc_wrapper &dst_d,
188 const std::function<bool(const memory_desc_wrapper &)> &predicate) {
189 return all_binary_postop_rhs_per_oc_broadcast(post_ops, dst_d,
190 get_all_strategies_supported_by_injector(), predicate);
191}
192
193bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
194 const memory_desc_wrapper &dst_d,
195 const bcast_set_t &supported_strategy_set,
196 const std::function<bool(const memory_desc_wrapper &)> &predicate) {
197 return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
198 [&](const post_ops_t::entry_t &entry) -> bool {
199 if (entry.is_binary()) {
200 const auto bcast_type = get_rhs_arg_broadcasting_strategy(
201 entry.binary.src1_desc, dst_d,
202 supported_strategy_set);
203 if (bcast_type == broadcasting_strategy_t::per_oc
204 || bcast_type
205 == broadcasting_strategy_t::per_oc_spatial)
206 return predicate(
207 memory_desc_wrapper(entry.binary.src1_desc));
208 }
209 return true;
210 });
211}
212
213static_params_t::static_params_t(const Xbyak::Reg64 &param1,
214 const bcast_set_t &supported_strategy_set,
215 const rhs_arg_static_params_t &rhs_arg_static_params)
216 : param1(param1)
217 , supported_strategy_set(supported_strategy_set)
218 , rhs_arg_static_params(rhs_arg_static_params) {}
219
220static_params_t::static_params_t(const Xbyak::Reg64 &param1,
221 const rhs_arg_static_params_t &rhs_arg_static_params)
222 : static_params_t(param1, get_all_strategies_supported_by_injector(),
223 rhs_arg_static_params) {}
224
225rhs_arg_static_params_t::rhs_arg_static_params_t(
226 std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
227 const Xbyak::Reg64 &rhs_helper_reg,
228 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
229 bool preserve_vmm_helper, std::size_t abi_param_offset,
230 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
231 std::size_t tail_size, bool use_exact_tail_scalar_bcast)
232 : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg,
233 rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
234 preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
235 tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast,
236 rhs_helper_reg, false /*is_opmask_set*/) {}
237
238rhs_arg_static_params_t::rhs_arg_static_params_t(
239 std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
240 const Xbyak::Reg64 &rhs_helper_reg,
241 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
242 bool preserve_vmm_helper, std::size_t abi_param_offset,
243 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
244 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
245 bool use_exact_tail_scalar_bcast)
246 : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg,
247 rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
248 preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
249 tail_size, tail_opmask, use_exact_tail_scalar_bcast, rhs_helper_reg,
250 true /*is_opmask_set*/) {}
251
252rhs_arg_static_params_t::rhs_arg_static_params_t(
253 std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
254 const Xbyak::Reg64 &rhs_helper_reg,
255 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
256 bool preserve_vmm_helper, std::size_t abi_param_offset,
257 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
258 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
259 const Xbyak::Reg64 &reg_tail_size, bool use_exact_tail_scalar_bcast)
260 : rhs_arg_static_params_t(rhs_dt_helper_vmm_idx, rhs_addr_reg,
261 rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr_helpers,
262 preserve_vmm_helper, abi_param_offset, dst_orig_offset, dst_d,
263 tail_size, tail_opmask, use_exact_tail_scalar_bcast, reg_tail_size,
264 true /*is_opmask_set*/) {}
265
266rhs_arg_static_params_t::rhs_arg_static_params_t(
267 std::size_t rhs_dt_helper_vmm_idx, const Xbyak::Reg64 &rhs_addr_reg,
268 const Xbyak::Reg64 &rhs_helper_reg,
269 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
270 bool preserve_vmm_helper, std::size_t abi_param_offset,
271 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
272 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
273 bool use_exact_tail_scalar_bcast, const Xbyak::Reg64 &reg_tail_size,
274 bool is_opmask_set)
275 : rhs_dt_helper_vmm_idx(rhs_dt_helper_vmm_idx)
276 , rhs_addr_reg(rhs_addr_reg)
277 , rhs_helper_reg(rhs_helper_reg)
278 , rhs_addr_cache_reg(rhs_addr_cache_reg)
279 , preserve_gpr_helpers(preserve_gpr_helpers)
280 , preserve_vmm_helper(preserve_vmm_helper)
281 , abi_param_offset(abi_param_offset)
282 , dst_orig_offset(dst_orig_offset)
283 , dst_d(dst_d)
284 , tail_size(tail_size)
285 , tail_opmask(tail_opmask)
286 , use_exact_tail_scalar_bcast(use_exact_tail_scalar_bcast)
287 , reg_tail_size(reg_tail_size)
288 , is_tail(tail_size)
289 , is_opmask_set_(is_opmask_set) {}
290
291template <cpu_isa_t isa, typename Vmm>
292jit_uni_binary_injector_t<isa, Vmm>::jit_uni_binary_injector_t(
293 jit_generator *host, const static_params_t &static_params)
294 : host_(host)
295 , rhs_arg_static_params_(static_params.rhs_arg_static_params)
296 , param1_(static_params.param1)
297 , supported_strategy_set_(static_params.supported_strategy_set) {}
298
299template <typename ParamsMap>
300static bool params_differ(ParamsMap &params,
301 const typename ParamsMap::key_type key1,
302 const typename ParamsMap::key_type key2) {
303 const auto &it1 = params.find(key1);
304 const auto &it2 = params.find(key2);
305 if (utils::one_of(params.end(), it1, it2)) return it1 != it2;
306 return it1->second != it2->second;
307}
308
309static bool rhs_arg_params_differ(size_t vmm_idx1, size_t vmm_idx2,
310 const rhs_arg_dynamic_params_t &rhs_arg_params,
311 broadcasting_strategy_t rhs_broadcasting_strategy) {
312
313 const auto &out_addr = rhs_arg_params.vmm_idx_to_out_addr;
314 const auto &out_reg = rhs_arg_params.vmm_idx_to_out_reg;
315 const auto &out_elem_off_val = rhs_arg_params.vmm_idx_to_out_elem_off_val;
316
317 if (rhs_broadcasting_strategy != broadcasting_strategy_t::scalar) {
318 return params_differ(out_addr, vmm_idx1, vmm_idx2)
319 || params_differ(out_reg, vmm_idx1, vmm_idx2)
320 || params_differ(out_elem_off_val, vmm_idx1, vmm_idx2);
321 }
322 return false;
323}
324
325template <cpu_isa_t isa, typename Vmm>
326int jit_uni_binary_injector_t<isa, Vmm>::adjust_temp_vmm_hint(
327 int user_hint, int start_idx, int end_idx, int max_vmm_idx) const {
328 const bool user_hint_in_vector_range
329 = user_hint >= start_idx && user_hint <= end_idx;
330 const bool user_hint_exceeded_limit = user_hint > max_vmm_idx;
331 const bool user_hint_invalid
332 = user_hint_in_vector_range || user_hint_exceeded_limit;
333
334 if (user_hint_invalid) {
335 const bool max_vmm_idx_in_vector_range
336 = max_vmm_idx >= start_idx && max_vmm_idx <= end_idx;
337
338 if (max_vmm_idx_in_vector_range || user_hint_exceeded_limit
339 || user_hint == max_vmm_idx)
340 return 0;
341 else
342 return max_vmm_idx;
343 }
344
345 return user_hint;
346}
347
348template <typename Vmm>
349static void push_vmm(jit_generator *host, const Vmm &vmm) {
350 host->sub(host->rsp, vreg_traits<Vmm>::vlen);
351 host->uni_vmovups(host->ptr[host->rsp], vmm);
352}
353
354template <typename Vmm>
355static void pop_vmm(jit_generator *host, const Vmm &vmm) {
356 host->uni_vmovups(vmm, host->ptr[host->rsp]);
357 host->add(host->rsp, vreg_traits<Vmm>::vlen);
358}
359
360static void push_opmask(jit_generator *host, const Xbyak::Opmask &k) {
361 static constexpr int k_mask_size = 8;
362 host->sub(host->rsp, k_mask_size);
363 if (mayiuse(avx512_core))
364 host->kmovq(host->ptr[host->rsp], k);
365 else
366 host->kmovw(host->ptr[host->rsp], k);
367}
368
369static void pop_opmask(jit_generator *host, const Xbyak::Opmask &k) {
370 static constexpr int k_mask_size = 8;
371 if (mayiuse(avx512_core))
372 host->kmovq(k, host->ptr[host->rsp]);
373 else
374 host->kmovw(k, host->ptr[host->rsp]);
375 host->add(host->rsp, k_mask_size);
376}
377
378template <typename Vmm>
379static void restore_stack(jit_generator *host, const Vmm &vmm) {
380 host->add(host->rsp, vreg_traits<Vmm>::vlen);
381}
382
383template <cpu_isa_t isa, typename Vmm>
384std::pair<bool, int> jit_uni_binary_injector_t<isa, Vmm>::should_preserve_vmm(
385 int curr_idx, int vmm_hint, int max_vmm_idx,
386 bool dt_helper_vmm_needed) const {
387 if (dt_helper_vmm_needed && vmm_hint == curr_idx) {
388 if (curr_idx == 0)
389 return std::make_pair(true, max_vmm_idx);
390 else
391 return std::make_pair(true, 0);
392 }
393 return std::make_pair(false, vmm_hint);
394}
395
396template <cpu_isa_t isa, typename Vmm>
397void jit_uni_binary_injector_t<isa, Vmm>::compute_vector_range(size_t start_idx,
398 size_t end_idx, std::size_t rhs_arg_idx,
399 const dnnl_post_ops::entry_t &post_op,
400 const rhs_arg_dynamic_params_t &rhs_arg_params) const {
401 injector_utils::vmm_index_set_t vmm_idxs;
402 for (size_t i = start_idx; i < end_idx; i++)
403 vmm_idxs.emplace(i);
404 compute_vector_range(vmm_idxs, rhs_arg_idx, post_op, rhs_arg_params);
405}
406
407template <cpu_isa_t isa, typename Vmm>
408void jit_uni_binary_injector_t<isa, Vmm>::compute_vector_range(
409 const injector_utils::vmm_index_set_t &vmm_idxs,
410 std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
411 const rhs_arg_dynamic_params_t &rhs_arg_params) const {
412
413 if (vmm_idxs.empty()) return;
414 const auto start_idx = *(vmm_idxs.begin());
415 const auto end_idx = *(vmm_idxs.rbegin());
416
417 // Phase 1 Validate temporary vmm user hint
418 static constexpr int max_vmm_idx = cpu_isa_traits<isa>::n_vregs - 1;
419 auto &vmm_hint = rhs_arg_static_params_.rhs_dt_helper_vmm_idx;
420 vmm_hint = adjust_temp_vmm_hint(vmm_hint, start_idx, end_idx, max_vmm_idx);
421
422 const auto rhs_broadcasting_strategy
423 = get_rhs_arg_broadcasting_strategy(post_op.binary.src1_desc,
424 rhs_arg_static_params_.dst_d, supported_strategy_set_);
425 const auto rhs_arg_data_type = post_op.binary.src1_desc.data_type;
426 const auto &vmm_tail_idx = rhs_arg_params.vmm_tail_idx_;
427 const bool tail_exists_in_range = !vmm_tail_idx.empty();
428 const bool bcast_f32_non_avx512 = !is_avx512_
429 && utils::one_of(rhs_broadcasting_strategy,
430 broadcasting_strategy_t::scalar,
431 broadcasting_strategy_t::per_oc_spatial)
432 && rhs_arg_data_type == data_type::f32;
433 const bool should_preserve_vmm_tail = tail_exists_in_range
434 && (!is_avx512_
435 || !utils::one_of(rhs_broadcasting_strategy,
436 broadcasting_strategy_t::scalar,
437 broadcasting_strategy_t::per_oc_spatial)
438 || rhs_arg_data_type != data_type::f32);
439 const bool dt_helper_vmm_needed
440 = !binary_op_with_unaligned_mem_operand_allowed_
441 || rhs_arg_data_type != data_type::f32 || bcast_f32_non_avx512
442 || should_preserve_vmm_tail;
443 const auto tail_load_mode = rhs_arg_params.tail_load_mode;
444 const auto dst_d = rhs_arg_static_params_.dst_d;
445 const int simd_w = cpu_isa_traits<isa>::vlen
446 / types::data_type_size(dst_d.data_type());
447 const int blk_size = dst_d.blocking_desc().inner_blks[0];
448 const bool use_offset_conversions
449 = (!rhs_arg_params.vmm_idx_to_out_addr.empty()
450 || !rhs_arg_params.vmm_idx_to_out_reg.empty());
451 const bool should_preserve_oc_offset_conversion_regs
452 = use_offset_conversions
453 && utils::one_of(rhs_broadcasting_strategy,
454 broadcasting_strategy_t::per_oc,
455 broadcasting_strategy_t::per_oc_spatial)
456 && blk_size > simd_w;
457 const bool should_preserve_mb_sp_offset_conversion_regs
458 = use_offset_conversions
459 && utils::one_of(rhs_broadcasting_strategy,
460 broadcasting_strategy_t::per_mb_spatial,
461 broadcasting_strategy_t::per_mb_w);
462 const bool should_preserve_w_offset_conversion_regs = use_offset_conversions
463 && rhs_broadcasting_strategy == broadcasting_strategy_t::per_w;
464 const bool should_preserve_w_or_oc_offset_conversion_regs
465 = should_preserve_oc_offset_conversion_regs
466 || should_preserve_w_offset_conversion_regs;
467
468 // Phase 2 Protect temporary registers content.
469 const injector_utils::register_preserve_guard_t register_guard {host_,
470 (rhs_arg_static_params_.preserve_gpr_helpers
471 && should_preserve_w_or_oc_offset_conversion_regs
472 ? std::initializer_list<Xbyak::Reg64>(
473 {rhs_arg_static_params_.rhs_addr_reg,
474 rhs_arg_static_params_
475 .rhs_helper_reg,
476 rhs_arg_static_params_
477 .rhs_addr_cache_reg,
478 host_->rax, host_->rdx, host_->r8})
479 : rhs_arg_static_params_.preserve_gpr_helpers
480 && should_preserve_mb_sp_offset_conversion_regs
481 ? std::initializer_list<Xbyak::Reg64>(
482 {rhs_arg_static_params_
483 .rhs_addr_reg,
484 rhs_arg_static_params_
485 .rhs_helper_reg,
486 rhs_arg_static_params_
487 .rhs_addr_cache_reg,
488 host_->rax, host_->rdx,
489 host_->r8, host_->r9})
490 : rhs_arg_static_params_
491 .preserve_gpr_helpers
492 ? std::initializer_list<
493 Xbyak::Reg64>({rhs_arg_static_params_
494 .rhs_addr_reg,
495 rhs_arg_static_params_
496 .rhs_helper_reg,
497 rhs_arg_static_params_
498 .rhs_addr_cache_reg,
499 host_->rax, host_->rdx})
500 : should_preserve_w_or_oc_offset_conversion_regs
501 ? std::initializer_list<
502 Xbyak::Reg64>(
503 {rhs_arg_static_params_
504 .rhs_addr_cache_reg,
505 host_->rax,
506 host_->rdx,
507 host_->r8})
508 : should_preserve_mb_sp_offset_conversion_regs
509 ? std::initializer_list<
510 Xbyak::Reg64>({rhs_arg_static_params_
511 .rhs_addr_cache_reg,
512 host_->rax,
513 host_->rdx,
514 host_->r8,
515 host_->r9})
516 : use_offset_conversions
517 ? std::initializer_list<
518 Xbyak::Reg64>({rhs_arg_static_params_
519 .rhs_addr_cache_reg,
520 host_->rax,
521 host_->rdx})
522 : std::initializer_list<
523 Xbyak::Reg64>()),
524 (rhs_arg_static_params_.preserve_vmm_helper && dt_helper_vmm_needed
525 ? std::initializer_list<Xbyak::Xmm>({Vmm(vmm_hint)})
526 : std::initializer_list<Xbyak::Xmm>())};
527
528 bool vmm0_was_preserved = false;
529 static const Vmm zero_vmm(0);
530
531 Xbyak::Address rhs_arg_addr(0);
532
533 // Phase 3 Apply binary post-op over all vmms.
534 for (const auto vmm_idx : vmm_idxs) {
535 const bool is_start_idx = vmm_idx == start_idx;
536 if (is_start_idx
537 || rhs_arg_params_differ(vmm_idx, vmm_idx - 1, rhs_arg_params,
538 rhs_broadcasting_strategy)) {
539 rhs_arg_addr = prepare_rhs_arg_addr(vmm_idx, rhs_arg_idx, post_op,
540 rhs_arg_params, rhs_broadcasting_strategy, is_start_idx);
541 }
542
543 const auto local_vmm_preservation = should_preserve_vmm(
544 vmm_idx, vmm_hint, max_vmm_idx, dt_helper_vmm_needed);
545 const bool &vmm_preservation_needed = local_vmm_preservation.first;
546 const Vmm dst_vmm(vmm_idx);
547 const bool with_tail = rhs_arg_static_params_.is_tail
548 && vmm_tail_idx.find(vmm_idx) != vmm_tail_idx.cend()
549 && IMPLICATION(rhs_broadcasting_strategy
550 == broadcasting_strategy_t::scalar,
551 rhs_arg_static_params_.use_exact_tail_scalar_bcast);
552
553 if (vmm_preservation_needed) {
554 const Vmm vmm_to_preserve(local_vmm_preservation.second);
555 push_vmm(host_, vmm_to_preserve);
556 inject_binary(
557 post_op, dst_vmm, rhs_arg_addr, with_tail, tail_load_mode);
558 pop_vmm(host_, vmm_to_preserve);
559 // in case all Vmm are occupied, Vmm(0) is chosen for tmp by default,
560 // so it's content needs to be preserved...
561
562 push_vmm(host_, zero_vmm);
563 vmm0_was_preserved = true;
564 } else
565 inject_binary(
566 post_op, dst_vmm, rhs_arg_addr, with_tail, tail_load_mode);
567 }
568 // ...and restored afterwards
569 if (vmm0_was_preserved) pop_vmm(host_, zero_vmm);
570}
571
572template <cpu_isa_t isa, typename Vmm>
573Xbyak::Address jit_uni_binary_injector_t<isa, Vmm>::prepare_rhs_arg_addr(
574 std::size_t vmm_idx, std::size_t rhs_arg_idx,
575 const dnnl_post_ops::entry_t &post_op,
576 const rhs_arg_dynamic_params_t &rhs_arg_params,
577 const broadcasting_strategy_t rhs_broadcasting_strategy,
578 bool is_first) const {
579
580 static constexpr auto rhs_arg_ptr_size = sizeof(const void *);
581 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
582 const auto &abi_param_offset = rhs_arg_static_params_.abi_param_offset;
583 const auto &rhs_helper_reg = rhs_arg_static_params_.rhs_helper_reg;
584 const auto rhs_arg_elem_size
585 = types::data_type_size(post_op.binary.src1_desc.data_type);
586
587 if (is_first) {
588 host_->mov(rhs_addr_reg, host_->ptr[param1_ + abi_param_offset]);
589 host_->mov(rhs_addr_reg,
590 host_->ptr[rhs_addr_reg + rhs_arg_idx * rhs_arg_ptr_size]);
591 }
592
593 switch (rhs_broadcasting_strategy) {
594 case broadcasting_strategy_t::scalar: return host_->ptr_b[rhs_addr_reg];
595 case broadcasting_strategy_t::no_broadcast: {
596 append_no_broadcast_offset(rhs_arg_params.vmm_idx_to_out_addr,
597 rhs_arg_params.vmm_idx_to_out_reg,
598 rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx,
599 rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first);
600
601 return host_->ptr[rhs_addr_reg];
602 }
603 case broadcasting_strategy_t::per_oc:
604 case broadcasting_strategy_t::per_oc_spatial: {
605 append_oc_offset(rhs_arg_params.vmm_idx_to_out_addr,
606 rhs_arg_params.vmm_idx_to_out_reg,
607 rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx,
608 rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first);
609
610 return rhs_broadcasting_strategy
611 == broadcasting_strategy_t::per_oc_spatial
612 ? host_->ptr_b[rhs_addr_reg]
613 : host_->ptr[rhs_addr_reg];
614 }
615 case broadcasting_strategy_t::per_mb_spatial: {
616 append_mb_sp_offset(rhs_arg_params.vmm_idx_to_out_addr,
617 rhs_arg_params.vmm_idx_to_out_reg,
618 rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx,
619 rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first);
620
621 return host_->ptr[rhs_addr_reg];
622 }
623 case broadcasting_strategy_t::per_mb_w: {
624 append_mb_w_offset(rhs_arg_params.vmm_idx_to_out_addr,
625 rhs_arg_params.vmm_idx_to_out_reg,
626 rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx,
627 rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first);
628
629 return host_->ptr[rhs_addr_reg];
630 }
631 case broadcasting_strategy_t::per_w: {
632 append_w_offset(rhs_arg_params.vmm_idx_to_out_addr,
633 rhs_arg_params.vmm_idx_to_out_reg,
634 rhs_arg_params.vmm_idx_to_out_elem_off_val, vmm_idx,
635 rhs_addr_reg, rhs_helper_reg, rhs_arg_elem_size, is_first);
636
637 return host_->ptr[rhs_addr_reg];
638 }
639 default: assert(false && "Broadcasting type not supported");
640 }
641
642 return host_->ptr[rhs_addr_reg];
643}
644
645template <cpu_isa_t isa, typename Vmm>
646void jit_uni_binary_injector_t<isa, Vmm>::append_no_broadcast_offset(
647 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
648 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
649 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, int vmm_idx,
650 const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg,
651 std::size_t elem_size_bytes, bool is_first) const {
652
653 const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx);
654 const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx);
655
656 const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end();
657 const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end();
658 if (is_out_addr || is_out_reg) {
659 Xbyak::Address out_addr = is_out_addr ? it_out_addr->second
660 : host_->ptr[it_out_reg->second];
661 const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx);
662 const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg;
663
664 if (is_first) {
665 calculate_no_broadcast_base(out_addr, tmp_reg);
666 if (elem_size_bytes > 1) {
667 const int shift_val = std::log2(elem_size_bytes);
668 host_->sal(tmp_reg, shift_val);
669 }
670 host_->add(addr_reg, tmp_reg);
671 host_->mov(addr_cache_reg, addr_reg);
672 } else {
673 host_->mov(addr_reg, addr_cache_reg);
674 }
675
676 if (it_off_val != vmm_idx_to_out_elem_off_val.end()) {
677 calculate_no_broadcast_partial(
678 it_off_val->second, tmp_reg, elem_size_bytes);
679 host_->add(addr_reg, tmp_reg);
680 }
681 }
682}
683
684template <cpu_isa_t isa, typename Vmm>
685void jit_uni_binary_injector_t<isa, Vmm>::calculate_no_broadcast_base(
686 Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const {
687 host_->lea(out_reg, addr);
688 host_->sub(out_reg,
689 host_->ptr[param1_ + rhs_arg_static_params_.dst_orig_offset]);
690 host_->shr(out_reg,
691 std::log2(types::data_type_size(
692 rhs_arg_static_params_.dst_d.data_type())));
693}
694
695template <cpu_isa_t isa, typename Vmm>
696void jit_uni_binary_injector_t<isa, Vmm>::calculate_no_broadcast_partial(
697 const std::size_t offset, const Xbyak::Reg64 &out_reg,
698 std::size_t elem_size_bytes) const {
699 const auto offset_adj = offset >> math::ilog2q(types::data_type_size(
700 rhs_arg_static_params_.dst_d.data_type()));
701 host_->mov(out_reg,
702 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
703 : offset_adj);
704}
705
706template <cpu_isa_t isa, typename Vmm>
707void jit_uni_binary_injector_t<isa, Vmm>::append_oc_offset(
708 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
709 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
710 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, int vmm_idx,
711 const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg,
712 std::size_t elem_size_bytes, bool is_first) const {
713
714 const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx);
715 const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx);
716
717 const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end();
718 const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end();
719
720 if (is_out_addr || is_out_reg) {
721 Xbyak::Address out_addr = is_out_addr ? it_out_addr->second
722 : host_->ptr[it_out_reg->second];
723 const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx);
724 const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg;
725
726 const auto dst_d = rhs_arg_static_params_.dst_d;
727 const auto strides = dst_d.blocking_desc().strides;
728 const auto layout = injector_utils::get_layout_type(dst_d);
729
730 if (is_first) {
731 calculate_no_broadcast_base(out_addr, tmp_reg);
732
733 const auto rax = host_->rax;
734 const auto rdx = host_->rdx;
735 const auto r8 = host_->r8;
736
737 const injector_utils::conditional_register_preserve_guard_t
738 register_guard {is_out_reg ? utils::one_of(
739 it_out_reg->second, rax, rdx, r8)
740 : false,
741 host_, {it_out_reg->second}};
742
743 switch (layout) {
744 case injector_utils::layout_t::ncsp:
745 calculate_oc_ncsp_base(strides, tmp_reg);
746 break;
747 case injector_utils::layout_t::c_blocked:
748 calculate_oc_blocked_base(strides, tmp_reg);
749 break;
750 case injector_utils::layout_t::nspc:
751 calculate_oc_nspc_base(strides, tmp_reg);
752 break;
753 case injector_utils::layout_t::cspn:
754 calculate_oc_cspn_base(strides, tmp_reg);
755 break;
756 default: assert(!"Unknown layout");
757 }
758
759 if (elem_size_bytes == 1) {
760 host_->add(addr_reg, rax);
761 } else {
762 const int shift_val = std::log2(elem_size_bytes);
763 host_->mov(tmp_reg, rax);
764 host_->sal(tmp_reg, shift_val);
765 host_->add(addr_reg, tmp_reg);
766 }
767 host_->mov(addr_cache_reg, addr_reg);
768 } else {
769 host_->mov(addr_reg, addr_cache_reg);
770 }
771
772 if (it_off_val != vmm_idx_to_out_elem_off_val.end()) {
773 switch (layout) {
774 case injector_utils::layout_t::ncsp:
775 calculate_oc_ncsp_partial(strides, it_off_val->second,
776 tmp_reg, elem_size_bytes);
777 break;
778 case injector_utils::layout_t::c_blocked:
779 calculate_oc_blocked_partial(strides, it_off_val->second,
780 tmp_reg, elem_size_bytes);
781 break;
782 case injector_utils::layout_t::nspc:
783 calculate_oc_nspc_partial(strides, it_off_val->second,
784 tmp_reg, elem_size_bytes);
785 break;
786 case injector_utils::layout_t::cspn:
787 calculate_oc_cspn_partial(strides, it_off_val->second,
788 tmp_reg, elem_size_bytes);
789 break;
790 default: assert(!"Unknown layout");
791 }
792 host_->add(addr_reg, tmp_reg);
793 }
794 }
795}
796
797template <cpu_isa_t isa, typename Vmm>
798void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_ncsp_base(
799 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
800 // c = (offset % strides[0]) / strides[1]
801 // output = rax
802 const auto rax = host_->rax;
803 const auto rdx = host_->rdx;
804
805 host_->mov(rax, tmp_reg);
806 host_->mov(tmp_reg, strides[0]);
807 host_->xor_(rdx, rdx);
808 host_->div(tmp_reg);
809 host_->mov(tmp_reg, strides[1]);
810 host_->mov(rax, rdx);
811 host_->xor_(rdx, rdx);
812 host_->div(tmp_reg);
813}
814
815template <cpu_isa_t isa, typename Vmm>
816void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_ncsp_partial(
817 const dim_t *strides, const std::size_t offset,
818 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
819 // c = (offset % strides[0]) / strides[1]
820 const auto offset_adj
821 = ((offset >> math::ilog2q(types::data_type_size(
822 rhs_arg_static_params_.dst_d.data_type())))
823 % strides[0])
824 / strides[1];
825 host_->mov(tmp_reg,
826 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
827 : offset_adj);
828}
829
830template <cpu_isa_t isa, typename Vmm>
831void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_blocked_base(
832 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
833 // c = ((offset % strides[0]) / strides[1]) * strides[ndims - 1] + offset % blk_size
834 // output = rax
835 const auto dst_d = rhs_arg_static_params_.dst_d;
836 const int simd_w = cpu_isa_traits<isa>::vlen
837 / types::data_type_size(dst_d.data_type());
838 const int blk_size = dst_d.blocking_desc().inner_blks[0];
839 const auto rax = host_->rax;
840 const auto rdx = host_->rdx;
841 const auto r8 = host_->r8;
842
843 calculate_oc_ncsp_base(strides, tmp_reg);
844
845 if (blk_size > simd_w) {
846 // extract c % blk_size
847 host_->mov(r8, rax);
848 host_->mov(rax, rdx);
849 host_->mov(tmp_reg, blk_size);
850 host_->xor_(rdx, rdx);
851 host_->div(tmp_reg);
852 host_->mov(rax, r8);
853 host_->mov(r8, rdx);
854 }
855
856 host_->mov(tmp_reg, blk_size);
857 host_->mul(tmp_reg);
858 if (blk_size > simd_w) host_->add(rax, r8);
859}
860
861template <cpu_isa_t isa, typename Vmm>
862void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_blocked_partial(
863 const dim_t *strides, const std::size_t offset,
864 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
865 // c = ((offset % strides[0]) / strides[1]) * strides[ndims - 1] + offset % blk_size
866 const auto dst_d = rhs_arg_static_params_.dst_d;
867 const int blk_size = dst_d.blocking_desc().inner_blks[0];
868 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
869 rhs_arg_static_params_.dst_d.data_type()));
870 const auto offset_adj = ((offset_shr % strides[0]) / strides[1]) * blk_size
871 + offset_shr % blk_size;
872 host_->mov(tmp_reg,
873 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
874 : offset_adj);
875}
876
877template <cpu_isa_t isa, typename Vmm>
878void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_nspc_base(
879 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
880 // c = offset % C
881 // output = rax
882 const auto rax = host_->rax;
883 const auto rdx = host_->rdx;
884 const auto C = rhs_arg_static_params_.dst_d.dims()[1];
885
886 host_->mov(rax, tmp_reg);
887 host_->mov(tmp_reg, C);
888 host_->xor_(rdx, rdx);
889 host_->div(tmp_reg);
890 host_->mov(rax, rdx);
891}
892
893template <cpu_isa_t isa, typename Vmm>
894void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_nspc_partial(
895 const dim_t *strides, const std::size_t offset,
896 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
897 // c = offset % C
898 const auto C = rhs_arg_static_params_.dst_d.dims()[1];
899 const auto offset_adj = (offset >> math::ilog2q(types::data_type_size(
900 rhs_arg_static_params_.dst_d.data_type())))
901 % C;
902 host_->mov(tmp_reg,
903 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
904 : offset_adj);
905}
906
907template <cpu_isa_t isa, typename Vmm>
908void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_cspn_base(
909 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
910 // c = offset / strides[1]
911 // output = rax
912 const auto rax = host_->rax;
913 const auto rdx = host_->rdx;
914
915 host_->mov(rax, tmp_reg);
916 host_->mov(tmp_reg, strides[1]);
917 host_->xor_(rdx, rdx);
918 host_->div(tmp_reg);
919}
920
921template <cpu_isa_t isa, typename Vmm>
922void jit_uni_binary_injector_t<isa, Vmm>::calculate_oc_cspn_partial(
923 const dim_t *strides, const std::size_t offset,
924 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
925 // c = offset / strides[1]
926 const auto offset_adj = (offset >> math::ilog2q(types::data_type_size(
927 rhs_arg_static_params_.dst_d.data_type())))
928 / strides[1];
929 host_->mov(tmp_reg,
930 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
931 : offset_adj);
932}
933
934template <cpu_isa_t isa, typename Vmm>
935void jit_uni_binary_injector_t<isa, Vmm>::append_mb_sp_offset(
936 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
937 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
938 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, int vmm_idx,
939 const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg,
940 std::size_t elem_size_bytes, bool is_first) const {
941
942 const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx);
943 const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx);
944
945 const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end();
946 const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end();
947
948 if (is_out_addr || is_out_reg) {
949 Xbyak::Address out_addr = is_out_addr ? it_out_addr->second
950 : host_->ptr[it_out_reg->second];
951 const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx);
952 const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg;
953
954 const auto dst_d = rhs_arg_static_params_.dst_d;
955 const auto strides = dst_d.blocking_desc().strides;
956 const auto layout = injector_utils::get_layout_type(dst_d);
957
958 if (is_first) {
959 calculate_no_broadcast_base(out_addr, tmp_reg);
960
961 const auto rax = host_->rax;
962 const auto rdx = host_->rdx;
963 const auto r8 = host_->r8;
964 const auto r9 = host_->r9;
965
966 const injector_utils::conditional_register_preserve_guard_t
967 register_guard {is_out_reg
968 ? utils::one_of(it_out_reg->second, rax,
969 rdx, r8, r9)
970 : false,
971 host_, {it_out_reg->second}};
972
973 switch (layout) {
974 case injector_utils::layout_t::ncsp:
975 calculate_mb_sp_ncsp_base(strides, tmp_reg);
976 break;
977 case injector_utils::layout_t::c_blocked:
978 calculate_mb_sp_blocked_base(strides, tmp_reg);
979 break;
980 case injector_utils::layout_t::nspc:
981 calculate_mb_sp_nspc_base(strides, tmp_reg);
982 break;
983 case injector_utils::layout_t::cspn:
984 calculate_mb_sp_cspn_base(strides, tmp_reg);
985 break;
986 default: assert(!"Unknown layout");
987 }
988
989 if (elem_size_bytes == 1) {
990 host_->add(addr_reg, rax);
991 } else {
992 const int shift_val = std::log2(elem_size_bytes);
993 host_->mov(tmp_reg, rax);
994 host_->sal(tmp_reg, shift_val);
995 host_->add(addr_reg, tmp_reg);
996 }
997 host_->mov(addr_cache_reg, addr_reg);
998 } else {
999 host_->mov(addr_reg, addr_cache_reg);
1000 }
1001
1002 if (it_off_val != vmm_idx_to_out_elem_off_val.end()) {
1003 switch (layout) {
1004 case injector_utils::layout_t::ncsp:
1005 calculate_mb_sp_ncsp_partial(strides, it_off_val->second,
1006 tmp_reg, elem_size_bytes);
1007 break;
1008 case injector_utils::layout_t::c_blocked:
1009 calculate_mb_sp_blocked_partial(strides, it_off_val->second,
1010 tmp_reg, elem_size_bytes);
1011 break;
1012 case injector_utils::layout_t::nspc:
1013 calculate_mb_sp_nspc_partial(strides, it_off_val->second,
1014 tmp_reg, elem_size_bytes);
1015 break;
1016 case injector_utils::layout_t::cspn:
1017 calculate_mb_sp_cspn_partial(strides, it_off_val->second,
1018 tmp_reg, elem_size_bytes);
1019 break;
1020 default: assert(!"Unknown layout");
1021 }
1022 host_->add(addr_reg, tmp_reg);
1023 }
1024 }
1025}
1026
1027template <cpu_isa_t isa, typename Vmm>
1028void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_ncsp_base(
1029 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1030 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1031 // mb_sp_off = (n * (stride_n/C)) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1032 // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW)
1033 // output = rax
1034 const auto dst_d = rhs_arg_static_params_.dst_d;
1035 const auto ndims = dst_d.ndims();
1036 const auto C_padded = dst_d.padded_dims()[1];
1037 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1038 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1039 const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1;
1040
1041 const auto rax = host_->rax;
1042 const auto rdx = host_->rdx;
1043 const auto r8 = host_->r8;
1044 const auto r9 = host_->r9;
1045
1046 host_->mov(rax, tmp_reg);
1047 host_->mov(r9, strides[0]);
1048 host_->xor_(rdx, rdx);
1049 host_->div(r9);
1050 host_->mov(r8, rax);
1051 // r8 = n
1052 host_->mov(r9, strides[1]);
1053 host_->mov(rax, rdx);
1054 host_->xor_(rdx, rdx);
1055 host_->div(r9);
1056 host_->mul(r9);
1057 // rax = c * stride_c
1058 host_->sub(tmp_reg, rax);
1059 // tmp_reg = offset - c * stride_c
1060 host_->mov(rax, r8);
1061 // rax = n
1062 host_->mov(r9, (C_padded - 1) * D * H * W);
1063 // n(C - 1)DHW = nCDHW - nDHW
1064 host_->mul(r9);
1065 // rax = n(C - 1)DHW
1066 host_->sub(tmp_reg, rax);
1067 host_->mov(rax, tmp_reg);
1068 // rax = offset - (c * stride_c) - (n * (C - 1)DHW)
1069}
1070
1071template <cpu_isa_t isa, typename Vmm>
1072void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_ncsp_partial(
1073 const dim_t *strides, const std::size_t offset,
1074 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1075 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1076 // mb_sp_off = (n * (stride_n/C)) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1077 // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW)
1078
1079 const auto dst_d = rhs_arg_static_params_.dst_d;
1080 const auto ndims = dst_d.ndims();
1081 const auto C_padded = dst_d.padded_dims()[1];
1082 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1083 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1084 const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1;
1085
1086 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1087 rhs_arg_static_params_.dst_d.data_type()));
1088 const auto c = (offset_shr % strides[0]) / strides[1];
1089 const auto n = offset_shr / strides[0];
1090 const auto offset_adj
1091 = offset_shr - (c * strides[1]) - (n * (C_padded - 1) * D * H * W);
1092 host_->mov(tmp_reg,
1093 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1094 : offset_adj);
1095}
1096
1097template <cpu_isa_t isa, typename Vmm>
1098void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_blocked_base(
1099 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1100 // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW) - c % blk_size
1101 // output = rax
1102 const auto dst_d = rhs_arg_static_params_.dst_d;
1103 const int simd_w = cpu_isa_traits<isa>::vlen
1104 / types::data_type_size(dst_d.data_type());
1105 const int blk_size = dst_d.blocking_desc().inner_blks[0];
1106
1107 const auto rax = host_->rax;
1108 const auto rdx = host_->rdx;
1109 const auto r8 = host_->r8;
1110
1111 if (blk_size > simd_w) {
1112 // substract c % blk_size
1113 host_->mov(r8, tmp_reg);
1114 host_->mov(rax, tmp_reg);
1115 host_->mov(tmp_reg, blk_size);
1116 host_->xor_(rdx, rdx);
1117 host_->div(tmp_reg);
1118 host_->mov(tmp_reg, r8);
1119 host_->sub(tmp_reg, rdx);
1120 }
1121
1122 calculate_mb_sp_ncsp_base(strides, tmp_reg);
1123}
1124
1125template <cpu_isa_t isa, typename Vmm>
1126void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_blocked_partial(
1127 const dim_t *strides, const std::size_t offset,
1128 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1129 // mb_sp_off = offset - (c * stride_c) - (n * (C - 1)DHW) - c % blk_size
1130
1131 const auto dst_d = rhs_arg_static_params_.dst_d;
1132 const auto ndims = dst_d.ndims();
1133 const auto C_padded = dst_d.padded_dims()[1];
1134 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1135 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1136 const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1;
1137 const int blk_size = dst_d.blocking_desc().inner_blks[0];
1138
1139 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1140 rhs_arg_static_params_.dst_d.data_type()));
1141 const auto c = (offset_shr % strides[0]) / strides[1];
1142 const auto n = offset_shr / strides[0];
1143 const auto offset_adj = offset_shr - (c * strides[1])
1144 - (n * (C_padded - 1) * D * H * W) - c % blk_size;
1145 host_->mov(tmp_reg,
1146 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1147 : offset_adj);
1148}
1149
1150template <cpu_isa_t isa, typename Vmm>
1151void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_nspc_base(
1152 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1153 // offset = nDHWC + dHWC + hWC + wC + c
1154 // mb_sp_off = nDHW + dHW + hW + w
1155 // mb_sp_off = offset / C
1156 // output = rax
1157 const auto rax = host_->rax;
1158 const auto rdx = host_->rdx;
1159 const auto C = rhs_arg_static_params_.dst_d.padded_dims()[1];
1160
1161 host_->mov(rax, tmp_reg);
1162 host_->mov(tmp_reg, C);
1163 host_->xor_(rdx, rdx);
1164 host_->div(tmp_reg);
1165}
1166
1167template <cpu_isa_t isa, typename Vmm>
1168void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_nspc_partial(
1169 const dim_t *strides, const std::size_t offset,
1170 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1171 // offset = nDHWC + dHWC + hWC + wC + c
1172 // mb_sp_off = nDHW + dHW + hW + w
1173 // mb_sp_off = offset / C
1174 const auto C = rhs_arg_static_params_.dst_d.padded_dims()[1];
1175 const auto offset_adj = (offset >> math::ilog2q(types::data_type_size(
1176 rhs_arg_static_params_.dst_d.data_type())))
1177 / C;
1178 host_->mov(tmp_reg,
1179 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1180 : offset_adj);
1181}
1182
1183template <cpu_isa_t isa, typename Vmm>
1184void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_cspn_base(
1185 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1186 // offset = cDHWN + dHWN + hWN + wN + n
1187 // mb_sp_off = dHWN + hWN + wN + n
1188 // mb_sp_off = offset % stride_c
1189 // output = rax
1190 const auto rax = host_->rax;
1191 const auto rdx = host_->rdx;
1192
1193 host_->mov(rax, tmp_reg);
1194 host_->mov(tmp_reg, strides[1]);
1195 host_->xor_(rdx, rdx);
1196 host_->div(tmp_reg);
1197 host_->mov(rax, rdx);
1198}
1199
1200template <cpu_isa_t isa, typename Vmm>
1201void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_sp_cspn_partial(
1202 const dim_t *strides, const std::size_t offset,
1203 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1204 // offset = cDHWN + dHWN + hWN + wN + n
1205 // mb_sp_off = dHWN + hWN + wN + n
1206 // mb_sp_off = offset % stride_c
1207 const auto offset_adj = (offset >> math::ilog2q(types::data_type_size(
1208 rhs_arg_static_params_.dst_d.data_type())))
1209 % strides[1];
1210 host_->mov(tmp_reg,
1211 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1212 : offset_adj);
1213}
1214
1215template <cpu_isa_t isa, typename Vmm>
1216void jit_uni_binary_injector_t<isa, Vmm>::append_mb_w_offset(
1217 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
1218 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
1219 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, int vmm_idx,
1220 const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg,
1221 std::size_t elem_size_bytes, bool is_first) const {
1222
1223 const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx);
1224 const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx);
1225
1226 const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end();
1227 const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end();
1228
1229 if (is_out_addr || is_out_reg) {
1230 Xbyak::Address out_addr = is_out_addr ? it_out_addr->second
1231 : host_->ptr[it_out_reg->second];
1232 const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx);
1233 const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg;
1234
1235 const auto dst_d = rhs_arg_static_params_.dst_d;
1236 const auto strides = dst_d.blocking_desc().strides;
1237 const auto layout = injector_utils::get_layout_type(dst_d);
1238
1239 if (is_first) {
1240 calculate_no_broadcast_base(out_addr, tmp_reg);
1241
1242 const auto rax = host_->rax;
1243 const auto rdx = host_->rdx;
1244 const auto r8 = host_->r8;
1245 const auto r9 = host_->r9;
1246
1247 const injector_utils::conditional_register_preserve_guard_t
1248 register_guard {is_out_reg
1249 ? utils::one_of(it_out_reg->second, rax,
1250 rdx, r8, r9)
1251 : false,
1252 host_, {it_out_reg->second}};
1253
1254 switch (layout) {
1255 case injector_utils::layout_t::ncsp:
1256 calculate_mb_w_ncsp_base(strides, tmp_reg);
1257 break;
1258 case injector_utils::layout_t::c_blocked:
1259 calculate_mb_w_blocked_base(strides, tmp_reg);
1260 break;
1261 case injector_utils::layout_t::nspc:
1262 calculate_mb_w_nspc_base(strides, tmp_reg);
1263 break;
1264 case injector_utils::layout_t::cspn:
1265 calculate_mb_w_cspn_base(strides, tmp_reg);
1266 break;
1267 default: assert(!"Unknown layout");
1268 }
1269
1270 if (elem_size_bytes == 1) {
1271 host_->add(addr_reg, rax);
1272 } else {
1273 const int shift_val = std::log2(elem_size_bytes);
1274 host_->mov(tmp_reg, rax);
1275 host_->sal(tmp_reg, shift_val);
1276 host_->add(addr_reg, tmp_reg);
1277 }
1278 host_->mov(addr_cache_reg, addr_reg);
1279 } else {
1280 host_->mov(addr_reg, addr_cache_reg);
1281 }
1282
1283 if (it_off_val != vmm_idx_to_out_elem_off_val.end()) {
1284 switch (layout) {
1285 case injector_utils::layout_t::ncsp:
1286 calculate_mb_w_ncsp_partial(strides, it_off_val->second,
1287 tmp_reg, elem_size_bytes);
1288 break;
1289 case injector_utils::layout_t::c_blocked:
1290 calculate_mb_w_blocked_partial(strides, it_off_val->second,
1291 tmp_reg, elem_size_bytes);
1292 break;
1293 case injector_utils::layout_t::nspc:
1294 calculate_mb_w_nspc_partial(strides, it_off_val->second,
1295 tmp_reg, elem_size_bytes);
1296 break;
1297 case injector_utils::layout_t::cspn:
1298 calculate_mb_w_cspn_partial(strides, it_off_val->second,
1299 tmp_reg, elem_size_bytes);
1300 break;
1301 default: assert(!"Unknown layout");
1302 }
1303 host_->add(addr_reg, tmp_reg);
1304 }
1305 }
1306}
1307
1308template <cpu_isa_t isa, typename Vmm>
1309void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_ncsp_base(
1310 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1311 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1312 // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w)
1313 // output = rax
1314 const auto dst_d = rhs_arg_static_params_.dst_d;
1315 const auto ndims = dst_d.ndims();
1316 const auto C_padded = dst_d.padded_dims()[1];
1317 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1318 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1319
1320 const auto rax = host_->rax;
1321 const auto rdx = host_->rdx;
1322 const auto r8 = host_->r8;
1323 const auto r9 = host_->r9;
1324
1325 host_->mov(rax, tmp_reg);
1326 host_->mov(r9, strides[0]);
1327 host_->xor_(rdx, rdx);
1328 host_->div(r9);
1329 host_->mov(r8, rax);
1330 // r8 = n
1331
1332 host_->mov(r9, strides[1]);
1333 host_->mov(rax, rdx);
1334 host_->xor_(rdx, rdx);
1335 host_->div(r9);
1336
1337 if (ndims >= 5) {
1338 host_->mov(r9, strides[ndims - 3]);
1339 host_->mov(rax, rdx);
1340 host_->xor_(rdx, rdx);
1341 host_->div(r9);
1342 }
1343 if (ndims >= 4) {
1344 host_->mov(r9, strides[ndims - 2]);
1345 host_->mov(rax, rdx);
1346 host_->xor_(rdx, rdx);
1347 host_->div(r9);
1348 }
1349 if (ndims >= 3) {
1350 host_->mov(r9, strides[ndims - 1]);
1351 host_->mov(rax, rdx);
1352 host_->xor_(rdx, rdx);
1353 host_->div(r9);
1354 host_->mul(r9);
1355 host_->mov(tmp_reg, rax);
1356 // tmp_reg = w * stride_w
1357 }
1358 // tmp_reg = w * stride_w
1359 host_->mov(rax, r8);
1360 // rax = n
1361 host_->mov(r9, strides[0] / (C_padded * D * H));
1362 host_->mul(r9);
1363 // rax = n * (stride_n/(C*D*H))
1364 if (ndims >= 3) host_->add(rax, tmp_reg);
1365 // rax = (n * (stride_n/(C*D*H))) + (w * stride_w)
1366}
1367
1368template <cpu_isa_t isa, typename Vmm>
1369void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_ncsp_partial(
1370 const dim_t *strides, const std::size_t offset,
1371 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1372 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1373 // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w)
1374 const auto dst_d = rhs_arg_static_params_.dst_d;
1375 const auto ndims = dst_d.ndims();
1376 const auto C_padded = dst_d.padded_dims()[1];
1377 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1378 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1379
1380 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1381 rhs_arg_static_params_.dst_d.data_type()));
1382 const auto n = offset_shr / strides[0];
1383 const auto w = (offset_shr % strides[ndims - 2]) / strides[ndims - 1];
1384 const auto offset_adj = (n * (strides[0] / (C_padded * D * H)))
1385 + (w * strides[ndims - 1]);
1386 host_->mov(tmp_reg,
1387 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1388 : offset_adj);
1389}
1390
1391template <cpu_isa_t isa, typename Vmm>
1392void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_blocked_base(
1393 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1394 // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w)
1395 // output = rax
1396 calculate_mb_sp_ncsp_base(strides, tmp_reg);
1397}
1398
1399template <cpu_isa_t isa, typename Vmm>
1400void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_blocked_partial(
1401 const dim_t *strides, const std::size_t offset,
1402 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1403 // mb_w_off = (n * (stride_n/(C*D*H))) + (w * stride_w)
1404 calculate_mb_w_ncsp_partial(strides, offset, tmp_reg, elem_size_bytes);
1405}
1406
1407template <cpu_isa_t isa, typename Vmm>
1408void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_nspc_base(
1409 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1410 // offset = nDHWC + dHWC + hWC + wC + c
1411 // mb_w_off = nW + w
1412 // output = rax
1413
1414 const auto dst_d = rhs_arg_static_params_.dst_d;
1415 const auto ndims = dst_d.ndims();
1416 const auto C_padded = dst_d.padded_dims()[1];
1417 const auto D = (ndims >= 5) ? dst_d.dims()[ndims - 3] : 1;
1418 const auto H = (ndims >= 4) ? dst_d.dims()[ndims - 2] : 1;
1419
1420 const auto rax = host_->rax;
1421 const auto rdx = host_->rdx;
1422 const auto r8 = host_->r8;
1423 const auto r9 = host_->r9;
1424
1425 host_->mov(rax, tmp_reg);
1426 host_->mov(r9, strides[0]);
1427 host_->xor_(rdx, rdx);
1428 host_->div(r9);
1429 host_->mov(r8, rax);
1430 // r8 = n
1431 if (ndims >= 5) {
1432 host_->mov(r9, strides[ndims - 3]);
1433 host_->mov(rax, rdx);
1434 host_->xor_(rdx, rdx);
1435 host_->div(r9);
1436 }
1437 if (ndims >= 4) {
1438 host_->mov(r9, strides[ndims - 2]);
1439 host_->mov(rax, rdx);
1440 host_->xor_(rdx, rdx);
1441 host_->div(r9);
1442 }
1443 if (ndims >= 3) {
1444 host_->mov(r9, strides[ndims - 1]);
1445 host_->mov(rax, rdx);
1446 host_->xor_(rdx, rdx);
1447 host_->div(r9);
1448 host_->mov(tmp_reg, rax);
1449 // tmp_reg = w
1450 }
1451 host_->mov(rax, r8);
1452 // rax = n
1453 host_->mov(r9, strides[0] / (D * H * C_padded));
1454 host_->mul(r9);
1455 // rax = nW
1456 if (ndims >= 3) host_->add(rax, tmp_reg);
1457}
1458
1459template <cpu_isa_t isa, typename Vmm>
1460void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_nspc_partial(
1461 const dim_t *strides, const std::size_t offset,
1462 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1463 // offset = nDHWC + dHWC + hWC + wC + c
1464 // mb_w_off = nW + w
1465 const auto dst_d = rhs_arg_static_params_.dst_d;
1466 const auto ndims = dst_d.ndims();
1467 const auto W = (ndims >= 3) ? dst_d.dims()[ndims - 1] : 1;
1468
1469 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1470 rhs_arg_static_params_.dst_d.data_type()));
1471 const auto n = offset_shr / strides[0];
1472 const auto w = (offset_shr % strides[ndims >= 4 ? ndims - 2 : 0])
1473 / strides[ndims - 1];
1474 const auto offset_adj = n * W + w;
1475 host_->mov(tmp_reg,
1476 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1477 : offset_adj);
1478}
1479
1480template <cpu_isa_t isa, typename Vmm>
1481void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_cspn_base(
1482 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1483 // offset = cDHWN + dHWN + hWN + wN + n
1484 // mb_w_off = wN + n
1485 // output = rax
1486 const auto ndims = rhs_arg_static_params_.dst_d.ndims();
1487 const auto rax = host_->rax;
1488 const auto rdx = host_->rdx;
1489
1490 host_->mov(rax, tmp_reg);
1491 host_->mov(tmp_reg, strides[1]);
1492 host_->xor_(rdx, rdx);
1493 host_->div(tmp_reg);
1494 host_->mov(rax, rdx);
1495 if (ndims >= 5) {
1496 host_->mov(tmp_reg, strides[ndims - 3]);
1497 host_->mov(rax, rdx);
1498 host_->xor_(rdx, rdx);
1499 host_->div(tmp_reg);
1500 }
1501 if (ndims >= 4) {
1502 host_->mov(tmp_reg, strides[ndims - 2]);
1503 host_->mov(rax, rdx);
1504 host_->xor_(rdx, rdx);
1505 host_->div(tmp_reg);
1506 }
1507}
1508
1509template <cpu_isa_t isa, typename Vmm>
1510void jit_uni_binary_injector_t<isa, Vmm>::calculate_mb_w_cspn_partial(
1511 const dim_t *strides, const std::size_t offset,
1512 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1513 // offset = cDHWN + dHWN + hWN + wN + n
1514 // mb_w_off = wN + n
1515 const auto ndims = rhs_arg_static_params_.dst_d.ndims();
1516 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1517 rhs_arg_static_params_.dst_d.data_type()));
1518 const auto offset_adj
1519 = ndims >= 4 ? offset_shr % strides[ndims - 2] : offset_shr;
1520 host_->mov(tmp_reg,
1521 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1522 : offset_adj);
1523}
1524
1525template <cpu_isa_t isa, typename Vmm>
1526void jit_uni_binary_injector_t<isa, Vmm>::append_w_offset(
1527 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
1528 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
1529 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, int vmm_idx,
1530 const Xbyak::Reg64 &addr_reg, const Xbyak::Reg64 &tmp_reg,
1531 std::size_t elem_size_bytes, bool is_first) const {
1532
1533 const auto it_out_addr = vmm_idx_to_out_addr.find(vmm_idx);
1534 const auto it_out_reg = vmm_idx_to_out_reg.find(vmm_idx);
1535
1536 const bool is_out_addr = it_out_addr != vmm_idx_to_out_addr.end();
1537 const bool is_out_reg = it_out_reg != vmm_idx_to_out_reg.end();
1538
1539 if (is_out_addr || is_out_reg) {
1540 Xbyak::Address out_addr = is_out_addr ? it_out_addr->second
1541 : host_->ptr[it_out_reg->second];
1542 const auto it_off_val = vmm_idx_to_out_elem_off_val.find(vmm_idx);
1543 const auto &addr_cache_reg = rhs_arg_static_params_.rhs_addr_cache_reg;
1544
1545 const auto dst_d = rhs_arg_static_params_.dst_d;
1546 const auto strides = dst_d.blocking_desc().strides;
1547 const auto layout = injector_utils::get_layout_type(dst_d);
1548
1549 if (is_first) {
1550 calculate_no_broadcast_base(out_addr, tmp_reg);
1551
1552 const auto rax = host_->rax;
1553 const auto rdx = host_->rdx;
1554 const auto r8 = host_->r8;
1555
1556 const injector_utils::conditional_register_preserve_guard_t
1557 register_guard {is_out_reg ? utils::one_of(
1558 it_out_reg->second, rax, rdx, r8)
1559 : false,
1560 host_, {it_out_reg->second}};
1561
1562 switch (layout) {
1563 case injector_utils::layout_t::ncsp:
1564 calculate_w_ncsp_base(strides, tmp_reg);
1565 break;
1566 case injector_utils::layout_t::c_blocked:
1567 calculate_w_blocked_base(strides, tmp_reg);
1568 break;
1569 case injector_utils::layout_t::nspc:
1570 calculate_w_nspc_base(strides, tmp_reg);
1571 break;
1572 case injector_utils::layout_t::cspn:
1573 calculate_w_cspn_base(strides, tmp_reg);
1574 break;
1575 default: assert(!"Unknown layout");
1576 }
1577
1578 if (elem_size_bytes == 1) {
1579 host_->add(addr_reg, rax);
1580 } else {
1581 const int shift_val = std::log2(elem_size_bytes);
1582 host_->mov(tmp_reg, rax);
1583 host_->sal(tmp_reg, shift_val);
1584 host_->add(addr_reg, tmp_reg);
1585 }
1586 host_->mov(addr_cache_reg, addr_reg);
1587 } else {
1588 host_->mov(addr_reg, addr_cache_reg);
1589 }
1590
1591 if (it_off_val != vmm_idx_to_out_elem_off_val.end()) {
1592 switch (layout) {
1593 case injector_utils::layout_t::ncsp:
1594 calculate_w_ncsp_partial(strides, it_off_val->second,
1595 tmp_reg, elem_size_bytes);
1596 break;
1597 case injector_utils::layout_t::c_blocked:
1598 calculate_w_blocked_partial(strides, it_off_val->second,
1599 tmp_reg, elem_size_bytes);
1600 break;
1601 case injector_utils::layout_t::nspc:
1602 calculate_w_nspc_partial(strides, it_off_val->second,
1603 tmp_reg, elem_size_bytes);
1604 break;
1605 case injector_utils::layout_t::cspn:
1606 calculate_w_cspn_partial(strides, it_off_val->second,
1607 tmp_reg, elem_size_bytes);
1608 break;
1609 default: assert(!"Unknown layout");
1610 }
1611 host_->add(addr_reg, tmp_reg);
1612 }
1613 }
1614}
1615
1616template <cpu_isa_t isa, typename Vmm>
1617void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_ncsp_base(
1618 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1619 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1620 // w_off = w * stride_w
1621 // output = rax
1622 const auto dst_d = rhs_arg_static_params_.dst_d;
1623 const auto ndims = dst_d.ndims();
1624
1625 const auto rax = host_->rax;
1626 const auto rdx = host_->rdx;
1627 const auto r8 = host_->r8;
1628
1629 assert(ndims >= 3);
1630
1631 host_->mov(rax, tmp_reg);
1632 host_->mov(r8, strides[ndims - 2]);
1633 host_->xor_(rdx, rdx);
1634 host_->div(r8);
1635
1636 host_->mov(r8, strides[ndims - 1]);
1637 host_->mov(rax, rdx);
1638 host_->xor_(rdx, rdx);
1639 host_->div(r8);
1640 host_->mul(r8);
1641 // rax = w * stride_w
1642}
1643
1644template <cpu_isa_t isa, typename Vmm>
1645void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_ncsp_partial(
1646 const dim_t *strides, const std::size_t offset,
1647 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1648 // offset = (n * stride_n) + (c * stride_c) + (d * stride_d) + (h * stride_h) + (w * stride_w)
1649 // w_off = w * stride_w
1650 const auto ndims = rhs_arg_static_params_.dst_d.ndims();
1651 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1652 rhs_arg_static_params_.dst_d.data_type()));
1653 const auto w = (offset_shr % strides[ndims - 2]) / strides[ndims - 1];
1654 const auto offset_adj = w * strides[ndims - 1];
1655 host_->mov(tmp_reg,
1656 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1657 : offset_adj);
1658}
1659
1660template <cpu_isa_t isa, typename Vmm>
1661void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_blocked_base(
1662 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1663 calculate_w_ncsp_base(strides, tmp_reg);
1664}
1665
1666template <cpu_isa_t isa, typename Vmm>
1667void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_blocked_partial(
1668 const dim_t *strides, const std::size_t offset,
1669 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1670 calculate_w_ncsp_partial(strides, offset, tmp_reg, elem_size_bytes);
1671}
1672
1673template <cpu_isa_t isa, typename Vmm>
1674void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_nspc_base(
1675 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1676 // offset = nDHWC + dHWC + hWC + wC + c
1677 // w_off = w
1678 // output = rax
1679 const auto dst_d = rhs_arg_static_params_.dst_d;
1680 const auto ndims = dst_d.ndims();
1681
1682 const auto rax = host_->rax;
1683 const auto rdx = host_->rdx;
1684 const auto r8 = host_->r8;
1685
1686 assert(ndims >= 3);
1687
1688 host_->mov(rax, tmp_reg);
1689 host_->mov(r8, strides[ndims - 2]);
1690 host_->xor_(rdx, rdx);
1691 host_->div(r8);
1692
1693 host_->mov(r8, strides[ndims - 1]);
1694 host_->mov(rax, rdx);
1695 host_->xor_(rdx, rdx);
1696 host_->div(r8);
1697 // rax = w
1698}
1699
1700template <cpu_isa_t isa, typename Vmm>
1701void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_nspc_partial(
1702 const dim_t *strides, const std::size_t offset,
1703 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1704 // offset = nDHWC + dHWC + hWC + wC + c
1705 // w_off = w
1706 const auto ndims = rhs_arg_static_params_.dst_d.ndims();
1707 const auto offset_shr = offset >> math::ilog2q(types::data_type_size(
1708 rhs_arg_static_params_.dst_d.data_type()));
1709 const auto offset_adj
1710 = (offset_shr % strides[ndims - 2]) / strides[ndims - 1];
1711 host_->mov(tmp_reg,
1712 elem_size_bytes > 1 ? offset_adj << math::ilog2q(elem_size_bytes)
1713 : offset_adj);
1714}
1715
1716template <cpu_isa_t isa, typename Vmm>
1717void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_cspn_base(
1718 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const {
1719 // offset = cDHWN + dHWN + hWN + wN + n
1720 // w_off = w
1721 calculate_w_nspc_base(strides, tmp_reg);
1722}
1723
1724template <cpu_isa_t isa, typename Vmm>
1725void jit_uni_binary_injector_t<isa, Vmm>::calculate_w_cspn_partial(
1726 const dim_t *strides, const std::size_t offset,
1727 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes) const {
1728 // offset = cDHWN + dHWN + hWN + wN + n
1729 // w_off = w
1730 calculate_w_nspc_partial(strides, offset, tmp_reg, elem_size_bytes);
1731}
1732
1733template <cpu_isa_t isa, typename Vmm>
1734void jit_uni_binary_injector_t<isa, Vmm>::inject_binary(
1735 const dnnl_post_ops::entry_t &post_op, Vmm dst,
1736 const Xbyak::Address &rhs_addr, bool with_tail,
1737 const tail_lode_mode_t tail_load_mode) const {
1738
1739 const auto &alg = post_op.binary.alg;
1740 const bool cmp_op = utils::one_of(alg, alg_kind::binary_ge,
1741 alg_kind::binary_gt, alg_kind::binary_le, alg_kind::binary_lt,
1742 alg_kind::binary_eq, alg_kind::binary_ne);
1743 const auto &rhs_arg_data_type = post_op.binary.src1_desc.data_type;
1744 const bool scalar_f32
1745 = rhs_addr.isBroadcast() && rhs_arg_data_type == data_type::f32;
1746 const bool with_tail_not_fusable_to_binary_op
1747 = with_tail && !(scalar_f32 && is_avx512_);
1748 const bool process_rhs_arg_using_tmp_vmm
1749 = rhs_arg_data_type != data_type::f32 || (scalar_f32 && !is_avx512_)
1750 || with_tail_not_fusable_to_binary_op
1751 || !binary_op_with_unaligned_mem_operand_allowed_
1752 || (cmp_op && !is_avx512_);
1753
1754 if (process_rhs_arg_using_tmp_vmm) {
1755
1756 const Vmm tmp_vmm = Vmm(rhs_arg_static_params_.rhs_dt_helper_vmm_idx);
1757
1758 if (rhs_addr.isBroadcast())
1759 execute_broadcast(rhs_arg_data_type, tmp_vmm,
1760 remove_bcast_bit(rhs_addr), tail_load_mode, with_tail);
1761 else
1762 load_rhs(rhs_arg_data_type, tmp_vmm, rhs_addr, tail_load_mode,
1763 with_tail);
1764
1765 if (types::is_integral_dt(rhs_arg_data_type)) cvt_to_f32(tmp_vmm);
1766
1767 execute_binary(alg, dst, dst, tmp_vmm);
1768 } else {
1769 const auto lhs = dst;
1770 const bool with_tail_fusable_to_binary_op
1771 = with_tail && scalar_f32 && is_avx512_;
1772 if (with_tail_fusable_to_binary_op) {
1773 assert(rhs_arg_static_params_.is_opmask_set()
1774 && "Opmask is not set for tail loading avx512");
1775 const auto &tail_opmask = rhs_arg_static_params_.tail_opmask;
1776 dst = dst | tail_opmask | host_->T_z;
1777 }
1778
1779 execute_binary(alg, dst, lhs, rhs_addr);
1780 }
1781}
1782
1783template <cpu_isa_t isa, typename Vmm>
1784void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast(
1785 const data_type_t &data_type, const Vmm &tmp_reg,
1786 const Xbyak::Address &rhs_addr, const tail_lode_mode_t tail_load_mode,
1787 bool with_tail) const {
1788 if (with_tail) {
1789 if (tail_load_mode == tail_lode_mode_t::DYNAMIC
1790 || (tail_load_mode == tail_lode_mode_t::DEFAULT
1791 && is_avx512_)) {
1792 if (is_avx512_)
1793 execute_broadcast_tail_with_opmask(
1794 data_type, tmp_reg, rhs_addr);
1795 else
1796 execute_broadcast_tail_with_gpr(data_type, tmp_reg, rhs_addr);
1797 } else
1798 execute_broadcast_tail_statically(data_type, tmp_reg, rhs_addr,
1799 rhs_arg_static_params_.tail_size);
1800 } else
1801 execute_broadcast_no_tail(data_type, tmp_reg, rhs_addr);
1802}
1803
1804template <cpu_isa_t isa, typename Vmm>
1805void jit_uni_binary_injector_t<isa, Vmm>::load_rhs(const data_type_t &data_type,
1806 const Vmm &tmp_reg, const Xbyak::Address &rhs_addr,
1807 const tail_lode_mode_t tail_load_mode, bool with_tail) const {
1808 if (with_tail) {
1809 if (tail_load_mode == tail_lode_mode_t::DYNAMIC
1810 || (tail_load_mode == tail_lode_mode_t::DEFAULT
1811 && is_avx512_)) {
1812 if (is_avx512_)
1813 load_rhs_tail_dynamically_with_opmask(
1814 data_type, tmp_reg, rhs_addr);
1815 else
1816 load_rhs_tail_dynamically_with_gpr(data_type, tmp_reg);
1817 } else
1818 load_rhs_tail_statically(data_type, tmp_reg, rhs_addr);
1819 } else
1820 load_rhs_no_tail(data_type, tmp_reg, rhs_addr);
1821}
1822
1823template <cpu_isa_t isa, typename Vmm>
1824Xbyak::Address jit_uni_binary_injector_t<isa, Vmm>::remove_bcast_bit(
1825 const Xbyak::Address &rhs_addr) const {
1826 return Xbyak::Address(rhs_addr.getBit(), false, rhs_addr.getRegExp());
1827}
1828
1829template <cpu_isa_t isa, typename Vmm>
1830void jit_uni_binary_injector_t<isa, Vmm>::cvt_to_f32(const Vmm &tmp_vmm) const {
1831 host_->vcvtdq2ps(tmp_vmm, tmp_vmm);
1832}
1833
1834template <>
1835void jit_uni_binary_injector_t<sse41, Xbyak::Xmm>::cvt_to_f32(
1836 const Xbyak::Xmm &tmp_vmm) const {
1837 host_->cvtdq2ps(tmp_vmm, tmp_vmm);
1838}
1839
1840template <cpu_isa_t isa, typename Vmm>
1841void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast_no_tail(
1842 const data_type_t &data_type, const Vmm &tmp_vmm,
1843 const Xbyak::Address &rhs_addr) const {
1844 assert(is_data_supported(isa, data_type) && "unsupported data type");
1845 switch (data_type) {
1846 case data_type::f32: host_->uni_vbroadcastss(tmp_vmm, rhs_addr); break;
1847 case data_type::s32: host_->uni_vpbroadcastd(tmp_vmm, rhs_addr); break;
1848 case data_type::s8:
1849 case data_type::u8:
1850 execute_broadcast_s8u8_no_tail(data_type, tmp_vmm, rhs_addr);
1851 break;
1852 case data_type::f16:
1853 if (is_avx512_core_fp16_)
1854 host_->vcvtph2psx(tmp_vmm, host_->ptr_b[rhs_addr.getRegExp()]);
1855 else if (isa == avx2_vnni_2)
1856 host_->vbcstnesh2ps(tmp_vmm, rhs_addr);
1857 else
1858 assert(!"unsupported ISA for given data type");
1859 break;
1860 case data_type::bf16:
1861 if (is_avx512_) {
1862 host_->vpbroadcastw(tmp_vmm, rhs_addr);
1863 host_->vpslld(tmp_vmm, tmp_vmm, 0x10);
1864 } else if (isa == avx2_vnni_2) {
1865 host_->vbcstnebf162ps(tmp_vmm, rhs_addr);
1866 } else
1867 assert(!"unsupported ISA for given data type");
1868 break;
1869 default: assert(!"unsupported data type");
1870 }
1871}
1872
1873template <cpu_isa_t isa, typename Vmm>
1874void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast_s8u8_no_tail(
1875 const data_type_t &data_type, const Vmm &tmp_vmm,
1876 const Xbyak::Address &rhs_addr) const {
1877 assert(utils::one_of(data_type, data_type::s8, data_type::u8)
1878 && "unsupported data type");
1879
1880 const Xbyak::Xmm xmm(tmp_vmm.getIdx());
1881
1882 host_->uni_vpinsrb(xmm, xmm, rhs_addr, 0);
1883 if (data_type == data_type::s8)
1884 host_->uni_vpmovsxbd(xmm, xmm);
1885 else if (data_type == data_type::u8)
1886 host_->uni_vpmovzxbd(tmp_vmm, xmm);
1887 host_->uni_vpbroadcastd(tmp_vmm, xmm);
1888}
1889
1890template <cpu_isa_t isa, typename Vmm>
1891struct helper_broadcast_s8u8_t {};
1892
1893template <typename Vmm>
1894struct helper_broadcast_s8u8_t<avx, Vmm> {
1895 static void execute_broadcast_s8u8_no_tail(jit_generator *host,
1896 const int rhs_helper_reg_idx, const data_type_t &data_type,
1897 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr,
1898 const std::function<void()> &post_process) {
1899
1900 if (data_type != data_type::s8 && data_type != data_type::u8)
1901 assert(!"unsupported data type");
1902
1903 const Xbyak::Reg8 tmp_reg8 = Xbyak::Reg8(rhs_helper_reg_idx);
1904 const Xbyak::Reg32 tmp_reg32 = Xbyak::Reg32(rhs_helper_reg_idx);
1905 const auto tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
1906 host->mov(tmp_reg8, rhs_addr);
1907 host->vmovd(tmp_xmm, tmp_reg32);
1908 host->vpunpcklbw(tmp_xmm, tmp_xmm, tmp_xmm);
1909 host->vpshuflw(tmp_xmm, tmp_xmm, 0);
1910 if (data_type == data_type::s8)
1911 host->vpmovsxbd(tmp_xmm, tmp_xmm);
1912 else
1913 host->vpmovzxbd(tmp_xmm, tmp_xmm);
1914
1915 if (post_process) post_process();
1916 }
1917};
1918
1919template <>
1920void jit_uni_binary_injector_t<avx, Xbyak::Ymm>::execute_broadcast_s8u8_no_tail(
1921 const data_type_t &data_type, const Xbyak::Ymm &tmp_vmm,
1922 const Xbyak::Address &rhs_addr) const {
1923
1924 const auto rhs_helper_reg_idx
1925 = rhs_arg_static_params_.rhs_helper_reg.getIdx();
1926 const auto expand_xmm_to_ymm = [&] {
1927 const auto tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
1928 host_->vinsertf128(tmp_vmm, tmp_vmm, tmp_xmm, 1);
1929 };
1930
1931 helper_broadcast_s8u8_t<avx, Xbyak::Ymm>::execute_broadcast_s8u8_no_tail(
1932 host_, rhs_helper_reg_idx, data_type, tmp_vmm, rhs_addr,
1933 expand_xmm_to_ymm);
1934}
1935
1936template <>
1937void jit_uni_binary_injector_t<avx, Xbyak::Xmm>::execute_broadcast_s8u8_no_tail(
1938 const data_type_t &data_type, const Xbyak::Xmm &tmp_vmm,
1939 const Xbyak::Address &rhs_addr) const {
1940
1941 const auto rhs_helper_reg_idx
1942 = rhs_arg_static_params_.rhs_helper_reg.getIdx();
1943 helper_broadcast_s8u8_t<avx, Xbyak::Xmm>::execute_broadcast_s8u8_no_tail(
1944 host_, rhs_helper_reg_idx, data_type, tmp_vmm, rhs_addr, nullptr);
1945}
1946
1947template <>
1948void jit_uni_binary_injector_t<sse41,
1949 Xbyak::Xmm>::execute_broadcast_s8u8_no_tail(const data_type_t
1950 &data_type,
1951 const Xbyak::Xmm &tmp_vmm, const Xbyak::Address &rhs_addr) const {
1952
1953 if (data_type == data_type::s8 || data_type == data_type::u8) {
1954 const auto tmp_reg64_idx
1955 = rhs_arg_static_params_.rhs_helper_reg.getIdx();
1956 const Xbyak::Reg8 tmp_reg8 = Xbyak::Reg8(tmp_reg64_idx);
1957 host_->mov(tmp_reg8, rhs_addr);
1958 const Xbyak::Reg32 tmp_reg32 = Xbyak::Reg32(tmp_reg64_idx);
1959 host_->movd(tmp_vmm, tmp_reg32);
1960 host_->punpcklbw(tmp_vmm, tmp_vmm);
1961 host_->pshuflw(tmp_vmm, tmp_vmm, 0);
1962 if (data_type == data_type::s8)
1963 host_->pmovsxbd(tmp_vmm, tmp_vmm);
1964 else
1965 host_->pmovzxbd(tmp_vmm, tmp_vmm);
1966 } else
1967 assert(!"unsupported data type");
1968}
1969
1970template <cpu_isa_t isa, typename Vmm>
1971void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast_tail_with_opmask(
1972 const data_type_t &data_type, const Vmm &tmp_vmm,
1973 const Xbyak::Address &rhs_addr) const {
1974
1975 assert(is_data_supported(isa, data_type) && "unsupported data type");
1976 assert(rhs_arg_static_params_.is_opmask_set()
1977 && "Opmask is not set for tail loading avx512");
1978 const auto &tail_opmask = rhs_arg_static_params_.tail_opmask;
1979
1980 switch (data_type) {
1981 case data_type::f32:
1982 host_->vbroadcastss(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
1983 break;
1984 case data_type::s32:
1985 host_->vpbroadcastd(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
1986 break;
1987 case data_type::s8:
1988 case data_type::u8: {
1989 const Xbyak::Xmm xmm(tmp_vmm.getIdx());
1990
1991 host_->uni_vpinsrb(xmm, xmm, rhs_addr, 0);
1992 if (data_type == data_type::s8)
1993 host_->uni_vpmovsxbd(xmm, xmm);
1994 else if (data_type == data_type::u8)
1995 host_->uni_vpmovzxbd(xmm, xmm);
1996 host_->uni_vpbroadcastd(tmp_vmm | tail_opmask | host_->T_z, xmm);
1997 break;
1998 }
1999 case data_type::f16:
2000 if (is_avx512_core_fp16_)
2001 host_->vcvtph2psx(tmp_vmm | tail_opmask | host_->T_z,
2002 host_->ptr_b[rhs_addr.getRegExp()]);
2003 else
2004 assert(!"unsupported masked tail processing");
2005 break;
2006 case data_type::bf16:
2007 host_->vpbroadcastw(tmp_vmm, rhs_addr);
2008 host_->vpslld(tmp_vmm | tail_opmask | host_->T_z, tmp_vmm, 0x10);
2009 break;
2010 default: return;
2011 }
2012}
2013
2014static constexpr int xmm_size_elem = 4;
2015
2016static void load_tail_avx(jit_generator *host, std::size_t ymm_idx,
2017 std::size_t tail_size, const std::function<void()> &init_op,
2018 const std::function<void(int, bool)> &ymm_upper_half_op,
2019 const std::function<void(int)> &ymm_lower_half_op) {
2020
2021 if (init_op) init_op();
2022
2023 const auto res = std::div(tail_size, xmm_size_elem);
2024 const auto &ymm_upper_half_op_data_size = res.rem;
2025 const bool should_load_lower_half = res.quot;
2026
2027 if (ymm_upper_half_op_data_size && ymm_upper_half_op)
2028 ymm_upper_half_op(ymm_upper_half_op_data_size, should_load_lower_half);
2029
2030 if (should_load_lower_half) {
2031 const auto tmp_xmm = Xbyak::Xmm(ymm_idx);
2032
2033 if (ymm_upper_half_op_data_size) push_vmm(host, tmp_xmm);
2034
2035 if (ymm_lower_half_op) ymm_lower_half_op(ymm_upper_half_op_data_size);
2036
2037 if (ymm_upper_half_op_data_size) {
2038 const auto tmp_ymm = Xbyak::Ymm(ymm_idx);
2039 host->vinsertf128(tmp_ymm, tmp_ymm, host->ptr[host->rsp], 1);
2040 restore_stack(host, tmp_xmm);
2041 }
2042 }
2043}
2044
2045static void load_tail_avx(jit_generator *host, std::size_t ymm_idx,
2046 std::size_t tail_size,
2047 const std::function<void(int, bool)> &ymm_upper_half_op,
2048 const std::function<void(int)> &ymm_lower_half_op) {
2049 load_tail_avx(host, ymm_idx, tail_size, nullptr, ymm_upper_half_op,
2050 ymm_lower_half_op);
2051}
2052
2053static Xbyak::uint8 MM_SHUFFLE(
2054 Xbyak::uint8 z, Xbyak::uint8 y, Xbyak::uint8 x, Xbyak::uint8 w) {
2055 return (((z) << 6) | ((y) << 4) | ((x) << 2) | (w));
2056}
2057
2058static void execute_broadcast_f32_tail_avx(jit_generator *host,
2059 const Xbyak::Ymm &vmm, const Xbyak::Address &rhs_addr,
2060 std::size_t tail_size) {
2061
2062 const auto vmm_idx = vmm.getIdx();
2063 const auto tmp_xmm = Xbyak::Xmm(vmm_idx);
2064 static const std::array<Xbyak::uint8, 2> imms {
2065 {MM_SHUFFLE(3, 2, 0, 0), MM_SHUFFLE(3, 0, 0, 0)}};
2066
2067 const auto init_op = [&] { host->vmovss(tmp_xmm, rhs_addr); };
2068 const auto upper_half_op
2069 = [&](int upper_half_data_size, bool should_load_lower_half) {
2070 // one element is already loaded
2071 if (upper_half_data_size > 1)
2072 host->vshufps(tmp_xmm, tmp_xmm, tmp_xmm,
2073 imms.at(upper_half_data_size - 2));
2074 };
2075 const auto lower_half_op = [&](int upper_half_data_size) {
2076 host->vshufps(tmp_xmm, tmp_xmm, tmp_xmm, 0);
2077 };
2078
2079 load_tail_avx(
2080 host, vmm_idx, tail_size, init_op, upper_half_op, lower_half_op);
2081}
2082
2083static void execute_broadcast_f32_tail_avx(jit_generator *host,
2084 const Xbyak::Xmm &vmm, const Xbyak::Address &rhs_addr,
2085 std::size_t tail_size) {
2086
2087 const auto vmm_idx = vmm.getIdx();
2088 const auto tmp_xmm = Xbyak::Xmm(vmm_idx);
2089 static const std::array<Xbyak::uint8, 2> imms {
2090 {MM_SHUFFLE(3, 2, 0, 0), MM_SHUFFLE(3, 0, 0, 0)}};
2091
2092 host->vmovss(tmp_xmm, rhs_addr);
2093 // one element is already loaded
2094 if (tail_size > 1)
2095 host->vshufps(tmp_xmm, tmp_xmm, tmp_xmm, imms.at(tail_size - 2));
2096}
2097
2098template <cpu_isa_t isa, typename Vmm>
2099struct helper_bcast_tail_t {};
2100
2101template <typename Vmm>
2102struct helper_bcast_tail_t<avx2, Vmm> {
2103 static void execute_broadcast_tail_statically(jit_generator *host,
2104 const size_t tail_size, const data_type_t &data_type,
2105 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) {
2106 host->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2107
2108 if (data_type == data_type::f32 || data_type == data_type::s32) {
2109 execute_broadcast_f32_tail_avx(host, tmp_vmm, rhs_addr, tail_size);
2110 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2111 const auto tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
2112 for (std::size_t i = 0; i < tail_size; i++)
2113 host->vpinsrb(tmp_xmm, tmp_xmm, rhs_addr, i);
2114
2115 if (data_type == data_type::s8)
2116 host->vpmovsxbd(tmp_vmm, tmp_xmm);
2117 else
2118 host->vpmovzxbd(tmp_vmm, tmp_xmm);
2119 } else
2120 assert(!"unsupported data type");
2121 }
2122};
2123
2124template <cpu_isa_t isa, typename Vmm>
2125void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast_tail_statically(
2126 const data_type_t &data_type, const Vmm &tmp_vmm,
2127 const Xbyak::Address &rhs_addr, const std::size_t tail_size) const {
2128 assert(!"unsupported tail load mode");
2129}
2130
2131template <>
2132void jit_uni_binary_injector_t<avx2_vnni_2,
2133 Xbyak::Ymm>::execute_broadcast_tail_statically(const data_type_t
2134 &data_type,
2135 const Xbyak::Ymm &tmp_vmm, const Xbyak::Address &rhs_addr,
2136 const std::size_t tail_size) const {
2137 if (utils::one_of(data_type, data_type::bf16, data_type::f16)) {
2138 const auto tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
2139 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2140 for (std::size_t i = 0; i < tail_size; i++)
2141 host_->vpinsrw(tmp_xmm, tmp_xmm, rhs_addr, i);
2142 if (data_type == data_type::bf16) {
2143 host_->vpmovzxwd(tmp_vmm, tmp_xmm);
2144 host_->vpslld(tmp_vmm, tmp_vmm, 16);
2145 } else // f16
2146 host_->vcvtph2ps(tmp_vmm, tmp_xmm);
2147 } else {
2148 helper_bcast_tail_t<avx2,
2149 Xbyak::Ymm>::execute_broadcast_tail_statically(host_, tail_size,
2150 data_type, tmp_vmm, rhs_addr);
2151 }
2152}
2153
2154template <>
2155void jit_uni_binary_injector_t<avx2,
2156 Xbyak::Ymm>::execute_broadcast_tail_statically(const data_type_t
2157 &data_type,
2158 const Xbyak::Ymm &tmp_vmm, const Xbyak::Address &rhs_addr,
2159 const std::size_t tail_size) const {
2160 helper_bcast_tail_t<avx2, Xbyak::Ymm>::execute_broadcast_tail_statically(
2161 host_, tail_size, data_type, tmp_vmm, rhs_addr);
2162}
2163
2164template <>
2165void jit_uni_binary_injector_t<avx2,
2166 Xbyak::Xmm>::execute_broadcast_tail_statically(const data_type_t
2167 &data_type,
2168 const Xbyak::Xmm &tmp_vmm, const Xbyak::Address &rhs_addr,
2169 const std::size_t tail_size) const {
2170 helper_bcast_tail_t<avx2, Xbyak::Xmm>::execute_broadcast_tail_statically(
2171 host_, tail_size, data_type, tmp_vmm, rhs_addr);
2172}
2173
2174template <>
2175void jit_uni_binary_injector_t<avx,
2176 Xbyak::Ymm>::execute_broadcast_tail_statically(const data_type_t
2177 &data_type,
2178 const Xbyak::Ymm &tmp_vmm, const Xbyak::Address &rhs_addr,
2179 const std::size_t tail_size) const {
2180
2181 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2182
2183 if (data_type == data_type::f32 || data_type == data_type::s32) {
2184 execute_broadcast_f32_tail_avx(host_, tmp_vmm, rhs_addr, tail_size);
2185 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2186 const auto vmm_idx = tmp_vmm.getIdx();
2187 const auto tmp_xmm = Xbyak::Xmm(vmm_idx);
2188 static const std::array<Xbyak::uint8, 2> imms {
2189 {MM_SHUFFLE(3, 2, 0, 0), MM_SHUFFLE(3, 0, 0, 0)}};
2190
2191 const auto cvt_to_dword = [&] {
2192 if (data_type == data_type::s8)
2193 host_->vpmovsxbd(tmp_xmm, tmp_xmm);
2194 else
2195 host_->vpmovzxbd(tmp_xmm, tmp_xmm);
2196 };
2197
2198 const auto init_op = [&] {
2199 host_->vpinsrb(tmp_xmm, tmp_xmm, rhs_addr, 0);
2200 cvt_to_dword();
2201 };
2202
2203 const auto upper_half_op
2204 = [&](int upper_half_data_size, bool should_load_lower_half) {
2205 if (upper_half_data_size > 1)
2206 host_->vshufps(tmp_xmm, tmp_xmm, tmp_xmm,
2207 imms.at(upper_half_data_size - 2));
2208 };
2209
2210 const auto lower_half_op = [&](int upper_half_data_size) {
2211 host_->vshufps(tmp_xmm, tmp_xmm, tmp_xmm, 0);
2212 };
2213
2214 load_tail_avx(host_, vmm_idx, tail_size, init_op, upper_half_op,
2215 lower_half_op);
2216 } else
2217 assert(!"unsupported data type");
2218}
2219
2220template <>
2221void jit_uni_binary_injector_t<avx,
2222 Xbyak::Xmm>::execute_broadcast_tail_statically(const data_type_t
2223 &data_type,
2224 const Xbyak::Xmm &tmp_vmm, const Xbyak::Address &rhs_addr,
2225 const std::size_t tail_size) const {
2226
2227 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2228
2229 const auto load_tail_avx_xmm = [&]() {
2230 for (size_t i = 0; i < tail_size; i++)
2231 host_->vpinsrb(tmp_vmm, tmp_vmm, rhs_addr, i);
2232 };
2233
2234 if (data_type == data_type::f32 || data_type == data_type::s32) {
2235 execute_broadcast_f32_tail_avx(host_, tmp_vmm, rhs_addr, tail_size);
2236 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2237 load_tail_avx_xmm();
2238 if (data_type == data_type::s8)
2239 host_->vpmovsxbd(tmp_vmm, tmp_vmm);
2240 else
2241 host_->vpmovzxbd(tmp_vmm, tmp_vmm);
2242 } else
2243 assert(!"unsupported data type");
2244}
2245
2246template <>
2247void jit_uni_binary_injector_t<sse41,
2248 Xbyak::Xmm>::execute_broadcast_tail_statically(const data_type_t
2249 &data_type,
2250 const Xbyak::Xmm &tmp_vmm, const Xbyak::Address &rhs_addr,
2251 const std::size_t tail_size) const {
2252
2253 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2254 if (data_type == data_type::f32 || data_type == data_type::s32) {
2255 static const std::array<Xbyak::uint8, 2> imms {
2256 {MM_SHUFFLE(3, 2, 0, 0), MM_SHUFFLE(3, 0, 0, 0)}};
2257
2258 host_->movss(tmp_vmm, rhs_addr);
2259 if (tail_size > 1) host_->shufps(tmp_vmm, tmp_vmm, imms[tail_size - 2]);
2260
2261 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2262 for (std::size_t i = 0; i < tail_size; i++)
2263 host_->pinsrb(tmp_vmm, rhs_addr, i);
2264
2265 if (data_type == data_type::s8)
2266 host_->pmovsxbd(tmp_vmm, tmp_vmm);
2267 else
2268 host_->pmovzxbd(tmp_vmm, tmp_vmm);
2269 } else
2270 assert(!"unsupported data type");
2271}
2272
2273template <cpu_isa_t isa, typename Vmm>
2274void jit_uni_binary_injector_t<isa, Vmm>::execute_broadcast_tail_with_gpr(
2275 const data_type_t &data_type, const Vmm &tmp_vmm,
2276 const Xbyak::Address &rhs_addr) const {
2277
2278 const Xbyak::Reg64 &reg_tmp = rhs_arg_static_params_.rhs_helper_reg;
2279 const Xbyak::Reg64 &reg_tail_size = rhs_arg_static_params_.reg_tail_size;
2280
2281 auto runtime_tail_load = [&](int load_size) {
2282 execute_broadcast_tail_statically(
2283 data_type, tmp_vmm, rhs_addr, load_size);
2284 };
2285 host_->runtime_tail_process<Vmm>(reg_tail_size, reg_tmp, runtime_tail_load);
2286}
2287
2288template <cpu_isa_t isa, typename Vmm>
2289void jit_uni_binary_injector_t<isa, Vmm>::load_rhs_no_tail(
2290 const data_type_t &data_type, const Vmm &tmp_vmm,
2291 const Xbyak::Address &rhs_addr) const {
2292 assert(is_data_supported(isa, data_type) && "unsupported data type");
2293 switch (data_type) {
2294 case data_type::f32:
2295 case data_type::s32: host_->uni_vmovups(tmp_vmm, rhs_addr); break;
2296 case data_type::s8:
2297 case data_type::u8:
2298 load_rhs_i8_no_tail(data_type, tmp_vmm, rhs_addr);
2299 break;
2300 case data_type::f16:
2301 if (is_avx512_core_fp16_)
2302 host_->vcvtph2psx(tmp_vmm, rhs_addr);
2303 else if (isa == avx2_vnni_2)
2304 host_->vcvtph2ps(tmp_vmm, rhs_addr);
2305 else
2306 assert(!"unsupported ISA for given data type");
2307 break;
2308 case data_type::bf16:
2309 if (is_avx512_ || isa == avx2_vnni_2) {
2310 host_->vpmovzxwd(tmp_vmm, rhs_addr);
2311 host_->vpslld(tmp_vmm, tmp_vmm, 0x10);
2312 break;
2313 }
2314 default: assert(!"unsupported data type");
2315 }
2316}
2317
2318template <cpu_isa_t isa, typename Vmm>
2319void jit_uni_binary_injector_t<isa, Vmm>::load_rhs_i8_no_tail(
2320 const data_type_t &data_type, const Vmm &tmp_vmm,
2321 const Xbyak::Address &rhs_addr) const {
2322 if (data_type == data_type::s8)
2323 host_->uni_vpmovsxbd(tmp_vmm, rhs_addr);
2324 else if (data_type == data_type::u8)
2325 host_->uni_vpmovzxbd(tmp_vmm, rhs_addr);
2326 else
2327 assert(!"unsupported data type");
2328}
2329
2330template <>
2331void jit_uni_binary_injector_t<avx, Xbyak::Ymm>::load_rhs_i8_no_tail(
2332 const data_type_t &data_type, const Xbyak::Ymm &tmp_vmm,
2333 const Xbyak::Address &rhs_addr) const {
2334 static constexpr int xmm_size_elem = 4;
2335 static constexpr int one_load_size = xmm_size_elem * sizeof(uint8_t);
2336 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2337 const auto tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
2338
2339 auto load_i8_fn = [&](const Xbyak::Address &addr) {
2340 if (data_type == data_type::s8)
2341 host_->uni_vpmovsxbd(tmp_xmm, addr);
2342 else if (data_type == data_type::u8)
2343 host_->uni_vpmovzxbd(tmp_xmm, addr);
2344 else
2345 assert(!"unsupported data type");
2346 };
2347
2348 load_i8_fn(host_->ptr[rhs_addr_reg + one_load_size]);
2349 push_vmm(host_, tmp_xmm);
2350 load_i8_fn(rhs_addr);
2351 host_->vinsertf128(tmp_vmm, tmp_vmm, host_->ptr[host_->rsp], 1);
2352 restore_stack(host_, tmp_xmm);
2353}
2354
2355template <cpu_isa_t isa, typename Vmm>
2356void jit_uni_binary_injector_t<isa, Vmm>::load_rhs_tail_dynamically_with_opmask(
2357 const data_type_t &data_type, const Vmm &tmp_vmm,
2358 const Xbyak::Address &rhs_addr) const {
2359 assert(is_data_supported(isa, data_type) && "unsupported data type");
2360 assert(rhs_arg_static_params_.is_opmask_set()
2361 && "Opmask is not set for tail loading avx512");
2362
2363 const auto &tail_opmask = rhs_arg_static_params_.tail_opmask;
2364
2365 switch (data_type) {
2366 case data_type::f32:
2367 case data_type::s32:
2368 host_->vmovups(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
2369 break;
2370 case data_type::s8:
2371 host_->vpmovsxbd(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
2372 break;
2373 case data_type::u8:
2374 host_->vpmovzxbd(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
2375 break;
2376 case data_type::f16:
2377 if (is_avx512_core_fp16_)
2378 host_->vcvtph2psx(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
2379 else
2380 assert(!"unsupported masked tail processing");
2381 break;
2382 case data_type::bf16:
2383 host_->vpmovzxwd(tmp_vmm | tail_opmask | host_->T_z, rhs_addr);
2384 host_->vpslld(tmp_vmm | tail_opmask | host_->T_z, tmp_vmm, 0x10);
2385 break;
2386 default: return;
2387 }
2388}
2389
2390template <cpu_isa_t isa, typename Vmm>
2391void jit_uni_binary_injector_t<isa, Vmm>::load_rhs_tail_dynamically_with_gpr(
2392 const data_type_t &data_type, const Vmm &tmp_vmm) const {
2393
2394 const bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
2395 const Xbyak::Reg64 &reg_addr = rhs_arg_static_params_.rhs_addr_reg;
2396 const Xbyak::Reg64 &reg_tmp = rhs_arg_static_params_.rhs_helper_reg;
2397 const Xbyak::Reg64 &reg_tail_size = rhs_arg_static_params_.reg_tail_size;
2398 const Xbyak::Xmm x = Xbyak::Xmm(tmp_vmm.getIdx());
2399 const Xbyak::Ymm y = Xbyak::Ymm(tmp_vmm.getIdx());
2400
2401 auto runtime_tail_load = [&](int load_size) {
2402 if (is_ymm)
2403 host_->load_data(data_type, y, reg_addr, 0, load_size);
2404 else
2405 host_->load_data(data_type, x, reg_addr, 0, load_size);
2406 };
2407
2408 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2409 host_->runtime_tail_process<Vmm>(reg_tail_size, reg_tmp, runtime_tail_load);
2410}
2411
2412template <cpu_isa_t isa, typename Vmm>
2413void jit_uni_binary_injector_t<isa, Vmm>::load_rhs_tail_statically(
2414 const data_type_t &data_type, const Vmm &tmp_vmm,
2415 const Xbyak::Address &rhs_addr) const {
2416 assert(!"unsupported tail load mode");
2417}
2418template <cpu_isa_t isa, typename Vmm>
2419struct helper_load_tail_t {};
2420
2421template <typename Vmm>
2422struct helper_load_tail_t<avx2, Vmm> {
2423 static void load_rhs_tail_statically(jit_generator *host,
2424 const size_t tail_size, const Xbyak::Reg64 &rhs_addr_reg,
2425 const data_type_t &data_type, const Vmm &tmp_vmm,
2426 const Xbyak::Address &rhs_addr) {
2427
2428 if (!utils::one_of(data_type, data_type::f32, data_type::s32,
2429 data_type::s8, data_type::u8))
2430 assert(!"unsupported data type");
2431
2432 host->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2433 host->load_data(data_type, tmp_vmm, rhs_addr_reg, 0, tail_size);
2434 }
2435};
2436
2437template <>
2438void jit_uni_binary_injector_t<avx2, Xbyak::Ymm>::load_rhs_tail_statically(
2439 const data_type_t &data_type, const Xbyak::Ymm &tmp_vmm,
2440 const Xbyak::Address &rhs_addr) const {
2441
2442 const auto &tail_size = rhs_arg_static_params_.tail_size;
2443 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2444 helper_load_tail_t<avx2, Xbyak::Ymm>::load_rhs_tail_statically(
2445 host_, tail_size, rhs_addr_reg, data_type, tmp_vmm, rhs_addr);
2446}
2447
2448template <>
2449void jit_uni_binary_injector_t<avx2, Xbyak::Xmm>::load_rhs_tail_statically(
2450 const data_type_t &data_type, const Xbyak::Xmm &tmp_vmm,
2451 const Xbyak::Address &rhs_addr) const {
2452
2453 const auto &tail_size = rhs_arg_static_params_.tail_size;
2454 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2455 helper_load_tail_t<avx2, Xbyak::Xmm>::load_rhs_tail_statically(
2456 host_, tail_size, rhs_addr_reg, data_type, tmp_vmm, rhs_addr);
2457}
2458
2459template <>
2460void jit_uni_binary_injector_t<avx2_vnni_2,
2461 Xbyak::Ymm>::load_rhs_tail_statically(const data_type_t &data_type,
2462 const Xbyak::Ymm &tmp_vmm, const Xbyak::Address &rhs_addr) const {
2463
2464 const auto &tail_size = rhs_arg_static_params_.tail_size;
2465 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2466 const Xbyak::Xmm tmp_xmm = Xbyak::Xmm(tmp_vmm.getIdx());
2467
2468 if (utils::one_of(data_type, data_type::bf16, data_type::f16)) {
2469 host_->load_bytes(
2470 tmp_xmm, rhs_addr_reg, 0, tail_size * sizeof(bfloat16_t));
2471 if (data_type == data_type::bf16) {
2472 host_->vpmovzxwd(tmp_vmm, tmp_xmm);
2473 host_->vpslld(tmp_vmm, tmp_vmm, 16);
2474 } else //f16
2475 host_->vcvtph2ps(tmp_vmm, tmp_xmm);
2476 } else {
2477 helper_load_tail_t<avx2, Xbyak::Ymm>::load_rhs_tail_statically(
2478 host_, tail_size, rhs_addr_reg, data_type, tmp_vmm, rhs_addr);
2479 }
2480}
2481
2482template <>
2483void jit_uni_binary_injector_t<avx, Xbyak::Ymm>::load_rhs_tail_statically(
2484 const data_type_t &data_type, const Xbyak::Ymm &tmp_vmm,
2485 const Xbyak::Address &rhs_addr) const {
2486 const auto &tail_size = rhs_arg_static_params_.tail_size;
2487 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2488
2489 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2490 static constexpr int xmm_size_elem = 4;
2491 const auto res = std::div(tail_size, xmm_size_elem);
2492 const auto vmm_idx = tmp_vmm.getIdx();
2493 const auto tmp_xmm = Xbyak::Xmm(vmm_idx);
2494
2495 if (data_type == data_type::f32 || data_type == data_type::s32) {
2496 const auto upper_half_op = [&](int upper_half_data_size,
2497 bool should_load_lower_half) {
2498 const int offset = should_load_lower_half
2499 ? xmm_size_elem * sizeof(float)
2500 : 0;
2501 for (int i = 0; i < res.rem; i++)
2502 host_->vpinsrd(tmp_xmm, tmp_xmm,
2503 host_->ptr[rhs_addr_reg + offset + i * sizeof(float)],
2504 i);
2505 };
2506
2507 const auto lower_half_op = [&](int upper_half_data_size) {
2508 host_->vmovups(tmp_xmm, rhs_addr);
2509 };
2510 load_tail_avx(host_, vmm_idx, tail_size, upper_half_op, lower_half_op);
2511
2512 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2513 const auto cvt_to_dword = [&](const Xbyak::Operand &operand) {
2514 if (data_type == data_type::s8)
2515 host_->vpmovsxbd(tmp_xmm, operand);
2516 else
2517 host_->vpmovzxbd(tmp_xmm, operand);
2518 };
2519
2520 const auto upper_half_op = [&](int upper_half_data_size,
2521 bool should_load_lower_half) {
2522 const int offset = should_load_lower_half ? xmm_size_elem : 0;
2523 for (int i = 0; i < upper_half_data_size; i++)
2524 host_->vpinsrb(tmp_xmm, tmp_xmm,
2525 host_->ptr[rhs_addr_reg + offset + i * sizeof(int8_t)],
2526 i);
2527 cvt_to_dword(tmp_xmm);
2528 };
2529
2530 const auto lower_half_op
2531 = [&](int upper_half_data_size) { cvt_to_dword(rhs_addr); };
2532
2533 load_tail_avx(host_, vmm_idx, tail_size, upper_half_op, lower_half_op);
2534 } else
2535 assert(!"unsupported data type");
2536}
2537
2538template <>
2539void jit_uni_binary_injector_t<avx, Xbyak::Xmm>::load_rhs_tail_statically(
2540 const data_type_t &data_type, const Xbyak::Xmm &tmp_vmm,
2541 const Xbyak::Address &rhs_addr) const {
2542 const auto &tail_size = rhs_arg_static_params_.tail_size;
2543 const auto &rhs_addr_reg = rhs_arg_static_params_.rhs_addr_reg;
2544 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2545
2546 if (data_type == data_type::f32 || data_type == data_type::s32) {
2547 for (size_t i = 0; i < tail_size; i++)
2548 host_->vpinsrd(tmp_vmm, tmp_vmm,
2549 host_->ptr[rhs_addr_reg + i * sizeof(float)], i);
2550 } else if (data_type == data_type::u8 || data_type == data_type::s8) {
2551 for (size_t i = 0; i < tail_size; i++)
2552 host_->vpinsrb(tmp_vmm, tmp_vmm,
2553 host_->ptr[rhs_addr_reg + i * sizeof(int8_t)], i);
2554 if (data_type == data_type::s8)
2555 host_->vpmovsxbd(tmp_vmm, tmp_vmm);
2556 else
2557 host_->vpmovzxbd(tmp_vmm, tmp_vmm);
2558 } else
2559 assert(!"unsupported data type");
2560}
2561
2562template <>
2563void jit_uni_binary_injector_t<sse41, Xbyak::Xmm>::load_rhs_tail_statically(
2564 const data_type_t &data_type, const Xbyak::Xmm &tmp_vmm,
2565 const Xbyak::Address &rhs_addr) const {
2566 if (!utils::one_of(data_type, data_type::f32, data_type::s32, data_type::s8,
2567 data_type::u8))
2568 assert(!"unsupported data type");
2569
2570 const auto &tail_size = rhs_arg_static_params_.tail_size;
2571 host_->uni_vxorps(tmp_vmm, tmp_vmm, tmp_vmm);
2572 host_->load_data(data_type, tmp_vmm, rhs_arg_static_params_.rhs_addr_reg, 0,
2573 tail_size);
2574}
2575
2576// Support compare with Address param only when isa is avx512.
2577// AVX512 implementation
2578template <cpu_isa_t isa, typename Vmm>
2579template <typename T>
2580typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
2581 || std::is_same<T, Xbyak::Address>::value>::type
2582jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
2583 const Vmm &lhs, const T &rhs, const unsigned int cmp_predicate) const {
2584 // For GreaterEqual op, replace 0xFFFFFFFF by 1
2585 // which was returned by vcmpps.
2586 const auto &cmp_mask = rhs_arg_static_params_.tail_opmask;
2587 const Xbyak::Xmm xreg_one
2588 = Xbyak::Xmm(rhs_arg_static_params_.rhs_dt_helper_vmm_idx);
2589 const Xbyak::Reg64 reg_tmp = rhs_arg_static_params_.rhs_helper_reg;
2590
2591 push_opmask(host_, cmp_mask);
2592 host_->vcmpps(cmp_mask, lhs, rhs, cmp_predicate);
2593 host_->mov(reg_tmp, float2int(1));
2594 host_->uni_vmovq(xreg_one, reg_tmp);
2595 // broadcast 1.0f with mask
2596 host_->vbroadcastss(dst | cmp_mask | host_->T_z, xreg_one);
2597 // pop tail mask from stack
2598 pop_opmask(host_, cmp_mask);
2599}
2600
2601// SSE4.1., AVX and AVX2 implementation
2602template <cpu_isa_t isa, typename Vmm>
2603template <typename T>
2604typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
2605 || std::is_same<T, Xbyak::Address>::value)>::type
2606jit_uni_binary_injector_t<isa, Vmm>::execute_cmp_binary(const Vmm &dst,
2607 const Vmm &lhs, const T &rhs, const unsigned int cmp_predicate) const {
2608 const int vmm_idx = rhs_arg_static_params_.rhs_dt_helper_vmm_idx;
2609 const Vmm vreg_one = Vmm(vmm_idx);
2610 const Xbyak::Xmm xreg_one = Xbyak::Xmm(vmm_idx);
2611 const Xbyak::Reg64 reg_tmp = rhs_arg_static_params_.rhs_helper_reg;
2612
2613 host_->uni_vcmpps(dst, lhs, rhs, cmp_predicate);
2614 host_->mov(reg_tmp, float2int(1));
2615 host_->uni_vmovq(xreg_one, reg_tmp);
2616 host_->uni_vbroadcastss(vreg_one, xreg_one);
2617 host_->uni_vminps(dst, dst, vreg_one);
2618}
2619
2620template <cpu_isa_t isa, typename Vmm>
2621template <typename T>
2622void jit_uni_binary_injector_t<isa, Vmm>::execute_binary(alg_kind_t binary_alg,
2623 const Vmm &dst, const Vmm &lhs, const T &rhs) const {
2624 switch (binary_alg) {
2625 case alg_kind::binary_add: host_->uni_vaddps(dst, lhs, rhs); break;
2626 case alg_kind::binary_mul: host_->uni_vmulps(dst, lhs, rhs); break;
2627 case alg_kind::binary_max: host_->uni_vmaxps(dst, lhs, rhs); break;
2628 case alg_kind::binary_min: host_->uni_vminps(dst, lhs, rhs); break;
2629 case alg_kind::binary_div: host_->uni_vdivps(dst, lhs, rhs); break;
2630 case alg_kind::binary_sub: host_->uni_vsubps(dst, lhs, rhs); break;
2631 case alg_kind::binary_ge:
2632 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_nlt_us);
2633 break;
2634 case alg_kind::binary_gt:
2635 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_nle_us);
2636 break;
2637 case alg_kind::binary_le:
2638 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_le_os);
2639 break;
2640 case alg_kind::binary_lt:
2641 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_lt_os);
2642 break;
2643 case alg_kind::binary_eq:
2644 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_eq_oq);
2645 break;
2646 case alg_kind::binary_ne:
2647 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_neq_uq);
2648 break;
2649 default: assert(!"unsupported algorithm");
2650 }
2651}
2652
2653template <cpu_isa_t isa, typename Vmm>
2654struct helper_binary_t {};
2655
2656template <typename Vmm>
2657struct helper_binary_t<avx, Vmm> {
2658 template <typename T, typename F>
2659 static void execute_binary(jit_generator *host, F execute_cmp_binary,
2660 alg_kind_t binary_alg, const Vmm &dst, const Vmm &lhs,
2661 const T &rhs) {
2662 switch (binary_alg) {
2663 case alg_kind::binary_add: host->uni_vaddps(dst, lhs, rhs); break;
2664 case alg_kind::binary_mul: host->uni_vmulps(dst, lhs, rhs); break;
2665 case alg_kind::binary_max: host->uni_vmaxps(dst, lhs, rhs); break;
2666 case alg_kind::binary_min: host->uni_vminps(dst, lhs, rhs); break;
2667 case alg_kind::binary_div: host->uni_vdivps(dst, lhs, rhs); break;
2668 case alg_kind::binary_sub: host->uni_vsubps(dst, lhs, rhs); break;
2669 case alg_kind::binary_ge:
2670 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_nlt_us);
2671 break;
2672 case alg_kind::binary_gt:
2673 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_nle_us);
2674 break;
2675 case alg_kind::binary_le:
2676 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_le_os);
2677 break;
2678 case alg_kind::binary_lt:
2679 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_lt_os);
2680 break;
2681 case alg_kind::binary_eq:
2682 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_eq_oq);
2683 break;
2684 case alg_kind::binary_ne:
2685 execute_cmp_binary(dst, lhs, rhs, jit_generator::_cmp_neq_uq);
2686 break;
2687 default: assert(!"unsupported algorithm");
2688 }
2689 }
2690};
2691
2692template <>
2693template <typename T>
2694void jit_uni_binary_injector_t<avx, Xbyak::Ymm>::execute_binary(
2695 alg_kind_t binary_alg, const Xbyak::Ymm &dst, const Xbyak::Ymm &lhs,
2696 const T &rhs) const {
2697
2698 const auto execute_cmp_binary_lam
2699 = [this](const Xbyak::Ymm &dst, const Xbyak::Ymm &lhs, const T &rhs,
2700 const unsigned int cmp_predicate) {
2701 this->execute_cmp_binary<T>(dst, lhs, rhs, cmp_predicate);
2702 };
2703 helper_binary_t<avx, Xbyak::Ymm>::execute_binary<T>(
2704 host_, execute_cmp_binary_lam, binary_alg, dst, lhs, rhs);
2705}
2706
2707template <>
2708template <typename T>
2709void jit_uni_binary_injector_t<avx, Xbyak::Xmm>::execute_binary(
2710 alg_kind_t binary_alg, const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs,
2711 const T &rhs) const {
2712
2713 const auto execute_cmp_binary_lam
2714 = [this](const Xbyak::Xmm &dst, const Xbyak::Xmm &lhs, const T &rhs,
2715 const unsigned int cmp_predicate) {
2716 this->execute_cmp_binary<T>(dst, lhs, rhs, cmp_predicate);
2717 };
2718 helper_binary_t<avx, Xbyak::Xmm>::execute_binary<T>(
2719 host_, execute_cmp_binary_lam, binary_alg, dst, lhs, rhs);
2720}
2721
2722template <cpu_isa_t isa, typename Vmm>
2723void jit_uni_binary_injector_t<isa, Vmm>::compute_vector(size_t idx,
2724 std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
2725 const rhs_arg_dynamic_params_t &rhs_arg_params) const {
2726 compute_vector_range({idx}, rhs_arg_idx, post_op, rhs_arg_params);
2727}
2728
2729template class jit_uni_binary_injector_t<avx512_core_fp16>;
2730template class jit_uni_binary_injector_t<avx512_core_fp16, Xbyak::Ymm>;
2731template class jit_uni_binary_injector_t<avx512_core_fp16, Xbyak::Xmm>;
2732template class jit_uni_binary_injector_t<avx512_core_bf16>;
2733template class jit_uni_binary_injector_t<avx512_core>;
2734template class jit_uni_binary_injector_t<avx512_core, Xbyak::Ymm>;
2735template class jit_uni_binary_injector_t<avx512_core, Xbyak::Xmm>;
2736template class jit_uni_binary_injector_t<avx2_vnni_2>;
2737template class jit_uni_binary_injector_t<avx2, Xbyak::Ymm>;
2738template class jit_uni_binary_injector_t<avx2, Xbyak::Xmm>;
2739template class jit_uni_binary_injector_t<avx, Xbyak::Ymm>;
2740template class jit_uni_binary_injector_t<avx, Xbyak::Xmm>;
2741template class jit_uni_binary_injector_t<sse41>;
2742
2743} // namespace binary_injector
2744} // namespace x64
2745} // namespace cpu
2746} // namespace impl
2747} // namespace dnnl
2748