1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_blocked.hpp" |
17 | |
18 | namespace dnnl { |
19 | namespace impl { |
20 | namespace cpu { |
21 | namespace x64 { |
22 | namespace lrn { |
23 | |
24 | template <data_type_t d_type> |
25 | jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>:: |
26 | jit_avx512_common_lrn_kernel_fwd_blocked_t( |
27 | const struct nChw16c_across_t &J, prop_kind_t prop_kind, |
28 | int use_h_parallel, float alpha, float beta, float k, |
29 | int local_size, void *code_ptr, size_t code_size) |
30 | : jit_avx512_common_lrn_kernel_fwd_t<d_type>(prop_kind, alpha, beta, k, |
31 | local_size, code_ptr, code_size, jit_name()) |
32 | , use_h_parallelism_(use_h_parallel) { |
33 | // some registers needed for conversion from bf16 to f32 |
34 | src_prev_offset_ = this->vlen_ - 4 * sizeof(data_t); |
35 | version_ = J.version; |
36 | W_ = J.W; |
37 | HW_ = J.W * J.H; |
38 | xmm_size_ = 4 * sizeof(acc_data_t); |
39 | zmm_size_ = 64; |
40 | buffer_block_ = xmm_size_ + zmm_size_ + xmm_size_; |
41 | buffer_nest_offset_ = xmm_size_ + zmm_size_; |
42 | } |
43 | |
44 | template <data_type_t d_type> |
45 | void jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>::generate() { |
46 | this->preamble(); |
47 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
48 | |
49 | #define GET_OFF(field) \ |
50 | offsetof(typename jit_avx512_common_lrn_kernel_fwd_t< \ |
51 | d_type>::jit_args_fwd_t, \ |
52 | field) |
53 | this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]); |
54 | this->mov(this->dst_, ptr[this->param_ + GET_OFF(dst)]); |
55 | if (this->pk_ != prop_kind::forward_inference) { |
56 | this->mov(this->ws0_, ptr[this->param_ + GET_OFF(ws0)]); |
57 | this->mov(this->ws1_, ptr[this->param_ + GET_OFF(ws1)]); |
58 | } |
59 | #undef GET_OFF |
60 | |
61 | int LSB = use_h_parallelism_ ? W_ : HW_; |
62 | |
63 | this->sub(t_, this->reg_block_ * buffer_block_); |
64 | this->mov(this->imm_addr64_, float2int(this->alpha_)); |
65 | this->vmovq(this->xalpha_, this->imm_addr64_); |
66 | this->vbroadcastss(this->zalpha_, this->xalpha_); |
67 | |
68 | this->mov(this->imm_addr64_, float2int(this->k_)); |
69 | this->vmovq(this->xk_, this->imm_addr64_); |
70 | this->vbroadcastss(this->zk_, this->xk_); |
71 | |
72 | if (version_ == across_version::First |
73 | || version_ == across_version::Single) { |
74 | this->uni_vpxor(xmm2, xmm2, xmm2); |
75 | for (int irb = 0; irb < this->reg_block_; irb++) { |
76 | this->vmovups(ptr[t_ + irb * buffer_block_], xmm2); |
77 | } |
78 | } |
79 | if (version_ == across_version::Last |
80 | || version_ == across_version::Single) { |
81 | this->uni_vpxor(xmm2, xmm2, xmm2); |
82 | for (int irb = 0; irb < this->reg_block_; irb++) { |
83 | this->vmovups( |
84 | ptr[t_ + irb * buffer_block_ + buffer_nest_offset_], xmm2); |
85 | } |
86 | } |
87 | |
88 | const int LSREST = LSB % this->reg_block_; |
89 | const int LS = LSB - LSREST; |
90 | |
91 | Label lrn_loop; |
92 | |
93 | if (LS > 0) { |
94 | this->mov(hw_, LS); |
95 | |
96 | this->L(lrn_loop); |
97 | { |
98 | compute_loop(this->reg_block_); |
99 | |
100 | this->add(this->src_, this->reg_block_ * this->vlen_); |
101 | this->add(this->dst_, this->reg_block_ * this->vlen_); |
102 | if (this->pk_ != prop_kind::forward_inference) { |
103 | this->add(this->ws0_, this->reg_block_ * this->vlen_); |
104 | this->add(this->ws1_, this->reg_block_ * this->vlen_); |
105 | } |
106 | |
107 | for (int irb = 0; irb < this->reg_block_; irb++) |
108 | this->dec(hw_); |
109 | this->cmp(hw_, 0); |
110 | this->jne(lrn_loop, this->T_NEAR); |
111 | } |
112 | } |
113 | |
114 | compute_loop(LSREST); |
115 | |
116 | this->add(t_, this->reg_block_ * buffer_block_); |
117 | this->postamble(); |
118 | } |
119 | |
120 | template <data_type_t d_type> |
121 | void jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>::compute_loop( |
122 | int loop_size_param) { |
123 | // loop_size - param for IRB_LOOP macro |
124 | int loop_size = loop_size_param; |
125 | |
126 | if (loop_size == 0) return; |
127 | |
128 | // --- loading source data to special buffer to form convenient data layout |
129 | // for ACROSS lrn --- |
130 | if (version_ != across_version::First |
131 | && version_ != across_version::Single) { |
132 | IRB_LOOP(this->load_data(this->xreg(irb, xsrc_prev_), |
133 | ptr[this->src_ + (irb - HW_) * this->vlen_ |
134 | + src_prev_offset_])); |
135 | } |
136 | IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_), |
137 | this->EVEX_compress_addr(this->src_, irb * this->vlen_))); |
138 | if (version_ != across_version::Last |
139 | && version_ != across_version::Single) { |
140 | IRB_LOOP(this->load_data(this->xreg(irb, xsrc_next_), |
141 | ptr[this->src_ + (irb + HW_) * this->vlen_])); |
142 | } |
143 | |
144 | if (version_ != across_version::First |
145 | && version_ != across_version::Single) { |
146 | IRB_LOOP(this->vmovups( |
147 | ptr[t_ + irb * buffer_block_], this->xreg(irb, xsrc_prev_))); |
148 | } |
149 | IRB_LOOP(this->vmovups( |
150 | this->EVEX_compress_addr(t_, irb * buffer_block_ + xmm_size_), |
151 | this->zreg(irb, this->zsrc_))); |
152 | if (version_ != across_version::Last |
153 | && version_ != across_version::Single) { |
154 | IRB_LOOP(this->vmovups( |
155 | ptr[t_ + irb * buffer_block_ + buffer_nest_offset_], |
156 | this->xreg(irb, xsrc_next_))); |
157 | } |
158 | |
159 | // --- perform ACROSS lrn --- |
160 | const size_t acc_size = sizeof(acc_data_t); |
161 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[0]), |
162 | this->EVEX_compress_addr( |
163 | t_, irb * buffer_block_ + xmm_size_ - 2 * acc_size))); |
164 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[1]), |
165 | this->EVEX_compress_addr( |
166 | t_, irb * buffer_block_ + xmm_size_ - acc_size))); |
167 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[0]), |
168 | this->EVEX_compress_addr( |
169 | t_, irb * buffer_block_ + xmm_size_ + acc_size))); |
170 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[1]), |
171 | this->EVEX_compress_addr( |
172 | t_, irb * buffer_block_ + xmm_size_ + 2 * acc_size))); |
173 | |
174 | assert(this->zc_ == this->zsrc_); |
175 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_), |
176 | this->zreg(irb, this->zc_), this->zreg(irb, this->zc_))); |
177 | |
178 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
179 | this->zreg(irb, this->z_prev_[0]), |
180 | this->zreg(irb, this->z_prev_[0]))); |
181 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
182 | this->zreg(irb, this->z_prev_[1]), |
183 | this->zreg(irb, this->z_prev_[1]))); |
184 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
185 | this->zreg(irb, this->z_next_[0]), |
186 | this->zreg(irb, this->z_next_[0]))); |
187 | IRB_LOOP(this->vfmadd231ps(this->zreg(irb, this->zsum_), |
188 | this->zreg(irb, this->z_next_[1]), |
189 | this->zreg(irb, this->z_next_[1]))); |
190 | |
191 | IRB_LOOP(this->vfmadd132ps( |
192 | this->zreg(irb, this->zsum_), this->zk_, this->zalpha_)); |
193 | |
194 | IRB_LOOP(this->vmovaps( |
195 | this->zreg(irb, this->zbase_), this->zreg(irb, this->zsum_))); |
196 | |
197 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum2_), |
198 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_))); |
199 | |
200 | if (this->beta_ != 1) { |
201 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsum_), |
202 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum2_))); |
203 | |
204 | IRB_LOOP(this->vsqrtps( |
205 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_))); |
206 | IRB_LOOP(this->vsqrtps( |
207 | this->zreg(irb, this->zsum_), this->zreg(irb, this->zsum_))); |
208 | } |
209 | |
210 | const int ytmp = this->zsum2_; // temporary ymm for f32->bf16 conversion |
211 | if (this->pk_ != prop_kind::forward_inference) { |
212 | // save intermediate results for lrn backward |
213 | IRB_LOOP(this->store_data( |
214 | this->EVEX_compress_addr(this->ws0_, irb * this->vlen_), |
215 | this->zreg(irb, this->zsum_), this->yreg(irb, ytmp))); |
216 | } |
217 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zdst_), |
218 | this->zreg(irb, this->zsrc_), this->zreg(irb, this->zsum_))); |
219 | // storing to dst |
220 | IRB_LOOP(this->store_data( |
221 | this->EVEX_compress_addr(this->dst_, irb * this->vlen_), |
222 | this->zreg(irb, this->zdst_), this->yreg(irb, ytmp))); |
223 | if (this->pk_ != prop_kind::forward_inference) { |
224 | // calculate and save more intermediate results for lrn backward |
225 | /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */ |
226 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zsum_), |
227 | this->zreg(irb, this->zdst_), this->zreg(irb, this->zbase_))); |
228 | IRB_LOOP(this->store_data( |
229 | this->EVEX_compress_addr(this->ws1_, irb * this->vlen_), |
230 | this->zreg(irb, this->zsum_), this->yreg(irb, ytmp))); |
231 | } |
232 | } |
233 | |
234 | template class jit_avx512_common_lrn_kernel_fwd_blocked_t<f32>; |
235 | template class jit_avx512_common_lrn_kernel_fwd_blocked_t<bf16>; |
236 | template class jit_avx512_common_lrn_kernel_fwd_blocked_t<f16>; |
237 | |
238 | } // namespace lrn |
239 | } // namespace x64 |
240 | } // namespace cpu |
241 | } // namespace impl |
242 | } // namespace dnnl |
243 | |