1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <type_traits>
17
18#include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24
25jit_prelu_backward_kernel_t::jit_prelu_backward_kernel_t(
26 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa, const int vlen,
27 const size_t number_vmm_single_compute)
28 : jit_prelu_base_kernel_t(isa, vlen,
29 prelu::get_bcast_type(memory_desc_wrapper(pd->diff_src_md(0)),
30 memory_desc_wrapper(pd->diff_weights_md(0))),
31 memory_desc_wrapper(pd->diff_src_md(0)), number_vmm_single_compute,
32 jit_name())
33 , pd_(pd)
34 , src_dt_(pd->src_md(0)->data_type)
35 , wei_dt_(pd->weights_md(0)->data_type)
36 , diff_src_dt_(pd->diff_src_md(0)->data_type)
37 , diff_dst_dt_(pd->diff_dst_md(0)->data_type)
38 , diff_wei_dt_(bcast_ == prelu::bcast::full
39 ? pd->diff_weights_md(0)->data_type
40 : data_type::f32)
41 , diff_src_block_tail_(prelu::get_block_tail_size(pd->diff_src_md(0)))
42 , diff_wei_block_tail_(prelu::get_block_tail_size(pd->diff_weights_md(0))) {
43}
44
45#define PARAM_OFF(x) offsetof(call_params_t, x)
46
47void jit_prelu_backward_kernel_t::load_kernel_call_params() {
48 mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]);
49 mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]);
50 mov(reg_src_diff_, ptr[abi_param1 + PARAM_OFF(src_diff)]);
51 mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]);
52 mov(reg_dst_diff_, ptr[abi_param1 + PARAM_OFF(dst_diff)]);
53 mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]);
54}
55
56#undef PARAM_OFF
57
58Xbyak::Address jit_prelu_backward_kernel_t::data_ptr(int arg_num, size_t offt) {
59 const auto get_addr
60 = [&](const Xbyak::Reg64 &reg_base, const data_type_t dt) {
61 const auto dt_size = types::data_type_size(dt);
62 return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
63 };
64
65 switch (arg_num) {
66 case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_);
67 case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_);
68 case DNNL_ARG_DIFF_SRC: return get_addr(reg_src_diff_, diff_src_dt_);
69 case DNNL_ARG_DIFF_WEIGHTS:
70 return get_addr(reg_weights_diff_, diff_wei_dt_);
71 case DNNL_ARG_DIFF_DST: return get_addr(reg_dst_diff_, diff_dst_dt_);
72
73 default: assert(!"unsupported arg_num"); break;
74 }
75 return Xbyak::Address(0);
76}
77
78bool jit_prelu_backward_kernel_t::any_tensor_bf16() const {
79 return utils::one_of(data_type::bf16, src_dt_, wei_dt_, diff_src_dt_,
80 diff_dst_dt_, diff_wei_dt_);
81}
82
83template <typename Vmm>
84jit_uni_prelu_backward_kernel_t<Vmm>::jit_uni_prelu_backward_kernel_t(
85 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa)
86 : jit_prelu_backward_kernel_t(pd, isa, vreg_traits<Vmm>::vlen,
87 std::is_same<Vmm, Xbyak::Zmm>::value ? 4u : 6u)
88 , saturation_needed_diff_src_(utils::one_of(
89 diff_src_dt_, data_type::u8, data_type::s8, data_type::s32))
90 , saturation_needed_diff_weights_(utils::one_of(
91 diff_wei_dt_, data_type::u8, data_type::s8, data_type::s32))
92 , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0)
93 , vmm_zeros_(reserve_vmm())
94 , saturation_ubound_diff_src_(
95 saturation_needed_diff_src_ ? reserve_vmm() : 0)
96 , saturation_ubound_diff_weights_(saturation_needed_diff_weights_
97 ? (diff_wei_dt_ == diff_src_dt_
98 ? saturation_ubound_diff_src_.getIdx()
99 : reserve_vmm())
100 : 0)
101 , vmm_ones_(reserve_vmm())
102 , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
103 prelu::bcast::per_oc_blocked)
104 ? reserve_vmm()
105 : 0)
106 , weights_diff_acc_vmm_(
107 utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
108 prelu::bcast::per_oc_blocked)
109 ? reserve_vmm()
110 : 0)
111 , io_(this, isa,
112 {src_dt_, wei_dt_, diff_src_dt_, diff_dst_dt_, diff_wei_dt_}, {},
113 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
114 tail_vmm_mask_.getIdx(), reg_tmp_},
115 io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) {
116 assert(tail_vmm_mask_.getIdx() == 0);
117}
118
119template <typename Vmm>
120jit_uni_prelu_backward_kernel_t<Vmm>::~jit_uni_prelu_backward_kernel_t()
121 = default;
122
123template <typename Vmm>
124void jit_uni_prelu_backward_kernel_t<Vmm>::prepare_kernel_const_vars() {
125 uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_);
126
127 io_.init_bf16();
128 if (tail_size_) io_.prepare_tail_mask();
129 if (saturation_needed_diff_src_ || saturation_needed_diff_weights_) {
130 io_.init_saturate_f32({diff_src_dt_, diff_wei_dt_});
131 }
132 // load ones
133 this->mov(this->reg_tmp_, float2int(1));
134 const Xbyak::Xmm xmm_ones_ {vmm_ones_.getIdx()};
135 this->uni_vmovq(xmm_ones_, this->reg_tmp_);
136 this->uni_vbroadcastss(vmm_ones_, xmm_ones_);
137
138 if (bcast_ == prelu::bcast::per_oc_blocked) {
139 io_.at(wei_dt_)->load(
140 ptr[reg_weights_], weights_const_vmm_, false /*tail*/);
141 vmovups(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
142 } else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
143 io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_);
144 uni_vxorps(weights_diff_acc_vmm_, weights_diff_acc_vmm_,
145 weights_diff_acc_vmm_);
146 uni_vmovss(weights_diff_acc_vmm_, ptr[reg_weights_diff_]);
147 }
148}
149
150template <typename Vmm>
151void jit_uni_prelu_backward_kernel_t<Vmm>::compute_dst(
152 size_t unrolling_factor, bool tail) {
153
154 static constexpr size_t dst_diff_idx = 0;
155 static constexpr size_t src_idx = 1;
156 static constexpr size_t src_le_zero_idx = 2;
157 static constexpr size_t src_gt_zero_idx = 3;
158 static constexpr size_t weights_diff_idx = 4;
159 static constexpr size_t weights_idx = 5;
160
161 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
162 ++unroll_group) {
163
164 const Vmm dst_diff_vmm {get_compute_vmm(dst_diff_idx, unroll_group)};
165 const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
166 const Vmm src_le_zero_vmm {
167 get_compute_vmm(src_le_zero_idx, unroll_group)};
168 const Vmm src_gt_zero_vmm {
169 get_compute_vmm(src_gt_zero_idx, unroll_group)};
170 const Vmm weights_diff_vmm {
171 get_compute_vmm(weights_diff_idx, unroll_group)};
172 const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)};
173
174 const auto offset = unroll_group * simd_w_;
175 io_.at(diff_dst_dt_)
176 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
177 io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
178 static constexpr int VCMPLEPS = 2;
179 uni_vcmpps(src_le_zero_vmm, src_vmm, vmm_zeros_, VCMPLEPS);
180 uni_vandps(src_le_zero_vmm, src_le_zero_vmm, vmm_ones_);
181 static constexpr int VCMPGTPS = 14;
182 uni_vcmpps(src_gt_zero_vmm, src_vmm, vmm_zeros_, VCMPGTPS);
183 uni_vandps(src_gt_zero_vmm, src_gt_zero_vmm, vmm_ones_);
184
185 //weights_diff_calculations
186 uni_vmulps(weights_diff_vmm, dst_diff_vmm, src_vmm);
187 uni_vmulps(weights_diff_vmm, weights_diff_vmm, src_le_zero_vmm);
188
189 //src_diff calculations
190 const auto weights_operand = get_or_load_weights(
191 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
192 uni_vfmadd231ps(src_gt_zero_vmm, src_le_zero_vmm, weights_operand);
193 const auto &src_diff_vmm = src_gt_zero_vmm;
194 uni_vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
195 io_.at(diff_src_dt_)
196 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
197 tail);
198 if (diff_src_block_tail_ && tail)
199 prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
200 diff_src_block_tail_, reg_src_diff_, nullptr);
201
202 accumulate_weights_diff(weights_diff_vmm, src_gt_zero_vmm,
203 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
204 }
205}
206
207template <>
208void jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>::compute_dst(
209 size_t unrolling_factor, bool tail) {
210
211 size_t opmask_counter = 2;
212 auto get_next_opmask = [opmask_counter]() mutable {
213 static constexpr size_t opmask_range_begin = 2;
214 static constexpr size_t opmask_range_end = 8;
215 const auto opmask = Xbyak::Opmask(opmask_counter++);
216 if (opmask_counter == opmask_range_end)
217 opmask_counter = opmask_range_begin;
218 return opmask;
219 };
220
221 static constexpr size_t dst_diff_idx = 0;
222 static constexpr size_t src_idx = 1;
223 static constexpr size_t weights_diff_idx = 2;
224 static constexpr size_t weights_idx = 3;
225
226 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
227 ++unroll_group) {
228
229 const auto offset = unroll_group * simd_w_;
230 const Xbyak::Zmm dst_diff_vmm {
231 get_compute_vmm(dst_diff_idx, unroll_group)};
232 const Xbyak::Zmm src_vmm {get_compute_vmm(src_idx, unroll_group)};
233
234 io_.at(diff_dst_dt_)
235 ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail);
236 io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail);
237
238 const Xbyak::Opmask src_le_zero_opmask = get_next_opmask();
239 static constexpr int VCMPLEPS = 2;
240 vcmpps(src_le_zero_opmask, src_vmm, vmm_zeros_, VCMPLEPS);
241 const Xbyak::Opmask src_gt_zero_vmm_opmask = get_next_opmask();
242 static constexpr int VCMPGTPS = 14;
243 vcmpps(src_gt_zero_vmm_opmask, src_vmm, vmm_zeros_, VCMPGTPS);
244
245 // //weights_diff_calculations
246 const Xbyak::Zmm weights_diff_vmm {
247 get_compute_vmm(weights_diff_idx, unroll_group)};
248 vmulps(weights_diff_vmm | src_le_zero_opmask | T_z, dst_diff_vmm,
249 src_vmm);
250 accumulate_weights_diff(weights_diff_vmm, weights_diff_acc_vmm_,
251 data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail);
252
253 //src_diff calculations
254 const Xbyak::Zmm weights_vmm {
255 get_compute_vmm(weights_idx, unroll_group)};
256 const auto &src_diff_vmm = weights_vmm;
257 const auto weights_operand = get_or_load_weights(
258 data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail);
259
260 vmovaps(src_diff_vmm | src_le_zero_opmask | T_z, weights_operand);
261 vaddps(src_diff_vmm | src_gt_zero_vmm_opmask, src_diff_vmm, vmm_ones_);
262 vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm);
263 io_.at(diff_src_dt_)
264 ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset),
265 tail);
266 if (diff_src_block_tail_ && tail)
267 prelu::apply_zero_padding(this, tail_size_, diff_src_dt_,
268 diff_src_block_tail_, reg_src_diff_, nullptr);
269 }
270}
271
272template <typename Vmm>
273void jit_uni_prelu_backward_kernel_t<Vmm>::accumulate_weights_diff(
274 const Vmm &partial_sum_vmm, const Vmm &tmp_vmm,
275 const Xbyak::Address &dst_addr, bool tail) {
276
277 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
278 prelu::bcast::per_oc_blocked)) {
279 uni_vaddps(
280 weights_diff_acc_vmm_, weights_diff_acc_vmm_, partial_sum_vmm);
281 } else if (bcast_ == prelu::bcast::per_oc_n_spatial_c) {
282 if (std::is_same<Vmm, Xbyak::Zmm>::value || isa_ == avx2)
283 uni_vaddps(partial_sum_vmm, partial_sum_vmm, dst_addr);
284 else {
285 uni_vmovups(tmp_vmm, dst_addr);
286 uni_vaddps(partial_sum_vmm, partial_sum_vmm, tmp_vmm);
287 }
288 uni_vmovups(dst_addr, partial_sum_vmm);
289 } else {
290 io_.at(diff_wei_dt_)->store(partial_sum_vmm, dst_addr, tail);
291 if (diff_wei_block_tail_ && tail)
292 prelu::apply_zero_padding(this, tail_size_, diff_wei_dt_,
293 diff_wei_block_tail_, reg_weights_diff_, nullptr);
294 }
295}
296
297template <typename Vmm>
298const Xbyak::Operand &jit_uni_prelu_backward_kernel_t<Vmm>::get_or_load_weights(
299 const Xbyak::Address &src_addr, const Vmm &weights_vmm, bool tail) {
300
301 if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial,
302 prelu::bcast::per_oc_blocked))
303 return weights_const_vmm_;
304
305 io_.at(wei_dt_)->load(src_addr, weights_vmm, tail);
306 return weights_vmm;
307}
308
309static void reduce(jit_generator *host, const Xbyak::Xmm &src,
310 const Xbyak::Xmm &helper, const cpu_isa_t &isa) {
311 UNUSED(helper);
312 if (isa == sse41) {
313 host->haddps(src, src);
314 host->haddps(src, src);
315 } else {
316 host->vhaddps(src, src, src);
317 host->vhaddps(src, src, src);
318 }
319}
320
321static void reduce(jit_generator *host, const Xbyak::Ymm &src,
322 const Xbyak::Ymm &helper, const cpu_isa_t &isa) {
323 const Xbyak::Xmm xmm_helper {helper.getIdx()};
324 const Xbyak::Xmm xmm_src {src.getIdx()};
325
326 host->vextractf128(xmm_helper, src, 1);
327 host->vaddps(xmm_src, xmm_src, xmm_helper);
328 reduce(host, xmm_src, xmm_helper, isa);
329}
330
331static void reduce(jit_generator *host, const Xbyak::Zmm &src,
332 const Xbyak::Zmm &helper, const cpu_isa_t &isa) {
333 const Xbyak::Ymm ymm_helper {helper.getIdx()};
334 const Xbyak::Ymm ymm_src {src.getIdx()};
335
336 host->vextractf64x4(ymm_helper, src, 1);
337 host->vaddps(ymm_src, ymm_src, ymm_helper);
338 reduce(host, ymm_src, ymm_helper, isa);
339}
340
341template <typename Vmm>
342void jit_uni_prelu_backward_kernel_t<Vmm>::finalize() {
343 if (bcast_ == prelu::bcast::per_oc_blocked)
344 uni_vmovups(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
345 else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) {
346 reduce(this, weights_diff_acc_vmm_, weights_const_vmm_, isa_);
347 uni_vmovss(ptr[reg_weights_diff_], weights_diff_acc_vmm_);
348 }
349}
350
351template <typename Vmm>
352std::map<data_type_t, io::io_saturation_conf_t>
353jit_uni_prelu_backward_kernel_t<Vmm>::create_saturation_vmm_map() const {
354
355 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
356
357 if (saturation_needed_diff_src_)
358 saturation_map.emplace(diff_src_dt_,
359 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
360 saturation_ubound_diff_src_.getIdx(), reg_tmp_});
361
362 if (saturation_needed_diff_weights_ && diff_src_dt_ != diff_wei_dt_)
363 saturation_map.emplace(diff_wei_dt_,
364 io::io_saturation_conf_t {vmm_zeros_.getIdx(),
365 saturation_ubound_diff_weights_.getIdx(), reg_tmp_});
366
367 return saturation_map;
368}
369
370jit_prelu_backward_kernel_t *jit_prelu_backward_kernel_t::create(
371 const cpu_prelu_bwd_pd_t *pd) {
372
373 const auto isa = prelu::get_supported_isa();
374
375 const auto &src_dt = pd->src_md(0)->data_type;
376 const auto &wei_dt = pd->weights_md(0)->data_type;
377 const auto &diff_src_dt = pd->diff_src_md(0)->data_type;
378 const auto &diff_dst_dt = pd->diff_dst_md(0)->data_type;
379 const auto &diff_wei_dt = pd->diff_weights_md(0)->data_type;
380
381 if (is_superset(isa, avx512_core))
382 return new jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>(pd, isa);
383 else if (is_superset(isa, avx)) {
384 if (isa == avx
385 && prelu::is_s8u8({src_dt, wei_dt, diff_src_dt, diff_dst_dt,
386 diff_wei_dt}))
387 return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
388 else
389 return new jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>(pd, isa);
390 } else if (isa == sse41)
391 return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa);
392
393 return nullptr;
394}
395
396template class jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>;
397template class jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>;
398template class jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>;
399
400} // namespace x64
401} // namespace cpu
402} // namespace impl
403} // namespace dnnl
404