1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Common for RNN and LSTM cell execution |
19 | */ |
20 | #include "common/bfloat16.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | |
23 | #include "cpu/rnn/ref_rnn.hpp" |
24 | #include "cpu/simple_q10n.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | using namespace rnn_utils; |
30 | using namespace dnnl::impl::utils; |
31 | |
32 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
33 | data_type_t acc_type> |
34 | rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
35 | acc_type>::cell_execution_ref)) { |
36 | const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
37 | const auto weights_projection_scales = rnn.is_lstm_projection |
38 | ? pd_->attr()->rnn_weights_projection_qparams_.scales_ |
39 | : nullptr; |
40 | |
41 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
42 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
43 | |
44 | if (rnn.need_gemm_layer(cell_position)) { |
45 | CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb, |
46 | rnn.slc, 1.0f, w_layer_[0], rnn.weights_layer_ld, src_layer_, |
47 | src_layer_ld, 0.0f, scratch_gates_, rnn.scratch_gates_ld)); |
48 | } |
49 | CHECK((this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb, |
50 | rnn.sic, 1.0f, w_iter_[0], rnn.weights_iter_ld, src_iter_, |
51 | src_iter_ld, 1.0f, scratch_gates_, rnn.scratch_gates_ld)); |
52 | |
53 | // Note: here proj_ht is scratchpad if inference or workspace if training |
54 | const auto dst_postgemm = rnn.is_lstm_projection ? proj_ht_ : dst_layer_; |
55 | // for lstmp, the copy to dst_iter happens after the projection |
56 | const auto dst_iter_postgemm = rnn.is_lstm_projection ? nullptr : dst_iter_; |
57 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
58 | augru_attention_, dst_postgemm, dst_iter_c_, src_iter_, src_iter_c_, |
59 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
60 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
61 | weights_peephole_, bias_[0], ws_grid_, scratch_cell_, |
62 | dst_iter_postgemm, weights_scales, rnn.dhc * sizeof(scratch_t)); |
63 | |
64 | if (rnn.is_lstm_projection) { |
65 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true); |
66 | |
67 | // Here, because the accumulation type is different |
68 | // than dst_layer, we have to use scratch to hold temporary |
69 | // accumulators |
70 | assert(rnn.scratch_gates_ld >= rnn.dlc); |
71 | gemm_acc_t *dst_proj = rnn.dt_conf == all_f32 ? (gemm_acc_t *)dst_layer_ |
72 | : scratch_gates_; |
73 | const int dst_proj_ld |
74 | = rnn.dt_conf == all_f32 ? dst_layer_ld : rnn.scratch_gates_ld; |
75 | |
76 | CHECK((this->*gemm_projection_func)('N', 'N', rnn.dic, rnn.mb, rnn.dhc, |
77 | 1.0f, w_projection_[0], rnn.weights_projection_ld, dst_postgemm, |
78 | rnn.proj_ht_ld, 0.0f, dst_proj, dst_proj_ld)); |
79 | |
80 | // we have to downconvert the output to dst_layer_t and copy to dst_iter if needed |
81 | rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, dst_proj, |
82 | nullptr, dst_layer_, nullptr, nullptr, w_proj_comp, nullptr, |
83 | nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, |
84 | nullptr, nullptr, nullptr, dst_iter_, weights_projection_scales, |
85 | rnn.dlc * sizeof(dst_layer_t)); |
86 | } |
87 | |
88 | return dnnl_success; |
89 | } |
90 | |
91 | template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_ref); |
92 | template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_ref); |
93 | template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_ref); |
94 | template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_ref); |
95 | |
96 | template <typename scratch_data_t, typename acc_data_t> |
97 | void lstm_bwd_weights_peephole_and_bias(const rnn_utils::rnn_conf_t &rnn, |
98 | cell_position_t cell_position, const void *src_iter_c_, |
99 | const void *dst_iter_c_, const scratch_data_t *scratch_gates_, |
100 | float *diff_weights_peephole_, acc_data_t *diff_bias_) { |
101 | const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position); |
102 | const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position); |
103 | |
104 | const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_, |
105 | types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld, |
106 | dst_iter_c_ld); |
107 | const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_, |
108 | types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld, |
109 | src_iter_c_ld); |
110 | |
111 | const ws_gates_aoc<const scratch_data_t> scratch_gates(rnn, scratch_gates_); |
112 | const weights_peephole_aoc_t<float> diff_weights_peephole( |
113 | rnn, diff_weights_peephole_); |
114 | |
115 | parallel(0, [&](int ithr, int nthr) { |
116 | int g_dhc_start {}, g_dhc_stop {}; |
117 | const int gates_to_process = 5; // 3 -- weights peephole + |
118 | // 2 -- bias (process a pair at once) |
119 | balance211(gates_to_process * rnn.dhc, nthr, ithr, g_dhc_start, |
120 | g_dhc_stop); |
121 | int g = g_dhc_start / rnn.dhc; |
122 | int dhc = g_dhc_start % rnn.dhc; |
123 | while (g_dhc_start++ < g_dhc_stop) { |
124 | if (g < 3) { |
125 | // weights peephole |
126 | auto &c_states = g < 2 ? src_iter_c : dst_iter_c; |
127 | const auto c_states_dt |
128 | = g < 2 ? rnn.src_iter_c_dt : rnn.dst_iter_c_dt; |
129 | |
130 | const int scratch_g = g < 2 ? g : 3; |
131 | for (int mb = 0; mb < rnn.mb; ++mb) { |
132 | diff_weights_peephole(g, dhc) |
133 | += to_float(c_states(mb, dhc), c_states_dt) |
134 | * scratch_gates(mb, scratch_g, dhc); |
135 | } |
136 | } else { |
137 | // bias |
138 | const int bias_g_start = 2 * (g - 3); |
139 | const int bias_g_end = bias_g_start + 2; |
140 | for_(int bias_g = bias_g_start; bias_g < bias_g_end; ++bias_g) |
141 | for (int mb = 0; mb < rnn.mb; ++mb) |
142 | diff_bias_[bias_g * rnn.dhc + dhc] |
143 | += scratch_gates(mb, bias_g, dhc); |
144 | } |
145 | if (++dhc == rnn.dhc) { |
146 | dhc = 0; |
147 | g++; |
148 | } |
149 | } |
150 | }); |
151 | } |
152 | |
153 | template <typename T1, typename T2, typename T3, typename T4, typename T5, |
154 | typename T6, typename T7, typename weights_data_t, typename src_data_t, |
155 | typename acc_data_t, typename scratch_data_t> |
156 | dnnl_status_t common_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f, |
157 | T3 gemm_proj_f, T4 gemm_weights_layer_f, T5 gemm_weights_iter_f, |
158 | T6 gemm_weights_proj_f, T7 rnn_postgemm, |
159 | const rnn_utils::rnn_conf_t &rnn, const cell_position_t cell_position, |
160 | src_data_t *dst_layer_, void *dst_iter_c_, acc_data_t *diff_src_layer_, |
161 | acc_data_t *diff_augru_attention_, acc_data_t *diff_src_iter_, |
162 | acc_data_t *diff_src_iter_c_, weights_data_t **w_layer_, |
163 | weights_data_t **w_iter_, weights_data_t **w_proj_, |
164 | const float *weights_peephole_, void **bias_, |
165 | const src_data_t *src_layer_, const src_data_t *augru_attention_, |
166 | const src_data_t *src_iter_, const void *src_iter_c_, |
167 | acc_data_t *diff_dst_layer_, acc_data_t *diff_dst_iter_, |
168 | acc_data_t *diff_dst_iter_c_, acc_data_t *diff_w_layer_, |
169 | acc_data_t *diff_w_iter_, float *diff_weights_projection_, |
170 | float *diff_weights_peephole_, acc_data_t *diff_bias_, |
171 | src_data_t *ws_gates_, scratch_data_t *scratch_gates_, |
172 | src_data_t *ws_ht_, acc_data_t *scratch_diff_ht_, src_data_t *ws_grid_, |
173 | scratch_data_t *scratch_cell_, src_data_t *dst_iter_) { |
174 | |
175 | if (rnn.is_lstm_projection) { |
176 | parallel_nd(rnn.mb, [&](dim_t i) { |
177 | PRAGMA_OMP_SIMD() |
178 | for (int j = 0; j < rnn.dlc; j++) |
179 | scratch_diff_ht_[i * rnn.scratch_diff_ht_ld + j] |
180 | = diff_dst_layer_[i * rnn.ws_diff_states_layer_ld + j] |
181 | + diff_dst_iter_[i * rnn.ws_diff_states_iter_ld + j]; |
182 | }); |
183 | |
184 | CHECK(gemm_weights_proj_f( |
185 | scratch_diff_ht_, ws_ht_, diff_weights_projection_)); |
186 | CHECK(gemm_proj_f(w_proj_[0], scratch_diff_ht_, diff_dst_layer_)); |
187 | } |
188 | |
189 | rnn_postgemm->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
190 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_, |
191 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
192 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
193 | weights_peephole_, bias_[0], ws_grid_, scratch_cell_, dst_iter_, |
194 | nullptr, 0); |
195 | |
196 | /// bwd by data on the cell |
197 | CHECK(gemm_iter_f(w_iter_[0], scratch_gates_, diff_src_iter_)); |
198 | |
199 | /// bwd by weights on the cell |
200 | if (rnn.need_gemm_layer(cell_position)) |
201 | CHECK(gemm_weights_layer_f(scratch_gates_, src_layer_, diff_w_layer_)); |
202 | |
203 | if (!rnn.merge_gemm_layer) |
204 | CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_)); |
205 | |
206 | if (!rnn.merge_gemm_iter) |
207 | CHECK(gemm_weights_iter_f(scratch_gates_, src_iter_, diff_w_iter_)); |
208 | |
209 | if (rnn.is_lstm_peephole) { |
210 | /// bwd by weights peephole and bias |
211 | lstm_bwd_weights_peephole_and_bias(rnn, cell_position, src_iter_c_, |
212 | dst_iter_c_, scratch_gates_, diff_weights_peephole_, |
213 | diff_bias_); |
214 | } else { |
215 | /// bwd by bias we just accumulate diffs from the gates |
216 | gates_reduction(rnn, scratch_gates_, diff_bias_); |
217 | } |
218 | return dnnl_success; |
219 | } |
220 | |
221 | template <> |
222 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_ref) { |
223 | const auto gemm_layer = [&](const float *A, const float *B, float *C) { |
224 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
225 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B, |
226 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld); |
227 | }; |
228 | const auto gemm_iter = [&](const float *A, const float *B, float *C) { |
229 | return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, |
230 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_iter_ld, B, |
231 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_iter_ld); |
232 | }; |
233 | const auto gemm_proj = [&](const float *A, const float *B, float *C) { |
234 | return (this->*gemm_projection_func)('N', 'N', rnn.dhc, rnn.mb, rnn.dic, |
235 | 1.0, A, rnn.weights_projection_ld, B, rnn.scratch_diff_ht_ld, |
236 | 0.0f, C, rnn.ws_diff_states_layer_ld); |
237 | }; |
238 | const auto gemm_weights_layer |
239 | = [&](const float *A, const float *B, float *C) { |
240 | auto src_layer_ld = rnn.src_layer_ld(cell_position); |
241 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, |
242 | 1.0, A, rnn.scratch_gates_ld, B, src_layer_ld, 1.0, C, |
243 | rnn.diff_weights_layer_ld); |
244 | }; |
245 | const auto gemm_weights_iter |
246 | = [&](const float *A, const float *B, float *C) { |
247 | auto src_iter_ld = rnn.src_iter_ld(cell_position); |
248 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, |
249 | 1.0, A, rnn.scratch_gates_ld, B, src_iter_ld, 1.0, C, |
250 | rnn.diff_weights_iter_ld); |
251 | }; |
252 | const auto gemm_weights_proj |
253 | = [&](const float *A, const float *B, float *C) { |
254 | return gemm('N', 'T', rnn.dlc, rnn.dhc, rnn.mb, 1.0f, A, |
255 | rnn.scratch_diff_ht_ld, B, rnn.ws_ht_ld, 1.0f, C, |
256 | rnn.diff_weights_projection_ld); |
257 | }; |
258 | return common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_proj, |
259 | gemm_weights_layer, gemm_weights_iter, gemm_weights_proj, |
260 | rnn_postgemm_, rnn, cell_position, dst_layer_, dst_iter_c_, |
261 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
262 | diff_src_iter_c_, w_layer_, w_iter_, w_projection_, |
263 | weights_peephole_, bias_, src_layer_, augru_attention_, src_iter_, |
264 | src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
265 | diff_w_layer_, diff_w_iter_, diff_weights_projection_, |
266 | diff_weights_peephole_, diff_bias_, ws_gates_, scratch_gates_, |
267 | proj_ht_, scratch_diff_ht_, ws_grid_, scratch_cell_, dst_iter_); |
268 | } |
269 | |
270 | template <> |
271 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_ref) { |
272 | const auto gemm_layer = [&](const bfloat16_t *A, const bfloat16_t *B, |
273 | float *C) { |
274 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
275 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B, |
276 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld); |
277 | }; |
278 | const auto gemm_iter = [&](const bfloat16_t *A, const bfloat16_t *B, |
279 | float *C) { |
280 | return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, |
281 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_iter_ld, B, |
282 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_iter_ld); |
283 | }; |
284 | const auto gemm_proj = [&](const bfloat16_t *, const float *, float *) { |
285 | assert(!"unimplemented" ); |
286 | return dnnl_unimplemented; |
287 | }; |
288 | const auto gemm_weights_layer |
289 | = [&](const bfloat16_t *A, const bfloat16_t *B, float *C) { |
290 | auto src_layer_ld = rnn.src_layer_ld(cell_position); |
291 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, |
292 | 1.0, A, rnn.scratch_gates_ld, B, src_layer_ld, 1.0, C, |
293 | rnn.diff_weights_layer_ld); |
294 | }; |
295 | const auto gemm_weights_iter |
296 | = [&](const bfloat16_t *A, const bfloat16_t *B, float *C) { |
297 | auto src_iter_ld = rnn.src_iter_ld(cell_position); |
298 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, |
299 | 1.0, A, rnn.scratch_gates_ld, B, src_iter_ld, 1.0, C, |
300 | rnn.diff_weights_iter_ld); |
301 | }; |
302 | const auto gemm_weights_proj |
303 | = [&](const float *, const bfloat16_t *, float *) { |
304 | assert(!"unimplemented" ); |
305 | return dnnl_unimplemented; |
306 | }; |
307 | return common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_proj, |
308 | gemm_weights_layer, gemm_weights_iter, gemm_weights_proj, |
309 | rnn_postgemm_, rnn, cell_position, dst_layer_, dst_iter_c_, |
310 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
311 | diff_src_iter_c_, w_layer_, w_iter_, w_projection_, |
312 | weights_peephole_, bias_, src_layer_, augru_attention_, src_iter_, |
313 | src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
314 | diff_w_layer_, diff_w_iter_, diff_weights_projection_, |
315 | diff_weights_peephole_, diff_bias_, ws_gates_, scratch_gates_, |
316 | proj_ht_, scratch_diff_ht_, ws_grid_, scratch_cell_, dst_iter_); |
317 | } |
318 | |
319 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
320 | data_type_t acc_type> |
321 | rnn_merged_layer_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
322 | acc_type>::merged_layer_execution_ref)) { |
323 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
324 | // If we avoid copying the last iteration, the corresponding |
325 | // input states appear in `dst_iter_` instead of `ws_states_layer`, |
326 | // hence we cannot merge all iterations. |
327 | // This is not applicable for the first layer though, since |
328 | // all the states come from user's `src_layer_`. |
329 | const int n_iter |
330 | = (cell_position & first_layer) && rnn.skip_src_layer_copy() |
331 | ? rnn.n_iter |
332 | : rnn.n_iter - (rnn.skip_dst_iter_copy() ? 1 : 0); |
333 | |
334 | if (aprop == prop_kind::forward) { |
335 | CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, |
336 | rnn.mb * n_iter, rnn.slc, 1.0, w_layer_[0], |
337 | rnn.weights_layer_ld, src_layer_, src_layer_ld, 0.0, |
338 | (gemm_acc_t *)scratch_gates_, rnn.scratch_gates_ld)); |
339 | } else if (aprop == prop_kind::backward) { |
340 | CHECK((this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, |
341 | rnn.n_gates * rnn.dhc, 1.0, w_layer_[0], rnn.weights_layer_ld, |
342 | (gates_t *)scratch_gates_, rnn.scratch_gates_ld, 0.0, |
343 | diff_src_layer_, rnn.ws_diff_states_layer_ld)); |
344 | CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb * n_iter, |
345 | 1.0, (weights_t *)scratch_gates_, rnn.scratch_gates_ld, |
346 | src_layer_, src_layer_ld, 1.0, diff_w_layer_, |
347 | rnn.diff_weights_layer_ld)); |
348 | } else { |
349 | assert(!"unimplemented" ); |
350 | } |
351 | |
352 | return dnnl_success; |
353 | } |
354 | |
355 | template rnn_merged_layer_execution_sig( |
356 | ref_rnn_fwd_f32_t::merged_layer_execution_ref); |
357 | template rnn_merged_layer_execution_sig( |
358 | ref_rnn_fwd_bf16_t::merged_layer_execution_ref); |
359 | template rnn_merged_layer_execution_sig( |
360 | ref_rnn_fwd_u8s8_t::merged_layer_execution_ref); |
361 | template rnn_merged_layer_execution_sig( |
362 | ref_rnn_fwd_s8s8_t::merged_layer_execution_ref); |
363 | template rnn_merged_layer_execution_sig( |
364 | ref_rnn_bwd_f32_t::merged_layer_execution_ref); |
365 | template rnn_merged_layer_execution_sig( |
366 | ref_rnn_bwd_bf16_t::merged_layer_execution_ref); |
367 | |
368 | } // namespace cpu |
369 | } // namespace impl |
370 | } // namespace dnnl |
371 | |