1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_blocked.hpp"
17
18namespace dnnl {
19namespace impl {
20namespace cpu {
21namespace x64 {
22namespace lrn {
23
24template <data_type_t d_type>
25jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>::
26 jit_avx512_common_lrn_kernel_fwd_blocked_t(
27 const struct nChw16c_across_t &J, prop_kind_t prop_kind,
28 int use_h_parallel, float alpha, float beta, float k,
29 int local_size, void *code_ptr, size_t code_size)
30 : jit_avx512_common_lrn_kernel_fwd_t<d_type>(prop_kind, alpha, beta, k,
31 local_size, code_ptr, code_size, jit_name())
32 , use_h_parallelism_(use_h_parallel) {
33 // some registers needed for conversion from bf16 to f32
34 src_prev_offset_ = this->vlen_ - 4 * sizeof(data_t);
35 version_ = J.version;
36 W_ = J.W;
37 HW_ = J.W * J.H;
38 xmm_size_ = 4 * sizeof(acc_data_t);
39 zmm_size_ = 64;
40 buffer_block_ = xmm_size_ + zmm_size_ + xmm_size_;
41 buffer_nest_offset_ = xmm_size_ + zmm_size_;
42}
43
44template <data_type_t d_type>
45void jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>::generate() {
46 this->preamble();
47 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
48
49#define GET_OFF(field) \
50 offsetof(typename jit_avx512_common_lrn_kernel_fwd_t< \
51 d_type>::jit_args_fwd_t, \
52 field)
53 this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]);
54 this->mov(this->dst_, ptr[this->param_ + GET_OFF(dst)]);
55 if (this->pk_ != prop_kind::forward_inference) {
56 this->mov(this->ws0_, ptr[this->param_ + GET_OFF(ws0)]);
57 this->mov(this->ws1_, ptr[this->param_ + GET_OFF(ws1)]);
58 }
59#undef GET_OFF
60
61 int LSB = use_h_parallelism_ ? W_ : HW_;
62
63 this->sub(t_, this->reg_block_ * buffer_block_);
64 this->mov(this->imm_addr64_, float2int(this->alpha_));
65 this->vmovq(this->xalpha_, this->imm_addr64_);
66 this->vbroadcastss(this->zalpha_, this->xalpha_);
67
68 this->mov(this->imm_addr64_, float2int(this->k_));
69 this->vmovq(this->xk_, this->imm_addr64_);
70 this->vbroadcastss(this->zk_, this->xk_);
71
72 if (version_ == across_version::First
73 || version_ == across_version::Single) {
74 this->uni_vpxor(xmm2, xmm2, xmm2);
75 for (int irb = 0; irb < this->reg_block_; irb++) {
76 this->vmovups(ptr[t_ + irb * buffer_block_], xmm2);
77 }
78 }
79 if (version_ == across_version::Last
80 || version_ == across_version::Single) {
81 this->uni_vpxor(xmm2, xmm2, xmm2);
82 for (int irb = 0; irb < this->reg_block_; irb++) {
83 this->vmovups(
84 ptr[t_ + irb * buffer_block_ + buffer_nest_offset_], xmm2);
85 }
86 }
87
88 const int LSREST = LSB % this->reg_block_;
89 const int LS = LSB - LSREST;
90
91 Label lrn_loop;
92
93 if (LS > 0) {
94 this->mov(hw_, LS);
95
96 this->L(lrn_loop);
97 {
98 compute_loop(this->reg_block_);
99
100 this->add(this->src_, this->reg_block_ * this->vlen_);
101 this->add(this->dst_, this->reg_block_ * this->vlen_);
102 if (this->pk_ != prop_kind::forward_inference) {
103 this->add(this->ws0_, this->reg_block_ * this->vlen_);
104 this->add(this->ws1_, this->reg_block_ * this->vlen_);
105 }
106
107 for (int irb = 0; irb < this->reg_block_; irb++)
108 this->dec(hw_);
109 this->cmp(hw_, 0);
110 this->jne(lrn_loop, this->T_NEAR);
111 }
112 }
113
114 compute_loop(LSREST);
115
116 this->add(t_, this->reg_block_ * buffer_block_);
117 this->postamble();
118}
119
120template <data_type_t d_type>
121void jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>::compute_loop(
122 int loop_size_param) {
123 // loop_size - param for IRB_LOOP macro
124 int loop_size = loop_size_param;
125
126 if (loop_size == 0) return;
127
128 // --- loading source data to special buffer to form convenient data layout
129 // for ACROSS lrn ---
130 if (version_ != across_version::First
131 && version_ != across_version::Single) {
132 IRB_LOOP(this->load_data(this->xreg(irb, xsrc_prev_),
133 ptr[this->src_ + (irb - HW_) * this->vlen_
134 + src_prev_offset_]));
135 }
136 IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_),
137 this->EVEX_compress_addr(this->src_, irb * this->vlen_)));
138 if (version_ != across_version::Last
139 && version_ != across_version::Single) {
140 IRB_LOOP(this->load_data(this->xreg(irb, xsrc_next_),
141 ptr[this->src_ + (irb + HW_) * this->vlen_]));
142 }
143
144 if (version_ != across_version::First
145 && version_ != across_version::Single) {
146 IRB_LOOP(this->vmovups(
147 ptr[t_ + irb * buffer_block_], this->xreg(irb, xsrc_prev_)));
148 }
149 IRB_LOOP(this->vmovups(
150 this->EVEX_compress_addr(t_, irb * buffer_block_ + xmm_size_),
151 this->zreg(irb, this->zsrc_)));
152 if (version_ != across_version::Last
153 && version_ != across_version::Single) {
154 IRB_LOOP(this->vmovups(
155 ptr[t_ + irb * buffer_block_ + buffer_nest_offset_],
156 this->xreg(irb, xsrc_next_)));
157 }
158
159 // --- perform ACROSS lrn ---
160 const size_t acc_size = sizeof(acc_data_t);
161 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[0]),
162 this->EVEX_compress_addr(
163 t_, irb * buffer_block_ + xmm_size_ - 2 * acc_size)));
164 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[1]),
165 this->EVEX_compress_addr(
166 t_, irb * buffer_block_ + xmm_size_ - acc_size)));
167 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[0]),
168 this->EVEX_compress_addr(
169 t_, irb * buffer_block_ + xmm_size_ + acc_size)));
170 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[1]),
171 this->EVEX_compress_addr(
172 t_, irb * buffer_block_ + xmm_size_ + 2 * acc_size)));
173
174 assert(this->zc_ == this->zsrc_);
175 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_),
176 this->zreg(irb, this->zc_), this->zreg(irb, this->zc_)));
177
178 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
179 this->zreg(irb, this->z_prev_[0]),
180 this->zreg(irb, this->z_prev_[0])));
181 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
182 this->zreg(irb, this->z_prev_[1]),
183 this->zreg(irb, this->z_prev_[1])));
184 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
185 this->zreg(irb, this->z_next_[0]),
186 this->zreg(irb, this->z_next_[0])));
187 IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_),
188 this->zreg(irb, this->z_next_[1]),
189 this->zreg(irb, this->z_next_[1])));
190
191 IRB_LOOP(this->vfmadd132ps(
192 this->zreg(irb, this->zsum_), this->zk_, this->zalpha_));
193
194 IRB_LOOP(this->vmovaps(
195 this->zreg(irb, this->zbase_), this->zreg(irb, this->zsum_)));
196
197 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum2_),
198 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_)));
199
200 if (this->beta_ != 1) {
201 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_),
202 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum2_)));
203
204 IRB_LOOP(this->vsqrtps(
205 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_)));
206 IRB_LOOP(this->vsqrtps(
207 this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_)));
208 }
209
210 const int ytmp = this->zsum2_; // temporary ymm for f32->bf16 conversion
211 if (this->pk_ != prop_kind::forward_inference) {
212 // save intermediate results for lrn backward
213 IRB_LOOP(this->store_data(
214 this->EVEX_compress_addr(this->ws0_, irb * this->vlen_),
215 this->zreg(irb, this->zsum_), this->yreg(irb, ytmp)));
216 }
217 IRB_LOOP(this->vdivps(this->zreg(irb, this->zdst_),
218 this->zreg(irb, this->zsrc_), this->zreg(irb, this->zsum_)));
219 // storing to dst
220 IRB_LOOP(this->store_data(
221 this->EVEX_compress_addr(this->dst_, irb * this->vlen_),
222 this->zreg(irb, this->zdst_), this->yreg(irb, ytmp)));
223 if (this->pk_ != prop_kind::forward_inference) {
224 // calculate and save more intermediate results for lrn backward
225 /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
226 IRB_LOOP(this->vdivps(this->zreg(irb, this->zsum_),
227 this->zreg(irb, this->zdst_), this->zreg(irb, this->zbase_)));
228 IRB_LOOP(this->store_data(
229 this->EVEX_compress_addr(this->ws1_, irb * this->vlen_),
230 this->zreg(irb, this->zsum_), this->yreg(irb, ytmp)));
231 }
232}
233
234template class jit_avx512_common_lrn_kernel_fwd_blocked_t<f32>;
235template class jit_avx512_common_lrn_kernel_fwd_blocked_t<bf16>;
236template class jit_avx512_common_lrn_kernel_fwd_blocked_t<f16>;
237
238} // namespace lrn
239} // namespace x64
240} // namespace cpu
241} // namespace impl
242} // namespace dnnl
243