1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_blocked.hpp" |
17 | |
18 | namespace dnnl { |
19 | namespace impl { |
20 | namespace cpu { |
21 | namespace x64 { |
22 | namespace lrn { |
23 | |
24 | using acc_data_t = float; |
25 | |
26 | template <data_type_t d_type> |
27 | jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>:: |
28 | jit_avx512_common_lrn_kernel_bwd_blocked_t( |
29 | const struct nChw16c_across_t &J, float alpha, float beta, |
30 | int local_size, int use_h_parallel, void *code_ptr, |
31 | size_t code_size) |
32 | : jit_avx512_common_lrn_kernel_bwd_t<d_type>( |
33 | alpha, beta, local_size, code_ptr, code_size, jit_name()) |
34 | , xmm_size_ {4 * sizeof(acc_data_t)} |
35 | , zmm_size_ {64} |
36 | , buffer_block_ {xmm_size_ + zmm_size_ + xmm_size_} |
37 | , buffer_nest_offset_ {xmm_size_ + zmm_size_} |
38 | , src_prev_offset_ {static_cast<int>(this->vlen_ - 4 * sizeof(data_t))} |
39 | , use_h_parallelism_(use_h_parallel) { |
40 | W_ = J.W; |
41 | HW_ = J.H * J.W; |
42 | version_ = J.version; |
43 | } |
44 | |
45 | template <data_type_t d_type> |
46 | void jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>::generate() { |
47 | |
48 | this->preamble(); |
49 | if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16(); |
50 | |
51 | #define GET_OFF(field) offsetof(jit_args_bwd_t, field) |
52 | this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]); |
53 | this->mov(this->diffdst_, ptr[this->param_ + GET_OFF(diff_dst)]); |
54 | this->mov(this->workspace0_, ptr[this->param_ + GET_OFF(ws0)]); |
55 | this->mov(this->workspace1_, ptr[this->param_ + GET_OFF(ws1)]); |
56 | this->mov(this->diffsrc_, ptr[this->param_ + GET_OFF(diff_src)]); |
57 | #undef GET_OFF |
58 | |
59 | int LSB = this->use_h_parallelism_ ? W_ : HW_; |
60 | |
61 | this->sub(this->rsp, this->reg_block_ * buffer_block_); |
62 | this->mov(this->imm_addr64_, float2int(this->nalphabeta_)); |
63 | this->vmovq(this->xnalphabeta_, this->imm_addr64_); |
64 | this->vbroadcastss(this->znalphabeta_, this->xnalphabeta_); |
65 | |
66 | if (version_ == across_version::First |
67 | || version_ == across_version::Single) { |
68 | this->uni_vpxor(xmm1, xmm1, xmm1); |
69 | for (int irb = 0; irb < this->reg_block_; irb++) { |
70 | this->vmovups(ptr[this->rsp + irb * buffer_block_], xmm1); |
71 | } |
72 | } |
73 | if (version_ == across_version::Last |
74 | || version_ == across_version::Single) { |
75 | this->uni_vpxor(xmm1, xmm1, xmm1); |
76 | for (int irb = 0; irb < this->reg_block_; irb++) { |
77 | this->vmovups( |
78 | ptr[this->rsp + irb * buffer_block_ + buffer_nest_offset_], |
79 | xmm1); |
80 | } |
81 | } |
82 | |
83 | int LSREST = LSB % this->reg_block_; |
84 | int LS = LSB - LSREST; |
85 | |
86 | Label lrn_loop; |
87 | |
88 | if (LS > 0) { |
89 | this->mov(hw_, LS); |
90 | |
91 | this->L(lrn_loop); |
92 | { |
93 | compute_loop(this->reg_block_); |
94 | |
95 | this->add(this->src_, this->reg_block_ * this->vlen_); |
96 | this->add(this->diffsrc_, this->reg_block_ * this->vlen_); |
97 | this->add(this->diffdst_, this->reg_block_ * this->vlen_); |
98 | this->add(this->workspace0_, this->reg_block_ * this->vlen_); |
99 | this->add(this->workspace1_, this->reg_block_ * this->vlen_); |
100 | |
101 | for (int irb = 0; irb < this->reg_block_; irb++) |
102 | this->dec(hw_); |
103 | this->cmp(hw_, 0); |
104 | this->jne(lrn_loop, this->T_NEAR); |
105 | } |
106 | } |
107 | |
108 | compute_loop(LSREST); |
109 | |
110 | this->add(this->rsp, this->reg_block_ * buffer_block_); |
111 | this->postamble(); |
112 | } |
113 | |
114 | template <data_type_t d_type> |
115 | void jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>::compute_loop( |
116 | int loop_size_param) { |
117 | // loop_size - this->param_ for IRB_LOOP macro |
118 | int loop_size = loop_size_param; |
119 | |
120 | if (loop_size_param == 0) return; |
121 | |
122 | if (version_ != across_version::First |
123 | && version_ != across_version::Single) { |
124 | IRB_LOOP(this->load_data(this->xreg(irb, xws1_prev_), |
125 | ptr[this->workspace1_ + (irb - 2 * HW_) * this->vlen_ |
126 | + src_prev_offset_])); |
127 | IRB_LOOP(this->load_data(this->xreg(irb, xdiffdst_prev_), |
128 | ptr[this->diffdst_ + (irb - HW_) * this->vlen_ |
129 | + src_prev_offset_])); |
130 | IRB_LOOP(this->vmulps(this->xreg(irb, xdiffdst_prev_), |
131 | this->xreg(irb, xdiffdst_prev_), this->xreg(irb, xws1_prev_))); |
132 | } |
133 | |
134 | IRB_LOOP(this->load_data(this->zreg(irb, zws1_), |
135 | this->EVEX_compress_addr(this->workspace1_, irb * this->vlen_))); |
136 | IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffdst_), |
137 | this->EVEX_compress_addr(this->diffdst_, irb * this->vlen_))); |
138 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zdiffsrc_), |
139 | this->zreg(irb, this->zdiffdst_), this->zreg(irb, zws1_))); |
140 | |
141 | if (version_ != across_version::Last |
142 | && version_ != across_version::Single) { |
143 | IRB_LOOP(this->load_data(this->xreg(irb, xws1_next_), |
144 | ptr[this->workspace1_ + (irb + 2 * HW_) * this->vlen_])); |
145 | IRB_LOOP(this->load_data(this->xreg(irb, xdiffdst_next_), |
146 | ptr[this->diffdst_ + (irb + HW_) * this->vlen_])); |
147 | IRB_LOOP(this->vmulps(this->xreg(irb, xdiffdst_next_), |
148 | this->xreg(irb, xdiffdst_next_), this->xreg(irb, xws1_next_))); |
149 | } |
150 | |
151 | if (version_ != across_version::First |
152 | && version_ != across_version::Single) { |
153 | IRB_LOOP(this->vmovups(ptr[this->rsp + irb * buffer_block_], |
154 | this->xreg(irb, xdiffdst_prev_))); |
155 | } |
156 | IRB_LOOP(this->vmovups(this->EVEX_compress_addr( |
157 | this->rsp, irb * buffer_block_ + xmm_size_), |
158 | this->zreg(irb, this->zdiffsrc_))); |
159 | if (version_ != across_version::Last |
160 | && version_ != across_version::Single) { |
161 | IRB_LOOP(this->vmovups( |
162 | ptr[this->rsp + irb * buffer_block_ + buffer_nest_offset_], |
163 | this->xreg(irb, xdiffdst_next_))); |
164 | } |
165 | size_t acc_size = sizeof(acc_data_t); |
166 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[0]), |
167 | this->EVEX_compress_addr(this->rsp, |
168 | irb * buffer_block_ + xmm_size_ - 2 * acc_size))); |
169 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[1]), |
170 | this->EVEX_compress_addr(this->rsp, |
171 | irb * buffer_block_ + xmm_size_ - 1 * acc_size))); |
172 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[0]), |
173 | this->EVEX_compress_addr(this->rsp, |
174 | irb * buffer_block_ + xmm_size_ + 1 * acc_size))); |
175 | IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[1]), |
176 | this->EVEX_compress_addr(this->rsp, |
177 | irb * buffer_block_ + xmm_size_ + 2 * acc_size))); |
178 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
179 | this->zreg(irb, this->zdiffsrc_), |
180 | this->zreg(irb, this->z_prev_[0]))); |
181 | assert(this->zsrc_ == this->z_prev_[0]); |
182 | IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_), |
183 | this->EVEX_compress_addr(this->src_, irb * this->vlen_))); |
184 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
185 | this->zreg(irb, this->zdiffsrc_), |
186 | this->zreg(irb, this->z_prev_[1]))); |
187 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
188 | this->zreg(irb, this->zdiffsrc_), |
189 | this->zreg(irb, this->z_next_[0]))); |
190 | IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_), |
191 | this->zreg(irb, this->zdiffsrc_), |
192 | this->zreg(irb, this->z_next_[1]))); |
193 | IRB_LOOP(this->vmulps(this->zreg(irb, this->zsrc_), |
194 | this->zreg(irb, this->zsrc_), this->znalphabeta_)); |
195 | |
196 | IRB_LOOP(this->load_data(this->zreg(irb, this->zws0_), |
197 | this->EVEX_compress_addr(this->workspace0_, irb * this->vlen_))); |
198 | IRB_LOOP(this->vdivps(this->zreg(irb, this->zdiffdst_), |
199 | this->zreg(irb, this->zdiffdst_), this->zreg(irb, this->zws0_))); |
200 | IRB_LOOP(this->vfmadd213ps(this->zreg(irb, this->zdiffsrc_), |
201 | this->zreg(irb, this->zsrc_), this->zreg(irb, this->zdiffdst_))); |
202 | |
203 | Label unaligned_store, end_store; |
204 | this->test(this->diffsrc_, this->vlen_ - 1); |
205 | this->jnz(unaligned_store, this->T_NEAR); |
206 | IRB_LOOP(this->store_data(true, |
207 | this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_), |
208 | this->zreg(irb, this->zdiffsrc_))); |
209 | this->jmp(end_store, this->T_NEAR); |
210 | this->L(unaligned_store); |
211 | { |
212 | IRB_LOOP(this->store_data(false, |
213 | this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_), |
214 | this->zreg(irb, this->zdiffsrc_))); |
215 | } |
216 | this->L(end_store); |
217 | } |
218 | |
219 | template class jit_avx512_common_lrn_kernel_bwd_blocked_t<f32>; |
220 | template class jit_avx512_common_lrn_kernel_bwd_blocked_t<bf16>; |
221 | template class jit_avx512_common_lrn_kernel_bwd_blocked_t<f16>; |
222 | |
223 | } // namespace lrn |
224 | } // namespace x64 |
225 | } // namespace cpu |
226 | } // namespace impl |
227 | } // namespace dnnl |
228 | |