1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef BENCHDNN_RNN_CELLS_HPP
18#define BENCHDNN_RNN_CELLS_HPP
19
20#include "rnn/rnn.hpp"
21
22namespace rnn {
23
24void rnn_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
25 const float *weights_layer_, const float *weights_iter_,
26 const float *bias_, const float *src_layer_, const float *src_iter_);
27
28void rnn_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_,
29 float *diff_weights_layer_, float *diff_weights_iter_,
30 float *diff_bias_, float *b_gates_, const float *src_layer_,
31 const float *src_iter_, const float *weights_layer_,
32 const float *weights_iter_, const float *bias_, const float *gates_,
33 const float *diff_dst_layer_, const float *diff_dst_iter_);
34
35void lstm_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_,
36 float *dst_iter_c_, float *gates_, float *ht_,
37 const float *weights_layer_, const float *weights_iter_,
38 const float *weights_peephole_, const float *weights_projection_,
39 const float *weights_projection_compensation, const float *bias_,
40 const float *src_layer_, const float *src_iter_,
41 const float *src_iter_c_);
42
43void lstm_bwd(const prb_t &prb, float *diff_src_layer_, float *diff_src_iter_,
44 float *diff_src_iter_c_, float *diff_weights_layer_,
45 float *diff_weights_iter_, float *diff_weights_peephole_,
46 float *diff_weights_projection_, float *diff_bias_, float *b_gates_,
47 const float *src_layer_, const float *src_iter_,
48 const float *src_iter_c_, const float *weights_layer_,
49 const float *weights_iter_, const float *weights_peephole_,
50 const float *weights_projection_, const float *bias_,
51 const float *dst_layer_, const float *dst_iter_c_, const float *gates_,
52 const float *ht_, const float *diff_dst_layer_,
53 const float *diff_dst_iter_, const float *diff_dst_iter_c_,
54 float *cell_scratchpad_);
55
56void gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
57 const float *weights_layer_, const float *weights_iter_,
58 const float *bias_, const float *src_layer_,
59 const float *src_layer_attention_, const float *src_iter_);
60
61void gru_bwd(const prb_t &prb, float *diff_src_layer_,
62 float *diff_src_layer_attention_, float *diff_src_iter_,
63 float *diff_weights_layer_, float *diff_weights_iter_,
64 float *diff_bias_, float *b_gates_, const float *src_layer_,
65 const float *src_layer_attention_, const float *src_iter_,
66 const float *weights_layer_, const float *weights_iter_,
67 const float *bias_, const float *gates_, const float *diff_dst_layer_,
68 const float *diff_dst_iter_, float *cell_scratchpad_);
69
70void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
71 const float *weights_layer_, const float *weights_iter_,
72 const float *bias_, const float *src_layer_,
73 const float *src_layer_attention_, const float *src_iter_,
74 float *cell_scratchpad_);
75
76void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_,
77 float *diff_src_layer_attention_, float *diff_src_iter_,
78 float *diff_weights_layer_, float *diff_weights_iter_,
79 float *diff_bias_, float *b_gates_, const float *src_layer_,
80 const float *src_layer_attention_, const float *src_iter_,
81 const float *weights_layer_, const float *weights_iter_,
82 const float *bias_, const float *gates_, const float *diff_dst_layer_,
83 const float *diff_dst_iter_, float *cell_scratchpad_);
84
85} // namespace rnn
86
87#endif
88