1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_blocked.hpp"
17
18namespace dnnl {
19namespace impl {
20namespace cpu {
21namespace x64 {
22namespace lrn {
23
24using acc_data_t = float;
25
26template <data_type_t d_type>
27jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>::
28 jit_avx512_common_lrn_kernel_bwd_blocked_t(
29 const struct nChw16c_across_t &J, float alpha, float beta,
30 int local_size, int use_h_parallel, void *code_ptr,
31 size_t code_size)
32 : jit_avx512_common_lrn_kernel_bwd_t<d_type>(
33 alpha, beta, local_size, code_ptr, code_size, jit_name())
34 , xmm_size_ {4 * sizeof(acc_data_t)}
35 , zmm_size_ {64}
36 , buffer_block_ {xmm_size_ + zmm_size_ + xmm_size_}
37 , buffer_nest_offset_ {xmm_size_ + zmm_size_}
38 , src_prev_offset_ {static_cast<int>(this->vlen_ - 4 * sizeof(data_t))}
39 , use_h_parallelism_(use_h_parallel) {
40 W_ = J.W;
41 HW_ = J.H * J.W;
42 version_ = J.version;
43}
44
45template <data_type_t d_type>
46void jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>::generate() {
47
48 this->preamble();
49 if (this->bf16_emu_) this->bf16_emu_->init_vcvtneps2bf16();
50
51#define GET_OFF(field) offsetof(jit_args_bwd_t, field)
52 this->mov(this->src_, ptr[this->param_ + GET_OFF(src)]);
53 this->mov(this->diffdst_, ptr[this->param_ + GET_OFF(diff_dst)]);
54 this->mov(this->workspace0_, ptr[this->param_ + GET_OFF(ws0)]);
55 this->mov(this->workspace1_, ptr[this->param_ + GET_OFF(ws1)]);
56 this->mov(this->diffsrc_, ptr[this->param_ + GET_OFF(diff_src)]);
57#undef GET_OFF
58
59 int LSB = this->use_h_parallelism_ ? W_ : HW_;
60
61 this->sub(this->rsp, this->reg_block_ * buffer_block_);
62 this->mov(this->imm_addr64_, float2int(this->nalphabeta_));
63 this->vmovq(this->xnalphabeta_, this->imm_addr64_);
64 this->vbroadcastss(this->znalphabeta_, this->xnalphabeta_);
65
66 if (version_ == across_version::First
67 || version_ == across_version::Single) {
68 this->uni_vpxor(xmm1, xmm1, xmm1);
69 for (int irb = 0; irb < this->reg_block_; irb++) {
70 this->vmovups(ptr[this->rsp + irb * buffer_block_], xmm1);
71 }
72 }
73 if (version_ == across_version::Last
74 || version_ == across_version::Single) {
75 this->uni_vpxor(xmm1, xmm1, xmm1);
76 for (int irb = 0; irb < this->reg_block_; irb++) {
77 this->vmovups(
78 ptr[this->rsp + irb * buffer_block_ + buffer_nest_offset_],
79 xmm1);
80 }
81 }
82
83 int LSREST = LSB % this->reg_block_;
84 int LS = LSB - LSREST;
85
86 Label lrn_loop;
87
88 if (LS > 0) {
89 this->mov(hw_, LS);
90
91 this->L(lrn_loop);
92 {
93 compute_loop(this->reg_block_);
94
95 this->add(this->src_, this->reg_block_ * this->vlen_);
96 this->add(this->diffsrc_, this->reg_block_ * this->vlen_);
97 this->add(this->diffdst_, this->reg_block_ * this->vlen_);
98 this->add(this->workspace0_, this->reg_block_ * this->vlen_);
99 this->add(this->workspace1_, this->reg_block_ * this->vlen_);
100
101 for (int irb = 0; irb < this->reg_block_; irb++)
102 this->dec(hw_);
103 this->cmp(hw_, 0);
104 this->jne(lrn_loop, this->T_NEAR);
105 }
106 }
107
108 compute_loop(LSREST);
109
110 this->add(this->rsp, this->reg_block_ * buffer_block_);
111 this->postamble();
112}
113
114template <data_type_t d_type>
115void jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>::compute_loop(
116 int loop_size_param) {
117 // loop_size - this->param_ for IRB_LOOP macro
118 int loop_size = loop_size_param;
119
120 if (loop_size_param == 0) return;
121
122 if (version_ != across_version::First
123 && version_ != across_version::Single) {
124 IRB_LOOP(this->load_data(this->xreg(irb, xws1_prev_),
125 ptr[this->workspace1_ + (irb - 2 * HW_) * this->vlen_
126 + src_prev_offset_]));
127 IRB_LOOP(this->load_data(this->xreg(irb, xdiffdst_prev_),
128 ptr[this->diffdst_ + (irb - HW_) * this->vlen_
129 + src_prev_offset_]));
130 IRB_LOOP(this->vmulps(this->xreg(irb, xdiffdst_prev_),
131 this->xreg(irb, xdiffdst_prev_), this->xreg(irb, xws1_prev_)));
132 }
133
134 IRB_LOOP(this->load_data(this->zreg(irb, zws1_),
135 this->EVEX_compress_addr(this->workspace1_, irb * this->vlen_)));
136 IRB_LOOP(this->load_data(this->zreg(irb, this->zdiffdst_),
137 this->EVEX_compress_addr(this->diffdst_, irb * this->vlen_)));
138 IRB_LOOP(this->vmulps(this->zreg(irb, this->zdiffsrc_),
139 this->zreg(irb, this->zdiffdst_), this->zreg(irb, zws1_)));
140
141 if (version_ != across_version::Last
142 && version_ != across_version::Single) {
143 IRB_LOOP(this->load_data(this->xreg(irb, xws1_next_),
144 ptr[this->workspace1_ + (irb + 2 * HW_) * this->vlen_]));
145 IRB_LOOP(this->load_data(this->xreg(irb, xdiffdst_next_),
146 ptr[this->diffdst_ + (irb + HW_) * this->vlen_]));
147 IRB_LOOP(this->vmulps(this->xreg(irb, xdiffdst_next_),
148 this->xreg(irb, xdiffdst_next_), this->xreg(irb, xws1_next_)));
149 }
150
151 if (version_ != across_version::First
152 && version_ != across_version::Single) {
153 IRB_LOOP(this->vmovups(ptr[this->rsp + irb * buffer_block_],
154 this->xreg(irb, xdiffdst_prev_)));
155 }
156 IRB_LOOP(this->vmovups(this->EVEX_compress_addr(
157 this->rsp, irb * buffer_block_ + xmm_size_),
158 this->zreg(irb, this->zdiffsrc_)));
159 if (version_ != across_version::Last
160 && version_ != across_version::Single) {
161 IRB_LOOP(this->vmovups(
162 ptr[this->rsp + irb * buffer_block_ + buffer_nest_offset_],
163 this->xreg(irb, xdiffdst_next_)));
164 }
165 size_t acc_size = sizeof(acc_data_t);
166 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[0]),
167 this->EVEX_compress_addr(this->rsp,
168 irb * buffer_block_ + xmm_size_ - 2 * acc_size)));
169 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_prev_[1]),
170 this->EVEX_compress_addr(this->rsp,
171 irb * buffer_block_ + xmm_size_ - 1 * acc_size)));
172 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[0]),
173 this->EVEX_compress_addr(this->rsp,
174 irb * buffer_block_ + xmm_size_ + 1 * acc_size)));
175 IRB_LOOP(this->vmovups(this->zreg(irb, this->z_next_[1]),
176 this->EVEX_compress_addr(this->rsp,
177 irb * buffer_block_ + xmm_size_ + 2 * acc_size)));
178 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
179 this->zreg(irb, this->zdiffsrc_),
180 this->zreg(irb, this->z_prev_[0])));
181 assert(this->zsrc_ == this->z_prev_[0]);
182 IRB_LOOP(this->load_data(this->zreg(irb, this->zsrc_),
183 this->EVEX_compress_addr(this->src_, irb * this->vlen_)));
184 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
185 this->zreg(irb, this->zdiffsrc_),
186 this->zreg(irb, this->z_prev_[1])));
187 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
188 this->zreg(irb, this->zdiffsrc_),
189 this->zreg(irb, this->z_next_[0])));
190 IRB_LOOP(this->vaddps(this->zreg(irb, this->zdiffsrc_),
191 this->zreg(irb, this->zdiffsrc_),
192 this->zreg(irb, this->z_next_[1])));
193 IRB_LOOP(this->vmulps(this->zreg(irb, this->zsrc_),
194 this->zreg(irb, this->zsrc_), this->znalphabeta_));
195
196 IRB_LOOP(this->load_data(this->zreg(irb, this->zws0_),
197 this->EVEX_compress_addr(this->workspace0_, irb * this->vlen_)));
198 IRB_LOOP(this->vdivps(this->zreg(irb, this->zdiffdst_),
199 this->zreg(irb, this->zdiffdst_), this->zreg(irb, this->zws0_)));
200 IRB_LOOP(this->vfmadd213ps(this->zreg(irb, this->zdiffsrc_),
201 this->zreg(irb, this->zsrc_), this->zreg(irb, this->zdiffdst_)));
202
203 Label unaligned_store, end_store;
204 this->test(this->diffsrc_, this->vlen_ - 1);
205 this->jnz(unaligned_store, this->T_NEAR);
206 IRB_LOOP(this->store_data(true,
207 this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_),
208 this->zreg(irb, this->zdiffsrc_)));
209 this->jmp(end_store, this->T_NEAR);
210 this->L(unaligned_store);
211 {
212 IRB_LOOP(this->store_data(false,
213 this->EVEX_compress_addr(this->diffsrc_, irb * this->vlen_),
214 this->zreg(irb, this->zdiffsrc_)));
215 }
216 this->L(end_store);
217}
218
219template class jit_avx512_common_lrn_kernel_bwd_blocked_t<f32>;
220template class jit_avx512_common_lrn_kernel_bwd_blocked_t<bf16>;
221template class jit_avx512_common_lrn_kernel_bwd_blocked_t<f16>;
222
223} // namespace lrn
224} // namespace x64
225} // namespace cpu
226} // namespace impl
227} // namespace dnnl
228