1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution LSTM
19 */
20
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23
24#include "cpu/simple_q10n.hpp"
25
26#include "cpu/rnn/postgemm_dispatcher.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32using namespace dnnl::impl::utils;
33using namespace dnnl::impl::math;
34using namespace rnn_utils;
35
36template <typename T1, typename T2, typename T3, typename T4,
37 typename src_data_t, typename scratch_data_t>
38void lstm_fwd_postgemm_template(T1 func1, T2 func2, T3 to_src_dt, T4 to_float,
39 const float *scales, const float *cscale,
40 const rnn_utils::rnn_conf_t &rnn,
41 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
42 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
43 src_data_t *dst_iter_, void *dst_iter_c_, const src_data_t *src_iter_,
44 const void *src_iter_c_, const float *weights_peephole_,
45 const void *bias_, int block_step) {
46 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
47 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
48 const weights_peephole_aoc_t<const float> weights_peephole(
49 rnn, weights_peephole_);
50 const auto bias_aoc = rnn_utils::make_raw_aoc(
51 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
52 const auto bias = [&](int gate_id, int dhc_id) {
53 return rnn_utils::to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
54 };
55 // If lstmp, instead of dst_layer, we use scratch_ht if inference or ws_ht if training
56 const auto dst_layer_ld = rnn.is_lstm_projection
57 ? rnn.scratch_ht_ld
58 : rnn.dst_layer_ld(cell_position);
59 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
60 const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position);
61 const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position);
62
63 const ws_states_layer_aoc<src_data_t> dst_layer(
64 rnn, dst_layer_, dst_layer_ld);
65 // TODO: we use scratch and not dst_iter for lstmp
66 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
67
68 const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_,
69 types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld,
70 dst_iter_c_ld);
71 const auto src_iter_c_aoc = rnn_utils::make_raw_aoc(src_iter_c_,
72 types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld,
73 src_iter_c_ld);
74
75 const auto src_iter_c = [&](int mb_id, int dhc_id) {
76 return rnn_utils::to_float(
77 src_iter_c_aoc(mb_id, dhc_id), rnn.src_iter_c_dt);
78 };
79 const auto dst_iter_c_assign = [&](int mb_id, int dhc_id, float c_state) {
80 const auto dst_iter_c_ptr
81 = const_cast<void *>(dst_iter_c(mb_id, dhc_id));
82
83 if (rnn.dst_iter_c_dt == data_type::f32)
84 *static_cast<float *>(dst_iter_c_ptr) = c_state;
85 else if (rnn.dst_iter_c_dt == data_type::bf16)
86 *static_cast<bfloat16_t *>(dst_iter_c_ptr)
87 = cpu::saturate_and_round<bfloat16_t>(c_state);
88 };
89
90 const auto postgemm_call = [&](int i) {
91 const int n_elem = block_step / (int)sizeof(scratch_data_t);
92 PRAGMA_OMP_SIMD()
93 for (int j = 0; j < n_elem; j++) {
94 float gate_i_arg
95 = to_float(scratch_gates(i, 0, j), 0, j) + bias(0, j);
96 if (rnn.is_lstm_peephole)
97 gate_i_arg += weights_peephole(0, j) * src_iter_c(i, j);
98
99 float gate_f_arg
100 = to_float(scratch_gates(i, 1, j), 1, j) + bias(1, j);
101 if (rnn.is_lstm_peephole)
102 gate_f_arg += weights_peephole(1, j) * src_iter_c(i, j);
103
104 const float gate_c_arg
105 = to_float(scratch_gates(i, 2, j), 2, j) + bias(2, j);
106
107 // default func1 is sigmoid, func2 is tanh
108
109 const float gate_i = func1(scales + 0, gate_i_arg);
110 const float gate_f = func1(scales + 1, gate_f_arg);
111 const float gate_c = func2(scales + 2, gate_c_arg);
112
113 const float c_state = gate_f * src_iter_c(i, j) + gate_i * gate_c;
114 dst_iter_c_assign(i, j, c_state);
115
116 float gate_o_arg
117 = to_float(scratch_gates(i, 3, j), 3, j) + bias(3, j);
118 if (rnn.is_lstm_peephole)
119 gate_o_arg += weights_peephole(2, j) * c_state;
120
121 const float gate_o = func1(scales + 3, gate_o_arg);
122
123 const src_data_t ht = to_src_dt(gate_o * func2(cscale, c_state));
124 if (dst_layer_ != nullptr) dst_layer(i, j) = ht;
125 if (dst_iter_ != nullptr) dst_iter(i, j) = ht;
126
127 // write gates back to memory for training
128 // we to_src_dt them as as they are GEMM inputs in BWD
129 if (rnn.is_training) {
130 ws_gates(i, 0, j) = to_src_dt(gate_i);
131 ws_gates(i, 1, j) = to_src_dt(gate_f);
132 ws_gates(i, 2, j) = to_src_dt(gate_c);
133 ws_gates(i, 3, j) = to_src_dt(gate_o);
134 }
135 }
136 };
137
138 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
139 for (int i = 0; i < rnn.m_block; i++)
140 postgemm_call(i);
141 } else {
142 parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
143 }
144}
145
146template <>
147rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_postgemm) {
148 const float *scales = pd_->attr()->rnn_tparams_.scales_;
149 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
150
151 const auto q_id = [&](float f) { return f; };
152 const auto deq_id = [&](float f, int i, int j) { return f; };
153
154 const auto linear_f
155 = [](const float *scale, float a) { return *scale * a; };
156 const auto logistic_f = [](const float *scale, float a) {
157 return logistic_fwd<float>(a);
158 };
159 const auto tanh_f
160 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
161
162 if (!pd_->attr()->rnn_tparams_.test_mode_)
163 lstm_fwd_postgemm_template(logistic_f, tanh_f, q_id, deq_id, scales,
164 cscale, rnn, cell_position, ws_gates_, scratch_gates_,
165 dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_,
166 weights_peephole_, bias_, block_step);
167 else
168 lstm_fwd_postgemm_template(linear_f, linear_f, q_id, deq_id, scales,
169 cscale, rnn, cell_position, ws_gates_, scratch_gates_,
170 dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_,
171 weights_peephole_, bias_, block_step);
172}
173
174template <>
175rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_postgemm) {
176 const float *scales = pd_->attr()->rnn_tparams_.scales_;
177 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
178 const auto round_f32_bf16 = [&](float f) { return bfloat16_t(f); };
179 const auto deq_id = [&](float f, int i, int j) { return f; };
180
181 const auto linear_f
182 = [](const float *scale, float a) { return *scale * a; };
183 const auto logistic_f = [](const float *scale, float a) {
184 return logistic_fwd<float>(a);
185 };
186 const auto tanh_f
187 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
188
189 if (!pd_->attr()->rnn_tparams_.test_mode_)
190 lstm_fwd_postgemm_template(logistic_f, tanh_f, round_f32_bf16, deq_id,
191 scales, cscale, rnn, cell_position, ws_gates_, scratch_gates_,
192 dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_,
193 weights_peephole_, bias_, block_step);
194 else
195 lstm_fwd_postgemm_template(linear_f, linear_f, round_f32_bf16, deq_id,
196 scales, cscale, rnn, cell_position, ws_gates_, scratch_gates_,
197 dst_layer_, dst_iter_, dst_iter_c_, src_iter_, src_iter_c_,
198 weights_peephole_, bias_, block_step);
199}
200
201template <>
202rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_postgemm) {
203 const float *scales = pd_->attr()->rnn_tparams_.scales_;
204 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
205
206 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
207 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
208
209 const auto quantize_f32_u8 = [&](float f) {
210 float qf = f * data_scale + data_shift;
211 return qz_a1b0<float, dst_layer_t>()(qf);
212 };
213
214 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
215 const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
216 ? weights_scales_[0]
217 : weights_scales_[gate * rnn.dhc + j];
218
219 return saturate<float>(s) * (1.f / (wscale * data_scale));
220 };
221
222 const auto linear_f
223 = [](const float *scale, float a) { return *scale * a; };
224 const auto logistic_f = [](const float *scale, float a) {
225 return logistic_fwd<float>(a);
226 };
227 const auto tanh_f
228 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
229
230 if (!pd_->attr()->rnn_tparams_.test_mode_)
231 lstm_fwd_postgemm_template(logistic_f, tanh_f, quantize_f32_u8,
232 dequantize_s32_f32, scales, cscale, rnn, cell_position,
233 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_,
234 src_iter_, src_iter_c_, weights_peephole_, bias_, block_step);
235 else
236 lstm_fwd_postgemm_template(linear_f, linear_f, quantize_f32_u8,
237 dequantize_s32_f32, scales, cscale, rnn, cell_position,
238 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_,
239 src_iter_, src_iter_c_, weights_peephole_, bias_, block_step);
240}
241
242template <>
243rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_postgemm) {
244 const float *scales = pd_->attr()->rnn_tparams_.scales_;
245 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
246
247 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
248 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
249
250 const auto quantize_f32_s8 = [&](float f) {
251 float qf = f * data_scale + data_shift;
252 return qz_a1b0<float, dst_layer_t>()(qf);
253 };
254
255 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
256 float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
257 ? weights_scales_[0]
258 : weights_scales_[gate * rnn.dhc + j];
259
260 return saturate<float>(s) * (1.f / (wscale * data_scale));
261 };
262
263 const auto linear_f
264 = [](const float *scale, float a) { return *scale * a; };
265 const auto logistic_f = [](const float *scale, float a) {
266 return logistic_fwd<float>(a);
267 };
268 const auto tanh_f
269 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
270
271 if (!pd_->attr()->rnn_tparams_.test_mode_)
272 lstm_fwd_postgemm_template(logistic_f, tanh_f, quantize_f32_s8,
273 dequantize_s32_f32, scales, cscale, rnn, cell_position,
274 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_,
275 src_iter_, src_iter_c_, weights_peephole_, bias_, block_step);
276 else
277 lstm_fwd_postgemm_template(linear_f, linear_f, quantize_f32_s8,
278 dequantize_s32_f32, scales, cscale, rnn, cell_position,
279 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, dst_iter_c_,
280 src_iter_, src_iter_c_, weights_peephole_, bias_, block_step);
281}
282
283template <typename T1, typename T2, typename src_data_t, typename acc_data_t,
284 typename scratch_data_t>
285void lstm_bwd_postgemm_template(T1 func1, T2 to_src_dt, const float *cscale,
286 const rnn_utils::rnn_conf_t &rnn, const cell_position_t cell_position,
287 src_data_t *ws_gates_, scratch_data_t *scratch_gates_,
288 void *dst_iter_c_, const void *src_iter_c_,
289 acc_data_t *diff_src_iter_c_, acc_data_t *diff_dst_layer_,
290 acc_data_t *diff_dst_iter_, acc_data_t *diff_dst_iter_c_,
291 const float *weights_peephole_, const void *bias_) {
292 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
293 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
294 const weights_peephole_aoc_t<const float> weights_peephole(
295 rnn, weights_peephole_);
296 const int dst_iter_c_ld = rnn.dst_iter_c_ld(cell_position);
297 const int src_iter_c_ld = rnn.src_iter_c_ld(cell_position);
298 const auto src_iter_c_aoc = rnn_utils::make_raw_aoc(src_iter_c_,
299 types::data_type_size(rnn.src_iter_c_dt), rnn.ws_states_iter_c_nld,
300 src_iter_c_ld);
301 const auto src_iter_c = [&](int mb_id, int dhc_id) {
302 return rnn_utils::to_float(
303 src_iter_c_aoc(mb_id, dhc_id), rnn.src_iter_c_dt);
304 };
305 const auto dst_iter_c_aoc = rnn_utils::make_raw_aoc(dst_iter_c_,
306 types::data_type_size(rnn.dst_iter_c_dt), rnn.ws_states_iter_c_nld,
307 dst_iter_c_ld);
308 const auto dst_iter_c = [&](int mb_id, int dhc_id) {
309 return rnn_utils::to_float(
310 dst_iter_c_aoc(mb_id, dhc_id), rnn.dst_iter_c_dt);
311 };
312
313 const ws_diff_states_iter_c_aoc<acc_data_t> diff_src_iter_c(
314 rnn, diff_src_iter_c_);
315 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
316 rnn, diff_dst_layer_);
317 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
318 rnn, diff_dst_iter_);
319 const ws_diff_states_iter_c_aoc<acc_data_t> diff_dst_iter_c(
320 rnn, diff_dst_iter_c_);
321
322 parallel_nd(rnn.mb, [&](dim_t i) {
323 PRAGMA_OMP_SIMD()
324 for (int j = 0; j < rnn.dhc; j++) {
325 const float Ct = dst_iter_c(i, j);
326 /// @todo save it in the workspace in fwd pass or recompute it to
327 /// save bw
328 const float tanhCt = func1(cscale, Ct);
329 // we have 2 incoming diffs on Ht if no projection,
330 // otherwise we have only 1 as the summation happened
331 // before the bwd projection
332 float dHt = diff_dst_layer(i, j);
333 if (!rnn.is_lstm_projection) dHt += diff_dst_iter(i, j);
334 float dCt = diff_dst_iter_c(i, j)
335 + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt;
336
337 const float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j));
338
339 if (rnn.is_lstm_peephole) dCt += dG3 * weights_peephole(2, j);
340
341 const float dG1
342 = src_iter_c(i, j) * dCt * x_m_square(ws_gates(i, 1, j));
343 const float dG0
344 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j));
345 const float dG2
346 = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j));
347
348 diff_src_iter_c(i, j) = dCt * ws_gates(i, 1, j);
349
350 if (rnn.is_lstm_peephole) {
351 diff_src_iter_c(i, j) += dG1 * weights_peephole(1, j);
352 diff_src_iter_c(i, j) += dG0 * weights_peephole(0, j);
353 }
354
355 scratch_gates(i, 0, j) = to_src_dt(dG0);
356 scratch_gates(i, 1, j) = to_src_dt(dG1);
357 scratch_gates(i, 2, j) = to_src_dt(dG2);
358 scratch_gates(i, 3, j) = to_src_dt(dG3);
359 }
360 });
361}
362
363template <>
364rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_postgemm) {
365 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
366 const auto linear_f
367 = [](const float *scale, float a) { return *scale * a; };
368 const auto tanh_f
369 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
370 const auto to_src_dt = [](float a) { return a; };
371
372 if (!pd_->attr()->rnn_tparams_.test_mode_)
373 lstm_bwd_postgemm_template(tanh_f, to_src_dt, cscale, rnn,
374 cell_position, ws_gates_, scratch_gates_, dst_iter_c_,
375 src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
376 diff_dst_iter_c_, weights_peephole_, bias_);
377 else
378 lstm_bwd_postgemm_template(linear_f, to_src_dt, cscale, rnn,
379 cell_position, ws_gates_, scratch_gates_, dst_iter_c_,
380 src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
381 diff_dst_iter_c_, weights_peephole_, bias_);
382}
383
384template <>
385rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_postgemm) {
386 const float *cscale = &(pd_->attr()->rnn_tparams_.cscale_);
387 const auto linear_f
388 = [](const float *scale, float a) { return *scale * a; };
389 const auto tanh_f
390 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
391 const auto to_src_dt = [](float a) { return bfloat16_t(a); };
392
393 if (!pd_->attr()->rnn_tparams_.test_mode_)
394 lstm_bwd_postgemm_template(tanh_f, to_src_dt, cscale, rnn,
395 cell_position, ws_gates_, scratch_gates_, dst_iter_c_,
396 src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
397 diff_dst_iter_c_, weights_peephole_, bias_);
398 else
399 lstm_bwd_postgemm_template(linear_f, to_src_dt, cscale, rnn,
400 cell_position, ws_gates_, scratch_gates_, dst_iter_c_,
401 src_iter_c_, diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_,
402 diff_dst_iter_c_, weights_peephole_, bias_);
403}
404
405} // namespace cpu
406} // namespace impl
407} // namespace dnnl
408