1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <type_traits> |
17 | |
18 | #include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | |
25 | jit_prelu_backward_kernel_t::jit_prelu_backward_kernel_t( |
26 | const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa, const int vlen, |
27 | const size_t number_vmm_single_compute) |
28 | : jit_prelu_base_kernel_t(isa, vlen, |
29 | prelu::get_bcast_type(memory_desc_wrapper(pd->diff_src_md(0)), |
30 | memory_desc_wrapper(pd->diff_weights_md(0))), |
31 | memory_desc_wrapper(pd->diff_src_md(0)), number_vmm_single_compute, |
32 | jit_name()) |
33 | , pd_(pd) |
34 | , src_dt_(pd->src_md(0)->data_type) |
35 | , wei_dt_(pd->weights_md(0)->data_type) |
36 | , diff_src_dt_(pd->diff_src_md(0)->data_type) |
37 | , diff_dst_dt_(pd->diff_dst_md(0)->data_type) |
38 | , diff_wei_dt_(bcast_ == prelu::bcast::full |
39 | ? pd->diff_weights_md(0)->data_type |
40 | : data_type::f32) |
41 | , diff_src_block_tail_(prelu::get_block_tail_size(pd->diff_src_md(0))) |
42 | , diff_wei_block_tail_(prelu::get_block_tail_size(pd->diff_weights_md(0))) { |
43 | } |
44 | |
45 | #define PARAM_OFF(x) offsetof(call_params_t, x) |
46 | |
47 | void jit_prelu_backward_kernel_t::load_kernel_call_params() { |
48 | mov(reg_src_, ptr[abi_param1 + PARAM_OFF(src)]); |
49 | mov(reg_weights_, ptr[abi_param1 + PARAM_OFF(weights)]); |
50 | mov(reg_src_diff_, ptr[abi_param1 + PARAM_OFF(src_diff)]); |
51 | mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]); |
52 | mov(reg_dst_diff_, ptr[abi_param1 + PARAM_OFF(dst_diff)]); |
53 | mov(reg_data_size_, ptr[abi_param1 + PARAM_OFF(compute_data_size)]); |
54 | } |
55 | |
56 | #undef PARAM_OFF |
57 | |
58 | Xbyak::Address jit_prelu_backward_kernel_t::data_ptr(int arg_num, size_t offt) { |
59 | const auto get_addr |
60 | = [&](const Xbyak::Reg64 ®_base, const data_type_t dt) { |
61 | const auto dt_size = types::data_type_size(dt); |
62 | return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size]; |
63 | }; |
64 | |
65 | switch (arg_num) { |
66 | case DNNL_ARG_SRC: return get_addr(reg_src_, src_dt_); |
67 | case DNNL_ARG_WEIGHTS: return get_addr(reg_weights_, wei_dt_); |
68 | case DNNL_ARG_DIFF_SRC: return get_addr(reg_src_diff_, diff_src_dt_); |
69 | case DNNL_ARG_DIFF_WEIGHTS: |
70 | return get_addr(reg_weights_diff_, diff_wei_dt_); |
71 | case DNNL_ARG_DIFF_DST: return get_addr(reg_dst_diff_, diff_dst_dt_); |
72 | |
73 | default: assert(!"unsupported arg_num" ); break; |
74 | } |
75 | return Xbyak::Address(0); |
76 | } |
77 | |
78 | bool jit_prelu_backward_kernel_t::any_tensor_bf16() const { |
79 | return utils::one_of(data_type::bf16, src_dt_, wei_dt_, diff_src_dt_, |
80 | diff_dst_dt_, diff_wei_dt_); |
81 | } |
82 | |
83 | template <typename Vmm> |
84 | jit_uni_prelu_backward_kernel_t<Vmm>::jit_uni_prelu_backward_kernel_t( |
85 | const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa) |
86 | : jit_prelu_backward_kernel_t(pd, isa, vreg_traits<Vmm>::vlen, |
87 | std::is_same<Vmm, Xbyak::Zmm>::value ? 4u : 6u) |
88 | , saturation_needed_diff_src_(utils::one_of( |
89 | diff_src_dt_, data_type::u8, data_type::s8, data_type::s32)) |
90 | , saturation_needed_diff_weights_(utils::one_of( |
91 | diff_wei_dt_, data_type::u8, data_type::s8, data_type::s32)) |
92 | , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0) |
93 | , vmm_zeros_(reserve_vmm()) |
94 | , saturation_ubound_diff_src_( |
95 | saturation_needed_diff_src_ ? reserve_vmm() : 0) |
96 | , saturation_ubound_diff_weights_(saturation_needed_diff_weights_ |
97 | ? (diff_wei_dt_ == diff_src_dt_ |
98 | ? saturation_ubound_diff_src_.getIdx() |
99 | : reserve_vmm()) |
100 | : 0) |
101 | , vmm_ones_(reserve_vmm()) |
102 | , weights_const_vmm_(utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
103 | prelu::bcast::per_oc_blocked) |
104 | ? reserve_vmm() |
105 | : 0) |
106 | , weights_diff_acc_vmm_( |
107 | utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
108 | prelu::bcast::per_oc_blocked) |
109 | ? reserve_vmm() |
110 | : 0) |
111 | , io_(this, isa, |
112 | {src_dt_, wei_dt_, diff_src_dt_, diff_dst_dt_, diff_wei_dt_}, {}, |
113 | io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_, |
114 | tail_vmm_mask_.getIdx(), reg_tmp_}, |
115 | io::io_emu_bf16_conf_t {}, create_saturation_vmm_map()) { |
116 | assert(tail_vmm_mask_.getIdx() == 0); |
117 | } |
118 | |
119 | template <typename Vmm> |
120 | jit_uni_prelu_backward_kernel_t<Vmm>::~jit_uni_prelu_backward_kernel_t() |
121 | = default; |
122 | |
123 | template <typename Vmm> |
124 | void jit_uni_prelu_backward_kernel_t<Vmm>::prepare_kernel_const_vars() { |
125 | uni_vxorps(vmm_zeros_, vmm_zeros_, vmm_zeros_); |
126 | |
127 | io_.init_bf16(); |
128 | if (tail_size_) io_.prepare_tail_mask(); |
129 | if (saturation_needed_diff_src_ || saturation_needed_diff_weights_) { |
130 | io_.init_saturate_f32({diff_src_dt_, diff_wei_dt_}); |
131 | } |
132 | // load ones |
133 | this->mov(this->reg_tmp_, float2int(1)); |
134 | const Xbyak::Xmm xmm_ones_ {vmm_ones_.getIdx()}; |
135 | this->uni_vmovq(xmm_ones_, this->reg_tmp_); |
136 | this->uni_vbroadcastss(vmm_ones_, xmm_ones_); |
137 | |
138 | if (bcast_ == prelu::bcast::per_oc_blocked) { |
139 | io_.at(wei_dt_)->load( |
140 | ptr[reg_weights_], weights_const_vmm_, false /*tail*/); |
141 | vmovups(weights_diff_acc_vmm_, ptr[reg_weights_diff_]); |
142 | } else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) { |
143 | io_.at(wei_dt_)->broadcast(ptr[reg_weights_], weights_const_vmm_); |
144 | uni_vxorps(weights_diff_acc_vmm_, weights_diff_acc_vmm_, |
145 | weights_diff_acc_vmm_); |
146 | uni_vmovss(weights_diff_acc_vmm_, ptr[reg_weights_diff_]); |
147 | } |
148 | } |
149 | |
150 | template <typename Vmm> |
151 | void jit_uni_prelu_backward_kernel_t<Vmm>::compute_dst( |
152 | size_t unrolling_factor, bool tail) { |
153 | |
154 | static constexpr size_t dst_diff_idx = 0; |
155 | static constexpr size_t src_idx = 1; |
156 | static constexpr size_t src_le_zero_idx = 2; |
157 | static constexpr size_t src_gt_zero_idx = 3; |
158 | static constexpr size_t weights_diff_idx = 4; |
159 | static constexpr size_t weights_idx = 5; |
160 | |
161 | for (size_t unroll_group = 0; unroll_group < unrolling_factor; |
162 | ++unroll_group) { |
163 | |
164 | const Vmm dst_diff_vmm {get_compute_vmm(dst_diff_idx, unroll_group)}; |
165 | const Vmm src_vmm {get_compute_vmm(src_idx, unroll_group)}; |
166 | const Vmm src_le_zero_vmm { |
167 | get_compute_vmm(src_le_zero_idx, unroll_group)}; |
168 | const Vmm src_gt_zero_vmm { |
169 | get_compute_vmm(src_gt_zero_idx, unroll_group)}; |
170 | const Vmm weights_diff_vmm { |
171 | get_compute_vmm(weights_diff_idx, unroll_group)}; |
172 | const Vmm weights_vmm {get_compute_vmm(weights_idx, unroll_group)}; |
173 | |
174 | const auto offset = unroll_group * simd_w_; |
175 | io_.at(diff_dst_dt_) |
176 | ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail); |
177 | io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail); |
178 | static constexpr int VCMPLEPS = 2; |
179 | uni_vcmpps(src_le_zero_vmm, src_vmm, vmm_zeros_, VCMPLEPS); |
180 | uni_vandps(src_le_zero_vmm, src_le_zero_vmm, vmm_ones_); |
181 | static constexpr int VCMPGTPS = 14; |
182 | uni_vcmpps(src_gt_zero_vmm, src_vmm, vmm_zeros_, VCMPGTPS); |
183 | uni_vandps(src_gt_zero_vmm, src_gt_zero_vmm, vmm_ones_); |
184 | |
185 | //weights_diff_calculations |
186 | uni_vmulps(weights_diff_vmm, dst_diff_vmm, src_vmm); |
187 | uni_vmulps(weights_diff_vmm, weights_diff_vmm, src_le_zero_vmm); |
188 | |
189 | //src_diff calculations |
190 | const auto weights_operand = get_or_load_weights( |
191 | data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail); |
192 | uni_vfmadd231ps(src_gt_zero_vmm, src_le_zero_vmm, weights_operand); |
193 | const auto &src_diff_vmm = src_gt_zero_vmm; |
194 | uni_vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm); |
195 | io_.at(diff_src_dt_) |
196 | ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset), |
197 | tail); |
198 | if (diff_src_block_tail_ && tail) |
199 | prelu::apply_zero_padding(this, tail_size_, diff_src_dt_, |
200 | diff_src_block_tail_, reg_src_diff_, nullptr); |
201 | |
202 | accumulate_weights_diff(weights_diff_vmm, src_gt_zero_vmm, |
203 | data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail); |
204 | } |
205 | } |
206 | |
207 | template <> |
208 | void jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>::compute_dst( |
209 | size_t unrolling_factor, bool tail) { |
210 | |
211 | size_t opmask_counter = 2; |
212 | auto get_next_opmask = [opmask_counter]() mutable { |
213 | static constexpr size_t opmask_range_begin = 2; |
214 | static constexpr size_t opmask_range_end = 8; |
215 | const auto opmask = Xbyak::Opmask(opmask_counter++); |
216 | if (opmask_counter == opmask_range_end) |
217 | opmask_counter = opmask_range_begin; |
218 | return opmask; |
219 | }; |
220 | |
221 | static constexpr size_t dst_diff_idx = 0; |
222 | static constexpr size_t src_idx = 1; |
223 | static constexpr size_t weights_diff_idx = 2; |
224 | static constexpr size_t weights_idx = 3; |
225 | |
226 | for (size_t unroll_group = 0; unroll_group < unrolling_factor; |
227 | ++unroll_group) { |
228 | |
229 | const auto offset = unroll_group * simd_w_; |
230 | const Xbyak::Zmm dst_diff_vmm { |
231 | get_compute_vmm(dst_diff_idx, unroll_group)}; |
232 | const Xbyak::Zmm src_vmm {get_compute_vmm(src_idx, unroll_group)}; |
233 | |
234 | io_.at(diff_dst_dt_) |
235 | ->load(data_ptr(DNNL_ARG_DIFF_DST, offset), dst_diff_vmm, tail); |
236 | io_.at(src_dt_)->load(data_ptr(DNNL_ARG_SRC, offset), src_vmm, tail); |
237 | |
238 | const Xbyak::Opmask src_le_zero_opmask = get_next_opmask(); |
239 | static constexpr int VCMPLEPS = 2; |
240 | vcmpps(src_le_zero_opmask, src_vmm, vmm_zeros_, VCMPLEPS); |
241 | const Xbyak::Opmask src_gt_zero_vmm_opmask = get_next_opmask(); |
242 | static constexpr int VCMPGTPS = 14; |
243 | vcmpps(src_gt_zero_vmm_opmask, src_vmm, vmm_zeros_, VCMPGTPS); |
244 | |
245 | // //weights_diff_calculations |
246 | const Xbyak::Zmm weights_diff_vmm { |
247 | get_compute_vmm(weights_diff_idx, unroll_group)}; |
248 | vmulps(weights_diff_vmm | src_le_zero_opmask | T_z, dst_diff_vmm, |
249 | src_vmm); |
250 | accumulate_weights_diff(weights_diff_vmm, weights_diff_acc_vmm_, |
251 | data_ptr(DNNL_ARG_DIFF_WEIGHTS, offset), tail); |
252 | |
253 | //src_diff calculations |
254 | const Xbyak::Zmm weights_vmm { |
255 | get_compute_vmm(weights_idx, unroll_group)}; |
256 | const auto &src_diff_vmm = weights_vmm; |
257 | const auto weights_operand = get_or_load_weights( |
258 | data_ptr(DNNL_ARG_WEIGHTS, offset), weights_vmm, tail); |
259 | |
260 | vmovaps(src_diff_vmm | src_le_zero_opmask | T_z, weights_operand); |
261 | vaddps(src_diff_vmm | src_gt_zero_vmm_opmask, src_diff_vmm, vmm_ones_); |
262 | vmulps(src_diff_vmm, src_diff_vmm, dst_diff_vmm); |
263 | io_.at(diff_src_dt_) |
264 | ->store(src_diff_vmm, data_ptr(DNNL_ARG_DIFF_SRC, offset), |
265 | tail); |
266 | if (diff_src_block_tail_ && tail) |
267 | prelu::apply_zero_padding(this, tail_size_, diff_src_dt_, |
268 | diff_src_block_tail_, reg_src_diff_, nullptr); |
269 | } |
270 | } |
271 | |
272 | template <typename Vmm> |
273 | void jit_uni_prelu_backward_kernel_t<Vmm>::accumulate_weights_diff( |
274 | const Vmm &partial_sum_vmm, const Vmm &tmp_vmm, |
275 | const Xbyak::Address &dst_addr, bool tail) { |
276 | |
277 | if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
278 | prelu::bcast::per_oc_blocked)) { |
279 | uni_vaddps( |
280 | weights_diff_acc_vmm_, weights_diff_acc_vmm_, partial_sum_vmm); |
281 | } else if (bcast_ == prelu::bcast::per_oc_n_spatial_c) { |
282 | if (std::is_same<Vmm, Xbyak::Zmm>::value || isa_ == avx2) |
283 | uni_vaddps(partial_sum_vmm, partial_sum_vmm, dst_addr); |
284 | else { |
285 | uni_vmovups(tmp_vmm, dst_addr); |
286 | uni_vaddps(partial_sum_vmm, partial_sum_vmm, tmp_vmm); |
287 | } |
288 | uni_vmovups(dst_addr, partial_sum_vmm); |
289 | } else { |
290 | io_.at(diff_wei_dt_)->store(partial_sum_vmm, dst_addr, tail); |
291 | if (diff_wei_block_tail_ && tail) |
292 | prelu::apply_zero_padding(this, tail_size_, diff_wei_dt_, |
293 | diff_wei_block_tail_, reg_weights_diff_, nullptr); |
294 | } |
295 | } |
296 | |
297 | template <typename Vmm> |
298 | const Xbyak::Operand &jit_uni_prelu_backward_kernel_t<Vmm>::get_or_load_weights( |
299 | const Xbyak::Address &src_addr, const Vmm &weights_vmm, bool tail) { |
300 | |
301 | if (utils::one_of(bcast_, prelu::bcast::per_oc_n_c_spatial, |
302 | prelu::bcast::per_oc_blocked)) |
303 | return weights_const_vmm_; |
304 | |
305 | io_.at(wei_dt_)->load(src_addr, weights_vmm, tail); |
306 | return weights_vmm; |
307 | } |
308 | |
309 | static void reduce(jit_generator *host, const Xbyak::Xmm &src, |
310 | const Xbyak::Xmm &helper, const cpu_isa_t &isa) { |
311 | UNUSED(helper); |
312 | if (isa == sse41) { |
313 | host->haddps(src, src); |
314 | host->haddps(src, src); |
315 | } else { |
316 | host->vhaddps(src, src, src); |
317 | host->vhaddps(src, src, src); |
318 | } |
319 | } |
320 | |
321 | static void reduce(jit_generator *host, const Xbyak::Ymm &src, |
322 | const Xbyak::Ymm &helper, const cpu_isa_t &isa) { |
323 | const Xbyak::Xmm xmm_helper {helper.getIdx()}; |
324 | const Xbyak::Xmm xmm_src {src.getIdx()}; |
325 | |
326 | host->vextractf128(xmm_helper, src, 1); |
327 | host->vaddps(xmm_src, xmm_src, xmm_helper); |
328 | reduce(host, xmm_src, xmm_helper, isa); |
329 | } |
330 | |
331 | static void reduce(jit_generator *host, const Xbyak::Zmm &src, |
332 | const Xbyak::Zmm &helper, const cpu_isa_t &isa) { |
333 | const Xbyak::Ymm ymm_helper {helper.getIdx()}; |
334 | const Xbyak::Ymm ymm_src {src.getIdx()}; |
335 | |
336 | host->vextractf64x4(ymm_helper, src, 1); |
337 | host->vaddps(ymm_src, ymm_src, ymm_helper); |
338 | reduce(host, ymm_src, ymm_helper, isa); |
339 | } |
340 | |
341 | template <typename Vmm> |
342 | void jit_uni_prelu_backward_kernel_t<Vmm>::finalize() { |
343 | if (bcast_ == prelu::bcast::per_oc_blocked) |
344 | uni_vmovups(ptr[reg_weights_diff_], weights_diff_acc_vmm_); |
345 | else if (bcast_ == prelu::bcast::per_oc_n_c_spatial) { |
346 | reduce(this, weights_diff_acc_vmm_, weights_const_vmm_, isa_); |
347 | uni_vmovss(ptr[reg_weights_diff_], weights_diff_acc_vmm_); |
348 | } |
349 | } |
350 | |
351 | template <typename Vmm> |
352 | std::map<data_type_t, io::io_saturation_conf_t> |
353 | jit_uni_prelu_backward_kernel_t<Vmm>::create_saturation_vmm_map() const { |
354 | |
355 | std::map<data_type_t, io::io_saturation_conf_t> saturation_map {}; |
356 | |
357 | if (saturation_needed_diff_src_) |
358 | saturation_map.emplace(diff_src_dt_, |
359 | io::io_saturation_conf_t {vmm_zeros_.getIdx(), |
360 | saturation_ubound_diff_src_.getIdx(), reg_tmp_}); |
361 | |
362 | if (saturation_needed_diff_weights_ && diff_src_dt_ != diff_wei_dt_) |
363 | saturation_map.emplace(diff_wei_dt_, |
364 | io::io_saturation_conf_t {vmm_zeros_.getIdx(), |
365 | saturation_ubound_diff_weights_.getIdx(), reg_tmp_}); |
366 | |
367 | return saturation_map; |
368 | } |
369 | |
370 | jit_prelu_backward_kernel_t *jit_prelu_backward_kernel_t::create( |
371 | const cpu_prelu_bwd_pd_t *pd) { |
372 | |
373 | const auto isa = prelu::get_supported_isa(); |
374 | |
375 | const auto &src_dt = pd->src_md(0)->data_type; |
376 | const auto &wei_dt = pd->weights_md(0)->data_type; |
377 | const auto &diff_src_dt = pd->diff_src_md(0)->data_type; |
378 | const auto &diff_dst_dt = pd->diff_dst_md(0)->data_type; |
379 | const auto &diff_wei_dt = pd->diff_weights_md(0)->data_type; |
380 | |
381 | if (is_superset(isa, avx512_core)) |
382 | return new jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>(pd, isa); |
383 | else if (is_superset(isa, avx)) { |
384 | if (isa == avx |
385 | && prelu::is_s8u8({src_dt, wei_dt, diff_src_dt, diff_dst_dt, |
386 | diff_wei_dt})) |
387 | return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa); |
388 | else |
389 | return new jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>(pd, isa); |
390 | } else if (isa == sse41) |
391 | return new jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>(pd, isa); |
392 | |
393 | return nullptr; |
394 | } |
395 | |
396 | template class jit_uni_prelu_backward_kernel_t<Xbyak::Zmm>; |
397 | template class jit_uni_prelu_backward_kernel_t<Xbyak::Ymm>; |
398 | template class jit_uni_prelu_backward_kernel_t<Xbyak::Xmm>; |
399 | |
400 | } // namespace x64 |
401 | } // namespace cpu |
402 | } // namespace impl |
403 | } // namespace dnnl |
404 | |