1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl.h" |
18 | |
19 | #include "rnn/rnn_aux.hpp" |
20 | |
21 | namespace rnn { |
22 | |
23 | alg_t str2alg(const char *str) { |
24 | #define CASE(_alg) \ |
25 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
26 | CASE(VANILLA_RNN); |
27 | CASE(VANILLA_LSTM); |
28 | CASE(VANILLA_GRU); |
29 | CASE(LBR_GRU); |
30 | CASE(VANILLA_AUGRU); |
31 | CASE(LBR_AUGRU); |
32 | #undef CASE |
33 | assert(!"unknown algorithm" ); |
34 | return VANILLA_RNN; |
35 | } |
36 | |
37 | const char *alg2str(alg_t alg) { |
38 | if (alg == VANILLA_RNN) return "VANILLA_RNN" ; |
39 | if (alg == VANILLA_LSTM) return "VANILLA_LSTM" ; |
40 | if (alg == VANILLA_GRU) return "VANILLA_GRU" ; |
41 | if (alg == LBR_GRU) return "LBR_GRU" ; |
42 | if (alg == VANILLA_AUGRU) return "VANILLA_AUGRU" ; |
43 | if (alg == LBR_AUGRU) return "LBR_AUGRU" ; |
44 | assert(!"unknown algorithm" ); |
45 | return "unknown algorithm" ; |
46 | } |
47 | |
48 | dnnl_alg_kind_t alg2kind(alg_t alg) { |
49 | if (alg == VANILLA_RNN) return dnnl_vanilla_rnn; |
50 | if (alg == VANILLA_LSTM) return dnnl_vanilla_lstm; |
51 | if (alg == VANILLA_GRU) return dnnl_vanilla_gru; |
52 | if (alg == LBR_GRU) return dnnl_lbr_gru; |
53 | if (alg == VANILLA_AUGRU) return dnnl_vanilla_augru; |
54 | if (alg == LBR_AUGRU) return dnnl_lbr_augru; |
55 | assert(!"unknown algorithm" ); |
56 | return dnnl_alg_kind_undef; |
57 | } |
58 | |
59 | activation_t str2activation(const char *str) { |
60 | #define CASE(_act) \ |
61 | if (!strcasecmp(STRINGIFY(_act), str)) return _act |
62 | CASE(RELU); |
63 | CASE(LOGISTIC); |
64 | CASE(TANH); |
65 | CASE(UNDEF); |
66 | #undef CASE |
67 | assert(!"unknown activation" ); |
68 | return UNDEF; |
69 | } |
70 | |
71 | const char *activation2str(activation_t act) { |
72 | const char *str = "unknown activation" ; |
73 | switch (act) { |
74 | case RELU: str = "RELU" ; break; |
75 | case LOGISTIC: str = "LOGISTIC" ; break; |
76 | case TANH: str = "TANH" ; break; |
77 | case UNDEF: str = "UNDEF" ; break; |
78 | default: assert(!"unknown activation" ); |
79 | } |
80 | return str; |
81 | } |
82 | |
83 | dnnl_alg_kind_t activation2kind(activation_t act) { |
84 | dnnl_alg_kind_t alg_kind = dnnl_alg_kind_undef; |
85 | switch (act) { |
86 | case RELU: alg_kind = dnnl_eltwise_relu; break; |
87 | case LOGISTIC: alg_kind = dnnl_eltwise_logistic; break; |
88 | case TANH: alg_kind = dnnl_eltwise_tanh; break; |
89 | case UNDEF: alg_kind = dnnl_alg_kind_undef; break; |
90 | default: assert(!"unknown activation" ); |
91 | } |
92 | return alg_kind; |
93 | } |
94 | |
95 | dnnl_rnn_direction_t str2direction(const char *str) { |
96 | if (!strcasecmp("left2right" , str)) return dnnl_unidirectional_left2right; |
97 | if (!strcasecmp("right2left" , str)) return dnnl_unidirectional_right2left; |
98 | if (!strcasecmp("concat" , str)) return dnnl_bidirectional_concat; |
99 | if (!strcasecmp("sum" , str)) return dnnl_bidirectional_sum; |
100 | assert(!"unknown direction" ); |
101 | return dnnl_unidirectional_left2right; |
102 | } |
103 | |
104 | const char *direction2str(dnnl_rnn_direction_t direction) { |
105 | if (direction == dnnl_unidirectional_left2right) return "left2right" ; |
106 | if (direction == dnnl_unidirectional_right2left) return "right2left" ; |
107 | if (direction == dnnl_bidirectional_concat) return "concat" ; |
108 | if (direction == dnnl_bidirectional_sum) return "sum" ; |
109 | assert(!"unknown direction" ); |
110 | return "unknown direction" ; |
111 | } |
112 | |
113 | const char *rnn_data_kind2str(rnn_data_kind_t kind) { |
114 | #define CASE(KIND) \ |
115 | if (kind == (KIND)) return STRINGIFY(KIND) |
116 | CASE(SRC_LAYER); |
117 | CASE(AUGRU_ATTENTION); |
118 | CASE(SRC_ITER); |
119 | CASE(SRC_ITER_C); |
120 | CASE(WEIGHTS_LAYER); |
121 | CASE(WEIGHTS_ITER); |
122 | CASE(WEIGHTS_PEEPHOLE); |
123 | CASE(WEIGHTS_PROJECTION); |
124 | CASE(BIAS); |
125 | CASE(DST_LAYER); |
126 | CASE(DST_ITER); |
127 | CASE(DST_ITER_C); |
128 | |
129 | CASE(DIFF_SRC_LAYER); |
130 | CASE(DIFF_AUGRU_ATTENTION); |
131 | CASE(DIFF_SRC_ITER); |
132 | CASE(DIFF_SRC_ITER_C); |
133 | CASE(DIFF_WEIGHTS_LAYER); |
134 | CASE(DIFF_WEIGHTS_ITER); |
135 | CASE(DIFF_WEIGHTS_PEEPHOLE); |
136 | CASE(DIFF_WEIGHTS_PROJECTION); |
137 | CASE(DIFF_BIAS); |
138 | CASE(DIFF_DST_LAYER); |
139 | CASE(DIFF_DST_ITER); |
140 | CASE(DIFF_DST_ITER_C); |
141 | #undef CASE |
142 | |
143 | assert(!"incorrect rnn data kind" ); |
144 | return "incorrect rnn data kind" ; |
145 | } |
146 | |
147 | int str2desc(desc_t *desc, const char *str) { |
148 | desc_t d {0}; |
149 | |
150 | /* canonical form: |
151 | * lXtXmbXsicXslcXdhcXdicX |
152 | * |
153 | * where: X is number, S - string |
154 | * note: symbol `_` is ignored |
155 | * |
156 | * implicit rules: |
157 | * - default values: |
158 | * l = 1, t = 1, mb = 2 |
159 | * - if slc/dhc is undefined => slc/dhc = sic |
160 | * - if dic is undefined => dic = dhc |
161 | */ |
162 | |
163 | d.n_layer = 1; |
164 | d.n_iter = 1; |
165 | d.mb = 2; |
166 | |
167 | const char *s = str; |
168 | assert(s); |
169 | |
170 | #define CASE_NN(prb, c) \ |
171 | do { \ |
172 | if (!strncmp(prb, s, strlen(prb))) { \ |
173 | ok = 1; \ |
174 | s += strlen(prb); \ |
175 | char *end_s; \ |
176 | d.c = strtol(s, &end_s, 10); \ |
177 | s += (end_s - s); \ |
178 | if (d.c < 0) return FAIL; \ |
179 | } \ |
180 | } while (0) |
181 | #define CASE_N(c) CASE_NN(#c, c) |
182 | while (*s) { |
183 | int ok = 0; |
184 | CASE_NN("l" , n_layer); |
185 | CASE_NN("t" , n_iter); |
186 | CASE_N(mb); |
187 | CASE_N(sic); |
188 | CASE_N(slc); |
189 | CASE_N(dhc); |
190 | CASE_N(dic); |
191 | if (!strncmp("dlc" , s, 3)) { |
192 | BENCHDNN_PRINT(0, "%s\n" , |
193 | "WARNING: the RNN descriptor symbol `dlc` is no longer " |
194 | "supported. Please adjust the RNN descriptor and try " |
195 | "again. Note: usually it is enough to simply remove `dlc` " |
196 | "from the descriptor string." ); |
197 | return FAIL; |
198 | } |
199 | if (*s == 'n') { |
200 | d.name = s + 1; |
201 | break; |
202 | } |
203 | if (*s == '_') ++s; |
204 | if (!ok) return FAIL; |
205 | } |
206 | #undef CASE_NN |
207 | #undef CASE_N |
208 | |
209 | if (d.sic == 0) return FAIL; |
210 | if (d.slc == 0) d.slc = d.sic; |
211 | if (d.dhc == 0) d.dhc = d.sic; |
212 | if (d.dic == 0) d.dic = d.dhc; |
213 | d.wc = MAX2(MAX2(d.sic, d.slc), MAX2(d.dic, d.dhc)); |
214 | |
215 | *desc = d; |
216 | |
217 | return OK; |
218 | } |
219 | |
220 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
221 | s << "l" << d.n_layer << "t" << d.n_iter << "mb" << d.mb << "sic" << d.sic |
222 | << "slc" << d.slc << "dhc" << d.dhc << "dic" << d.dic; |
223 | |
224 | if (!d.name.empty()) s << "n" << d.name; |
225 | |
226 | return s; |
227 | } |
228 | |
229 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
230 | dump_global_params(s); |
231 | settings_t def; |
232 | |
233 | if (canonical || prb.prop != prop2prop_kind(def.prop[0])) |
234 | s << "--prop=" << prop2str(prb.prop) << " " ; |
235 | if (canonical || prb.cfg.str() != def.cfg[0]) |
236 | s << "--cfg=" << prb.cfg.str() << " " ; |
237 | if (canonical || prb.alg != def.alg[0]) |
238 | s << "--alg=" << alg2str(prb.alg) << " " ; |
239 | if (canonical || prb.direction != def.direction[0]) |
240 | s << "--direction=" << direction2str(prb.direction) << " " ; |
241 | if (canonical || prb.activation != def.activation[0]) |
242 | s << "--activation=" << activation2str(prb.activation) << " " ; |
243 | if (canonical || prb.skip_nonlinear != def.skip_nonlinear[0]) |
244 | s << "--skip-nonlinear=" << bool2str(prb.skip_nonlinear) << " " ; |
245 | if (canonical || prb.with_peephole != def.with_peephole[0]) |
246 | s << "--with-peephole=" << bool2str(prb.with_peephole) << " " ; |
247 | if (canonical || prb.with_projection != def.with_projection[0]) |
248 | s << "--with-projection=" << bool2str(prb.with_projection) << " " ; |
249 | if (canonical || prb.wei_scales_policy != def.scale_policy[0]) |
250 | s << "--scaling=" << prb.wei_scales_policy << " " ; |
251 | if (canonical || prb.wei_proj_scales_policy != def.scale_proj_policy[0]) |
252 | s << "--scaling-proj=" << prb.wei_proj_scales_policy << " " ; |
253 | if (canonical || prb.trivial_strides != def.trivial_strides[0]) |
254 | s << "--trivial-strides=" << bool2str(prb.trivial_strides) << " " ; |
255 | |
256 | s << prb.attr; |
257 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
258 | s << "--ctx-init=" << prb.ctx_init << " " ; |
259 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
260 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
261 | |
262 | s << static_cast<const desc_t &>(prb); |
263 | |
264 | return s; |
265 | } |
266 | |
267 | dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine, |
268 | const prb_t &prb, dnnl_prop_kind_t prop_kind, |
269 | const_dnnl_memory_desc_t src_layer_d, |
270 | const_dnnl_memory_desc_t src_iter_d, |
271 | const_dnnl_memory_desc_t src_iter_c_d, |
272 | const_dnnl_memory_desc_t attention_d, |
273 | const_dnnl_memory_desc_t weights_layer_d, |
274 | const_dnnl_memory_desc_t weights_iter_d, |
275 | const_dnnl_memory_desc_t weights_peephole_d, |
276 | const_dnnl_memory_desc_t weights_projection_d, |
277 | const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d, |
278 | const_dnnl_memory_desc_t dst_iter_d, |
279 | const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr) { |
280 | dnnl_alg_kind_t kind = alg2kind(prb.alg); |
281 | dnnl_alg_kind_t f = activation2kind(prb.activation); |
282 | |
283 | dnnl_status_t status; |
284 | switch (kind) { |
285 | case dnnl_vanilla_rnn: |
286 | status = dnnl_vanilla_rnn_forward_primitive_desc_create(pd, engine, |
287 | prop_kind, f, prb.direction, src_layer_d, src_iter_d, |
288 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
289 | dst_iter_d, prb.flags, prb.alpha, prb.beta, attr); |
290 | break; |
291 | case dnnl_vanilla_lstm: |
292 | status = dnnl_lstm_forward_primitive_desc_create(pd, engine, |
293 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
294 | src_iter_c_d, weights_layer_d, weights_iter_d, |
295 | weights_peephole_d, weights_projection_d, bias_d, |
296 | dst_layer_d, dst_iter_d, dst_iter_c_d, prb.flags, attr); |
297 | break; |
298 | case dnnl_vanilla_gru: |
299 | status = dnnl_gru_forward_primitive_desc_create(pd, engine, |
300 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
301 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
302 | dst_iter_d, prb.flags, attr); |
303 | break; |
304 | case dnnl_lbr_gru: |
305 | status = dnnl_lbr_gru_forward_primitive_desc_create(pd, engine, |
306 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
307 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
308 | dst_iter_d, prb.flags, attr); |
309 | break; |
310 | case dnnl_vanilla_augru: |
311 | status = dnnl_augru_forward_primitive_desc_create(pd, engine, |
312 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
313 | attention_d, weights_layer_d, weights_iter_d, bias_d, |
314 | dst_layer_d, dst_iter_d, prb.flags, attr); |
315 | break; |
316 | case dnnl_lbr_augru: |
317 | status = dnnl_lbr_augru_forward_primitive_desc_create(pd, engine, |
318 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
319 | attention_d, weights_layer_d, weights_iter_d, bias_d, |
320 | dst_layer_d, dst_iter_d, prb.flags, attr); |
321 | break; |
322 | default: status = dnnl_unimplemented; |
323 | } |
324 | return status; |
325 | } |
326 | |
327 | dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine, |
328 | const prb_t &prb, dnnl_prop_kind_t prop_kind, |
329 | const_dnnl_memory_desc_t src_layer_d, |
330 | const_dnnl_memory_desc_t src_iter_d, |
331 | const_dnnl_memory_desc_t src_iter_c_d, |
332 | const_dnnl_memory_desc_t attention_d, |
333 | const_dnnl_memory_desc_t weights_layer_d, |
334 | const_dnnl_memory_desc_t weights_iter_d, |
335 | const_dnnl_memory_desc_t weights_peephole_d, |
336 | const_dnnl_memory_desc_t weights_projection_d, |
337 | const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d, |
338 | const_dnnl_memory_desc_t dst_iter_d, |
339 | const_dnnl_memory_desc_t dst_iter_c_d, |
340 | const_dnnl_memory_desc_t diff_src_layer_d, |
341 | const_dnnl_memory_desc_t diff_src_iter_d, |
342 | const_dnnl_memory_desc_t diff_src_iter_c_d, |
343 | const_dnnl_memory_desc_t diff_attention_d, |
344 | const_dnnl_memory_desc_t diff_weights_layer_d, |
345 | const_dnnl_memory_desc_t diff_weights_iter_d, |
346 | const_dnnl_memory_desc_t diff_weights_peephole_d, |
347 | const_dnnl_memory_desc_t diff_weights_projection_d, |
348 | const_dnnl_memory_desc_t diff_bias_d, |
349 | const_dnnl_memory_desc_t diff_dst_layer_d, |
350 | const_dnnl_memory_desc_t diff_dst_iter_d, |
351 | const_dnnl_memory_desc_t diff_dst_iter_c_d, |
352 | const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr) { |
353 | dnnl_alg_kind_t kind = alg2kind(prb.alg); |
354 | dnnl_alg_kind_t f = activation2kind(prb.activation); |
355 | |
356 | dnnl_status_t status; |
357 | switch (kind) { |
358 | case dnnl_vanilla_rnn: |
359 | status = dnnl_vanilla_rnn_backward_primitive_desc_create(pd, engine, |
360 | prop_kind, f, prb.direction, src_layer_d, src_iter_d, |
361 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
362 | dst_iter_d, diff_src_layer_d, diff_src_iter_d, |
363 | diff_weights_layer_d, diff_weights_iter_d, diff_bias_d, |
364 | diff_dst_layer_d, diff_dst_iter_d, prb.flags, prb.alpha, |
365 | prb.beta, hint, attr); |
366 | break; |
367 | case dnnl_vanilla_lstm: |
368 | status = dnnl_lstm_backward_primitive_desc_create(pd, engine, |
369 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
370 | src_iter_c_d, weights_layer_d, weights_iter_d, |
371 | weights_peephole_d, weights_projection_d, bias_d, |
372 | dst_layer_d, dst_iter_d, dst_iter_c_d, diff_src_layer_d, |
373 | diff_src_iter_d, diff_src_iter_c_d, diff_weights_layer_d, |
374 | diff_weights_iter_d, diff_weights_peephole_d, |
375 | diff_weights_projection_d, diff_bias_d, diff_dst_layer_d, |
376 | diff_dst_iter_d, diff_dst_iter_c_d, prb.flags, hint, attr); |
377 | break; |
378 | case dnnl_vanilla_gru: |
379 | status = dnnl_gru_backward_primitive_desc_create(pd, engine, |
380 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
381 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
382 | dst_iter_d, diff_src_layer_d, diff_src_iter_d, |
383 | diff_weights_layer_d, diff_weights_iter_d, diff_bias_d, |
384 | diff_dst_layer_d, diff_dst_iter_d, prb.flags, hint, attr); |
385 | break; |
386 | case dnnl_lbr_gru: |
387 | status = dnnl_lbr_gru_backward_primitive_desc_create(pd, engine, |
388 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
389 | weights_layer_d, weights_iter_d, bias_d, dst_layer_d, |
390 | dst_iter_d, diff_src_layer_d, diff_src_iter_d, |
391 | diff_weights_layer_d, diff_weights_iter_d, diff_bias_d, |
392 | diff_dst_layer_d, diff_dst_iter_d, prb.flags, hint, attr); |
393 | break; |
394 | case dnnl_vanilla_augru: |
395 | status = dnnl_augru_backward_primitive_desc_create(pd, engine, |
396 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
397 | attention_d, weights_layer_d, weights_iter_d, bias_d, |
398 | dst_layer_d, dst_iter_d, diff_src_layer_d, diff_src_iter_d, |
399 | diff_attention_d, diff_weights_layer_d, diff_weights_iter_d, |
400 | diff_bias_d, diff_dst_layer_d, diff_dst_iter_d, prb.flags, |
401 | hint, attr); |
402 | break; |
403 | case dnnl_lbr_augru: |
404 | status = dnnl_lbr_augru_backward_primitive_desc_create(pd, engine, |
405 | prop_kind, prb.direction, src_layer_d, src_iter_d, |
406 | attention_d, weights_layer_d, weights_iter_d, bias_d, |
407 | dst_layer_d, dst_iter_d, diff_src_layer_d, diff_src_iter_d, |
408 | diff_attention_d, diff_weights_layer_d, diff_weights_iter_d, |
409 | diff_bias_d, diff_dst_layer_d, diff_dst_iter_d, prb.flags, |
410 | hint, attr); |
411 | break; |
412 | default: status = dnnl_unimplemented; |
413 | } |
414 | return status; |
415 | } |
416 | |
417 | void init_buffer(float *buf, int64_t size, float value) { |
418 | for (int64_t i = 0; i < size; i++) |
419 | buf[i] = value; |
420 | } |
421 | |
422 | // If needed, dequantize u8 data to f32 via data scale and shift |
423 | float maybe_deq(const prb_t &prb, const float in) { |
424 | if (!prb.cfg.is_int8()) return in; |
425 | return (in - prb.data_shift) * (1.0f / prb.data_scale); |
426 | } |
427 | |
428 | // If needed, dequantize s32 accumulators to f32 via data and weights scales |
429 | // (no data shift is needed as it is handled by the compensation in the bias) |
430 | float maybe_deq(const prb_t &prb, const float in, int64_t oc) { |
431 | if (!prb.cfg.is_int8()) return in; |
432 | float scale = prb.get_wei_scale(oc); |
433 | return in * (1.0f / (scale * prb.data_scale)); |
434 | } |
435 | |
436 | // If needed, dequantize s32 accumulators to f32 via data, weights scales |
437 | // and compensation. |
438 | float maybe_deq( |
439 | const prb_t &prb, const float in, float scale, float compensation) { |
440 | if (!prb.cfg.is_int8()) return in; |
441 | return (in - compensation * prb.data_shift) |
442 | * (1.0f / (scale * prb.data_scale)); |
443 | } |
444 | |
445 | float maybe_deq_proj( |
446 | const prb_t &prb, const float in, float compensation, int64_t oc) { |
447 | return maybe_deq(prb, in, prb.get_wei_proj_scale(oc), compensation); |
448 | } |
449 | |
450 | float maybe_q(const prb_t &prb, float h) { |
451 | if (!prb.cfg.is_int8()) return h; |
452 | float fp = prb.data_scale * h + prb.data_shift; |
453 | if (fp > prb.cfg[SRC_LAYER].max) fp = prb.cfg[SRC_LAYER].max; |
454 | if (fp < prb.cfg[SRC_LAYER].min) fp = prb.cfg[SRC_LAYER].min; |
455 | fp = mxcsr_cvt(fp); |
456 | return fp; |
457 | } |
458 | |
459 | float logistic(float x) { |
460 | if (x < 0) |
461 | return (expf(x) / (1 + expf(x))); |
462 | else |
463 | return 1.0f - (expf(-x) / (1 + expf(-x))); |
464 | } |
465 | float dlogistic(float x) { |
466 | float tmp = logistic(x); |
467 | return tmp * (1 - tmp); |
468 | } |
469 | float dtanhf(float x) { |
470 | return (1 - tanhf(x)) * (1 + tanhf(x)); |
471 | } |
472 | float x_m_square(float x) { |
473 | return x - x * x; |
474 | } |
475 | float relu(float x, float alpha) { |
476 | return x > 0 ? x : alpha * x; |
477 | } |
478 | float drelu(float x, float alpha) { |
479 | return x > 0 ? 1.0f : alpha; |
480 | } |
481 | float one_m_square(float x) { |
482 | return 1 - x * x; |
483 | } |
484 | |
485 | rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind) { |
486 | switch (data_kind) { |
487 | case data_kind_t::DST: return rnn_data_kind_t::DST_LAYER; |
488 | case data_kind_t::DST_ITER: return rnn_data_kind_t::DST_ITER; |
489 | case data_kind_t::DST_ITER_C: return rnn_data_kind_t::DST_ITER_C; |
490 | case data_kind_t::SRC: return rnn_data_kind_t::DIFF_SRC_LAYER; |
491 | case data_kind_t::AUGRU_ATTENTION: |
492 | return rnn_data_kind_t::DIFF_AUGRU_ATTENTION; |
493 | case data_kind_t::SRC_ITER: return rnn_data_kind_t::DIFF_SRC_ITER; |
494 | case data_kind_t::SRC_ITER_C: return rnn_data_kind_t::DIFF_SRC_ITER_C; |
495 | case data_kind_t::WEI: return rnn_data_kind_t::DIFF_WEIGHTS_LAYER; |
496 | case data_kind_t::WEI_ITER: return rnn_data_kind_t::DIFF_WEIGHTS_ITER; |
497 | case data_kind_t::WEI_PEEPHOLE: |
498 | return rnn_data_kind_t::DIFF_WEIGHTS_PEEPHOLE; |
499 | case data_kind_t::WEI_PROJECTION: |
500 | return rnn_data_kind_t::DIFF_WEIGHTS_PROJECTION; |
501 | case data_kind_t::BIA: return rnn_data_kind_t::DIFF_BIAS; |
502 | default: assert(!"unknown data kind" ); |
503 | } |
504 | return KIND_TOTAL; |
505 | } |
506 | |
507 | void prb_t::set_qparams(float fp_min, float fp_max) { |
508 | if (!cfg.is_int8()) { |
509 | data_shift = 0.; |
510 | data_scale = 1.; |
511 | wei_scales[0] = 1.; |
512 | return; |
513 | } |
514 | |
515 | /* Set parameters for quantization of src and weights from fp32 data |
516 | * in [-1, 1] to int8 data in a range specified in cfg */ |
517 | float fp_range = fp_max - fp_min; |
518 | float int8_src_range = cfg[SRC_LAYER].f_max - cfg[SRC_LAYER].f_min, |
519 | int8_wei_range = cfg[WEIGHTS_LAYER].f_max - cfg[WEIGHTS_LAYER].f_min; |
520 | |
521 | // No shift needed for s8s8 AMX LSTM |
522 | data_shift = cfg.is_s8() ? 0 : cfg[SRC_LAYER].f_mean; |
523 | data_scale = int8_src_range / fp_range; |
524 | |
525 | float K = int8_wei_range / fp_range; |
526 | auto set_wei_scales = [&](float *scales, int nelems) { |
527 | for (int64_t i = 0; i < nelems; i++) |
528 | scales[i] = K * (1. + (float)i / nelems); |
529 | }; |
530 | |
531 | set_wei_scales(wei_scales, wei_nscales); |
532 | if (with_projection) set_wei_scales(wei_proj_scales, wei_proj_nscales); |
533 | } |
534 | |
535 | void prb_t::set_tparams(float fp_min, float fp_max) { |
536 | if (skip_nonlinear) { |
537 | assert(linear_scales != nullptr); |
538 | // Here, we assume that the inputs of the cells are in [fp_min,fp_max]. |
539 | // We pick the scaling factors to ensure that the output of the linear |
540 | // pre/post gemm is in [fp_min,fp_max] |
541 | |
542 | // Also, we rely on the fact that for forward, the weights |
543 | // matrices are sparse, and contain coefficient equal to |
544 | // 1/n_gates() to compensate for the gemm accumulation. So |
545 | // here, we account only for the post-gemm accumulation, and |
546 | // the fact that we want to use different scales per gate. |
547 | |
548 | // For BWD_W, we cannot assume sparseness though since the |
549 | // gates and diff_dst_* are dense. |
550 | int64_t fwd_acc_dim = n_gates(); |
551 | int64_t bwdd_acc_dim = dhc; |
552 | int64_t bwdw_acc_dim = mb; |
553 | int64_t acc_dim = 0; |
554 | if (prop == dnnl_backward) |
555 | acc_dim = n_gates() |
556 | * MAX2(fwd_acc_dim, MAX2(bwdd_acc_dim, bwdw_acc_dim)); |
557 | else |
558 | acc_dim = fwd_acc_dim; |
559 | // make scaling exact by choosing powers of two. |
560 | int64_t n_cscale = (alg == VANILLA_LSTM); |
561 | int64_t divisor = next_pow2((acc_dim + n_cscale) * (is_int8() ? 2 : 1)); |
562 | float factor = (1.0f / (float)(divisor)); |
563 | for (int64_t i = 0; i < n_gates(); i++) |
564 | linear_scales[i] = (i + 1) * factor; |
565 | if (n_cscale) linear_cscale = (n_gates() + 1) * factor; |
566 | } |
567 | } |
568 | |
569 | } // namespace rnn |
570 | |