1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_BRGEMM_UTILS_RNN_HPP |
18 | #define CPU_X64_RNN_BRGEMM_UTILS_RNN_HPP |
19 | |
20 | #include <memory> |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "cpu/x64/brgemm/brgemm.hpp" |
24 | #include "cpu/x64/jit_brgemm_transpose_utils.hpp" |
25 | #include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp" |
26 | #include "cpu/x64/rnn/jit_brgemm_transpose_single_row.hpp" |
27 | #include "cpu/x64/rnn/jit_diff_weights_peephole.hpp" |
28 | #include "cpu/x64/rnn/jit_gates_reduction.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | namespace rnn_utils { |
35 | struct rnn_conf_t; |
36 | } |
37 | |
38 | namespace x64 { |
39 | |
40 | struct jit_brgemm_trans_src_t; |
41 | |
42 | namespace rnn_brgemm_utils { |
43 | |
44 | using brgemm_ker_ptr_t = std::unique_ptr<brgemm_kernel_t>; |
45 | using brgemm_pallete_t = char[64]; |
46 | using srcatch_gates_reorder_ker_ptr_t |
47 | = std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>; |
48 | |
49 | struct rnn_brgemm_base_t { |
50 | static void init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn, |
51 | memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size, |
52 | dim_t gemm_acc_align); |
53 | static constexpr dim_t num_base_kernels_ = 3; |
54 | static constexpr dim_t num_proj_kernels_ = 4; |
55 | static constexpr dim_t num_vanilla_gru_iter_part2_kernels_ = 4; |
56 | }; |
57 | |
58 | template <prop_kind_t aprop> |
59 | struct rnn_brgemm_t; |
60 | |
61 | template <> |
62 | struct rnn_brgemm_t<prop_kind::forward> : public rnn_brgemm_base_t { |
63 | using rnn_brgemm_base_t::init_scratchpad; |
64 | |
65 | static status_t configure_brgemm(cpu::rnn_utils::rnn_conf_t &rnn, |
66 | alg_kind_t cell_kind, dim_t src_layer_type_size, |
67 | dim_t scratch_type_size); |
68 | status_t init_kernels(const cpu::rnn_utils::rnn_conf_t &rnn, |
69 | data_type_t src_type, data_type_t weights_type); |
70 | |
71 | brgemm_t desc_layer_b0_[num_base_kernels_]; |
72 | brgemm_t desc_iter_b0_[num_base_kernels_]; |
73 | brgemm_t desc_iter_b1_[num_base_kernels_]; |
74 | brgemm_t desc_layer_N_tail_b0_[num_base_kernels_]; |
75 | brgemm_t desc_iter_N_tail_b0_[num_base_kernels_]; |
76 | brgemm_t desc_iter_N_tail_b1_[num_base_kernels_]; |
77 | |
78 | brgemm_t desc_layer_K1_tail_b1_[num_base_kernels_]; |
79 | brgemm_t desc_layer_NK1_tail_b1_[num_base_kernels_]; |
80 | brgemm_t desc_iter_K2_tail_b1_[num_base_kernels_]; |
81 | brgemm_t desc_iter_NK2_tail_b1_[num_base_kernels_]; |
82 | |
83 | brgemm_t desc_layermerged_b0_[num_base_kernels_]; |
84 | brgemm_t desc_layermerged_N_tail_b0_[num_base_kernels_]; |
85 | brgemm_t desc_layermerged_K1_tail_b1_[num_base_kernels_]; |
86 | brgemm_t desc_layermerged_NK1_tail_b1_[num_base_kernels_]; |
87 | |
88 | brgemm_t desc_proj_b0_[num_proj_kernels_]; |
89 | brgemm_t desc_proj_N_tail_b0_[num_proj_kernels_]; |
90 | brgemm_t desc_proj_N_tail_b1_[num_proj_kernels_]; |
91 | brgemm_t desc_proj_K_tail_b1_[num_proj_kernels_]; |
92 | brgemm_t desc_proj_NK_tail_b1_[num_proj_kernels_]; |
93 | |
94 | // Set of brgemm descriptor for 2nd part of iteration gemm in vanulla GRU |
95 | // cell |
96 | brgemm_t desc_iter_p2_b1_[num_vanilla_gru_iter_part2_kernels_]; |
97 | brgemm_t desc_iter_p2_N_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
98 | brgemm_t desc_iter_p2_K2_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
99 | brgemm_t desc_iter_p2_NK2_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
100 | |
101 | brgemm_ker_ptr_t kernel_layer_b0_[num_base_kernels_]; |
102 | brgemm_ker_ptr_t kernel_layer_b1_[num_base_kernels_]; |
103 | brgemm_ker_ptr_t kernel_iter_b0_[num_base_kernels_]; |
104 | brgemm_ker_ptr_t kernel_iter_b1_[num_base_kernels_]; |
105 | brgemm_ker_ptr_t kernel_layer_N_tail_b0_[num_base_kernels_]; |
106 | brgemm_ker_ptr_t kernel_layer_N_tail_b1_[num_base_kernels_]; |
107 | brgemm_ker_ptr_t kernel_iter_N_tail_b0_[num_base_kernels_]; |
108 | brgemm_ker_ptr_t kernel_iter_N_tail_b1_[num_base_kernels_]; |
109 | |
110 | brgemm_ker_ptr_t kernel_layer_K1_tail_b1_[num_base_kernels_]; |
111 | brgemm_ker_ptr_t kernel_layer_NK1_tail_b1_[num_base_kernels_]; |
112 | brgemm_ker_ptr_t kernel_iter_K2_tail_b1_[num_base_kernels_]; |
113 | brgemm_ker_ptr_t kernel_iter_NK2_tail_b1_[num_base_kernels_]; |
114 | |
115 | brgemm_ker_ptr_t kernel_layermerged_b0_[num_base_kernels_]; |
116 | brgemm_ker_ptr_t kernel_layermerged_b1_[num_base_kernels_]; |
117 | brgemm_ker_ptr_t kernel_layermerged_N_tail_b0_[num_base_kernels_]; |
118 | brgemm_ker_ptr_t kernel_layermerged_N_tail_b1_[num_base_kernels_]; |
119 | brgemm_ker_ptr_t kernel_layermerged_K1_tail_b1_[num_base_kernels_]; |
120 | brgemm_ker_ptr_t kernel_layermerged_NK1_tail_b1_[num_base_kernels_]; |
121 | |
122 | brgemm_ker_ptr_t kernel_proj_b0_[num_proj_kernels_]; |
123 | brgemm_ker_ptr_t kernel_proj_N_tail_b0_[num_proj_kernels_]; |
124 | brgemm_ker_ptr_t kernel_proj_N_tail_b1_[num_proj_kernels_]; |
125 | brgemm_ker_ptr_t kernel_proj_K_tail_b1_[num_proj_kernels_]; |
126 | brgemm_ker_ptr_t kernel_proj_NK_tail_b1_[num_proj_kernels_]; |
127 | |
128 | // Set of brgemm kernels for 2nd part of iteration gemm in vanulla GRU cell |
129 | brgemm_ker_ptr_t kernel_iter_p2_b1_[num_vanilla_gru_iter_part2_kernels_]; |
130 | brgemm_ker_ptr_t |
131 | kernel_iter_p2_N_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
132 | brgemm_ker_ptr_t |
133 | kernel_iter_p2_K2_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
134 | brgemm_ker_ptr_t |
135 | kernel_iter_p2_NK2_tail_b1_[num_vanilla_gru_iter_part2_kernels_]; |
136 | |
137 | brgemm_pallete_t pallete_buff_iter_; |
138 | brgemm_pallete_t pallete_buff_iter_n_tail_; |
139 | brgemm_pallete_t pallete_buff_layer_; |
140 | brgemm_pallete_t pallete_buff_layer_n_tail_; |
141 | |
142 | brgemm_pallete_t pallete_buff_k1_tail_; |
143 | brgemm_pallete_t pallete_buff_k2_tail_; |
144 | brgemm_pallete_t pallete_buff_nk1_tail_; |
145 | brgemm_pallete_t pallete_buff_nk2_tail_; |
146 | brgemm_pallete_t pallete_buff_proj_; |
147 | brgemm_pallete_t pallete_buff_nproj_tail_; |
148 | brgemm_pallete_t pallete_buff_kproj_tail_; |
149 | brgemm_pallete_t pallete_buff_nkproj_tail_; |
150 | |
151 | brgemm_pallete_t pallete_buff_layermerged_; |
152 | brgemm_pallete_t pallete_buff_layermerged_n_tail_; |
153 | brgemm_pallete_t pallete_buff_layermerged_k1_tail_; |
154 | brgemm_pallete_t pallete_buff_layermerged_nk1_tail_; |
155 | |
156 | private: |
157 | status_t brgemm_rnn_init_tiles(brgemm_t *desc, brgemm_pallete_t pallete); |
158 | status_t brgemm_rnn_init_tiles_proj( |
159 | brgemm_t *desc, brgemm_pallete_t pallete); |
160 | status_t brgemm_rnn_init_tiles( |
161 | brgemm_t *desc, dim_t size, brgemm_pallete_t pallete); |
162 | }; |
163 | |
164 | struct rnn_diff_src_brgemm_t { |
165 | brgemm_t desc_iter_layer_beta0_; |
166 | brgemm_t desc_iter_layer_beta1_; |
167 | brgemm_t desc_layer_N_tail_beta0_; |
168 | brgemm_t desc_layer_N_tail_beta1_; |
169 | brgemm_t desc_iter_N_tail_beta0_; |
170 | brgemm_t desc_iter_N_tail_beta1_; |
171 | brgemm_t desc_iter_layer_K_tail_beta1_; |
172 | brgemm_t desc_layer_NK_tail_beta1_; |
173 | brgemm_t desc_iter_NK_tail_beta1_; |
174 | |
175 | brgemm_ker_ptr_t kernel_iter_layer_beta0_ = nullptr; |
176 | brgemm_ker_ptr_t kernel_iter_layer_beta1_ = nullptr; |
177 | brgemm_ker_ptr_t kernel_layer_N_tail_beta0_ = nullptr; |
178 | brgemm_ker_ptr_t kernel_layer_N_tail_beta1_ = nullptr; |
179 | brgemm_ker_ptr_t kernel_iter_N_tail_beta0_ = nullptr; |
180 | brgemm_ker_ptr_t kernel_iter_N_tail_beta1_ = nullptr; |
181 | brgemm_ker_ptr_t kernel_iter_layer_K_tail_beta1_ = nullptr; |
182 | brgemm_ker_ptr_t kernel_layer_NK_tail_beta1_ = nullptr; |
183 | brgemm_ker_ptr_t kernel_iter_NK_tail_beta1_ = nullptr; |
184 | |
185 | brgemm_pallete_t pallete_buff_iter_layer_ = {}; |
186 | brgemm_pallete_t pallete_buff_iter_layer_k_tail_ = {}; |
187 | brgemm_pallete_t pallete_buff_iter_n_tail_ = {}; |
188 | brgemm_pallete_t pallete_buff_layer_n_tail_ = {}; |
189 | brgemm_pallete_t pallete_buff_iter_nk_tail_ = {}; |
190 | brgemm_pallete_t pallete_buff_layer_nk_tail_ = {}; |
191 | }; |
192 | |
193 | struct rnn_diff_wei_brgemm_t { |
194 | brgemm_t desc_iter_beta1_; |
195 | brgemm_t desc_layer_beta1_; |
196 | brgemm_t desc_iter_N_tail_beta1_; |
197 | brgemm_t desc_layer_N_tail_beta1_; |
198 | brgemm_t desc_iter_NK_tail_beta1_; |
199 | brgemm_t desc_layer_NK_tail_beta1_; |
200 | brgemm_t desc_iter_K_tail_beta1_; |
201 | brgemm_t desc_layer_K_tail_beta1_; |
202 | |
203 | brgemm_ker_ptr_t kernel_iter_beta1_ = nullptr; |
204 | brgemm_ker_ptr_t kernel_layer_beta1_ = nullptr; |
205 | brgemm_ker_ptr_t kernel_iter_N_tail_beta1_ = nullptr; |
206 | brgemm_ker_ptr_t kernel_layer_N_tail_beta1_ = nullptr; |
207 | brgemm_ker_ptr_t kernel_iter_NK_tail_beta1_ = nullptr; |
208 | brgemm_ker_ptr_t kernel_layer_NK_tail_beta1_ = nullptr; |
209 | brgemm_ker_ptr_t kernel_iter_K_tail_beta1_ = nullptr; |
210 | brgemm_ker_ptr_t kernel_layer_K_tail_beta1_ = nullptr; |
211 | |
212 | brgemm_pallete_t pallete_buff_iter_ = {}; |
213 | brgemm_pallete_t pallete_buff_layer_ = {}; |
214 | brgemm_pallete_t pallete_buff_iter_n_tail_ = {}; |
215 | brgemm_pallete_t pallete_buff_layer_n_tail_ = {}; |
216 | brgemm_pallete_t pallete_buff_iter_nk_tail_ = {}; |
217 | brgemm_pallete_t pallete_buff_layer_nk_tail_ = {}; |
218 | brgemm_pallete_t pallete_buff_iter_k_tail_ = {}; |
219 | brgemm_pallete_t pallete_buff_layer_k_tail_ = {}; |
220 | |
221 | srcatch_gates_reorder_ker_ptr_t srcatch_gates_reorder_kernel_; |
222 | }; |
223 | |
224 | template <> |
225 | struct rnn_brgemm_t<prop_kind::backward> : public rnn_brgemm_base_t { |
226 | public: |
227 | static void init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn, |
228 | memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size, |
229 | dim_t gemm_acc_align); |
230 | static status_t configure_brgemm(cpu::rnn_utils::rnn_conf_t &rnn, |
231 | alg_kind_t cell_kind, dim_t src_layer_type_size, |
232 | dim_t scratch_type_size); |
233 | |
234 | status_t init_kernels(const cpu::rnn_utils::rnn_conf_t &rnn, |
235 | data_type_t src_type, data_type_t weights_type); |
236 | |
237 | rnn_diff_src_brgemm_t diff_src_; |
238 | rnn_diff_wei_brgemm_t diff_wei_; |
239 | |
240 | std::unique_ptr<jit_gates_reduction_t> kernel_gates_reduction_; |
241 | std::unique_ptr<jit_gates_reduction_t> kernel_gates_reduction_tail_; |
242 | |
243 | std::unique_ptr<jit_brgemm_transpose_single_row_t> |
244 | kernel_transpose_single_row_iter_; |
245 | std::unique_ptr<jit_brgemm_transpose_single_row_t> |
246 | kernel_transpose_single_row_layer_; |
247 | |
248 | std::unique_ptr<jit_brgemm_trans_src_t> |
249 | kernel_transpose_iter_[num_base_kernels_]; |
250 | std::unique_ptr<jit_brgemm_trans_src_t> |
251 | kernel_transpose_layer_[num_base_kernels_]; |
252 | |
253 | std::unique_ptr<jit_diff_weights_peephole_t> kernel_peephole_; |
254 | std::unique_ptr<jit_diff_weights_peephole_t> kernel_peephole_tail_; |
255 | |
256 | private: |
257 | static void configure_brgemm_peephole(cpu::rnn_utils::rnn_conf_t &rnn); |
258 | |
259 | status_t init_peephole_kernels(const cpu::rnn_utils::rnn_conf_t &rnn); |
260 | }; |
261 | |
262 | } // namespace rnn_brgemm_utils |
263 | } // namespace x64 |
264 | } // namespace cpu |
265 | } // namespace impl |
266 | } // namespace dnnl |
267 | |
268 | #endif |
269 | |