1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18
19#include "utils/parallel.hpp"
20
21#include "rnn/rnn.hpp"
22#include "rnn/rnn_aux.hpp"
23
24#include "rnn/cells.hpp"
25
26namespace rnn {
27
28template <typename T1, typename T2>
29void lstm_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb,
30 float *gates_, const float *weights_peephole_, const float *bias_,
31 const float *src_iter_c_, float *dst_layer_, float *dst_iter_c_) {
32 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
33 AOC<const float> weights_peephole(weights_peephole_, 3, prb.dhc);
34 AOC<const float> bias(bias_, prb.n_gates(), prb.dhc);
35 AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc);
36 AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
37 AOC<float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc);
38
39 // run the eltwise
40 benchdnn_parallel_nd(prb.mb, [&](int64_t ib) {
41 for (int64_t ih = 0; ih < prb.dhc; ih++) {
42 float peephole_extra_i = 0, peephole_extra_f = 0;
43 if (prb.is_lstm_peephole()) {
44 peephole_extra_i = weights_peephole(0, ih) * src_iter_c(ib, ih);
45 peephole_extra_f = weights_peephole(1, ih) * src_iter_c(ib, ih);
46 }
47
48 gates(ib, LSTM_I, ih) = func1(prb.linear_scales[LSTM_I],
49 maybe_deq(prb, gates(ib, LSTM_I, ih), LSTM_I * prb.dhc + ih)
50 + peephole_extra_i + bias(LSTM_I, ih));
51 gates(ib, LSTM_F, ih) = func1(prb.linear_scales[LSTM_F],
52 maybe_deq(prb, gates(ib, LSTM_F, ih), LSTM_F * prb.dhc + ih)
53 + peephole_extra_f + bias(LSTM_F, ih));
54
55 gates(ib, LSTM_C, ih) = func2(prb.linear_scales[LSTM_C],
56 maybe_deq(prb, gates(ib, LSTM_C, ih), LSTM_C * prb.dhc + ih)
57 + bias(LSTM_C, ih));
58
59 // compute C_t_l and H_t_l
60 float tmp = gates(ib, LSTM_F, ih) * src_iter_c(ib, ih)
61 + gates(ib, LSTM_I, ih) * gates(ib, LSTM_C, ih);
62 dst_iter_c(ib, ih) = tmp;
63
64 float peephole_extra_o = 0;
65 if (prb.is_lstm_peephole())
66 peephole_extra_o = weights_peephole(2, ih) * tmp;
67
68 gates(ib, LSTM_O, ih) = func1(prb.linear_scales[LSTM_O],
69 maybe_deq(prb, gates(ib, LSTM_O, ih), LSTM_O * prb.dhc + ih)
70 + peephole_extra_o + bias(LSTM_O, ih));
71
72 dst_layer(ib, ih) = maybe_q(
73 prb, gates(ib, LSTM_O, ih) * func2(prb.linear_cscale, tmp));
74
75 for (int64_t ig = 0; ig < 4; ig++) {
76 BENCHDNN_PRINT(80,
77 "activation 1 a[" IFMT "][" IFMT "][" IFMT "] = %.7f\n",
78 ib, ig, ih, gates(ib, ig, ih));
79 }
80 BENCHDNN_PRINT(80, "recomp tmp(%a) cin(%a) ht(%a)\n", tmp,
81 src_iter_c(ib, ih), dst_layer(ib, ih));
82 }
83 });
84}
85
86void lstm_fwd_postgemm(const prb_t &prb, float *gates_,
87 const float *weights_peephole_, const float *bias_,
88 const float *src_iter_c_, float *dst_layer_, float *dst_iter_c_) {
89 if (prb.skip_nonlinear)
90 lstm_fwd_postgemm_template(
91 [](float scale, float val) { return scale * val; },
92 [](float scale, float val) { return scale * val; }, prb, gates_,
93 weights_peephole_, bias_, src_iter_c_, dst_layer_, dst_iter_c_);
94 else
95 lstm_fwd_postgemm_template(
96 [](float scale, float val) { return logistic(val); },
97 [](float scale, float val) { return tanhf(val); }, prb, gates_,
98 weights_peephole_, bias_, src_iter_c_, dst_layer_, dst_iter_c_);
99}
100
101void lstm_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_,
102 float *dst_iter_c_, float *gates_, float *ht_,
103 const float *weights_layer_, const float *weights_iter_,
104 const float *weights_peephole_, const float *weights_projection_,
105 const float *weights_projection_compensation, const float *bias_,
106 const float *src_layer_, const float *src_iter_,
107 const float *src_iter_c_) {
108
109 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0,
110 src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0,
111 gates_, prb.n_gates() * prb.dhc);
112 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0,
113 src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 1.0,
114 gates_, prb.n_gates() * prb.dhc);
115
116 // if lstmp, we use the workspace to write the postgemm output
117 auto dst_postgemm = prb.is_lstm_projection() ? ht_ : dst_layer_;
118 lstm_fwd_postgemm(prb, gates_, weights_peephole_, bias_, src_iter_c_,
119 dst_postgemm, dst_iter_c_);
120
121 assert(dst_layer_ == dst_iter_);
122 if (prb.is_lstm_projection()) {
123 gemm("C", "N", "N", prb.mb, prb.dic, prb.dhc, 1.0, dst_postgemm, prb.wc,
124 weights_projection_, prb.dic, 0.0, dst_layer_, prb.wc);
125
126 if (prb.cfg.is_int8()) {
127 // Here we simulate int8 usage by dequantizing and requantizing the buffer
128 benchdnn_parallel_nd(prb.mb, [&](int64_t i) {
129 for (int j = 0; j < prb.dic; j++) {
130 int64_t addr = i * prb.wc + j;
131 float d_tmp = maybe_deq_proj(prb, dst_layer_[addr],
132 weights_projection_compensation[j], j);
133 dst_layer_[addr] = maybe_q(prb, d_tmp);
134 }
135 });
136 }
137 } else {
138 assert(prb.dic == prb.dhc);
139 }
140}
141
142template <typename T1>
143void lstm_bwd_pregemm_template(T1 func1, const prb_t &prb,
144 const float *src_iter_c_, const float *dst_iter_c_,
145 const float *weights_peephole_, const float *diff_hidden_state_,
146 const float *diff_dst_iter_c_, const float *gates_,
147 float *diff_src_iter_c_, float *b_gates_) {
148 AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc);
149 AOC<const float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc);
150 AOC<const float> weights_peephole(weights_peephole_, 3, prb.dhc);
151 AOC<const float> diff_hidden_state(diff_hidden_state_, prb.mb, prb.dhc);
152 AOC<const float> diff_dst_iter_c(diff_dst_iter_c_, prb.mb, prb.wc);
153 AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
154 AOC<float> diff_src_iter_c(diff_src_iter_c_, prb.mb, prb.wc);
155 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
156
157 for (int64_t ib = 0; ib < prb.mb; ib++)
158 for (int64_t ih = 0; ih < prb.dhc; ih++) {
159 BENCHDNN_PRINT(80, "rnn_single_bwd: ib = " IFMT " ih = " IFMT "\n",
160 ib, ih);
161 float hi = gates(ib, LSTM_I, ih);
162 float hf = gates(ib, LSTM_F, ih);
163 float hc = gates(ib, LSTM_C, ih);
164 float ho = gates(ib, LSTM_O, ih);
165
166 float dh = diff_hidden_state(ib, ih);
167
168 float tanhC = func1(prb.linear_cscale, dst_iter_c(ib, ih));
169 float dho = tanhC * dh;
170 b_gates(ib, LSTM_O, ih) = x_m_square(ho) * dho;
171
172 float dc = diff_dst_iter_c(ib, ih);
173 dc += ho * dh * one_m_square(tanhC);
174
175 if (prb.is_lstm_peephole())
176 dc += b_gates(ib, LSTM_O, ih) * weights_peephole(2, ih);
177
178 float dc_tm1 = hf * dc;
179
180 float c_old = src_iter_c(ib, ih);
181 float dhf = c_old * dc;
182 b_gates(ib, LSTM_F, ih) = x_m_square(hf) * dhf;
183
184 float dhi = hc * dc;
185 b_gates(ib, LSTM_I, ih) = x_m_square(hi) * dhi;
186
187 float dhc = hi * dc;
188 b_gates(ib, LSTM_C, ih) = one_m_square(hc) * dhc;
189
190 if (prb.is_lstm_peephole()) {
191 dc_tm1 += b_gates(ib, LSTM_F, ih) * weights_peephole(1, ih);
192 dc_tm1 += b_gates(ib, LSTM_I, ih) * weights_peephole(0, ih);
193 }
194
195 diff_src_iter_c(ib, ih) = dc_tm1;
196 }
197}
198
199void lstm_bwd_pregemm(const prb_t &prb, const float *src_iter_c_,
200 const float *dst_iter_c_, const float *weights_peephole_,
201 const float *diff_hidden_state_, const float *diff_dst_iter_c_,
202 const float *gates_, float *diff_src_iter_c_, float *b_gates_) {
203 if (prb.skip_nonlinear)
204 lstm_bwd_pregemm_template(
205 [](float scale, float val) { return scale * val; }, prb,
206 src_iter_c_, dst_iter_c_, weights_peephole_, diff_hidden_state_,
207 diff_dst_iter_c_, gates_, diff_src_iter_c_, b_gates_);
208
209 else
210 lstm_bwd_pregemm_template(
211 [](float scale, float val) { return tanhf(val); }, prb,
212 src_iter_c_, dst_iter_c_, weights_peephole_, diff_hidden_state_,
213 diff_dst_iter_c_, gates_, diff_src_iter_c_, b_gates_);
214}
215
216void lstm_bwd_weights_peephole(const prb_t &prb, const float *src_iter_c_,
217 const float *dst_iter_c_, const float *b_gates_,
218 float *diff_weights_peephole_) {
219 AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc);
220 AOC<const float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc);
221 AOC<const float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
222 AOC<float> diff_weights_peephole(diff_weights_peephole_, 3, prb.dhc);
223
224 for_(int64_t ib = 0; ib < prb.mb; ++ib)
225 for (int64_t ih = 0; ih < prb.dhc; ++ih)
226 diff_weights_peephole(2, ih)
227 += b_gates(ib, LSTM_O, ih) * dst_iter_c(ib, ih);
228
229 for_(int64_t ib = 0; ib < prb.mb; ++ib)
230 for_(int64_t j = 0; j < 2; ++j)
231 for (int64_t ih = 0; ih < prb.dhc; ++ih)
232 diff_weights_peephole(j, ih) += b_gates(ib, j, ih) * src_iter_c(ib, ih);
233}
234
235void lstm_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_,
236 float *diff_src_iter_c_, float *diff_weights_layer_,
237 float *diff_weights_iter_, float *diff_weights_peephole_,
238 float *diff_weights_projection_, float *diff_bias_, float *b_gates_,
239 const float *src_layer_, const float *src_iter_,
240 const float *src_iter_c_, const float *weights_layer_,
241 const float *weights_iter_, const float *weights_peephole_,
242 const float *weights_projection_, const float *bias_,
243 const float *dst_layer_, const float *dst_iter_c_, const float *gates_,
244 const float *ht_, const float *diff_dst_layer_,
245 const float *diff_dst_iter_, const float *diff_dst_iter_c_,
246 float *cell_scratchpad_) {
247 float *diff_hidden_state_ = cell_scratchpad_;
248
249 AOC<float> diff_hidden_state(diff_hidden_state_, prb.mb, prb.dhc);
250 AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc);
251 AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc);
252
253 if (prb.is_lstm_projection()) {
254 float *diff_dst
255 = (float *)zmalloc(prb.mb * prb.dic * sizeof(float), 64);
256 DNN_SAFE_V(diff_dst == nullptr ? dnnl_out_of_memory : dnnl_success);
257
258 // The loop below relies on this property
259 assert(prb.dic == prb.dlc(CELL));
260 for_(int64_t ib = 0; ib < prb.mb; ib++)
261 for (int64_t ih = 0; ih < prb.dic; ih++)
262 diff_dst[ib * prb.dic + ih]
263 = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
264
265 gemm("C", "T", "N", prb.dhc, prb.dic, prb.mb, 1.0, ht_, prb.wc,
266 diff_dst, prb.dic, 1.0, diff_weights_projection_, prb.dic);
267 gemm("C", "N", "T", prb.mb, prb.dhc, prb.dic, 1.0, diff_dst, prb.dic,
268 weights_projection_, prb.dic, 0.0, diff_hidden_state_, prb.dhc);
269 zfree(diff_dst);
270 } else {
271 for_(int64_t ib = 0; ib < prb.mb; ib++)
272 for (int64_t ih = 0; ih < prb.dhc; ih++)
273 diff_hidden_state(ib, ih)
274 = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
275 }
276
277 lstm_bwd_pregemm(prb, src_iter_c_, dst_iter_c_, weights_peephole_,
278 diff_hidden_state_, diff_dst_iter_c_, gates_, diff_src_iter_c_,
279 b_gates_);
280
281 gemm("C", "T", "N", prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0,
282 src_iter_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
283 diff_weights_iter_, prb.n_gates() * prb.dhc);
284 gemm("C", "T", "N", prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0,
285 src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
286 diff_weights_layer_, prb.n_gates() * prb.dhc);
287
288 gemm("C", "N", "T", prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0, b_gates_,
289 prb.n_gates() * prb.dhc, weights_iter_, prb.n_gates() * prb.dhc,
290 0.0, diff_src_iter_, prb.wc);
291 gemm("C", "N", "T", prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_,
292 prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc,
293 0.0, diff_src_layer_, prb.wc);
294
295 if (prb.is_lstm_peephole())
296 lstm_bwd_weights_peephole(prb, src_iter_c_, dst_iter_c_, b_gates_,
297 diff_weights_peephole_);
298
299 gates_reduction(prb, b_gates_, diff_bias_);
300}
301
302} // namespace rnn
303