1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <cstddef> |
17 | |
18 | #include "cpu/x64/prelu/jit_uni_prelu_forward_kernel.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | |
25 | jit_prelu_forward_kernel_t::jit_prelu_forward_kernel_t( |
26 | const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa, const int vlen, |
27 | const size_t number_vmm_single_compute) |
28 | : jit_prelu_base_kernel_t(isa, vlen, |
29 | prelu::get_bcast_type(memory_desc_wrapper(pd->src_md(0)), |
30 | memory_desc_wrapper(pd->weights_md(0))), |
31 | memory_desc_wrapper(pd->src_md(0)), number_vmm_single_compute, |
32 | jit_name()) |
33 | , src_dt_(pd->src_md(0)->data_type) |
34 | , wei_dt_(pd->weights_md(0)->data_type) |
35 | , dst_dt_(pd->dst_md(0)->data_type) |
36 | , dst_tail_block_(prelu::get_block_tail_size(pd->dst_md(0))) |
37 | , pd_(pd) {} |
38 | |
39 | #define PARAM_OFF(x) offsetof(call_params_t, x) |
40 | |
41 | void jit_prelu_forward_kernel_t::load_kernel_call_params() { |
42 | mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]); |
43 | mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]); |
44 | mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]); |
45 | mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]); |
46 | } |
47 | |
48 | #undef PARAM_OFF |
49 | |
50 | Xbyak::Address jit_prelu_forward_kernel_t::data_ptr(int arg_num, size_t offt) { |
51 | |
52 | const auto get_addr |
53 | = [&](const Xbyak::Reg64 ®_base, const data_type_t dt) { |
54 | const auto dt_size = types::data_type_size(dt); |
55 | return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size]; |
56 | }; |
57 | |
58 | switch (arg_num) { |
59 | case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_); |
60 | case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_); |
61 | case DNNL_ARG_DST: return get_addr(reg_dst_, dst_dt_); |
62 | default: assert(!"unsupported arg_num" ); break; |
63 | } |
64 | return Xbyak::Address(0); |
65 | } |
66 | |
67 | bool jit_prelu_forward_kernel_t::any_tensor_bf16() const { |
68 | return utils::one_of(data_type::bf16, src_dt_, wei_dt_, dst_dt_); |
69 | } |
70 | |
71 | template <typename Vmm> |
72 | jit_uni_prelu_forward_kernel_t<Vmm>::jit_uni_prelu_forward_kernel_t( |
73 | const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa) |
74 | : jit_prelu_forward_kernel_t(pd, isa, vreg_traits<Vmm>::vlen, |
75 | (utils::one_of(isa, sse41, avx) |
76 | || pd->src_md(0)->data_type != data_type::f32) |
77 | ? 4u |
78 | : 3u) |
79 | , saturation_needed_(utils::one_of( |
80 | dst_dt_, data_type::u8, data_type::s8, data_type::s32)) |
81 | , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0) |
82 | , vmm_zeros_(reserve_vmm()) |
83 | , dst_saturate_ubound_(saturation_needed_ ? reserve_vmm() : 0) |
84 | , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
85 | prelu::bcast::per_oc_blocked) |
86 | ? reserve_vmm() |
87 | : 0) |
88 | , io_(this, isa, {src_dt_, wei_dt_, dst_dt_}, {}, |
89 | io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_, |
90 | tail_vmm_mask_.getIdx(), reg_tmp_}, |
91 | io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) { |
92 | assert(tail_vmm_mask_.getIdx() == 0); |
93 | } |
94 | |
95 | template <typename Vmm> |
96 | jit_uni_prelu_forward_kernel_t<Vmm>::~jit_uni_prelu_forward_kernel_t() |
97 | = default; |
98 | |
99 | template <typename Vmm> |
100 | void jit_uni_prelu_forward_kernel_t<Vmm>::prepare_kernel_const_vars() { |
101 | uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_); |
102 | |
103 | io_.init_bf16(); |
104 | if (saturation_needed_) io_.init_saturate_f32({dst_dt_}); |
105 | if (tail_size_) io_.prepare_tail_mask(); |
106 | if (bcast_ == prelu::bcast::per_oc_n_c_spatial) |
107 | io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_); |
108 | else if (bcast_ == prelu::bcast::per_oc_blocked) |
109 | io_.at(wei_dt_)->load( |
110 | ptr[reg_weights_], weights_const_vmm_, false /*tail*/); |
111 | } |
112 | |
113 | template <typename Vmm> |
114 | std::map<data_type_t, io::io_saturation_conf_t> |
115 | jit_uni_prelu_forward_kernel_t<Vmm>::create_saturation_vmm_map() const { |
116 | |
117 | std::map<data_type_t, io::io_saturation_conf_t> saturation_map {}; |
118 | |
119 | if (saturation_needed_) { |
120 | saturation_map.emplace(dst_dt_, |
121 | io::io_saturation_conf_t {vmm_zeros_.getIdx(), |
122 | dst_saturate_ubound_.getIdx(), reg_tmp_}); |
123 | } |
124 | |
125 | return saturation_map; |
126 | } |
127 | |
128 | template <> |
129 | bool jit_uni_prelu_forward_kernel_t< |
130 | Xbyak::Zmm>::can_load_wei_from_addr_directly(bool tail) const noexcept { |
131 | return wei_dt_ == data_type::f32 |
132 | && !utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
133 | prelu::bcast::per_oc_blocked); |
134 | } |
135 | |
136 | template <> |
137 | bool jit_uni_prelu_forward_kernel_t< |
138 | Xbyak::Ymm>::can_load_wei_from_addr_directly(bool tail) const noexcept { |
139 | return wei_dt_ == data_type::f32 && is_superset(isa_, avx2) && !tail |
140 | && !utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
141 | prelu::bcast::per_oc_blocked); |
142 | } |
143 | |
144 | template <> |
145 | bool jit_uni_prelu_forward_kernel_t< |
146 | Xbyak::Xmm>::can_load_wei_from_addr_directly(bool tail) const noexcept { |
147 | return false; |
148 | } |
149 | |
150 | template <> |
151 | Xbyak::Zmm jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>::get_or_load_weights( |
152 | const Xbyak::Address &src_addr, const Xbyak::Zmm &weights_vmm, |
153 | bool tail) { |
154 | if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
155 | prelu::bcast::per_oc_blocked)) |
156 | return weights_const_vmm_; |
157 | |
158 | io_.at(wei_dt_)->load(src_addr, weights_vmm, tail); |
159 | return weights_vmm; |
160 | } |
161 | |
162 | template <> |
163 | Xbyak::Ymm jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>::get_or_load_weights( |
164 | const Xbyak::Address &src_addr, const Xbyak::Ymm &weights_vmm, |
165 | bool tail) { |
166 | if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
167 | prelu::bcast::per_oc_blocked)) |
168 | return weights_const_vmm_; |
169 | |
170 | io_.at(wei_dt_)->load(src_addr, weights_vmm, tail); |
171 | return weights_vmm; |
172 | } |
173 | |
174 | template <> |
175 | Xbyak::Xmm jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>::get_or_load_weights( |
176 | const Xbyak::Address &src_addr, const Xbyak::Xmm &weights_vmm, |
177 | bool tail) { |
178 | |
179 | if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
180 | prelu::bcast::per_oc_blocked)) |
181 | return weights_const_vmm_; |
182 | |
183 | io_.at(wei_dt_)->load(src_addr, weights_vmm, tail); |
184 | return weights_vmm; |
185 | } |
186 | |
187 | template <typename Vmm> |
188 | void jit_uni_prelu_forward_kernel_t<Vmm>::uni_vfmadd132ps( |
189 | const Vmm &x1, const Vmm &x2, const Xbyak::Operand &op, bool tail) { |
190 | uni_vfmadd132ps(x1, x2, op); |
191 | } |
192 | |
193 | template <> |
194 | void jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>::uni_vfmadd132ps( |
195 | const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, const Xbyak::Operand &op, |
196 | bool tail) { |
197 | |
198 | if (op.isMEM()) { |
199 | const Xbyak::Zmm dst = tail ? (x1 | tail_opmask_) : x1; |
200 | // workaround for DataParallelC++ compiler issue converting mem to ZMM |
201 | const Xbyak::Address addr |
202 | = reinterpret_cast<const Xbyak::Address &>(op); |
203 | vfmadd132ps(dst, x2, addr); |
204 | } else { |
205 | vfmadd132ps(x1, x2, op); |
206 | } |
207 | } |
208 | |
209 | template <typename Vmm> |
210 | void jit_uni_prelu_forward_kernel_t<Vmm>::compute_dst( |
211 | size_t unrolling_factor, bool tail) { |
212 | static constexpr size_t max_idx = 0; |
213 | static constexpr size_t min_idx = 1; |
214 | static constexpr size_t src_idx = 2; |
215 | static constexpr size_t weights_idx = 3; |
216 | |
217 | for (size_t unroll_group = 0; unroll_group < unrolling_factor; |
218 | ++unroll_group) { |
219 | const Vmm max_vmm {get_compute_vmm(max_idx, unroll_group)}; |
220 | const Vmm min_vmm {get_compute_vmm(min_idx, unroll_group)}; |
221 | const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)}; |
222 | const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)}; |
223 | |
224 | const auto offset = unroll_group * simd_w_; |
225 | io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail); |
226 | uni_vmaxps(max_vmm, vmm_zeros_, src_vmm); |
227 | uni_vminps(min_vmm, vmm_zeros_, src_vmm); |
228 | const auto &dst_vmm = min_vmm; |
229 | |
230 | const Xbyak::Address weights_addr = data_ptr(DNNL_ARG_WEIGHTS, offset); |
231 | if (can_load_wei_from_addr_directly(tail)) { |
232 | uni_vfmadd132ps(dst_vmm, max_vmm, weights_addr, tail); |
233 | } else { |
234 | const Vmm weights_operand |
235 | = get_or_load_weights(weights_addr, weights_vmm, tail); |
236 | uni_vfmadd132ps(dst_vmm, max_vmm, weights_operand, tail); |
237 | } |
238 | |
239 | io_.at(dst_dt_)->store(dst_vmm, data_ptr(DNNL_ARG_DST, offset), tail); |
240 | if (dst_tail_block_ && tail) |
241 | prelu::apply_zero_padding(this, tail_size_, dst_dt_, |
242 | dst_tail_block_, reg_dst_, ®_offset_); |
243 | } |
244 | } |
245 | |
246 | jit_prelu_forward_kernel_t *jit_prelu_forward_kernel_t::create( |
247 | const cpu_prelu_fwd_pd_t *pd) { |
248 | |
249 | const auto isa = prelu::get_supported_isa(); |
250 | const auto &src_dt = pd->src_md(0)->data_type; |
251 | const auto &wei_dt = pd->weights_md(0)->data_type; |
252 | const auto &dst_dt = pd->dst_md(0)->data_type; |
253 | |
254 | if (is_superset(isa, avx512_core)) |
255 | return new jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>(pd, isa); |
256 | else if (is_superset(isa, avx)) |
257 | if (isa == avx && prelu::is_s8u8({src_dt, wei_dt, dst_dt})) |
258 | return new jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>(pd, isa); |
259 | else |
260 | return new jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>(pd, isa); |
261 | else if (isa == sse41) |
262 | return new jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>(pd, isa); |
263 | |
264 | return nullptr; |
265 | } |
266 | |
267 | template class jit_uni_prelu_forward_kernel_t<Xbyak::Zmm>; |
268 | template class jit_uni_prelu_forward_kernel_t<Xbyak::Ymm>; |
269 | template class jit_uni_prelu_forward_kernel_t<Xbyak::Xmm>; |
270 | |
271 | } // namespace x64 |
272 | } // namespace cpu |
273 | } // namespace impl |
274 | } // namespace dnnl |
275 | |