1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef RNN_HPP |
18 | #define RNN_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <limits.h> |
22 | #include <stdint.h> |
23 | |
24 | #include <string> |
25 | #include <vector> |
26 | |
27 | #include "common.hpp" |
28 | #include "dnn_types.hpp" |
29 | #include "dnnl_common.hpp" |
30 | #include "dnnl_debug.hpp" |
31 | #include "utils/perf_report.hpp" |
32 | #include "utils/settings.hpp" |
33 | |
34 | #define AOC array_offset_calculator |
35 | |
36 | namespace rnn { |
37 | |
38 | enum alg_t { |
39 | VANILLA_RNN, |
40 | VANILLA_LSTM, |
41 | VANILLA_GRU, |
42 | LBR_GRU, |
43 | VANILLA_AUGRU, |
44 | LBR_AUGRU |
45 | }; |
46 | alg_t str2alg(const char *str); |
47 | const char *alg2str(alg_t alg); |
48 | dnnl_alg_kind_t alg2kind(alg_t alg); |
49 | |
50 | enum activation_t { UNDEF, RELU, LOGISTIC, TANH }; |
51 | activation_t str2activation(const char *str); |
52 | const char *activation2str(activation_t alg); |
53 | dnnl_alg_kind_t activation2kind(activation_t alg); |
54 | |
55 | dnnl_rnn_direction_t str2direction(const char *str); |
56 | const char *direction2str(dnnl_rnn_direction_t direction); |
57 | |
58 | enum rnn_data_kind_t { |
59 | SRC_LAYER, |
60 | SRC_ITER, |
61 | SRC_ITER_C, |
62 | WEIGHTS_LAYER, |
63 | WEIGHTS_ITER, |
64 | BIAS, |
65 | DST_ITER, |
66 | DST_ITER_C, |
67 | DST_LAYER, |
68 | |
69 | DIFF_SRC_LAYER, |
70 | DIFF_SRC_ITER, |
71 | DIFF_SRC_ITER_C, |
72 | DIFF_WEIGHTS_LAYER, |
73 | DIFF_WEIGHTS_ITER, |
74 | DIFF_BIAS, |
75 | DIFF_DST_ITER, |
76 | DIFF_DST_ITER_C, |
77 | DIFF_DST_LAYER, |
78 | |
79 | // FIXME: adding peephole related weights to the appropriate places will |
80 | // cause false-positive accuracy check failures in unrelated test cases |
81 | // (e.g. backward vanilla RNN for bf16) due to the data fill seed being |
82 | // dependent on the position of the tensor kind in the enum: adding |
83 | // `WEIGHTS_PEEPHOLE` before `dst_*` and `*diff_*` results in initializing |
84 | // the corresponding tensors differently. |
85 | // We need a more robust way of testing RNN. |
86 | WEIGHTS_PEEPHOLE, |
87 | DIFF_WEIGHTS_PEEPHOLE, |
88 | WEIGHTS_PROJECTION, |
89 | DIFF_WEIGHTS_PROJECTION, |
90 | // AUGRU requires an addtional argument for attention. |
91 | AUGRU_ATTENTION, |
92 | DIFF_AUGRU_ATTENTION, |
93 | KIND_TOTAL, |
94 | }; |
95 | const char *rnn_data_kind2str(rnn_data_kind_t kind); |
96 | |
97 | // Gates indices |
98 | enum { |
99 | LSTM_I = 0, |
100 | LSTM_F = 1, |
101 | LSTM_C = 2, |
102 | LSTM_O = 3, |
103 | GRU_U = 0, |
104 | GRU_R = 1, |
105 | GRU_O = 2, |
106 | LBR_GRU_U_PRIME = 3, |
107 | }; |
108 | |
109 | // dlc is different at the cell level and the primitive level |
110 | // This enum enable to explicitely query the intended one |
111 | enum dlc_type_t { CELL, PRIMITIVE }; |
112 | |
113 | template <typename Telem> |
114 | struct array_offset_calculator { |
115 | array_offset_calculator() = default; |
116 | |
117 | template <typename... Targs> |
118 | array_offset_calculator(const dnn_mem_t &mem, Targs... dims) |
119 | : base_ptr_(mem ? (Telem *)mem : nullptr), dims_({dims...}) {} |
120 | |
121 | template <typename... Targs> |
122 | array_offset_calculator(Telem *base_ptr, Targs... dims) |
123 | : base_ptr_(base_ptr), dims_({dims...}) {} |
124 | |
125 | // ctor for AOC<const T> based on const AOC<T> & |
126 | template <typename Uelem> |
127 | array_offset_calculator(const array_offset_calculator<Uelem> &rhs) |
128 | : base_ptr_(rhs.base_ptr_), dims_(rhs.dims_) {} |
129 | |
130 | // to make the above ctor work AOC<const T> should be able to access |
131 | // private fields of AOC<T>, hence let's friend them |
132 | friend struct array_offset_calculator<const Telem>; |
133 | |
134 | template <typename... Targs> |
135 | Telem &operator()(Targs... Fargs) const { |
136 | assert(static_cast<bool>(base_ptr_)); |
137 | return *(base_ptr_ + offset(1, Fargs...)); |
138 | } |
139 | |
140 | int64_t nelems() const { |
141 | int64_t res = 1; |
142 | for (auto dim : dims_) |
143 | res *= dim; |
144 | return res; |
145 | } |
146 | |
147 | void set_base_ptr(Telem *base_ptr) { base_ptr_ = base_ptr; } |
148 | |
149 | private: |
150 | template <typename... Targs> |
151 | int64_t offset(int64_t d, int64_t pos) const { |
152 | return pos; |
153 | } |
154 | |
155 | template <typename... Targs> |
156 | int64_t offset(int64_t d, int64_t off, int64_t pos) const { |
157 | return off * dims_[d] + pos; |
158 | } |
159 | |
160 | template <typename... Targs> |
161 | int64_t offset(int64_t d, int64_t off, int64_t pos, Targs... rem) const { |
162 | return offset(d + 1, off * dims_[d] + pos, rem...); |
163 | } |
164 | |
165 | Telem *base_ptr_; |
166 | std::vector<int64_t> dims_; |
167 | }; |
168 | |
169 | struct desc_t { |
170 | int64_t sic; |
171 | int64_t slc; |
172 | int64_t dhc; |
173 | int64_t dic; |
174 | int64_t wc; |
175 | int64_t mb; |
176 | int64_t n_layer; |
177 | int64_t n_iter; |
178 | std::string name; |
179 | }; |
180 | int str2desc(desc_t *desc, const char *str); |
181 | std::ostream &operator<<(std::ostream &s, const desc_t &d); |
182 | |
183 | /** configuration structure, that controls initial data filling + error check |
184 | * |
185 | * dt defines precision |
186 | * |
187 | * for each lst data kind the values are filled as follows: |
188 | * if (rand() > f_sparsity) then: |
189 | * v <-- f_base |
190 | * else: |
191 | * v <-- f_min + rand() * f_step % (f_max - f_min) |
192 | * |
193 | * on final check the resulting values should be in [min .. max] range, the |
194 | * relative difference should not exceed eps |
195 | */ |
196 | struct dt_conf_t { |
197 | struct entry_t { |
198 | dnnl_data_type_t dt; |
199 | int min, max; // representative |
200 | float f_min, f_max; // fill range |
201 | float f_mean, f_stddev; // parameters of normal distribution |
202 | double eps; // acceptable error |
203 | }; |
204 | |
205 | dt_conf_t(const std::string &str) : str_(str) {} |
206 | |
207 | virtual const entry_t &operator[](rnn_data_kind_t kind) const = 0; |
208 | |
209 | const std::string &str() const { return str_; } |
210 | bool is_int8() const { |
211 | return operator[](SRC_LAYER).dt == dnnl_u8 |
212 | || operator[](SRC_LAYER).dt == dnnl_s8; |
213 | } |
214 | bool is_s8() const { return operator[](SRC_LAYER).dt == dnnl_s8; } |
215 | |
216 | static const dt_conf_t &create(const std::string &str, const attr_t &attr); |
217 | |
218 | std::string str_; |
219 | }; |
220 | |
221 | struct settings_t : public base_settings_t { |
222 | settings_t() = default; |
223 | |
224 | // ctor to save certain fields from resetting |
225 | settings_t(const char *perf_template) : settings_t() { |
226 | this->perf_template = perf_template; |
227 | } |
228 | |
229 | desc_t desc {}; |
230 | |
231 | std::vector<dir_t> prop {FWD_I}; |
232 | std::vector<std::string> cfg {"f32" }; |
233 | std::vector<alg_t> alg {VANILLA_RNN}; |
234 | std::vector<dnnl_rnn_direction_t> direction { |
235 | dnnl_unidirectional_left2right}; |
236 | std::vector<activation_t> activation {RELU}; |
237 | std::vector<bool> skip_nonlinear {false}; |
238 | std::vector<bool> trivial_strides {false}; |
239 | std::vector<bool> with_peephole {false}; |
240 | std::vector<bool> with_projection {false}; |
241 | std::vector<int64_t> n_layer {0}, n_iter {0}; |
242 | std::vector<policy_t> scale_policy {policy_t::COMMON}; |
243 | std::vector<policy_t> scale_proj_policy {policy_t::COMMON}; |
244 | unsigned int flags = 0x0; |
245 | float alpha = 0.9f, beta = 0.0f; |
246 | |
247 | const char *perf_template_csv() const { |
248 | static const std::string args |
249 | = "%prop%,%cfg%,%alg%,%activation%,%direction%" ; |
250 | return perf_template_csv_base(args); |
251 | } |
252 | |
253 | void reset() { *this = settings_t(perf_template); } |
254 | }; |
255 | |
256 | struct prb_t : public desc_t { |
257 | prb_t(const desc_t &desc, const dt_conf_t &cfg, dir_t prop, alg_t alg, |
258 | bool with_peephole, bool with_projection, |
259 | dnnl_rnn_direction_t direction, policy_t scale_policy, |
260 | policy_t scale_proj_policy, unsigned int flags, |
261 | activation_t activation, const attr_t &attr, |
262 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe, float alpha, |
263 | float beta, bool skip_nonlinear, bool trivial_strides, |
264 | int64_t n_layer, int64_t n_iter, int64_t mb = 0) |
265 | : desc_t(desc) |
266 | , cfg(cfg) |
267 | , prop(prop2prop_kind(prop)) |
268 | , dir(prop) |
269 | , alg(alg) |
270 | , with_peephole(with_peephole) |
271 | , with_projection(with_projection) |
272 | , direction(direction) |
273 | , wei_scales_policy(scale_policy) |
274 | , wei_proj_scales_policy(scale_proj_policy) |
275 | , flags(flags) |
276 | , activation(activation) |
277 | , attr(attr) |
278 | , ctx_init(ctx_init) |
279 | , ctx_exe(ctx_exe) |
280 | , user_mb(mb) |
281 | , alpha(alpha) |
282 | , beta(beta) |
283 | , skip_nonlinear(skip_nonlinear) |
284 | , trivial_strides(trivial_strides) |
285 | , ops(0.0) |
286 | , linear_cscale(0.0f) { |
287 | |
288 | if (n_layer) this->n_layer = n_layer; |
289 | if (n_iter) this->n_iter = n_iter; |
290 | if (mb) this->mb = mb; |
291 | count_ops(); |
292 | |
293 | wei_scales = nullptr; |
294 | wei_proj_scales = nullptr; |
295 | linear_scales = nullptr; |
296 | |
297 | // We always allocate linear scales. Even if they are not |
298 | // used, they get dereferenced when built in debug mode. |
299 | linear_scales = (float *)zmalloc(sizeof(float) * n_gates(), 64); |
300 | // Here we use the range of SRC_LAYER to set the scales |
301 | set_tparams(cfg[SRC_LAYER].f_min, cfg[SRC_LAYER].f_max); |
302 | |
303 | switch (wei_scales_policy) { |
304 | case policy_t::COMMON: |
305 | wei_scales_mask = 0x0; |
306 | wei_nscales = 1; |
307 | break; |
308 | case policy_t::PER_OC: |
309 | wei_scales_mask = 0x18; |
310 | wei_nscales = dhc * n_gates(); |
311 | break; |
312 | default: assert(!"unsupported scaling policy" ); |
313 | } |
314 | wei_scales = (float *)zmalloc(sizeof(float) * wei_nscales, 64); |
315 | |
316 | if (with_projection) { |
317 | switch (wei_proj_scales_policy) { |
318 | case policy_t::PER_OC: |
319 | wei_proj_scales_mask = 0x8; |
320 | wei_proj_nscales = dic; |
321 | break; |
322 | case policy_t::COMMON: |
323 | wei_proj_scales_mask = 0x0; |
324 | wei_proj_nscales = 1; |
325 | break; |
326 | default: assert(!"unsupported scaling policy" ); |
327 | } |
328 | wei_proj_scales |
329 | = (float *)zmalloc(sizeof(float) * wei_proj_nscales, 64); |
330 | } |
331 | |
332 | set_qparams(-1., 1.); |
333 | } |
334 | ~prb_t() { |
335 | if (wei_scales) zfree(wei_scales); |
336 | if (wei_proj_scales) zfree(wei_proj_scales); |
337 | if (linear_scales) zfree(linear_scales); |
338 | } |
339 | |
340 | float get_wei_scale(int idx) const { |
341 | return wei_scales[MIN2(idx, wei_nscales - 1)]; |
342 | } |
343 | |
344 | inline float get_wei_proj_scale(int idx) const { |
345 | return wei_proj_scales[MIN2(idx, wei_proj_nscales - 1)]; |
346 | } |
347 | |
348 | void count_ops() { |
349 | // Here, we count only the ops in GEMM portion as there is no |
350 | // theoretical number of ops for the post-gemm operations |
351 | int64_t num_cells = (int64_t)n_dir() * n_layer * n_iter; |
352 | int64_t cell_ops = (int64_t)2 * (n_gates() * dhc) * mb * (sic + slc); |
353 | if (with_projection) cell_ops += (int64_t)2 * dhc * mb * dic; |
354 | int64_t prop_multiplier = prop == dnnl_backward ? 2 : 1; |
355 | ops = prop_multiplier * num_cells * cell_ops; |
356 | } |
357 | |
358 | int64_t n_dir() const { |
359 | return (direction == dnnl_bidirectional_concat |
360 | || direction == dnnl_bidirectional_sum) |
361 | ? 2 |
362 | : 1; |
363 | } |
364 | int64_t n_states() const { return alg == VANILLA_LSTM ? 2 : 1; } |
365 | int64_t n_gates() const { |
366 | return alg == VANILLA_LSTM |
367 | ? 4 |
368 | : (alg == VANILLA_GRU || alg == LBR_GRU || alg == VANILLA_AUGRU |
369 | || alg == LBR_AUGRU |
370 | ? 3 |
371 | : 1); |
372 | } |
373 | int64_t n_bias() const { |
374 | return alg == LBR_GRU || alg == LBR_AUGRU ? n_gates() + 1 : n_gates(); |
375 | } |
376 | |
377 | int64_t dlc(dlc_type_t type) const { |
378 | if (type == PRIMITIVE) |
379 | return (direction == dnnl_bidirectional_concat ? 2 : 1) * dic; |
380 | if (type == CELL) return dic; |
381 | assert(!"unsupported dlc type" ); |
382 | return 0; |
383 | } |
384 | |
385 | bool is_int8() const { |
386 | return cfg[SRC_LAYER].dt == dnnl_u8 || cfg[SRC_LAYER].dt == dnnl_s8; |
387 | } |
388 | bool is_u8() const { return cfg[SRC_LAYER].dt == dnnl_u8; } |
389 | bool is_s8() const { return cfg[SRC_LAYER].dt == dnnl_s8; } |
390 | bool is_lstm_peephole() const { return with_peephole; } |
391 | bool is_lstm_projection() const { return with_projection; } |
392 | bool is_augru() const { return alg == VANILLA_AUGRU || alg == LBR_AUGRU; } |
393 | |
394 | const dt_conf_t &cfg; |
395 | dnnl_prop_kind_t prop; |
396 | dir_t dir; // Same as `prop`, for compatibility. TODO: remove me; |
397 | alg_t alg; |
398 | bool with_peephole, with_projection; |
399 | dnnl_rnn_direction_t direction; |
400 | policy_t wei_scales_policy; |
401 | policy_t wei_proj_scales_policy; |
402 | unsigned int flags; |
403 | activation_t activation; |
404 | attr_t attr; |
405 | thr_ctx_t ctx_init, ctx_exe; |
406 | int64_t user_mb; |
407 | float alpha; |
408 | float beta; |
409 | |
410 | float data_scale, data_shift; |
411 | |
412 | float *wei_scales; |
413 | int wei_nscales; |
414 | int wei_scales_mask; |
415 | |
416 | float *wei_proj_scales; |
417 | int wei_proj_nscales; |
418 | int wei_proj_scales_mask; |
419 | |
420 | bool skip_nonlinear; |
421 | bool trivial_strides; |
422 | double ops; |
423 | float *linear_scales; |
424 | float linear_cscale; |
425 | |
426 | private: |
427 | /* Todo: fused the two functions in set_shifts_scales */ |
428 | void set_qparams(float fp_min, float fp_max); |
429 | void set_tparams(float fp_min, float fp_max); |
430 | prb_t(const prb_t &) = delete; |
431 | prb_t &operator=(const prb_t &) = delete; |
432 | }; |
433 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
434 | |
435 | struct perf_report_t : public base_perf_report_t { |
436 | perf_report_t(const prb_t *prb, const char *perf_template) |
437 | : base_perf_report_t(perf_template), p_(prb) {} |
438 | |
439 | void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); } |
440 | |
441 | void dump_cfg(std::ostream &s) const override { s << p_->cfg.str(); } |
442 | |
443 | void dump_desc(std::ostream &s) const override { |
444 | s << static_cast<const desc_t &>(*p_); |
445 | } |
446 | |
447 | void dump_desc_csv(std::ostream &s) const override { |
448 | s << p_->n_layer << "," << p_->n_iter << "," << p_->mb << "," << p_->sic |
449 | << "," << p_->slc << "," << p_->dhc << "," << p_->dic; |
450 | } |
451 | |
452 | void dump_rnn_activation(std::ostream &s) const override { |
453 | s << activation2str(p_->activation); |
454 | } |
455 | |
456 | void dump_rnn_direction(std::ostream &s) const override { |
457 | s << direction2str(p_->direction); |
458 | } |
459 | |
460 | double ops() const override { return p_->ops; } |
461 | const int64_t *user_mb() const override { return &p_->user_mb; } |
462 | const attr_t *attr() const override { return &p_->attr; } |
463 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
464 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
465 | const std::string *name() const override { return &p_->name; } |
466 | const dnnl_prop_kind_t *prop() const override { return &p_->prop; } |
467 | |
468 | private: |
469 | const prb_t *p_; |
470 | }; |
471 | |
472 | void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer, |
473 | AOC<float> &ws_src_layer, AOC<float> &ws_src_iter, |
474 | AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht); |
475 | |
476 | void rnn_linear_fwd(const prb_t &prb, const args_t &args, |
477 | const AOC<float> &ws_src_layer, const AOC<float> &ws_src_iter, |
478 | const AOC<float> &ws_src_iter_c, const AOC<float> &ws_gates, |
479 | const AOC<float> &ws_ht); |
480 | |
481 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
482 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
483 | void compute_ref(const prb_t *prb, const args_t &args, |
484 | dnnl_primitive_t prim_ref = nullptr); |
485 | void compute_ref_fwd(const prb_t &prb, const args_t &args); |
486 | void compute_ref_bwd(const prb_t &prb, const args_t &args); |
487 | |
488 | int doit(const prb_t &prb, res_t *res); |
489 | int bench(int argc, char **argv); |
490 | |
491 | } // namespace rnn |
492 | |
493 | #endif |
494 | |