1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <numeric> |
18 | #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_base.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | namespace lrn { |
25 | |
26 | static constexpr int acc_size = sizeof(acc_data_t); |
27 | static constexpr int acc_bf_16_size = sizeof(acc_data_bf16_t); |
28 | |
29 | template <data_type_t d_type> |
30 | const int32_t |
31 | jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_args_fwd_t::mask[48] |
32 | = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19, 20, |
33 | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0, |
34 | 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
35 | |
36 | template <data_type_t d_type> |
37 | jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_args_fwd_t::jit_args_fwd_t() |
38 | : src(nullptr) |
39 | , dst(nullptr) |
40 | , ws0(nullptr) |
41 | , ws1(nullptr) |
42 | , mask_ptr(&mask[16]) {} |
43 | |
44 | template <> |
45 | void jit_avx512_common_lrn_kernel_fwd_t<f32>::load_data( |
46 | Xmm reg, const Address p, bool from_stack) { |
47 | this->vmovups(reg, p); |
48 | } |
49 | |
50 | template <> |
51 | void jit_avx512_common_lrn_kernel_fwd_t<bf16>::load_data( |
52 | Xmm reg, const Address p, bool from_stack) { |
53 | if (!from_stack) { |
54 | this->vpmovzxwd(reg, p); |
55 | this->vpslld(reg, reg, 0x10); |
56 | } else |
57 | this->vmovups(reg, p); |
58 | } |
59 | |
60 | template <> |
61 | void jit_avx512_common_lrn_kernel_fwd_t<f16>::load_data( |
62 | Xmm reg, const Address p, bool from_stack) { |
63 | if (!from_stack) { |
64 | this->vcvtph2psx(reg, p); |
65 | } else |
66 | this->vmovups(reg, p); |
67 | } |
68 | |
69 | template <data_type_t d_type> |
70 | void jit_avx512_common_lrn_kernel_fwd_t<d_type>::load_tail(int tail_value, |
71 | Reg64 src, int src_mem_offset, int dst_stack_offset, |
72 | int tmp_load_to_stack_idx_tail) { |
73 | |
74 | // TODO: Investigate if this method can be simplified by using mask or |
75 | // jit_generator load utilities. |
76 | static constexpr auto src_size = sizeof(data_t); |
77 | auto tmp_xreg = this->xreg(0, tmp_load_to_stack_idx_tail); |
78 | |
79 | const auto load_tail_simd = [&](Xmm tmp_reg, int vlen) { |
80 | this->load_data(tmp_reg, this->EVEX_compress_addr(src, src_mem_offset)); |
81 | this->vmovups(this->EVEX_compress_addr(rsp, dst_stack_offset), tmp_reg); |
82 | dst_stack_offset += vlen * acc_size; |
83 | src_mem_offset += vlen * src_size; |
84 | tail_value -= vlen; |
85 | }; |
86 | |
87 | if (tail_value >= 8) |
88 | load_tail_simd(this->yreg(0, tmp_load_to_stack_idx_tail), 8); |
89 | if (tail_value >= 4) load_tail_simd(tmp_xreg, 4); |
90 | |
91 | for (int i = 0; i < tail_value; ++i) { |
92 | if (d_type == bf16) { |
93 | this->movzx(this->imm_addr64_, word[src + src_mem_offset]); |
94 | this->vmovq(tmp_xreg, this->imm_addr64_); |
95 | this->vpslld(tmp_xreg, tmp_xreg, 0x10); |
96 | } else if (d_type == f16) { |
97 | this->vxorps(tmp_xreg, tmp_xreg, tmp_xreg); |
98 | this->vcvtsh2ss(tmp_xreg, tmp_xreg, |
99 | this->EVEX_compress_addr(src, src_mem_offset)); |
100 | } else |
101 | this->vmovss( |
102 | tmp_xreg, this->EVEX_compress_addr(src, src_mem_offset)); |
103 | |
104 | this->vmovss(ptr[rsp + dst_stack_offset], tmp_xreg); |
105 | |
106 | dst_stack_offset += acc_size; |
107 | src_mem_offset += src_size; |
108 | } |
109 | } |
110 | |
111 | template <> |
112 | void jit_avx512_common_lrn_kernel_fwd_t<f16>::store_data( |
113 | const Address addr, Zmm zr, Ymm yr) { |
114 | this->vcvtps2ph(addr, zr, this->_op_mxcsr); |
115 | } |
116 | |
117 | template <> |
118 | void jit_avx512_common_lrn_kernel_fwd_t<bf16>::store_data( |
119 | const Address addr, Zmm zr, Ymm yr) { |
120 | if (emulateBfloat_) |
121 | this->bf16_emu_->vcvtneps2bf16(yr, zr); |
122 | else |
123 | this->vcvtneps2bf16(yr, zr); |
124 | |
125 | this->vmovdqu16(addr, yr); |
126 | } |
127 | |
128 | template <> |
129 | void jit_avx512_common_lrn_kernel_fwd_t<f32>::store_data( |
130 | const Address addr, Zmm zr, Ymm yr) { |
131 | this->vmovups(addr, zr); |
132 | } |
133 | |
134 | template <> |
135 | void jit_avx512_common_lrn_kernel_fwd_t<f32>::store_tail(int tail_value, |
136 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
137 | int tmp_idx) { |
138 | |
139 | this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src, |
140 | this->yreg(0, tmp_idx)); |
141 | |
142 | const auto store_tail_simd = [&](Xmm tmp_reg, int vlen) { |
143 | this->vmovups(tmp_reg, this->EVEX_compress_addr(rsp, tmp_stack_offset)); |
144 | this->vmovups(this->EVEX_compress_addr(dst, dst_mem_offset), tmp_reg); |
145 | tmp_stack_offset += vlen * acc_size; |
146 | dst_mem_offset += vlen * acc_size; |
147 | tail_value -= vlen; |
148 | }; |
149 | |
150 | if (tail_value >= 8) store_tail_simd(this->yreg(0, tmp_idx), 8); |
151 | if (tail_value >= 4) store_tail_simd(this->xreg(0, tmp_idx), 4); |
152 | |
153 | for (int i = 0; i < tail_value; |
154 | ++i, tmp_stack_offset += acc_size, dst_mem_offset += acc_size) { |
155 | this->vmovss(this->xreg(0, tmp_idx), |
156 | this->EVEX_compress_addr(rsp, tmp_stack_offset)); |
157 | this->vmovss(this->EVEX_compress_addr(dst, dst_mem_offset), |
158 | this->xreg(0, tmp_idx)); |
159 | } |
160 | } |
161 | |
162 | template <> |
163 | void jit_avx512_common_lrn_kernel_fwd_t<bf16>::store_tail(int tail_value, |
164 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
165 | int tmp_idx) { |
166 | |
167 | this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src, |
168 | this->yreg(0, tmp_idx)); |
169 | const auto res = std::div(tail_value, 4); |
170 | |
171 | for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size, |
172 | dst_mem_offset += 4 * acc_bf_16_size) { |
173 | this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]); |
174 | this->mov(qword[dst + dst_mem_offset], this->imm_addr64_); |
175 | } |
176 | |
177 | for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size, |
178 | dst_mem_offset += acc_bf_16_size) { |
179 | this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]); |
180 | this->mov(word[dst + dst_mem_offset], this->imm_addr16_); |
181 | } |
182 | } |
183 | |
184 | // Copy-paste from above bf16 implementation |
185 | template <> |
186 | void jit_avx512_common_lrn_kernel_fwd_t<f16>::store_tail(int tail_value, |
187 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
188 | int tmp_idx) { |
189 | |
190 | this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src, |
191 | this->yreg(0, tmp_idx)); |
192 | const auto res = std::div(tail_value, 4); |
193 | |
194 | for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size, |
195 | dst_mem_offset += 4 * acc_bf_16_size) { |
196 | this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]); |
197 | this->mov(qword[dst + dst_mem_offset], this->imm_addr64_); |
198 | } |
199 | |
200 | for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size, |
201 | dst_mem_offset += acc_bf_16_size) { |
202 | this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]); |
203 | this->mov(word[dst + dst_mem_offset], this->imm_addr16_); |
204 | } |
205 | } |
206 | |
207 | template <data_type_t d_type> |
208 | jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_avx512_common_lrn_kernel_fwd_t( |
209 | prop_kind_t prop_kind, float alpha, float beta, float k, int local_size, |
210 | void *code_ptr, size_t code_size, const char *name) |
211 | : jit_generator(name, code_ptr, code_size, true, avx512_core_bf16) |
212 | , pk_(prop_kind) |
213 | , alpha_(alpha) |
214 | , beta_(beta) |
215 | , k_(k) |
216 | , local_size_ {local_size - !(local_size % 2)} |
217 | , z_prev_ {[this]() { |
218 | std::vector<int> v(this->local_size_ / 2); |
219 | std::iota(v.begin(), v.end(), 3); |
220 | return v; |
221 | }()} |
222 | , z_next_ {[this]() { |
223 | std::vector<int> v(this->local_size_ / 2); |
224 | std::iota(v.begin(), v.end(), 3 + this->local_size_ / 2); |
225 | return v; |
226 | }()} |
227 | , zsum_ {std::max(local_size_ + 2, 6)} |
228 | , emulateBfloat_(d_type == bf16 && !mayiuse(avx512_core_bf16)) |
229 | , regs_used_per_block_ {std::max(this->local_size_ + 2, 6)} |
230 | , reg_block_ {[this]() { |
231 | const int max_possible_reg_block |
232 | = (emulateBfloat_ ? 26 : 30) / this->regs_used_per_block_; |
233 | return mayiuse(avx512_core) ? max_possible_reg_block |
234 | : std::min(max_possible_reg_block, 2); |
235 | }()} { |
236 | if (emulateBfloat_) { |
237 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
238 | bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_, |
239 | bf16_emu_scratch_, bf16_emu_reserv_4_); |
240 | } |
241 | } |
242 | |
243 | template <data_type_t d_type> |
244 | Zmm jit_avx512_common_lrn_kernel_fwd_t<d_type>::zreg(int irb, int i) const { |
245 | return Zmm(irb * regs_used_per_block_ + i); |
246 | } |
247 | |
248 | template <data_type_t d_type> |
249 | Ymm jit_avx512_common_lrn_kernel_fwd_t<d_type>::yreg(int irb, int i) const { |
250 | return Ymm(irb * regs_used_per_block_ + i); |
251 | } |
252 | |
253 | template <data_type_t d_type> |
254 | Xmm jit_avx512_common_lrn_kernel_fwd_t<d_type>::xreg(int irb, int i) const { |
255 | return Xmm(irb * regs_used_per_block_ + i); |
256 | } |
257 | |
258 | template class jit_avx512_common_lrn_kernel_fwd_t<f32>; |
259 | template class jit_avx512_common_lrn_kernel_fwd_t<bf16>; |
260 | template class jit_avx512_common_lrn_kernel_fwd_t<f16>; |
261 | |
262 | } // namespace lrn |
263 | } // namespace x64 |
264 | } // namespace cpu |
265 | } // namespace impl |
266 | } // namespace dnnl |
267 | |