1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl.h"
18
19#include "rnn/rnn_aux.hpp"
20
21namespace rnn {
22
23alg_t str2alg(const char *str) {
24#define CASE(_alg) \
25 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
26 CASE(VANILLA_RNN);
27 CASE(VANILLA_LSTM);
28 CASE(VANILLA_GRU);
29 CASE(LBR_GRU);
30 CASE(VANILLA_AUGRU);
31 CASE(LBR_AUGRU);
32#undef CASE
33 assert(!"unknown algorithm");
34 return VANILLA_RNN;
35}
36
37const char *alg2str(alg_t alg) {
38 if (alg == VANILLA_RNN) return "VANILLA_RNN";
39 if (alg == VANILLA_LSTM) return "VANILLA_LSTM";
40 if (alg == VANILLA_GRU) return "VANILLA_GRU";
41 if (alg == LBR_GRU) return "LBR_GRU";
42 if (alg == VANILLA_AUGRU) return "VANILLA_AUGRU";
43 if (alg == LBR_AUGRU) return "LBR_AUGRU";
44 assert(!"unknown algorithm");
45 return "unknown algorithm";
46}
47
48dnnl_alg_kind_t alg2kind(alg_t alg) {
49 if (alg == VANILLA_RNN) return dnnl_vanilla_rnn;
50 if (alg == VANILLA_LSTM) return dnnl_vanilla_lstm;
51 if (alg == VANILLA_GRU) return dnnl_vanilla_gru;
52 if (alg == LBR_GRU) return dnnl_lbr_gru;
53 if (alg == VANILLA_AUGRU) return dnnl_vanilla_augru;
54 if (alg == LBR_AUGRU) return dnnl_lbr_augru;
55 assert(!"unknown algorithm");
56 return dnnl_alg_kind_undef;
57}
58
59activation_t str2activation(const char *str) {
60#define CASE(_act) \
61 if (!strcasecmp(STRINGIFY(_act), str)) return _act
62 CASE(RELU);
63 CASE(LOGISTIC);
64 CASE(TANH);
65 CASE(UNDEF);
66#undef CASE
67 assert(!"unknown activation");
68 return UNDEF;
69}
70
71const char *activation2str(activation_t act) {
72 const char *str = "unknown activation";
73 switch (act) {
74 case RELU: str = "RELU"; break;
75 case LOGISTIC: str = "LOGISTIC"; break;
76 case TANH: str = "TANH"; break;
77 case UNDEF: str = "UNDEF"; break;
78 default: assert(!"unknown activation");
79 }
80 return str;
81}
82
83dnnl_alg_kind_t activation2kind(activation_t act) {
84 dnnl_alg_kind_t alg_kind = dnnl_alg_kind_undef;
85 switch (act) {
86 case RELU: alg_kind = dnnl_eltwise_relu; break;
87 case LOGISTIC: alg_kind = dnnl_eltwise_logistic; break;
88 case TANH: alg_kind = dnnl_eltwise_tanh; break;
89 case UNDEF: alg_kind = dnnl_alg_kind_undef; break;
90 default: assert(!"unknown activation");
91 }
92 return alg_kind;
93}
94
95dnnl_rnn_direction_t str2direction(const char *str) {
96 if (!strcasecmp("left2right", str)) return dnnl_unidirectional_left2right;
97 if (!strcasecmp("right2left", str)) return dnnl_unidirectional_right2left;
98 if (!strcasecmp("concat", str)) return dnnl_bidirectional_concat;
99 if (!strcasecmp("sum", str)) return dnnl_bidirectional_sum;
100 assert(!"unknown direction");
101 return dnnl_unidirectional_left2right;
102}
103
104const char *direction2str(dnnl_rnn_direction_t direction) {
105 if (direction == dnnl_unidirectional_left2right) return "left2right";
106 if (direction == dnnl_unidirectional_right2left) return "right2left";
107 if (direction == dnnl_bidirectional_concat) return "concat";
108 if (direction == dnnl_bidirectional_sum) return "sum";
109 assert(!"unknown direction");
110 return "unknown direction";
111}
112
113const char *rnn_data_kind2str(rnn_data_kind_t kind) {
114#define CASE(KIND) \
115 if (kind == (KIND)) return STRINGIFY(KIND)
116 CASE(SRC_LAYER);
117 CASE(AUGRU_ATTENTION);
118 CASE(SRC_ITER);
119 CASE(SRC_ITER_C);
120 CASE(WEIGHTS_LAYER);
121 CASE(WEIGHTS_ITER);
122 CASE(WEIGHTS_PEEPHOLE);
123 CASE(WEIGHTS_PROJECTION);
124 CASE(BIAS);
125 CASE(DST_LAYER);
126 CASE(DST_ITER);
127 CASE(DST_ITER_C);
128
129 CASE(DIFF_SRC_LAYER);
130 CASE(DIFF_AUGRU_ATTENTION);
131 CASE(DIFF_SRC_ITER);
132 CASE(DIFF_SRC_ITER_C);
133 CASE(DIFF_WEIGHTS_LAYER);
134 CASE(DIFF_WEIGHTS_ITER);
135 CASE(DIFF_WEIGHTS_PEEPHOLE);
136 CASE(DIFF_WEIGHTS_PROJECTION);
137 CASE(DIFF_BIAS);
138 CASE(DIFF_DST_LAYER);
139 CASE(DIFF_DST_ITER);
140 CASE(DIFF_DST_ITER_C);
141#undef CASE
142
143 assert(!"incorrect rnn data kind");
144 return "incorrect rnn data kind";
145}
146
147int str2desc(desc_t *desc, const char *str) {
148 desc_t d {0};
149
150 /* canonical form:
151 * lXtXmbXsicXslcXdhcXdicX
152 *
153 * where: X is number, S - string
154 * note: symbol `_` is ignored
155 *
156 * implicit rules:
157 * - default values:
158 * l = 1, t = 1, mb = 2
159 * - if slc/dhc is undefined => slc/dhc = sic
160 * - if dic is undefined => dic = dhc
161 */
162
163 d.n_layer = 1;
164 d.n_iter = 1;
165 d.mb = 2;
166
167 const char *s = str;
168 assert(s);
169
170#define CASE_NN(prb, c) \
171 do { \
172 if (!strncmp(prb, s, strlen(prb))) { \
173 ok = 1; \
174 s += strlen(prb); \
175 char *end_s; \
176 d.c = strtol(s, &end_s, 10); \
177 s += (end_s - s); \
178 if (d.c < 0) return FAIL; \
179 } \
180 } while (0)
181#define CASE_N(c) CASE_NN(#c, c)
182 while (*s) {
183 int ok = 0;
184 CASE_NN("l", n_layer);
185 CASE_NN("t", n_iter);
186 CASE_N(mb);
187 CASE_N(sic);
188 CASE_N(slc);
189 CASE_N(dhc);
190 CASE_N(dic);
191 if (!strncmp("dlc", s, 3)) {
192 BENCHDNN_PRINT(0, "%s\n",
193 "WARNING: the RNN descriptor symbol `dlc` is no longer "
194 "supported. Please adjust the RNN descriptor and try "
195 "again. Note: usually it is enough to simply remove `dlc` "
196 "from the descriptor string.");
197 return FAIL;
198 }
199 if (*s == 'n') {
200 d.name = s + 1;
201 break;
202 }
203 if (*s == '_') ++s;
204 if (!ok) return FAIL;
205 }
206#undef CASE_NN
207#undef CASE_N
208
209 if (d.sic == 0) return FAIL;
210 if (d.slc == 0) d.slc = d.sic;
211 if (d.dhc == 0) d.dhc = d.sic;
212 if (d.dic == 0) d.dic = d.dhc;
213 d.wc = MAX2(MAX2(d.sic, d.slc), MAX2(d.dic, d.dhc));
214
215 *desc = d;
216
217 return OK;
218}
219
220std::ostream &operator<<(std::ostream &s, const desc_t &d) {
221 s << "l" << d.n_layer << "t" << d.n_iter << "mb" << d.mb << "sic" << d.sic
222 << "slc" << d.slc << "dhc" << d.dhc << "dic" << d.dic;
223
224 if (!d.name.empty()) s << "n" << d.name;
225
226 return s;
227}
228
229std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
230 dump_global_params(s);
231 settings_t def;
232
233 if (canonical || prb.prop != prop2prop_kind(def.prop[0]))
234 s << "--prop=" << prop2str(prb.prop) << " ";
235 if (canonical || prb.cfg.str() != def.cfg[0])
236 s << "--cfg=" << prb.cfg.str() << " ";
237 if (canonical || prb.alg != def.alg[0])
238 s << "--alg=" << alg2str(prb.alg) << " ";
239 if (canonical || prb.direction != def.direction[0])
240 s << "--direction=" << direction2str(prb.direction) << " ";
241 if (canonical || prb.activation != def.activation[0])
242 s << "--activation=" << activation2str(prb.activation) << " ";
243 if (canonical || prb.skip_nonlinear != def.skip_nonlinear[0])
244 s << "--skip-nonlinear=" << bool2str(prb.skip_nonlinear) << " ";
245 if (canonical || prb.with_peephole != def.with_peephole[0])
246 s << "--with-peephole=" << bool2str(prb.with_peephole) << " ";
247 if (canonical || prb.with_projection != def.with_projection[0])
248 s << "--with-projection=" << bool2str(prb.with_projection) << " ";
249 if (canonical || prb.wei_scales_policy != def.scale_policy[0])
250 s << "--scaling=" << prb.wei_scales_policy << " ";
251 if (canonical || prb.wei_proj_scales_policy != def.scale_proj_policy[0])
252 s << "--scaling-proj=" << prb.wei_proj_scales_policy << " ";
253 if (canonical || prb.trivial_strides != def.trivial_strides[0])
254 s << "--trivial-strides=" << bool2str(prb.trivial_strides) << " ";
255
256 s << prb.attr;
257 if (canonical || prb.ctx_init != def.ctx_init[0])
258 s << "--ctx-init=" << prb.ctx_init << " ";
259 if (canonical || prb.ctx_exe != def.ctx_exe[0])
260 s << "--ctx-exe=" << prb.ctx_exe << " ";
261
262 s << static_cast<const desc_t &>(prb);
263
264 return s;
265}
266
267dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
268 const prb_t &prb, dnnl_prop_kind_t prop_kind,
269 const_dnnl_memory_desc_t src_layer_d,
270 const_dnnl_memory_desc_t src_iter_d,
271 const_dnnl_memory_desc_t src_iter_c_d,
272 const_dnnl_memory_desc_t attention_d,
273 const_dnnl_memory_desc_t weights_layer_d,
274 const_dnnl_memory_desc_t weights_iter_d,
275 const_dnnl_memory_desc_t weights_peephole_d,
276 const_dnnl_memory_desc_t weights_projection_d,
277 const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
278 const_dnnl_memory_desc_t dst_iter_d,
279 const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr) {
280 dnnl_alg_kind_t kind = alg2kind(prb.alg);
281 dnnl_alg_kind_t f = activation2kind(prb.activation);
282
283 dnnl_status_t status;
284 switch (kind) {
285 case dnnl_vanilla_rnn:
286 status = dnnl_vanilla_rnn_forward_primitive_desc_create(pd, engine,
287 prop_kind, f, prb.direction, src_layer_d, src_iter_d,
288 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
289 dst_iter_d, prb.flags, prb.alpha, prb.beta, attr);
290 break;
291 case dnnl_vanilla_lstm:
292 status = dnnl_lstm_forward_primitive_desc_create(pd, engine,
293 prop_kind, prb.direction, src_layer_d, src_iter_d,
294 src_iter_c_d, weights_layer_d, weights_iter_d,
295 weights_peephole_d, weights_projection_d, bias_d,
296 dst_layer_d, dst_iter_d, dst_iter_c_d, prb.flags, attr);
297 break;
298 case dnnl_vanilla_gru:
299 status = dnnl_gru_forward_primitive_desc_create(pd, engine,
300 prop_kind, prb.direction, src_layer_d, src_iter_d,
301 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
302 dst_iter_d, prb.flags, attr);
303 break;
304 case dnnl_lbr_gru:
305 status = dnnl_lbr_gru_forward_primitive_desc_create(pd, engine,
306 prop_kind, prb.direction, src_layer_d, src_iter_d,
307 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
308 dst_iter_d, prb.flags, attr);
309 break;
310 case dnnl_vanilla_augru:
311 status = dnnl_augru_forward_primitive_desc_create(pd, engine,
312 prop_kind, prb.direction, src_layer_d, src_iter_d,
313 attention_d, weights_layer_d, weights_iter_d, bias_d,
314 dst_layer_d, dst_iter_d, prb.flags, attr);
315 break;
316 case dnnl_lbr_augru:
317 status = dnnl_lbr_augru_forward_primitive_desc_create(pd, engine,
318 prop_kind, prb.direction, src_layer_d, src_iter_d,
319 attention_d, weights_layer_d, weights_iter_d, bias_d,
320 dst_layer_d, dst_iter_d, prb.flags, attr);
321 break;
322 default: status = dnnl_unimplemented;
323 }
324 return status;
325}
326
327dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
328 const prb_t &prb, dnnl_prop_kind_t prop_kind,
329 const_dnnl_memory_desc_t src_layer_d,
330 const_dnnl_memory_desc_t src_iter_d,
331 const_dnnl_memory_desc_t src_iter_c_d,
332 const_dnnl_memory_desc_t attention_d,
333 const_dnnl_memory_desc_t weights_layer_d,
334 const_dnnl_memory_desc_t weights_iter_d,
335 const_dnnl_memory_desc_t weights_peephole_d,
336 const_dnnl_memory_desc_t weights_projection_d,
337 const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
338 const_dnnl_memory_desc_t dst_iter_d,
339 const_dnnl_memory_desc_t dst_iter_c_d,
340 const_dnnl_memory_desc_t diff_src_layer_d,
341 const_dnnl_memory_desc_t diff_src_iter_d,
342 const_dnnl_memory_desc_t diff_src_iter_c_d,
343 const_dnnl_memory_desc_t diff_attention_d,
344 const_dnnl_memory_desc_t diff_weights_layer_d,
345 const_dnnl_memory_desc_t diff_weights_iter_d,
346 const_dnnl_memory_desc_t diff_weights_peephole_d,
347 const_dnnl_memory_desc_t diff_weights_projection_d,
348 const_dnnl_memory_desc_t diff_bias_d,
349 const_dnnl_memory_desc_t diff_dst_layer_d,
350 const_dnnl_memory_desc_t diff_dst_iter_d,
351 const_dnnl_memory_desc_t diff_dst_iter_c_d,
352 const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr) {
353 dnnl_alg_kind_t kind = alg2kind(prb.alg);
354 dnnl_alg_kind_t f = activation2kind(prb.activation);
355
356 dnnl_status_t status;
357 switch (kind) {
358 case dnnl_vanilla_rnn:
359 status = dnnl_vanilla_rnn_backward_primitive_desc_create(pd, engine,
360 prop_kind, f, prb.direction, src_layer_d, src_iter_d,
361 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
362 dst_iter_d, diff_src_layer_d, diff_src_iter_d,
363 diff_weights_layer_d, diff_weights_iter_d, diff_bias_d,
364 diff_dst_layer_d, diff_dst_iter_d, prb.flags, prb.alpha,
365 prb.beta, hint, attr);
366 break;
367 case dnnl_vanilla_lstm:
368 status = dnnl_lstm_backward_primitive_desc_create(pd, engine,
369 prop_kind, prb.direction, src_layer_d, src_iter_d,
370 src_iter_c_d, weights_layer_d, weights_iter_d,
371 weights_peephole_d, weights_projection_d, bias_d,
372 dst_layer_d, dst_iter_d, dst_iter_c_d, diff_src_layer_d,
373 diff_src_iter_d, diff_src_iter_c_d, diff_weights_layer_d,
374 diff_weights_iter_d, diff_weights_peephole_d,
375 diff_weights_projection_d, diff_bias_d, diff_dst_layer_d,
376 diff_dst_iter_d, diff_dst_iter_c_d, prb.flags, hint, attr);
377 break;
378 case dnnl_vanilla_gru:
379 status = dnnl_gru_backward_primitive_desc_create(pd, engine,
380 prop_kind, prb.direction, src_layer_d, src_iter_d,
381 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
382 dst_iter_d, diff_src_layer_d, diff_src_iter_d,
383 diff_weights_layer_d, diff_weights_iter_d, diff_bias_d,
384 diff_dst_layer_d, diff_dst_iter_d, prb.flags, hint, attr);
385 break;
386 case dnnl_lbr_gru:
387 status = dnnl_lbr_gru_backward_primitive_desc_create(pd, engine,
388 prop_kind, prb.direction, src_layer_d, src_iter_d,
389 weights_layer_d, weights_iter_d, bias_d, dst_layer_d,
390 dst_iter_d, diff_src_layer_d, diff_src_iter_d,
391 diff_weights_layer_d, diff_weights_iter_d, diff_bias_d,
392 diff_dst_layer_d, diff_dst_iter_d, prb.flags, hint, attr);
393 break;
394 case dnnl_vanilla_augru:
395 status = dnnl_augru_backward_primitive_desc_create(pd, engine,
396 prop_kind, prb.direction, src_layer_d, src_iter_d,
397 attention_d, weights_layer_d, weights_iter_d, bias_d,
398 dst_layer_d, dst_iter_d, diff_src_layer_d, diff_src_iter_d,
399 diff_attention_d, diff_weights_layer_d, diff_weights_iter_d,
400 diff_bias_d, diff_dst_layer_d, diff_dst_iter_d, prb.flags,
401 hint, attr);
402 break;
403 case dnnl_lbr_augru:
404 status = dnnl_lbr_augru_backward_primitive_desc_create(pd, engine,
405 prop_kind, prb.direction, src_layer_d, src_iter_d,
406 attention_d, weights_layer_d, weights_iter_d, bias_d,
407 dst_layer_d, dst_iter_d, diff_src_layer_d, diff_src_iter_d,
408 diff_attention_d, diff_weights_layer_d, diff_weights_iter_d,
409 diff_bias_d, diff_dst_layer_d, diff_dst_iter_d, prb.flags,
410 hint, attr);
411 break;
412 default: status = dnnl_unimplemented;
413 }
414 return status;
415}
416
417void init_buffer(float *buf, int64_t size, float value) {
418 for (int64_t i = 0; i < size; i++)
419 buf[i] = value;
420}
421
422// If needed, dequantize u8 data to f32 via data scale and shift
423float maybe_deq(const prb_t &prb, const float in) {
424 if (!prb.cfg.is_int8()) return in;
425 return (in - prb.data_shift) * (1.0f / prb.data_scale);
426}
427
428// If needed, dequantize s32 accumulators to f32 via data and weights scales
429// (no data shift is needed as it is handled by the compensation in the bias)
430float maybe_deq(const prb_t &prb, const float in, int64_t oc) {
431 if (!prb.cfg.is_int8()) return in;
432 float scale = prb.get_wei_scale(oc);
433 return in * (1.0f / (scale * prb.data_scale));
434}
435
436// If needed, dequantize s32 accumulators to f32 via data, weights scales
437// and compensation.
438float maybe_deq(
439 const prb_t &prb, const float in, float scale, float compensation) {
440 if (!prb.cfg.is_int8()) return in;
441 return (in - compensation * prb.data_shift)
442 * (1.0f / (scale * prb.data_scale));
443}
444
445float maybe_deq_proj(
446 const prb_t &prb, const float in, float compensation, int64_t oc) {
447 return maybe_deq(prb, in, prb.get_wei_proj_scale(oc), compensation);
448}
449
450float maybe_q(const prb_t &prb, float h) {
451 if (!prb.cfg.is_int8()) return h;
452 float fp = prb.data_scale * h + prb.data_shift;
453 if (fp > prb.cfg[SRC_LAYER].max) fp = prb.cfg[SRC_LAYER].max;
454 if (fp < prb.cfg[SRC_LAYER].min) fp = prb.cfg[SRC_LAYER].min;
455 fp = mxcsr_cvt(fp);
456 return fp;
457}
458
459float logistic(float x) {
460 if (x < 0)
461 return (expf(x) / (1 + expf(x)));
462 else
463 return 1.0f - (expf(-x) / (1 + expf(-x)));
464}
465float dlogistic(float x) {
466 float tmp = logistic(x);
467 return tmp * (1 - tmp);
468}
469float dtanhf(float x) {
470 return (1 - tanhf(x)) * (1 + tanhf(x));
471}
472float x_m_square(float x) {
473 return x - x * x;
474}
475float relu(float x, float alpha) {
476 return x > 0 ? x : alpha * x;
477}
478float drelu(float x, float alpha) {
479 return x > 0 ? 1.0f : alpha;
480}
481float one_m_square(float x) {
482 return 1 - x * x;
483}
484
485rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind) {
486 switch (data_kind) {
487 case data_kind_t::DST: return rnn_data_kind_t::DST_LAYER;
488 case data_kind_t::DST_ITER: return rnn_data_kind_t::DST_ITER;
489 case data_kind_t::DST_ITER_C: return rnn_data_kind_t::DST_ITER_C;
490 case data_kind_t::SRC: return rnn_data_kind_t::DIFF_SRC_LAYER;
491 case data_kind_t::AUGRU_ATTENTION:
492 return rnn_data_kind_t::DIFF_AUGRU_ATTENTION;
493 case data_kind_t::SRC_ITER: return rnn_data_kind_t::DIFF_SRC_ITER;
494 case data_kind_t::SRC_ITER_C: return rnn_data_kind_t::DIFF_SRC_ITER_C;
495 case data_kind_t::WEI: return rnn_data_kind_t::DIFF_WEIGHTS_LAYER;
496 case data_kind_t::WEI_ITER: return rnn_data_kind_t::DIFF_WEIGHTS_ITER;
497 case data_kind_t::WEI_PEEPHOLE:
498 return rnn_data_kind_t::DIFF_WEIGHTS_PEEPHOLE;
499 case data_kind_t::WEI_PROJECTION:
500 return rnn_data_kind_t::DIFF_WEIGHTS_PROJECTION;
501 case data_kind_t::BIA: return rnn_data_kind_t::DIFF_BIAS;
502 default: assert(!"unknown data kind");
503 }
504 return KIND_TOTAL;
505}
506
507void prb_t::set_qparams(float fp_min, float fp_max) {
508 if (!cfg.is_int8()) {
509 data_shift = 0.;
510 data_scale = 1.;
511 wei_scales[0] = 1.;
512 return;
513 }
514
515 /* Set parameters for quantization of src and weights from fp32 data
516 * in [-1, 1] to int8 data in a range specified in cfg */
517 float fp_range = fp_max - fp_min;
518 float int8_src_range = cfg[SRC_LAYER].f_max - cfg[SRC_LAYER].f_min,
519 int8_wei_range = cfg[WEIGHTS_LAYER].f_max - cfg[WEIGHTS_LAYER].f_min;
520
521 // No shift needed for s8s8 AMX LSTM
522 data_shift = cfg.is_s8() ? 0 : cfg[SRC_LAYER].f_mean;
523 data_scale = int8_src_range / fp_range;
524
525 float K = int8_wei_range / fp_range;
526 auto set_wei_scales = [&](float *scales, int nelems) {
527 for (int64_t i = 0; i < nelems; i++)
528 scales[i] = K * (1. + (float)i / nelems);
529 };
530
531 set_wei_scales(wei_scales, wei_nscales);
532 if (with_projection) set_wei_scales(wei_proj_scales, wei_proj_nscales);
533}
534
535void prb_t::set_tparams(float fp_min, float fp_max) {
536 if (skip_nonlinear) {
537 assert(linear_scales != nullptr);
538 // Here, we assume that the inputs of the cells are in [fp_min,fp_max].
539 // We pick the scaling factors to ensure that the output of the linear
540 // pre/post gemm is in [fp_min,fp_max]
541
542 // Also, we rely on the fact that for forward, the weights
543 // matrices are sparse, and contain coefficient equal to
544 // 1/n_gates() to compensate for the gemm accumulation. So
545 // here, we account only for the post-gemm accumulation, and
546 // the fact that we want to use different scales per gate.
547
548 // For BWD_W, we cannot assume sparseness though since the
549 // gates and diff_dst_* are dense.
550 int64_t fwd_acc_dim = n_gates();
551 int64_t bwdd_acc_dim = dhc;
552 int64_t bwdw_acc_dim = mb;
553 int64_t acc_dim = 0;
554 if (prop == dnnl_backward)
555 acc_dim = n_gates()
556 * MAX2(fwd_acc_dim, MAX2(bwdd_acc_dim, bwdw_acc_dim));
557 else
558 acc_dim = fwd_acc_dim;
559 // make scaling exact by choosing powers of two.
560 int64_t n_cscale = (alg == VANILLA_LSTM);
561 int64_t divisor = next_pow2((acc_dim + n_cscale) * (is_int8() ? 2 : 1));
562 float factor = (1.0f / (float)(divisor));
563 for (int64_t i = 0; i < n_gates(); i++)
564 linear_scales[i] = (i + 1) * factor;
565 if (n_cscale) linear_cscale = (n_gates() + 1) * factor;
566 }
567}
568
569} // namespace rnn
570