1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_BRGEMM_CELL_COMMON_FWD_HPP |
18 | #define CPU_X64_RNN_BRGEMM_CELL_COMMON_FWD_HPP |
19 | |
20 | #include <functional> |
21 | #include "common/bfloat16.hpp" |
22 | #include "cpu/rnn/rnn_utils.hpp" |
23 | #include "cpu/x64/rnn/rnn_brgemm_utils.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | template <typename src_t, typename weights_t, typename scratch_t, |
31 | typename gemm_acc_t> |
32 | class brgemm_dst_layer_iter_t { |
33 | public: |
34 | using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>; |
35 | using postgemm_fused_t = std::function<void( |
36 | dim_t, dim_t, dim_t, const src_t *, scratch_t *, int)>; |
37 | brgemm_dst_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
38 | const rnn_utils::rnn_conf_t &rnn, |
39 | rnn_utils::cell_position_t cell_position, const src_t *src_iter, |
40 | const src_t *src_layer, weights_t *w_iter, weights_t *w_layer, |
41 | scratch_t *scratch_gates, gemm_acc_t *amx_scratchpad, |
42 | x64::brgemm_batch_element_t *addr_batch_global, |
43 | const postgemm_fused_t &fused_postgemm); |
44 | void execute() const; |
45 | |
46 | private: |
47 | void kernel(const int ithr, const int nthr) const; |
48 | void kernel_fused_iter_layer(const int ithr, const int nthr) const; |
49 | |
50 | const ref_rnn_brgemm_t &rnn_brgemm_; |
51 | const rnn_utils::rnn_conf_t &rnn_; |
52 | const bool need_gemm_layer_; |
53 | const dim_t layer_desc_idx_; |
54 | const dim_t iter_desc_idx_; |
55 | const src_t *const Al_; |
56 | const src_t *const Ai_; |
57 | const weights_t *const Bl_; |
58 | const weights_t *const Bi_; |
59 | scratch_t *const C_; |
60 | const dim_t LDAl_; |
61 | const dim_t LDAi_; |
62 | const dim_t max_nthr_; |
63 | const dim_t n_blocking_; |
64 | const dim_t m_blocking_; |
65 | const int work_amount_; |
66 | const dim_t Bl_n_offset_; |
67 | const dim_t Bi_n_offset_; |
68 | const dim_t Bl_g_offset_; |
69 | const dim_t Bi_g_offset_; |
70 | const dim_t Al_k_tail_offset_; |
71 | const dim_t Ai_k_tail_offset_; |
72 | const dim_t Bl_kb_offset_; |
73 | const dim_t Bi_kb_offset_; |
74 | const dim_t Bl_k_tail_offset_; |
75 | const dim_t Bi_k_tail_offset_; |
76 | const dim_t n_gates_; |
77 | const brgemm_kernel_t *const brgemm_kernel_iter_main_; |
78 | const brgemm_kernel_t *const brgemm_kernel_iter_n_tail_; |
79 | const brgemm_kernel_t *const brgemm_kernel_iter_k_tail_; |
80 | const brgemm_kernel_t *const brgemm_kernel_iter_nk_tail_; |
81 | |
82 | const brgemm_kernel_t *const brgemm_kernel_layer_main_; |
83 | const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_; |
84 | const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_; |
85 | const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_; |
86 | |
87 | const char *pallete_buff_iter_main_; |
88 | const char *pallete_buff_iter_n_tail_; |
89 | const char *pallete_buff_iter_k_tail_; |
90 | const char *pallete_buff_iter_nk_tail_; |
91 | |
92 | const char *pallete_buff_layer_main_; |
93 | const char *pallete_buff_layer_n_tail_; |
94 | const char *pallete_buff_layer_k_tail_; |
95 | const char *pallete_buff_layer_nk_tail_; |
96 | |
97 | gemm_acc_t *const amx_scratchpad_; |
98 | brgemm_batch_element_t *const addr_batch_global_; |
99 | const postgemm_fused_t fused_postgemm_; |
100 | const bool is_fused_layer_iter_brgemm_; |
101 | }; |
102 | |
103 | template <typename src_t, typename weights_t, typename gemm_acc_t> |
104 | class brgemm_dst_proj_t { |
105 | public: |
106 | using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>; |
107 | using postgemm_fused_t |
108 | = std::function<void(dim_t, dim_t, gemm_acc_t *, int)>; |
109 | brgemm_dst_proj_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
110 | const rnn_utils::rnn_conf_t &rnn, |
111 | rnn_utils::cell_position_t cell_position, const src_t *proj_ht, |
112 | const weights_t *w_projection, gemm_acc_t *output, |
113 | gemm_acc_t *amx_scratchpad, |
114 | x64::brgemm_batch_element_t *addr_batch_global, |
115 | const postgemm_fused_t &fused_postgemm); |
116 | |
117 | void execute() const; |
118 | |
119 | private: |
120 | void kernel(const int ithr, const int nthr) const; |
121 | |
122 | private: |
123 | const ref_rnn_brgemm_t &rnn_brgemm_; |
124 | const rnn_utils::rnn_conf_t &rnn_; |
125 | const int proj_desc_idx_; |
126 | const src_t *const A_; |
127 | const weights_t *const B_; |
128 | gemm_acc_t *const C_; |
129 | const dim_t LDC_; |
130 | const dim_t max_nthr_; |
131 | const int work_amount_proj_; |
132 | const dim_t B_n_offset_; |
133 | const dim_t Bp_kb_offset_; |
134 | gemm_acc_t *const amx_scratchpad_; |
135 | brgemm_batch_element_t *const addr_batch_global_; |
136 | |
137 | const brgemm_kernel_t *const brgemm_kernel_main_; |
138 | const brgemm_kernel_t *const brgemm_kernel_n_tail_; |
139 | const brgemm_kernel_t *const brgemm_kernel_nk_tail_; |
140 | const brgemm_kernel_t *const brgemm_kernel_k_tail_; |
141 | const postgemm_fused_t fused_postgemm_; |
142 | }; |
143 | |
144 | template <typename src_t, typename weights_t, typename scratch_t, |
145 | typename gemm_acc_t> |
146 | class brgemm_gru_t { |
147 | public: |
148 | using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>; |
149 | using postgemm_fused_t = std::function<void( |
150 | dim_t, dim_t, dim_t, const src_t *, scratch_t *, scratch_t *, int)>; |
151 | brgemm_gru_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
152 | const rnn_utils::rnn_conf_t &rnn, |
153 | rnn_utils::cell_position_t cell_position, const src_t *src_iter, |
154 | const src_t *src_layer, weights_t *w_iter, weights_t *w_iter1, |
155 | weights_t *w_layer, src_t *d_layer, scratch_t *scratch_gates, |
156 | scratch_t *scratch_cell, gemm_acc_t *amx_scratchpad, |
157 | x64::brgemm_batch_element_t *addr_batch_global, |
158 | const postgemm_fused_t &fused_postgemm_part1, |
159 | const postgemm_fused_t &fused_postgemm_part2); |
160 | void execute() const; |
161 | |
162 | private: |
163 | void kernel(const int ithr, const int nthr) const; |
164 | |
165 | const ref_rnn_brgemm_t &rnn_brgemm_; |
166 | const rnn_utils::rnn_conf_t &rnn_; |
167 | const bool need_gemm_layer_; |
168 | const dim_t layer_desc_idx_; |
169 | const dim_t iter_desc_idx_; |
170 | const dim_t iter_part2_desc_idx_; |
171 | const src_t *const Al_; |
172 | const src_t *const Ai_; |
173 | const weights_t *const Bl_; |
174 | const weights_t *const Bi_; |
175 | const weights_t *const Bi2_; |
176 | scratch_t *const C_gates_; |
177 | scratch_t *const C_cell_; |
178 | src_t *const Dl_; |
179 | const dim_t LDAl_; |
180 | const dim_t LDAi_; |
181 | const dim_t max_nthr_; |
182 | const dim_t n_blocking_; |
183 | const dim_t m_blocking_; |
184 | const int work_amount_; |
185 | const dim_t Bl_n_offset_; |
186 | const dim_t Bi_n_offset_; |
187 | const dim_t Bl_g_offset_; |
188 | const dim_t Bi_g_offset_; |
189 | const dim_t Al_k_tail_offset_; |
190 | const dim_t Ai_k_tail_offset_; |
191 | const dim_t Bl_kb_offset_; |
192 | const dim_t Bi_kb_offset_; |
193 | const dim_t Bl_k_tail_offset_; |
194 | const dim_t Bi_k_tail_offset_; |
195 | const dim_t n_gates_; |
196 | const brgemm_kernel_t *const brgemm_kernel_iter_p0_main_; |
197 | const brgemm_kernel_t *const brgemm_kernel_iter_p0_n_tail_; |
198 | const brgemm_kernel_t *const brgemm_kernel_iter_p0_k_tail_; |
199 | const brgemm_kernel_t *const brgemm_kernel_iter_p0_nk_tail_; |
200 | const brgemm_kernel_t *const brgemm_kernel_iter_p1_main_; |
201 | const brgemm_kernel_t *const brgemm_kernel_iter_p1_n_tail_; |
202 | const brgemm_kernel_t *const brgemm_kernel_iter_p1_k_tail_; |
203 | const brgemm_kernel_t *const brgemm_kernel_iter_p1_nk_tail_; |
204 | |
205 | const brgemm_kernel_t *const brgemm_kernel_layer_main_; |
206 | const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_; |
207 | const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_; |
208 | const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_; |
209 | |
210 | const char *pallete_buff_iter_main_; |
211 | const char *pallete_buff_iter_n_tail_; |
212 | const char *pallete_buff_iter_k_tail_; |
213 | const char *pallete_buff_iter_nk_tail_; |
214 | |
215 | const char *pallete_buff_layer_main_; |
216 | const char *pallete_buff_layer_n_tail_; |
217 | const char *pallete_buff_layer_k_tail_; |
218 | const char *pallete_buff_layer_nk_tail_; |
219 | |
220 | gemm_acc_t *const amx_scratchpad_; |
221 | brgemm_batch_element_t *const addr_batch_global_; |
222 | const postgemm_fused_t fused_postgemm_part1_; |
223 | const postgemm_fused_t fused_postgemm_part2_; |
224 | const bool is_fused_layer_iter_brgemm_; |
225 | }; |
226 | |
227 | template <typename src_t, typename weights_t, typename scratch_t, |
228 | typename gemm_acc_t> |
229 | class brgemm_merged_layer_t { |
230 | public: |
231 | using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>; |
232 | brgemm_merged_layer_t(const ref_rnn_brgemm_t &rnn_brgemm_, |
233 | const rnn_utils::rnn_conf_t &rnn, |
234 | rnn_utils::cell_position_t cell_position, const src_t *src_layer, |
235 | weights_t *w_layer, scratch_t *scratch_gates, |
236 | gemm_acc_t *amx_scratchpad, |
237 | x64::brgemm_batch_element_t *addr_batch_global); |
238 | void execute() const; |
239 | |
240 | private: |
241 | void kernel(const int ithr, const int nthr) const; |
242 | |
243 | const ref_rnn_brgemm_t &rnn_brgemm_; |
244 | const rnn_utils::rnn_conf_t &rnn_; |
245 | const dim_t layer_desc_idx_; |
246 | const src_t *const Al_; |
247 | const weights_t *const Bl_; |
248 | scratch_t *const C_; |
249 | const dim_t LDAl_; |
250 | const dim_t max_nthr_; |
251 | const dim_t n_blocking_; |
252 | const dim_t m_blocking_; |
253 | const int work_amount_; |
254 | const dim_t Bl_n_offset_; |
255 | const dim_t Bl_g_offset_; |
256 | const dim_t Al_k_tail_offset_; |
257 | const dim_t Bl_kb_offset_; |
258 | const dim_t Bl_k_tail_offset_; |
259 | const dim_t n_gates_; |
260 | |
261 | const brgemm_kernel_t *const brgemm_kernel_layer_main_; |
262 | const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_; |
263 | const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_; |
264 | const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_; |
265 | |
266 | const char *pallete_buff_layer_main_; |
267 | const char *pallete_buff_layer_n_tail_; |
268 | const char *pallete_buff_layer_k_tail_; |
269 | const char *pallete_buff_layer_nk_tail_; |
270 | |
271 | gemm_acc_t *const amx_scratchpad_; |
272 | brgemm_batch_element_t *const addr_batch_global_; |
273 | }; |
274 | |
275 | } // namespace x64 |
276 | } // namespace cpu |
277 | } // namespace impl |
278 | } // namespace dnnl |
279 | |
280 | #endif |
281 | |