1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18
19#include <cmath>
20
21#include "utils/parallel.hpp"
22
23#include "rnn/rnn.hpp"
24#include "rnn/rnn_aux.hpp"
25
26#include "rnn/cells.hpp"
27
28namespace rnn {
29
30void prepare_ws_bwd(const prb_t &prb, std::vector<float> &ws_bwd_buffer,
31 AOC<float> &ws_diff_src_layer, AOC<float> &ws_diff_src_iter,
32 AOC<float> &ws_diff_src_iter_c) {
33 bool is_lstm = prb.alg == VANILLA_LSTM;
34
35 ws_diff_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
36 prb.n_iter + 2, prb.mb, prb.wc);
37 ws_diff_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
38 prb.n_iter + 2, prb.mb, prb.wc);
39 ws_diff_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
40 prb.n_iter + 2, prb.mb, prb.wc);
41
42 int64_t size = ws_diff_src_layer.nelems() + ws_diff_src_iter.nelems()
43 + is_lstm * ws_diff_src_iter_c.nelems();
44 ws_bwd_buffer.resize(size, 0);
45
46 ws_diff_src_layer.set_base_ptr(ws_bwd_buffer.data());
47 ws_diff_src_iter.set_base_ptr(
48 ws_bwd_buffer.data() + ws_diff_src_layer.nelems());
49 ws_diff_src_iter_c.set_base_ptr(ws_bwd_buffer.data()
50 + ws_diff_src_layer.nelems() + ws_diff_src_iter.nelems());
51}
52
53/******************************************************************************/
54/******************************* Copy Routines ********************************/
55/******************************************************************************/
56
57void copy_init_bwd(const prb_t &prb, const AOC<float> &ws_diff_src_layer,
58 const AOC<float> &ws_diff_src_iter,
59 const AOC<float> &ws_diff_src_iter_c, const args_t &args,
60 rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
61 int64_t dir_val) {
62 const dnn_mem_t &diff_dst_layer_ = args.find(DNNL_ARG_DIFF_DST_LAYER);
63 const dnn_mem_t &diff_dst_iter_ = args.find(DNNL_ARG_DIFF_DST_ITER);
64 const dnn_mem_t &diff_dst_iter_c_ = args.find(DNNL_ARG_DIFF_DST_ITER_C);
65
66 AOC<const float> diff_dst_layer(
67 diff_dst_layer_, prb.n_iter, prb.mb * prb.dlc(PRIMITIVE));
68 AOC<const float> diff_dst_iter(
69 diff_dst_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.dic);
70 AOC<const float> diff_dst_iter_c(
71 diff_dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc);
72
73 const bool is_concat = prb.direction == dnnl_bidirectional_concat;
74 int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1;
75 int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1;
76
77 for (int64_t it = 0; it < prb.n_iter; it++)
78 copy(prb.mb, prb.dlc(CELL), prb.dlc(PRIMITIVE), prb.wc,
79 &diff_dst_layer(it, dir_val * is_concat * prb.dlc(CELL)),
80 &ws_diff_src_layer(lay_dest, dir_val, it + 1, 0, 0));
81
82 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
83 copy(prb.mb, prb.dic, prb.dic, prb.wc, &diff_dst_iter(lay, dir_val, 0),
84 &ws_diff_src_iter(lay + 1, dir_val, it_dest, 0, 0));
85 if (prb.alg == VANILLA_LSTM) {
86 copy(prb.mb, prb.dhc, prb.dhc, prb.wc,
87 &diff_dst_iter_c(lay, dir_val, 0),
88 &ws_diff_src_iter_c(lay + 1, dir_val, it_dest, 0, 0));
89 }
90 }
91}
92
93void copy_res_bwd(const prb_t &prb, const args_t &args,
94 const AOC<const float> &ws_diff_src_layer,
95 const AOC<const float> &ws_diff_src_iter,
96 const AOC<const float> &ws_diff_src_iter_c,
97 rnn_iter_direction_t iter_dir, rnn_layer_direction_t lay_dir,
98 int64_t dir_val, rnn_action_t action) {
99 const dnn_mem_t &diff_src_layer_ = args.find(DNNL_ARG_DIFF_SRC_LAYER);
100 const dnn_mem_t &diff_src_iter_ = args.find(DNNL_ARG_DIFF_SRC_ITER);
101 const dnn_mem_t &diff_src_iter_c_ = args.find(DNNL_ARG_DIFF_SRC_ITER_C);
102
103 AOC<float> diff_src_iter(
104 diff_src_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.sic);
105 AOC<float> diff_src_iter_c(
106 diff_src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc);
107 AOC<float> diff_src_layer(diff_src_layer_, prb.n_iter, prb.mb, prb.slc);
108
109 for (int64_t it = 0; it < prb.n_iter; it++) {
110 for (int64_t nb = 0; nb < prb.mb; nb++) {
111 auto from = &ws_diff_src_layer(1, dir_val, it + 1, nb, 0);
112 auto to = &diff_src_layer(it, nb, 0);
113
114 copy(1, prb.slc, prb.wc, prb.slc, from, to, action);
115 }
116 }
117
118 int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1;
119
120 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
121 if (prb.alg == VANILLA_LSTM) {
122 copy(prb.mb, prb.dhc, prb.wc, prb.dhc,
123 &ws_diff_src_iter_c(lay + 1, dir_val, it_source, 0, 0),
124 &diff_src_iter_c(lay, dir_val, 0, 0));
125 }
126 copy(prb.mb, prb.sic, prb.wc, prb.sic,
127 &ws_diff_src_iter(lay + 1, dir_val, it_source, 0, 0),
128 &diff_src_iter(lay, dir_val, 0, 0));
129 }
130}
131
132/******************************************************************************/
133/*************************** Computation Routines *****************************/
134/******************************************************************************/
135void gates_reduction(
136 const prb_t &prb, const float *b_gates_, float *diff_bias_) {
137 AOC<const float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
138 for (int64_t i = 0; i < prb.mb; i++)
139 for (int64_t j = 0; j < prb.n_gates(); j++)
140 for (int64_t k = 0; k < prb.dhc; k++)
141 diff_bias_[j * prb.dhc + k] += b_gates(i, j, k);
142}
143
144void rnn_cell_bwd(const prb_t &prb, float *diff_src_layer,
145 float *diff_src_layer_attention, float *diff_src_iter,
146 float *diff_src_iter_c, float *diff_weights_layer,
147 float *diff_weights_iter, float *diff_weights_peephole,
148 float *diff_weights_projection, float *diff_bias, float *b_gates,
149 const float *src_layer, const float *src_layer_attention,
150 const float *src_iter, const float *src_iter_c,
151 const float *weights_layer, const float *weights_iter,
152 const float *weights_peephole, const float *weights_projection,
153 const float *bias, const float *dst_layer, const float *dst_iter_c,
154 const float *gates, const float *ht, const float *diff_dst_layer,
155 const float *diff_dst_iter, const float *diff_dst_iter_c,
156 float *cell_scratchpad_) {
157
158 switch (prb.alg) {
159 case VANILLA_LSTM:
160 lstm_bwd(prb, diff_src_layer, diff_src_iter, diff_src_iter_c,
161 diff_weights_layer, diff_weights_iter,
162 diff_weights_peephole, diff_weights_projection, diff_bias,
163 b_gates, src_layer, src_iter, src_iter_c, weights_layer,
164 weights_iter, weights_peephole, weights_projection, bias,
165 dst_layer, dst_iter_c, gates, ht, diff_dst_layer,
166 diff_dst_iter, diff_dst_iter_c, cell_scratchpad_);
167 break;
168 case VANILLA_RNN:
169 rnn_bwd(prb, diff_src_layer, diff_src_iter, diff_weights_layer,
170 diff_weights_iter, diff_bias, b_gates, src_layer, src_iter,
171 weights_layer, weights_iter, bias, gates, diff_dst_layer,
172 diff_dst_iter);
173 break;
174 case VANILLA_GRU:
175 case VANILLA_AUGRU:
176 gru_bwd(prb, diff_src_layer, diff_src_layer_attention,
177 diff_src_iter, diff_weights_layer, diff_weights_iter,
178 diff_bias, b_gates, src_layer, src_layer_attention,
179 src_iter, weights_layer, weights_iter, bias, gates,
180 diff_dst_layer, diff_dst_iter, cell_scratchpad_);
181 break;
182 case LBR_GRU:
183 case LBR_AUGRU:
184 lbr_gru_bwd(prb, diff_src_layer, diff_src_layer_attention,
185 diff_src_iter, diff_weights_layer, diff_weights_iter,
186 diff_bias, b_gates, src_layer, src_layer_attention,
187 src_iter, weights_layer, weights_iter, bias, gates,
188 diff_dst_layer, diff_dst_iter, cell_scratchpad_);
189 default: break;
190 }
191}
192
193void rnn_linear_bwd(const prb_t &prb, const args_t &args,
194 const AOC<const float> &ws_src_layer,
195 const AOC<const float> &ws_src_iter,
196 const AOC<const float> &ws_src_iter_c, const AOC<const float> &ws_gates,
197 const AOC<const float> &ws_ht) {
198 const dnn_mem_t &src_layer_attention_ = args.find(DNNL_ARG_AUGRU_ATTENTION);
199 const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER);
200 const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER);
201 const dnn_mem_t &weights_peephole_ = args.find(DNNL_ARG_WEIGHTS_PEEPHOLE);
202 const dnn_mem_t &weights_projection_
203 = args.find(DNNL_ARG_WEIGHTS_PROJECTION);
204 const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS);
205 const dnn_mem_t &diff_src_layer_attention_
206 = args.find(DNNL_ARG_DIFF_AUGRU_ATTENTION);
207 const dnn_mem_t &diff_weights_layer_
208 = args.find(DNNL_ARG_DIFF_WEIGHTS_LAYER);
209 const dnn_mem_t &diff_weights_iter_ = args.find(DNNL_ARG_DIFF_WEIGHTS_ITER);
210 const dnn_mem_t &diff_weights_peephole_
211 = args.find(DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
212 const dnn_mem_t &diff_weights_projection_
213 = args.find(DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
214 const dnn_mem_t &diff_bias_ = args.find(DNNL_ARG_DIFF_BIAS);
215
216 bool is_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU;
217
218 AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
219 prb.n_gates() * prb.dhc, prb.slc);
220 AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
221 prb.n_gates() * prb.dhc, prb.sic);
222
223 AOC<float> diff_weights_layer(diff_weights_layer_, prb.n_layer, prb.n_dir(),
224 prb.n_gates() * prb.dhc, prb.slc);
225 AOC<float> diff_weights_iter(diff_weights_iter_, prb.n_layer, prb.n_dir(),
226 prb.n_gates() * prb.dhc, prb.sic);
227
228 AOC<const float> weights_peephole(
229 weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc);
230 AOC<float> diff_weights_peephole(
231 diff_weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc);
232
233 AOC<const float> weights_projection(
234 weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic);
235 AOC<float> diff_weights_projection(diff_weights_projection_, prb.n_layer,
236 prb.n_dir(), prb.dhc * prb.dic);
237
238 AOC<const float> bias(
239 bias_, prb.n_layer, prb.n_dir(), prb.n_gates() + is_lbr, prb.dhc);
240 AOC<float> diff_bias(diff_bias_, prb.n_layer, prb.n_dir(),
241 prb.n_gates() + is_lbr, prb.dhc);
242
243 std::vector<float> ws_bwd_buffer;
244 AOC<float> ws_diff_src_layer, ws_diff_src_iter, ws_diff_src_iter_c;
245 prepare_ws_bwd(prb, ws_bwd_buffer, ws_diff_src_layer, ws_diff_src_iter,
246 ws_diff_src_iter_c);
247
248 AOC<const float> src_layer_attention(
249 src_layer_attention_, prb.n_iter, prb.mb, 1);
250 AOC<float> diff_src_layer_attention(
251 diff_src_layer_attention_, prb.n_iter, prb.mb, 1);
252
253 int64_t b_gates_size = prb.mb * prb.n_gates() * prb.dhc;
254 auto *b_gates = new float[b_gates_size];
255 for (int i = 0; i < b_gates_size; i++) {
256 b_gates[i] = NAN;
257 }
258
259 int64_t cell_scratchpad_size = 0;
260 switch (prb.alg) {
261 case VANILLA_LSTM: cell_scratchpad_size = prb.mb * prb.dhc; break;
262 case LBR_GRU:
263 case LBR_AUGRU:
264 cell_scratchpad_size = prb.mb * (prb.n_gates() + 1) * prb.dhc;
265 break;
266 case VANILLA_GRU:
267 case VANILLA_AUGRU: cell_scratchpad_size = 2 * prb.mb * prb.dhc; break;
268 default: cell_scratchpad_size = 0;
269 }
270 float *cell_scratchpad_ = new float[cell_scratchpad_size];
271 for (int i = 0; i < cell_scratchpad_size; i++) {
272 cell_scratchpad_[i] = NAN;
273 }
274
275 auto process_direction = [&](rnn_iter_direction_t iter_dir,
276 rnn_layer_direction_t lay_dir,
277 int64_t dir_val, rnn_action_t action) {
278 // we first need to copy the initial diff_dst_layer and
279 // diff_dst_iter{,_c} into ws to simplify the logic of the code
280 copy_init_bwd(prb, ws_diff_src_layer, ws_diff_src_iter,
281 ws_diff_src_iter_c, args, iter_dir, lay_dir, dir_val);
282
283 // We run the grid of computation
284 for (int64_t j = prb.n_layer - 1; j >= 0; j--) {
285 for (int64_t i = 0; i < prb.n_iter; i++) {
286 int64_t iter
287 = (iter_dir == left2right) ? i + 1 : prb.n_iter - i;
288 int64_t prev_iter
289 = (iter_dir == left2right) ? iter - 1 : iter + 1;
290 int64_t lay = j + 1;
291 int64_t prev_lay = lay + 1;
292
293 int64_t ws_iter = iter;
294 int64_t ws_prev_iter
295 = (iter_dir == left2right) ? iter + 1 : iter - 1;
296
297#define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr
298 rnn_cell_bwd(prb, &ws_diff_src_layer(lay, dir_val, iter, 0, 0),
299 SAFE_PTR(diff_src_layer_attention, iter - 1, 0, 0),
300 &ws_diff_src_iter(lay, dir_val, iter, 0, 0),
301 &ws_diff_src_iter_c(lay, dir_val, iter, 0, 0),
302 SAFE_PTR(diff_weights_layer, lay - 1, dir_val, 0, 0),
303 SAFE_PTR(diff_weights_iter, lay - 1, dir_val, 0, 0),
304 SAFE_PTR(diff_weights_peephole, lay - 1, dir_val, 0),
305 SAFE_PTR(diff_weights_projection, lay - 1, dir_val, 0),
306 SAFE_PTR(diff_bias, lay - 1, dir_val, 0, 0), b_gates,
307 &ws_src_layer(lay - 1, dir_val, ws_iter, 0, 0),
308 SAFE_PTR(src_layer_attention, iter - 1, 0, 0),
309 &ws_src_iter(lay, dir_val, ws_prev_iter, 0, 0),
310 &ws_src_iter_c(lay, dir_val, ws_prev_iter, 0, 0),
311 SAFE_PTR(weights_layer, lay - 1, dir_val, 0, 0),
312 SAFE_PTR(weights_iter, lay - 1, dir_val, 0, 0),
313 SAFE_PTR(weights_peephole, lay - 1, dir_val, 0),
314 SAFE_PTR(weights_projection, lay - 1, dir_val, 0),
315 SAFE_PTR(bias, lay - 1, dir_val, 0, 0),
316 &ws_src_layer(lay, dir_val, ws_iter, 0, 0),
317 &ws_src_iter_c(lay, dir_val, ws_iter, 0, 0),
318 &ws_gates(lay - 1, dir_val, ws_iter - 1, 0, 0, 0),
319 &ws_ht(lay - 1, dir_val, ws_iter - 1, 0, 0),
320 &ws_diff_src_layer(prev_lay, dir_val, iter, 0, 0),
321 &ws_diff_src_iter(lay, dir_val, prev_iter, 0, 0),
322 &ws_diff_src_iter_c(lay, dir_val, prev_iter, 0, 0),
323 cell_scratchpad_);
324#undef SAFE_PTR
325 }
326 }
327
328 // Finally we copy the results to the result buffers
329 copy_res_bwd(prb, args, ws_diff_src_layer, ws_diff_src_iter,
330 ws_diff_src_iter_c, iter_dir, lay_dir, dir_val, action);
331 };
332
333 switch (prb.direction) {
334 case dnnl_unidirectional_left2right:
335 process_direction(right2left, top2bottom, 0, action_copy);
336 break;
337 case dnnl_unidirectional_right2left:
338 process_direction(left2right, top2bottom, 0, action_copy);
339 break;
340 case dnnl_bidirectional_sum:
341 process_direction(right2left, top2bottom, 0, action_copy);
342 process_direction(left2right, top2bottom, 1, action_sum);
343 break;
344 case dnnl_bidirectional_concat:
345 process_direction(right2left, top2bottom, 0, action_copy);
346 process_direction(left2right, top2bottom, 1, action_sum);
347 break;
348 default: assert(!"unknown direction"); break;
349 }
350
351 delete[] b_gates;
352 delete[] cell_scratchpad_;
353}
354
355void compute_ref_bwd(const prb_t &prb, const args_t &args) {
356 std::vector<float> ws_fwd_buffer;
357 AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht;
358 prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c,
359 ws_gates, ws_ht);
360
361 rnn_linear_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c,
362 ws_gates, ws_ht);
363
364 rnn_linear_bwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c,
365 ws_gates, ws_ht);
366}
367
368} // namespace rnn
369