1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution GRU with linear before reset |
19 | */ |
20 | #pragma warning(disable : 4503) /* name is too long */ |
21 | |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | |
25 | #include "cpu/rnn/ref_rnn.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | using namespace dnnl::impl::utils; |
32 | using namespace dnnl::impl::math; |
33 | using namespace rnn_utils; |
34 | #define AOC array_offset_calculator |
35 | |
36 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
37 | data_type_t acc_type> |
38 | rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
39 | acc_type>::cell_execution_gru_lbr)) { |
40 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
41 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
42 | |
43 | if (rnn.need_gemm_layer(cell_position)) { |
44 | CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb, |
45 | rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, src_layer_, |
46 | src_layer_ld, 0.0, scratch_gates_, rnn.scratch_gates_ld)); |
47 | } |
48 | CHECK((this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb, |
49 | rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, src_iter_, |
50 | src_iter_ld, 0.0, scratch_cell_, rnn.ws_gates_ld)); |
51 | |
52 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
53 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_, |
54 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, |
55 | diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, |
56 | bias_[0], ws_grid_, scratch_cell_, dst_iter_, nullptr, 0); |
57 | |
58 | return dnnl_success; |
59 | } |
60 | |
61 | template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); |
62 | template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru_lbr); |
63 | template <> |
64 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) { |
65 | assert(!"GRU LBR int8 is not supported" ); |
66 | return dnnl_unimplemented; |
67 | } |
68 | |
69 | template <> |
70 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru_lbr) { |
71 | assert(!"GRU LBR int8 is not supported" ); |
72 | return dnnl_unimplemented; |
73 | } |
74 | |
75 | template <typename T1, typename T2, typename T3, typename T4, typename T5, |
76 | typename weights_data_t, typename src_data_t, typename acc_data_t, |
77 | typename scratch_data_t> |
78 | dnnl_status_t common_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f, |
79 | T3 gemm_weights_layer_f, T4 gemm_weights_iter_f, T5 rnn_postgemm, |
80 | const rnn_utils::rnn_conf_t &rnn, cell_position_t cell_position, |
81 | src_data_t *dst_layer_, acc_data_t *diff_src_layer_, |
82 | acc_data_t *diff_augru_attention_, acc_data_t *diff_src_iter_, |
83 | weights_data_t **w_layer_, weights_data_t **w_iter_, void **bias_, |
84 | const src_data_t *src_layer_, const src_data_t *augru_attention_, |
85 | const src_data_t *src_iter_, acc_data_t *diff_dst_layer_, |
86 | acc_data_t *diff_dst_iter_, acc_data_t *diff_w_layer_, |
87 | acc_data_t *diff_w_iter_, acc_data_t *diff_bias_, src_data_t *ws_gates_, |
88 | src_data_t *ws_grid_, scratch_data_t *scratch_gates_, |
89 | scratch_data_t *scratch_cell_, src_data_t *dst_iter_) { |
90 | const auto src_layer_ld = rnn.src_layer_ld(cell_position); |
91 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
92 | |
93 | const ws_gates_aoc<scratch_data_t> scratch_gates_r(rnn, scratch_cell_); |
94 | |
95 | rnn_postgemm->execute(rnn, cell_position, ws_gates_, scratch_gates_, |
96 | augru_attention_, dst_layer_, nullptr, src_iter_, nullptr, nullptr, |
97 | diff_augru_attention_, diff_src_iter_, nullptr, diff_dst_layer_, |
98 | diff_dst_iter_, nullptr, nullptr, bias_[0], ws_grid_, scratch_cell_, |
99 | dst_iter_, nullptr, 0); |
100 | |
101 | // dWx += dG^t * x |
102 | if (rnn.need_gemm_layer(cell_position)) |
103 | CHECK(gemm_weights_layer_f( |
104 | scratch_gates_, src_layer_, src_layer_ld, diff_w_layer_)); |
105 | |
106 | // dx = dG * Wx^t |
107 | if (!rnn.merge_gemm_layer) |
108 | CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_)); |
109 | |
110 | // dh += dGr * Wh^t |
111 | CHECK(gemm_iter_f(w_iter_[0], scratch_cell_, diff_src_iter_)); |
112 | |
113 | // dWh += dGr^t * h |
114 | CHECK(gemm_weights_iter_f( |
115 | scratch_cell_, src_iter_, src_iter_ld, diff_w_iter_)); |
116 | |
117 | // db1-3 += e * dG |
118 | // db4 += e * (r * dG2) |
119 | gates_reduction(rnn, scratch_gates_, diff_bias_); |
120 | |
121 | parallel_nd(rnn.dhc, [&](dim_t j) { |
122 | for (int i = 0; i < rnn.mb; i++) { |
123 | diff_bias_[3 * rnn.dhc + j] += scratch_gates_r(i, 2, j); |
124 | } |
125 | }); |
126 | |
127 | return dnnl_success; |
128 | } |
129 | |
130 | #undef AOC |
131 | |
132 | template <> |
133 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) { |
134 | const auto gemm_layer = [&](const float *A, const float *B, float *C) { |
135 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
136 | rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_layer_ld, B, |
137 | rnn.scratch_gates_ld, 0.0f, C, rnn.ws_diff_states_layer_ld); |
138 | }; |
139 | const auto gemm_iter = [&](const float *A, const float *B, float *C) { |
140 | return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, |
141 | rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_iter_ld, B, |
142 | rnn.ws_gates_ld, 1.0f, C, rnn.ws_diff_states_iter_ld); |
143 | }; |
144 | const auto gemm_weights_layer |
145 | = [&](const float *A, const float *B, int ldb, float *C) { |
146 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, |
147 | 1.0f, A, rnn.scratch_gates_ld, B, ldb, 1.0f, C, |
148 | rnn.diff_weights_layer_ld); |
149 | }; |
150 | const auto gemm_weights_iter = [&](const float *A, const float *B, int ldb, |
151 | float *C) { |
152 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, 1.0f, A, |
153 | rnn.ws_gates_ld, B, ldb, 1.0f, C, rnn.diff_weights_iter_ld); |
154 | }; |
155 | |
156 | common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_weights_layer, |
157 | gemm_weights_iter, rnn_postgemm_, rnn, cell_position, dst_layer_, |
158 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, w_layer_, |
159 | w_iter_, bias_, src_layer_, augru_attention_, src_iter_, |
160 | diff_dst_layer_, diff_dst_iter_, diff_w_layer_, diff_w_iter_, |
161 | diff_bias_, ws_gates_, ws_grid_, scratch_gates_, scratch_cell_, |
162 | dst_iter_); |
163 | |
164 | return dnnl_success; |
165 | } |
166 | |
167 | template <> |
168 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru_lbr) { |
169 | const auto gemm_layer = [&](const bfloat16_t *A, const bfloat16_t *B, |
170 | float *C) { |
171 | return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb, |
172 | rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_layer_ld, B, |
173 | rnn.scratch_gates_ld, 0.0f, C, rnn.ws_diff_states_layer_ld); |
174 | }; |
175 | const auto gemm_iter = [&](const bfloat16_t *A, const bfloat16_t *B, |
176 | float *C) { |
177 | return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, |
178 | rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_iter_ld, B, |
179 | rnn.ws_gates_ld, 1.0f, C, rnn.ws_diff_states_iter_ld); |
180 | }; |
181 | const auto gemm_weights_layer |
182 | = [&](const bfloat16_t *A, const bfloat16_t *B, int ldb, float *C) { |
183 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, |
184 | 1.0f, A, rnn.scratch_gates_ld, B, ldb, 1.0f, C, |
185 | rnn.diff_weights_layer_ld); |
186 | }; |
187 | const auto gemm_weights_iter = [&](const bfloat16_t *A, const bfloat16_t *B, |
188 | int ldb, float *C) { |
189 | return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, 1.0f, A, |
190 | rnn.ws_gates_ld, B, ldb, 1.0f, C, rnn.diff_weights_iter_ld); |
191 | }; |
192 | |
193 | common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_weights_layer, |
194 | gemm_weights_iter, rnn_postgemm_, rnn, cell_position, dst_layer_, |
195 | diff_src_layer_, diff_augru_attention_, diff_src_iter_, w_layer_, |
196 | w_iter_, bias_, src_layer_, augru_attention_, src_iter_, |
197 | diff_dst_layer_, diff_dst_iter_, diff_w_layer_, diff_w_iter_, |
198 | diff_bias_, ws_gates_, ws_grid_, scratch_gates_, scratch_cell_, |
199 | dst_iter_); |
200 | return dnnl_success; |
201 | } |
202 | |
203 | } // namespace cpu |
204 | } // namespace impl |
205 | } // namespace dnnl |
206 | |