1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #ifndef CPU_X64_JIT_UNI_POSTOPS_INJECTOR_HPP |
17 | #define CPU_X64_JIT_UNI_POSTOPS_INJECTOR_HPP |
18 | |
19 | #include <functional> |
20 | #include <map> |
21 | #include <memory> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/primitive_attr.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | #include "cpu/x64/injectors/injector_utils.hpp" |
28 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
30 | #include "cpu/x64/jit_generator.hpp" |
31 | #include <initializer_list> |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | namespace injector { |
38 | |
39 | /* |
40 | * Allows specifying custom injector function for given post-op type - one |
41 | * function per primitive. There are post-ops type (example: sum) that don't |
42 | * have specialized injector. They heavily rely on kernel specific intrnals, |
43 | * which makes the generalization unreasonable. As so user can prepare internal |
44 | * kernel lambda and pass it explicitly to injector. |
45 | */ |
46 | using lambda_jit_injectors_t |
47 | = std::map<dnnl_primitive_kind_t, std::function<void()>>; |
48 | |
49 | struct post_ops_ok_args_t; |
50 | /* |
51 | * Checks if postops injection for given args is supported. |
52 | */ |
53 | bool is_supported(const post_ops_ok_args_t &post_ops_ok_args); |
54 | |
55 | /* |
56 | * Main mechanism of handling various post-ops types. It utilizes internally |
57 | * specialized injectors to generate post-ops code to host primitive. Random |
58 | * order of post-ops is supported. |
59 | */ |
60 | template <cpu_isa_t isa, typename Vmm = typename cpu_isa_traits<isa>::Vmm> |
61 | class jit_uni_postops_injector_t { |
62 | public: |
63 | /* |
64 | * @param host <required> - user primitive where post-ops generated code is |
65 | * injected |
66 | * @param post_ops <required> - struct representing requested post-ops chain |
67 | * @binary_static_params <reguired> - static params needed for binary_injector. |
68 | * see: jit_uni_binary_injector.hpp for more info. |
69 | * @param eltwise_static_params <optional> - allows user specify non default |
70 | * params for eltwise_injector |
71 | * @param lambda_jit_injectors <optional> - allows user specify custom injector |
72 | * function for given post-op type |
73 | */ |
74 | jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, |
75 | const binary_injector::static_params_t &binary_static_params); |
76 | jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, |
77 | const binary_injector::static_params_t &binary_static_params, |
78 | const lambda_jit_injectors_t &lambda_jit_injectors); |
79 | jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, |
80 | const binary_injector::static_params_t &binary_static_params, |
81 | const eltwise_injector::static_params_t &eltwise_static_params); |
82 | jit_uni_postops_injector_t(jit_generator *host, const post_ops_t &post_ops, |
83 | const binary_injector::static_params_t &binary_static_params, |
84 | const eltwise_injector::static_params_t &eltwise_static_params, |
85 | const lambda_jit_injectors_t &lambda_jit_injectors); |
86 | |
87 | /* |
88 | * Generates code of post_ops chain injected to host primitive. Applied to |
89 | * ordered set of vector registers' indexes. |
90 | * |
91 | * @rhs_arg_params: see jit_uni_binary_injector description |
92 | */ |
93 | void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs, |
94 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params); |
95 | |
96 | void compute_vector_range(const injector_utils::vmm_index_set_t &vmm_idxs); |
97 | |
98 | /* |
99 | * Generates code of post_ops chain injected to host primitive. Applied to |
100 | * range <start_idx, end_idx) of vector registers' indexes. |
101 | * |
102 | * @rhs_arg_params: see jit_uni_binary_injector description |
103 | */ |
104 | void compute_vector_range(size_t start_idx, size_t end_idx, |
105 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params); |
106 | |
107 | void compute_vector_range(size_t start_idx, size_t end_idx); |
108 | |
109 | /* |
110 | * Generates code of post_ops chain injected to host primitive. Applied to |
111 | * a single vector register index. |
112 | * |
113 | * @rhs_arg_params: see jit_uni_binary_injector description |
114 | */ |
115 | void compute_vector(size_t idx, |
116 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params); |
117 | void compute_vector(size_t idx); |
118 | |
119 | /* |
120 | * Thin wrapper for eltwise injector specific function |
121 | */ |
122 | void prepare_table(bool gen_table = true); |
123 | void set_lambda_injector(lambda_jit_injectors_t::key_type, |
124 | const lambda_jit_injectors_t::mapped_type &jit_injector); |
125 | |
126 | private: |
127 | post_ops_t post_ops_; |
128 | jit_generator *host_; |
129 | // Key is a numerical order of a post-op in attributes. |
130 | std::map<int, jit_uni_eltwise_injector_f32<isa, Vmm>> |
131 | alg_to_eltwise_injector_; |
132 | std::unique_ptr<binary_injector::jit_uni_binary_injector_t<isa, Vmm>> |
133 | binary_injector_; |
134 | lambda_jit_injectors_t lambda_jit_injectors_; |
135 | }; |
136 | |
137 | enum post_op_type { sum = 0, eltwise, binary }; |
138 | |
139 | struct post_ops_ok_args_t { |
140 | post_ops_ok_args_t(const cpu_isa_t isa, |
141 | const std::vector<post_op_type> &accepted_post_op_types, |
142 | const post_ops_t &post_ops, |
143 | const memory_desc_wrapper *dst_d = nullptr, |
144 | const bool sum_at_pos_0_only = false, |
145 | const bool sum_requires_scale_one = false, |
146 | const bool sum_requires_zp_zero = true, |
147 | const bcast_set_t &enabled_bcast_strategy = default_strategies()); |
148 | |
149 | const cpu_isa_t isa; |
150 | const std::vector<post_op_type> &accepted_post_op_types; |
151 | const post_ops_t &post_ops; |
152 | const memory_desc_wrapper *dst_d; |
153 | const bool sum_at_pos_0_only; |
154 | const bool sum_requires_scale_one; |
155 | const bool sum_requires_zp_zero; |
156 | const bcast_set_t enabled_bcast_strategy; |
157 | }; |
158 | |
159 | bool post_ops_ok(const post_ops_ok_args_t &args); |
160 | |
161 | } // namespace injector |
162 | } // namespace x64 |
163 | } // namespace cpu |
164 | } // namespace impl |
165 | } // namespace dnnl |
166 | |
167 | #endif |
168 | |