1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_BINARY_INJECTOR_HPP |
18 | #define CPU_X64_JIT_UNI_BINARY_INJECTOR_HPP |
19 | |
20 | #include <array> |
21 | #include <cassert> |
22 | #include <functional> |
23 | #include <map> |
24 | #include <utility> |
25 | #include <vector> |
26 | #include <unordered_set> |
27 | |
28 | #include "common/broadcast_strategy.hpp" |
29 | #include "common/c_types_map.hpp" |
30 | #include "common/primitive_attr.hpp" |
31 | #include "common/primitive_exec_types.hpp" |
32 | #include "cpu/binary_injector_utils.hpp" |
33 | #include "cpu/x64/cpu_isa_traits.hpp" |
34 | #include "cpu/x64/injectors/injector_utils.hpp" |
35 | #include "cpu/x64/jit_generator.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | namespace binary_injector { |
42 | using dnnl::impl::cpu::binary_injector_utils::prepare_binary_args; |
43 | |
44 | bool binary_args_matches_tag(format_tag_t tag, const post_ops_t &post_ops); |
45 | |
46 | bool binary_args_broadcast_supported(const post_ops_t &post_ops, |
47 | const memory_desc_wrapper &dst_d, |
48 | const bcast_set_t &supported_strategy_set); |
49 | |
50 | bool binary_args_tail_supported(const post_ops_t &post_ops, |
51 | const memory_desc_wrapper &dst_d, int vlen, |
52 | const bcast_set_t &supported_strategy_set); |
53 | |
54 | bool any_binary_postop_rhs_non_scalar_broadcast( |
55 | const post_ops_t &post_ops, const memory_desc_wrapper &dst_d); |
56 | bool any_binary_postop_rhs_per_oc_broadcast( |
57 | const post_ops_t &post_ops, const memory_desc_wrapper &dst_d); |
58 | bool any_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, |
59 | const memory_desc_wrapper &dst_d, |
60 | const bcast_set_t &supported_strategy_set); |
61 | |
62 | bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, |
63 | const memory_desc_wrapper &dst_d, |
64 | const std::function<bool(const memory_desc_wrapper &)> &predicate); |
65 | bool all_binary_postop_rhs_per_oc_broadcast(const post_ops_t &post_ops, |
66 | const memory_desc_wrapper &dst_d, |
67 | const bcast_set_t &supported_strategy_set, |
68 | const std::function<bool(const memory_desc_wrapper &)> &predicate); |
69 | |
70 | /* |
71 | * Represents params related to all binary post-ops right-hand side arguments |
72 | * (arg1) that don't change during jit_uni_binary_injector_t object lifetime |
73 | * and between compute_vector_range calls. |
74 | * |
75 | * @param rhs_dt_helper_vmm_idx - index of vmm helper used when loading data for |
76 | * calculations. Treated as hint from user. If inside compute_vector_range hint |
77 | * turns out to be invalid, it will be overwriten by register preserving logic inside |
78 | * binary injector. |
79 | * @param rhs_addr_reg - gpr register, used as the currently processed address of |
80 | * rhs tensor slice. Data of rhs(arg1) for the binary operation is loaded from address |
81 | * stored inside rhs_addr_reg. |
82 | * @param rhs_helper_reg - gpr register used as helper for calculations during data |
83 | * loading phase. |
84 | * @param rhs_addr_cache_reg - gpr register used for caching part of calculated |
85 | * offset, this register is always preserved. |
86 | * @param preserve_gpr_helpers - determines whether gpr registers specified above |
87 | * should be preserved (pushed to stack and poped back afterwords) between |
88 | * compute_vector_range calls. |
89 | * @param preserve_vmm_helper - determines whether vmm helper register specified |
90 | * above should be preserved between compute_vector_range calls. |
91 | * @param abi_param_offset - offset to rhs tensor from first binary post-op operation |
92 | * specified by user from runtime structure passed to kernel as abi param 1. |
93 | * @param dst_orig_offset - offset 0 to destination tensor |
94 | * @param dst_d - descriptor of destination tensor (result after applying all post-ops |
95 | * operations). |
96 | * @param tail_opmask - register with loaded by user mask, used in avx512 for load with |
97 | * tail handling. |
98 | * @param tail_size - size of processed tail in elements. |
99 | * @param use_exact_tail_scalar_bcast - in case of scalar broadcast user can disable |
100 | * loading data with tail, usually bcast through entire vector is faster (usually 1 instruction) |
101 | * vs. broadcasting limited by tail size (potentially several instructions). In case |
102 | * when user during storing ignores values from vmm above tail size, setting this option to |
103 | * false can result in better performance. |
104 | * @param reg_tail_size - register with loaded size of tail, used in sse41/avx/avx2 |
105 | * for load with tail in runtime. |
106 | */ |
107 | struct rhs_arg_static_params_t { |
108 | rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, |
109 | const Xbyak::Reg64 &rhs_addr_reg, |
110 | const Xbyak::Reg64 &rhs_helper_reg, |
111 | const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, |
112 | bool preserve_vmm_helper, std::size_t abi_param_offset, |
113 | std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, |
114 | std::size_t tail_size = 0u, |
115 | bool use_exact_tail_scalar_bcast = false); |
116 | rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, |
117 | const Xbyak::Reg64 &rhs_addr_reg, |
118 | const Xbyak::Reg64 &rhs_helper_reg, |
119 | const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, |
120 | bool preserve_vmm_helper, std::size_t abi_param_offset, |
121 | std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, |
122 | std::size_t tail_size, const Xbyak::Opmask &tail_opmask, |
123 | bool use_exact_tail_scalar_bcast); |
124 | rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, |
125 | const Xbyak::Reg64 &rhs_addr_reg, |
126 | const Xbyak::Reg64 &rhs_helper_reg, |
127 | const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, |
128 | bool preserve_vmm_helper, std::size_t abi_param_offset, |
129 | std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, |
130 | std::size_t tail_size, const Xbyak::Opmask &tail_opmask, |
131 | const Xbyak::Reg64 ®_tail_size, |
132 | bool use_exact_tail_scalar_bcast); |
133 | |
134 | bool is_opmask_set() const noexcept { return is_opmask_set_; } |
135 | |
136 | mutable std::size_t rhs_dt_helper_vmm_idx; |
137 | Xbyak::Reg64 rhs_addr_reg; |
138 | Xbyak::Reg64 rhs_helper_reg; |
139 | Xbyak::Reg64 rhs_addr_cache_reg; |
140 | bool preserve_gpr_helpers; |
141 | bool preserve_vmm_helper; |
142 | std::size_t abi_param_offset; |
143 | std::size_t dst_orig_offset; |
144 | memory_desc_wrapper dst_d; |
145 | std::size_t tail_size; |
146 | Xbyak::Opmask tail_opmask; |
147 | bool use_exact_tail_scalar_bcast; |
148 | Xbyak::Reg64 reg_tail_size; |
149 | bool is_tail; |
150 | |
151 | private: |
152 | rhs_arg_static_params_t(std::size_t rhs_dt_helper_vmm_idx, |
153 | const Xbyak::Reg64 &rhs_addr_reg, |
154 | const Xbyak::Reg64 &rhs_helper_reg, |
155 | const Xbyak::Reg64 &rhs_addr_cache_reg, bool preserve_gpr_helpers, |
156 | bool preserve_vmm_helper, std::size_t abi_param_offset, |
157 | std::size_t dst_orig_offset, const memory_desc_wrapper &dst_d, |
158 | std::size_t tail_size, const Xbyak::Opmask &tail_opmask, |
159 | bool use_exact_tail_scalar_bcast, const Xbyak::Reg64 ®_tail_size, |
160 | bool is_opmask_set); |
161 | |
162 | bool is_opmask_set_; |
163 | }; |
164 | |
165 | /* |
166 | * Represents params required by jit_uni_binary_injector_t that don't change |
167 | * during it's entire lifetime. |
168 | * |
169 | * @param param1 - register storing abi param1. At the moment of calling |
170 | * compute_vector_range method can be different than the default one defined |
171 | * inside jit_generator. |
172 | * @param bcast_set_t supported_strategy_set - set allowing disabling particular |
173 | * bcast strategies |
174 | * @param rhs_arg_static_params - params related to all binary post-ops right-hand side |
175 | * arguments that don't change during entire lifetime of jit_uni_binary_injector_t |
176 | * object. |
177 | */ |
178 | struct static_params_t { |
179 | static_params_t(const Xbyak::Reg64 ¶m1, |
180 | const bcast_set_t &supported_strategy_set, |
181 | const rhs_arg_static_params_t &rhs_arg_static_params); |
182 | static_params_t(const Xbyak::Reg64 ¶m1, |
183 | const rhs_arg_static_params_t &rhs_arg_static_params); |
184 | |
185 | Xbyak::Reg64 param1; |
186 | const bcast_set_t supported_strategy_set; |
187 | rhs_arg_static_params_t rhs_arg_static_params; |
188 | }; |
189 | |
190 | /* |
191 | * Mode of data load with tail for rhs: |
192 | * STATIC - load based on given integer. |
193 | * DYNAMIC - load based on opmask or 64-bit register. |
194 | * DEFAULT - DYNAMIC for avx512, STATIC for others. |
195 | */ |
196 | |
197 | enum class tail_lode_mode_t { STATIC, DYNAMIC, DEFAULT }; |
198 | |
199 | /* |
200 | * Represents params passed to compute_vector_range method of |
201 | * jit_uni_binary_injector_t that can be different for each call. |
202 | * Contains configurable std::maps where key is vmm index and value is |
203 | * offset in elements. The offset value identifies tensor slice in particular |
204 | * vmm. This is utilized by broadcasting mechanism. Offset, depending on the |
205 | * implementation particular kernels, can be passed as value (usually during |
206 | * unrolling), inside operand, under memory address. |
207 | * |
208 | * @param vmm_idx_to_out_addr - vmm mapped to address of destination tensor with offset, |
209 | * used to calculate offset in no_broadcast strategy, but also in other strategies whose |
210 | * calculations are based on no_broadcast strategy. |
211 | * @param vmm_idx_to_out_reg - vmm mapped to register containing address of destination |
212 | * with offset, used to calculate offset in no_broadcast strategy, but also in other |
213 | * strategies whose calculations are based on no_broadcast strategy. |
214 | * @param vmm_idx_to_out_elem_off_val - vmm mapped to offset in elements passed as raw |
215 | * value intended to use in no_broadcast strategy, but also in other |
216 | * strategies whose calculations are based on no_broadcast strategy. |
217 | * @param vmm_tail_idx - vmm indices that contains data don't fill the whole vector (tail). |
218 | * @param is_dynamic_tail_load - determines whether to load with tail in |
219 | * runtime (based on the value from reg_tail_size or opmask) or based on given |
220 | * integer. |
221 | */ |
222 | |
223 | struct rhs_arg_dynamic_params_t { |
224 | std::map<int, Xbyak::Address> vmm_idx_to_out_addr; |
225 | std::map<int, Xbyak::Reg64> vmm_idx_to_out_reg; |
226 | std::map<int, size_t> vmm_idx_to_out_elem_off_val; |
227 | |
228 | std::unordered_set<int> vmm_tail_idx_; |
229 | tail_lode_mode_t tail_load_mode = tail_lode_mode_t::DEFAULT; |
230 | }; |
231 | |
232 | /* |
233 | * Checks if src1 data type is supported by binary injector. |
234 | */ |
235 | bool is_data_supported(cpu_isa_t isa, data_type_t data_type); |
236 | |
237 | /* |
238 | * Checks if broadcast of src1 is supported by binary injector. |
239 | */ |
240 | bool is_bcast_supported(const dnnl::impl::memory_desc_t &src1_desc, |
241 | const memory_desc_wrapper &dst_d, |
242 | const bcast_set_t &supported_strategy_set); |
243 | |
244 | /* |
245 | * Checks if binary injection for given args is supported. |
246 | */ |
247 | bool is_supported(cpu_isa_t isa, const dnnl::impl::memory_desc_t &src1_desc, |
248 | const memory_desc_wrapper &dst_d, |
249 | const bcast_set_t &supported_strategy_set); |
250 | |
251 | /* |
252 | * Main mechanism responsible for injecting binary postops supporting various |
253 | * isa: sse41, avx, avx2, avx512 with core, bf16 extensions as well as data |
254 | * types: f32, bf16, s32, u8, s8. |
255 | */ |
256 | template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm> |
257 | class jit_uni_binary_injector_t { |
258 | public: |
259 | jit_uni_binary_injector_t( |
260 | jit_generator *host, const static_params_t &static_params); |
261 | |
262 | /* |
263 | * Generates code of binary post_op injected to host primitive. Applied to |
264 | * ordered set of vector registers' indexes. Function loads appropriate |
265 | * slice of rhs tensor for computations based on internally determined |
266 | * broadcast strategy and information about stored data in particular vmm |
267 | * described inside rhs_arg_params. |
268 | */ |
269 | void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, |
270 | std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op, |
271 | const rhs_arg_dynamic_params_t &rhs_arg_params) const; |
272 | |
273 | /* |
274 | * Generates code of binary post_op injected to host primitive. Applied to |
275 | * range <start_idx, end_idx) of vector registers' indexes. Function loads |
276 | * appropriate slice of rhs tensor for computations based on internally |
277 | * determined broadcast strategy and information about stored data in particular |
278 | * vmm described inside rhs_arg_params. |
279 | */ |
280 | void compute_vector_range(size_t start_idx, size_t end_idx, |
281 | std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op, |
282 | const rhs_arg_dynamic_params_t &rhs_arg_params) const; |
283 | |
284 | /* |
285 | * Generates code of binary post_op injected to host primitive. Applied to |
286 | * a single vector register index. Function loads appropriate slice of rhs tensor |
287 | * for computations based on internally determined broadcast strategy and information |
288 | * about stored data in particular vmm described inside rhs_arg_params. |
289 | */ |
290 | void compute_vector(size_t idx, std::size_t rhs_arg_idx, |
291 | const dnnl_post_ops::entry_t &post_op, |
292 | const rhs_arg_dynamic_params_t &rhs_arg_params) const; |
293 | |
294 | private: |
295 | /* |
296 | * Determines if hint passed by user is valid (is inside range |
297 | * <start_idx, end_idx>). If not it returns new vmm idx value that will be |
298 | * used as temporary vmm in future computations. |
299 | */ |
300 | int adjust_temp_vmm_hint( |
301 | int user_hint, int start_idx, int end_idx, int max_vmm_idx) const; |
302 | /* |
303 | * Taking into account rhs_broadcasting_strategy and information from user |
304 | * about tensor slice (rhs_arg_params) stored in Vmm(vmm_idx) calculates |
305 | * address of rhs tensor slice needed for binary operation and returns |
306 | * ptr to it. |
307 | */ |
308 | Xbyak::Address prepare_rhs_arg_addr(std::size_t vmm_idx, |
309 | std::size_t rhs_arg_idx, const dnnl_post_ops::entry_t &post_op, |
310 | const rhs_arg_dynamic_params_t &rhs_arg_params, |
311 | const broadcasting_strategy_t rhs_broadcasting_strategy, |
312 | bool is_first) const; |
313 | /* |
314 | * Loads data and applies particular binary operation. |
315 | */ |
316 | void inject_binary(const dnnl_post_ops::entry_t &post_op, Vmm dst, |
317 | const Xbyak::Address &rhs_addr, bool with_tail, |
318 | const tail_lode_mode_t tail_load_mode) const; |
319 | |
320 | /* |
321 | * Helper functions responsible for preparing rhs tensor slice address. |
322 | */ |
323 | void append_no_broadcast_offset( |
324 | const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr, |
325 | const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg, |
326 | const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, |
327 | int vmm_idx, const Xbyak::Reg64 &addr_reg, |
328 | const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, |
329 | bool is_first) const; |
330 | void calculate_no_broadcast_base( |
331 | Xbyak::Address addr, const Xbyak::Reg64 &out_reg) const; |
332 | void calculate_no_broadcast_partial(const std::size_t offset, |
333 | const Xbyak::Reg64 &out_reg, std::size_t elem_size_bytes) const; |
334 | |
335 | void append_oc_offset( |
336 | const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr, |
337 | const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg, |
338 | const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, |
339 | int vmm_idx, const Xbyak::Reg64 &addr_reg, |
340 | const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, |
341 | bool is_first) const; |
342 | void calculate_oc_ncsp_base( |
343 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
344 | void calculate_oc_ncsp_partial(const dim_t *strides, |
345 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
346 | std::size_t elem_size_bytes) const; |
347 | void calculate_oc_blocked_base( |
348 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
349 | void calculate_oc_blocked_partial(const dim_t *strides, |
350 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
351 | std::size_t elem_size_bytes) const; |
352 | void calculate_oc_nspc_base( |
353 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
354 | void calculate_oc_nspc_partial(const dim_t *strides, |
355 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
356 | std::size_t elem_size_bytes) const; |
357 | void calculate_oc_cspn_base( |
358 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
359 | void calculate_oc_cspn_partial(const dim_t *strides, |
360 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
361 | std::size_t elem_size_bytes) const; |
362 | |
363 | void append_mb_sp_offset( |
364 | const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr, |
365 | const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg, |
366 | const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, |
367 | int vmm_idx, const Xbyak::Reg64 &addr_reg, |
368 | const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, |
369 | bool is_first) const; |
370 | void calculate_mb_sp_ncsp_base( |
371 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
372 | void calculate_mb_sp_ncsp_partial(const dim_t *strides, |
373 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
374 | std::size_t elem_size_bytes) const; |
375 | void calculate_mb_sp_blocked_base( |
376 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
377 | void calculate_mb_sp_blocked_partial(const dim_t *strides, |
378 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
379 | std::size_t elem_size_bytes) const; |
380 | void calculate_mb_sp_nspc_base( |
381 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
382 | void calculate_mb_sp_nspc_partial(const dim_t *strides, |
383 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
384 | std::size_t elem_size_bytes) const; |
385 | void calculate_mb_sp_cspn_base( |
386 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
387 | void calculate_mb_sp_cspn_partial(const dim_t *strides, |
388 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
389 | std::size_t elem_size_bytes) const; |
390 | |
391 | void append_mb_w_offset( |
392 | const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr, |
393 | const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg, |
394 | const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, |
395 | int vmm_idx, const Xbyak::Reg64 &addr_reg, |
396 | const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, |
397 | bool is_first) const; |
398 | void calculate_mb_w_ncsp_base( |
399 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
400 | void calculate_mb_w_ncsp_partial(const dim_t *strides, |
401 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
402 | std::size_t elem_size_bytes) const; |
403 | void calculate_mb_w_blocked_base( |
404 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
405 | void calculate_mb_w_blocked_partial(const dim_t *strides, |
406 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
407 | std::size_t elem_size_bytes) const; |
408 | void calculate_mb_w_nspc_base( |
409 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
410 | void calculate_mb_w_nspc_partial(const dim_t *strides, |
411 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
412 | std::size_t elem_size_bytes) const; |
413 | void calculate_mb_w_cspn_base( |
414 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
415 | void calculate_mb_w_cspn_partial(const dim_t *strides, |
416 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
417 | std::size_t elem_size_bytes) const; |
418 | |
419 | void append_w_offset( |
420 | const std::map<int, Xbyak::Address> &vmm_idx_to_out_addr, |
421 | const std::map<int, Xbyak::Reg64> &vmm_idx_to_out_reg, |
422 | const std::map<int, size_t> &vmm_idx_to_out_elem_off_val, |
423 | int vmm_idx, const Xbyak::Reg64 &addr_reg, |
424 | const Xbyak::Reg64 &tmp_reg, std::size_t elem_size_bytes, |
425 | bool is_first) const; |
426 | void calculate_w_ncsp_base( |
427 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
428 | void calculate_w_ncsp_partial(const dim_t *strides, |
429 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
430 | std::size_t elem_size_bytes) const; |
431 | void calculate_w_blocked_base( |
432 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
433 | void calculate_w_blocked_partial(const dim_t *strides, |
434 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
435 | std::size_t elem_size_bytes) const; |
436 | void calculate_w_nspc_base( |
437 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
438 | void calculate_w_nspc_partial(const dim_t *strides, |
439 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
440 | std::size_t elem_size_bytes) const; |
441 | void calculate_w_cspn_base( |
442 | const dim_t *strides, const Xbyak::Reg64 &tmp_reg) const; |
443 | void calculate_w_cspn_partial(const dim_t *strides, |
444 | const std::size_t offset, const Xbyak::Reg64 &tmp_reg, |
445 | std::size_t elem_size_bytes) const; |
446 | |
447 | template <typename T> |
448 | typename std::enable_if<std::is_same<T, Xbyak::Zmm>::value |
449 | || std::is_same<T, Xbyak::Address>::value>::type |
450 | execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs, |
451 | const unsigned int cmp_predicate) const; |
452 | template <typename T> |
453 | typename std::enable_if<!(std::is_same<T, Xbyak::Zmm>::value |
454 | || std::is_same<T, Xbyak::Address>::value)>::type |
455 | execute_cmp_binary(const Vmm &dst, const Vmm &lhs, const T &rhs, |
456 | const unsigned int cmp_predicate) const; |
457 | template <typename T> |
458 | void execute_binary(alg_kind_t binary_alg, const Vmm &dst, const Vmm &lhs, |
459 | const T &rhs) const; |
460 | /* |
461 | * Used in scalar broadcast strategy, broadcasting single value of given |
462 | * data type over entire vector Vmm register. |
463 | */ |
464 | void execute_broadcast(const data_type_t &data_type, const Vmm &tmp_reg, |
465 | const Xbyak::Address &rhs_addr, |
466 | const tail_lode_mode_t tail_load_mode, |
467 | bool with_tail = false) const; |
468 | void load_rhs(const data_type_t &data_type, const Vmm &tmp_reg, |
469 | const Xbyak::Address &rhs_addr, |
470 | const tail_lode_mode_t tail_load_mode, |
471 | bool with_tail = false) const; |
472 | void execute_broadcast_tail_with_opmask(const data_type_t &data_type, |
473 | const Vmm &tmp_reg, const Xbyak::Address &rhs_addr) const; |
474 | void execute_broadcast_tail_statically(const data_type_t &data_type, |
475 | const Vmm &tmp_reg, const Xbyak::Address &rhs_addr, |
476 | const std::size_t tail_size) const; |
477 | void execute_broadcast_tail_with_gpr(const data_type_t &data_type, |
478 | const Vmm &tmp_reg, const Xbyak::Address &rhs_addr) const; |
479 | void load_rhs_tail_dynamically_with_opmask(const data_type_t &data_type, |
480 | const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const; |
481 | void load_rhs_tail_dynamically_with_gpr( |
482 | const data_type_t &data_type, const Vmm &tmp_vmm) const; |
483 | void load_rhs_tail_statically(const data_type_t &data_type, |
484 | const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const; |
485 | void execute_broadcast_no_tail(const data_type_t &data_type, |
486 | const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const; |
487 | void execute_broadcast_s8u8_no_tail(const data_type_t &data_type, |
488 | const Vmm &tmp_vmm, const Xbyak::Address &rhs_addr) const; |
489 | void load_rhs_no_tail(const data_type_t &data_type, const Vmm &tmp_reg, |
490 | const Xbyak::Address &rhs_addr) const; |
491 | void load_rhs_i8_no_tail(const data_type_t &data_type, const Vmm &tmp_reg, |
492 | const Xbyak::Address &rhs_addr) const; |
493 | void cvt_to_f32(const Vmm &tmp_reg) const; |
494 | /* |
495 | * Returns pair consisting of flag indication preservation is needed for vmm |
496 | * index in second member that should be used as temporary vmm inside inject |
497 | * binary. |
498 | */ |
499 | std::pair<bool, int> should_preserve_vmm(int curr_idx, int vmm_hint, |
500 | int max_vmm_idx, bool dt_helper_vmm_needed) const; |
501 | /* |
502 | * Used in isa != avx512 where m32bcst is not supported, replaces ptr_b |
503 | * with ptr. |
504 | */ |
505 | Xbyak::Address remove_bcast_bit(const Xbyak::Address &rhs_addr) const; |
506 | |
507 | jit_generator *host_; |
508 | const rhs_arg_static_params_t rhs_arg_static_params_; |
509 | const Xbyak::Reg64 param1_; |
510 | const bcast_set_t supported_strategy_set_; |
511 | const bool is_avx512_ = is_superset(isa, avx512_core); |
512 | const bool is_avx512_core_fp16_ = is_superset(isa, avx512_core_fp16); |
513 | |
514 | static constexpr int sizeof_reg64 = 8; |
515 | /* |
516 | * Instructions from SSE/AVX used to compute binary result like vaddps where |
517 | * second operand is memory, require mem operand to be 16/32 byte explicitly |
518 | * aligned. (Intel Manual chapter 2.4). |
519 | * Rule is relaxed from AVX2 (Intel Manual chapter 14.9). |
520 | * When using benchdnn zmalloc_protect doesn't guarantee that tensor memory |
521 | * address is 64 byte aligned, which can cause segmentation fault. |
522 | */ |
523 | static constexpr bool binary_op_with_unaligned_mem_operand_allowed_ |
524 | = !utils::one_of(isa, avx, sse41); |
525 | }; |
526 | |
527 | } // namespace binary_injector |
528 | } // namespace x64 |
529 | } // namespace cpu |
530 | } // namespace impl |
531 | } // namespace dnnl |
532 | |
533 | #endif |
534 | |