1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <cstddef>
17
18#include "cpu/x64/prelu/jit_uni_prelu_forward_kernel.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24
25jit_prelu_forward_kernel_t::jit_prelu_forward_kernel_t(
26 const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa, const int vlen,
27 const size_t number_vmm_single_compute)
28 : jit_prelu_base_kernel_t(isa, vlen,
29 prelu::get_bcast_type(memory_desc_wrapper(pd->src_md(0)),
30 memory_desc_wrapper(pd->weights_md(0))),
31 memory_desc_wrapper(pd->src_md(0)), number_vmm_single_compute,
32 jit_name())
33 , src_dt_(pd->src_md(0)->data_type)
34 , wei_dt_(pd->weights_md(0)->data_type)
35 , dst_dt_(pd->dst_md(0)->data_type)
36 , dst_tail_block_(prelu::get_block_tail_size(pd->dst_md(0)))
37 , pd_(pd) {}
38
39#define PARAM_OFF(x) offsetof(call_params_t, x)
40
41void jit_prelu_forward_kernel_t::load_kernel_call_params() {
42 mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
43 mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]);
44 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
45 mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]);
46}
47
48#undef PARAM_OFF
49
50Xbyak::Address jit_prelu_forward_kernel_t::data_ptr(int arg_num, size_t offt) {
51
52 const auto get_addr
53 = [&](const Xbyak::Reg64 &reg_base, const data_type_t dt) {
54 const auto dt_size = types::data_type_size(dt);
55 return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
56 };
57
58 switch (arg_num) {
59 case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_);
60 case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_);
61 case DNNL_ARG_DST: return get_addr(reg_dst_, dst_dt_);
62 default: assert(!"unsupported arg_num"); break;
63 }
64 return Xbyak::Address(0);
65}
66
67bool jit_prelu_forward_kernel_t::any_tensor_bf16() const {
68 return utils::one_of(data_type::bf16, src_dt_, wei_dt_, dst_dt_);
69}
70
71template <typename Vmm>
72jit_uni_prelu_forward_kernel_t<Vmm>::jit_uni_prelu_forward_kernel_t(
73 const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa)
74 : jit_prelu_forward_kernel_t(pd, isa, vreg_traits<Vmm>::vlen,
75 (utils::one_of(isa, sse41, avx)
76 || pd->src_md(0)->data_type != data_type::f32)
77 ? 4u
78 : 3u)
79 , saturation_needed_(utils::one_of(
80 dst_dt_, data_type::u8, data_type::s8, data_type::s32))
81 , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0)
82 , vmm_zeros_(reserve_vmm())
83 , dst_saturate_ubound_(saturation_needed_ ? reserve_vmm() : 0)
84 , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
85 prelu::bcast::per_oc_blocked)
86 ? reserve_vmm()
87 : 0)
88 , io_(this, isa, {src_dt_, wei_dt_, dst_dt_}, {},
89 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
90 tail_vmm_mask_.getIdx(), reg_tmp_},
91 io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) {
92 assert(tail_vmm_mask_.getIdx() == 0);
93}
94
95template <typename Vmm>
96jit_uni_prelu_forward_kernel_t<Vmm>::~jit_uni_prelu_forward_kernel_t()
97 = default;
98
99template <typename Vmm>
100void jit_uni_prelu_forward_kernel_t<Vmm>::prepare_kernel_const_vars() {
101 uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_);
102
103 io_.init_bf16();
104 if (saturation_needed_) io_.init_saturate_f32({dst_dt_});
105 if (tail_size_) io_.prepare_tail_mask();
106 if (bcast_ == prelu::bcast::per_oc_n_c_spatial)
107 io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_);
108 else if (bcast_ == prelu::bcast::per_oc_blocked)
109 io_.at(wei_dt_)->load(
110 ptr[reg_weights_], weights_const_vmm_, false /*tail*/);
111}
112
113template <typename Vmm>
114std::map<data_type_t, io::io_saturation_conf_t>
115jit_uni_prelu_forward_kernel_t<Vmm>::create_saturation_vmm_map() const {
116
117 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
118
119 if (saturation_needed_) {
120 saturation_map.emplace(dst_dt_,
121 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
122 dst_saturate_ubound_.getIdx(), reg_tmp_});
123 }
124
125 return saturation_map;
126}
127
128template <>
129bool jit_uni_prelu_forward_kernel_t<
130 Xbyak::Zmm>::can_load_wei_from_addr_directly(bool tail) const noexcept {
131 return wei_dt_ == data_type::f32
132 && !utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
133 prelu::bcast::per_oc_blocked);
134}
135
136template <>
137bool jit_uni_prelu_forward_kernel_t<
138 Xbyak::Ymm>::can_load_wei_from_addr_directly(bool tail) const noexcept {
139 return wei_dt_ == data_type::f32 && is_superset(isa_, avx2) && !tail
140 && !utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
141 prelu::bcast::per_oc_blocked);
142}
143
144template <>
145bool jit_uni_prelu_forward_kernel_t<
146 Xbyak::Xmm>::can_load_wei_from_addr_directly(bool tail) const noexcept {
147 return false;
148}
149
150template <>
151Xbyak::Zmm jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>::get_or_load_weights(
152 const Xbyak::Address &src_addr, const Xbyak::Zmm &weights_vmm,
153 bool tail) {
154 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
155 prelu::bcast::per_oc_blocked))
156 return weights_const_vmm_;
157
158 io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
159 return weights_vmm;
160}
161
162template <>
163Xbyak::Ymm jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>::get_or_load_weights(
164 const Xbyak::Address &src_addr, const Xbyak::Ymm &weights_vmm,
165 bool tail) {
166 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
167 prelu::bcast::per_oc_blocked))
168 return weights_const_vmm_;
169
170 io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
171 return weights_vmm;
172}
173
174template <>
175Xbyak::Xmm jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>::get_or_load_weights(
176 const Xbyak::Address &src_addr, const Xbyak::Xmm &weights_vmm,
177 bool tail) {
178
179 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
180 prelu::bcast::per_oc_blocked))
181 return weights_const_vmm_;
182
183 io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
184 return weights_vmm;
185}
186
187template <typename Vmm>
188void jit_uni_prelu_forward_kernel_t<Vmm>::uni_vfmadd132ps(
189 const Vmm &x1, const Vmm &x2, const Xbyak::Operand &op, bool tail) {
190 uni_vfmadd132ps(x1, x2, op);
191}
192
193template <>
194void jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>::uni_vfmadd132ps(
195 const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, const Xbyak::Operand &op,
196 bool tail) {
197
198 if (op.isMEM()) {
199 const Xbyak::Zmm dst = tail ? (x1 | tail_opmask_) : x1;
200 // workaround for DataParallelC++ compiler issue converting mem to ZMM
201 const Xbyak::Address addr
202 = reinterpret_cast<const Xbyak::Address &>(op);
203 vfmadd132ps(dst, x2, addr);
204 } else {
205 vfmadd132ps(x1, x2, op);
206 }
207}
208
209template <typename Vmm>
210void jit_uni_prelu_forward_kernel_t<Vmm>::compute_dst(
211 size_t unrolling_factor, bool tail) {
212 static constexpr size_t max_idx = 0;
213 static constexpr size_t min_idx = 1;
214 static constexpr size_t src_idx = 2;
215 static constexpr size_t weights_idx = 3;
216
217 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
218 ++unroll_group) {
219 const Vmm max_vmm {get_compute_vmm(max_idx, unroll_group)};
220 const Vmm min_vmm {get_compute_vmm(min_idx, unroll_group)};
221 const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
222 const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)};
223
224 const auto offset = unroll_group * simd_w_;
225 io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
226 uni_vmaxps(max_vmm, vmm_zeros_, src_vmm);
227 uni_vminps(min_vmm, vmm_zeros_, src_vmm);
228 const auto &dst_vmm = min_vmm;
229
230 const Xbyak::Address weights_addr = data_ptr(DNNL_ARG_WEIGHTS, offset);
231 if (can_load_wei_from_addr_directly(tail)) {
232 uni_vfmadd132ps(dst_vmm, max_vmm, weights_addr, tail);
233 } else {
234 const Vmm weights_operand
235 = get_or_load_weights(weights_addr, weights_vmm, tail);
236 uni_vfmadd132ps(dst_vmm, max_vmm, weights_operand, tail);
237 }
238
239 io_.at(dst_dt_)->store(dst_vmm, data_ptr(DNNL_ARG_DST, offset), tail);
240 if (dst_tail_block_ && tail)
241 prelu::apply_zero_padding(this, tail_size_, dst_dt_,
242 dst_tail_block_, reg_dst_, &reg_offset_);
243 }
244}
245
246jit_prelu_forward_kernel_t *jit_prelu_forward_kernel_t::create(
247 const cpu_prelu_fwd_pd_t *pd) {
248
249 const auto isa = prelu::get_supported_isa();
250 const auto &src_dt = pd->src_md(0)->data_type;
251 const auto &wei_dt = pd->weights_md(0)->data_type;
252 const auto &dst_dt = pd->dst_md(0)->data_type;
253
254 if (is_superset(isa, avx512_core))
255 return new jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>(pd, isa);
256 else if (is_superset(isa, avx))
257 if (isa == avx && prelu::is_s8u8({src_dt, wei_dt, dst_dt}))
258 return new jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>(pd, isa);
259 else
260 return new jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>(pd, isa);
261 else if (isa == sse41)
262 return new jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>(pd, isa);
263
264 return nullptr;
265}
266
267template class jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>;
268template class jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>;
269template class jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>;
270
271} // namespace x64
272} // namespace cpu
273} // namespace impl
274} // namespace dnnl
275