1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <cassert> |
17 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace cpu { |
22 | namespace x64 { |
23 | namespace injector { |
24 | |
25 | bool is_supported(const post_ops_ok_args_t &post_ops_ok_args) { |
26 | const cpu_isa_t isa = post_ops_ok_args.isa; |
27 | const post_ops_t &post_ops = post_ops_ok_args.post_ops; |
28 | const memory_desc_wrapper *dst_d = post_ops_ok_args.dst_d; |
29 | const auto &enabled_bcast_strategy |
30 | = post_ops_ok_args.enabled_bcast_strategy; |
31 | |
32 | for (const auto &post_op : post_ops.entry_) { |
33 | if (post_op.is_eltwise()) { |
34 | const auto res |
35 | = eltwise_injector::is_supported(isa, post_op.eltwise.alg); |
36 | if (!res) return false; |
37 | } else if (post_op.is_binary()) { |
38 | const auto &src1_desc = post_op.binary.src1_desc; |
39 | const auto res = binary_injector::is_supported( |
40 | isa, src1_desc, *dst_d, enabled_bcast_strategy); |
41 | if (!res) return false; |
42 | } |
43 | } |
44 | return true; |
45 | } |
46 | |
47 | template <cpu_isa_t isa, typename Vmm> |
48 | jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t( |
49 | jit_generator *host, const post_ops_t &post_ops, |
50 | const binary_injector::static_params_t &binary_static_params, |
51 | const eltwise_injector::static_params_t &eltwise_static_params, |
52 | const lambda_jit_injectors_t &lambda_jit_injectors) |
53 | : post_ops_(post_ops) |
54 | , host_(host) |
55 | , binary_injector_(nullptr) |
56 | , lambda_jit_injectors_(lambda_jit_injectors) { |
57 | |
58 | const auto &esp = eltwise_static_params; |
59 | bool is_binary = false; |
60 | bool is_eltwise = false; |
61 | |
62 | for (int i = 0; i < post_ops.len(); i++) { |
63 | const auto &post_op = post_ops.entry_[i]; |
64 | if (post_op.is_eltwise()) { |
65 | is_eltwise = true; |
66 | alg_to_eltwise_injector_.emplace(i, |
67 | jit_uni_eltwise_injector_f32<isa, Vmm>(host_, |
68 | post_op.eltwise, esp.save_state, esp.p_table, |
69 | esp.k_mask, esp.is_fwd, esp.use_dst, |
70 | esp.preserve_vmm, esp.preserve_p_table)); |
71 | } else if (post_op.is_binary()) { |
72 | is_binary = true; |
73 | } |
74 | } |
75 | |
76 | if (is_superset(isa, avx512_core) && is_eltwise && is_binary |
77 | && binary_static_params.rhs_arg_static_params.tail_size) |
78 | assert(eltwise_static_params.k_mask |
79 | != binary_static_params.rhs_arg_static_params.tail_opmask && |
80 | "Binary tail opmask should be different than eltwise injector \ |
81 | opmask. Otherwise eltwise injector will overwrite binary tail \ |
82 | opmask." ); |
83 | |
84 | if (is_binary) |
85 | binary_injector_ = utils::make_unique< |
86 | binary_injector::jit_uni_binary_injector_t<isa, Vmm>>( |
87 | host, binary_static_params); |
88 | } |
89 | |
90 | template <cpu_isa_t isa, typename Vmm> |
91 | jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t( |
92 | jit_generator *host, const post_ops_t &post_ops, |
93 | const binary_injector::static_params_t &binary_static_params) |
94 | : jit_uni_postops_injector_t(host, post_ops, binary_static_params, |
95 | eltwise_injector::static_params_t(), lambda_jit_injectors_t()) {} |
96 | |
97 | template <cpu_isa_t isa, typename Vmm> |
98 | jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t( |
99 | jit_generator *host, const post_ops_t &post_ops, |
100 | const binary_injector::static_params_t &binary_static_params, |
101 | const lambda_jit_injectors_t &lambda_jit_injectors) |
102 | : jit_uni_postops_injector_t(host, post_ops, binary_static_params, |
103 | eltwise_injector::static_params_t(), lambda_jit_injectors) {} |
104 | |
105 | template <cpu_isa_t isa, typename Vmm> |
106 | jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t( |
107 | jit_generator *host, const post_ops_t &post_ops, |
108 | const binary_injector::static_params_t &binary_static_params, |
109 | const eltwise_injector::static_params_t &eltwise_static_params) |
110 | : jit_uni_postops_injector_t(host, post_ops, binary_static_params, |
111 | eltwise_static_params, lambda_jit_injectors_t()) {} |
112 | |
113 | template <cpu_isa_t isa, typename Vmm> |
114 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range( |
115 | size_t start_idx, size_t end_idx, |
116 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { |
117 | |
118 | injector_utils::vmm_index_set_t vmm_idxs; |
119 | for (size_t i = start_idx; i < end_idx; i++) |
120 | vmm_idxs.emplace(i); |
121 | compute_vector_range(vmm_idxs, rhs_arg_params); |
122 | } |
123 | |
124 | template <cpu_isa_t isa, typename Vmm> |
125 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range( |
126 | size_t start_idx, size_t end_idx) { |
127 | compute_vector_range( |
128 | start_idx, end_idx, binary_injector::rhs_arg_dynamic_params_t()); |
129 | } |
130 | |
131 | template <cpu_isa_t isa, typename Vmm> |
132 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range( |
133 | const injector_utils::vmm_index_set_t &vmm_idxs, |
134 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { |
135 | |
136 | std::size_t rhs_arg_idx = 0; |
137 | for (int i = 0; i < post_ops_.len(); i++) { |
138 | const auto &post_op = post_ops_.entry_[i]; |
139 | if (post_op.is_eltwise()) { |
140 | alg_to_eltwise_injector_.at(i).compute_vector_range(vmm_idxs); |
141 | } else if (post_op.is_binary()) { |
142 | binary_injector_->compute_vector_range( |
143 | vmm_idxs, rhs_arg_idx, post_op, rhs_arg_params); |
144 | ++rhs_arg_idx; |
145 | } else { |
146 | const auto lam = lambda_jit_injectors_.find(post_op.kind); |
147 | if (lam != lambda_jit_injectors_.end()) lam->second(); |
148 | } |
149 | } |
150 | } |
151 | template <cpu_isa_t isa, typename Vmm> |
152 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range( |
153 | const injector_utils::vmm_index_set_t &vmm_idxs) { |
154 | compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t()); |
155 | } |
156 | |
157 | template <cpu_isa_t isa, typename Vmm> |
158 | void jit_uni_postops_injector_t<isa, Vmm>::prepare_table(bool gen_table) { |
159 | for (auto &alg_elt_inject : alg_to_eltwise_injector_) |
160 | alg_elt_inject.second.prepare_table(gen_table); |
161 | } |
162 | |
163 | template <cpu_isa_t isa, typename Vmm> |
164 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector(size_t idx, |
165 | const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) { |
166 | compute_vector_range({idx}, rhs_arg_params); |
167 | } |
168 | |
169 | template <cpu_isa_t isa, typename Vmm> |
170 | void jit_uni_postops_injector_t<isa, Vmm>::compute_vector(size_t idx) { |
171 | compute_vector_range({idx}); |
172 | } |
173 | |
174 | template <cpu_isa_t isa, typename Vmm> |
175 | void jit_uni_postops_injector_t<isa, Vmm>::set_lambda_injector( |
176 | dnnl_primitive_kind_t kind, const std::function<void()> &jit_injector) { |
177 | lambda_jit_injectors_[kind] = jit_injector; |
178 | } |
179 | |
180 | post_ops_ok_args_t::post_ops_ok_args_t(const cpu_isa_t isa, |
181 | const std::vector<post_op_type> &accepted_post_op_types, |
182 | const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, |
183 | const bool sum_at_pos_0_only, const bool sum_requires_scale_one, |
184 | const bool sum_requires_zp_zero, |
185 | const bcast_set_t &enabled_bcast_strategy) |
186 | : isa(isa) |
187 | , accepted_post_op_types(accepted_post_op_types) |
188 | , post_ops(post_ops) |
189 | , dst_d(dst_d) |
190 | , sum_at_pos_0_only(sum_at_pos_0_only) |
191 | , sum_requires_scale_one(sum_requires_scale_one) |
192 | , sum_requires_zp_zero(sum_requires_zp_zero) |
193 | , enabled_bcast_strategy(enabled_bcast_strategy) {}; |
194 | |
195 | bool post_ops_ok(const post_ops_ok_args_t &post_ops_ok_args) { |
196 | const cpu_isa_t isa = post_ops_ok_args.isa; |
197 | const std::vector<post_op_type> &accepted_post_op_types |
198 | = post_ops_ok_args.accepted_post_op_types; |
199 | const post_ops_t &post_ops = post_ops_ok_args.post_ops; |
200 | const memory_desc_wrapper *dst_d = post_ops_ok_args.dst_d; |
201 | const bool sum_at_pos_0_only = post_ops_ok_args.sum_at_pos_0_only; |
202 | const bool sum_requires_scale_one = post_ops_ok_args.sum_requires_scale_one; |
203 | const bool sum_requires_zp_zero = post_ops_ok_args.sum_requires_zp_zero; |
204 | const auto &enabled_bcast_strategy |
205 | = post_ops_ok_args.enabled_bcast_strategy; |
206 | |
207 | const auto is_accepted_postop = [&](const int idx) { |
208 | for (const auto &post_op : accepted_post_op_types) { |
209 | const auto &entry = post_ops.entry_[idx]; |
210 | switch (post_op) { |
211 | case sum: |
212 | if (entry.is_sum(false, false)) { |
213 | if (sum_requires_scale_one && entry.sum.scale != 1) |
214 | return false; |
215 | if (sum_requires_zp_zero && entry.sum.zero_point != 0) |
216 | return false; |
217 | return IMPLICATION(sum_at_pos_0_only, idx == 0); |
218 | } |
219 | break; |
220 | case eltwise: |
221 | if (entry.is_eltwise()) { |
222 | const auto alg = entry.eltwise.alg; |
223 | return eltwise_injector::is_supported(isa, alg); |
224 | } |
225 | break; |
226 | case binary: |
227 | if (entry.is_binary()) { |
228 | assert(dst_d != nullptr && "dst_d is null" ); |
229 | return binary_injector::is_supported(isa, |
230 | entry.binary.src1_desc, *dst_d, |
231 | enabled_bcast_strategy); |
232 | } |
233 | break; |
234 | default: assert(false && "Unhandled post_op type" ); |
235 | } |
236 | } |
237 | return false; |
238 | }; |
239 | |
240 | for (int i = 0; i < post_ops.len(); i++) { |
241 | if (!is_accepted_postop(i)) return false; |
242 | } |
243 | |
244 | return true; |
245 | } |
246 | |
247 | template class jit_uni_postops_injector_t<avx512_core_fp16>; |
248 | template class jit_uni_postops_injector_t<avx512_core_fp16, Xbyak::Ymm>; |
249 | template class jit_uni_postops_injector_t<avx512_core_fp16, Xbyak::Xmm>; |
250 | template class jit_uni_postops_injector_t<avx512_core_bf16>; |
251 | template class jit_uni_postops_injector_t<avx512_core>; |
252 | template class jit_uni_postops_injector_t<avx512_core, Xbyak::Ymm>; |
253 | template class jit_uni_postops_injector_t<avx512_core, Xbyak::Xmm>; |
254 | template class jit_uni_postops_injector_t<avx2_vnni_2>; |
255 | template class jit_uni_postops_injector_t<avx2>; |
256 | template class jit_uni_postops_injector_t<avx2, Xbyak::Xmm>; |
257 | template class jit_uni_postops_injector_t<avx>; |
258 | template class jit_uni_postops_injector_t<avx, Xbyak::Xmm>; |
259 | template class jit_uni_postops_injector_t<sse41>; |
260 | |
261 | } // namespace injector |
262 | } // namespace x64 |
263 | } // namespace cpu |
264 | } // namespace impl |
265 | } // namespace dnnl |
266 | |