1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <numeric>
18#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_base.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24namespace lrn {
25
26static constexpr int acc_size = sizeof(acc_data_t);
27static constexpr int acc_bf_16_size = sizeof(acc_data_bf16_t);
28
29template <data_type_t d_type>
30const int32_t
31 jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_args_bwd_t::mask[48]
32 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19, 20,
33 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0,
34 0, 0, 0, 0, 0, 0, 0, 0, 0};
35
36template <data_type_t d_type>
37jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_args_bwd_t::jit_args_bwd_t()
38 : src(nullptr)
39 , diff_dst(nullptr)
40 , ws0(nullptr)
41 , ws1(nullptr)
42 , diff_src(nullptr)
43 , mask_ptr(&mask[16]) {}
44
45template <>
46void jit_avx512_common_lrn_kernel_bwd_t<f32>::load_data(
47 Xmm reg, const Address p, bool from_stack) {
48 this->vmovups(reg, p);
49}
50
51template <>
52void jit_avx512_common_lrn_kernel_bwd_t<bf16>::load_data(
53 Xmm reg, const Address p, bool from_stack) {
54 if (!from_stack) {
55 this->vpmovzxwd(reg, p);
56 this->vpslld(reg, reg, 0x10);
57 } else
58 this->vmovups(reg, p);
59}
60
61template <>
62void jit_avx512_common_lrn_kernel_bwd_t<f16>::load_data(
63 Xmm reg, const Address p, bool from_stack) {
64 if (!from_stack) {
65 this->vcvtph2psx(reg, p);
66 } else
67 this->vmovups(reg, p);
68}
69
70template <>
71void jit_avx512_common_lrn_kernel_bwd_t<f16>::store_data(
72 bool nt, const Address addr, Zmm zr) {
73 this->vcvtps2ph(addr, zr, this->_op_mxcsr);
74}
75
76template <>
77void jit_avx512_common_lrn_kernel_bwd_t<bf16>::store_data(
78 bool nt, const Address addr, Zmm zr) {
79 const Ymm yr = Ymm(zr.getIdx());
80 if (mayiuse(avx512_core_bf16))
81 vcvtneps2bf16(yr, zr);
82 else
83 bf16_emu_->vcvtneps2bf16(yr, zr);
84 vmovdqu16(addr, yr);
85}
86
87template <>
88void jit_avx512_common_lrn_kernel_bwd_t<f32>::store_data(
89 bool non_temp_hint, const Address addr, Zmm zr) {
90 if (non_temp_hint)
91 uni_vmovntps(addr, zr);
92 else
93 uni_vmovups(addr, zr);
94}
95
96template <data_type_t d_type>
97void jit_avx512_common_lrn_kernel_bwd_t<d_type>::load_tail(int tail_value,
98 Reg64 src, int src_mem_offset, int dst_stack_offset,
99 int tmp_load_to_stack_idx_tail) {
100 // TODO: Investigate if this method can be simplified by using mask or
101 // jit_generator load utilities.
102 static constexpr auto src_acc_size
103 = utils::one_of(d_type, bf16, f16) ? acc_bf_16_size : acc_size;
104 auto tmp_xreg = this->xreg(0, tmp_load_to_stack_idx_tail);
105
106 const auto load_tail_simd = [&](Xmm tmp_reg, int vlen) {
107 this->load_data(tmp_reg, this->EVEX_compress_addr(src, src_mem_offset));
108 this->vmovups(this->EVEX_compress_addr(rsp, dst_stack_offset), tmp_reg);
109 dst_stack_offset += vlen * acc_size;
110 src_mem_offset += vlen * src_acc_size;
111 tail_value -= vlen;
112 };
113
114 if (tail_value >= 8)
115 load_tail_simd(this->yreg(0, tmp_load_to_stack_idx_tail), 8);
116 if (tail_value >= 4) load_tail_simd(tmp_xreg, 4);
117
118 for (int i = 0; i < tail_value; ++i) {
119 if (d_type == bf16) {
120 this->movzx(this->imm_addr64_, word[src + src_mem_offset]);
121 this->vmovq(tmp_xreg, this->imm_addr64_);
122 this->vpslld(tmp_xreg, tmp_xreg, 0x10);
123 } else if (d_type == f16) {
124 this->vxorps(tmp_xreg, tmp_xreg, tmp_xreg);
125 this->vcvtsh2ss(tmp_xreg, tmp_xreg,
126 this->EVEX_compress_addr(src, src_mem_offset));
127 } else
128 this->vmovss(
129 tmp_xreg, this->EVEX_compress_addr(src, src_mem_offset));
130
131 this->vmovss(ptr[rsp + dst_stack_offset], tmp_xreg);
132
133 dst_stack_offset += acc_size;
134 src_mem_offset += src_acc_size;
135 }
136}
137
138template <>
139void jit_avx512_common_lrn_kernel_bwd_t<f32>::store_tail(int tail_value,
140 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
141 int tmp_idx) {
142
143 this->store_data(
144 false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src);
145
146 const auto store_tail_simd = [&](Xmm tmp_reg, int vlen) {
147 this->vmovups(tmp_reg, this->EVEX_compress_addr(rsp, tmp_stack_offset));
148 this->vmovups(this->EVEX_compress_addr(dst, dst_mem_offset), tmp_reg);
149 tmp_stack_offset += vlen * acc_size;
150 dst_mem_offset += vlen * acc_size;
151 tail_value -= vlen;
152 };
153
154 if (tail_value >= 8) store_tail_simd(this->yreg(0, tmp_idx), 8);
155 if (tail_value >= 4) store_tail_simd(this->xreg(0, tmp_idx), 4);
156
157 for (int i = 0; i < tail_value;
158 ++i, tmp_stack_offset += acc_size, dst_mem_offset += acc_size) {
159 this->vmovss(this->xreg(0, tmp_idx),
160 this->EVEX_compress_addr(rsp, tmp_stack_offset));
161 this->vmovss(this->EVEX_compress_addr(dst, dst_mem_offset),
162 this->xreg(0, tmp_idx));
163 }
164}
165
166template <>
167void jit_avx512_common_lrn_kernel_bwd_t<bf16>::store_tail(int tail_value,
168 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
169 int tmp_idx) {
170
171 this->store_data(
172 false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src);
173 const auto res = std::div(tail_value, 4);
174
175 for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size,
176 dst_mem_offset += 4 * acc_bf_16_size) {
177 this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]);
178 this->mov(qword[dst + dst_mem_offset], this->imm_addr64_);
179 }
180
181 for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size,
182 dst_mem_offset += acc_bf_16_size) {
183 this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]);
184 this->mov(word[dst + dst_mem_offset], this->imm_addr16_);
185 }
186}
187
188// Copy-paste of bf16 from above
189template <>
190void jit_avx512_common_lrn_kernel_bwd_t<f16>::store_tail(int tail_value,
191 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
192 int tmp_idx) {
193
194 this->store_data(
195 false, this->EVEX_compress_addr(rsp, tmp_stack_offset), src);
196 const auto res = std::div(tail_value, 4);
197
198 for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size,
199 dst_mem_offset += 4 * acc_bf_16_size) {
200 this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]);
201 this->mov(qword[dst + dst_mem_offset], this->imm_addr64_);
202 }
203
204 for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size,
205 dst_mem_offset += acc_bf_16_size) {
206 this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]);
207 this->mov(word[dst + dst_mem_offset], this->imm_addr16_);
208 }
209}
210
211template <data_type_t d_type>
212jit_avx512_common_lrn_kernel_bwd_t<d_type>::jit_avx512_common_lrn_kernel_bwd_t(
213 float alpha, float beta, int local_size, void *code_ptr,
214 size_t code_size, const char *name)
215 : jit_generator(name, code_ptr, code_size, true, avx512_core_bf16)
216 , local_size_ {local_size - !(local_size % 2)}
217 , z_prev_ {[this]() {
218 std::vector<int> v(this->local_size_ / 2);
219 std::iota(v.begin(), v.end(), 3);
220 return v;
221 }()}
222 , z_next_ {[this]() {
223 std::vector<int> v(this->local_size_ / 2);
224 std::iota(v.begin(), v.end(), 3 + this->local_size_ / 2);
225 return v;
226 }()}
227 , nalphabeta_(-2 * alpha * beta)
228 , emulateBfloat_(d_type == bf16 && !mayiuse(avx512_core_bf16))
229 , regs_used_per_block_ {std::max(this->local_size_ + 2, 7)}
230 , reg_block_ {[this]() {
231 const int max_possible_reg_block
232 = (emulateBfloat_ ? 27 : 31) / this->regs_used_per_block_;
233 return mayiuse(avx512_core) ? max_possible_reg_block
234 : std::min(max_possible_reg_block, 2);
235 }()} {
236 if (emulateBfloat_) {
237 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
238 bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_,
239 bf16_emu_scratch_, bf16_emu_reserv_4_);
240 }
241}
242
243template <data_type_t d_type>
244Zmm jit_avx512_common_lrn_kernel_bwd_t<d_type>::zreg(int irb, int i) const {
245 return Zmm(irb * regs_used_per_block_ + i);
246}
247
248template <data_type_t d_type>
249Ymm jit_avx512_common_lrn_kernel_bwd_t<d_type>::yreg(int irb, int i) const {
250 return Ymm(irb * regs_used_per_block_ + i);
251}
252
253template <data_type_t d_type>
254Xmm jit_avx512_common_lrn_kernel_bwd_t<d_type>::xreg(int irb, int i) const {
255 return Xmm(irb * regs_used_per_block_ + i);
256}
257
258template class jit_avx512_common_lrn_kernel_bwd_t<f32>;
259template class jit_avx512_common_lrn_kernel_bwd_t<bf16>;
260template class jit_avx512_common_lrn_kernel_bwd_t<f16>;
261
262} // namespace lrn
263} // namespace x64
264} // namespace cpu
265} // namespace impl
266} // namespace dnnl
267