1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution LSTM |
19 | */ |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | |
24 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | using namespace dnnl::impl::utils; |
31 | using namespace dnnl::impl::math; |
32 | using namespace rnn_utils; |
33 | #define AOC array_offset_calculator |
34 | |
35 | template <typename T1, typename T2, typename T3, typename src_data_t, |
36 | typename scratch_data_t> |
37 | void gru_lbr_fwd_postgemm_template(T1 func1, T2 func2, T3 to_src, |
38 | const float *scales, const rnn_utils::rnn_conf_t &rnn, |
39 | rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_, |
40 | scratch_data_t *scratch_gates_, const src_data_t *augru_attention_, |
41 | src_data_t *dst_layer_, src_data_t *dst_iter_, |
42 | const src_data_t *src_iter_, const void *bias_, src_data_t *ws_grid_, |
43 | scratch_data_t *scratch_cell_) { |
44 | |
45 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
46 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position); |
47 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
48 | |
49 | const augru_attention_aoc<const src_data_t> augru_attention( |
50 | rnn, augru_attention_); |
51 | const ws_states_layer_aoc<src_data_t> dst_layer( |
52 | rnn, dst_layer_, dst_layer_ld); |
53 | const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld); |
54 | const ws_states_iter_aoc<const src_data_t> src_iter( |
55 | rnn, src_iter_, src_iter_ld); |
56 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
57 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
58 | const auto bias_aoc = rnn_utils::make_raw_aoc( |
59 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
60 | |
61 | const auto bias = [&](int gate_id, int dhc_id) { |
62 | return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt); |
63 | }; |
64 | const ws_gates_aoc<scratch_data_t> scratch_cell(rnn, scratch_cell_); |
65 | const AOC<src_data_t, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dhc); |
66 | |
67 | const auto get_scales = [](const float *scales, int idx) { |
68 | return scales ? scales + idx : nullptr; |
69 | }; |
70 | const float *scales_G1 = get_scales(scales, 1); |
71 | const float *scales_G2 = get_scales(scales, 2); |
72 | |
73 | parallel_nd(rnn.mb, [&](dim_t i) { |
74 | PRAGMA_OMP_SIMD() |
75 | for (int j = 0; j < rnn.dhc; j++) { |
76 | const float Wh_b = scratch_cell(i, 2, j) + bias(3, j); |
77 | auto G0 = func1(scales, // default func1 is sigmoid |
78 | scratch_gates(i, 0, j) + scratch_cell(i, 0, j) |
79 | + bias(0, j)); |
80 | const auto G1 = func1(scales_G1, // default func1 is sigmoid |
81 | scratch_gates(i, 1, j) + scratch_cell(i, 1, j) |
82 | + bias(1, j)); |
83 | const auto G2 = func2(scales_G2, // default func2 is tanh |
84 | scratch_gates(i, 2, j) + G1 * Wh_b + bias(2, j)); |
85 | if (rnn.is_training) { |
86 | ws_gates(i, 0, j) = to_src(G0); |
87 | ws_gates(i, 1, j) = to_src(G1); |
88 | ws_gates(i, 2, j) = to_src(G2); |
89 | ws_Wh_b(i, j) = to_src(Wh_b); |
90 | } |
91 | if (rnn.is_augru) { |
92 | const auto a = to_src(augru_attention(i)); |
93 | G0 = (1.0f - a) * G0; |
94 | } |
95 | const auto tmp = to_src(src_iter(i, j) * G0 + (1.0f - G0) * G2); |
96 | if (dst_layer_ != nullptr) dst_layer(i, j) = tmp; |
97 | if (dst_iter_ != nullptr) dst_iter(i, j) = tmp; |
98 | } |
99 | }); |
100 | } |
101 | |
102 | template <> |
103 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_lbr_postgemm) { |
104 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
105 | |
106 | const auto linear_f |
107 | = [](const float *scale, float a) { return *scale * a; }; |
108 | const auto logistic_f = [](const float *scale, float a) { |
109 | return logistic_fwd<float>(a); |
110 | }; |
111 | const auto tanh_f |
112 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
113 | const auto to_src = [](float a) { return a; }; |
114 | |
115 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
116 | gru_lbr_fwd_postgemm_template(logistic_f, tanh_f, to_src, scales, rnn, |
117 | cell_position, ws_gates_, scratch_gates_, augru_attention_, |
118 | dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_, |
119 | scratch_cell_); |
120 | else |
121 | gru_lbr_fwd_postgemm_template(linear_f, linear_f, to_src, scales, rnn, |
122 | cell_position, ws_gates_, scratch_gates_, augru_attention_, |
123 | dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_, |
124 | scratch_cell_); |
125 | } |
126 | |
127 | template <> |
128 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_lbr_postgemm) { |
129 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
130 | |
131 | const auto linear_f |
132 | = [](const float *scale, float a) { return *scale * a; }; |
133 | const auto logistic_f = [](const float *scale, float a) { |
134 | return logistic_fwd<float>(a); |
135 | }; |
136 | const auto tanh_f |
137 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
138 | const auto to_src = [](float a) { return bfloat16_t(a); }; |
139 | |
140 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
141 | gru_lbr_fwd_postgemm_template(logistic_f, tanh_f, to_src, scales, rnn, |
142 | cell_position, ws_gates_, scratch_gates_, augru_attention_, |
143 | dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_, |
144 | scratch_cell_); |
145 | else |
146 | gru_lbr_fwd_postgemm_template(linear_f, linear_f, to_src, scales, rnn, |
147 | cell_position, ws_gates_, scratch_gates_, augru_attention_, |
148 | dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_, |
149 | scratch_cell_); |
150 | } |
151 | |
152 | template <> |
153 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_lbr_postgemm) { |
154 | assert(!"GRU LBR int8 is not supported" ); |
155 | } |
156 | |
157 | template <> |
158 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_lbr_postgemm) { |
159 | assert(!"GRU LBR signed int8 is not supported" ); |
160 | } |
161 | |
162 | template <typename T1, typename src_data_t, typename acc_data_t, |
163 | typename scratch_data_t> |
164 | void gru_lbr_bwd_postgemm_template(T1 to_src, const rnn_utils::rnn_conf_t &rnn, |
165 | cell_position_t cell_position, src_data_t *ws_gates_, |
166 | scratch_data_t *scratch_gates_, const src_data_t *augru_attention_, |
167 | const src_data_t *src_iter_, acc_data_t *diff_src_iter_, |
168 | acc_data_t *diff_dst_iter_, acc_data_t *diff_augru_attention_, |
169 | acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_, |
170 | src_data_t *ws_grid_) { |
171 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
172 | |
173 | const augru_attention_aoc<const src_data_t> augru_attention( |
174 | rnn, augru_attention_); |
175 | const augru_attention_aoc<acc_data_t> diff_augru_attention( |
176 | rnn, diff_augru_attention_); |
177 | |
178 | const ws_states_iter_aoc<const src_data_t> src_iter( |
179 | rnn, src_iter_, src_iter_ld); |
180 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
181 | const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
182 | const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter( |
183 | rnn, diff_src_iter_); |
184 | const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter( |
185 | rnn, diff_dst_iter_); |
186 | const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer( |
187 | rnn, diff_dst_layer_); |
188 | const ws_gates_aoc<scratch_data_t> scratch_gates_r(rnn, scratch_cell_); |
189 | const AOC<src_data_t, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dhc); |
190 | |
191 | // 1. calculate dG1 dG2 dG3 |
192 | // dG0 = (dht - G2) * dht * (1 - G0) * G0 |
193 | // dG1 = (W*h + b) * dG2 * (1 - G1) * G1 |
194 | // dG2 = (1 - G0) * dht * (1 - G2*G2) |
195 | parallel_nd(rnn.mb, [&](dim_t i) { |
196 | acc_data_t diff_attention = 0.0f; |
197 | PRAGMA_OMP_SIMD(reduction(+ : diff_attention)) |
198 | for (int j = 0; j < rnn.dhc; j++) { |
199 | const float h = src_iter(i, j); |
200 | const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j); |
201 | float dG0 = (h - ws_gates(i, 2, j)) * dHt |
202 | * x_m_square(ws_gates(i, 0, j)); |
203 | const float dG2 = (1.0f - ws_gates(i, 0, j)) |
204 | * one_m_square(ws_gates(i, 2, j)) * dHt; |
205 | const float dG1 |
206 | = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j)); |
207 | |
208 | if (rnn.is_augru) { |
209 | diff_attention -= dG0 * ws_gates(i, 0, j); |
210 | dG0 *= 1.0f - augru_attention(i); |
211 | } |
212 | |
213 | diff_src_iter(i, j) = dHt * ws_gates(i, 0, j); |
214 | scratch_gates(i, 2, j) = to_src(dG2); |
215 | scratch_gates_r(i, 2, j) = to_src(dG2 * ws_gates(i, 1, j)); |
216 | scratch_gates(i, 0, j) = scratch_gates_r(i, 0, j) = to_src(dG0); |
217 | scratch_gates(i, 1, j) = scratch_gates_r(i, 1, j) = to_src(dG1); |
218 | } |
219 | if (rnn.is_augru) diff_augru_attention(i) = diff_attention; |
220 | }); |
221 | } |
222 | |
223 | template <> |
224 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_lbr_postgemm) { |
225 | auto to_src = [&](float a) { return a; }; |
226 | gru_lbr_bwd_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
227 | scratch_gates_, augru_attention_, src_iter_, diff_src_iter_, |
228 | diff_dst_iter_, diff_augru_attention_, diff_dst_layer_, |
229 | scratch_cell_, ws_grid_); |
230 | } |
231 | |
232 | template <> |
233 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_lbr_postgemm) { |
234 | auto to_src = [&](float a) { return bfloat16_t(a); }; |
235 | gru_lbr_bwd_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
236 | scratch_gates_, augru_attention_, src_iter_, diff_src_iter_, |
237 | diff_dst_iter_, diff_augru_attention_, diff_dst_layer_, |
238 | scratch_cell_, ws_grid_); |
239 | } |
240 | |
241 | } // namespace cpu |
242 | } // namespace impl |
243 | } // namespace dnnl |
244 | |