1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <numeric>
18#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_base.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24namespace lrn {
25
26static constexpr int acc_size = sizeof(acc_data_t);
27static constexpr int acc_bf_16_size = sizeof(acc_data_bf16_t);
28
29template <data_type_t d_type>
30const int32_t
31 jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_args_fwd_t::mask[48]
32 = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 17, 18, 19, 20,
33 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 0, 0, 0, 0, 0, 0, 0,
34 0, 0, 0, 0, 0, 0, 0, 0, 0};
35
36template <data_type_t d_type>
37jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_args_fwd_t::jit_args_fwd_t()
38 : src(nullptr)
39 , dst(nullptr)
40 , ws0(nullptr)
41 , ws1(nullptr)
42 , mask_ptr(&mask[16]) {}
43
44template <>
45void jit_avx512_common_lrn_kernel_fwd_t<f32>::load_data(
46 Xmm reg, const Address p, bool from_stack) {
47 this->vmovups(reg, p);
48}
49
50template <>
51void jit_avx512_common_lrn_kernel_fwd_t<bf16>::load_data(
52 Xmm reg, const Address p, bool from_stack) {
53 if (!from_stack) {
54 this->vpmovzxwd(reg, p);
55 this->vpslld(reg, reg, 0x10);
56 } else
57 this->vmovups(reg, p);
58}
59
60template <>
61void jit_avx512_common_lrn_kernel_fwd_t<f16>::load_data(
62 Xmm reg, const Address p, bool from_stack) {
63 if (!from_stack) {
64 this->vcvtph2psx(reg, p);
65 } else
66 this->vmovups(reg, p);
67}
68
69template <data_type_t d_type>
70void jit_avx512_common_lrn_kernel_fwd_t<d_type>::load_tail(int tail_value,
71 Reg64 src, int src_mem_offset, int dst_stack_offset,
72 int tmp_load_to_stack_idx_tail) {
73
74 // TODO: Investigate if this method can be simplified by using mask or
75 // jit_generator load utilities.
76 static constexpr auto src_size = sizeof(data_t);
77 auto tmp_xreg = this->xreg(0, tmp_load_to_stack_idx_tail);
78
79 const auto load_tail_simd = [&](Xmm tmp_reg, int vlen) {
80 this->load_data(tmp_reg, this->EVEX_compress_addr(src, src_mem_offset));
81 this->vmovups(this->EVEX_compress_addr(rsp, dst_stack_offset), tmp_reg);
82 dst_stack_offset += vlen * acc_size;
83 src_mem_offset += vlen * src_size;
84 tail_value -= vlen;
85 };
86
87 if (tail_value >= 8)
88 load_tail_simd(this->yreg(0, tmp_load_to_stack_idx_tail), 8);
89 if (tail_value >= 4) load_tail_simd(tmp_xreg, 4);
90
91 for (int i = 0; i < tail_value; ++i) {
92 if (d_type == bf16) {
93 this->movzx(this->imm_addr64_, word[src + src_mem_offset]);
94 this->vmovq(tmp_xreg, this->imm_addr64_);
95 this->vpslld(tmp_xreg, tmp_xreg, 0x10);
96 } else if (d_type == f16) {
97 this->vxorps(tmp_xreg, tmp_xreg, tmp_xreg);
98 this->vcvtsh2ss(tmp_xreg, tmp_xreg,
99 this->EVEX_compress_addr(src, src_mem_offset));
100 } else
101 this->vmovss(
102 tmp_xreg, this->EVEX_compress_addr(src, src_mem_offset));
103
104 this->vmovss(ptr[rsp + dst_stack_offset], tmp_xreg);
105
106 dst_stack_offset += acc_size;
107 src_mem_offset += src_size;
108 }
109}
110
111template <>
112void jit_avx512_common_lrn_kernel_fwd_t<f16>::store_data(
113 const Address addr, Zmm zr, Ymm yr) {
114 this->vcvtps2ph(addr, zr, this->_op_mxcsr);
115}
116
117template <>
118void jit_avx512_common_lrn_kernel_fwd_t<bf16>::store_data(
119 const Address addr, Zmm zr, Ymm yr) {
120 if (emulateBfloat_)
121 this->bf16_emu_->vcvtneps2bf16(yr, zr);
122 else
123 this->vcvtneps2bf16(yr, zr);
124
125 this->vmovdqu16(addr, yr);
126}
127
128template <>
129void jit_avx512_common_lrn_kernel_fwd_t<f32>::store_data(
130 const Address addr, Zmm zr, Ymm yr) {
131 this->vmovups(addr, zr);
132}
133
134template <>
135void jit_avx512_common_lrn_kernel_fwd_t<f32>::store_tail(int tail_value,
136 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
137 int tmp_idx) {
138
139 this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src,
140 this->yreg(0, tmp_idx));
141
142 const auto store_tail_simd = [&](Xmm tmp_reg, int vlen) {
143 this->vmovups(tmp_reg, this->EVEX_compress_addr(rsp, tmp_stack_offset));
144 this->vmovups(this->EVEX_compress_addr(dst, dst_mem_offset), tmp_reg);
145 tmp_stack_offset += vlen * acc_size;
146 dst_mem_offset += vlen * acc_size;
147 tail_value -= vlen;
148 };
149
150 if (tail_value >= 8) store_tail_simd(this->yreg(0, tmp_idx), 8);
151 if (tail_value >= 4) store_tail_simd(this->xreg(0, tmp_idx), 4);
152
153 for (int i = 0; i < tail_value;
154 ++i, tmp_stack_offset += acc_size, dst_mem_offset += acc_size) {
155 this->vmovss(this->xreg(0, tmp_idx),
156 this->EVEX_compress_addr(rsp, tmp_stack_offset));
157 this->vmovss(this->EVEX_compress_addr(dst, dst_mem_offset),
158 this->xreg(0, tmp_idx));
159 }
160}
161
162template <>
163void jit_avx512_common_lrn_kernel_fwd_t<bf16>::store_tail(int tail_value,
164 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
165 int tmp_idx) {
166
167 this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src,
168 this->yreg(0, tmp_idx));
169 const auto res = std::div(tail_value, 4);
170
171 for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size,
172 dst_mem_offset += 4 * acc_bf_16_size) {
173 this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]);
174 this->mov(qword[dst + dst_mem_offset], this->imm_addr64_);
175 }
176
177 for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size,
178 dst_mem_offset += acc_bf_16_size) {
179 this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]);
180 this->mov(word[dst + dst_mem_offset], this->imm_addr16_);
181 }
182}
183
184// Copy-paste from above bf16 implementation
185template <>
186void jit_avx512_common_lrn_kernel_fwd_t<f16>::store_tail(int tail_value,
187 Zmm src, Reg64 dst, int dst_mem_offset, int tmp_stack_offset,
188 int tmp_idx) {
189
190 this->store_data(this->EVEX_compress_addr(rsp, tmp_stack_offset), src,
191 this->yreg(0, tmp_idx));
192 const auto res = std::div(tail_value, 4);
193
194 for (int i = 0; i < res.quot; ++i, tmp_stack_offset += 4 * acc_bf_16_size,
195 dst_mem_offset += 4 * acc_bf_16_size) {
196 this->mov(this->imm_addr64_, qword[rsp + tmp_stack_offset]);
197 this->mov(qword[dst + dst_mem_offset], this->imm_addr64_);
198 }
199
200 for (int i = 0; i < res.rem; ++i, tmp_stack_offset += acc_bf_16_size,
201 dst_mem_offset += acc_bf_16_size) {
202 this->mov(this->imm_addr16_, word[rsp + tmp_stack_offset]);
203 this->mov(word[dst + dst_mem_offset], this->imm_addr16_);
204 }
205}
206
207template <data_type_t d_type>
208jit_avx512_common_lrn_kernel_fwd_t<d_type>::jit_avx512_common_lrn_kernel_fwd_t(
209 prop_kind_t prop_kind, float alpha, float beta, float k, int local_size,
210 void *code_ptr, size_t code_size, const char *name)
211 : jit_generator(name, code_ptr, code_size, true, avx512_core_bf16)
212 , pk_(prop_kind)
213 , alpha_(alpha)
214 , beta_(beta)
215 , k_(k)
216 , local_size_ {local_size - !(local_size % 2)}
217 , z_prev_ {[this]() {
218 std::vector<int> v(this->local_size_ / 2);
219 std::iota(v.begin(), v.end(), 3);
220 return v;
221 }()}
222 , z_next_ {[this]() {
223 std::vector<int> v(this->local_size_ / 2);
224 std::iota(v.begin(), v.end(), 3 + this->local_size_ / 2);
225 return v;
226 }()}
227 , zsum_ {std::max(local_size_ + 2, 6)}
228 , emulateBfloat_(d_type == bf16 && !mayiuse(avx512_core_bf16))
229 , regs_used_per_block_ {std::max(this->local_size_ + 2, 6)}
230 , reg_block_ {[this]() {
231 const int max_possible_reg_block
232 = (emulateBfloat_ ? 26 : 30) / this->regs_used_per_block_;
233 return mayiuse(avx512_core) ? max_possible_reg_block
234 : std::min(max_possible_reg_block, 2);
235 }()} {
236 if (emulateBfloat_) {
237 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
238 bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_,
239 bf16_emu_scratch_, bf16_emu_reserv_4_);
240 }
241}
242
243template <data_type_t d_type>
244Zmm jit_avx512_common_lrn_kernel_fwd_t<d_type>::zreg(int irb, int i) const {
245 return Zmm(irb * regs_used_per_block_ + i);
246}
247
248template <data_type_t d_type>
249Ymm jit_avx512_common_lrn_kernel_fwd_t<d_type>::yreg(int irb, int i) const {
250 return Ymm(irb * regs_used_per_block_ + i);
251}
252
253template <data_type_t d_type>
254Xmm jit_avx512_common_lrn_kernel_fwd_t<d_type>::xreg(int irb, int i) const {
255 return Xmm(irb * regs_used_per_block_ + i);
256}
257
258template class jit_avx512_common_lrn_kernel_fwd_t<f32>;
259template class jit_avx512_common_lrn_kernel_fwd_t<bf16>;
260template class jit_avx512_common_lrn_kernel_fwd_t<f16>;
261
262} // namespace lrn
263} // namespace x64
264} // namespace cpu
265} // namespace impl
266} // namespace dnnl
267