1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Common for RNN and LSTM cell execution
19 */
20#include "common/bfloat16.hpp"
21#include "common/dnnl_thread.hpp"
22
23#include "cpu/rnn/ref_rnn.hpp"
24#include "cpu/simple_q10n.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29using namespace rnn_utils;
30using namespace dnnl::impl::utils;
31
32template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
33 data_type_t acc_type>
34rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
35 acc_type>::cell_execution_ref)) {
36 const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
37 const auto weights_projection_scales = rnn.is_lstm_projection
38 ? pd_->attr()->rnn_weights_projection_qparams_.scales_
39 : nullptr;
40
41 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
42 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
43
44 if (rnn.need_gemm_layer(cell_position)) {
45 CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb,
46 rnn.slc, 1.0f, w_layer_[0], rnn.weights_layer_ld, src_layer_,
47 src_layer_ld, 0.0f, scratch_gates_, rnn.scratch_gates_ld));
48 }
49 CHECK((this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb,
50 rnn.sic, 1.0f, w_iter_[0], rnn.weights_iter_ld, src_iter_,
51 src_iter_ld, 1.0f, scratch_gates_, rnn.scratch_gates_ld));
52
53 // Note: here proj_ht is scratchpad if inference or workspace if training
54 const auto dst_postgemm = rnn.is_lstm_projection ? proj_ht_ : dst_layer_;
55 // for lstmp, the copy to dst_iter happens after the projection
56 const auto dst_iter_postgemm = rnn.is_lstm_projection ? nullptr : dst_iter_;
57 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
58 augru_attention_, dst_postgemm, dst_iter_c_, src_iter_, src_iter_c_,
59 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
60 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
61 weights_peephole_, bias_[0], ws_grid_, scratch_cell_,
62 dst_iter_postgemm, weights_scales, rnn.dhc * sizeof(scratch_t));
63
64 if (rnn.is_lstm_projection) {
65 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
66
67 // Here, because the accumulation type is different
68 // than dst_layer, we have to use scratch to hold temporary
69 // accumulators
70 assert(rnn.scratch_gates_ld >= rnn.dlc);
71 gemm_acc_t *dst_proj = rnn.dt_conf == all_f32 ? (gemm_acc_t *)dst_layer_
72 : scratch_gates_;
73 const int dst_proj_ld
74 = rnn.dt_conf == all_f32 ? dst_layer_ld : rnn.scratch_gates_ld;
75
76 CHECK((this->*gemm_projection_func)('N', 'N', rnn.dic, rnn.mb, rnn.dhc,
77 1.0f, w_projection_[0], rnn.weights_projection_ld, dst_postgemm,
78 rnn.proj_ht_ld, 0.0f, dst_proj, dst_proj_ld));
79
80 // we have to downconvert the output to dst_layer_t and copy to dst_iter if needed
81 rnn_postgemm_->execute_part2(rnn, cell_position, nullptr, dst_proj,
82 nullptr, dst_layer_, nullptr, nullptr, w_proj_comp, nullptr,
83 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
84 nullptr, nullptr, nullptr, dst_iter_, weights_projection_scales,
85 rnn.dlc * sizeof(dst_layer_t));
86 }
87
88 return dnnl_success;
89}
90
91template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_ref);
92template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_ref);
93template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_ref);
94template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_ref);
95
96template <typename scratch_data_t, typename acc_data_t>
97void lstm_bwd_weights_peephole_and_bias(const rnn_utils::rnn_conf_t &rnn,
98 cell_position_t cell_position, const void *src_iter_c_,
99 const void *dst_iter_c_, const scratch_data_t *scratch_gates_,
100 float *diff_weights_peephole_, acc_data_t *diff_bias_) {
101 const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position);
102 const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position);
103
104 const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_,
105 types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld,
106 dst_iter_c_ld);
107 const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_,
108 types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld,
109 src_iter_c_ld);
110
111 const ws_gates_aoc<const scratch_data_t> scratch_gates(rnn, scratch_gates_);
112 const weights_peephole_aoc_t<float> diff_weights_peephole(
113 rnn, diff_weights_peephole_);
114
115 parallel(0, [&](int ithr, int nthr) {
116 int g_dhc_start {}, g_dhc_stop {};
117 const int gates_to_process = 5; // 3 -- weights peephole +
118 // 2 -- bias (process a pair at once)
119 balance211(gates_to_process * rnn.dhc, nthr, ithr, g_dhc_start,
120 g_dhc_stop);
121 int g = g_dhc_start / rnn.dhc;
122 int dhc = g_dhc_start % rnn.dhc;
123 while (g_dhc_start++ < g_dhc_stop) {
124 if (g < 3) {
125 // weights peephole
126 auto &c_states = g < 2 ? src_iter_c : dst_iter_c;
127 const auto c_states_dt
128 = g < 2 ? rnn.src_iter_c_dt : rnn.dst_iter_c_dt;
129
130 const int scratch_g = g < 2 ? g : 3;
131 for (int mb = 0; mb < rnn.mb; ++mb) {
132 diff_weights_peephole(g, dhc)
133 += to_float(c_states(mb, dhc), c_states_dt)
134 * scratch_gates(mb, scratch_g, dhc);
135 }
136 } else {
137 // bias
138 const int bias_g_start = 2 * (g - 3);
139 const int bias_g_end = bias_g_start + 2;
140 for_(int bias_g = bias_g_start; bias_g < bias_g_end; ++bias_g)
141 for (int mb = 0; mb < rnn.mb; ++mb)
142 diff_bias_[bias_g * rnn.dhc + dhc]
143 += scratch_gates(mb, bias_g, dhc);
144 }
145 if (++dhc == rnn.dhc) {
146 dhc = 0;
147 g++;
148 }
149 }
150 });
151}
152
153template <typename T1, typename T2, typename T3, typename T4, typename T5,
154 typename T6, typename T7, typename weights_data_t, typename src_data_t,
155 typename acc_data_t, typename scratch_data_t>
156dnnl_status_t common_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f,
157 T3 gemm_proj_f, T4 gemm_weights_layer_f, T5 gemm_weights_iter_f,
158 T6 gemm_weights_proj_f, T7 rnn_postgemm,
159 const rnn_utils::rnn_conf_t &rnn, const cell_position_t cell_position,
160 src_data_t *dst_layer_, void *dst_iter_c_, acc_data_t *diff_src_layer_,
161 acc_data_t *diff_augru_attention_, acc_data_t *diff_src_iter_,
162 acc_data_t *diff_src_iter_c_, weights_data_t **w_layer_,
163 weights_data_t **w_iter_, weights_data_t **w_proj_,
164 const float *weights_peephole_, void **bias_,
165 const src_data_t *src_layer_, const src_data_t *augru_attention_,
166 const src_data_t *src_iter_, const void *src_iter_c_,
167 acc_data_t *diff_dst_layer_, acc_data_t *diff_dst_iter_,
168 acc_data_t *diff_dst_iter_c_, acc_data_t *diff_w_layer_,
169 acc_data_t *diff_w_iter_, float *diff_weights_projection_,
170 float *diff_weights_peephole_, acc_data_t *diff_bias_,
171 src_data_t *ws_gates_, scratch_data_t *scratch_gates_,
172 src_data_t *ws_ht_, acc_data_t *scratch_diff_ht_, src_data_t *ws_grid_,
173 scratch_data_t *scratch_cell_, src_data_t *dst_iter_) {
174
175 if (rnn.is_lstm_projection) {
176 parallel_nd(rnn.mb, [&](dim_t i) {
177 PRAGMA_OMP_SIMD()
178 for (int j = 0; j < rnn.dlc; j++)
179 scratch_diff_ht_[i * rnn.scratch_diff_ht_ld + j]
180 = diff_dst_layer_[i * rnn.ws_diff_states_layer_ld + j]
181 + diff_dst_iter_[i * rnn.ws_diff_states_iter_ld + j];
182 });
183
184 CHECK(gemm_weights_proj_f(
185 scratch_diff_ht_, ws_ht_, diff_weights_projection_));
186 CHECK(gemm_proj_f(w_proj_[0], scratch_diff_ht_, diff_dst_layer_));
187 }
188
189 rnn_postgemm->execute(rnn, cell_position, ws_gates_, scratch_gates_,
190 augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_,
191 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
192 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
193 weights_peephole_, bias_[0], ws_grid_, scratch_cell_, dst_iter_,
194 nullptr, 0);
195
196 /// bwd by data on the cell
197 CHECK(gemm_iter_f(w_iter_[0], scratch_gates_, diff_src_iter_));
198
199 /// bwd by weights on the cell
200 if (rnn.need_gemm_layer(cell_position))
201 CHECK(gemm_weights_layer_f(scratch_gates_, src_layer_, diff_w_layer_));
202
203 if (!rnn.merge_gemm_layer)
204 CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_));
205
206 if (!rnn.merge_gemm_iter)
207 CHECK(gemm_weights_iter_f(scratch_gates_, src_iter_, diff_w_iter_));
208
209 if (rnn.is_lstm_peephole) {
210 /// bwd by weights peephole and bias
211 lstm_bwd_weights_peephole_and_bias(rnn, cell_position, src_iter_c_,
212 dst_iter_c_, scratch_gates_, diff_weights_peephole_,
213 diff_bias_);
214 } else {
215 /// bwd by bias we just accumulate diffs from the gates
216 gates_reduction(rnn, scratch_gates_, diff_bias_);
217 }
218 return dnnl_success;
219}
220
221template <>
222rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_ref) {
223 const auto gemm_layer = [&](const float *A, const float *B, float *C) {
224 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
225 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B,
226 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld);
227 };
228 const auto gemm_iter = [&](const float *A, const float *B, float *C) {
229 return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
230 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_iter_ld, B,
231 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_iter_ld);
232 };
233 const auto gemm_proj = [&](const float *A, const float *B, float *C) {
234 return (this->*gemm_projection_func)('N', 'N', rnn.dhc, rnn.mb, rnn.dic,
235 1.0, A, rnn.weights_projection_ld, B, rnn.scratch_diff_ht_ld,
236 0.0f, C, rnn.ws_diff_states_layer_ld);
237 };
238 const auto gemm_weights_layer
239 = [&](const float *A, const float *B, float *C) {
240 auto src_layer_ld = rnn.src_layer_ld(cell_position);
241 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb,
242 1.0, A, rnn.scratch_gates_ld, B, src_layer_ld, 1.0, C,
243 rnn.diff_weights_layer_ld);
244 };
245 const auto gemm_weights_iter
246 = [&](const float *A, const float *B, float *C) {
247 auto src_iter_ld = rnn.src_iter_ld(cell_position);
248 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb,
249 1.0, A, rnn.scratch_gates_ld, B, src_iter_ld, 1.0, C,
250 rnn.diff_weights_iter_ld);
251 };
252 const auto gemm_weights_proj
253 = [&](const float *A, const float *B, float *C) {
254 return gemm('N', 'T', rnn.dlc, rnn.dhc, rnn.mb, 1.0f, A,
255 rnn.scratch_diff_ht_ld, B, rnn.ws_ht_ld, 1.0f, C,
256 rnn.diff_weights_projection_ld);
257 };
258 return common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_proj,
259 gemm_weights_layer, gemm_weights_iter, gemm_weights_proj,
260 rnn_postgemm_, rnn, cell_position, dst_layer_, dst_iter_c_,
261 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
262 diff_src_iter_c_, w_layer_, w_iter_, w_projection_,
263 weights_peephole_, bias_, src_layer_, augru_attention_, src_iter_,
264 src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
265 diff_w_layer_, diff_w_iter_, diff_weights_projection_,
266 diff_weights_peephole_, diff_bias_, ws_gates_, scratch_gates_,
267 proj_ht_, scratch_diff_ht_, ws_grid_, scratch_cell_, dst_iter_);
268}
269
270template <>
271rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_ref) {
272 const auto gemm_layer = [&](const bfloat16_t *A, const bfloat16_t *B,
273 float *C) {
274 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
275 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B,
276 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld);
277 };
278 const auto gemm_iter = [&](const bfloat16_t *A, const bfloat16_t *B,
279 float *C) {
280 return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
281 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_iter_ld, B,
282 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_iter_ld);
283 };
284 const auto gemm_proj = [&](const bfloat16_t *, const float *, float *) {
285 assert(!"unimplemented");
286 return dnnl_unimplemented;
287 };
288 const auto gemm_weights_layer
289 = [&](const bfloat16_t *A, const bfloat16_t *B, float *C) {
290 auto src_layer_ld = rnn.src_layer_ld(cell_position);
291 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb,
292 1.0, A, rnn.scratch_gates_ld, B, src_layer_ld, 1.0, C,
293 rnn.diff_weights_layer_ld);
294 };
295 const auto gemm_weights_iter
296 = [&](const bfloat16_t *A, const bfloat16_t *B, float *C) {
297 auto src_iter_ld = rnn.src_iter_ld(cell_position);
298 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb,
299 1.0, A, rnn.scratch_gates_ld, B, src_iter_ld, 1.0, C,
300 rnn.diff_weights_iter_ld);
301 };
302 const auto gemm_weights_proj
303 = [&](const float *, const bfloat16_t *, float *) {
304 assert(!"unimplemented");
305 return dnnl_unimplemented;
306 };
307 return common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_proj,
308 gemm_weights_layer, gemm_weights_iter, gemm_weights_proj,
309 rnn_postgemm_, rnn, cell_position, dst_layer_, dst_iter_c_,
310 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
311 diff_src_iter_c_, w_layer_, w_iter_, w_projection_,
312 weights_peephole_, bias_, src_layer_, augru_attention_, src_iter_,
313 src_iter_c_, diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
314 diff_w_layer_, diff_w_iter_, diff_weights_projection_,
315 diff_weights_peephole_, diff_bias_, ws_gates_, scratch_gates_,
316 proj_ht_, scratch_diff_ht_, ws_grid_, scratch_cell_, dst_iter_);
317}
318
319template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
320 data_type_t acc_type>
321rnn_merged_layer_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
322 acc_type>::merged_layer_execution_ref)) {
323 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
324 // If we avoid copying the last iteration, the corresponding
325 // input states appear in `dst_iter_` instead of `ws_states_layer`,
326 // hence we cannot merge all iterations.
327 // This is not applicable for the first layer though, since
328 // all the states come from user's `src_layer_`.
329 const int n_iter
330 = (cell_position & first_layer) && rnn.skip_src_layer_copy()
331 ? rnn.n_iter
332 : rnn.n_iter - (rnn.skip_dst_iter_copy() ? 1 : 0);
333
334 if (aprop == prop_kind::forward) {
335 CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc,
336 rnn.mb * n_iter, rnn.slc, 1.0, w_layer_[0],
337 rnn.weights_layer_ld, src_layer_, src_layer_ld, 0.0,
338 (gemm_acc_t *)scratch_gates_, rnn.scratch_gates_ld));
339 } else if (aprop == prop_kind::backward) {
340 CHECK((this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
341 rnn.n_gates * rnn.dhc, 1.0, w_layer_[0], rnn.weights_layer_ld,
342 (gates_t *)scratch_gates_, rnn.scratch_gates_ld, 0.0,
343 diff_src_layer_, rnn.ws_diff_states_layer_ld));
344 CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb * n_iter,
345 1.0, (weights_t *)scratch_gates_, rnn.scratch_gates_ld,
346 src_layer_, src_layer_ld, 1.0, diff_w_layer_,
347 rnn.diff_weights_layer_ld));
348 } else {
349 assert(!"unimplemented");
350 }
351
352 return dnnl_success;
353}
354
355template rnn_merged_layer_execution_sig(
356 ref_rnn_fwd_f32_t::merged_layer_execution_ref);
357template rnn_merged_layer_execution_sig(
358 ref_rnn_fwd_bf16_t::merged_layer_execution_ref);
359template rnn_merged_layer_execution_sig(
360 ref_rnn_fwd_u8s8_t::merged_layer_execution_ref);
361template rnn_merged_layer_execution_sig(
362 ref_rnn_fwd_s8s8_t::merged_layer_execution_ref);
363template rnn_merged_layer_execution_sig(
364 ref_rnn_bwd_f32_t::merged_layer_execution_ref);
365template rnn_merged_layer_execution_sig(
366 ref_rnn_bwd_bf16_t::merged_layer_execution_ref);
367
368} // namespace cpu
369} // namespace impl
370} // namespace dnnl
371