1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_LRN_KERNEL_HPP
18#define CPU_X64_JIT_UNI_LRN_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/type_helpers.hpp"
22
23#include "cpu/x64/jit_generator.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30struct bf16_emulation_t;
31struct jit_args_fwd_t {
32 const void *src;
33 void *dst, *scratch, *bwd_intermediate_res;
34};
35
36struct jit_args_bwd_t {
37 const void *src, *diff_dst, *scratch, *bwd_intermediate_res;
38 void *diff_src;
39};
40
41struct nchw8c_across_t {
42 /* version:
43 * -1: channels 0..7,
44 * 1: channels C-8 .. C-1,
45 * 0: other channels
46 * 3: channels only for this kernel(without prev and next)
47 */
48 int H, W, version;
49 nchw8c_across_t(int h, int w, int v) : H(h), W(w), version(v) {}
50 nchw8c_across_t() : nchw8c_across_t(0, 0, 0) {}
51};
52
53struct within_config_t {
54 const int H, W, C, size;
55 const format_tag_t dat_tag;
56 within_config_t(int h, int w, int c, int s, format_tag_t dat_tag)
57 : H(h), W(w), C(c), size(s), dat_tag(dat_tag) {}
58 within_config_t() : within_config_t(0, 0, 0, 0, dnnl_format_tag_undef) {}
59};
60
61struct nchw_across_t {
62 int C, HW, tail;
63 nchw_across_t(int c, int hw, int t) : C(c), HW(hw), tail(t) {}
64 nchw_across_t() : nchw_across_t(0, 0, 0) {}
65};
66
67struct nhwc_across_t {
68 int C;
69 nhwc_across_t(int c) : C(c) {}
70 nhwc_across_t() : nhwc_across_t(0) {}
71};
72
73enum class lrn_config_t {
74 none = 0,
75 nchw8c_across,
76 within_config,
77 nchw_across,
78 nhwc_across,
79};
80
81template <class Derived>
82class jit_uni_lrn_kernel_t; // primary template
83
84template <template <cpu_isa_t isa, data_type_t d_type> class Derived,
85 cpu_isa_t isa, data_type_t d_type>
86class jit_uni_lrn_kernel_t<Derived<isa, d_type>> : public jit_generator {
87public:
88 jit_uni_lrn_kernel_t(void *code_ptr = nullptr,
89 size_t code_size = MAX_CODE_SIZE, const char *name = jit_name());
90 jit_uni_lrn_kernel_t(const within_config_t &J, void *code_ptr = nullptr,
91 size_t code_size = MAX_CODE_SIZE, const char *name = jit_name());
92
93 ~jit_uni_lrn_kernel_t();
94
95 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_kernel_t);
96 static constexpr int VECTOR_LENGTH
97 = (isa & avx512_core_bit) == avx512_core_bit ? 16 : 8;
98
99protected:
100 using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
101 Xbyak::Zmm>::type;
102
103 void load_constant(float constant, const Vmm &v_constant,
104 const Xbyak::Xmm &x_constant);
105 void load_data(const Vmm &reg, const Xbyak::Address &p);
106 void store_data(const Xbyak::Address &p, const Vmm &reg);
107 void within_loop(
108 const within_config_t &config, int max_reg_blocks, prop_kind_t pk);
109 void within_body_reg_blocked(int loop_count, int max_reg_block, int hoff,
110 int Hoff, int woff, int Woff, int stride, prop_kind_t pk);
111
112 const bool emulate_bfloat_ = false;
113 const Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28);
114 const Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29);
115 const Xbyak::Reg64 bf16_emu_scratch_ = this->rax;
116 const Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30);
117 const Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31);
118 std::unique_ptr<bf16_emulation_t> bf16_emu_;
119 const Xbyak::Reg64 h_ = this->r9;
120 const Xbyak::Reg64 w_ = this->r10;
121 const Xbyak::Reg64 imm_addr64_ = this->rbx;
122 int single_pixel_offset_
123 = VECTOR_LENGTH * sizeof(typename prec_traits<d_type>::type);
124 int tempIdx_ = 0;
125 int reg_block_idx_ = 0;
126};
127
128template <cpu_isa_t isa, data_type_t d_type>
129class jit_uni_lrn_fwd_kernel_t
130 : public jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>> {
131public:
132 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_t)
133
134 jit_uni_lrn_fwd_kernel_t(const within_config_t &J, float A, float K,
135 prop_kind_t pk, void *code_ptr = nullptr,
136 size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
137 jit_uni_lrn_fwd_kernel_t(const nchw8c_across_t &J, float A, float K,
138 prop_kind_t pk, void *code_ptr = nullptr,
139 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
140 jit_uni_lrn_fwd_kernel_t(const nhwc_across_t &J, float A, float K,
141 prop_kind_t pk, void *code_ptr = nullptr,
142 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
143 jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K,
144 prop_kind_t pk, void *code_ptr = nullptr,
145 size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE);
146 ~jit_uni_lrn_fwd_kernel_t();
147
148private:
149 using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>>;
150
151 void generate() override {
152 switch (config_) {
153 case lrn_config_t::within_config:
154 generate(this->within_config_);
155 return;
156 case lrn_config_t::nchw8c_across:
157 generate(this->nchw8c_across_);
158 return;
159 case lrn_config_t::nhwc_across:
160 generate(this->nhwc_across_);
161 return;
162 case lrn_config_t::nchw_across:
163 generate(this->nchw_across_);
164 return;
165 default: assert(!"Configuration not supported"); return;
166 }
167 }
168 void generate(const within_config_t &config);
169 void generate(const nchw8c_across_t &config);
170 void generate(const nhwc_across_t &config);
171 void generate(const nchw_across_t &config);
172
173public:
174 using Base::VECTOR_LENGTH;
175
176private:
177 friend Base;
178 using typename Base::Vmm;
179
180 void within_body(int hoff, int Hoff, int woff, int Woff, int stride,
181 prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0);
182 void nchw_body(int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask,
183 Xbyak::Ymm ya, Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd,
184 Xbyak::Ymm ye, Xbyak::Ymm ysum);
185 void nchw_body_sse41(int tail, int HW, prop_kind_t pk, Xbyak::Xmm xe_lo,
186 Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
187 void nchw_tail_sse41(int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo,
188 Xbyak::Xmm xtail_hi);
189 void move_data_pointers(int pixel_count, prop_kind_t pk);
190
191 const Xbyak::Reg64 src_ = this->rax;
192 const Xbyak::Reg64 dst_ = this->r8;
193 const Xbyak::Reg64 scratch_ = this->r14;
194 const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx;
195 const Xbyak::Reg64 store_addr_ = this->rbp;
196
197 const Xbyak::Xmm xalpha_ = this->xmm0;
198 const Xbyak::Xmm xk_ = this->xmm1;
199 const Xbyak::Ymm yk_ = this->ymm1;
200 const Vmm valpha_ = Vmm(0);
201 const Vmm vk_ = Vmm(1);
202
203 lrn_config_t config_;
204 const nchw8c_across_t nchw8c_across_;
205 const within_config_t within_config_;
206 const nchw_across_t nchw_across_;
207 const nhwc_across_t nhwc_across_;
208 float alpha_;
209 float k_;
210 prop_kind_t pk_;
211 static constexpr int stack_space_needed_ = 11 * 4 * sizeof(float) + 16;
212};
213
214template <cpu_isa_t isa, data_type_t d_type>
215class jit_uni_lrn_bwd_kernel_t
216 : public jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>> {
217public:
218 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_t)
219
220 jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t &J, float A, float B,
221 int use_h_parallel, void *code_ptr = nullptr,
222 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
223 jit_uni_lrn_bwd_kernel_t(const within_config_t &J, float A, float B,
224 void *code_ptr = nullptr,
225 size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
226
227private:
228 using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>>;
229
230 void generate() override {
231 switch (config_) {
232 case lrn_config_t::nchw8c_across:
233 generate(this->nchw8c_across_);
234 return;
235 case lrn_config_t::within_config:
236 generate(this->within_config_);
237 return;
238 default: assert(!"Configuration not supported"); return;
239 }
240 }
241 void generate(const nchw8c_across_t &config);
242 void generate(const within_config_t &config);
243
244public:
245 using Base::VECTOR_LENGTH;
246
247private:
248 friend Base;
249 using typename Base::Vmm;
250
251 void within_body(int hoff, int Hoff, int woff, int Woff, int stride,
252 prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0);
253 void move_data_pointers(int pixel_count, prop_kind_t pk);
254
255 lrn_config_t config_;
256 const nchw8c_across_t nchw8c_across_;
257 const within_config_t within_config_;
258 prop_kind_t pk_ = prop_kind::backward;
259
260 float nalphabeta_;
261 int use_h_parallelizm_;
262 const Xbyak::Reg64 src_ = this->rax;
263 const Xbyak::Reg64 diffsrc_ = this->r13;
264 const Xbyak::Reg64 diffdst_ = this->r14;
265 const Xbyak::Reg64 scratch_ = this->r15;
266 const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx;
267 const Xbyak::Xmm xnalphabeta_ = this->xmm0;
268 const Vmm vnalphabeta_ = Vmm(0);
269};
270
271} // namespace x64
272} // namespace cpu
273} // namespace impl
274} // namespace dnnl
275
276#endif
277
278// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
279