1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef BENCHDNN_RNN_AUX_HPP |
18 | #define BENCHDNN_RNN_AUX_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <stdlib.h> |
22 | #include "rnn/rnn.hpp" |
23 | |
24 | namespace rnn { |
25 | |
26 | typedef enum { |
27 | rnn_forward = 0, |
28 | rnn_backward, |
29 | } rnn_propagation_t; |
30 | |
31 | typedef enum { |
32 | left2right = 0, |
33 | right2left, |
34 | } rnn_iter_direction_t; |
35 | |
36 | typedef enum { |
37 | bottom2top = 0, |
38 | top2bottom, |
39 | } rnn_layer_direction_t; |
40 | |
41 | typedef enum { action_copy = 0, action_sum, action_concat } rnn_action_t; |
42 | |
43 | dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine, |
44 | const prb_t &prb, dnnl_prop_kind_t prop_kind, |
45 | const_dnnl_memory_desc_t src_layer_d, |
46 | const_dnnl_memory_desc_t src_iter_d, |
47 | const_dnnl_memory_desc_t src_iter_c_d, |
48 | const_dnnl_memory_desc_t attention_d, |
49 | const_dnnl_memory_desc_t weights_layer_d, |
50 | const_dnnl_memory_desc_t weights_iter_d, |
51 | const_dnnl_memory_desc_t weights_peephole_d, |
52 | const_dnnl_memory_desc_t weights_projection_d, |
53 | const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d, |
54 | const_dnnl_memory_desc_t dst_iter_d, |
55 | const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr); |
56 | |
57 | dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine, |
58 | const prb_t &prb, dnnl_prop_kind_t prop_kind, |
59 | const_dnnl_memory_desc_t src_layer_d, |
60 | const_dnnl_memory_desc_t src_iter_d, |
61 | const_dnnl_memory_desc_t src_iter_c_d, |
62 | const_dnnl_memory_desc_t attention_d, |
63 | const_dnnl_memory_desc_t weights_layer_d, |
64 | const_dnnl_memory_desc_t weights_iter_d, |
65 | const_dnnl_memory_desc_t weights_peephole_d, |
66 | const_dnnl_memory_desc_t weights_projection_d, |
67 | const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d, |
68 | const_dnnl_memory_desc_t dst_iter_d, |
69 | const_dnnl_memory_desc_t dst_iter_c_d, |
70 | const_dnnl_memory_desc_t diff_src_layer_d, |
71 | const_dnnl_memory_desc_t diff_src_iter_d, |
72 | const_dnnl_memory_desc_t diff_src_iter_c_d, |
73 | const_dnnl_memory_desc_t diff_attention_d, |
74 | const_dnnl_memory_desc_t diff_weights_layer_d, |
75 | const_dnnl_memory_desc_t diff_weights_iter_d, |
76 | const_dnnl_memory_desc_t diff_weights_peephole_d, |
77 | const_dnnl_memory_desc_t diff_weights_projection_d, |
78 | const_dnnl_memory_desc_t diff_bias_d, |
79 | const_dnnl_memory_desc_t diff_dst_layer_d, |
80 | const_dnnl_memory_desc_t diff_dst_iter_d, |
81 | const_dnnl_memory_desc_t diff_dst_iter_c_d, |
82 | const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr); |
83 | |
84 | void init_buffer(float *buf, int64_t size, float value); |
85 | |
86 | float maybe_q(const prb_t &prb, float h); |
87 | float maybe_deq(const prb_t &prb, const float in); |
88 | float maybe_deq(const prb_t &prb, const float in, int64_t oc); |
89 | float maybe_deq( |
90 | const prb_t &prb, const float in, float scale, float compensation); |
91 | float maybe_deq_proj( |
92 | const prb_t &prb, const float in, float compensation, int64_t oc); |
93 | |
94 | float logistic(float x); |
95 | float dlogistic(float x); |
96 | float relu(float x, float alpha); |
97 | float drelu(float x, float alpha); |
98 | float dtanhf(float x); |
99 | float one_m_square(float x); |
100 | float x_m_square(float x); |
101 | |
102 | void copy(int64_t dimc, int64_t dimr, int64_t ld_src, int64_t ld_dst, |
103 | const float *src_, float *dst_, rnn_action_t action = action_copy, |
104 | bool saturate_to_u8 = false); |
105 | void data_q10n(int64_t dimc, int64_t dimr, int64_t ld_src, float *src_, |
106 | float data_scale, float data_shift); |
107 | void data_deq10n(int64_t dimc, int64_t dimr, int64_t ld_src, float *src_, |
108 | float data_scale, float data_shift); |
109 | void gates_reduction( |
110 | const prb_t &prb, const float *b_gates_, float *diff_bias_); |
111 | |
112 | rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind); |
113 | |
114 | }; // namespace rnn |
115 | |
116 | #endif |
117 | |