1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/prelu/jit_prelu_reduction_kernel.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "cpu/x64/prelu/jit_prelu_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | static constexpr dim_t alignment |
28 | = platform::get_cache_line_size() / sizeof(float); |
29 | static dim_t get_C(const cpu_prelu_bwd_pd_t *pd) { |
30 | const memory_desc_wrapper src_diff_d {pd->diff_src_md(0)}; |
31 | return src_diff_d.ndims() >= 2 ? src_diff_d.dims()[1] : 1; |
32 | } |
33 | |
34 | jit_prelu_reduction_kernel_t::jit_prelu_reduction_kernel_t( |
35 | const cpu_prelu_bwd_pd_t *pd, int simd_w) |
36 | : jit_generator(jit_name()) |
37 | , scratchpad_c_block_offset_( |
38 | utils::rnd_up(get_C(pd), alignment) * sizeof(float)) |
39 | , simd_w_(simd_w) |
40 | , data_type_(pd->diff_weights_md(0)->data_type) |
41 | , tail_size_(get_C(pd) % simd_w) |
42 | , tail_block_size_(prelu::get_block_tail_size(pd->diff_weights_md(0))) |
43 | , c_blk_nelems_(prelu::c_blk_nelems(pd->diff_weights_md(0), false)) {} |
44 | |
45 | #define PARAM_OFF(x) offsetof(call_params_t, x) |
46 | |
47 | size_t jit_prelu_reduction_kernel_t::simd_w() const { |
48 | return simd_w_; |
49 | } |
50 | |
51 | void jit_prelu_reduction_kernel_t::load_kernel_call_params() { |
52 | mov(reg_reduction_blocks_, ptr[abi_param1 + PARAM_OFF(reduction_blocks)]); |
53 | mov(reg_weights_diff_scratch_, |
54 | ptr[abi_param1 + PARAM_OFF(weights_diff_scratch)]); |
55 | mov(reg_weights_diff_, ptr[abi_param1 + PARAM_OFF(weights_diff)]); |
56 | mov(reg_tail_, byte[abi_param1 + PARAM_OFF(tail)]); |
57 | mov(reg_last_c_blk_byte_, byte[abi_param1 + PARAM_OFF(is_last_c_blk)]); |
58 | } |
59 | |
60 | #undef PARAM_OFF |
61 | |
62 | void jit_prelu_reduction_kernel_t::generate() { |
63 | Xbyak::Label tail, end; |
64 | |
65 | preamble(); |
66 | load_kernel_call_params(); |
67 | |
68 | if (tail_size_) { |
69 | cmp(reg_tail_, 1); |
70 | je(tail, T_NEAR); |
71 | |
72 | generate(false /* tail*/); |
73 | jmp(end, T_NEAR); |
74 | |
75 | L(tail); |
76 | generate(true /* tail*/); |
77 | |
78 | L(end); |
79 | } else |
80 | generate(false /* tail*/); |
81 | |
82 | postamble(); |
83 | } |
84 | |
85 | void jit_prelu_reduction_kernel_t::generate(bool tail) { |
86 | |
87 | Xbyak::Label unroll_loop, unroll_loop_tail, end; |
88 | const auto unrolling_factor = get_unrolling_factor(tail); |
89 | |
90 | prepare_kernel_const_vars(tail); |
91 | xor_(reg_offset_, reg_offset_); |
92 | L(unroll_loop); |
93 | { |
94 | const size_t offt = unrolling_factor * scratchpad_c_block_offset_; |
95 | cmp(reg_reduction_blocks_, unrolling_factor); |
96 | jl(unroll_loop_tail, T_NEAR); |
97 | compute_dst(unrolling_factor, tail); |
98 | sub(reg_reduction_blocks_, unrolling_factor); |
99 | add(reg_offset_, offt); |
100 | jmp(unroll_loop); |
101 | } |
102 | |
103 | L(unroll_loop_tail); |
104 | { |
105 | cmp(reg_reduction_blocks_, 0); |
106 | jle(end, T_NEAR); |
107 | compute_dst(1, tail); |
108 | sub(reg_reduction_blocks_, 1); |
109 | add(reg_offset_, scratchpad_c_block_offset_); |
110 | jmp(unroll_loop_tail); |
111 | } |
112 | |
113 | L(end); |
114 | |
115 | finalize(tail); |
116 | } |
117 | |
118 | int jit_prelu_reduction_kernel_t::reserve_vmm() { |
119 | return number_reserved_vmms_++; |
120 | } |
121 | |
122 | Xbyak::Address jit_prelu_reduction_kernel_t::diff_scratch_ptr( |
123 | int unrolling_group) const { |
124 | return ptr[reg_weights_diff_scratch_ + reg_offset_ |
125 | + unrolling_group * scratchpad_c_block_offset_]; |
126 | } |
127 | |
128 | template <typename Vmm> |
129 | jit_uni_prelu_reduction_kernel_t<Vmm>::jit_uni_prelu_reduction_kernel_t( |
130 | const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa) |
131 | : jit_prelu_reduction_kernel_t(pd, vreg_traits<Vmm>::vlen / sizeof(float)) |
132 | , isa_(isa) |
133 | , saturation_needed_(utils::one_of( |
134 | data_type_, data_type::s8, data_type::u8, data_type::s32)) |
135 | , tail_vmm_mask_(tail_size_ && is_subset(isa, avx2) ? reserve_vmm() : 0) |
136 | , accumulator_(reserve_vmm()) |
137 | , saturation_lower_bound_(saturation_needed_ ? reserve_vmm() : 0) |
138 | , saturation_upper_bound_(saturation_needed_ ? reserve_vmm() : 0) |
139 | , io_(this, isa_, data_type_, {}, |
140 | io::io_tail_conf_t {simd_w_, tail_size_, tail_opmask_, |
141 | tail_vmm_mask_.getIdx(), reg_tmp_}, |
142 | io::io_emu_bf16_conf_t {}, |
143 | io::io_saturation_conf_t {saturation_lower_bound_.getIdx(), |
144 | saturation_upper_bound_.getIdx(), reg_tmp_}) { |
145 | assert(tail_vmm_mask_.getIdx() == 0); |
146 | } |
147 | |
148 | template <typename Vmm> |
149 | size_t jit_uni_prelu_reduction_kernel_t<Vmm>::get_unrolling_factor( |
150 | bool tail) const { |
151 | const size_t max_num_threads = dnnl_get_max_threads(); |
152 | const size_t n_vregs = prelu::get_n_vregs(isa_); |
153 | const size_t number_of_available_regs = n_vregs |
154 | - (number_reserved_vmms_ |
155 | + (data_type_ == data_type::bf16 && isa_ == avx512_core |
156 | ? 4 |
157 | : 0)); |
158 | |
159 | return nstl::min(number_of_available_regs, max_num_threads); |
160 | } |
161 | |
162 | template <typename Vmm> |
163 | void jit_uni_prelu_reduction_kernel_t<Vmm>::finalize(bool tail) { |
164 | io_.store(accumulator_, ptr[reg_weights_diff_], tail); |
165 | |
166 | if (!tail_block_size_) return; |
167 | Xbyak::Label end; |
168 | cmp(reg_last_c_blk_byte_, 1); |
169 | jne(end, T_NEAR); |
170 | const auto base_off = (c_blk_nelems_ % simd_w_) ? tail_size_ : simd_w_; |
171 | prelu::apply_zero_padding(this, base_off, data_type_, tail_block_size_, |
172 | reg_weights_diff_, nullptr); |
173 | L(end); |
174 | } |
175 | |
176 | template <typename Vmm> |
177 | void jit_uni_prelu_reduction_kernel_t<Vmm>::prepare_kernel_const_vars( |
178 | bool tail) { |
179 | uni_vxorps(accumulator_, accumulator_, accumulator_); |
180 | |
181 | io_.init_bf16(); |
182 | if (tail) io_.prepare_tail_mask(); |
183 | if (saturation_needed_) io_.init_saturate_f32(); |
184 | } |
185 | |
186 | template <typename Vmm> |
187 | void jit_uni_prelu_reduction_kernel_t<Vmm>::compute_dst( |
188 | int unrolling_factor, bool tail) { |
189 | |
190 | const int vmm_begin = number_reserved_vmms_; |
191 | |
192 | for (int unrolling_group = 0; unrolling_group < unrolling_factor; |
193 | ++unrolling_group) { |
194 | const Vmm load_vmm {vmm_begin + unrolling_group}; |
195 | uni_vmovups(load_vmm, diff_scratch_ptr(unrolling_group)); |
196 | uni_vaddps(accumulator_, accumulator_, load_vmm); |
197 | } |
198 | } |
199 | |
200 | jit_prelu_reduction_kernel_t *jit_prelu_reduction_kernel_t::create( |
201 | const cpu_prelu_bwd_pd_t *pd) { |
202 | |
203 | const auto isa = prelu::get_supported_isa(); |
204 | |
205 | if (is_superset(isa, avx512_core)) |
206 | return new jit_uni_prelu_reduction_kernel_t<Xbyak::Zmm>(pd, isa); |
207 | else if (is_superset(isa, avx)) |
208 | if (isa == avx && prelu::is_s8u8({pd->diff_weights_md(0)->data_type})) |
209 | return new jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>(pd, isa); |
210 | else |
211 | return new jit_uni_prelu_reduction_kernel_t<Xbyak::Ymm>(pd, isa); |
212 | else if (isa == sse41) |
213 | return new jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>(pd, isa); |
214 | |
215 | return nullptr; |
216 | } |
217 | |
218 | template class jit_uni_prelu_reduction_kernel_t<Xbyak::Zmm>; |
219 | template class jit_uni_prelu_reduction_kernel_t<Xbyak::Ymm>; |
220 | template class jit_uni_prelu_reduction_kernel_t<Xbyak::Xmm>; |
221 | |
222 | } // namespace x64 |
223 | } // namespace cpu |
224 | } // namespace impl |
225 | } // namespace dnnl |
226 | |