1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <cassert>
17#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace cpu {
22namespace x64 {
23namespace injector {
24
25bool is_supported(const post_ops_ok_args_t &post_ops_ok_args) {
26 const cpu_isa_t isa = post_ops_ok_args.isa;
27 const post_ops_t &post_ops = post_ops_ok_args.post_ops;
28 const memory_desc_wrapper *dst_d = post_ops_ok_args.dst_d;
29 const auto &enabled_bcast_strategy
30 = post_ops_ok_args.enabled_bcast_strategy;
31
32 for (const auto &post_op : post_ops.entry_) {
33 if (post_op.is_eltwise()) {
34 const auto res
35 = eltwise_injector::is_supported(isa, post_op.eltwise.alg);
36 if (!res) return false;
37 } else if (post_op.is_binary()) {
38 const auto &src1_desc = post_op.binary.src1_desc;
39 const auto res = binary_injector::is_supported(
40 isa, src1_desc, *dst_d, enabled_bcast_strategy);
41 if (!res) return false;
42 }
43 }
44 return true;
45}
46
47template <cpu_isa_t isa, typename Vmm>
48jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t(
49 jit_generator *host, const post_ops_t &post_ops,
50 const binary_injector::static_params_t &binary_static_params,
51 const eltwise_injector::static_params_t &eltwise_static_params,
52 const lambda_jit_injectors_t &lambda_jit_injectors)
53 : post_ops_(post_ops)
54 , host_(host)
55 , binary_injector_(nullptr)
56 , lambda_jit_injectors_(lambda_jit_injectors) {
57
58 const auto &esp = eltwise_static_params;
59 bool is_binary = false;
60 bool is_eltwise = false;
61
62 for (int i = 0; i < post_ops.len(); i++) {
63 const auto &post_op = post_ops.entry_[i];
64 if (post_op.is_eltwise()) {
65 is_eltwise = true;
66 alg_to_eltwise_injector_.emplace(i,
67 jit_uni_eltwise_injector_f32<isa, Vmm>(host_,
68 post_op.eltwise, esp.save_state, esp.p_table,
69 esp.k_mask, esp.is_fwd, esp.use_dst,
70 esp.preserve_vmm, esp.preserve_p_table));
71 } else if (post_op.is_binary()) {
72 is_binary = true;
73 }
74 }
75
76 if (is_superset(isa, avx512_core) && is_eltwise && is_binary
77 && binary_static_params.rhs_arg_static_params.tail_size)
78 assert(eltwise_static_params.k_mask
79 != binary_static_params.rhs_arg_static_params.tail_opmask &&
80 "Binary tail opmask should be different than eltwise injector \
81 opmask. Otherwise eltwise injector will overwrite binary tail \
82 opmask.");
83
84 if (is_binary)
85 binary_injector_ = utils::make_unique<
86 binary_injector::jit_uni_binary_injector_t<isa, Vmm>>(
87 host, binary_static_params);
88}
89
90template <cpu_isa_t isa, typename Vmm>
91jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t(
92 jit_generator *host, const post_ops_t &post_ops,
93 const binary_injector::static_params_t &binary_static_params)
94 : jit_uni_postops_injector_t(host, post_ops, binary_static_params,
95 eltwise_injector::static_params_t(), lambda_jit_injectors_t()) {}
96
97template <cpu_isa_t isa, typename Vmm>
98jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t(
99 jit_generator *host, const post_ops_t &post_ops,
100 const binary_injector::static_params_t &binary_static_params,
101 const lambda_jit_injectors_t &lambda_jit_injectors)
102 : jit_uni_postops_injector_t(host, post_ops, binary_static_params,
103 eltwise_injector::static_params_t(), lambda_jit_injectors) {}
104
105template <cpu_isa_t isa, typename Vmm>
106jit_uni_postops_injector_t<isa, Vmm>::jit_uni_postops_injector_t(
107 jit_generator *host, const post_ops_t &post_ops,
108 const binary_injector::static_params_t &binary_static_params,
109 const eltwise_injector::static_params_t &eltwise_static_params)
110 : jit_uni_postops_injector_t(host, post_ops, binary_static_params,
111 eltwise_static_params, lambda_jit_injectors_t()) {}
112
113template <cpu_isa_t isa, typename Vmm>
114void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range(
115 size_t start_idx, size_t end_idx,
116 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) {
117
118 injector_utils::vmm_index_set_t vmm_idxs;
119 for (size_t i = start_idx; i < end_idx; i++)
120 vmm_idxs.emplace(i);
121 compute_vector_range(vmm_idxs, rhs_arg_params);
122}
123
124template <cpu_isa_t isa, typename Vmm>
125void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range(
126 size_t start_idx, size_t end_idx) {
127 compute_vector_range(
128 start_idx, end_idx, binary_injector::rhs_arg_dynamic_params_t());
129}
130
131template <cpu_isa_t isa, typename Vmm>
132void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range(
133 const injector_utils::vmm_index_set_t &vmm_idxs,
134 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) {
135
136 std::size_t rhs_arg_idx = 0;
137 for (int i = 0; i < post_ops_.len(); i++) {
138 const auto &post_op = post_ops_.entry_[i];
139 if (post_op.is_eltwise()) {
140 alg_to_eltwise_injector_.at(i).compute_vector_range(vmm_idxs);
141 } else if (post_op.is_binary()) {
142 binary_injector_->compute_vector_range(
143 vmm_idxs, rhs_arg_idx, post_op, rhs_arg_params);
144 ++rhs_arg_idx;
145 } else {
146 const auto lam = lambda_jit_injectors_.find(post_op.kind);
147 if (lam != lambda_jit_injectors_.end()) lam->second();
148 }
149 }
150}
151template <cpu_isa_t isa, typename Vmm>
152void jit_uni_postops_injector_t<isa, Vmm>::compute_vector_range(
153 const injector_utils::vmm_index_set_t &vmm_idxs) {
154 compute_vector_range(vmm_idxs, binary_injector::rhs_arg_dynamic_params_t());
155}
156
157template <cpu_isa_t isa, typename Vmm>
158void jit_uni_postops_injector_t<isa, Vmm>::prepare_table(bool gen_table) {
159 for (auto &alg_elt_inject : alg_to_eltwise_injector_)
160 alg_elt_inject.second.prepare_table(gen_table);
161}
162
163template <cpu_isa_t isa, typename Vmm>
164void jit_uni_postops_injector_t<isa, Vmm>::compute_vector(size_t idx,
165 const binary_injector::rhs_arg_dynamic_params_t &rhs_arg_params) {
166 compute_vector_range({idx}, rhs_arg_params);
167}
168
169template <cpu_isa_t isa, typename Vmm>
170void jit_uni_postops_injector_t<isa, Vmm>::compute_vector(size_t idx) {
171 compute_vector_range({idx});
172}
173
174template <cpu_isa_t isa, typename Vmm>
175void jit_uni_postops_injector_t<isa, Vmm>::set_lambda_injector(
176 dnnl_primitive_kind_t kind, const std::function<void()> &jit_injector) {
177 lambda_jit_injectors_[kind] = jit_injector;
178}
179
180post_ops_ok_args_t::post_ops_ok_args_t(const cpu_isa_t isa,
181 const std::vector<post_op_type> &accepted_post_op_types,
182 const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
183 const bool sum_at_pos_0_only, const bool sum_requires_scale_one,
184 const bool sum_requires_zp_zero,
185 const bcast_set_t &enabled_bcast_strategy)
186 : isa(isa)
187 , accepted_post_op_types(accepted_post_op_types)
188 , post_ops(post_ops)
189 , dst_d(dst_d)
190 , sum_at_pos_0_only(sum_at_pos_0_only)
191 , sum_requires_scale_one(sum_requires_scale_one)
192 , sum_requires_zp_zero(sum_requires_zp_zero)
193 , enabled_bcast_strategy(enabled_bcast_strategy) {};
194
195bool post_ops_ok(const post_ops_ok_args_t &post_ops_ok_args) {
196 const cpu_isa_t isa = post_ops_ok_args.isa;
197 const std::vector<post_op_type> &accepted_post_op_types
198 = post_ops_ok_args.accepted_post_op_types;
199 const post_ops_t &post_ops = post_ops_ok_args.post_ops;
200 const memory_desc_wrapper *dst_d = post_ops_ok_args.dst_d;
201 const bool sum_at_pos_0_only = post_ops_ok_args.sum_at_pos_0_only;
202 const bool sum_requires_scale_one = post_ops_ok_args.sum_requires_scale_one;
203 const bool sum_requires_zp_zero = post_ops_ok_args.sum_requires_zp_zero;
204 const auto &enabled_bcast_strategy
205 = post_ops_ok_args.enabled_bcast_strategy;
206
207 const auto is_accepted_postop = [&](const int idx) {
208 for (const auto &post_op : accepted_post_op_types) {
209 const auto &entry = post_ops.entry_[idx];
210 switch (post_op) {
211 case sum:
212 if (entry.is_sum(false, false)) {
213 if (sum_requires_scale_one && entry.sum.scale != 1)
214 return false;
215 if (sum_requires_zp_zero && entry.sum.zero_point != 0)
216 return false;
217 return IMPLICATION(sum_at_pos_0_only, idx == 0);
218 }
219 break;
220 case eltwise:
221 if (entry.is_eltwise()) {
222 const auto alg = entry.eltwise.alg;
223 return eltwise_injector::is_supported(isa, alg);
224 }
225 break;
226 case binary:
227 if (entry.is_binary()) {
228 assert(dst_d != nullptr && "dst_d is null");
229 return binary_injector::is_supported(isa,
230 entry.binary.src1_desc, *dst_d,
231 enabled_bcast_strategy);
232 }
233 break;
234 default: assert(false && "Unhandled post_op type");
235 }
236 }
237 return false;
238 };
239
240 for (int i = 0; i < post_ops.len(); i++) {
241 if (!is_accepted_postop(i)) return false;
242 }
243
244 return true;
245}
246
247template class jit_uni_postops_injector_t<avx512_core_fp16>;
248template class jit_uni_postops_injector_t<avx512_core_fp16, Xbyak::Ymm>;
249template class jit_uni_postops_injector_t<avx512_core_fp16, Xbyak::Xmm>;
250template class jit_uni_postops_injector_t<avx512_core_bf16>;
251template class jit_uni_postops_injector_t<avx512_core>;
252template class jit_uni_postops_injector_t<avx512_core, Xbyak::Ymm>;
253template class jit_uni_postops_injector_t<avx512_core, Xbyak::Xmm>;
254template class jit_uni_postops_injector_t<avx2_vnni_2>;
255template class jit_uni_postops_injector_t<avx2>;
256template class jit_uni_postops_injector_t<avx2, Xbyak::Xmm>;
257template class jit_uni_postops_injector_t<avx>;
258template class jit_uni_postops_injector_t<avx, Xbyak::Xmm>;
259template class jit_uni_postops_injector_t<sse41>;
260
261} // namespace injector
262} // namespace x64
263} // namespace cpu
264} // namespace impl
265} // namespace dnnl
266