1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef BENCHDNN_RNN_CELLS_HPP |
18 | #define BENCHDNN_RNN_CELLS_HPP |
19 | |
20 | #include "rnn/rnn.hpp" |
21 | |
22 | namespace rnn { |
23 | |
24 | void rnn_fwd(const prb_t &prb, float *dst_layer_, float *gates_, |
25 | const float *weights_layer_, const float *weights_iter_, |
26 | const float *bias_, const float *src_layer_, const float *src_iter_); |
27 | |
28 | void rnn_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_, |
29 | float *diff_weights_layer_, float *diff_weights_iter_, |
30 | float *diff_bias_, float *b_gates_, const float *src_layer_, |
31 | const float *src_iter_, const float *weights_layer_, |
32 | const float *weights_iter_, const float *bias_, const float *gates_, |
33 | const float *diff_dst_layer_, const float *diff_dst_iter_); |
34 | |
35 | void lstm_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_, |
36 | float *dst_iter_c_, float *gates_, float *ht_, |
37 | const float *weights_layer_, const float *weights_iter_, |
38 | const float *weights_peephole_, const float *weights_projection_, |
39 | const float *weights_projection_compensation, const float *bias_, |
40 | const float *src_layer_, const float *src_iter_, |
41 | const float *src_iter_c_); |
42 | |
43 | void lstm_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_, |
44 | float *diff_src_iter_c_, float *diff_weights_layer_, |
45 | float *diff_weights_iter_, float *diff_weights_peephole_, |
46 | float *diff_weights_projection_, float *diff_bias_, float *b_gates_, |
47 | const float *src_layer_, const float *src_iter_, |
48 | const float *src_iter_c_, const float *weights_layer_, |
49 | const float *weights_iter_, const float *weights_peephole_, |
50 | const float *weights_projection_, const float *bias_, |
51 | const float *dst_layer_, const float *dst_iter_c_, const float *gates_, |
52 | const float *ht_, const float *diff_dst_layer_, |
53 | const float *diff_dst_iter_, const float *diff_dst_iter_c_, |
54 | float *cell_scratchpad_); |
55 | |
56 | void gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_, |
57 | const float *weights_layer_, const float *weights_iter_, |
58 | const float *bias_, const float *src_layer_, |
59 | const float *src_layer_attention_, const float *src_iter_); |
60 | |
61 | void gru_bwd(const prb_t &prb, float *diff_src_layer_, |
62 | float *diff_src_layer_attention_, float *diff_src_iter_, |
63 | float *diff_weights_layer_, float *diff_weights_iter_, |
64 | float *diff_bias_, float *b_gates_, const float *src_layer_, |
65 | const float *src_layer_attention_, const float *src_iter_, |
66 | const float *weights_layer_, const float *weights_iter_, |
67 | const float *bias_, const float *gates_, const float *diff_dst_layer_, |
68 | const float *diff_dst_iter_, float *cell_scratchpad_); |
69 | |
70 | void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_, |
71 | const float *weights_layer_, const float *weights_iter_, |
72 | const float *bias_, const float *src_layer_, |
73 | const float *src_layer_attention_, const float *src_iter_, |
74 | float *cell_scratchpad_); |
75 | |
76 | void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_, |
77 | float *diff_src_layer_attention_, float *diff_src_iter_, |
78 | float *diff_weights_layer_, float *diff_weights_iter_, |
79 | float *diff_bias_, float *b_gates_, const float *src_layer_, |
80 | const float *src_layer_attention_, const float *src_iter_, |
81 | const float *weights_layer_, const float *weights_iter_, |
82 | const float *bias_, const float *gates_, const float *diff_dst_layer_, |
83 | const float *diff_dst_iter_, float *cell_scratchpad_); |
84 | |
85 | } // namespace rnn |
86 | |
87 | #endif |
88 | |