1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include <cmath> |
20 | |
21 | #include "rnn/rnn.hpp" |
22 | #include "rnn/rnn_aux.hpp" |
23 | |
24 | #include "rnn/cells.hpp" |
25 | |
26 | namespace rnn { |
27 | |
28 | void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer, |
29 | AOC<float> &ws_src_layer, AOC<float> &ws_src_iter, |
30 | AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht) { |
31 | bool is_lstm = prb.alg == VANILLA_LSTM; |
32 | bool is_lstmp = prb.is_lstm_projection(); |
33 | |
34 | ws_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
35 | prb.n_iter + 2, prb.mb, prb.wc); |
36 | ws_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
37 | prb.n_iter + 2, prb.mb, prb.wc); |
38 | ws_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(), |
39 | prb.n_iter + 2, prb.mb, prb.wc); |
40 | ws_gates = AOC<float>(nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb, |
41 | prb.n_gates(), prb.dhc); |
42 | ws_ht = AOC<float>( |
43 | nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb, prb.wc); |
44 | |
45 | int64_t size = ws_src_layer.nelems() + is_lstm * ws_src_iter_c.nelems() |
46 | + ws_gates.nelems() + is_lstmp * ws_ht.nelems(); |
47 | ws_fwd_buffer.resize(size); |
48 | |
49 | float *ptr = ws_fwd_buffer.data(); |
50 | ws_src_layer.set_base_ptr(ptr); |
51 | ws_src_iter.set_base_ptr(ptr); |
52 | |
53 | ptr += ws_src_iter.nelems(); |
54 | ws_src_iter_c.set_base_ptr(ptr); |
55 | |
56 | ptr += is_lstm * ws_src_iter_c.nelems(); |
57 | ws_gates.set_base_ptr(ptr); |
58 | |
59 | ptr += is_lstmp * ws_gates.nelems(); |
60 | ws_ht.set_base_ptr(ptr); |
61 | } |
62 | |
63 | /******************************************************************************/ |
64 | /******************************* Copy Routines ********************************/ |
65 | /******************************************************************************/ |
66 | void prepare_projection_compensation(const prb_t &prb, |
67 | float *weights_projection_compensation_, const args_t &args) { |
68 | const dnn_mem_t &weights_projection_ |
69 | = args.find(DNNL_ARG_WEIGHTS_PROJECTION); |
70 | |
71 | AOC<float> weights_projection_compensation(weights_projection_compensation_, |
72 | prb.n_layer, prb.n_dir(), prb.dic); |
73 | AOC<const float> weights_projection( |
74 | weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc, prb.dic); |
75 | for (int layer = 0; layer < prb.n_layer; ++layer) |
76 | for (int dir = 0; dir < prb.n_dir(); ++dir) |
77 | for (int dic = 0; dic < prb.dic; ++dic) { |
78 | float weights_compensation = 0; |
79 | for (int dhc = 0; dhc < prb.dhc; ++dhc) |
80 | weights_compensation |
81 | += weights_projection(layer, dir, dhc, dic); |
82 | weights_projection_compensation(layer, dir, dic) |
83 | = weights_compensation; |
84 | } |
85 | } |
86 | |
87 | void prepare_bias( |
88 | const prb_t &prb, float *bias_with_compensation_, const args_t &args) { |
89 | const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS); |
90 | const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER); |
91 | const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER); |
92 | |
93 | AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(), |
94 | prb.slc, prb.n_gates(), prb.dhc); |
95 | AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(), |
96 | prb.sic, prb.n_gates(), prb.dhc); |
97 | |
98 | AOC<const float> bias( |
99 | bias_, prb.n_layer, prb.n_dir(), prb.n_gates(), prb.dhc); |
100 | AOC<float> bias_with_compensation(bias_with_compensation_, prb.n_layer, |
101 | prb.n_dir(), prb.n_gates(), prb.dhc); |
102 | |
103 | for (int layer = 0; layer < prb.n_layer; ++layer) |
104 | for (int dir = 0; dir < prb.n_dir(); ++dir) |
105 | for (int gate = 0; gate < prb.n_gates(); ++gate) |
106 | for (int dhc = 0; dhc < prb.dhc; ++dhc) { |
107 | float weights_compensation = 0; |
108 | for (int sic = 0; sic < prb.sic; ++sic) |
109 | weights_compensation |
110 | += weights_iter(layer, dir, sic, gate, dhc); |
111 | for (int slc = 0; slc < prb.slc; ++slc) |
112 | weights_compensation |
113 | += weights_layer(layer, dir, slc, gate, dhc); |
114 | |
115 | float scale = prb.data_scale |
116 | * prb.get_wei_scale(gate * prb.dhc + dhc); |
117 | bias_with_compensation(layer, dir, gate, dhc) |
118 | = bias(layer, dir, gate, dhc) |
119 | - weights_compensation * prb.data_shift / scale; |
120 | } |
121 | } |
122 | |
123 | void copy_init_fwd(const prb_t &prb, const AOC<float> &ws_src_layer, |
124 | const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c, |
125 | const args_t &args, rnn_iter_direction_t iter_dir, |
126 | rnn_layer_direction_t lay_dir, int64_t dir_val) { |
127 | const dnn_mem_t &src_layer_ = args.find(DNNL_ARG_SRC_LAYER); |
128 | const dnn_mem_t &src_iter_ = args.find(DNNL_ARG_SRC_ITER); |
129 | const dnn_mem_t &src_iter_c_ = args.find(DNNL_ARG_SRC_ITER_C); |
130 | |
131 | AOC<const float> src_layer(src_layer_, prb.n_iter, prb.mb * prb.slc); |
132 | AOC<const float> src_iter( |
133 | src_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.sic); |
134 | AOC<const float> src_iter_c( |
135 | src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc); |
136 | |
137 | int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1; |
138 | int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1; |
139 | |
140 | // Copy src_layer |
141 | for (int64_t it = 0; it < prb.n_iter; it++) { |
142 | copy(prb.mb, prb.slc, prb.slc, prb.wc, &src_layer(it, 0), |
143 | &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0)); |
144 | if (prb.is_int8()) |
145 | data_q10n(prb.mb, prb.slc, prb.wc, |
146 | &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0), |
147 | prb.data_scale, prb.data_shift); |
148 | } |
149 | |
150 | // Copy src_iter (and src_iter_c) |
151 | for (int64_t lay = 0; lay < prb.n_layer; lay++) { |
152 | copy(prb.mb, prb.sic, prb.sic, prb.wc, &src_iter(lay, dir_val, 0), |
153 | &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0)); |
154 | if (prb.is_int8()) |
155 | data_q10n(prb.mb, prb.sic, prb.wc, |
156 | &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0), |
157 | prb.data_scale, prb.data_shift); |
158 | |
159 | if (prb.alg == VANILLA_LSTM) |
160 | copy(prb.mb, prb.dhc, prb.dhc, prb.wc, &src_iter_c(lay, dir_val, 0), |
161 | &ws_src_iter_c(lay + 1, dir_val, it_dest, 0, 0)); |
162 | } |
163 | } |
164 | |
165 | void copy_res_fwd(const prb_t &prb, const args_t &args, |
166 | const AOC<const float> &ws_src_layer, |
167 | const AOC<const float> &ws_src_iter, |
168 | const AOC<const float> &ws_src_iter_c, rnn_iter_direction_t iter_dir, |
169 | rnn_layer_direction_t lay_dir, int64_t dir_val, rnn_action_t action) { |
170 | const dnn_mem_t &dst_layer_ = args.find(DNNL_ARG_DST_LAYER); |
171 | const dnn_mem_t &dst_iter_ = args.find(DNNL_ARG_DST_ITER); |
172 | const dnn_mem_t &dst_iter_c_ = args.find(DNNL_ARG_DST_ITER_C); |
173 | |
174 | AOC<float> dst_iter(dst_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.dic); |
175 | AOC<float> dst_iter_c( |
176 | dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc); |
177 | AOC<float> dst_layer(dst_layer_, prb.n_iter, prb.mb, prb.dlc(PRIMITIVE)); |
178 | const bool is_layer_deq = (prb.is_u8() && prb.cfg[DST_LAYER].dt != dnnl_u8) |
179 | || (prb.is_s8() && prb.cfg[DST_LAYER].dt != dnnl_s8); |
180 | const bool is_iter_deq = (prb.is_u8() && prb.cfg[DST_ITER].dt != dnnl_u8) |
181 | || (prb.is_s8() && prb.cfg[DST_ITER].dt != dnnl_s8); |
182 | // Copy dst_layer |
183 | for (int64_t it = 0; it < prb.n_iter; it++) { |
184 | for (int64_t nb = 0; nb < prb.mb; nb++) { |
185 | auto from = &ws_src_layer(prb.n_layer, dir_val, it + 1, nb, 0); |
186 | auto to = &dst_layer( |
187 | it, nb, action == action_concat ? prb.dlc(CELL) : 0); |
188 | copy(1, prb.dlc(CELL), prb.wc, prb.dlc(PRIMITIVE), from, to, action, |
189 | prb.is_int8()); |
190 | |
191 | if (is_layer_deq) { |
192 | float data_shift = prb.data_shift; |
193 | bool do_deq10n = true; |
194 | |
195 | if (prb.direction == dnnl_bidirectional_sum) { |
196 | // In `bidir_sum` case, we need to dequantize data only |
197 | // after the final summation. Also, since we sum two shifted |
198 | // tensors, we need to enlarge the shift by 2x. |
199 | do_deq10n = action == action_sum; |
200 | data_shift *= 2; |
201 | } |
202 | |
203 | if (do_deq10n) |
204 | data_deq10n(1, prb.dlc(CELL), prb.dlc(PRIMITIVE), to, |
205 | prb.data_scale, data_shift); |
206 | } |
207 | } |
208 | } |
209 | |
210 | int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1; |
211 | |
212 | // Copy dst_iter (and dst_iter_c) |
213 | for (int64_t lay = 0; lay < prb.n_layer; lay++) { |
214 | if (prb.alg == VANILLA_LSTM) { |
215 | copy(prb.mb, prb.dhc, prb.wc, prb.dhc, |
216 | &ws_src_iter_c(lay + 1, dir_val, it_source, 0, 0), |
217 | &dst_iter_c(lay, dir_val, 0, 0)); |
218 | } |
219 | |
220 | copy(prb.mb, prb.dic, prb.wc, prb.dic, |
221 | &ws_src_iter(lay + 1, dir_val, it_source, 0, 0), |
222 | &dst_iter(lay, dir_val, 0, 0)); |
223 | if (is_iter_deq) |
224 | data_deq10n(prb.mb, prb.dic, prb.dic, &dst_iter(lay, dir_val, 0, 0), |
225 | prb.data_scale, prb.data_shift); |
226 | } |
227 | } |
228 | |
229 | /******************************************************************************/ |
230 | /*************************** Computation Routines *****************************/ |
231 | /******************************************************************************/ |
232 | |
233 | void rnn_cell_fwd(const prb_t &prb, float *dst_layer, float *dst_iter, |
234 | float *dst_iter_c, float *gates, float *ht, const float *weights_layer, |
235 | const float *weights_iter, const float *weights_peephole, |
236 | const float *weights_projection, |
237 | const float *weights_projection_compensation, const float *bias, |
238 | const float *src_layer, const float *src_layer_attention, |
239 | const float *src_iter, const float *src_iter_c, |
240 | float *cell_scratchpad_) { |
241 | if (prb.alg != VANILLA_LSTM) assert(dst_layer == dst_iter); |
242 | |
243 | switch (prb.alg) { |
244 | case VANILLA_GRU: |
245 | case VANILLA_AUGRU: |
246 | gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias, |
247 | src_layer, src_layer_attention, src_iter); |
248 | break; |
249 | case LBR_GRU: |
250 | case LBR_AUGRU: |
251 | lbr_gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter, |
252 | bias, src_layer, src_layer_attention, src_iter, |
253 | cell_scratchpad_); |
254 | break; |
255 | case VANILLA_LSTM: |
256 | lstm_fwd(prb, dst_layer, dst_iter, dst_iter_c, gates, ht, |
257 | weights_layer, weights_iter, weights_peephole, |
258 | weights_projection, weights_projection_compensation, bias, |
259 | src_layer, src_iter, src_iter_c); |
260 | break; |
261 | case VANILLA_RNN: |
262 | rnn_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias, |
263 | src_layer, src_iter); |
264 | break; |
265 | default: break; |
266 | } |
267 | } |
268 | |
269 | void rnn_linear_fwd(const prb_t &prb, const args_t &args, |
270 | const AOC<float> &ws_src_layer, const AOC<float> &ws_src_iter, |
271 | const AOC<float> &ws_src_iter_c, const AOC<float> &ws_gates, |
272 | const AOC<float> &ws_ht) { |
273 | const dnn_mem_t &src_layer_attention_ = args.find(DNNL_ARG_AUGRU_ATTENTION); |
274 | const dnn_mem_t &weights_layer_ = args.find(DNNL_ARG_WEIGHTS_LAYER); |
275 | const dnn_mem_t &weights_iter_ = args.find(DNNL_ARG_WEIGHTS_ITER); |
276 | const dnn_mem_t &weights_peephole_ = args.find(DNNL_ARG_WEIGHTS_PEEPHOLE); |
277 | const dnn_mem_t &weights_projection_ |
278 | = args.find(DNNL_ARG_WEIGHTS_PROJECTION); |
279 | const dnn_mem_t &bias_ = args.find(DNNL_ARG_BIAS); |
280 | |
281 | float *bias_ptr = (float *)bias_; |
282 | |
283 | bool is_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU; |
284 | |
285 | float *bias_with_compensation = nullptr; |
286 | float *weights_projection_compensation_ = nullptr; |
287 | if (prb.is_int8()) { |
288 | bias_with_compensation = new float[prb.n_layer * prb.n_dir() |
289 | * (prb.n_gates() + is_lbr) * prb.dhc]; |
290 | prepare_bias(prb, bias_with_compensation, args); |
291 | bias_ptr = bias_with_compensation; |
292 | if (prb.is_lstm_projection()) { |
293 | weights_projection_compensation_ |
294 | = new float[prb.n_layer * prb.n_dir() * prb.dic]; |
295 | prepare_projection_compensation( |
296 | prb, weights_projection_compensation_, args); |
297 | } |
298 | } |
299 | |
300 | AOC<const float> weights_peephole( |
301 | weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc); |
302 | AOC<const float> weights_projection( |
303 | weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic); |
304 | AOC<const float> weights_projection_compensation( |
305 | weights_projection_compensation_, prb.n_layer, prb.n_dir(), |
306 | prb.dic); |
307 | AOC<const float> bias(bias_ptr, prb.n_layer, prb.n_dir(), |
308 | (prb.n_gates() + is_lbr) * prb.dhc); |
309 | AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(), |
310 | prb.n_gates() * prb.dhc, prb.slc); |
311 | AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(), |
312 | prb.n_gates() * prb.dhc, prb.sic); |
313 | |
314 | AOC<const float> src_layer_attention( |
315 | src_layer_attention_, prb.n_iter, prb.mb, 1); |
316 | |
317 | int64_t cell_scratchpad_size = is_lbr * prb.mb * prb.n_gates() * prb.dhc; |
318 | float *cell_scratchpad_ |
319 | = (float *)zmalloc(cell_scratchpad_size * sizeof(float), 4096); |
320 | SAFE_V(cell_scratchpad_ != nullptr ? OK : FAIL); |
321 | for (int i = 0; i < cell_scratchpad_size; i++) { |
322 | cell_scratchpad_[i] = NAN; |
323 | } |
324 | |
325 | auto process_direction = [&](rnn_iter_direction_t iter_dir, |
326 | rnn_layer_direction_t lay_dir, |
327 | int64_t dir_val, rnn_action_t action) { |
328 | // we first need to copy the initial src_layer and src_iter{,_c} into |
329 | // ws to simplify the logic of the code |
330 | BENCHDNN_PRINT(80, |
331 | "rnn_linear_fwd: call copy_init dir_val = " IFMT "\n" , dir_val); |
332 | copy_init_fwd(prb, ws_src_layer, ws_src_iter, ws_src_iter_c, args, |
333 | iter_dir, lay_dir, dir_val); |
334 | |
335 | // We run the grid of computation |
336 | for (int64_t il = 0; il < prb.n_layer; il++) { |
337 | for (int64_t it = 0; it < prb.n_iter; it++) { |
338 | BENCHDNN_PRINT(80, |
339 | "==== layer = " IFMT " iter = " IFMT " ===\n" , il, it); |
340 | int64_t iter |
341 | = (iter_dir == left2right) ? it + 1 : prb.n_iter - it; |
342 | int64_t prev_iter |
343 | = (iter_dir == left2right) ? iter - 1 : iter + 1; |
344 | int64_t lay = il + 1; |
345 | #define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr |
346 | rnn_cell_fwd(prb, &ws_src_layer(lay, dir_val, iter, 0, 0), |
347 | &ws_src_iter(lay, dir_val, iter, 0, 0), |
348 | &ws_src_iter_c(lay, dir_val, iter, 0, 0), |
349 | &ws_gates(lay - 1, dir_val, iter - 1, 0, 0, 0), |
350 | &ws_ht(lay - 1, dir_val, iter - 1, 0, 0), |
351 | SAFE_PTR(weights_layer, lay - 1, dir_val, 0, 0), |
352 | SAFE_PTR(weights_iter, lay - 1, dir_val, 0, 0), |
353 | SAFE_PTR(weights_peephole, lay - 1, dir_val, 0), |
354 | SAFE_PTR(weights_projection, lay - 1, dir_val, 0), |
355 | SAFE_PTR(weights_projection_compensation, lay - 1, |
356 | dir_val, 0), |
357 | SAFE_PTR(bias, lay - 1, dir_val, 0), |
358 | &ws_src_layer(lay - 1, dir_val, iter, 0, 0), |
359 | SAFE_PTR(src_layer_attention, iter - 1, 0, 0), |
360 | &ws_src_iter(lay, dir_val, prev_iter, 0, 0), |
361 | &ws_src_iter_c(lay, dir_val, prev_iter, 0, 0), |
362 | cell_scratchpad_); |
363 | #undef SAFE_PTR |
364 | } |
365 | } |
366 | |
367 | // Finally we copy the results to the result buffers |
368 | copy_res_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c, |
369 | iter_dir, lay_dir, dir_val, action); |
370 | }; |
371 | |
372 | switch (prb.direction) { |
373 | case dnnl_unidirectional_left2right: |
374 | process_direction(left2right, bottom2top, 0, action_copy); |
375 | break; |
376 | case dnnl_unidirectional_right2left: |
377 | process_direction(right2left, bottom2top, 0, action_copy); |
378 | break; |
379 | case dnnl_bidirectional_sum: |
380 | process_direction(left2right, bottom2top, 0, action_copy); |
381 | process_direction(right2left, bottom2top, 1, action_sum); |
382 | break; |
383 | case dnnl_bidirectional_concat: |
384 | process_direction(left2right, bottom2top, 0, action_copy); |
385 | process_direction(right2left, bottom2top, 1, action_concat); |
386 | break; |
387 | default: assert(!"unknown direction" ); break; |
388 | } |
389 | |
390 | zfree(cell_scratchpad_); |
391 | delete[] bias_with_compensation; |
392 | delete[] weights_projection_compensation_; |
393 | } |
394 | |
395 | void compute_ref_fwd(const prb_t &prb, const args_t &args) { |
396 | std::vector<float> ws_fwd_buffer; |
397 | AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht; |
398 | prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c, |
399 | ws_gates, ws_ht); |
400 | |
401 | rnn_linear_fwd(prb, args, ws_src_layer, ws_src_iter, ws_src_iter_c, |
402 | ws_gates, ws_ht); |
403 | } |
404 | |
405 | } // namespace rnn |
406 | |