1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution LSTM |
19 | */ |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | |
24 | #include "cpu/simple_q10n.hpp" |
25 | |
26 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | |
32 | using namespace dnnl::impl::utils; |
33 | using namespace dnnl::impl::math; |
34 | using namespace rnn_utils; |
35 | |
36 | template <typename T1, typename T2, typename T3, typename T4, |
37 | typename src_data_t, typename scratch_data_t> |
38 | void lstm_fwd_postgemm_template(T1 func1, T2 func2, T3 to_src_dt, T4 to_float, |
39 | const float *scales, const float *cscale, |
40 | const rnn_utils::rnn_conf_t &rnn, |
41 | rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_, |
42 | scratch_data_t *scratch_gates_, src_data_t *dst_layer_, |
43 | src_data_t *dst_iter_, void *dst_iter_c_, const src_data_t *src_iter_, |
44 | const void *src_iter_c_, const float *weights_peephole_, |
45 | const void *bias_, int block_step) { |
46 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
47 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
48 | const weights_peephole_aoc_t<const float> weights_peephole( |
49 | rnn, weights_peephole_); |
50 | const auto bias_aoc = rnn_utils::make_raw_aoc( |
51 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
52 | const auto bias = [&](int gate_id, int dhc_id) { |
53 | return rnn_utils::to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt); |
54 | }; |
55 | // If lstmp, instead of dst_layer, we use scratch_ht if inference or ws_ht if training |
56 | const auto dst_layer_ld = rnn.is_lstm_projection |
57 | ? rnn.scratch_ht_ld |
58 | : rnn.dst_layer_ld(cell_position); |
59 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
60 | const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position); |
61 | const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position); |
62 | |
63 | const ws_states_layer_aoc<src_data_t> dst_layer( |
64 | rnn, dst_layer_, dst_layer_ld); |
65 | // TODO: we use scratch and not dst_iter for lstmp |
66 | const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld); |
67 | |
68 | const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_, |
69 | types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld, |
70 | dst_iter_c_ld); |
71 | const auto src_iter_c_aoc = rnn_utils::make_raw_aoc(src_iter_c_, |
72 | types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld, |
73 | src_iter_c_ld); |
74 | |
75 | const auto src_iter_c = [&](int mb_id, int dhc_id) { |
76 | return rnn_utils::to_float( |
77 | src_iter_c_aoc(mb_id, dhc_id), rnn.src_iter_c_dt); |
78 | }; |
79 | const auto dst_iter_c_assign = [&](int mb_id, int dhc_id, float c_state) { |
80 | const auto dst_iter_c_ptr |
81 | = const_cast<void *>(dst_iter_c(mb_id, dhc_id)); |
82 | |
83 | if (rnn.dst_iter_c_dt == data_type::f32) |
84 | *static_cast<float *>(dst_iter_c_ptr) = c_state; |
85 | else if (rnn.dst_iter_c_dt == data_type::bf16) |
86 | *static_cast<bfloat16_t *>(dst_iter_c_ptr) |
87 | = cpu::saturate_and_round<bfloat16_t>(c_state); |
88 | }; |
89 | |
90 | const auto postgemm_call = [&](int i) { |
91 | const int n_elem = block_step / (int)sizeof(scratch_data_t); |
92 | PRAGMA_OMP_SIMD() |
93 | for (int j = 0; j < n_elem; j++) { |
94 | float gate_i_arg |
95 | = to_float(scratch_gates(i, 0, j), 0, j) + bias(0, j); |
96 | if (rnn.is_lstm_peephole) |
97 | gate_i_arg += weights_peephole(0, j) * src_iter_c(i, j); |
98 | |
99 | float gate_f_arg |
100 | = to_float(scratch_gates(i, 1, j), 1, j) + bias(1, j); |
101 | if (rnn.is_lstm_peephole) |
102 | gate_f_arg += weights_peephole(1, j) * src_iter_c(i, j); |
103 | |
104 | const float gate_c_arg |
105 | = to_float(scratch_gates(i, 2, j), 2, j) + bias(2, j); |
106 | |
107 | // default func1 is sigmoid, func2 is tanh |
108 | |
109 | const float gate_i = func1(scales + 0, gate_i_arg); |
110 | const float gate_f = func1(scales + 1, gate_f_arg); |
111 | const float gate_c = func2(scales + 2, gate_c_arg); |
112 | |
113 | const float c_state = gate_f * src_iter_c(i, j) + gate_i * gate_c; |
114 | dst_iter_c_assign(i, j, c_state); |
115 | |
116 | float gate_o_arg |
117 | = to_float(scratch_gates(i, 3, j), 3, j) + bias(3, j); |
118 | if (rnn.is_lstm_peephole) |
119 | gate_o_arg += weights_peephole(2, j) * c_state; |
120 | |
121 | const float gate_o = func1(scales + 3, gate_o_arg); |
122 | |
123 | const src_data_t ht = to_src_dt(gate_o * func2(cscale, c_state)); |
124 | if (dst_layer_ != nullptr) dst_layer(i, j) = ht; |
125 | if (dst_iter_ != nullptr) dst_iter(i, j) = ht; |
126 | |
127 | // write gates back to memory for training |
128 | // we to_src_dt them as as they are GEMM inputs in BWD |
129 | if (rnn.is_training) { |
130 | ws_gates(i, 0, j) = to_src_dt(gate_i); |
131 | ws_gates(i, 1, j) = to_src_dt(gate_f); |
132 | ws_gates(i, 2, j) = to_src_dt(gate_c); |
133 | ws_gates(i, 3, j) = to_src_dt(gate_o); |
134 | } |
135 | } |
136 | }; |
137 | |
138 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
139 | for (int i = 0; i < rnn.m_block; i++) |
140 | postgemm_call(i); |
141 | } else { |
142 | parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); }); |
143 | } |
144 | } |
145 | |
146 | template <> |
147 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_postgemm) { |
148 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
149 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
150 | |
151 | const auto q_id = [&](float f) { return f; }; |
152 | const auto deq_id = [&](float f, int i, int j) { return f; }; |
153 | |
154 | const auto linear_f |
155 | = [](const float *scale, float a) { return *scale * a; }; |
156 | const auto logistic_f = [](const float *scale, float a) { |
157 | return logistic_fwd<float>(a); |
158 | }; |
159 | const auto tanh_f |
160 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
161 | |
162 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
163 | lstm_fwd_postgemm_template(logistic_f, tanh_f, q_id, deq_id, scales, |
164 | cscale, rnn, cell_position, ws_gates_, scratch_gates_, |
165 | dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_, |
166 | weights_peephole_, bias_, block_step); |
167 | else |
168 | lstm_fwd_postgemm_template(linear_f, linear_f, q_id, deq_id, scales, |
169 | cscale, rnn, cell_position, ws_gates_, scratch_gates_, |
170 | dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_, |
171 | weights_peephole_, bias_, block_step); |
172 | } |
173 | |
174 | template <> |
175 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_postgemm) { |
176 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
177 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
178 | const auto round_f32_bf16 = [&](float f) { return bfloat16_t(f); }; |
179 | const auto deq_id = [&](float f, int i, int j) { return f; }; |
180 | |
181 | const auto linear_f |
182 | = [](const float *scale, float a) { return *scale * a; }; |
183 | const auto logistic_f = [](const float *scale, float a) { |
184 | return logistic_fwd<float>(a); |
185 | }; |
186 | const auto tanh_f |
187 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
188 | |
189 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
190 | lstm_fwd_postgemm_template(logistic_f, tanh_f, round_f32_bf16, deq_id, |
191 | scales, cscale, rnn, cell_position, ws_gates_, scratch_gates_, |
192 | dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_, |
193 | weights_peephole_, bias_, block_step); |
194 | else |
195 | lstm_fwd_postgemm_template(linear_f, linear_f, round_f32_bf16, deq_id, |
196 | scales, cscale, rnn, cell_position, ws_gates_, scratch_gates_, |
197 | dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_, |
198 | weights_peephole_, bias_, block_step); |
199 | } |
200 | |
201 | template <> |
202 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_postgemm) { |
203 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
204 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
205 | |
206 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
207 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
208 | |
209 | const auto quantize_f32_u8 = [&](float f) { |
210 | float qf = f * data_scale + data_shift; |
211 | return qz_a1b0<float, dst_layer_t>()(qf); |
212 | }; |
213 | |
214 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { |
215 | const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0 |
216 | ? weights_scales_[0] |
217 | : weights_scales_[gate * rnn.dhc + j]; |
218 | |
219 | return saturate<float>(s) * (1.f / (wscale * data_scale)); |
220 | }; |
221 | |
222 | const auto linear_f |
223 | = [](const float *scale, float a) { return *scale * a; }; |
224 | const auto logistic_f = [](const float *scale, float a) { |
225 | return logistic_fwd<float>(a); |
226 | }; |
227 | const auto tanh_f |
228 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
229 | |
230 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
231 | lstm_fwd_postgemm_template(logistic_f, tanh_f, quantize_f32_u8, |
232 | dequantize_s32_f32, scales, cscale, rnn, cell_position, |
233 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_, |
234 | src_iter_, src_iter_c_, weights_peephole_, bias_, block_step); |
235 | else |
236 | lstm_fwd_postgemm_template(linear_f, linear_f, quantize_f32_u8, |
237 | dequantize_s32_f32, scales, cscale, rnn, cell_position, |
238 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_, |
239 | src_iter_, src_iter_c_, weights_peephole_, bias_, block_step); |
240 | } |
241 | |
242 | template <> |
243 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_postgemm) { |
244 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
245 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
246 | |
247 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
248 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
249 | |
250 | const auto quantize_f32_s8 = [&](float f) { |
251 | float qf = f * data_scale + data_shift; |
252 | return qz_a1b0<float, dst_layer_t>()(qf); |
253 | }; |
254 | |
255 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { |
256 | float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0 |
257 | ? weights_scales_[0] |
258 | : weights_scales_[gate * rnn.dhc + j]; |
259 | |
260 | return saturate<float>(s) * (1.f / (wscale * data_scale)); |
261 | }; |
262 | |
263 | const auto linear_f |
264 | = [](const float *scale, float a) { return *scale * a; }; |
265 | const auto logistic_f = [](const float *scale, float a) { |
266 | return logistic_fwd<float>(a); |
267 | }; |
268 | const auto tanh_f |
269 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
270 | |
271 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
272 | lstm_fwd_postgemm_template(logistic_f, tanh_f, quantize_f32_s8, |
273 | dequantize_s32_f32, scales, cscale, rnn, cell_position, |
274 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_, |
275 | src_iter_, src_iter_c_, weights_peephole_, bias_, block_step); |
276 | else |
277 | lstm_fwd_postgemm_template(linear_f, linear_f, quantize_f32_s8, |
278 | dequantize_s32_f32, scales, cscale, rnn, cell_position, |
279 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_, |
280 | src_iter_, src_iter_c_, weights_peephole_, bias_, block_step); |
281 | } |
282 | |
283 | template <typename T1, typename T2, typename src_data_t, typename acc_data_t, |
284 | typename scratch_data_t> |
285 | void lstm_bwd_postgemm_template(T1 func1, T2 to_src_dt, const float *cscale, |
286 | const rnn_utils::rnn_conf_t &rnn, const cell_position_t cell_position, |
287 | src_data_t *ws_gates_, scratch_data_t *scratch_gates_, |
288 | void *dst_iter_c_, const void *src_iter_c_, |
289 | acc_data_t *diff_src_iter_c_, acc_data_t *diff_dst_layer_, |
290 | acc_data_t *diff_dst_iter_, acc_data_t *diff_dst_iter_c_, |
291 | const float *weights_peephole_, const void *bias_) { |
292 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
293 | const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
294 | const weights_peephole_aoc_t<const float> weights_peephole( |
295 | rnn, weights_peephole_); |
296 | const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position); |
297 | const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position); |
298 | const auto src_iter_c_aoc = rnn_utils::make_raw_aoc(src_iter_c_, |
299 | types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld, |
300 | src_iter_c_ld); |
301 | const auto src_iter_c = [&](int mb_id, int dhc_id) { |
302 | return rnn_utils::to_float( |
303 | src_iter_c_aoc(mb_id, dhc_id), rnn.src_iter_c_dt); |
304 | }; |
305 | const auto dst_iter_c_aoc = rnn_utils::make_raw_aoc(dst_iter_c_, |
306 | types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld, |
307 | dst_iter_c_ld); |
308 | const auto dst_iter_c = [&](int mb_id, int dhc_id) { |
309 | return rnn_utils::to_float( |
310 | dst_iter_c_aoc(mb_id, dhc_id), rnn.dst_iter_c_dt); |
311 | }; |
312 | |
313 | const ws_diff_states_iter_c_aoc<acc_data_t> diff_src_iter_c( |
314 | rnn, diff_src_iter_c_); |
315 | const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer( |
316 | rnn, diff_dst_layer_); |
317 | const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter( |
318 | rnn, diff_dst_iter_); |
319 | const ws_diff_states_iter_c_aoc<acc_data_t> diff_dst_iter_c( |
320 | rnn, diff_dst_iter_c_); |
321 | |
322 | parallel_nd(rnn.mb, [&](dim_t i) { |
323 | PRAGMA_OMP_SIMD() |
324 | for (int j = 0; j < rnn.dhc; j++) { |
325 | const float Ct = dst_iter_c(i, j); |
326 | /// @todo save it in the workspace in fwd pass or recompute it to |
327 | /// save bw |
328 | const float tanhCt = func1(cscale, Ct); |
329 | // we have 2 incoming diffs on Ht if no projection, |
330 | // otherwise we have only 1 as the summation happened |
331 | // before the bwd projection |
332 | float dHt = diff_dst_layer(i, j); |
333 | if (!rnn.is_lstm_projection) dHt += diff_dst_iter(i, j); |
334 | float dCt = diff_dst_iter_c(i, j) |
335 | + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt; |
336 | |
337 | const float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j)); |
338 | |
339 | if (rnn.is_lstm_peephole) dCt += dG3 * weights_peephole(2, j); |
340 | |
341 | const float dG1 |
342 | = src_iter_c(i, j) * dCt * x_m_square(ws_gates(i, 1, j)); |
343 | const float dG0 |
344 | = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j)); |
345 | const float dG2 |
346 | = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j)); |
347 | |
348 | diff_src_iter_c(i, j) = dCt * ws_gates(i, 1, j); |
349 | |
350 | if (rnn.is_lstm_peephole) { |
351 | diff_src_iter_c(i, j) += dG1 * weights_peephole(1, j); |
352 | diff_src_iter_c(i, j) += dG0 * weights_peephole(0, j); |
353 | } |
354 | |
355 | scratch_gates(i, 0, j) = to_src_dt(dG0); |
356 | scratch_gates(i, 1, j) = to_src_dt(dG1); |
357 | scratch_gates(i, 2, j) = to_src_dt(dG2); |
358 | scratch_gates(i, 3, j) = to_src_dt(dG3); |
359 | } |
360 | }); |
361 | } |
362 | |
363 | template <> |
364 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_postgemm) { |
365 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
366 | const auto linear_f |
367 | = [](const float *scale, float a) { return *scale * a; }; |
368 | const auto tanh_f |
369 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
370 | const auto to_src_dt = [](float a) { return a; }; |
371 | |
372 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
373 | lstm_bwd_postgemm_template(tanh_f, to_src_dt, cscale, rnn, |
374 | cell_position, ws_gates_, scratch_gates_, dst_iter_c_, |
375 | src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
376 | diff_dst_iter_c_, weights_peephole_, bias_); |
377 | else |
378 | lstm_bwd_postgemm_template(linear_f, to_src_dt, cscale, rnn, |
379 | cell_position, ws_gates_, scratch_gates_, dst_iter_c_, |
380 | src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
381 | diff_dst_iter_c_, weights_peephole_, bias_); |
382 | } |
383 | |
384 | template <> |
385 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_postgemm) { |
386 | const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_); |
387 | const auto linear_f |
388 | = [](const float *scale, float a) { return *scale * a; }; |
389 | const auto tanh_f |
390 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
391 | const auto to_src_dt = [](float a) { return bfloat16_t(a); }; |
392 | |
393 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
394 | lstm_bwd_postgemm_template(tanh_f, to_src_dt, cscale, rnn, |
395 | cell_position, ws_gates_, scratch_gates_, dst_iter_c_, |
396 | src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
397 | diff_dst_iter_c_, weights_peephole_, bias_); |
398 | else |
399 | lstm_bwd_postgemm_template(linear_f, to_src_dt, cscale, rnn, |
400 | cell_position, ws_gates_, scratch_gates_, dst_iter_c_, |
401 | src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, |
402 | diff_dst_iter_c_, weights_peephole_, bias_); |
403 | } |
404 | |
405 | } // namespace cpu |
406 | } // namespace impl |
407 | } // namespace dnnl |
408 | |