1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/prelu/jit_prelu_reduction_kernel.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "cpu/x64/prelu/jit_prelu_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27static constexpr dim_t alignment
28 = platform::get_cache_line_size() / sizeof(float);
29static dim_t get_C(const cpu_prelu_bwd_pd_t *pd) {
30 const memory_desc_wrapper src_diff_d {pd->diff_src_md(0)};
31 return src_diff_d.ndims() >= 2 ? src_diff_d.dims()[1] : 1;
32}
33
34jit_prelu_reduction_kernel_t::jit_prelu_reduction_kernel_t(
35 const cpu_prelu_bwd_pd_t *pd, int simd_w)
36 : jit_generator(jit_name())
37 , scratchpad_c_block_offset_(
38 utils::rnd_up(get_C(pd), alignment) * sizeof(float))
39 , simd_w_(simd_w)
40 , data_type_(pd->diff_weights_md(0)->data_type)
41 , tail_size_(get_C(pd) % simd_w)
42 , tail_block_size_(prelu::get_block_tail_size(pd->diff_weights_md(0)))
43 , c_blk_nelems_(prelu::c_blk_nelems(pd->diff_weights_md(0), false)) {}
44
45#define PARAM_OFF(x) offsetof(call_params_t, x)
46
47size_t jit_prelu_reduction_kernel_t::simd_w() const {
48 return simd_w_;
49}
50
51void jit_prelu_reduction_kernel_t::load_kernel_call_params() {
52 mov(reg_reduction_blocks_, ptr[abi_param1 + PARAM_OFF(reduction_blocks)]);
53 mov(reg_weights_diff_scratch_,
54 ptr[abi_param1 + PARAM_OFF(weights_diff_scratch)]);
55 mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]);
56 mov(reg_tail_, byte[abi_param1 + PARAM_OFF(tail)]);
57 mov(reg_last_c_blk_byte_, byte[abi_param1 + PARAM_OFF(is_last_c_blk)]);
58}
59
60#undef PARAM_OFF
61
62void jit_prelu_reduction_kernel_t::generate() {
63 Xbyak::Label tail, end;
64
65 preamble();
66 load_kernel_call_params();
67
68 if (tail_size_) {
69 cmp(reg_tail_, 1);
70 je(tail, T_NEAR);
71
72 generate(false /* tail*/);
73 jmp(end, T_NEAR);
74
75 L(tail);
76 generate(true /* tail*/);
77
78 L(end);
79 } else
80 generate(false /* tail*/);
81
82 postamble();
83}
84
85void jit_prelu_reduction_kernel_t::generate(bool tail) {
86
87 Xbyak::Label unroll_loop, unroll_loop_tail, end;
88 const auto unrolling_factor = get_unrolling_factor(tail);
89
90 prepare_kernel_const_vars(tail);
91 xor_(reg_offset_, reg_offset_);
92 L(unroll_loop);
93 {
94 const size_t offt = unrolling_factor * scratchpad_c_block_offset_;
95 cmp(reg_reduction_blocks_, unrolling_factor);
96 jl(unroll_loop_tail, T_NEAR);
97 compute_dst(unrolling_factor, tail);
98 sub(reg_reduction_blocks_, unrolling_factor);
99 add(reg_offset_, offt);
100 jmp(unroll_loop);
101 }
102
103 L(unroll_loop_tail);
104 {
105 cmp(reg_reduction_blocks_, 0);
106 jle(end, T_NEAR);
107 compute_dst(1, tail);
108 sub(reg_reduction_blocks_, 1);
109 add(reg_offset_, scratchpad_c_block_offset_);
110 jmp(unroll_loop_tail);
111 }
112
113 L(end);
114
115 finalize(tail);
116}
117
118int jit_prelu_reduction_kernel_t::reserve_vmm() {
119 return number_reserved_vmms_++;
120}
121
122Xbyak::Address jit_prelu_reduction_kernel_t::diff_scratch_ptr(
123 int unrolling_group) const {
124 return ptr[reg_weights_diff_scratch_ + reg_offset_
125 + unrolling_group * scratchpad_c_block_offset_];
126}
127
128template <typename Vmm>
129jit_uni_prelu_reduction_kernel_t<Vmm>::jit_uni_prelu_reduction_kernel_t(
130 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa)
131 : jit_prelu_reduction_kernel_t(pd, vreg_traits<Vmm>::vlen / sizeof(float))
132 , isa_(isa)
133 , saturation_needed_(utils::one_of(
134 data_type_, data_type::s8, data_type::u8, data_type::s32))
135 , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0)
136 , accumulator_(reserve_vmm())
137 , saturation_lower_bound_(saturation_needed_ ? reserve_vmm() : 0)
138 , saturation_upper_bound_(saturation_needed_ ? reserve_vmm() : 0)
139 , io_(this, isa_, data_type_, {},
140 io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_,
141 tail_vmm_mask_.getIdx(), reg_tmp_},
142 io::io_emu_bf16_conf_t {},
143 io::io_saturation_conf_t {saturation_lower_bound_.getIdx(),
144 saturation_upper_bound_.getIdx(), reg_tmp_}) {
145 assert(tail_vmm_mask_.getIdx() == 0);
146}
147
148template <typename Vmm>
149size_t jit_uni_prelu_reduction_kernel_t<Vmm>::get_unrolling_factor(
150 bool tail) const {
151 const size_t max_num_threads = dnnl_get_max_threads();
152 const size_t n_vregs = prelu::get_n_vregs(isa_);
153 const size_t number_of_available_regs = n_vregs
154 - (number_reserved_vmms_
155 + (data_type_ == data_type::bf16 && isa_ == avx512_core
156 ? 4
157 : 0));
158
159 return nstl::min(number_of_available_regs, max_num_threads);
160}
161
162template <typename Vmm>
163void jit_uni_prelu_reduction_kernel_t<Vmm>::finalize(bool tail) {
164 io_.store(accumulator_, ptr[reg_weights_diff_], tail);
165
166 if (!tail_block_size_) return;
167 Xbyak::Label end;
168 cmp(reg_last_c_blk_byte_, 1);
169 jne(end, T_NEAR);
170 const auto base_off = (c_blk_nelems_ % simd_w_) ? tail_size_ : simd_w_;
171 prelu::apply_zero_padding(this, base_off, data_type_, tail_block_size_,
172 reg_weights_diff_, nullptr);
173 L(end);
174}
175
176template <typename Vmm>
177void jit_uni_prelu_reduction_kernel_t<Vmm>::prepare_kernel_const_vars(
178 bool tail) {
179 uni_vxorps(accumulator_, accumulator_, accumulator_);
180
181 io_.init_bf16();
182 if (tail) io_.prepare_tail_mask();
183 if (saturation_needed_) io_.init_saturate_f32();
184}
185
186template <typename Vmm>
187void jit_uni_prelu_reduction_kernel_t<Vmm>::compute_dst(
188 int unrolling_factor, bool tail) {
189
190 const int vmm_begin = number_reserved_vmms_;
191
192 for (int unrolling_group = 0; unrolling_group < unrolling_factor;
193 ++unrolling_group) {
194 const Vmm load_vmm {vmm_begin + unrolling_group};
195 uni_vmovups(load_vmm, diff_scratch_ptr(unrolling_group));
196 uni_vaddps(accumulator_, accumulator_, load_vmm);
197 }
198}
199
200jit_prelu_reduction_kernel_t *jit_prelu_reduction_kernel_t::create(
201 const cpu_prelu_bwd_pd_t *pd) {
202
203 const auto isa = prelu::get_supported_isa();
204
205 if (is_superset(isa, avx512_core))
206 return new jit_uni_prelu_reduction_kernel_t<Xbyak::Zmm>(pd, isa);
207 else if (is_superset(isa, avx))
208 if (isa == avx && prelu::is_s8u8({pd->diff_weights_md(0)->data_type}))
209 return new jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>(pd, isa);
210 else
211 return new jit_uni_prelu_reduction_kernel_t<Xbyak::Ymm>(pd, isa);
212 else if (isa == sse41)
213 return new jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>(pd, isa);
214
215 return nullptr;
216}
217
218template class jit_uni_prelu_reduction_kernel_t<Xbyak::Zmm>;
219template class jit_uni_prelu_reduction_kernel_t<Xbyak::Ymm>;
220template class jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>;
221
222} // namespace x64
223} // namespace cpu
224} // namespace impl
225} // namespace dnnl
226