1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_LRN_KERNEL_HPP |
18 | #define CPU_X64_JIT_UNI_LRN_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | |
23 | #include "cpu/x64/jit_generator.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | struct bf16_emulation_t; |
31 | struct jit_args_fwd_t { |
32 | const void *src; |
33 | void *dst, *scratch, *bwd_intermediate_res; |
34 | }; |
35 | |
36 | struct jit_args_bwd_t { |
37 | const void *src, *diff_dst, *scratch, *bwd_intermediate_res; |
38 | void *diff_src; |
39 | }; |
40 | |
41 | struct nchw8c_across_t { |
42 | /* version: |
43 | * -1: channels 0..7, |
44 | * 1: channels C-8 .. C-1, |
45 | * 0: other channels |
46 | * 3: channels only for this kernel(without prev and next) |
47 | */ |
48 | int H, W, version; |
49 | nchw8c_across_t(int h, int w, int v) : H(h), W(w), version(v) {} |
50 | nchw8c_across_t() : nchw8c_across_t(0, 0, 0) {} |
51 | }; |
52 | |
53 | struct within_config_t { |
54 | const int H, W, C, size; |
55 | const format_tag_t dat_tag; |
56 | within_config_t(int h, int w, int c, int s, format_tag_t dat_tag) |
57 | : H(h), W(w), C(c), size(s), dat_tag(dat_tag) {} |
58 | within_config_t() : within_config_t(0, 0, 0, 0, dnnl_format_tag_undef) {} |
59 | }; |
60 | |
61 | struct nchw_across_t { |
62 | int C, HW, tail; |
63 | nchw_across_t(int c, int hw, int t) : C(c), HW(hw), tail(t) {} |
64 | nchw_across_t() : nchw_across_t(0, 0, 0) {} |
65 | }; |
66 | |
67 | struct nhwc_across_t { |
68 | int C; |
69 | nhwc_across_t(int c) : C(c) {} |
70 | nhwc_across_t() : nhwc_across_t(0) {} |
71 | }; |
72 | |
73 | enum class lrn_config_t { |
74 | none = 0, |
75 | nchw8c_across, |
76 | within_config, |
77 | nchw_across, |
78 | nhwc_across, |
79 | }; |
80 | |
81 | template <class Derived> |
82 | class jit_uni_lrn_kernel_t; // primary template |
83 | |
84 | template <template <cpu_isa_t isa, data_type_t d_type> class Derived, |
85 | cpu_isa_t isa, data_type_t d_type> |
86 | class jit_uni_lrn_kernel_t<Derived<isa, d_type>> : public jit_generator { |
87 | public: |
88 | jit_uni_lrn_kernel_t(void *code_ptr = nullptr, |
89 | size_t code_size = MAX_CODE_SIZE, const char *name = jit_name()); |
90 | jit_uni_lrn_kernel_t(const within_config_t &J, void *code_ptr = nullptr, |
91 | size_t code_size = MAX_CODE_SIZE, const char *name = jit_name()); |
92 | |
93 | ~jit_uni_lrn_kernel_t(); |
94 | |
95 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_kernel_t); |
96 | static constexpr int VECTOR_LENGTH |
97 | = (isa & avx512_core_bit) == avx512_core_bit ? 16 : 8; |
98 | |
99 | protected: |
100 | using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm, |
101 | Xbyak::Zmm>::type; |
102 | |
103 | void load_constant(float constant, const Vmm &v_constant, |
104 | const Xbyak::Xmm &x_constant); |
105 | void load_data(const Vmm ®, const Xbyak::Address &p); |
106 | void store_data(const Xbyak::Address &p, const Vmm ®); |
107 | void within_loop( |
108 | const within_config_t &config, int max_reg_blocks, prop_kind_t pk); |
109 | void within_body_reg_blocked(int loop_count, int max_reg_block, int hoff, |
110 | int Hoff, int woff, int Woff, int stride, prop_kind_t pk); |
111 | |
112 | const bool emulate_bfloat_ = false; |
113 | const Xbyak::Zmm bf16_emu_reserv_1_ = Xbyak::Zmm(28); |
114 | const Xbyak::Zmm bf16_emu_reserv_2_ = Xbyak::Zmm(29); |
115 | const Xbyak::Reg64 bf16_emu_scratch_ = this->rax; |
116 | const Xbyak::Zmm bf16_emu_reserv_3_ = Xbyak::Zmm(30); |
117 | const Xbyak::Zmm bf16_emu_reserv_4_ = Xbyak::Zmm(31); |
118 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
119 | const Xbyak::Reg64 h_ = this->r9; |
120 | const Xbyak::Reg64 w_ = this->r10; |
121 | const Xbyak::Reg64 imm_addr64_ = this->rbx; |
122 | int single_pixel_offset_ |
123 | = VECTOR_LENGTH * sizeof(typename prec_traits<d_type>::type); |
124 | int tempIdx_ = 0; |
125 | int reg_block_idx_ = 0; |
126 | }; |
127 | |
128 | template <cpu_isa_t isa, data_type_t d_type> |
129 | class jit_uni_lrn_fwd_kernel_t |
130 | : public jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>> { |
131 | public: |
132 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_t) |
133 | |
134 | jit_uni_lrn_fwd_kernel_t(const within_config_t &J, float A, float K, |
135 | prop_kind_t pk, void *code_ptr = nullptr, |
136 | size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
137 | jit_uni_lrn_fwd_kernel_t(const nchw8c_across_t &J, float A, float K, |
138 | prop_kind_t pk, void *code_ptr = nullptr, |
139 | size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
140 | jit_uni_lrn_fwd_kernel_t(const nhwc_across_t &J, float A, float K, |
141 | prop_kind_t pk, void *code_ptr = nullptr, |
142 | size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
143 | jit_uni_lrn_fwd_kernel_t(const nchw_across_t &J, float A, float K, |
144 | prop_kind_t pk, void *code_ptr = nullptr, |
145 | size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
146 | ~jit_uni_lrn_fwd_kernel_t(); |
147 | |
148 | private: |
149 | using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_fwd_kernel_t<isa, d_type>>; |
150 | |
151 | void generate() override { |
152 | switch (config_) { |
153 | case lrn_config_t::within_config: |
154 | generate(this->within_config_); |
155 | return; |
156 | case lrn_config_t::nchw8c_across: |
157 | generate(this->nchw8c_across_); |
158 | return; |
159 | case lrn_config_t::nhwc_across: |
160 | generate(this->nhwc_across_); |
161 | return; |
162 | case lrn_config_t::nchw_across: |
163 | generate(this->nchw_across_); |
164 | return; |
165 | default: assert(!"Configuration not supported" ); return; |
166 | } |
167 | } |
168 | void generate(const within_config_t &config); |
169 | void generate(const nchw8c_across_t &config); |
170 | void generate(const nhwc_across_t &config); |
171 | void generate(const nchw_across_t &config); |
172 | |
173 | public: |
174 | using Base::VECTOR_LENGTH; |
175 | |
176 | private: |
177 | friend Base; |
178 | using typename Base::Vmm; |
179 | |
180 | void within_body(int hoff, int Hoff, int woff, int Woff, int stride, |
181 | prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0); |
182 | void nchw_body(int tail, int HW, prop_kind_t pk, Xbyak::Ymm ymask, |
183 | Xbyak::Ymm ya, Xbyak::Ymm yb, Xbyak::Ymm yc, Xbyak::Ymm yd, |
184 | Xbyak::Ymm ye, Xbyak::Ymm ysum); |
185 | void nchw_body_sse41(int tail, int HW, prop_kind_t pk, Xbyak::Xmm xe_lo, |
186 | Xbyak::Xmm xe_hi, Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi); |
187 | void nchw_tail_sse41(int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, |
188 | Xbyak::Xmm xtail_hi); |
189 | void move_data_pointers(int pixel_count, prop_kind_t pk); |
190 | |
191 | const Xbyak::Reg64 src_ = this->rax; |
192 | const Xbyak::Reg64 dst_ = this->r8; |
193 | const Xbyak::Reg64 scratch_ = this->r14; |
194 | const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx; |
195 | const Xbyak::Reg64 store_addr_ = this->rbp; |
196 | |
197 | const Xbyak::Xmm xalpha_ = this->xmm0; |
198 | const Xbyak::Xmm xk_ = this->xmm1; |
199 | const Xbyak::Ymm yk_ = this->ymm1; |
200 | const Vmm valpha_ = Vmm(0); |
201 | const Vmm vk_ = Vmm(1); |
202 | |
203 | lrn_config_t config_; |
204 | const nchw8c_across_t nchw8c_across_; |
205 | const within_config_t within_config_; |
206 | const nchw_across_t nchw_across_; |
207 | const nhwc_across_t nhwc_across_; |
208 | float alpha_; |
209 | float k_; |
210 | prop_kind_t pk_; |
211 | static constexpr int stack_space_needed_ = 11 * 4 * sizeof(float) + 16; |
212 | }; |
213 | |
214 | template <cpu_isa_t isa, data_type_t d_type> |
215 | class jit_uni_lrn_bwd_kernel_t |
216 | : public jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>> { |
217 | public: |
218 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_t) |
219 | |
220 | jit_uni_lrn_bwd_kernel_t(const nchw8c_across_t &J, float A, float B, |
221 | int use_h_parallel, void *code_ptr = nullptr, |
222 | size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
223 | jit_uni_lrn_bwd_kernel_t(const within_config_t &J, float A, float B, |
224 | void *code_ptr = nullptr, |
225 | size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE); |
226 | |
227 | private: |
228 | using Base = jit_uni_lrn_kernel_t<jit_uni_lrn_bwd_kernel_t<isa, d_type>>; |
229 | |
230 | void generate() override { |
231 | switch (config_) { |
232 | case lrn_config_t::nchw8c_across: |
233 | generate(this->nchw8c_across_); |
234 | return; |
235 | case lrn_config_t::within_config: |
236 | generate(this->within_config_); |
237 | return; |
238 | default: assert(!"Configuration not supported" ); return; |
239 | } |
240 | } |
241 | void generate(const nchw8c_across_t &config); |
242 | void generate(const within_config_t &config); |
243 | |
244 | public: |
245 | using Base::VECTOR_LENGTH; |
246 | |
247 | private: |
248 | friend Base; |
249 | using typename Base::Vmm; |
250 | |
251 | void within_body(int hoff, int Hoff, int woff, int Woff, int stride, |
252 | prop_kind_t pk, int reg_block = 1, int single_pixel_offset = 0); |
253 | void move_data_pointers(int pixel_count, prop_kind_t pk); |
254 | |
255 | lrn_config_t config_; |
256 | const nchw8c_across_t nchw8c_across_; |
257 | const within_config_t within_config_; |
258 | prop_kind_t pk_ = prop_kind::backward; |
259 | |
260 | float nalphabeta_; |
261 | int use_h_parallelizm_; |
262 | const Xbyak::Reg64 src_ = this->rax; |
263 | const Xbyak::Reg64 diffsrc_ = this->r13; |
264 | const Xbyak::Reg64 diffdst_ = this->r14; |
265 | const Xbyak::Reg64 scratch_ = this->r15; |
266 | const Xbyak::Reg64 bwd_intermediate_res_ = this->rdx; |
267 | const Xbyak::Xmm xnalphabeta_ = this->xmm0; |
268 | const Vmm vnalphabeta_ = Vmm(0); |
269 | }; |
270 | |
271 | } // namespace x64 |
272 | } // namespace cpu |
273 | } // namespace impl |
274 | } // namespace dnnl |
275 | |
276 | #endif |
277 | |
278 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
279 | |