1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_BRGEMM_CELL_COMMON_FWD_HPP
18#define CPU_X64_RNN_BRGEMM_CELL_COMMON_FWD_HPP
19
20#include <functional>
21#include "common/bfloat16.hpp"
22#include "cpu/rnn/rnn_utils.hpp"
23#include "cpu/x64/rnn/rnn_brgemm_utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30template <typename src_t, typename weights_t, typename scratch_t,
31 typename gemm_acc_t>
32class brgemm_dst_layer_iter_t {
33public:
34 using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>;
35 using postgemm_fused_t = std::function<void(
36 dim_t, dim_t, dim_t, const src_t *, scratch_t *, int)>;
37 brgemm_dst_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm_,
38 const rnn_utils::rnn_conf_t &rnn,
39 rnn_utils::cell_position_t cell_position, const src_t *src_iter,
40 const src_t *src_layer, weights_t *w_iter, weights_t *w_layer,
41 scratch_t *scratch_gates, gemm_acc_t *amx_scratchpad,
42 x64::brgemm_batch_element_t *addr_batch_global,
43 const postgemm_fused_t &fused_postgemm);
44 void execute() const;
45
46private:
47 void kernel(const int ithr, const int nthr) const;
48 void kernel_fused_iter_layer(const int ithr, const int nthr) const;
49
50 const ref_rnn_brgemm_t &rnn_brgemm_;
51 const rnn_utils::rnn_conf_t &rnn_;
52 const bool need_gemm_layer_;
53 const dim_t layer_desc_idx_;
54 const dim_t iter_desc_idx_;
55 const src_t *const Al_;
56 const src_t *const Ai_;
57 const weights_t *const Bl_;
58 const weights_t *const Bi_;
59 scratch_t *const C_;
60 const dim_t LDAl_;
61 const dim_t LDAi_;
62 const dim_t max_nthr_;
63 const dim_t n_blocking_;
64 const dim_t m_blocking_;
65 const int work_amount_;
66 const dim_t Bl_n_offset_;
67 const dim_t Bi_n_offset_;
68 const dim_t Bl_g_offset_;
69 const dim_t Bi_g_offset_;
70 const dim_t Al_k_tail_offset_;
71 const dim_t Ai_k_tail_offset_;
72 const dim_t Bl_kb_offset_;
73 const dim_t Bi_kb_offset_;
74 const dim_t Bl_k_tail_offset_;
75 const dim_t Bi_k_tail_offset_;
76 const dim_t n_gates_;
77 const brgemm_kernel_t *const brgemm_kernel_iter_main_;
78 const brgemm_kernel_t *const brgemm_kernel_iter_n_tail_;
79 const brgemm_kernel_t *const brgemm_kernel_iter_k_tail_;
80 const brgemm_kernel_t *const brgemm_kernel_iter_nk_tail_;
81
82 const brgemm_kernel_t *const brgemm_kernel_layer_main_;
83 const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_;
84 const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_;
85 const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_;
86
87 const char *pallete_buff_iter_main_;
88 const char *pallete_buff_iter_n_tail_;
89 const char *pallete_buff_iter_k_tail_;
90 const char *pallete_buff_iter_nk_tail_;
91
92 const char *pallete_buff_layer_main_;
93 const char *pallete_buff_layer_n_tail_;
94 const char *pallete_buff_layer_k_tail_;
95 const char *pallete_buff_layer_nk_tail_;
96
97 gemm_acc_t *const amx_scratchpad_;
98 brgemm_batch_element_t *const addr_batch_global_;
99 const postgemm_fused_t fused_postgemm_;
100 const bool is_fused_layer_iter_brgemm_;
101};
102
103template <typename src_t, typename weights_t, typename gemm_acc_t>
104class brgemm_dst_proj_t {
105public:
106 using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>;
107 using postgemm_fused_t
108 = std::function<void(dim_t, dim_t, gemm_acc_t *, int)>;
109 brgemm_dst_proj_t(const ref_rnn_brgemm_t &rnn_brgemm_,
110 const rnn_utils::rnn_conf_t &rnn,
111 rnn_utils::cell_position_t cell_position, const src_t *proj_ht,
112 const weights_t *w_projection, gemm_acc_t *output,
113 gemm_acc_t *amx_scratchpad,
114 x64::brgemm_batch_element_t *addr_batch_global,
115 const postgemm_fused_t &fused_postgemm);
116
117 void execute() const;
118
119private:
120 void kernel(const int ithr, const int nthr) const;
121
122private:
123 const ref_rnn_brgemm_t &rnn_brgemm_;
124 const rnn_utils::rnn_conf_t &rnn_;
125 const int proj_desc_idx_;
126 const src_t *const A_;
127 const weights_t *const B_;
128 gemm_acc_t *const C_;
129 const dim_t LDC_;
130 const dim_t max_nthr_;
131 const int work_amount_proj_;
132 const dim_t B_n_offset_;
133 const dim_t Bp_kb_offset_;
134 gemm_acc_t *const amx_scratchpad_;
135 brgemm_batch_element_t *const addr_batch_global_;
136
137 const brgemm_kernel_t *const brgemm_kernel_main_;
138 const brgemm_kernel_t *const brgemm_kernel_n_tail_;
139 const brgemm_kernel_t *const brgemm_kernel_nk_tail_;
140 const brgemm_kernel_t *const brgemm_kernel_k_tail_;
141 const postgemm_fused_t fused_postgemm_;
142};
143
144template <typename src_t, typename weights_t, typename scratch_t,
145 typename gemm_acc_t>
146class brgemm_gru_t {
147public:
148 using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>;
149 using postgemm_fused_t = std::function<void(
150 dim_t, dim_t, dim_t, const src_t *, scratch_t *, scratch_t *, int)>;
151 brgemm_gru_t(const ref_rnn_brgemm_t &rnn_brgemm_,
152 const rnn_utils::rnn_conf_t &rnn,
153 rnn_utils::cell_position_t cell_position, const src_t *src_iter,
154 const src_t *src_layer, weights_t *w_iter, weights_t *w_iter1,
155 weights_t *w_layer, src_t *d_layer, scratch_t *scratch_gates,
156 scratch_t *scratch_cell, gemm_acc_t *amx_scratchpad,
157 x64::brgemm_batch_element_t *addr_batch_global,
158 const postgemm_fused_t &fused_postgemm_part1,
159 const postgemm_fused_t &fused_postgemm_part2);
160 void execute() const;
161
162private:
163 void kernel(const int ithr, const int nthr) const;
164
165 const ref_rnn_brgemm_t &rnn_brgemm_;
166 const rnn_utils::rnn_conf_t &rnn_;
167 const bool need_gemm_layer_;
168 const dim_t layer_desc_idx_;
169 const dim_t iter_desc_idx_;
170 const dim_t iter_part2_desc_idx_;
171 const src_t *const Al_;
172 const src_t *const Ai_;
173 const weights_t *const Bl_;
174 const weights_t *const Bi_;
175 const weights_t *const Bi2_;
176 scratch_t *const C_gates_;
177 scratch_t *const C_cell_;
178 src_t *const Dl_;
179 const dim_t LDAl_;
180 const dim_t LDAi_;
181 const dim_t max_nthr_;
182 const dim_t n_blocking_;
183 const dim_t m_blocking_;
184 const int work_amount_;
185 const dim_t Bl_n_offset_;
186 const dim_t Bi_n_offset_;
187 const dim_t Bl_g_offset_;
188 const dim_t Bi_g_offset_;
189 const dim_t Al_k_tail_offset_;
190 const dim_t Ai_k_tail_offset_;
191 const dim_t Bl_kb_offset_;
192 const dim_t Bi_kb_offset_;
193 const dim_t Bl_k_tail_offset_;
194 const dim_t Bi_k_tail_offset_;
195 const dim_t n_gates_;
196 const brgemm_kernel_t *const brgemm_kernel_iter_p0_main_;
197 const brgemm_kernel_t *const brgemm_kernel_iter_p0_n_tail_;
198 const brgemm_kernel_t *const brgemm_kernel_iter_p0_k_tail_;
199 const brgemm_kernel_t *const brgemm_kernel_iter_p0_nk_tail_;
200 const brgemm_kernel_t *const brgemm_kernel_iter_p1_main_;
201 const brgemm_kernel_t *const brgemm_kernel_iter_p1_n_tail_;
202 const brgemm_kernel_t *const brgemm_kernel_iter_p1_k_tail_;
203 const brgemm_kernel_t *const brgemm_kernel_iter_p1_nk_tail_;
204
205 const brgemm_kernel_t *const brgemm_kernel_layer_main_;
206 const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_;
207 const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_;
208 const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_;
209
210 const char *pallete_buff_iter_main_;
211 const char *pallete_buff_iter_n_tail_;
212 const char *pallete_buff_iter_k_tail_;
213 const char *pallete_buff_iter_nk_tail_;
214
215 const char *pallete_buff_layer_main_;
216 const char *pallete_buff_layer_n_tail_;
217 const char *pallete_buff_layer_k_tail_;
218 const char *pallete_buff_layer_nk_tail_;
219
220 gemm_acc_t *const amx_scratchpad_;
221 brgemm_batch_element_t *const addr_batch_global_;
222 const postgemm_fused_t fused_postgemm_part1_;
223 const postgemm_fused_t fused_postgemm_part2_;
224 const bool is_fused_layer_iter_brgemm_;
225};
226
227template <typename src_t, typename weights_t, typename scratch_t,
228 typename gemm_acc_t>
229class brgemm_merged_layer_t {
230public:
231 using ref_rnn_brgemm_t = rnn_brgemm_utils::rnn_brgemm_t<prop_kind::forward>;
232 brgemm_merged_layer_t(const ref_rnn_brgemm_t &rnn_brgemm_,
233 const rnn_utils::rnn_conf_t &rnn,
234 rnn_utils::cell_position_t cell_position, const src_t *src_layer,
235 weights_t *w_layer, scratch_t *scratch_gates,
236 gemm_acc_t *amx_scratchpad,
237 x64::brgemm_batch_element_t *addr_batch_global);
238 void execute() const;
239
240private:
241 void kernel(const int ithr, const int nthr) const;
242
243 const ref_rnn_brgemm_t &rnn_brgemm_;
244 const rnn_utils::rnn_conf_t &rnn_;
245 const dim_t layer_desc_idx_;
246 const src_t *const Al_;
247 const weights_t *const Bl_;
248 scratch_t *const C_;
249 const dim_t LDAl_;
250 const dim_t max_nthr_;
251 const dim_t n_blocking_;
252 const dim_t m_blocking_;
253 const int work_amount_;
254 const dim_t Bl_n_offset_;
255 const dim_t Bl_g_offset_;
256 const dim_t Al_k_tail_offset_;
257 const dim_t Bl_kb_offset_;
258 const dim_t Bl_k_tail_offset_;
259 const dim_t n_gates_;
260
261 const brgemm_kernel_t *const brgemm_kernel_layer_main_;
262 const brgemm_kernel_t *const brgemm_kernel_layer_n_tail_;
263 const brgemm_kernel_t *const brgemm_kernel_layer_k_tail_;
264 const brgemm_kernel_t *const brgemm_kernel_layer_nk_tail_;
265
266 const char *pallete_buff_layer_main_;
267 const char *pallete_buff_layer_n_tail_;
268 const char *pallete_buff_layer_k_tail_;
269 const char *pallete_buff_layer_nk_tail_;
270
271 gemm_acc_t *const amx_scratchpad_;
272 brgemm_batch_element_t *const addr_batch_global_;
273};
274
275} // namespace x64
276} // namespace cpu
277} // namespace impl
278} // namespace dnnl
279
280#endif
281