1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include "utils/parallel.hpp" |
20 | |
21 | #include "rnn/rnn.hpp" |
22 | #include "rnn/rnn_aux.hpp" |
23 | |
24 | #include "rnn/cells.hpp" |
25 | |
26 | namespace rnn { |
27 | |
28 | template <typename T1, typename T2> |
29 | void lstm_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb, |
30 | float *gates_, const float *weights_peephole_, const float *bias_, |
31 | const float *src_iter_c_, float *dst_layer_, float *dst_iter_c_) { |
32 | AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
33 | AOC<const float> weights_peephole(weights_peephole_, 3, prb.dhc); |
34 | AOC<const float> bias(bias_, prb.n_gates(), prb.dhc); |
35 | AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc); |
36 | AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc); |
37 | AOC<float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc); |
38 | |
39 | // run the eltwise |
40 | benchdnn_parallel_nd(prb.mb, [&](int64_t ib) { |
41 | for (int64_t ih = 0; ih < prb.dhc; ih++) { |
42 | float = 0, = 0; |
43 | if (prb.is_lstm_peephole()) { |
44 | peephole_extra_i = weights_peephole(0, ih) * src_iter_c(ib, ih); |
45 | peephole_extra_f = weights_peephole(1, ih) * src_iter_c(ib, ih); |
46 | } |
47 | |
48 | gates(ib, LSTM_I, ih) = func1(prb.linear_scales[LSTM_I], |
49 | maybe_deq(prb, gates(ib, LSTM_I, ih), LSTM_I * prb.dhc + ih) |
50 | + peephole_extra_i + bias(LSTM_I, ih)); |
51 | gates(ib, LSTM_F, ih) = func1(prb.linear_scales[LSTM_F], |
52 | maybe_deq(prb, gates(ib, LSTM_F, ih), LSTM_F * prb.dhc + ih) |
53 | + peephole_extra_f + bias(LSTM_F, ih)); |
54 | |
55 | gates(ib, LSTM_C, ih) = func2(prb.linear_scales[LSTM_C], |
56 | maybe_deq(prb, gates(ib, LSTM_C, ih), LSTM_C * prb.dhc + ih) |
57 | + bias(LSTM_C, ih)); |
58 | |
59 | // compute C_t_l and H_t_l |
60 | float tmp = gates(ib, LSTM_F, ih) * src_iter_c(ib, ih) |
61 | + gates(ib, LSTM_I, ih) * gates(ib, LSTM_C, ih); |
62 | dst_iter_c(ib, ih) = tmp; |
63 | |
64 | float = 0; |
65 | if (prb.is_lstm_peephole()) |
66 | peephole_extra_o = weights_peephole(2, ih) * tmp; |
67 | |
68 | gates(ib, LSTM_O, ih) = func1(prb.linear_scales[LSTM_O], |
69 | maybe_deq(prb, gates(ib, LSTM_O, ih), LSTM_O * prb.dhc + ih) |
70 | + peephole_extra_o + bias(LSTM_O, ih)); |
71 | |
72 | dst_layer(ib, ih) = maybe_q( |
73 | prb, gates(ib, LSTM_O, ih) * func2(prb.linear_cscale, tmp)); |
74 | |
75 | for (int64_t ig = 0; ig < 4; ig++) { |
76 | BENCHDNN_PRINT(80, |
77 | "activation 1 a[" IFMT "][" IFMT "][" IFMT "] = %.7f\n" , |
78 | ib, ig, ih, gates(ib, ig, ih)); |
79 | } |
80 | BENCHDNN_PRINT(80, "recomp tmp(%a) cin(%a) ht(%a)\n" , tmp, |
81 | src_iter_c(ib, ih), dst_layer(ib, ih)); |
82 | } |
83 | }); |
84 | } |
85 | |
86 | void lstm_fwd_postgemm(const prb_t &prb, float *gates_, |
87 | const float *weights_peephole_, const float *bias_, |
88 | const float *src_iter_c_, float *dst_layer_, float *dst_iter_c_) { |
89 | if (prb.skip_nonlinear) |
90 | lstm_fwd_postgemm_template( |
91 | [](float scale, float val) { return scale * val; }, |
92 | [](float scale, float val) { return scale * val; }, prb, gates_, |
93 | weights_peephole_, bias_, src_iter_c_, dst_layer_, dst_iter_c_); |
94 | else |
95 | lstm_fwd_postgemm_template( |
96 | [](float scale, float val) { return logistic(val); }, |
97 | [](float scale, float val) { return tanhf(val); }, prb, gates_, |
98 | weights_peephole_, bias_, src_iter_c_, dst_layer_, dst_iter_c_); |
99 | } |
100 | |
101 | void lstm_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_, |
102 | float *dst_iter_c_, float *gates_, float *ht_, |
103 | const float *weights_layer_, const float *weights_iter_, |
104 | const float *weights_peephole_, const float *weights_projection_, |
105 | const float *weights_projection_compensation, const float *bias_, |
106 | const float *src_layer_, const float *src_iter_, |
107 | const float *src_iter_c_) { |
108 | |
109 | gemm("C" , "N" , "N" , prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0, |
110 | src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0, |
111 | gates_, prb.n_gates() * prb.dhc); |
112 | gemm("C" , "N" , "N" , prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0, |
113 | src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 1.0, |
114 | gates_, prb.n_gates() * prb.dhc); |
115 | |
116 | // if lstmp, we use the workspace to write the postgemm output |
117 | auto dst_postgemm = prb.is_lstm_projection() ? ht_ : dst_layer_; |
118 | lstm_fwd_postgemm(prb, gates_, weights_peephole_, bias_, src_iter_c_, |
119 | dst_postgemm, dst_iter_c_); |
120 | |
121 | assert(dst_layer_ == dst_iter_); |
122 | if (prb.is_lstm_projection()) { |
123 | gemm("C" , "N" , "N" , prb.mb, prb.dic, prb.dhc, 1.0, dst_postgemm, prb.wc, |
124 | weights_projection_, prb.dic, 0.0, dst_layer_, prb.wc); |
125 | |
126 | if (prb.cfg.is_int8()) { |
127 | // Here we simulate int8 usage by dequantizing and requantizing the buffer |
128 | benchdnn_parallel_nd(prb.mb, [&](int64_t i) { |
129 | for (int j = 0; j < prb.dic; j++) { |
130 | int64_t addr = i * prb.wc + j; |
131 | float d_tmp = maybe_deq_proj(prb, dst_layer_[addr], |
132 | weights_projection_compensation[j], j); |
133 | dst_layer_[addr] = maybe_q(prb, d_tmp); |
134 | } |
135 | }); |
136 | } |
137 | } else { |
138 | assert(prb.dic == prb.dhc); |
139 | } |
140 | } |
141 | |
142 | template <typename T1> |
143 | void lstm_bwd_pregemm_template(T1 func1, const prb_t &prb, |
144 | const float *src_iter_c_, const float *dst_iter_c_, |
145 | const float *weights_peephole_, const float *diff_hidden_state_, |
146 | const float *diff_dst_iter_c_, const float *gates_, |
147 | float *diff_src_iter_c_, float *b_gates_) { |
148 | AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc); |
149 | AOC<const float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc); |
150 | AOC<const float> weights_peephole(weights_peephole_, 3, prb.dhc); |
151 | AOC<const float> diff_hidden_state(diff_hidden_state_, prb.mb, prb.dhc); |
152 | AOC<const float> diff_dst_iter_c(diff_dst_iter_c_, prb.mb, prb.wc); |
153 | AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
154 | AOC<float> diff_src_iter_c(diff_src_iter_c_, prb.mb, prb.wc); |
155 | AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
156 | |
157 | for (int64_t ib = 0; ib < prb.mb; ib++) |
158 | for (int64_t ih = 0; ih < prb.dhc; ih++) { |
159 | BENCHDNN_PRINT(80, "rnn_single_bwd: ib = " IFMT " ih = " IFMT "\n" , |
160 | ib, ih); |
161 | float hi = gates(ib, LSTM_I, ih); |
162 | float hf = gates(ib, LSTM_F, ih); |
163 | float hc = gates(ib, LSTM_C, ih); |
164 | float ho = gates(ib, LSTM_O, ih); |
165 | |
166 | float dh = diff_hidden_state(ib, ih); |
167 | |
168 | float tanhC = func1(prb.linear_cscale, dst_iter_c(ib, ih)); |
169 | float dho = tanhC * dh; |
170 | b_gates(ib, LSTM_O, ih) = x_m_square(ho) * dho; |
171 | |
172 | float dc = diff_dst_iter_c(ib, ih); |
173 | dc += ho * dh * one_m_square(tanhC); |
174 | |
175 | if (prb.is_lstm_peephole()) |
176 | dc += b_gates(ib, LSTM_O, ih) * weights_peephole(2, ih); |
177 | |
178 | float dc_tm1 = hf * dc; |
179 | |
180 | float c_old = src_iter_c(ib, ih); |
181 | float dhf = c_old * dc; |
182 | b_gates(ib, LSTM_F, ih) = x_m_square(hf) * dhf; |
183 | |
184 | float dhi = hc * dc; |
185 | b_gates(ib, LSTM_I, ih) = x_m_square(hi) * dhi; |
186 | |
187 | float dhc = hi * dc; |
188 | b_gates(ib, LSTM_C, ih) = one_m_square(hc) * dhc; |
189 | |
190 | if (prb.is_lstm_peephole()) { |
191 | dc_tm1 += b_gates(ib, LSTM_F, ih) * weights_peephole(1, ih); |
192 | dc_tm1 += b_gates(ib, LSTM_I, ih) * weights_peephole(0, ih); |
193 | } |
194 | |
195 | diff_src_iter_c(ib, ih) = dc_tm1; |
196 | } |
197 | } |
198 | |
199 | void lstm_bwd_pregemm(const prb_t &prb, const float *src_iter_c_, |
200 | const float *dst_iter_c_, const float *weights_peephole_, |
201 | const float *diff_hidden_state_, const float *diff_dst_iter_c_, |
202 | const float *gates_, float *diff_src_iter_c_, float *b_gates_) { |
203 | if (prb.skip_nonlinear) |
204 | lstm_bwd_pregemm_template( |
205 | [](float scale, float val) { return scale * val; }, prb, |
206 | src_iter_c_, dst_iter_c_, weights_peephole_, diff_hidden_state_, |
207 | diff_dst_iter_c_, gates_, diff_src_iter_c_, b_gates_); |
208 | |
209 | else |
210 | lstm_bwd_pregemm_template( |
211 | [](float scale, float val) { return tanhf(val); }, prb, |
212 | src_iter_c_, dst_iter_c_, weights_peephole_, diff_hidden_state_, |
213 | diff_dst_iter_c_, gates_, diff_src_iter_c_, b_gates_); |
214 | } |
215 | |
216 | void lstm_bwd_weights_peephole(const prb_t &prb, const float *src_iter_c_, |
217 | const float *dst_iter_c_, const float *b_gates_, |
218 | float *diff_weights_peephole_) { |
219 | AOC<const float> src_iter_c(src_iter_c_, prb.mb, prb.wc); |
220 | AOC<const float> dst_iter_c(dst_iter_c_, prb.mb, prb.wc); |
221 | AOC<const float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
222 | AOC<float> diff_weights_peephole(diff_weights_peephole_, 3, prb.dhc); |
223 | |
224 | for_(int64_t ib = 0; ib < prb.mb; ++ib) |
225 | for (int64_t ih = 0; ih < prb.dhc; ++ih) |
226 | diff_weights_peephole(2, ih) |
227 | += b_gates(ib, LSTM_O, ih) * dst_iter_c(ib, ih); |
228 | |
229 | for_(int64_t ib = 0; ib < prb.mb; ++ib) |
230 | for_(int64_t j = 0; j < 2; ++j) |
231 | for (int64_t ih = 0; ih < prb.dhc; ++ih) |
232 | diff_weights_peephole(j, ih) += b_gates(ib, j, ih) * src_iter_c(ib, ih); |
233 | } |
234 | |
235 | void lstm_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_, |
236 | float *diff_src_iter_c_, float *diff_weights_layer_, |
237 | float *diff_weights_iter_, float *diff_weights_peephole_, |
238 | float *diff_weights_projection_, float *diff_bias_, float *b_gates_, |
239 | const float *src_layer_, const float *src_iter_, |
240 | const float *src_iter_c_, const float *weights_layer_, |
241 | const float *weights_iter_, const float *weights_peephole_, |
242 | const float *weights_projection_, const float *bias_, |
243 | const float *dst_layer_, const float *dst_iter_c_, const float *gates_, |
244 | const float *ht_, const float *diff_dst_layer_, |
245 | const float *diff_dst_iter_, const float *diff_dst_iter_c_, |
246 | float *cell_scratchpad_) { |
247 | float *diff_hidden_state_ = cell_scratchpad_; |
248 | |
249 | AOC<float> diff_hidden_state(diff_hidden_state_, prb.mb, prb.dhc); |
250 | AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc); |
251 | AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc); |
252 | |
253 | if (prb.is_lstm_projection()) { |
254 | float *diff_dst |
255 | = (float *)zmalloc(prb.mb * prb.dic * sizeof(float), 64); |
256 | DNN_SAFE_V(diff_dst == nullptr ? dnnl_out_of_memory : dnnl_success); |
257 | |
258 | // The loop below relies on this property |
259 | assert(prb.dic == prb.dlc(CELL)); |
260 | for_(int64_t ib = 0; ib < prb.mb; ib++) |
261 | for (int64_t ih = 0; ih < prb.dic; ih++) |
262 | diff_dst[ib * prb.dic + ih] |
263 | = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih); |
264 | |
265 | gemm("C" , "T" , "N" , prb.dhc, prb.dic, prb.mb, 1.0, ht_, prb.wc, |
266 | diff_dst, prb.dic, 1.0, diff_weights_projection_, prb.dic); |
267 | gemm("C" , "N" , "T" , prb.mb, prb.dhc, prb.dic, 1.0, diff_dst, prb.dic, |
268 | weights_projection_, prb.dic, 0.0, diff_hidden_state_, prb.dhc); |
269 | zfree(diff_dst); |
270 | } else { |
271 | for_(int64_t ib = 0; ib < prb.mb; ib++) |
272 | for (int64_t ih = 0; ih < prb.dhc; ih++) |
273 | diff_hidden_state(ib, ih) |
274 | = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih); |
275 | } |
276 | |
277 | lstm_bwd_pregemm(prb, src_iter_c_, dst_iter_c_, weights_peephole_, |
278 | diff_hidden_state_, diff_dst_iter_c_, gates_, diff_src_iter_c_, |
279 | b_gates_); |
280 | |
281 | gemm("C" , "T" , "N" , prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0, |
282 | src_iter_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0, |
283 | diff_weights_iter_, prb.n_gates() * prb.dhc); |
284 | gemm("C" , "T" , "N" , prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0, |
285 | src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0, |
286 | diff_weights_layer_, prb.n_gates() * prb.dhc); |
287 | |
288 | gemm("C" , "N" , "T" , prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0, b_gates_, |
289 | prb.n_gates() * prb.dhc, weights_iter_, prb.n_gates() * prb.dhc, |
290 | 0.0, diff_src_iter_, prb.wc); |
291 | gemm("C" , "N" , "T" , prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_, |
292 | prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc, |
293 | 0.0, diff_src_layer_, prb.wc); |
294 | |
295 | if (prb.is_lstm_peephole()) |
296 | lstm_bwd_weights_peephole(prb, src_iter_c_, dst_iter_c_, b_gates_, |
297 | diff_weights_peephole_); |
298 | |
299 | gates_reduction(prb, b_gates_, diff_bias_); |
300 | } |
301 | |
302 | } // namespace rnn |
303 | |