1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef RNN_HPP
18#define RNN_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <string>
25#include <vector>
26
27#include "common.hpp"
28#include "dnn_types.hpp"
29#include "dnnl_common.hpp"
30#include "dnnl_debug.hpp"
31#include "utils/perf_report.hpp"
32#include "utils/settings.hpp"
33
34#define AOC array_offset_calculator
35
36namespace rnn {
37
38enum alg_t {
39 VANILLA_RNN,
40 VANILLA_LSTM,
41 VANILLA_GRU,
42 LBR_GRU,
43 VANILLA_AUGRU,
44 LBR_AUGRU
45};
46alg_t str2alg(const char *str);
47const char *alg2str(alg_t alg);
48dnnl_alg_kind_t alg2kind(alg_t alg);
49
50enum activation_t { UNDEF, RELU, LOGISTIC, TANH };
51activation_t str2activation(const char *str);
52const char *activation2str(activation_t alg);
53dnnl_alg_kind_t activation2kind(activation_t alg);
54
55dnnl_rnn_direction_t str2direction(const char *str);
56const char *direction2str(dnnl_rnn_direction_t direction);
57
58enum rnn_data_kind_t {
59 SRC_LAYER,
60 SRC_ITER,
61 SRC_ITER_C,
62 WEIGHTS_LAYER,
63 WEIGHTS_ITER,
64 BIAS,
65 DST_ITER,
66 DST_ITER_C,
67 DST_LAYER,
68
69 DIFF_SRC_LAYER,
70 DIFF_SRC_ITER,
71 DIFF_SRC_ITER_C,
72 DIFF_WEIGHTS_LAYER,
73 DIFF_WEIGHTS_ITER,
74 DIFF_BIAS,
75 DIFF_DST_ITER,
76 DIFF_DST_ITER_C,
77 DIFF_DST_LAYER,
78
79 // FIXME: adding peephole related weights to the appropriate places will
80 // cause false-positive accuracy check failures in unrelated test cases
81 // (e.g. backward vanilla RNN for bf16) due to the data fill seed being
82 // dependent on the position of the tensor kind in the enum: adding
83 // `WEIGHTS_PEEPHOLE` before `dst_*` and `*diff_*` results in initializing
84 // the corresponding tensors differently.
85 // We need a more robust way of testing RNN.
86 WEIGHTS_PEEPHOLE,
87 DIFF_WEIGHTS_PEEPHOLE,
88 WEIGHTS_PROJECTION,
89 DIFF_WEIGHTS_PROJECTION,
90 // AUGRU requires an addtional argument for attention.
91 AUGRU_ATTENTION,
92 DIFF_AUGRU_ATTENTION,
93 KIND_TOTAL,
94};
95const char *rnn_data_kind2str(rnn_data_kind_t kind);
96
97// Gates indices
98enum {
99 LSTM_I = 0,
100 LSTM_F = 1,
101 LSTM_C = 2,
102 LSTM_O = 3,
103 GRU_U = 0,
104 GRU_R = 1,
105 GRU_O = 2,
106 LBR_GRU_U_PRIME = 3,
107};
108
109// dlc is different at the cell level and the primitive level
110// This enum enable to explicitely query the intended one
111enum dlc_type_t { CELL, PRIMITIVE };
112
113template <typename Telem>
114struct array_offset_calculator {
115 array_offset_calculator() = default;
116
117 template <typename... Targs>
118 array_offset_calculator(const dnn_mem_t &mem, Targs... dims)
119 : base_ptr_(mem ? (Telem *)mem : nullptr), dims_({dims...}) {}
120
121 template <typename... Targs>
122 array_offset_calculator(Telem *base_ptr, Targs... dims)
123 : base_ptr_(base_ptr), dims_({dims...}) {}
124
125 // ctor for AOC<const T> based on const AOC<T> &
126 template <typename Uelem>
127 array_offset_calculator(const array_offset_calculator<Uelem> &rhs)
128 : base_ptr_(rhs.base_ptr_), dims_(rhs.dims_) {}
129
130 // to make the above ctor work AOC<const T> should be able to access
131 // private fields of AOC<T>, hence let's friend them
132 friend struct array_offset_calculator<const Telem>;
133
134 template <typename... Targs>
135 Telem &operator()(Targs... Fargs) const {
136 assert(static_cast<bool>(base_ptr_));
137 return *(base_ptr_ + offset(1, Fargs...));
138 }
139
140 int64_t nelems() const {
141 int64_t res = 1;
142 for (auto dim : dims_)
143 res *= dim;
144 return res;
145 }
146
147 void set_base_ptr(Telem *base_ptr) { base_ptr_ = base_ptr; }
148
149private:
150 template <typename... Targs>
151 int64_t offset(int64_t d, int64_t pos) const {
152 return pos;
153 }
154
155 template <typename... Targs>
156 int64_t offset(int64_t d, int64_t off, int64_t pos) const {
157 return off * dims_[d] + pos;
158 }
159
160 template <typename... Targs>
161 int64_t offset(int64_t d, int64_t off, int64_t pos, Targs... rem) const {
162 return offset(d + 1, off * dims_[d] + pos, rem...);
163 }
164
165 Telem *base_ptr_;
166 std::vector<int64_t> dims_;
167};
168
169struct desc_t {
170 int64_t sic;
171 int64_t slc;
172 int64_t dhc;
173 int64_t dic;
174 int64_t wc;
175 int64_t mb;
176 int64_t n_layer;
177 int64_t n_iter;
178 std::string name;
179};
180int str2desc(desc_t *desc, const char *str);
181std::ostream &operator<<(std::ostream &s, const desc_t &d);
182
183/** configuration structure, that controls initial data filling + error check
184*
185* dt defines precision
186*
187* for each lst data kind the values are filled as follows:
188* if (rand() > f_sparsity) then:
189* v <-- f_base
190* else:
191* v <-- f_min + rand() * f_step % (f_max - f_min)
192*
193* on final check the resulting values should be in [min .. max] range, the
194* relative difference should not exceed eps
195*/
196struct dt_conf_t {
197 struct entry_t {
198 dnnl_data_type_t dt;
199 int min, max; // representative
200 float f_min, f_max; // fill range
201 float f_mean, f_stddev; // parameters of normal distribution
202 double eps; // acceptable error
203 };
204
205 dt_conf_t(const std::string &str) : str_(str) {}
206
207 virtual const entry_t &operator[](rnn_data_kind_t kind) const = 0;
208
209 const std::string &str() const { return str_; }
210 bool is_int8() const {
211 return operator[](SRC_LAYER).dt == dnnl_u8
212 || operator[](SRC_LAYER).dt == dnnl_s8;
213 }
214 bool is_s8() const { return operator[](SRC_LAYER).dt == dnnl_s8; }
215
216 static const dt_conf_t &create(const std::string &str, const attr_t &attr);
217
218 std::string str_;
219};
220
221struct settings_t : public base_settings_t {
222 settings_t() = default;
223
224 // ctor to save certain fields from resetting
225 settings_t(const char *perf_template) : settings_t() {
226 this->perf_template = perf_template;
227 }
228
229 desc_t desc {};
230
231 std::vector<dir_t> prop {FWD_I};
232 std::vector<std::string> cfg {"f32"};
233 std::vector<alg_t> alg {VANILLA_RNN};
234 std::vector<dnnl_rnn_direction_t> direction {
235 dnnl_unidirectional_left2right};
236 std::vector<activation_t> activation {RELU};
237 std::vector<bool> skip_nonlinear {false};
238 std::vector<bool> trivial_strides {false};
239 std::vector<bool> with_peephole {false};
240 std::vector<bool> with_projection {false};
241 std::vector<int64_t> n_layer {0}, n_iter {0};
242 std::vector<policy_t> scale_policy {policy_t::COMMON};
243 std::vector<policy_t> scale_proj_policy {policy_t::COMMON};
244 unsigned int flags = 0x0;
245 float alpha = 0.9f, beta = 0.0f;
246
247 const char *perf_template_csv() const {
248 static const std::string args
249 = "%prop%,%cfg%,%alg%,%activation%,%direction%";
250 return perf_template_csv_base(args);
251 }
252
253 void reset() { *this = settings_t(perf_template); }
254};
255
256struct prb_t : public desc_t {
257 prb_t(const desc_t &desc, const dt_conf_t &cfg, dir_t prop, alg_t alg,
258 bool with_peephole, bool with_projection,
259 dnnl_rnn_direction_t direction, policy_t scale_policy,
260 policy_t scale_proj_policy, unsigned int flags,
261 activation_t activation, const attr_t &attr,
262 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe, float alpha,
263 float beta, bool skip_nonlinear, bool trivial_strides,
264 int64_t n_layer, int64_t n_iter, int64_t mb = 0)
265 : desc_t(desc)
266 , cfg(cfg)
267 , prop(prop2prop_kind(prop))
268 , dir(prop)
269 , alg(alg)
270 , with_peephole(with_peephole)
271 , with_projection(with_projection)
272 , direction(direction)
273 , wei_scales_policy(scale_policy)
274 , wei_proj_scales_policy(scale_proj_policy)
275 , flags(flags)
276 , activation(activation)
277 , attr(attr)
278 , ctx_init(ctx_init)
279 , ctx_exe(ctx_exe)
280 , user_mb(mb)
281 , alpha(alpha)
282 , beta(beta)
283 , skip_nonlinear(skip_nonlinear)
284 , trivial_strides(trivial_strides)
285 , ops(0.0)
286 , linear_cscale(0.0f) {
287
288 if (n_layer) this->n_layer = n_layer;
289 if (n_iter) this->n_iter = n_iter;
290 if (mb) this->mb = mb;
291 count_ops();
292
293 wei_scales = nullptr;
294 wei_proj_scales = nullptr;
295 linear_scales = nullptr;
296
297 // We always allocate linear scales. Even if they are not
298 // used, they get dereferenced when built in debug mode.
299 linear_scales = (float *)zmalloc(sizeof(float) * n_gates(), 64);
300 // Here we use the range of SRC_LAYER to set the scales
301 set_tparams(cfg[SRC_LAYER].f_min, cfg[SRC_LAYER].f_max);
302
303 switch (wei_scales_policy) {
304 case policy_t::COMMON:
305 wei_scales_mask = 0x0;
306 wei_nscales = 1;
307 break;
308 case policy_t::PER_OC:
309 wei_scales_mask = 0x18;
310 wei_nscales = dhc * n_gates();
311 break;
312 default: assert(!"unsupported scaling policy");
313 }
314 wei_scales = (float *)zmalloc(sizeof(float) * wei_nscales, 64);
315
316 if (with_projection) {
317 switch (wei_proj_scales_policy) {
318 case policy_t::PER_OC:
319 wei_proj_scales_mask = 0x8;
320 wei_proj_nscales = dic;
321 break;
322 case policy_t::COMMON:
323 wei_proj_scales_mask = 0x0;
324 wei_proj_nscales = 1;
325 break;
326 default: assert(!"unsupported scaling policy");
327 }
328 wei_proj_scales
329 = (float *)zmalloc(sizeof(float) * wei_proj_nscales, 64);
330 }
331
332 set_qparams(-1., 1.);
333 }
334 ~prb_t() {
335 if (wei_scales) zfree(wei_scales);
336 if (wei_proj_scales) zfree(wei_proj_scales);
337 if (linear_scales) zfree(linear_scales);
338 }
339
340 float get_wei_scale(int idx) const {
341 return wei_scales[MIN2(idx, wei_nscales - 1)];
342 }
343
344 inline float get_wei_proj_scale(int idx) const {
345 return wei_proj_scales[MIN2(idx, wei_proj_nscales - 1)];
346 }
347
348 void count_ops() {
349 // Here, we count only the ops in GEMM portion as there is no
350 // theoretical number of ops for the post-gemm operations
351 int64_t num_cells = (int64_t)n_dir() * n_layer * n_iter;
352 int64_t cell_ops = (int64_t)2 * (n_gates() * dhc) * mb * (sic + slc);
353 if (with_projection) cell_ops += (int64_t)2 * dhc * mb * dic;
354 int64_t prop_multiplier = prop == dnnl_backward ? 2 : 1;
355 ops = prop_multiplier * num_cells * cell_ops;
356 }
357
358 int64_t n_dir() const {
359 return (direction == dnnl_bidirectional_concat
360 || direction == dnnl_bidirectional_sum)
361 ? 2
362 : 1;
363 }
364 int64_t n_states() const { return alg == VANILLA_LSTM ? 2 : 1; }
365 int64_t n_gates() const {
366 return alg == VANILLA_LSTM
367 ? 4
368 : (alg == VANILLA_GRU || alg == LBR_GRU || alg == VANILLA_AUGRU
369 || alg == LBR_AUGRU
370 ? 3
371 : 1);
372 }
373 int64_t n_bias() const {
374 return alg == LBR_GRU || alg == LBR_AUGRU ? n_gates() + 1 : n_gates();
375 }
376
377 int64_t dlc(dlc_type_t type) const {
378 if (type == PRIMITIVE)
379 return (direction == dnnl_bidirectional_concat ? 2 : 1) * dic;
380 if (type == CELL) return dic;
381 assert(!"unsupported dlc type");
382 return 0;
383 }
384
385 bool is_int8() const {
386 return cfg[SRC_LAYER].dt == dnnl_u8 || cfg[SRC_LAYER].dt == dnnl_s8;
387 }
388 bool is_u8() const { return cfg[SRC_LAYER].dt == dnnl_u8; }
389 bool is_s8() const { return cfg[SRC_LAYER].dt == dnnl_s8; }
390 bool is_lstm_peephole() const { return with_peephole; }
391 bool is_lstm_projection() const { return with_projection; }
392 bool is_augru() const { return alg == VANILLA_AUGRU || alg == LBR_AUGRU; }
393
394 const dt_conf_t &cfg;
395 dnnl_prop_kind_t prop;
396 dir_t dir; // Same as `prop`, for compatibility. TODO: remove me;
397 alg_t alg;
398 bool with_peephole, with_projection;
399 dnnl_rnn_direction_t direction;
400 policy_t wei_scales_policy;
401 policy_t wei_proj_scales_policy;
402 unsigned int flags;
403 activation_t activation;
404 attr_t attr;
405 thr_ctx_t ctx_init, ctx_exe;
406 int64_t user_mb;
407 float alpha;
408 float beta;
409
410 float data_scale, data_shift;
411
412 float *wei_scales;
413 int wei_nscales;
414 int wei_scales_mask;
415
416 float *wei_proj_scales;
417 int wei_proj_nscales;
418 int wei_proj_scales_mask;
419
420 bool skip_nonlinear;
421 bool trivial_strides;
422 double ops;
423 float *linear_scales;
424 float linear_cscale;
425
426private:
427 /* Todo: fused the two functions in set_shifts_scales */
428 void set_qparams(float fp_min, float fp_max);
429 void set_tparams(float fp_min, float fp_max);
430 prb_t(const prb_t &) = delete;
431 prb_t &operator=(const prb_t &) = delete;
432};
433std::ostream &operator<<(std::ostream &s, const prb_t &prb);
434
435struct perf_report_t : public base_perf_report_t {
436 perf_report_t(const prb_t *prb, const char *perf_template)
437 : base_perf_report_t(perf_template), p_(prb) {}
438
439 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
440
441 void dump_cfg(std::ostream &s) const override { s << p_->cfg.str(); }
442
443 void dump_desc(std::ostream &s) const override {
444 s << static_cast<const desc_t &>(*p_);
445 }
446
447 void dump_desc_csv(std::ostream &s) const override {
448 s << p_->n_layer << "," << p_->n_iter << "," << p_->mb << "," << p_->sic
449 << "," << p_->slc << "," << p_->dhc << "," << p_->dic;
450 }
451
452 void dump_rnn_activation(std::ostream &s) const override {
453 s << activation2str(p_->activation);
454 }
455
456 void dump_rnn_direction(std::ostream &s) const override {
457 s << direction2str(p_->direction);
458 }
459
460 double ops() const override { return p_->ops; }
461 const int64_t *user_mb() const override { return &p_->user_mb; }
462 const attr_t *attr() const override { return &p_->attr; }
463 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
464 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
465 const std::string *name() const override { return &p_->name; }
466 const dnnl_prop_kind_t *prop() const override { return &p_->prop; }
467
468private:
469 const prb_t *p_;
470};
471
472void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer,
473 AOC<float> &ws_src_layer, AOC<float> &ws_src_iter,
474 AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht);
475
476void rnn_linear_fwd(const prb_t &prb, const args_t &args,
477 const AOC<float> &ws_src_layer, const AOC<float> &ws_src_iter,
478 const AOC<float> &ws_src_iter_c, const AOC<float> &ws_gates,
479 const AOC<float> &ws_ht);
480
481void skip_unimplemented_prb(const prb_t *prb, res_t *res);
482void skip_invalid_prb(const prb_t *prb, res_t *res);
483void compute_ref(const prb_t *prb, const args_t &args,
484 dnnl_primitive_t prim_ref = nullptr);
485void compute_ref_fwd(const prb_t &prb, const args_t &args);
486void compute_ref_bwd(const prb_t &prb, const args_t &args);
487
488int doit(const prb_t &prb, res_t *res);
489int bench(int argc, char **argv);
490
491} // namespace rnn
492
493#endif
494