1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_BRGEMM_UTILS_RNN_HPP
18#define CPU_X64_RNN_BRGEMM_UTILS_RNN_HPP
19
20#include <memory>
21#include "common/c_types_map.hpp"
22#include "common/memory_tracking.hpp"
23#include "cpu/x64/brgemm/brgemm.hpp"
24#include "cpu/x64/jit_brgemm_transpose_utils.hpp"
25#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp"
26#include "cpu/x64/rnn/jit_brgemm_transpose_single_row.hpp"
27#include "cpu/x64/rnn/jit_diff_weights_peephole.hpp"
28#include "cpu/x64/rnn/jit_gates_reduction.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34namespace rnn_utils {
35struct rnn_conf_t;
36}
37
38namespace x64 {
39
40struct jit_brgemm_trans_src_t;
41
42namespace rnn_brgemm_utils {
43
44using brgemm_ker_ptr_t = std::unique_ptr<brgemm_kernel_t>;
45using brgemm_pallete_t = char[64];
46using srcatch_gates_reorder_ker_ptr_t
47 = std::unique_ptr<matmul::jit_brgemm_matmul_copy_b_t>;
48
49struct rnn_brgemm_base_t {
50 static void init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn,
51 memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size,
52 dim_t gemm_acc_align);
53 static constexpr dim_t num_base_kernels_ = 3;
54 static constexpr dim_t num_proj_kernels_ = 4;
55 static constexpr dim_t num_vanilla_gru_iter_part2_kernels_ = 4;
56};
57
58template <prop_kind_t aprop>
59struct rnn_brgemm_t;
60
61template <>
62struct rnn_brgemm_t<prop_kind::forward> : public rnn_brgemm_base_t {
63 using rnn_brgemm_base_t::init_scratchpad;
64
65 static status_t configure_brgemm(cpu::rnn_utils::rnn_conf_t &rnn,
66 alg_kind_t cell_kind, dim_t src_layer_type_size,
67 dim_t scratch_type_size);
68 status_t init_kernels(const cpu::rnn_utils::rnn_conf_t &rnn,
69 data_type_t src_type, data_type_t weights_type);
70
71 brgemm_t desc_layer_b0_[num_base_kernels_];
72 brgemm_t desc_iter_b0_[num_base_kernels_];
73 brgemm_t desc_iter_b1_[num_base_kernels_];
74 brgemm_t desc_layer_N_tail_b0_[num_base_kernels_];
75 brgemm_t desc_iter_N_tail_b0_[num_base_kernels_];
76 brgemm_t desc_iter_N_tail_b1_[num_base_kernels_];
77
78 brgemm_t desc_layer_K1_tail_b1_[num_base_kernels_];
79 brgemm_t desc_layer_NK1_tail_b1_[num_base_kernels_];
80 brgemm_t desc_iter_K2_tail_b1_[num_base_kernels_];
81 brgemm_t desc_iter_NK2_tail_b1_[num_base_kernels_];
82
83 brgemm_t desc_layermerged_b0_[num_base_kernels_];
84 brgemm_t desc_layermerged_N_tail_b0_[num_base_kernels_];
85 brgemm_t desc_layermerged_K1_tail_b1_[num_base_kernels_];
86 brgemm_t desc_layermerged_NK1_tail_b1_[num_base_kernels_];
87
88 brgemm_t desc_proj_b0_[num_proj_kernels_];
89 brgemm_t desc_proj_N_tail_b0_[num_proj_kernels_];
90 brgemm_t desc_proj_N_tail_b1_[num_proj_kernels_];
91 brgemm_t desc_proj_K_tail_b1_[num_proj_kernels_];
92 brgemm_t desc_proj_NK_tail_b1_[num_proj_kernels_];
93
94 // Set of brgemm descriptor for 2nd part of iteration gemm in vanulla GRU
95 // cell
96 brgemm_t desc_iter_p2_b1_[num_vanilla_gru_iter_part2_kernels_];
97 brgemm_t desc_iter_p2_N_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
98 brgemm_t desc_iter_p2_K2_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
99 brgemm_t desc_iter_p2_NK2_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
100
101 brgemm_ker_ptr_t kernel_layer_b0_[num_base_kernels_];
102 brgemm_ker_ptr_t kernel_layer_b1_[num_base_kernels_];
103 brgemm_ker_ptr_t kernel_iter_b0_[num_base_kernels_];
104 brgemm_ker_ptr_t kernel_iter_b1_[num_base_kernels_];
105 brgemm_ker_ptr_t kernel_layer_N_tail_b0_[num_base_kernels_];
106 brgemm_ker_ptr_t kernel_layer_N_tail_b1_[num_base_kernels_];
107 brgemm_ker_ptr_t kernel_iter_N_tail_b0_[num_base_kernels_];
108 brgemm_ker_ptr_t kernel_iter_N_tail_b1_[num_base_kernels_];
109
110 brgemm_ker_ptr_t kernel_layer_K1_tail_b1_[num_base_kernels_];
111 brgemm_ker_ptr_t kernel_layer_NK1_tail_b1_[num_base_kernels_];
112 brgemm_ker_ptr_t kernel_iter_K2_tail_b1_[num_base_kernels_];
113 brgemm_ker_ptr_t kernel_iter_NK2_tail_b1_[num_base_kernels_];
114
115 brgemm_ker_ptr_t kernel_layermerged_b0_[num_base_kernels_];
116 brgemm_ker_ptr_t kernel_layermerged_b1_[num_base_kernels_];
117 brgemm_ker_ptr_t kernel_layermerged_N_tail_b0_[num_base_kernels_];
118 brgemm_ker_ptr_t kernel_layermerged_N_tail_b1_[num_base_kernels_];
119 brgemm_ker_ptr_t kernel_layermerged_K1_tail_b1_[num_base_kernels_];
120 brgemm_ker_ptr_t kernel_layermerged_NK1_tail_b1_[num_base_kernels_];
121
122 brgemm_ker_ptr_t kernel_proj_b0_[num_proj_kernels_];
123 brgemm_ker_ptr_t kernel_proj_N_tail_b0_[num_proj_kernels_];
124 brgemm_ker_ptr_t kernel_proj_N_tail_b1_[num_proj_kernels_];
125 brgemm_ker_ptr_t kernel_proj_K_tail_b1_[num_proj_kernels_];
126 brgemm_ker_ptr_t kernel_proj_NK_tail_b1_[num_proj_kernels_];
127
128 // Set of brgemm kernels for 2nd part of iteration gemm in vanulla GRU cell
129 brgemm_ker_ptr_t kernel_iter_p2_b1_[num_vanilla_gru_iter_part2_kernels_];
130 brgemm_ker_ptr_t
131 kernel_iter_p2_N_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
132 brgemm_ker_ptr_t
133 kernel_iter_p2_K2_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
134 brgemm_ker_ptr_t
135 kernel_iter_p2_NK2_tail_b1_[num_vanilla_gru_iter_part2_kernels_];
136
137 brgemm_pallete_t pallete_buff_iter_;
138 brgemm_pallete_t pallete_buff_iter_n_tail_;
139 brgemm_pallete_t pallete_buff_layer_;
140 brgemm_pallete_t pallete_buff_layer_n_tail_;
141
142 brgemm_pallete_t pallete_buff_k1_tail_;
143 brgemm_pallete_t pallete_buff_k2_tail_;
144 brgemm_pallete_t pallete_buff_nk1_tail_;
145 brgemm_pallete_t pallete_buff_nk2_tail_;
146 brgemm_pallete_t pallete_buff_proj_;
147 brgemm_pallete_t pallete_buff_nproj_tail_;
148 brgemm_pallete_t pallete_buff_kproj_tail_;
149 brgemm_pallete_t pallete_buff_nkproj_tail_;
150
151 brgemm_pallete_t pallete_buff_layermerged_;
152 brgemm_pallete_t pallete_buff_layermerged_n_tail_;
153 brgemm_pallete_t pallete_buff_layermerged_k1_tail_;
154 brgemm_pallete_t pallete_buff_layermerged_nk1_tail_;
155
156private:
157 status_t brgemm_rnn_init_tiles(brgemm_t *desc, brgemm_pallete_t pallete);
158 status_t brgemm_rnn_init_tiles_proj(
159 brgemm_t *desc, brgemm_pallete_t pallete);
160 status_t brgemm_rnn_init_tiles(
161 brgemm_t *desc, dim_t size, brgemm_pallete_t pallete);
162};
163
164struct rnn_diff_src_brgemm_t {
165 brgemm_t desc_iter_layer_beta0_;
166 brgemm_t desc_iter_layer_beta1_;
167 brgemm_t desc_layer_N_tail_beta0_;
168 brgemm_t desc_layer_N_tail_beta1_;
169 brgemm_t desc_iter_N_tail_beta0_;
170 brgemm_t desc_iter_N_tail_beta1_;
171 brgemm_t desc_iter_layer_K_tail_beta1_;
172 brgemm_t desc_layer_NK_tail_beta1_;
173 brgemm_t desc_iter_NK_tail_beta1_;
174
175 brgemm_ker_ptr_t kernel_iter_layer_beta0_ = nullptr;
176 brgemm_ker_ptr_t kernel_iter_layer_beta1_ = nullptr;
177 brgemm_ker_ptr_t kernel_layer_N_tail_beta0_ = nullptr;
178 brgemm_ker_ptr_t kernel_layer_N_tail_beta1_ = nullptr;
179 brgemm_ker_ptr_t kernel_iter_N_tail_beta0_ = nullptr;
180 brgemm_ker_ptr_t kernel_iter_N_tail_beta1_ = nullptr;
181 brgemm_ker_ptr_t kernel_iter_layer_K_tail_beta1_ = nullptr;
182 brgemm_ker_ptr_t kernel_layer_NK_tail_beta1_ = nullptr;
183 brgemm_ker_ptr_t kernel_iter_NK_tail_beta1_ = nullptr;
184
185 brgemm_pallete_t pallete_buff_iter_layer_ = {};
186 brgemm_pallete_t pallete_buff_iter_layer_k_tail_ = {};
187 brgemm_pallete_t pallete_buff_iter_n_tail_ = {};
188 brgemm_pallete_t pallete_buff_layer_n_tail_ = {};
189 brgemm_pallete_t pallete_buff_iter_nk_tail_ = {};
190 brgemm_pallete_t pallete_buff_layer_nk_tail_ = {};
191};
192
193struct rnn_diff_wei_brgemm_t {
194 brgemm_t desc_iter_beta1_;
195 brgemm_t desc_layer_beta1_;
196 brgemm_t desc_iter_N_tail_beta1_;
197 brgemm_t desc_layer_N_tail_beta1_;
198 brgemm_t desc_iter_NK_tail_beta1_;
199 brgemm_t desc_layer_NK_tail_beta1_;
200 brgemm_t desc_iter_K_tail_beta1_;
201 brgemm_t desc_layer_K_tail_beta1_;
202
203 brgemm_ker_ptr_t kernel_iter_beta1_ = nullptr;
204 brgemm_ker_ptr_t kernel_layer_beta1_ = nullptr;
205 brgemm_ker_ptr_t kernel_iter_N_tail_beta1_ = nullptr;
206 brgemm_ker_ptr_t kernel_layer_N_tail_beta1_ = nullptr;
207 brgemm_ker_ptr_t kernel_iter_NK_tail_beta1_ = nullptr;
208 brgemm_ker_ptr_t kernel_layer_NK_tail_beta1_ = nullptr;
209 brgemm_ker_ptr_t kernel_iter_K_tail_beta1_ = nullptr;
210 brgemm_ker_ptr_t kernel_layer_K_tail_beta1_ = nullptr;
211
212 brgemm_pallete_t pallete_buff_iter_ = {};
213 brgemm_pallete_t pallete_buff_layer_ = {};
214 brgemm_pallete_t pallete_buff_iter_n_tail_ = {};
215 brgemm_pallete_t pallete_buff_layer_n_tail_ = {};
216 brgemm_pallete_t pallete_buff_iter_nk_tail_ = {};
217 brgemm_pallete_t pallete_buff_layer_nk_tail_ = {};
218 brgemm_pallete_t pallete_buff_iter_k_tail_ = {};
219 brgemm_pallete_t pallete_buff_layer_k_tail_ = {};
220
221 srcatch_gates_reorder_ker_ptr_t srcatch_gates_reorder_kernel_;
222};
223
224template <>
225struct rnn_brgemm_t<prop_kind::backward> : public rnn_brgemm_base_t {
226public:
227 static void init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn,
228 memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size,
229 dim_t gemm_acc_align);
230 static status_t configure_brgemm(cpu::rnn_utils::rnn_conf_t &rnn,
231 alg_kind_t cell_kind, dim_t src_layer_type_size,
232 dim_t scratch_type_size);
233
234 status_t init_kernels(const cpu::rnn_utils::rnn_conf_t &rnn,
235 data_type_t src_type, data_type_t weights_type);
236
237 rnn_diff_src_brgemm_t diff_src_;
238 rnn_diff_wei_brgemm_t diff_wei_;
239
240 std::unique_ptr<jit_gates_reduction_t> kernel_gates_reduction_;
241 std::unique_ptr<jit_gates_reduction_t> kernel_gates_reduction_tail_;
242
243 std::unique_ptr<jit_brgemm_transpose_single_row_t>
244 kernel_transpose_single_row_iter_;
245 std::unique_ptr<jit_brgemm_transpose_single_row_t>
246 kernel_transpose_single_row_layer_;
247
248 std::unique_ptr<jit_brgemm_trans_src_t>
249 kernel_transpose_iter_[num_base_kernels_];
250 std::unique_ptr<jit_brgemm_trans_src_t>
251 kernel_transpose_layer_[num_base_kernels_];
252
253 std::unique_ptr<jit_diff_weights_peephole_t> kernel_peephole_;
254 std::unique_ptr<jit_diff_weights_peephole_t> kernel_peephole_tail_;
255
256private:
257 static void configure_brgemm_peephole(cpu::rnn_utils::rnn_conf_t &rnn);
258
259 status_t init_peephole_kernels(const cpu::rnn_utils::rnn_conf_t &rnn);
260};
261
262} // namespace rnn_brgemm_utils
263} // namespace x64
264} // namespace cpu
265} // namespace impl
266} // namespace dnnl
267
268#endif
269