1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include <cmath> |
20 | |
21 | #include "utils/parallel.hpp" |
22 | |
23 | #include "rnn/rnn.hpp" |
24 | #include "rnn/rnn_aux.hpp" |
25 | |
26 | #include "rnn/cells.hpp" |
27 | |
28 | namespace rnn { |
29 | |
30 | void prepare_ws_bwd(const prb_t &prb, std::vector<float> &ws_bwd_buffer, |
31 | AOC<float> &ws_diff_src_layer, AOC<float> &ws_diff_src_iter, |
32 | AOC<float> &ws_diff_src_iter_c) { |
33 | bool is_lstm = prb.alg == VANILLA_LSTM; |
34 | |
35 | ws_diff_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
36 | prb.n_iter + 2, prb.mb, prb.wc); |
37 | ws_diff_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
38 | prb.n_iter + 2, prb.mb, prb.wc); |
39 | ws_diff_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
40 | prb.n_iter + 2, prb.mb, prb.wc); |
41 | |
42 | int64_t size = ws_diff_src_layer.nelems() + ws_diff_src_iter.nelems() |
43 | + is_lstm * ws_diff_src_iter_c.nelems(); |
44 | ws_bwd_buffer.resize(size, 0); |
45 | |
46 | ws_diff_src_layer.set_base_ptr(ws_bwd_buffer.data()); |
47 | ws_diff_src_iter.set_base_ptr( |
48 | ws_bwd_buffer.data() + ws_diff_src_layer.nelems()); |
49 | ws_diff_src_iter_c.set_base_ptr(ws_bwd_buffer.data() |
50 | + ws_diff_src_layer.nelems() + ws_diff_src_iter.nelems()); |
51 | } |
52 | |
53 | /******************************************************************************/ |
54 | /******************************* Copy Routines ********************************/ |
55 | /******************************************************************************/ |
56 | |
57 | void copy_init_bwd(const prb_t &prb, const AOC<float> &ws_diff_src_layer, |
58 | const AOC<float> &ws_diff_src_iter, |
59 | const AOC<float> &ws_diff_src_iter_c, const args_t &args, |
60 | rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir, |
61 | int64_t dir_val) { |
62 | const dnn_mem_t &diff_dst_layer_ = args.find(DNNL_ARG_DIFF_DST_LAYER); |
63 | const dnn_mem_t &diff_dst_iter_ = args.find(DNNL_ARG_DIFF_DST_ITER); |
64 | const dnn_mem_t &diff_dst_iter_c_ = args.find(DNNL_ARG_DIFF_DST_ITER_C); |
65 | |
66 | AOC<const float> diff_dst_layer( |
67 | diff_dst_layer_, prb.n_iter, prb.mb * prb.dlc(PRIMITIVE)); |
68 | AOC<const float> diff_dst_iter( |
69 | diff_dst_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.dic); |
70 | AOC<const float> diff_dst_iter_c( |
71 | diff_dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc); |
72 | |
73 | const bool is_concat = prb.direction == dnnl_bidirectional_concat; |
74 | int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1; |
75 | int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1; |
76 | |
77 | for (int64_t it = 0; it < prb.n_iter; it++) |
78 | copy(prb.mb, prb.dlc(CELL), prb.dlc(PRIMITIVE), prb.wc, |
79 | &diff_dst_layer(it, dir_val * is_concat * prb.dlc(CELL)), |
80 | &ws_diff_src_layer(lay_dest, dir_val, it + 1, 0, 0)); |
81 | |
82 | for (int64_t lay = 0; lay < prb.n_layer; lay++) { |
83 | copy(prb.mb, prb.dic, prb.dic, prb.wc, &diff_dst_iter(lay, dir_val, 0), |
84 | &ws_diff_src_iter(lay + 1, dir_val, it_dest, 0, 0)); |
85 | if (prb.alg == VANILLA_LSTM) { |
86 | copy(prb.mb, prb.dhc, prb.dhc, prb.wc, |
87 | &diff_dst_iter_c(lay, dir_val, 0), |
88 | &ws_diff_src_iter_c(lay + 1, dir_val, it_dest, 0, 0)); |
89 | } |
90 | } |
91 | } |
92 | |
93 | void copy_res_bwd(const prb_t &prb, const args_t &args, |
94 | const AOC<const float> &ws_diff_src_layer, |
95 | const AOC<const float> &ws_diff_src_iter, |
96 | const AOC<const float> &ws_diff_src_iter_c, |
97 | rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir, |
98 | int64_t dir_val, rnn_action_t action) { |
99 | const dnn_mem_t &diff_src_layer_ = args.find(DNNL_ARG_DIFF_SRC_LAYER); |
100 | const dnn_mem_t &diff_src_iter_ = args.find(DNNL_ARG_DIFF_SRC_ITER); |
101 | const dnn_mem_t &diff_src_iter_c_ = args.find(DNNL_ARG_DIFF_SRC_ITER_C); |
102 | |
103 | AOC<float> diff_src_iter( |
104 | diff_src_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.sic); |
105 | AOC<float> diff_src_iter_c( |
106 | diff_src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc); |
107 | AOC<float> diff_src_layer(diff_src_layer_, prb.n_iter, prb.mb, prb.slc); |
108 | |
109 | for (int64_t it = 0; it < prb.n_iter; it++) { |
110 | for (int64_t nb = 0; nb < prb.mb; nb++) { |
111 | auto from = &ws_diff_src_layer(1, dir_val, it + 1, nb, 0); |
112 | auto to = &diff_src_layer(it, nb, 0); |
113 | |
114 | copy(1, prb.slc, prb.wc, prb.slc, from, to, action); |
115 | } |
116 | } |
117 | |
118 | int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1; |
119 | |
120 | for (int64_t lay = 0; lay < prb.n_layer; lay++) { |
121 | if (prb.alg == VANILLA_LSTM) { |
122 | copy(prb.mb, prb.dhc, prb.wc, prb.dhc, |
123 | &ws_diff_src_iter_c(lay + 1, dir_val, it_source, 0, 0), |
124 | &diff_src_iter_c(lay, dir_val, 0, 0)); |
125 | } |
126 | copy(prb.mb, prb.sic, prb.wc, prb.sic, |
127 | &ws_diff_src_iter(lay + 1, dir_val, it_source, 0, 0), |
128 | &diff_src_iter(lay, dir_val, 0, 0)); |
129 | } |
130 | } |
131 | |
132 | /******************************************************************************/ |
133 | /*************************** Computation Routines *****************************/ |
134 | /******************************************************************************/ |
135 | void gates_reduction( |
136 | const prb_t &prb, const float *b_gates_, float *diff_bias_) { |
137 | AOC<const float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
138 | for (int64_t i = 0; i < prb.mb; i++) |
139 | for (int64_t j = 0; j < prb.n_gates(); j++) |
140 | for (int64_t k = 0; k < prb.dhc; k++) |
141 | diff_bias_[j * prb.dhc + k] += b_gates(i, j, k); |
142 | } |
143 | |
144 | void rnn_cell_bwd(const prb_t &prb, float *diff_src_layer, |
145 | float *diff_src_layer_attention, float *diff_src_iter, |
146 | float *diff_src_iter_c, float *diff_weights_layer, |
147 | float *diff_weights_iter, float *diff_weights_peephole, |
148 | float *diff_weights_projection, float *diff_bias, float *b_gates, |
149 | const float *src_layer, const float *src_layer_attention, |
150 | const float *src_iter, const float *src_iter_c, |
151 | const float *weights_layer, const float *weights_iter, |
152 | const float *weights_peephole, const float *weights_projection, |
153 | const float *bias, const float *dst_layer, const float *dst_iter_c, |
154 | const float *gates, const float *ht, const float *diff_dst_layer, |
155 | const float *diff_dst_iter, const float *diff_dst_iter_c, |
156 | float *cell_scratchpad_) { |
157 | |
158 | switch (prb.alg) { |
159 | case VANILLA_LSTM: |
160 | lstm_bwd(prb, diff_src_layer, diff_src_iter, diff_src_iter_c, |
161 | diff_weights_layer, diff_weights_iter, |
162 | diff_weights_peephole, diff_weights_projection, diff_bias, |
163 | b_gates, src_layer, src_iter, src_iter_c, weights_layer, |
164 | weights_iter, weights_peephole, weights_projection, bias, |
165 | dst_layer, dst_iter_c, gates, ht, diff_dst_layer, |
166 | diff_dst_iter, diff_dst_iter_c, cell_scratchpad_); |
167 | break; |
168 | case VANILLA_RNN: |
169 | rnn_bwd(prb, diff_src_layer, diff_src_iter, diff_weights_layer, |
170 | diff_weights_iter, diff_bias, b_gates, src_layer, src_iter, |
171 | weights_layer, weights_iter, bias, gates, diff_dst_layer, |
172 | diff_dst_iter); |
173 | break; |
174 | case VANILLA_GRU: |
175 | case VANILLA_AUGRU: |
176 | gru_bwd(prb, diff_src_layer, diff_src_layer_attention, |
177 | diff_src_iter, diff_weights_layer, diff_weights_iter, |
178 | diff_bias, b_gates, src_layer, src_layer_attention, |
179 | src_iter, weights_layer, weights_iter, bias, gates, |
180 | diff_dst_layer, diff_dst_iter, cell_scratchpad_); |
181 | break; |
182 | case LBR_GRU: |
183 | case LBR_AUGRU: |
184 | lbr_gru_bwd(prb, diff_src_layer, diff_src_layer_attention, |
185 | diff_src_iter, diff_weights_layer, diff_weights_iter, |
186 | diff_bias, b_gates, src_layer, src_layer_attention, |
187 | src_iter, weights_layer, weights_iter, bias, gates, |
188 | diff_dst_layer, diff_dst_iter, cell_scratchpad_); |
189 | default: break; |
190 | } |
191 | } |
192 | |
193 | void rnn_linear_bwd(const prb_t &prb, const args_t &args, |
194 | const AOC<const float> &ws_src_layer, |
195 | const AOC<const float> &ws_src_iter, |
196 | const AOC<const float> &ws_src_iter_c, const AOC<const float> &ws_gates, |
197 | const AOC<const float> &ws_ht) { |
198 | const dnn_mem_t &src_layer_attention_ = args.find(DNNL_ARG_AUGRU_ATTENTION); |
199 | const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER); |
200 | const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER); |
201 | const dnn_mem_t &weights_peephole_ = args.find(DNNL_ARG_WEIGHTS_PEEPHOLE); |
202 | const dnn_mem_t &weights_projection_ |
203 | = args.find(DNNL_ARG_WEIGHTS_PROJECTION); |
204 | const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS); |
205 | const dnn_mem_t &diff_src_layer_attention_ |
206 | = args.find(DNNL_ARG_DIFF_AUGRU_ATTENTION); |
207 | const dnn_mem_t &diff_weights_layer_ |
208 | = args.find(DNNL_ARG_DIFF_WEIGHTS_LAYER); |
209 | const dnn_mem_t &diff_weights_iter_ = args.find(DNNL_ARG_DIFF_WEIGHTS_ITER); |
210 | const dnn_mem_t &diff_weights_peephole_ |
211 | = args.find(DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE); |
212 | const dnn_mem_t &diff_weights_projection_ |
213 | = args.find(DNNL_ARG_DIFF_WEIGHTS_PROJECTION); |
214 | const dnn_mem_t &diff_bias_ = args.find(DNNL_ARG_DIFF_BIAS); |
215 | |
216 | bool is_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU; |
217 | |
218 | AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(), |
219 | prb.n_gates() * prb.dhc, prb.slc); |
220 | AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(), |
221 | prb.n_gates() * prb.dhc, prb.sic); |
222 | |
223 | AOC<float> diff_weights_layer(diff_weights_layer_, prb.n_layer, prb.n_dir(), |
224 | prb.n_gates() * prb.dhc, prb.slc); |
225 | AOC<float> diff_weights_iter(diff_weights_iter_, prb.n_layer, prb.n_dir(), |
226 | prb.n_gates() * prb.dhc, prb.sic); |
227 | |
228 | AOC<const float> weights_peephole( |
229 | weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc); |
230 | AOC<float> diff_weights_peephole( |
231 | diff_weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc); |
232 | |
233 | AOC<const float> weights_projection( |
234 | weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic); |
235 | AOC<float> diff_weights_projection(diff_weights_projection_, prb.n_layer, |
236 | prb.n_dir(), prb.dhc * prb.dic); |
237 | |
238 | AOC<const float> bias( |
239 | bias_, prb.n_layer, prb.n_dir(), prb.n_gates() + is_lbr, prb.dhc); |
240 | AOC<float> diff_bias(diff_bias_, prb.n_layer, prb.n_dir(), |
241 | prb.n_gates() + is_lbr, prb.dhc); |
242 | |
243 | std::vector<float> ws_bwd_buffer; |
244 | AOC<float> ws_diff_src_layer, ws_diff_src_iter, ws_diff_src_iter_c; |
245 | prepare_ws_bwd(prb, ws_bwd_buffer, ws_diff_src_layer, ws_diff_src_iter, |
246 | ws_diff_src_iter_c); |
247 | |
248 | AOC<const float> src_layer_attention( |
249 | src_layer_attention_, prb.n_iter, prb.mb, 1); |
250 | AOC<float> diff_src_layer_attention( |
251 | diff_src_layer_attention_, prb.n_iter, prb.mb, 1); |
252 | |
253 | int64_t b_gates_size = prb.mb * prb.n_gates() * prb.dhc; |
254 | auto *b_gates = new float[b_gates_size]; |
255 | for (int i = 0; i < b_gates_size; i++) { |
256 | b_gates[i] = NAN; |
257 | } |
258 | |
259 | int64_t cell_scratchpad_size = 0; |
260 | switch (prb.alg) { |
261 | case VANILLA_LSTM: cell_scratchpad_size = prb.mb * prb.dhc; break; |
262 | case LBR_GRU: |
263 | case LBR_AUGRU: |
264 | cell_scratchpad_size = prb.mb * (prb.n_gates() + 1) * prb.dhc; |
265 | break; |
266 | case VANILLA_GRU: |
267 | case VANILLA_AUGRU: cell_scratchpad_size = 2 * prb.mb * prb.dhc; break; |
268 | default: cell_scratchpad_size = 0; |
269 | } |
270 | float *cell_scratchpad_ = new float[cell_scratchpad_size]; |
271 | for (int i = 0; i < cell_scratchpad_size; i++) { |
272 | cell_scratchpad_[i] = NAN; |
273 | } |
274 | |
275 | auto process_direction = [&](rnn_iter_direction_t iter_dir, |
276 | rnn_layer_direction_t lay_dir, |
277 | int64_t dir_val, rnn_action_t action) { |
278 | // we first need to copy the initial diff_dst_layer and |
279 | // diff_dst_iter{,_c} into ws to simplify the logic of the code |
280 | copy_init_bwd(prb, ws_diff_src_layer, ws_diff_src_iter, |
281 | ws_diff_src_iter_c, args, iter_dir, lay_dir, dir_val); |
282 | |
283 | // We run the grid of computation |
284 | for (int64_t j = prb.n_layer - 1; j >= 0; j--) { |
285 | for (int64_t i = 0; i < prb.n_iter; i++) { |
286 | int64_t iter |
287 | = (iter_dir == left2right) ? i + 1 : prb.n_iter - i; |
288 | int64_t prev_iter |
289 | = (iter_dir == left2right) ? iter - 1 : iter + 1; |
290 | int64_t lay = j + 1; |
291 | int64_t prev_lay = lay + 1; |
292 | |
293 | int64_t ws_iter = iter; |
294 | int64_t ws_prev_iter |
295 | = (iter_dir == left2right) ? iter + 1 : iter - 1; |
296 | |
297 | #define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr |
298 | rnn_cell_bwd(prb, &ws_diff_src_layer(lay, dir_val, iter, 0, 0), |
299 | SAFE_PTR(diff_src_layer_attention, iter - 1, 0, 0), |
300 | &ws_diff_src_iter(lay, dir_val, iter, 0, 0), |
301 | &ws_diff_src_iter_c(lay, dir_val, iter, 0, 0), |
302 | SAFE_PTR(diff_weights_layer, lay - 1, dir_val, 0, 0), |
303 | SAFE_PTR(diff_weights_iter, lay - 1, dir_val, 0, 0), |
304 | SAFE_PTR(diff_weights_peephole, lay - 1, dir_val, 0), |
305 | SAFE_PTR(diff_weights_projection, lay - 1, dir_val, 0), |
306 | SAFE_PTR(diff_bias, lay - 1, dir_val, 0, 0), b_gates, |
307 | &ws_src_layer(lay - 1, dir_val, ws_iter, 0, 0), |
308 | SAFE_PTR(src_layer_attention, iter - 1, 0, 0), |
309 | &ws_src_iter(lay, dir_val, ws_prev_iter, 0, 0), |
310 | &ws_src_iter_c(lay, dir_val, ws_prev_iter, 0, 0), |
311 | SAFE_PTR(weights_layer, lay - 1, dir_val, 0, 0), |
312 | SAFE_PTR(weights_iter, lay - 1, dir_val, 0, 0), |
313 | SAFE_PTR(weights_peephole, lay - 1, dir_val, 0), |
314 | SAFE_PTR(weights_projection, lay - 1, dir_val, 0), |
315 | SAFE_PTR(bias, lay - 1, dir_val, 0, 0), |
316 | &ws_src_layer(lay, dir_val, ws_iter, 0, 0), |
317 | &ws_src_iter_c(lay, dir_val, ws_iter, 0, 0), |
318 | &ws_gates(lay - 1, dir_val, ws_iter - 1, 0, 0, 0), |
319 | &ws_ht(lay - 1, dir_val, ws_iter - 1, 0, 0), |
320 | &ws_diff_src_layer(prev_lay, dir_val, iter, 0, 0), |
321 | &ws_diff_src_iter(lay, dir_val, prev_iter, 0, 0), |
322 | &ws_diff_src_iter_c(lay, dir_val, prev_iter, 0, 0), |
323 | cell_scratchpad_); |
324 | #undef SAFE_PTR |
325 | } |
326 | } |
327 | |
328 | // Finally we copy the results to the result buffers |
329 | copy_res_bwd(prb, args, ws_diff_src_layer, ws_diff_src_iter, |
330 | ws_diff_src_iter_c, iter_dir, lay_dir, dir_val, action); |
331 | }; |
332 | |
333 | switch (prb.direction) { |
334 | case dnnl_unidirectional_left2right: |
335 | process_direction(right2left, top2bottom, 0, action_copy); |
336 | break; |
337 | case dnnl_unidirectional_right2left: |
338 | process_direction(left2right, top2bottom, 0, action_copy); |
339 | break; |
340 | case dnnl_bidirectional_sum: |
341 | process_direction(right2left, top2bottom, 0, action_copy); |
342 | process_direction(left2right, top2bottom, 1, action_sum); |
343 | break; |
344 | case dnnl_bidirectional_concat: |
345 | process_direction(right2left, top2bottom, 0, action_copy); |
346 | process_direction(left2right, top2bottom, 1, action_sum); |
347 | break; |
348 | default: assert(!"unknown direction" ); break; |
349 | } |
350 | |
351 | delete[] b_gates; |
352 | delete[] cell_scratchpad_; |
353 | } |
354 | |
355 | void compute_ref_bwd(const prb_t &prb, const args_t &args) { |
356 | std::vector<float> ws_fwd_buffer; |
357 | AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht; |
358 | prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c, |
359 | ws_gates, ws_ht); |
360 | |
361 | rnn_linear_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c, |
362 | ws_gates, ws_ht); |
363 | |
364 | rnn_linear_bwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c, |
365 | ws_gates, ws_ht); |
366 | } |
367 | |
368 | } // namespace rnn |
369 | |