1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution GRU |
19 | */ |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | |
24 | #include "cpu/rnn/ref_rnn.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | using namespace dnnl::impl::utils; |
31 | using namespace dnnl::impl::math; |
32 | using namespace rnn_utils; |
33 | |
34 | #define AOC array_offset_calculator |
35 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
36 | data_type_t acc_type> |
37 | rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
38 | acc_type>::cell_execution_gru)) { |
39 | const ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_); |
40 | const scratch_gates_aoc<scratch_t> scratch_gates(rnn, scratch_gates_); |
41 | const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
42 | |
43 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
44 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
45 | const auto dst_iter_part2_ld = rnn.dst_iter_part2_ld(cell_position); |
46 | |
47 | // 1. gemm Wx[0-2],x |
48 | if (rnn.need_gemm_layer(cell_position)) { |
49 | CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb, |
50 | rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, src_layer_, |
51 | src_layer_ld, 0.0f, scratch_gates_, rnn.scratch_gates_ld)); |
52 | } |
53 | |
54 | // 2. gemm Wh[0-1],h |
55 | CHECK((this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dhc, rnn.mb, |
56 | rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, src_iter_, |
57 | src_iter_ld, 1.0f, scratch_gates_, rnn.scratch_gates_ld)); |
58 | |
59 | // 3. activation zt and rt + elemwise multiplication rt,ht-1 |
60 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
61 | augru_attention_, dst_layer_, nullptr, src_iter_, nullptr, |
62 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr, |
63 | diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, bias_[0], |
64 | nullptr, nullptr, dst_iter_, weights_scales, rnn.dhc); |
65 | |
66 | // 4. gemm Wh[2],h~t |
67 | CHECK((this->*gemm_iter_func)('N', 'N', rnn.dhc, rnn.mb, rnn.sic, 1.0, |
68 | w_iter_[1], rnn.weights_iter_ld, dst_layer_, dst_iter_part2_ld, 1.0, |
69 | &(scratch_gates(0, 2, 0)), rnn.scratch_gates_ld)); |
70 | |
71 | // 5. activation h~t + calculate ht |
72 | rnn_postgemm_->execute_part2(rnn, cell_position, ws_gates_, scratch_gates_, |
73 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_, |
74 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr, |
75 | diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, bias_[0], |
76 | nullptr, nullptr, dst_iter_, weights_scales, rnn.dhc); |
77 | |
78 | return dnnl_success; |
79 | } |
80 | |
81 | template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); |
82 | template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru); |
83 | template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); |
84 | template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru); |
85 | |
86 | template <typename T1, typename T2, typename T3, typename T4, typename T5, |
87 | typename weights_data_t, typename src_data_t, typename acc_data_t, |
88 | typename scratch_data_t> |
89 | dnnl_status_t gru_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f, |
90 | T3 gemm_weights_layer_f, T4 gemm_weights_iter_f, T5 rnn_postgemm_, |
91 | const rnn_utils::rnn_conf_t &rnn, cell_position_t cell_position, |
92 | src_data_t *ws_gates_, scratch_data_t *scratch_gates_, |
93 | src_data_t *dst_layer_, const src_data_t *src_iter_, |
94 | const src_data_t *src_layer_, const src_data_t *augru_attention_, |
95 | weights_data_t **w_layer_, weights_data_t **w_iter_, |
96 | acc_data_t *diff_w_layer_, acc_data_t *diff_w_iter_, |
97 | acc_data_t *diff_src_layer_, acc_data_t *diff_augru_attention_, |
98 | acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_, |
99 | acc_data_t *diff_dst_layer_, acc_data_t *diff_bias_, |
100 | scratch_data_t *scratch_cell_, src_data_t *dst_iter_) { |
101 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
102 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
103 | |
104 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
105 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
106 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position); |
107 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
108 | const ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_, |
109 | (cell_position & last_layer) ? dst_layer_ld : dst_iter_ld); |
110 | const ws_states_iter_aoc<const src_data_t> src_iter( |
111 | rnn, src_iter_, src_iter_ld); |
112 | const ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_); |
113 | |
114 | // use state memory for intermediate computations |
115 | // TODO: use cell ws for that |
116 | float *dhG1_ = diff_src_layer_; |
117 | const AOC<acc_data_t, 2> dhG1( |
118 | dhG1_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld); |
119 | // hg1 needs to be bf16 as it is used as gemm output |
120 | // hence it cannot alias to dhG1, and should use scratch_cell |
121 | const AOC<scratch_data_t, 2> hG1( |
122 | scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld); |
123 | |
124 | // 1. calculate dG2, dG1, and part of dht-1 |
125 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
126 | augru_attention_, dst_layer_, nullptr, src_iter_, nullptr, |
127 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr, |
128 | diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, nullptr, nullptr, |
129 | scratch_cell_, dst_iter_, nullptr, 0); |
130 | |
131 | // 2. calculate intermediate d(hG1) |
132 | // d(hG1) = dG2 * W2h^t |
133 | CHECK(gemm_iter_f(rnn.sic, rnn.mb, rnn.dhc, w_iter_[1], |
134 | &(scratch_gates(0, 2, 0)), 0.0f, dhG1_)); |
135 | |
136 | // 3. calculate dG1^ and part of dht-1 |
137 | rnn_postgemm_->execute_part2(rnn, cell_position, ws_gates_, scratch_gates_, |
138 | augru_attention_, dst_layer_, nullptr, src_iter_, nullptr, |
139 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr, |
140 | diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, nullptr, nullptr, |
141 | scratch_cell_, dst_iter_, nullptr, 0); |
142 | |
143 | // 4. calculate diff weights |
144 | // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h) |
145 | CHECK(gemm_weights_iter_f((rnn.n_gates - 1) * rnn.dhc, rnn.sic, rnn.mb, |
146 | scratch_gates_, src_iter_, src_iter_ld, 1.0f, diff_w_iter_)); |
147 | CHECK(gemm_weights_iter_f(rnn.dhc, rnn.sic, rnn.mb, |
148 | &(scratch_gates(0, 2, 0)), scratch_cell_, rnn.ws_states_layer_ld, |
149 | 1.0f, &(diff_w_iter(0, 2, 0)))); |
150 | |
151 | // 5. calculate diff states |
152 | // dht-1 += dG1 * W1h + dG0 * W0h |
153 | CHECK(gemm_iter_f(rnn.sic, rnn.mb, (rnn.n_gates - 1) * rnn.dhc, w_iter_[0], |
154 | scratch_gates_, 1.0f, diff_src_iter_)); |
155 | |
156 | // dWx += [dG0 dG1 dG2] * [x] |
157 | if (rnn.need_gemm_layer(cell_position)) |
158 | CHECK(gemm_weights_layer_f( |
159 | scratch_gates_, src_layer_, src_layer_ld, diff_w_layer_)); |
160 | |
161 | // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x |
162 | if (!rnn.merge_gemm_layer) |
163 | CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_)); |
164 | |
165 | // 6. calculate diff bias |
166 | gates_reduction(rnn, scratch_gates_, diff_bias_); |
167 | |
168 | return dnnl_success; |
169 | } |
170 | |
171 | template <> |
172 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) { |
173 | auto gemm_iter_f |
174 | = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B, |
175 | float beta, gemm_acc_t *C) { |
176 | return (this->*gemm_iter_func)('N', 'N', m, n, k, 1.0f, A, |
177 | rnn.weights_iter_ld, B, rnn.scratch_gates_ld, beta, C, |
178 | rnn.ws_diff_states_iter_ld); |
179 | }; |
180 | auto gemm_layer_f = [&](const weights_t *A, const gemm_data_t *B, |
181 | gemm_acc_t *C) { |
182 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
183 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B, |
184 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld); |
185 | }; |
186 | auto gemm_weights_layer_f = [&](const gemm_data_t *A, const weights_t *B, |
187 | int ldb, gemm_acc_t *C) { |
188 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, 1.0, A, |
189 | rnn.scratch_gates_ld, B, ldb, 1.0, C, |
190 | rnn.diff_weights_layer_ld); |
191 | }; |
192 | auto gemm_weights_iter_f |
193 | = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B, |
194 | int ldb, float beta, gemm_acc_t *C) { |
195 | return gemm('N', 'T', m, n, k, 1.0f, A, rnn.ws_gates_ld, B, |
196 | ldb, 1.0f, C, rnn.diff_weights_iter_ld); |
197 | }; |
198 | |
199 | return gru_bwd_cell_exec_template(gemm_layer_f, gemm_iter_f, |
200 | gemm_weights_layer_f, gemm_weights_iter_f, this->rnn_postgemm_, rnn, |
201 | cell_position, ws_gates_, scratch_gates_, dst_layer_, src_iter_, |
202 | src_layer_, augru_attention_, w_layer_, w_iter_, diff_w_layer_, |
203 | diff_w_iter_, diff_src_layer_, diff_augru_attention_, |
204 | diff_src_iter_, diff_dst_iter_, diff_dst_layer_, diff_bias_, |
205 | scratch_cell_, dst_iter_); |
206 | } |
207 | |
208 | template <> |
209 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru) { |
210 | auto gemm_iter_f |
211 | = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B, |
212 | float beta, gemm_acc_t *C) { |
213 | return (this->*gemm_iter_func)('N', 'N', m, n, k, 1.0f, A, |
214 | rnn.weights_iter_ld, B, rnn.scratch_gates_ld, beta, C, |
215 | rnn.ws_diff_states_iter_ld); |
216 | }; |
217 | auto gemm_layer_f = [&](const weights_t *A, const gemm_data_t *B, |
218 | gemm_acc_t *C) { |
219 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
220 | rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B, |
221 | rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld); |
222 | }; |
223 | auto gemm_weights_layer_f = [&](const gemm_data_t *A, const weights_t *B, |
224 | int ldb, gemm_acc_t *C) { |
225 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, 1.0, A, |
226 | rnn.scratch_gates_ld, B, ldb, 1.0, C, |
227 | rnn.diff_weights_layer_ld); |
228 | }; |
229 | auto gemm_weights_iter_f |
230 | = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B, |
231 | int ldb, float beta, gemm_acc_t *C) { |
232 | return gemm('N', 'T', m, n, k, 1.0f, A, rnn.ws_gates_ld, B, |
233 | ldb, 1.0f, C, rnn.diff_weights_iter_ld); |
234 | }; |
235 | |
236 | return gru_bwd_cell_exec_template(gemm_layer_f, gemm_iter_f, |
237 | gemm_weights_layer_f, gemm_weights_iter_f, this->rnn_postgemm_, rnn, |
238 | cell_position, ws_gates_, scratch_gates_, dst_layer_, src_iter_, |
239 | src_layer_, augru_attention_, w_layer_, w_iter_, diff_w_layer_, |
240 | diff_w_iter_, diff_src_layer_, diff_augru_attention_, |
241 | diff_src_iter_, diff_dst_iter_, diff_dst_layer_, diff_bias_, |
242 | scratch_cell_, dst_iter_); |
243 | } |
244 | |
245 | #undef AOC |
246 | } // namespace cpu |
247 | } // namespace impl |
248 | } // namespace dnnl |
249 | |