1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Common for RNN and LSTM cell execution |
19 | */ |
20 | |
21 | #include "common/bfloat16.hpp" |
22 | #include "cpu/rnn/ref_rnn.hpp" |
23 | |
24 | #if DNNL_X64 |
25 | #include <cassert> |
26 | #include <functional> |
27 | #include "cpu/x64/rnn/brgemm_cell_common_bwd.hpp" |
28 | #include "cpu/x64/rnn/brgemm_cell_common_fwd.hpp" |
29 | #include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp" |
30 | #include "cpu/x64/rnn/brgemm_cell_common_utils.hpp" |
31 | #endif |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | |
37 | using namespace rnn_utils; |
38 | using namespace dnnl::impl::utils; |
39 | #if DNNL_X64 |
40 | using namespace dnnl::impl::cpu::x64; |
41 | #endif |
42 | |
43 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
44 | data_type_t acc_type> |
45 | rnn_merged_layer_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
46 | acc_type>::merged_layer_brgemm_fwd)) { |
47 | #if DNNL_X64 |
48 | using brgemm_merged_layer_t = x64::brgemm_merged_layer_t<src_iter_t, |
49 | weights_t, scratch_t, gemm_acc_t>; |
50 | const brgemm_merged_layer_t layer_calc(rnn_brgemm_, rnn, cell_position, |
51 | src_layer_, w_layer_[0], scratch_gates_, amx_scratchpad, |
52 | addr_batch_global); |
53 | |
54 | layer_calc.execute(); |
55 | #endif |
56 | return dnnl_success; |
57 | } |
58 | |
59 | template rnn_merged_layer_execution_sig( |
60 | ref_rnn_fwd_f32_t::merged_layer_brgemm_fwd); |
61 | template rnn_merged_layer_execution_sig( |
62 | ref_rnn_fwd_bf16_t::merged_layer_brgemm_fwd); |
63 | template rnn_merged_layer_execution_sig( |
64 | ref_rnn_fwd_u8s8_t::merged_layer_brgemm_fwd); |
65 | template rnn_merged_layer_execution_sig( |
66 | ref_rnn_fwd_s8s8_t::merged_layer_brgemm_fwd); |
67 | |
68 | template <> |
69 | rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_brgemm_fwd) { |
70 | assert(!"unimplemented" ); |
71 | return dnnl_success; |
72 | } |
73 | |
74 | template <> |
75 | rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_brgemm_fwd) { |
76 | assert(!"unimplemented" ); |
77 | return dnnl_success; |
78 | } |
79 | |
80 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
81 | data_type_t acc_type> |
82 | rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
83 | acc_type>::cell_execution_brgemm_fwd)) { |
84 | #if DNNL_X64 |
85 | |
86 | const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
87 | const int mask = pd_->attr()->rnn_weights_qparams_.mask_; |
88 | const auto dst_postgemm = rnn.is_lstm_projection ? proj_ht_ : dst_layer_; |
89 | const auto dst_iter_postgemm = rnn.is_lstm_projection ? nullptr : dst_iter_; |
90 | |
91 | const auto LDDl = rnn.dst_layer_ld(cell_position); |
92 | const auto LDDi = rnn.dst_iter_ld(cell_position); |
93 | const auto LDDic = rnn.dst_iter_c_ld(cell_position); |
94 | const auto LDAic = rnn.src_iter_c_ld(cell_position); |
95 | |
96 | using brgemm_dst_layer_iter_t = x64::brgemm_dst_layer_iter_t<src_iter_t, |
97 | weights_t, scratch_t, gemm_acc_t>; |
98 | |
99 | typename brgemm_dst_layer_iter_t::postgemm_fused_t fused_postgemm; |
100 | |
101 | if (!rnn.unfused_post_gemm) { |
102 | fused_postgemm = [&](dim_t m, dim_t n, dim_t nb_i, |
103 | const src_iter_t *Ai_m, scratch_t *C_n, |
104 | int block_step) { |
105 | const auto Dpg_n = (dst_postgemm != nullptr) |
106 | ? dst_postgemm + m * LDDl + n |
107 | : nullptr; |
108 | const auto Di_n = (dst_iter_postgemm != nullptr) |
109 | ? dst_iter_postgemm + m * LDDi + n |
110 | : nullptr; |
111 | const auto Dic_n = (dst_iter_c_ != nullptr) |
112 | ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n) |
113 | : nullptr; |
114 | |
115 | const auto curr_ws_gates_ |
116 | = ws_gates_ + (m * rnn.ws_gates_ld) + nb_i * rnn.n_block; |
117 | const float *weights_peephole_n = weights_peephole_ |
118 | ? weights_peephole_ + n |
119 | : weights_peephole_; |
120 | auto weights_scales_n = weights_scales + (mask ? n : 0); |
121 | const auto Aic_n |
122 | = inc_ptr(src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n); |
123 | const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n); |
124 | rnn_postgemm_->execute(rnn, cell_position, curr_ws_gates_, C_n, |
125 | augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n, |
126 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
127 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
128 | diff_dst_iter_c_, weights_peephole_n, bias_n, ws_grid_, |
129 | scratch_cell_, Di_n, weights_scales_n, block_step); |
130 | }; |
131 | } |
132 | |
133 | if (rnn.is_orig_gru) { |
134 | using brgemm_gru_t = x64::brgemm_gru_t<src_iter_t, weights_t, scratch_t, |
135 | gemm_acc_t>; |
136 | typename brgemm_gru_t::postgemm_fused_t fused_postgemm_gru_part1, |
137 | fused_postgemm_gru_part2; |
138 | if (!rnn.unfused_post_gemm) { |
139 | fused_postgemm_gru_part1 = [&](dim_t m, dim_t n, dim_t nb_i, |
140 | const src_iter_t *Ai_m, |
141 | scratch_t *C_gates_n, |
142 | scratch_t *C_cell_n, |
143 | int block_step) { |
144 | const auto Dpg_n = (dst_postgemm != nullptr) |
145 | ? dst_postgemm + m * LDDl + n |
146 | : nullptr; |
147 | const auto Di_n = (dst_iter_postgemm != nullptr) |
148 | ? dst_iter_postgemm + m * LDDi + n |
149 | : nullptr; |
150 | const auto Dic_n = (dst_iter_c_ != nullptr) |
151 | ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n) |
152 | : nullptr; |
153 | |
154 | const auto curr_ws_gates_ = ws_gates_ + (m * rnn.ws_gates_ld) |
155 | + nb_i * rnn.n_block; |
156 | const auto Aic_n = inc_ptr( |
157 | src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n); |
158 | const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n); |
159 | auto weights_scales_n = weights_scales + (mask ? n : 0); |
160 | rnn_postgemm_->execute(rnn, cell_position, curr_ws_gates_, |
161 | C_gates_n, augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n, |
162 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
163 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
164 | nullptr, nullptr, bias_n, ws_grid_, C_cell_n, Di_n, |
165 | weights_scales_n, block_step); |
166 | }; |
167 | fused_postgemm_gru_part2 = [&](dim_t m, dim_t n, dim_t nb_i, |
168 | const src_iter_t *Ai_m, |
169 | scratch_t *C_gates_n, |
170 | scratch_t *C_cell_n, |
171 | int block_step) { |
172 | const auto Dpg_n = (dst_postgemm != nullptr) |
173 | ? dst_postgemm + m * LDDl + n |
174 | : nullptr; |
175 | const auto Di_n = (dst_iter_postgemm != nullptr) |
176 | ? dst_iter_postgemm + m * LDDi + n |
177 | : nullptr; |
178 | const auto Dic_n = (dst_iter_c_ != nullptr) |
179 | ? inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, m * LDDic + n) |
180 | : nullptr; |
181 | |
182 | const auto curr_ws_gates_ = ws_gates_ + (m * rnn.ws_gates_ld) |
183 | + nb_i * rnn.n_block; |
184 | const auto Aic_n = inc_ptr( |
185 | src_iter_c_, rnn.src_iter_c_dt, m * LDAic + n); |
186 | const auto bias_n = inc_ptr(bias_[0], rnn.bias_dt, n); |
187 | auto weights_scales_n = weights_scales + (mask ? n : 0); |
188 | rnn_postgemm_->execute_part2(rnn, cell_position, curr_ws_gates_, |
189 | C_gates_n, augru_attention_, Dpg_n, Dic_n, Ai_m, Aic_n, |
190 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
191 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
192 | nullptr, nullptr, bias_n, ws_grid_, C_cell_n, Di_n, |
193 | weights_scales_n, block_step); |
194 | }; |
195 | } |
196 | |
197 | const brgemm_gru_t dst_calc(rnn_brgemm_, rnn, cell_position, src_iter_, |
198 | src_layer_, w_iter_[0], w_iter_[1], w_layer_[0], dst_postgemm, |
199 | scratch_gates_, scratch_cell_, amx_scratchpad, |
200 | addr_batch_global, fused_postgemm_gru_part1, |
201 | fused_postgemm_gru_part2); |
202 | dst_calc.execute(); |
203 | } else { |
204 | // calculate |
205 | // scratch_gates_ = src_layer_ * w_layer_ + src_iter_ * w_iter_ |
206 | const brgemm_dst_layer_iter_t dst_calc(rnn_brgemm_, rnn, cell_position, |
207 | src_iter_, src_layer_, w_iter_[0], w_layer_[0], scratch_gates_, |
208 | amx_scratchpad, addr_batch_global, fused_postgemm); |
209 | dst_calc.execute(); |
210 | } |
211 | |
212 | if (rnn.unfused_post_gemm) { |
213 | const auto wscales_postgemm = pd_->attr()->rnn_weights_qparams_.scales_; |
214 | |
215 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
216 | augru_attention_, dst_postgemm, dst_iter_c_, src_iter_, |
217 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
218 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
219 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_[0], |
220 | ws_grid_, scratch_cell_, dst_iter_postgemm, wscales_postgemm, |
221 | rnn.dhc * sizeof(scratch_t)); |
222 | } |
223 | |
224 | if (rnn.is_lstm_projection) { |
225 | const auto wscales_proj_postgemm |
226 | = pd_->attr()->rnn_weights_projection_qparams_.scales_; |
227 | gemm_acc_t *const Cp = (rnn.dt_conf == all_f32) |
228 | ? reinterpret_cast<gemm_acc_t *>(dst_layer_) |
229 | : scratch_gates_; |
230 | const int pLDDl = rnn.dst_layer_ld(cell_position, true); |
231 | const int pmask = pd_->attr()->rnn_weights_projection_qparams_.mask_; |
232 | |
233 | using brgemm_dst_proj_t |
234 | = x64::brgemm_dst_proj_t<ht_t, weights_t, gemm_acc_t>; |
235 | typename brgemm_dst_proj_t::postgemm_fused_t fused_postgemm_proj; |
236 | |
237 | if (!rnn.unfused_post_gemm) { |
238 | fused_postgemm_proj = [&](dim_t m, dim_t n, gemm_acc_t *Cp_n, |
239 | int block_step) { |
240 | const auto weights_scales_n |
241 | = wscales_proj_postgemm + (pmask ? n : 0); |
242 | const auto Di_n = (dst_iter_ != nullptr) |
243 | ? dst_iter_ + m * LDDi + n |
244 | : nullptr; |
245 | const auto Dl_n = (dst_layer_ != nullptr) |
246 | ? dst_layer_ + m * pLDDl + n |
247 | : nullptr; |
248 | const auto Wp_comp_n = w_proj_comp + n; |
249 | rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, Cp_n, |
250 | nullptr, Dl_n, nullptr, nullptr, Wp_comp_n, nullptr, |
251 | nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, |
252 | nullptr, nullptr, nullptr, nullptr, Di_n, |
253 | weights_scales_n, block_step); |
254 | }; |
255 | } |
256 | |
257 | // calculate |
258 | // output = proj_ht_ * w_projection_ |
259 | const brgemm_dst_proj_t dst_proj_calc(rnn_brgemm_, rnn, cell_position, |
260 | proj_ht_, w_projection_[0], Cp, amx_scratchpad, |
261 | addr_batch_global, fused_postgemm_proj); |
262 | dst_proj_calc.execute(); |
263 | |
264 | if (rnn.unfused_post_gemm) { |
265 | // we have to downconvert the output to dst_layer_t and copy to |
266 | // dst_iter if needed |
267 | rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, Cp, |
268 | nullptr, dst_layer_, nullptr, nullptr, w_proj_comp, nullptr, |
269 | nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, |
270 | nullptr, nullptr, nullptr, nullptr, dst_iter_, |
271 | wscales_proj_postgemm, rnn.dlc * sizeof(dst_layer_t)); |
272 | } |
273 | } |
274 | |
275 | #endif |
276 | return dnnl_success; |
277 | } |
278 | |
279 | template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_fwd); |
280 | template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_fwd); |
281 | template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_fwd); |
282 | template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_fwd); |
283 | |
284 | template <> |
285 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_fwd) { |
286 | assert(!"unimplemented" ); |
287 | return dnnl_success; |
288 | } |
289 | template <> |
290 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_fwd) { |
291 | assert(!"unimplemented" ); |
292 | return dnnl_success; |
293 | } |
294 | |
295 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
296 | data_type_t acc_type> |
297 | rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
298 | acc_type>::cell_execution_brgemm_bwd)) { |
299 | |
300 | #if DNNL_X64 |
301 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
302 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_, |
303 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
304 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
305 | weights_peephole_, bias_[0], ws_grid_, scratch_cell_, dst_iter_, |
306 | nullptr, 0); |
307 | |
308 | using brgemm_diff_src_calc_t = x64::brgemm_diff_src_layer_iter_t<weights_t, |
309 | scratch_t, gemm_acc_t>; |
310 | using brgemm_diff_weights_calc_t |
311 | = x64::brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, |
312 | scratch_t, gemm_acc_t>; |
313 | |
314 | const brgemm_diff_src_calc_t diff_src_calc(rnn_brgemm_, rnn, cell_position, |
315 | scratch_gates_, w_iter_[0], w_layer_[0], diff_src_iter_, |
316 | diff_src_layer_, amx_scratchpad, addr_batch_global); |
317 | const brgemm_diff_weights_calc_t diff_weights_calc(rnn_brgemm_, rnn, |
318 | cell_position, src_iter_, scratch_src_iter_, src_layer_, |
319 | scratch_src_layer_, scratch_gates_, scratch_gates_blocked_, |
320 | diff_w_iter_, diff_w_layer_, diff_bias_, amx_scratchpad, |
321 | addr_batch_global); |
322 | |
323 | // calculate |
324 | // dff_src_iter = scratch * w_iter |
325 | // dff_src_layer = scratch * w_layer |
326 | diff_src_calc.execute(); |
327 | |
328 | if (rnn.diff_wei_brgemm.global_transpose) { |
329 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
330 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
331 | const auto src_layer_ld_nb = rnn.layer_brgemm_desc(cell_position); |
332 | const auto src_iter_ld_nb = rnn.iter_brgemm_desc(cell_position); |
333 | const auto rnd_up_size = (src_type == data_type::bf16 ? 2 : 1); |
334 | const auto dst_ld = utils::rnd_up(rnn.mb, rnd_up_size); |
335 | |
336 | const auto layer_transpose = src_layer_iter_transpose_t(src_layer_ld, |
337 | dst_ld, rnn.mb, rnn.slc, |
338 | rnn_brgemm_.kernel_transpose_layer_[src_layer_ld_nb].get()); |
339 | const auto iter_transpose = src_layer_iter_transpose_t(src_iter_ld, |
340 | dst_ld, rnn.mb, rnn.sic, |
341 | rnn_brgemm_.kernel_transpose_iter_[src_iter_ld_nb].get()); |
342 | layer_transpose.execute(src_layer_, scratch_src_layer_); |
343 | iter_transpose.execute(src_iter_, scratch_src_iter_); |
344 | } |
345 | // calculate |
346 | // dff_weights_layer = src_layer^T * scratch |
347 | // dff_weights_iter = src_iter^T * scratch |
348 | // performs gates reductions |
349 | // diff_bias = scratch reduction over mb |
350 | diff_weights_calc.execute(); |
351 | |
352 | if (rnn.is_lstm_peephole) { |
353 | using brgemm_diff_wei_peep_t = x64::brgemm_diff_wei_peep_t<scratch_t>; |
354 | const brgemm_diff_wei_peep_t diff_wei_peep_calc(rnn_brgemm_, rnn, |
355 | cell_position, scratch_gates_, src_iter_c_, dst_iter_c_, |
356 | diff_weights_peephole_); |
357 | |
358 | diff_wei_peep_calc.execute(); |
359 | } |
360 | |
361 | #endif |
362 | |
363 | return dnnl_success; |
364 | } |
365 | |
366 | template <> |
367 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_bwd) { |
368 | assert(!"unimplemented" ); |
369 | return dnnl_success; |
370 | } |
371 | template <> |
372 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_bwd) { |
373 | assert(!"unimplemented" ); |
374 | return dnnl_success; |
375 | } |
376 | template <> |
377 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_bwd) { |
378 | assert(!"unimplemented" ); |
379 | return dnnl_success; |
380 | } |
381 | template <> |
382 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_bwd) { |
383 | assert(!"unimplemented" ); |
384 | return dnnl_success; |
385 | } |
386 | template rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_bwd); |
387 | template rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_bwd); |
388 | |
389 | } // namespace cpu |
390 | } // namespace impl |
391 | } // namespace dnnl |
392 | |