1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Common for RNN and LSTM cell execution
19 */
20
21#include "common/bfloat16.hpp"
22#include "cpu/rnn/ref_rnn.hpp"
23
24#if DNNL_X64
25#include <cassert>
26#include <functional>
27#include "cpu/x64/rnn/brgemm_cell_common_bwd.hpp"
28#include "cpu/x64/rnn/brgemm_cell_common_fwd.hpp"
29#include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp"
30#include "cpu/x64/rnn/brgemm_cell_common_utils.hpp"
31#endif
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37using namespace rnn_utils;
38using namespace dnnl::impl::utils;
39#if DNNL_X64
40using namespace dnnl::impl::cpu::x64;
41#endif
42
43template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
44 data_type_t acc_type>
45rnn_merged_layer_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
46 acc_type>::merged_layer_brgemm_fwd)) {
47#if DNNL_X64
48 using brgemm_merged_layer_t = x64::brgemm_merged_layer_t<src_iter_t,
49 weights_t, scratch_t, gemm_acc_t>;
50 const brgemm_merged_layer_t layer_calc(rnn_brgemm_, rnn, cell_position,
51 src_layer_, w_layer_[0], scratch_gates_, amx_scratchpad,
52 addr_batch_global);
53
54 layer_calc.execute();
55#endif
56 return dnnl_success;
57}
58
59template rnn_merged_layer_execution_sig(
60 ref_rnn_fwd_f32_t::merged_layer_brgemm_fwd);
61template rnn_merged_layer_execution_sig(
62 ref_rnn_fwd_bf16_t::merged_layer_brgemm_fwd);
63template rnn_merged_layer_execution_sig(
64 ref_rnn_fwd_u8s8_t::merged_layer_brgemm_fwd);
65template rnn_merged_layer_execution_sig(
66 ref_rnn_fwd_s8s8_t::merged_layer_brgemm_fwd);
67
68template <>
69rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_brgemm_fwd) {
70 assert(!"unimplemented");
71 return dnnl_success;
72}
73
74template <>
75rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_brgemm_fwd) {
76 assert(!"unimplemented");
77 return dnnl_success;
78}
79
80template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
81 data_type_t acc_type>
82rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
83 acc_type>::cell_execution_brgemm_fwd)) {
84#if DNNL_X64
85
86 const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
87 const int mask = pd_->attr()->rnn_weights_qparams_.mask_;
88 const auto dst_postgemm = rnn.is_lstm_projection ? proj_ht_ : dst_layer_;
89 const auto dst_iter_postgemm = rnn.is_lstm_projection ? nullptr : dst_iter_;
90
91 const auto LDDl = rnn.dst_layer_ld(cell_position);
92 const auto LDDi = rnn.dst_iter_ld(cell_position);
93 const auto LDDic = rnn.dst_iter_c_ld(cell_position);
94 const auto LDAic = rnn.src_iter_c_ld(cell_position);
95
96 using brgemm_dst_layer_iter_t = x64::brgemm_dst_layer_iter_t<src_iter_t,
97 weights_t, scratch_t, gemm_acc_t>;
98
99 typename brgemm_dst_layer_iter_t::postgemm_fused_t fused_postgemm;
100
101 if (!rnn.unfused_post_gemm) {
102 fused_postgemm = [&](dim_t m, dim_t n, dim_t nb_i,
103 const src_iter_t *Ai_m, scratch_t *C_n,
104 int block_step) {
105 const auto Dpg_n = (dst_postgemm != nullptr)
106 ? dst_postgemm + m * LDDl + n
107 : nullptr;
108 const auto Di_n = (dst_iter_postgemm != nullptr)
109 ? dst_iter_postgemm + m * LDDi + n
110 : nullptr;
111 const auto Dic_n = (dst_iter_c_ != nullptr)
112 ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n)
113 : nullptr;
114
115 const auto curr_ws_gates_
116 = ws_gates_ + (m * rnn.ws_gates_ld) + nb_i * rnn.n_block;
117 const float *weights_peephole_n = weights_peephole_
118 ? weights_peephole_ + n
119 : weights_peephole_;
120 auto weights_scales_n = weights_scales + (mask ? n : 0);
121 const auto Aic_n
122 = inc_ptr(src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n);
123 const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n);
124 rnn_postgemm_->execute(rnn, cell_position, curr_ws_gates_, C_n,
125 augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n,
126 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
127 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
128 diff_dst_iter_c_, weights_peephole_n, bias_n, ws_grid_,
129 scratch_cell_, Di_n, weights_scales_n, block_step);
130 };
131 }
132
133 if (rnn.is_orig_gru) {
134 using brgemm_gru_t = x64::brgemm_gru_t<src_iter_t, weights_t, scratch_t,
135 gemm_acc_t>;
136 typename brgemm_gru_t::postgemm_fused_t fused_postgemm_gru_part1,
137 fused_postgemm_gru_part2;
138 if (!rnn.unfused_post_gemm) {
139 fused_postgemm_gru_part1 = [&](dim_t m, dim_t n, dim_t nb_i,
140 const src_iter_t *Ai_m,
141 scratch_t *C_gates_n,
142 scratch_t *C_cell_n,
143 int block_step) {
144 const auto Dpg_n = (dst_postgemm != nullptr)
145 ? dst_postgemm + m * LDDl + n
146 : nullptr;
147 const auto Di_n = (dst_iter_postgemm != nullptr)
148 ? dst_iter_postgemm + m * LDDi + n
149 : nullptr;
150 const auto Dic_n = (dst_iter_c_ != nullptr)
151 ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n)
152 : nullptr;
153
154 const auto curr_ws_gates_ = ws_gates_ + (m * rnn.ws_gates_ld)
155 + nb_i * rnn.n_block;
156 const auto Aic_n = inc_ptr(
157 src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n);
158 const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n);
159 auto weights_scales_n = weights_scales + (mask ? n : 0);
160 rnn_postgemm_->execute(rnn, cell_position, curr_ws_gates_,
161 C_gates_n, augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n,
162 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
163 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
164 nullptr, nullptr, bias_n, ws_grid_, C_cell_n, Di_n,
165 weights_scales_n, block_step);
166 };
167 fused_postgemm_gru_part2 = [&](dim_t m, dim_t n, dim_t nb_i,
168 const src_iter_t *Ai_m,
169 scratch_t *C_gates_n,
170 scratch_t *C_cell_n,
171 int block_step) {
172 const auto Dpg_n = (dst_postgemm != nullptr)
173 ? dst_postgemm + m * LDDl + n
174 : nullptr;
175 const auto Di_n = (dst_iter_postgemm != nullptr)
176 ? dst_iter_postgemm + m * LDDi + n
177 : nullptr;
178 const auto Dic_n = (dst_iter_c_ != nullptr)
179 ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n)
180 : nullptr;
181
182 const auto curr_ws_gates_ = ws_gates_ + (m * rnn.ws_gates_ld)
183 + nb_i * rnn.n_block;
184 const auto Aic_n = inc_ptr(
185 src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n);
186 const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n);
187 auto weights_scales_n = weights_scales + (mask ? n : 0);
188 rnn_postgemm_->execute_part2(rnn, cell_position, curr_ws_gates_,
189 C_gates_n, augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n,
190 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
191 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
192 nullptr, nullptr, bias_n, ws_grid_, C_cell_n, Di_n,
193 weights_scales_n, block_step);
194 };
195 }
196
197 const brgemm_gru_t dst_calc(rnn_brgemm_, rnn, cell_position, src_iter_,
198 src_layer_, w_iter_[0], w_iter_[1], w_layer_[0], dst_postgemm,
199 scratch_gates_, scratch_cell_, amx_scratchpad,
200 addr_batch_global, fused_postgemm_gru_part1,
201 fused_postgemm_gru_part2);
202 dst_calc.execute();
203 } else {
204 // calculate
205 // scratch_gates_ = src_layer_ * w_layer_ + src_iter_ * w_iter_
206 const brgemm_dst_layer_iter_t dst_calc(rnn_brgemm_, rnn, cell_position,
207 src_iter_, src_layer_, w_iter_[0], w_layer_[0], scratch_gates_,
208 amx_scratchpad, addr_batch_global, fused_postgemm);
209 dst_calc.execute();
210 }
211
212 if (rnn.unfused_post_gemm) {
213 const auto wscales_postgemm = pd_->attr()->rnn_weights_qparams_.scales_;
214
215 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
216 augru_attention_, dst_postgemm, dst_iter_c_, src_iter_,
217 src_iter_c_, diff_src_layer_, diff_augru_attention_,
218 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
219 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_[0],
220 ws_grid_, scratch_cell_, dst_iter_postgemm, wscales_postgemm,
221 rnn.dhc * sizeof(scratch_t));
222 }
223
224 if (rnn.is_lstm_projection) {
225 const auto wscales_proj_postgemm
226 = pd_->attr()->rnn_weights_projection_qparams_.scales_;
227 gemm_acc_t *const Cp = (rnn.dt_conf == all_f32)
228 ? reinterpret_cast<gemm_acc_t *>(dst_layer_)
229 : scratch_gates_;
230 const int pLDDl = rnn.dst_layer_ld(cell_position, true);
231 const int pmask = pd_->attr()->rnn_weights_projection_qparams_.mask_;
232
233 using brgemm_dst_proj_t
234 = x64::brgemm_dst_proj_t<ht_t, weights_t, gemm_acc_t>;
235 typename brgemm_dst_proj_t::postgemm_fused_t fused_postgemm_proj;
236
237 if (!rnn.unfused_post_gemm) {
238 fused_postgemm_proj = [&](dim_t m, dim_t n, gemm_acc_t *Cp_n,
239 int block_step) {
240 const auto weights_scales_n
241 = wscales_proj_postgemm + (pmask ? n : 0);
242 const auto Di_n = (dst_iter_ != nullptr)
243 ? dst_iter_ + m * LDDi + n
244 : nullptr;
245 const auto Dl_n = (dst_layer_ != nullptr)
246 ? dst_layer_ + m * pLDDl + n
247 : nullptr;
248 const auto Wp_comp_n = w_proj_comp + n;
249 rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, Cp_n,
250 nullptr, Dl_n, nullptr, nullptr, Wp_comp_n, nullptr,
251 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
252 nullptr, nullptr, nullptr, nullptr, Di_n,
253 weights_scales_n, block_step);
254 };
255 }
256
257 // calculate
258 // output = proj_ht_ * w_projection_
259 const brgemm_dst_proj_t dst_proj_calc(rnn_brgemm_, rnn, cell_position,
260 proj_ht_, w_projection_[0], Cp, amx_scratchpad,
261 addr_batch_global, fused_postgemm_proj);
262 dst_proj_calc.execute();
263
264 if (rnn.unfused_post_gemm) {
265 // we have to downconvert the output to dst_layer_t and copy to
266 // dst_iter if needed
267 rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, Cp,
268 nullptr, dst_layer_, nullptr, nullptr, w_proj_comp, nullptr,
269 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
270 nullptr, nullptr, nullptr, nullptr, dst_iter_,
271 wscales_proj_postgemm, rnn.dlc * sizeof(dst_layer_t));
272 }
273 }
274
275#endif
276 return dnnl_success;
277}
278
279template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_fwd);
280template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_fwd);
281template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_fwd);
282template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_fwd);
283
284template <>
285rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_fwd) {
286 assert(!"unimplemented");
287 return dnnl_success;
288}
289template <>
290rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_fwd) {
291 assert(!"unimplemented");
292 return dnnl_success;
293}
294
295template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
296 data_type_t acc_type>
297rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
298 acc_type>::cell_execution_brgemm_bwd)) {
299
300#if DNNL_X64
301 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
302 augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_,
303 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
304 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
305 weights_peephole_, bias_[0], ws_grid_, scratch_cell_, dst_iter_,
306 nullptr, 0);
307
308 using brgemm_diff_src_calc_t = x64::brgemm_diff_src_layer_iter_t<weights_t,
309 scratch_t, gemm_acc_t>;
310 using brgemm_diff_weights_calc_t
311 = x64::brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t,
312 scratch_t, gemm_acc_t>;
313
314 const brgemm_diff_src_calc_t diff_src_calc(rnn_brgemm_, rnn, cell_position,
315 scratch_gates_, w_iter_[0], w_layer_[0], diff_src_iter_,
316 diff_src_layer_, amx_scratchpad, addr_batch_global);
317 const brgemm_diff_weights_calc_t diff_weights_calc(rnn_brgemm_, rnn,
318 cell_position, src_iter_, scratch_src_iter_, src_layer_,
319 scratch_src_layer_, scratch_gates_, scratch_gates_blocked_,
320 diff_w_iter_, diff_w_layer_, diff_bias_, amx_scratchpad,
321 addr_batch_global);
322
323 // calculate
324 // dff_src_iter = scratch * w_iter
325 // dff_src_layer = scratch * w_layer
326 diff_src_calc.execute();
327
328 if (rnn.diff_wei_brgemm.global_transpose) {
329 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
330 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
331 const auto src_layer_ld_nb = rnn.layer_brgemm_desc(cell_position);
332 const auto src_iter_ld_nb = rnn.iter_brgemm_desc(cell_position);
333 const auto rnd_up_size = (src_type == data_type::bf16 ? 2 : 1);
334 const auto dst_ld = utils::rnd_up(rnn.mb, rnd_up_size);
335
336 const auto layer_transpose = src_layer_iter_transpose_t(src_layer_ld,
337 dst_ld, rnn.mb, rnn.slc,
338 rnn_brgemm_.kernel_transpose_layer_[src_layer_ld_nb].get());
339 const auto iter_transpose = src_layer_iter_transpose_t(src_iter_ld,
340 dst_ld, rnn.mb, rnn.sic,
341 rnn_brgemm_.kernel_transpose_iter_[src_iter_ld_nb].get());
342 layer_transpose.execute(src_layer_, scratch_src_layer_);
343 iter_transpose.execute(src_iter_, scratch_src_iter_);
344 }
345 // calculate
346 // dff_weights_layer = src_layer^T * scratch
347 // dff_weights_iter = src_iter^T * scratch
348 // performs gates reductions
349 // diff_bias = scratch reduction over mb
350 diff_weights_calc.execute();
351
352 if (rnn.is_lstm_peephole) {
353 using brgemm_diff_wei_peep_t = x64::brgemm_diff_wei_peep_t<scratch_t>;
354 const brgemm_diff_wei_peep_t diff_wei_peep_calc(rnn_brgemm_, rnn,
355 cell_position, scratch_gates_, src_iter_c_, dst_iter_c_,
356 diff_weights_peephole_);
357
358 diff_wei_peep_calc.execute();
359 }
360
361#endif
362
363 return dnnl_success;
364}
365
366template <>
367rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_bwd) {
368 assert(!"unimplemented");
369 return dnnl_success;
370}
371template <>
372rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_bwd) {
373 assert(!"unimplemented");
374 return dnnl_success;
375}
376template <>
377rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_bwd) {
378 assert(!"unimplemented");
379 return dnnl_success;
380}
381template <>
382rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_bwd) {
383 assert(!"unimplemented");
384 return dnnl_success;
385}
386template rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_bwd);
387template rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_bwd);
388
389} // namespace cpu
390} // namespace impl
391} // namespace dnnl
392