1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_BINARY_INJECTOR_HPP
18#define CPU_X64_JIT_UNI_BINARY_INJECTOR_HPP
19
20#include <array>
21#include <cassert>
22#include <functional>
23#include <map>
24#include <utility>
25#include <vector>
26#include <unordered_set>
27
28#include "common/broadcast_strategy.hpp"
29#include "common/c_types_map.hpp"
30#include "common/primitive_attr.hpp"
31#include "common/primitive_exec_types.hpp"
32#include "cpu/binary_injector_utils.hpp"
33#include "cpu/x64/cpu_isa_traits.hpp"
34#include "cpu/x64/injectors/injector_utils.hpp"
35#include "cpu/x64/jit_generator.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41namespace binary_injector {
42using dnnl::impl::cpu::binary_injector_utils::prepare_binary_args;
43
44bool binary_args_matches_tag(format_tag_t tag, const post_ops_t &post_ops);
45
46bool binary_args_broadcast_supported(const post_ops_t &post_ops,
47 const memory_desc_wrapper &dst_d,
48 const bcast_set_t &supported_strategy_set);
49
50bool binary_args_tail_supported(const post_ops_t &post_ops,
51 const memory_desc_wrapper &dst_d, int vlen,
52 const bcast_set_t &supported_strategy_set);
53
54bool any_binary_postop_rhs_non_scalar_broadcast(
55 const post_ops_t &post_ops, const memory_desc_wrapper &dst_d);
56bool any_binary_postop_rhs_per_oc_broadcast(
57 const post_ops_t &post_ops, const memory_desc_wrapper &dst_d);
58bool any_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
59 const memory_desc_wrapper &dst_d,
60 const bcast_set_t &supported_strategy_set);
61
62bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
63 const memory_desc_wrapper &dst_d,
64 const std::function<bool(const memory_desc_wrapper &)> &predicate);
65bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops,
66 const memory_desc_wrapper &dst_d,
67 const bcast_set_t &supported_strategy_set,
68 const std::function<bool(const memory_desc_wrapper &)> &predicate);
69
70/*
71 * Represents params related to all binary post-ops right-hand side arguments
72 * (arg1) that don't change during jit_uni_binary_injector_t object lifetime
73 * and between compute_vector_range calls.
74 *
75 * @param rhs_dt_helper_vmm_idx - index of vmm helper used when loading data for
76 * calculations. Treated as hint from user. If inside compute_vector_range hint
77 * turns out to be invalid, it will be overwriten by register preserving logic inside
78 * binary injector.
79 * @param rhs_addr_reg - gpr register, used as the currently processed address of
80 * rhs tensor slice. Data of rhs(arg1) for the binary operation is loaded from address
81 * stored inside rhs_addr_reg.
82 * @param rhs_helper_reg - gpr register used as helper for calculations during data
83 * loading phase.
84 * @param rhs_addr_cache_reg - gpr register used for caching part of calculated
85 * offset, this register is always preserved.
86 * @param preserve_gpr_helpers - determines whether gpr registers specified above
87 * should be preserved (pushed to stack and poped back afterwords) between
88 * compute_vector_range calls.
89 * @param preserve_vmm_helper - determines whether vmm helper register specified
90 * above should be preserved between compute_vector_range calls.
91 * @param abi_param_offset - offset to rhs tensor from first binary post-op operation
92 * specified by user from runtime structure passed to kernel as abi param 1.
93 * @param dst_orig_offset - offset 0 to destination tensor
94 * @param dst_d - descriptor of destination tensor (result after applying all post-ops
95 * operations).
96 * @param tail_opmask - register with loaded by user mask, used in avx512 for load with
97 * tail handling.
98 * @param tail_size - size of processed tail in elements.
99 * @param use_exact_tail_scalar_bcast - in case of scalar broadcast user can disable
100 * loading data with tail, usually bcast through entire vector is faster (usually 1 instruction)
101 * vs. broadcasting limited by tail size (potentially several instructions). In case
102 * when user during storing ignores values from vmm above tail size, setting this option to
103 * false can result in better performance.
104 * @param reg_tail_size - register with loaded size of tail, used in sse41/avx/avx2
105 * for load with tail in runtime.
106 */
107struct rhs_arg_static_params_t {
108 rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
109 const Xbyak::Reg64 &rhs_addr_reg,
110 const Xbyak::Reg64 &rhs_helper_reg,
111 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
112 bool preserve_vmm_helper, std::size_t abi_param_offset,
113 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
114 std::size_t tail_size = 0u,
115 bool use_exact_tail_scalar_bcast = false);
116 rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
117 const Xbyak::Reg64 &rhs_addr_reg,
118 const Xbyak::Reg64 &rhs_helper_reg,
119 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
120 bool preserve_vmm_helper, std::size_t abi_param_offset,
121 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
122 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
123 bool use_exact_tail_scalar_bcast);
124 rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
125 const Xbyak::Reg64 &rhs_addr_reg,
126 const Xbyak::Reg64 &rhs_helper_reg,
127 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
128 bool preserve_vmm_helper, std::size_t abi_param_offset,
129 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
130 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
131 const Xbyak::Reg64 &reg_tail_size,
132 bool use_exact_tail_scalar_bcast);
133
134 bool is_opmask_set() const noexcept { return is_opmask_set_; }
135
136 mutable std::size_t rhs_dt_helper_vmm_idx;
137 Xbyak::Reg64 rhs_addr_reg;
138 Xbyak::Reg64 rhs_helper_reg;
139 Xbyak::Reg64 rhs_addr_cache_reg;
140 bool preserve_gpr_helpers;
141 bool preserve_vmm_helper;
142 std::size_t abi_param_offset;
143 std::size_t dst_orig_offset;
144 memory_desc_wrapper dst_d;
145 std::size_t tail_size;
146 Xbyak::Opmask tail_opmask;
147 bool use_exact_tail_scalar_bcast;
148 Xbyak::Reg64 reg_tail_size;
149 bool is_tail;
150
151private:
152 rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx,
153 const Xbyak::Reg64 &rhs_addr_reg,
154 const Xbyak::Reg64 &rhs_helper_reg,
155 const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers,
156 bool preserve_vmm_helper, std::size_t abi_param_offset,
157 std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d,
158 std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
159 bool use_exact_tail_scalar_bcast, const Xbyak::Reg64 &reg_tail_size,
160 bool is_opmask_set);
161
162 bool is_opmask_set_;
163};
164
165/*
166 * Represents params required by jit_uni_binary_injector_t that don't change
167 * during it's entire lifetime.
168 *
169 * @param param1 - register storing abi param1. At the moment of calling
170 * compute_vector_range method can be different than the default one defined
171 * inside jit_generator.
172 * @param bcast_set_t supported_strategy_set - set allowing disabling particular
173 * bcast strategies
174 * @param rhs_arg_static_params - params related to all binary post-ops right-hand side
175 * arguments that don't change during entire lifetime of jit_uni_binary_injector_t
176 * object.
177 */
178struct static_params_t {
179 static_params_t(const Xbyak::Reg64 &param1,
180 const bcast_set_t &supported_strategy_set,
181 const rhs_arg_static_params_t &rhs_arg_static_params);
182 static_params_t(const Xbyak::Reg64 &param1,
183 const rhs_arg_static_params_t &rhs_arg_static_params);
184
185 Xbyak::Reg64 param1;
186 const bcast_set_t supported_strategy_set;
187 rhs_arg_static_params_t rhs_arg_static_params;
188};
189
190/*
191 * Mode of data load with tail for rhs:
192 * STATIC - load based on given integer.
193 * DYNAMIC - load based on opmask or 64-bit register.
194 * DEFAULT - DYNAMIC for avx512, STATIC for others.
195 */
196
197enum class tail_lode_mode_t { STATIC, DYNAMIC, DEFAULT };
198
199/*
200 * Represents params passed to compute_vector_range method of
201 * jit_uni_binary_injector_t that can be different for each call.
202 * Contains configurable std::maps where key is vmm index and value is
203 * offset in elements. The offset value identifies tensor slice in particular
204 * vmm. This is utilized by broadcasting mechanism. Offset, depending on the
205 * implementation particular kernels, can be passed as value (usually during
206 * unrolling), inside operand, under memory address.
207 *
208 * @param vmm_idx_to_out_addr - vmm mapped to address of destination tensor with offset,
209 * used to calculate offset in no_broadcast strategy, but also in other strategies whose
210 * calculations are based on no_broadcast strategy.
211 * @param vmm_idx_to_out_reg - vmm mapped to register containing address of destination
212 * with offset, used to calculate offset in no_broadcast strategy, but also in other
213 * strategies whose calculations are based on no_broadcast strategy.
214 * @param vmm_idx_to_out_elem_off_val - vmm mapped to offset in elements passed as raw
215 * value intended to use in no_broadcast strategy, but also in other
216 * strategies whose calculations are based on no_broadcast strategy.
217 * @param vmm_tail_idx - vmm indices that contains data don't fill the whole vector (tail).
218 * @param is_dynamic_tail_load - determines whether to load with tail in
219 * runtime (based on the value from reg_tail_size or opmask) or based on given
220 * integer.
221 */
222
223struct rhs_arg_dynamic_params_t {
224 std::map<int, Xbyak::Address> vmm_idx_to_out_addr;
225 std::map<int, Xbyak::Reg64> vmm_idx_to_out_reg;
226 std::map<int, size_t> vmm_idx_to_out_elem_off_val;
227
228 std::unordered_set<int> vmm_tail_idx_;
229 tail_lode_mode_t tail_load_mode = tail_lode_mode_t::DEFAULT;
230};
231
232/*
233 * Checks if src1 data type is supported by binary injector.
234 */
235bool is_data_supported(cpu_isa_t isa, data_type_t data_type);
236
237/*
238 * Checks if broadcast of src1 is supported by binary injector.
239 */
240bool is_bcast_supported(const dnnl::impl::memory_desc_t &src1_desc,
241 const memory_desc_wrapper &dst_d,
242 const bcast_set_t &supported_strategy_set);
243
244/*
245 * Checks if binary injection for given args is supported.
246 */
247bool is_supported(cpu_isa_t isa, const dnnl::impl::memory_desc_t &src1_desc,
248 const memory_desc_wrapper &dst_d,
249 const bcast_set_t &supported_strategy_set);
250
251/*
252 * Main mechanism responsible for injecting binary postops supporting various
253 * isa: sse41, avx, avx2, avx512 with core, bf16 extensions as well as data
254 * types: f32, bf16, s32, u8, s8.
255 */
256template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm>
257class jit_uni_binary_injector_t {
258public:
259 jit_uni_binary_injector_t(
260 jit_generator *host, const static_params_t &static_params);
261
262 /*
263 * Generates code of binary post_op injected to host primitive. Applied to
264 * ordered set of vector registers' indexes. Function loads appropriate
265 * slice of rhs tensor for computations based on internally determined
266 * broadcast strategy and information about stored data in particular vmm
267 * described inside rhs_arg_params.
268 */
269 void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs,
270 std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
271 const rhs_arg_dynamic_params_t &rhs_arg_params) const;
272
273 /*
274 * Generates code of binary post_op injected to host primitive. Applied to
275 * range <start_idx, end_idx) of vector registers' indexes. Function loads
276 * appropriate slice of rhs tensor for computations based on internally
277 * determined broadcast strategy and information about stored data in particular
278 * vmm described inside rhs_arg_params.
279 */
280 void compute_vector_range(size_t start_idx, size_t end_idx,
281 std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
282 const rhs_arg_dynamic_params_t &rhs_arg_params) const;
283
284 /*
285 * Generates code of binary post_op injected to host primitive. Applied to
286 * a single vector register index. Function loads appropriate slice of rhs tensor
287 * for computations based on internally determined broadcast strategy and information
288 * about stored data in particular vmm described inside rhs_arg_params.
289 */
290 void compute_vector(size_t idx, std::size_t rhs_arg_idx,
291 const dnnl_post_ops::entry_t &post_op,
292 const rhs_arg_dynamic_params_t &rhs_arg_params) const;
293
294private:
295 /*
296 * Determines if hint passed by user is valid (is inside range
297 * <start_idx, end_idx>). If not it returns new vmm idx value that will be
298 * used as temporary vmm in future computations.
299 */
300 int adjust_temp_vmm_hint(
301 int user_hint, int start_idx, int end_idx, int max_vmm_idx) const;
302 /*
303 * Taking into account rhs_broadcasting_strategy and information from user
304 * about tensor slice (rhs_arg_params) stored in Vmm(vmm_idx) calculates
305 * address of rhs tensor slice needed for binary operation and returns
306 * ptr to it.
307 */
308 Xbyak::Address prepare_rhs_arg_addr(std::size_t vmm_idx,
309 std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op,
310 const rhs_arg_dynamic_params_t &rhs_arg_params,
311 const broadcasting_strategy_t rhs_broadcasting_strategy,
312 bool is_first) const;
313 /*
314 * Loads data and applies particular binary operation.
315 */
316 void inject_binary(const dnnl_post_ops::entry_t &post_op, Vmm dst,
317 const Xbyak::Address &rhs_addr, bool with_tail,
318 const tail_lode_mode_t tail_load_mode) const;
319
320 /*
321 * Helper functions responsible for preparing rhs tensor slice address.
322 */
323 void append_no_broadcast_offset(
324 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
325 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
326 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
327 int vmm_idx, const Xbyak::Reg64 &addr_reg,
328 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
329 bool is_first) const;
330 void calculate_no_broadcast_base(
331 Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const;
332 void calculate_no_broadcast_partial(const std::size_t offset,
333 const Xbyak::Reg64 &out_reg, std::size_t elem_size_bytes) const;
334
335 void append_oc_offset(
336 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
337 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
338 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
339 int vmm_idx, const Xbyak::Reg64 &addr_reg,
340 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
341 bool is_first) const;
342 void calculate_oc_ncsp_base(
343 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
344 void calculate_oc_ncsp_partial(const dim_t *strides,
345 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
346 std::size_t elem_size_bytes) const;
347 void calculate_oc_blocked_base(
348 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
349 void calculate_oc_blocked_partial(const dim_t *strides,
350 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
351 std::size_t elem_size_bytes) const;
352 void calculate_oc_nspc_base(
353 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
354 void calculate_oc_nspc_partial(const dim_t *strides,
355 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
356 std::size_t elem_size_bytes) const;
357 void calculate_oc_cspn_base(
358 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
359 void calculate_oc_cspn_partial(const dim_t *strides,
360 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
361 std::size_t elem_size_bytes) const;
362
363 void append_mb_sp_offset(
364 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
365 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
366 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
367 int vmm_idx, const Xbyak::Reg64 &addr_reg,
368 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
369 bool is_first) const;
370 void calculate_mb_sp_ncsp_base(
371 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
372 void calculate_mb_sp_ncsp_partial(const dim_t *strides,
373 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
374 std::size_t elem_size_bytes) const;
375 void calculate_mb_sp_blocked_base(
376 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
377 void calculate_mb_sp_blocked_partial(const dim_t *strides,
378 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
379 std::size_t elem_size_bytes) const;
380 void calculate_mb_sp_nspc_base(
381 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
382 void calculate_mb_sp_nspc_partial(const dim_t *strides,
383 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
384 std::size_t elem_size_bytes) const;
385 void calculate_mb_sp_cspn_base(
386 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
387 void calculate_mb_sp_cspn_partial(const dim_t *strides,
388 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
389 std::size_t elem_size_bytes) const;
390
391 void append_mb_w_offset(
392 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
393 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
394 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
395 int vmm_idx, const Xbyak::Reg64 &addr_reg,
396 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
397 bool is_first) const;
398 void calculate_mb_w_ncsp_base(
399 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
400 void calculate_mb_w_ncsp_partial(const dim_t *strides,
401 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
402 std::size_t elem_size_bytes) const;
403 void calculate_mb_w_blocked_base(
404 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
405 void calculate_mb_w_blocked_partial(const dim_t *strides,
406 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
407 std::size_t elem_size_bytes) const;
408 void calculate_mb_w_nspc_base(
409 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
410 void calculate_mb_w_nspc_partial(const dim_t *strides,
411 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
412 std::size_t elem_size_bytes) const;
413 void calculate_mb_w_cspn_base(
414 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
415 void calculate_mb_w_cspn_partial(const dim_t *strides,
416 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
417 std::size_t elem_size_bytes) const;
418
419 void append_w_offset(
420 const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr,
421 const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg,
422 const std::map<int, size_t> &vmm_idx_to_out_elem_off_val,
423 int vmm_idx, const Xbyak::Reg64 &addr_reg,
424 const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes,
425 bool is_first) const;
426 void calculate_w_ncsp_base(
427 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
428 void calculate_w_ncsp_partial(const dim_t *strides,
429 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
430 std::size_t elem_size_bytes) const;
431 void calculate_w_blocked_base(
432 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
433 void calculate_w_blocked_partial(const dim_t *strides,
434 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
435 std::size_t elem_size_bytes) const;
436 void calculate_w_nspc_base(
437 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
438 void calculate_w_nspc_partial(const dim_t *strides,
439 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
440 std::size_t elem_size_bytes) const;
441 void calculate_w_cspn_base(
442 const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const;
443 void calculate_w_cspn_partial(const dim_t *strides,
444 const std::size_t offset, const Xbyak::Reg64 &tmp_reg,
445 std::size_t elem_size_bytes) const;
446
447 template <typename T>
448 typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value
449 || std::is_same<T, Xbyak::Address>::value>::type
450 execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs,
451 const unsigned int cmp_predicate) const;
452 template <typename T>
453 typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value
454 || std::is_same<T, Xbyak::Address>::value)>::type
455 execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs,
456 const unsigned int cmp_predicate) const;
457 template <typename T>
458 void execute_binary(alg_kind_t binary_alg, const Vmm &dst, const Vmm &lhs,
459 const T &rhs) const;
460 /*
461 * Used in scalar broadcast strategy, broadcasting single value of given
462 * data type over entire vector Vmm register.
463 */
464 void execute_broadcast(const data_type_t &data_type, const Vmm &tmp_reg,
465 const Xbyak::Address &rhs_addr,
466 const tail_lode_mode_t tail_load_mode,
467 bool with_tail = false) const;
468 void load_rhs(const data_type_t &data_type, const Vmm &tmp_reg,
469 const Xbyak::Address &rhs_addr,
470 const tail_lode_mode_t tail_load_mode,
471 bool with_tail = false) const;
472 void execute_broadcast_tail_with_opmask(const data_type_t &data_type,
473 const Vmm &tmp_reg, const Xbyak::Address &rhs_addr) const;
474 void execute_broadcast_tail_statically(const data_type_t &data_type,
475 const Vmm &tmp_reg, const Xbyak::Address &rhs_addr,
476 const std::size_t tail_size) const;
477 void execute_broadcast_tail_with_gpr(const data_type_t &data_type,
478 const Vmm &tmp_reg, const Xbyak::Address &rhs_addr) const;
479 void load_rhs_tail_dynamically_with_opmask(const data_type_t &data_type,
480 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const;
481 void load_rhs_tail_dynamically_with_gpr(
482 const data_type_t &data_type, const Vmm &tmp_vmm) const;
483 void load_rhs_tail_statically(const data_type_t &data_type,
484 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const;
485 void execute_broadcast_no_tail(const data_type_t &data_type,
486 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const;
487 void execute_broadcast_s8u8_no_tail(const data_type_t &data_type,
488 const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const;
489 void load_rhs_no_tail(const data_type_t &data_type, const Vmm &tmp_reg,
490 const Xbyak::Address &rhs_addr) const;
491 void load_rhs_i8_no_tail(const data_type_t &data_type, const Vmm &tmp_reg,
492 const Xbyak::Address &rhs_addr) const;
493 void cvt_to_f32(const Vmm &tmp_reg) const;
494 /*
495 * Returns pair consisting of flag indication preservation is needed for vmm
496 * index in second member that should be used as temporary vmm inside inject
497 * binary.
498 */
499 std::pair<bool, int> should_preserve_vmm(int curr_idx, int vmm_hint,
500 int max_vmm_idx, bool dt_helper_vmm_needed) const;
501 /*
502 * Used in isa != avx512 where m32bcst is not supported, replaces ptr_b
503 * with ptr.
504 */
505 Xbyak::Address remove_bcast_bit(const Xbyak::Address &rhs_addr) const;
506
507 jit_generator *host_;
508 const rhs_arg_static_params_t rhs_arg_static_params_;
509 const Xbyak::Reg64 param1_;
510 const bcast_set_t supported_strategy_set_;
511 const bool is_avx512_ = is_superset(isa, avx512_core);
512 const bool is_avx512_core_fp16_ = is_superset(isa, avx512_core_fp16);
513
514 static constexpr int sizeof_reg64 = 8;
515 /*
516 * Instructions from SSE/AVX used to compute binary result like vaddps where
517 * second operand is memory, require mem operand to be 16/32 byte explicitly
518 * aligned. (Intel Manual chapter 2.4).
519 * Rule is relaxed from AVX2 (Intel Manual chapter 14.9).
520 * When using benchdnn zmalloc_protect doesn't guarantee that tensor memory
521 * address is 64 byte aligned, which can cause segmentation fault.
522 */
523 static constexpr bool binary_op_with_unaligned_mem_operand_allowed_
524 = !utils::one_of(isa, avx, sse41);
525};
526
527} // namespace binary_injector
528} // namespace x64
529} // namespace cpu
530} // namespace impl
531} // namespace dnnl
532
533#endif
534