1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <numeric> |
18 | #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_base.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | namespace lrn { |
25 | |
26 | static constexpr int acc_size = sizeof(acc_data_t); |
27 | static constexpr int acc_bf_16_size = sizeof(acc_data_bf16_t); |
28 | |
29 | template <data_type_t d_type> |
30 | const int32_t |
31 | jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_args_bwd_t::mask[48] |
32 | = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19, 20, |
33 | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0, |
34 | 0, 0, 0, 0, 0, 0, 0, 0, 0}; |
35 | |
36 | template <data_type_t d_type> |
37 | jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_args_bwd_t::jit_args_bwd_t() |
38 | : src(nullptr) |
39 | , diff_dst(nullptr) |
40 | , ws0(nullptr) |
41 | , ws1(nullptr) |
42 | , diff_src(nullptr) |
43 | , mask_ptr(&mask[16]) {} |
44 | |
45 | template <> |
46 | void jit_avx512_common_lrn_kernel_bwd_t<f32>::load_data( |
47 | Xmm reg, const Address p, bool from_stack) { |
48 | this->vmovups(reg, p); |
49 | } |
50 | |
51 | template <> |
52 | void jit_avx512_common_lrn_kernel_bwd_t<bf16>::load_data( |
53 | Xmm reg, const Address p, bool from_stack) { |
54 | if (!from_stack) { |
55 | this->vpmovzxwd(reg, p); |
56 | this->vpslld(reg, reg, 0x10); |
57 | } else |
58 | this->vmovups(reg, p); |
59 | } |
60 | |
61 | template <> |
62 | void jit_avx512_common_lrn_kernel_bwd_t<f16>::load_data( |
63 | Xmm reg, const Address p, bool from_stack) { |
64 | if (!from_stack) { |
65 | this->vcvtph2psx(reg, p); |
66 | } else |
67 | this->vmovups(reg, p); |
68 | } |
69 | |
70 | template <> |
71 | void jit_avx512_common_lrn_kernel_bwd_t<f16>::store_data( |
72 | bool nt, const Address addr, Zmm zr) { |
73 | this->vcvtps2ph(addr, zr, this->_op_mxcsr); |
74 | } |
75 | |
76 | template <> |
77 | void jit_avx512_common_lrn_kernel_bwd_t<bf16>::store_data( |
78 | bool nt, const Address addr, Zmm zr) { |
79 | const Ymm yr = Ymm(zr.getIdx()); |
80 | if (mayiuse(avx512_core_bf16)) |
81 | vcvtneps2bf16(yr, zr); |
82 | else |
83 | bf16_emu_->vcvtneps2bf16(yr, zr); |
84 | vmovdqu16(addr, yr); |
85 | } |
86 | |
87 | template <> |
88 | void jit_avx512_common_lrn_kernel_bwd_t<f32>::store_data( |
89 | bool non_temp_hint, const Address addr, Zmm zr) { |
90 | if (non_temp_hint) |
91 | uni_vmovntps(addr, zr); |
92 | else |
93 | uni_vmovups(addr, zr); |
94 | } |
95 | |
96 | template <data_type_t d_type> |
97 | void jit_avx512_common_lrn_kernel_bwd_t<d_type>::load_tail(int tail_value, |
98 | Reg64 src, int src_mem_offset, int dst_stack_offset, |
99 | int tmp_load_to_stack_idx_tail) { |
100 | // TODO: Investigate if this method can be simplified by using mask or |
101 | // jit_generator load utilities. |
102 | static constexpr auto src_acc_size |
103 | = utils::one_of(d_type, bf16, f16) ? acc_bf_16_size : acc_size; |
104 | auto tmp_xreg = this->xreg(0, tmp_load_to_stack_idx_tail); |
105 | |
106 | const auto load_tail_simd = [&](Xmm tmp_reg, int vlen) { |
107 | this->load_data(tmp_reg, this->EVEX_compress_addr(src, src_mem_offset)); |
108 | this->vmovups(this->EVEX_compress_addr(rsp, dst_stack_offset), tmp_reg); |
109 | dst_stack_offset += vlen * acc_size; |
110 | src_mem_offset += vlen * src_acc_size; |
111 | tail_value -= vlen; |
112 | }; |
113 | |
114 | if (tail_value >= 8) |
115 | load_tail_simd(this->yreg(0, tmp_load_to_stack_idx_tail), 8); |
116 | if (tail_value >= 4) load_tail_simd(tmp_xreg, 4); |
117 | |
118 | for (int i = 0; i < tail_value; ++i) { |
119 | if (d_type == bf16) { |
120 | this->movzx(this->imm_addr64_, word[src + src_mem_offset]); |
121 | this->vmovq(tmp_xreg, this->imm_addr64_); |
122 | this->vpslld(tmp_xreg, tmp_xreg, 0x10); |
123 | } else if (d_type == f16) { |
124 | this->vxorps(tmp_xreg, tmp_xreg, tmp_xreg); |
125 | this->vcvtsh2ss(tmp_xreg, tmp_xreg, |
126 | this->EVEX_compress_addr(src, src_mem_offset)); |
127 | } else |
128 | this->vmovss( |
129 | tmp_xreg, this->EVEX_compress_addr(src, src_mem_offset)); |
130 | |
131 | this->vmovss(ptr[rsp + dst_stack_offset], tmp_xreg); |
132 | |
133 | dst_stack_offset += acc_size; |
134 | src_mem_offset += src_acc_size; |
135 | } |
136 | } |
137 | |
138 | template <> |
139 | void jit_avx512_common_lrn_kernel_bwd_t<f32>::store_tail(int tail_value, |
140 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
141 | int tmp_idx) { |
142 | |
143 | this->store_data( |
144 | false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src); |
145 | |
146 | const auto store_tail_simd = [&](Xmm tmp_reg, int vlen) { |
147 | this->vmovups(tmp_reg, this->EVEX_compress_addr(rsp, tmp_stack_offset)); |
148 | this->vmovups(this->EVEX_compress_addr(dst, dst_mem_offset), tmp_reg); |
149 | tmp_stack_offset += vlen * acc_size; |
150 | dst_mem_offset += vlen * acc_size; |
151 | tail_value -= vlen; |
152 | }; |
153 | |
154 | if (tail_value >= 8) store_tail_simd(this->yreg(0, tmp_idx), 8); |
155 | if (tail_value >= 4) store_tail_simd(this->xreg(0, tmp_idx), 4); |
156 | |
157 | for (int i = 0; i < tail_value; |
158 | ++i, tmp_stack_offset += acc_size, dst_mem_offset += acc_size) { |
159 | this->vmovss(this->xreg(0, tmp_idx), |
160 | this->EVEX_compress_addr(rsp, tmp_stack_offset)); |
161 | this->vmovss(this->EVEX_compress_addr(dst, dst_mem_offset), |
162 | this->xreg(0, tmp_idx)); |
163 | } |
164 | } |
165 | |
166 | template <> |
167 | void jit_avx512_common_lrn_kernel_bwd_t<bf16>::store_tail(int tail_value, |
168 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
169 | int tmp_idx) { |
170 | |
171 | this->store_data( |
172 | false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src); |
173 | const auto res = std::div(tail_value, 4); |
174 | |
175 | for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size, |
176 | dst_mem_offset += 4 * acc_bf_16_size) { |
177 | this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]); |
178 | this->mov(qword[dst + dst_mem_offset], this->imm_addr64_); |
179 | } |
180 | |
181 | for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size, |
182 | dst_mem_offset += acc_bf_16_size) { |
183 | this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]); |
184 | this->mov(word[dst + dst_mem_offset], this->imm_addr16_); |
185 | } |
186 | } |
187 | |
188 | // Copy-paste of bf16 from above |
189 | template <> |
190 | void jit_avx512_common_lrn_kernel_bwd_t<f16>::store_tail(int tail_value, |
191 | Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset, |
192 | int tmp_idx) { |
193 | |
194 | this->store_data( |
195 | false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src); |
196 | const auto res = std::div(tail_value, 4); |
197 | |
198 | for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size, |
199 | dst_mem_offset += 4 * acc_bf_16_size) { |
200 | this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]); |
201 | this->mov(qword[dst + dst_mem_offset], this->imm_addr64_); |
202 | } |
203 | |
204 | for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size, |
205 | dst_mem_offset += acc_bf_16_size) { |
206 | this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]); |
207 | this->mov(word[dst + dst_mem_offset], this->imm_addr16_); |
208 | } |
209 | } |
210 | |
211 | template <data_type_t d_type> |
212 | jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_avx512_common_lrn_kernel_bwd_t( |
213 | float alpha, float beta, int local_size, void *code_ptr, |
214 | size_t code_size, const char *name) |
215 | : jit_generator(name, code_ptr, code_size, true, avx512_core_bf16) |
216 | , local_size_ {local_size - !(local_size % 2)} |
217 | , z_prev_ {[this]() { |
218 | std::vector<int> v(this->local_size_ / 2); |
219 | std::iota(v.begin(), v.end(), 3); |
220 | return v; |
221 | }()} |
222 | , z_next_ {[this]() { |
223 | std::vector<int> v(this->local_size_ / 2); |
224 | std::iota(v.begin(), v.end(), 3 + this->local_size_ / 2); |
225 | return v; |
226 | }()} |
227 | , nalphabeta_(-2 * alpha * beta) |
228 | , emulateBfloat_(d_type == bf16 && !mayiuse(avx512_core_bf16)) |
229 | , regs_used_per_block_ {std::max(this->local_size_ + 2, 7)} |
230 | , reg_block_ {[this]() { |
231 | const int max_possible_reg_block |
232 | = (emulateBfloat_ ? 27 : 31) / this->regs_used_per_block_; |
233 | return mayiuse(avx512_core) ? max_possible_reg_block |
234 | : std::min(max_possible_reg_block, 2); |
235 | }()} { |
236 | if (emulateBfloat_) { |
237 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
238 | bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_, |
239 | bf16_emu_scratch_, bf16_emu_reserv_4_); |
240 | } |
241 | } |
242 | |
243 | template <data_type_t d_type> |
244 | Zmm jit_avx512_common_lrn_kernel_bwd_t<d_type>::zreg(int irb, int i) const { |
245 | return Zmm(irb * regs_used_per_block_ + i); |
246 | } |
247 | |
248 | template <data_type_t d_type> |
249 | Ymm jit_avx512_common_lrn_kernel_bwd_t<d_type>::yreg(int irb, int i) const { |
250 | return Ymm(irb * regs_used_per_block_ + i); |
251 | } |
252 | |
253 | template <data_type_t d_type> |
254 | Xmm jit_avx512_common_lrn_kernel_bwd_t<d_type>::xreg(int irb, int i) const { |
255 | return Xmm(irb * regs_used_per_block_ + i); |
256 | } |
257 | |
258 | template class jit_avx512_common_lrn_kernel_bwd_t<f32>; |
259 | template class jit_avx512_common_lrn_kernel_bwd_t<bf16>; |
260 | template class jit_avx512_common_lrn_kernel_bwd_t<f16>; |
261 | |
262 | } // namespace lrn |
263 | } // namespace x64 |
264 | } // namespace cpu |
265 | } // namespace impl |
266 | } // namespace dnnl |
267 | |