1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18
19#include <cmath>
20
21#include "rnn/rnn.hpp"
22#include "rnn/rnn_aux.hpp"
23
24#include "rnn/cells.hpp"
25
26namespace rnn {
27
28void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer,
29 AOC<float> &ws_src_layer, AOC<float> &ws_src_iter,
30 AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht) {
31 bool is_lstm = prb.alg == VANILLA_LSTM;
32 bool is_lstmp = prb.is_lstm_projection();
33
34 ws_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
35 prb.n_iter + 2, prb.mb, prb.wc);
36 ws_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
37 prb.n_iter + 2, prb.mb, prb.wc);
38 ws_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
39 prb.n_iter + 2, prb.mb, prb.wc);
40 ws_gates = AOC<float>(nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb,
41 prb.n_gates(), prb.dhc);
42 ws_ht = AOC<float>(
43 nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb, prb.wc);
44
45 int64_t size = ws_src_layer.nelems() + is_lstm * ws_src_iter_c.nelems()
46 + ws_gates.nelems() + is_lstmp * ws_ht.nelems();
47 ws_fwd_buffer.resize(size);
48
49 float *ptr = ws_fwd_buffer.data();
50 ws_src_layer.set_base_ptr(ptr);
51 ws_src_iter.set_base_ptr(ptr);
52
53 ptr += ws_src_iter.nelems();
54 ws_src_iter_c.set_base_ptr(ptr);
55
56 ptr += is_lstm * ws_src_iter_c.nelems();
57 ws_gates.set_base_ptr(ptr);
58
59 ptr += is_lstmp * ws_gates.nelems();
60 ws_ht.set_base_ptr(ptr);
61}
62
63/******************************************************************************/
64/******************************* Copy Routines ********************************/
65/******************************************************************************/
66void prepare_projection_compensation(const prb_t &prb,
67 float *weights_projection_compensation_, const args_t &args) {
68 const dnn_mem_t &weights_projection_
69 = args.find(DNNL_ARG_WEIGHTS_PROJECTION);
70
71 AOC<float> weights_projection_compensation(weights_projection_compensation_,
72 prb.n_layer, prb.n_dir(), prb.dic);
73 AOC<const float> weights_projection(
74 weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc, prb.dic);
75 for (int layer = 0; layer < prb.n_layer; ++layer)
76 for (int dir = 0; dir < prb.n_dir(); ++dir)
77 for (int dic = 0; dic < prb.dic; ++dic) {
78 float weights_compensation = 0;
79 for (int dhc = 0; dhc < prb.dhc; ++dhc)
80 weights_compensation
81 += weights_projection(layer, dir, dhc, dic);
82 weights_projection_compensation(layer, dir, dic)
83 = weights_compensation;
84 }
85}
86
87void prepare_bias(
88 const prb_t &prb, float *bias_with_compensation_, const args_t &args) {
89 const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS);
90 const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER);
91 const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER);
92
93 AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
94 prb.slc, prb.n_gates(), prb.dhc);
95 AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
96 prb.sic, prb.n_gates(), prb.dhc);
97
98 AOC<const float> bias(
99 bias_, prb.n_layer, prb.n_dir(), prb.n_gates(), prb.dhc);
100 AOC<float> bias_with_compensation(bias_with_compensation_, prb.n_layer,
101 prb.n_dir(), prb.n_gates(), prb.dhc);
102
103 for (int layer = 0; layer < prb.n_layer; ++layer)
104 for (int dir = 0; dir < prb.n_dir(); ++dir)
105 for (int gate = 0; gate < prb.n_gates(); ++gate)
106 for (int dhc = 0; dhc < prb.dhc; ++dhc) {
107 float weights_compensation = 0;
108 for (int sic = 0; sic < prb.sic; ++sic)
109 weights_compensation
110 += weights_iter(layer, dir, sic, gate, dhc);
111 for (int slc = 0; slc < prb.slc; ++slc)
112 weights_compensation
113 += weights_layer(layer, dir, slc, gate, dhc);
114
115 float scale = prb.data_scale
116 * prb.get_wei_scale(gate * prb.dhc + dhc);
117 bias_with_compensation(layer, dir, gate, dhc)
118 = bias(layer, dir, gate, dhc)
119 - weights_compensation * prb.data_shift / scale;
120 }
121}
122
123void copy_init_fwd(const prb_t &prb, const AOC<float> &ws_src_layer,
124 const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
125 const args_t &args, rnn_iter_direction_t iter_dir,
126 rnn_layer_direction_t lay_dir, int64_t dir_val) {
127 const dnn_mem_t &src_layer_ = args.find(DNNL_ARG_SRC_LAYER);
128 const dnn_mem_t &src_iter_ = args.find(DNNL_ARG_SRC_ITER);
129 const dnn_mem_t &src_iter_c_ = args.find(DNNL_ARG_SRC_ITER_C);
130
131 AOC<const float> src_layer(src_layer_, prb.n_iter, prb.mb * prb.slc);
132 AOC<const float> src_iter(
133 src_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.sic);
134 AOC<const float> src_iter_c(
135 src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc);
136
137 int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1;
138 int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1;
139
140 // Copy src_layer
141 for (int64_t it = 0; it < prb.n_iter; it++) {
142 copy(prb.mb, prb.slc, prb.slc, prb.wc, &src_layer(it, 0),
143 &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0));
144 if (prb.is_int8())
145 data_q10n(prb.mb, prb.slc, prb.wc,
146 &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0),
147 prb.data_scale, prb.data_shift);
148 }
149
150 // Copy src_iter (and src_iter_c)
151 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
152 copy(prb.mb, prb.sic, prb.sic, prb.wc, &src_iter(lay, dir_val, 0),
153 &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0));
154 if (prb.is_int8())
155 data_q10n(prb.mb, prb.sic, prb.wc,
156 &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0),
157 prb.data_scale, prb.data_shift);
158
159 if (prb.alg == VANILLA_LSTM)
160 copy(prb.mb, prb.dhc, prb.dhc, prb.wc, &src_iter_c(lay, dir_val, 0),
161 &ws_src_iter_c(lay + 1, dir_val, it_dest, 0, 0));
162 }
163}
164
165void copy_res_fwd(const prb_t &prb, const args_t &args,
166 const AOC<const float> &ws_src_layer,
167 const AOC<const float> &ws_src_iter,
168 const AOC<const float> &ws_src_iter_c, rnn_iter_direction_t iter_dir,
169 rnn_layer_direction_t lay_dir, int64_t dir_val, rnn_action_t action) {
170 const dnn_mem_t &dst_layer_ = args.find(DNNL_ARG_DST_LAYER);
171 const dnn_mem_t &dst_iter_ = args.find(DNNL_ARG_DST_ITER);
172 const dnn_mem_t &dst_iter_c_ = args.find(DNNL_ARG_DST_ITER_C);
173
174 AOC<float> dst_iter(dst_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.dic);
175 AOC<float> dst_iter_c(
176 dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc);
177 AOC<float> dst_layer(dst_layer_, prb.n_iter, prb.mb, prb.dlc(PRIMITIVE));
178 const bool is_layer_deq = (prb.is_u8() && prb.cfg[DST_LAYER].dt != dnnl_u8)
179 || (prb.is_s8() && prb.cfg[DST_LAYER].dt != dnnl_s8);
180 const bool is_iter_deq = (prb.is_u8() && prb.cfg[DST_ITER].dt != dnnl_u8)
181 || (prb.is_s8() && prb.cfg[DST_ITER].dt != dnnl_s8);
182 // Copy dst_layer
183 for (int64_t it = 0; it < prb.n_iter; it++) {
184 for (int64_t nb = 0; nb < prb.mb; nb++) {
185 auto from = &ws_src_layer(prb.n_layer, dir_val, it + 1, nb, 0);
186 auto to = &dst_layer(
187 it, nb, action == action_concat ? prb.dlc(CELL) : 0);
188 copy(1, prb.dlc(CELL), prb.wc, prb.dlc(PRIMITIVE), from, to, action,
189 prb.is_int8());
190
191 if (is_layer_deq) {
192 float data_shift = prb.data_shift;
193 bool do_deq10n = true;
194
195 if (prb.direction == dnnl_bidirectional_sum) {
196 // In `bidir_sum` case, we need to dequantize data only
197 // after the final summation. Also, since we sum two shifted
198 // tensors, we need to enlarge the shift by 2x.
199 do_deq10n = action == action_sum;
200 data_shift *= 2;
201 }
202
203 if (do_deq10n)
204 data_deq10n(1, prb.dlc(CELL), prb.dlc(PRIMITIVE), to,
205 prb.data_scale, data_shift);
206 }
207 }
208 }
209
210 int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1;
211
212 // Copy dst_iter (and dst_iter_c)
213 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
214 if (prb.alg == VANILLA_LSTM) {
215 copy(prb.mb, prb.dhc, prb.wc, prb.dhc,
216 &ws_src_iter_c(lay + 1, dir_val, it_source, 0, 0),
217 &dst_iter_c(lay, dir_val, 0, 0));
218 }
219
220 copy(prb.mb, prb.dic, prb.wc, prb.dic,
221 &ws_src_iter(lay + 1, dir_val, it_source, 0, 0),
222 &dst_iter(lay, dir_val, 0, 0));
223 if (is_iter_deq)
224 data_deq10n(prb.mb, prb.dic, prb.dic, &dst_iter(lay, dir_val, 0, 0),
225 prb.data_scale, prb.data_shift);
226 }
227}
228
229/******************************************************************************/
230/*************************** Computation Routines *****************************/
231/******************************************************************************/
232
233void rnn_cell_fwd(const prb_t &prb, float *dst_layer, float *dst_iter,
234 float *dst_iter_c, float *gates, float *ht, const float *weights_layer,
235 const float *weights_iter, const float *weights_peephole,
236 const float *weights_projection,
237 const float *weights_projection_compensation, const float *bias,
238 const float *src_layer, const float *src_layer_attention,
239 const float *src_iter, const float *src_iter_c,
240 float *cell_scratchpad_) {
241 if (prb.alg != VANILLA_LSTM) assert(dst_layer == dst_iter);
242
243 switch (prb.alg) {
244 case VANILLA_GRU:
245 case VANILLA_AUGRU:
246 gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
247 src_layer, src_layer_attention, src_iter);
248 break;
249 case LBR_GRU:
250 case LBR_AUGRU:
251 lbr_gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter,
252 bias, src_layer, src_layer_attention, src_iter,
253 cell_scratchpad_);
254 break;
255 case VANILLA_LSTM:
256 lstm_fwd(prb, dst_layer, dst_iter, dst_iter_c, gates, ht,
257 weights_layer, weights_iter, weights_peephole,
258 weights_projection, weights_projection_compensation, bias,
259 src_layer, src_iter, src_iter_c);
260 break;
261 case VANILLA_RNN:
262 rnn_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
263 src_layer, src_iter);
264 break;
265 default: break;
266 }
267}
268
269void rnn_linear_fwd(const prb_t &prb, const args_t &args,
270 const AOC<float> &ws_src_layer, const AOC<float> &ws_src_iter,
271 const AOC<float> &ws_src_iter_c, const AOC<float> &ws_gates,
272 const AOC<float> &ws_ht) {
273 const dnn_mem_t &src_layer_attention_ = args.find(DNNL_ARG_AUGRU_ATTENTION);
274 const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER);
275 const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER);
276 const dnn_mem_t &weights_peephole_ = args.find(DNNL_ARG_WEIGHTS_PEEPHOLE);
277 const dnn_mem_t &weights_projection_
278 = args.find(DNNL_ARG_WEIGHTS_PROJECTION);
279 const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS);
280
281 float *bias_ptr = (float *)bias_;
282
283 bool is_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU;
284
285 float *bias_with_compensation = nullptr;
286 float *weights_projection_compensation_ = nullptr;
287 if (prb.is_int8()) {
288 bias_with_compensation = new float[prb.n_layer * prb.n_dir()
289 * (prb.n_gates() + is_lbr) * prb.dhc];
290 prepare_bias(prb, bias_with_compensation, args);
291 bias_ptr = bias_with_compensation;
292 if (prb.is_lstm_projection()) {
293 weights_projection_compensation_
294 = new float[prb.n_layer * prb.n_dir() * prb.dic];
295 prepare_projection_compensation(
296 prb, weights_projection_compensation_, args);
297 }
298 }
299
300 AOC<const float> weights_peephole(
301 weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc);
302 AOC<const float> weights_projection(
303 weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic);
304 AOC<const float> weights_projection_compensation(
305 weights_projection_compensation_, prb.n_layer, prb.n_dir(),
306 prb.dic);
307 AOC<const float> bias(bias_ptr, prb.n_layer, prb.n_dir(),
308 (prb.n_gates() + is_lbr) * prb.dhc);
309 AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
310 prb.n_gates() * prb.dhc, prb.slc);
311 AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
312 prb.n_gates() * prb.dhc, prb.sic);
313
314 AOC<const float> src_layer_attention(
315 src_layer_attention_, prb.n_iter, prb.mb, 1);
316
317 int64_t cell_scratchpad_size = is_lbr * prb.mb * prb.n_gates() * prb.dhc;
318 float *cell_scratchpad_
319 = (float *)zmalloc(cell_scratchpad_size * sizeof(float), 4096);
320 SAFE_V(cell_scratchpad_ != nullptr ? OK : FAIL);
321 for (int i = 0; i < cell_scratchpad_size; i++) {
322 cell_scratchpad_[i] = NAN;
323 }
324
325 auto process_direction = [&](rnn_iter_direction_t iter_dir,
326 rnn_layer_direction_t lay_dir,
327 int64_t dir_val, rnn_action_t action) {
328 // we first need to copy the initial src_layer and src_iter{,_c} into
329 // ws to simplify the logic of the code
330 BENCHDNN_PRINT(80,
331 "rnn_linear_fwd: call copy_init dir_val = " IFMT "\n", dir_val);
332 copy_init_fwd(prb, ws_src_layer, ws_src_iter, ws_src_iter_c, args,
333 iter_dir, lay_dir, dir_val);
334
335 // We run the grid of computation
336 for (int64_t il = 0; il < prb.n_layer; il++) {
337 for (int64_t it = 0; it < prb.n_iter; it++) {
338 BENCHDNN_PRINT(80,
339 "==== layer = " IFMT " iter = " IFMT " ===\n", il, it);
340 int64_t iter
341 = (iter_dir == left2right) ? it + 1 : prb.n_iter - it;
342 int64_t prev_iter
343 = (iter_dir == left2right) ? iter - 1 : iter + 1;
344 int64_t lay = il + 1;
345#define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr
346 rnn_cell_fwd(prb, &ws_src_layer(lay, dir_val, iter, 0, 0),
347 &ws_src_iter(lay, dir_val, iter, 0, 0),
348 &ws_src_iter_c(lay, dir_val, iter, 0, 0),
349 &ws_gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
350 &ws_ht(lay - 1, dir_val, iter - 1, 0, 0),
351 SAFE_PTR(weights_layer, lay - 1, dir_val, 0, 0),
352 SAFE_PTR(weights_iter, lay - 1, dir_val, 0, 0),
353 SAFE_PTR(weights_peephole, lay - 1, dir_val, 0),
354 SAFE_PTR(weights_projection, lay - 1, dir_val, 0),
355 SAFE_PTR(weights_projection_compensation, lay - 1,
356 dir_val, 0),
357 SAFE_PTR(bias, lay - 1, dir_val, 0),
358 &ws_src_layer(lay - 1, dir_val, iter, 0, 0),
359 SAFE_PTR(src_layer_attention, iter - 1, 0, 0),
360 &ws_src_iter(lay, dir_val, prev_iter, 0, 0),
361 &ws_src_iter_c(lay, dir_val, prev_iter, 0, 0),
362 cell_scratchpad_);
363#undef SAFE_PTR
364 }
365 }
366
367 // Finally we copy the results to the result buffers
368 copy_res_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c,
369 iter_dir, lay_dir, dir_val, action);
370 };
371
372 switch (prb.direction) {
373 case dnnl_unidirectional_left2right:
374 process_direction(left2right, bottom2top, 0, action_copy);
375 break;
376 case dnnl_unidirectional_right2left:
377 process_direction(right2left, bottom2top, 0, action_copy);
378 break;
379 case dnnl_bidirectional_sum:
380 process_direction(left2right, bottom2top, 0, action_copy);
381 process_direction(right2left, bottom2top, 1, action_sum);
382 break;
383 case dnnl_bidirectional_concat:
384 process_direction(left2right, bottom2top, 0, action_copy);
385 process_direction(right2left, bottom2top, 1, action_concat);
386 break;
387 default: assert(!"unknown direction"); break;
388 }
389
390 zfree(cell_scratchpad_);
391 delete[] bias_with_compensation;
392 delete[] weights_projection_compensation_;
393}
394
395void compute_ref_fwd(const prb_t &prb, const args_t &args) {
396 std::vector<float> ws_fwd_buffer;
397 AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht;
398 prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c,
399 ws_gates, ws_ht);
400
401 rnn_linear_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c,
402 ws_gates, ws_ht);
403}
404
405} // namespace rnn
406