1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef BENCHDNN_RNN_AUX_HPP
18#define BENCHDNN_RNN_AUX_HPP
19
20#include <assert.h>
21#include <stdlib.h>
22#include "rnn/rnn.hpp"
23
24namespace rnn {
25
26typedef enum {
27 rnn_forward = 0,
28 rnn_backward,
29} rnn_propagation_t;
30
31typedef enum {
32 left2right = 0,
33 right2left,
34} rnn_iter_direction_t;
35
36typedef enum {
37 bottom2top = 0,
38 top2bottom,
39} rnn_layer_direction_t;
40
41typedef enum { action_copy = 0, action_sum, action_concat } rnn_action_t;
42
43dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
44 const prb_t &prb, dnnl_prop_kind_t prop_kind,
45 const_dnnl_memory_desc_t src_layer_d,
46 const_dnnl_memory_desc_t src_iter_d,
47 const_dnnl_memory_desc_t src_iter_c_d,
48 const_dnnl_memory_desc_t attention_d,
49 const_dnnl_memory_desc_t weights_layer_d,
50 const_dnnl_memory_desc_t weights_iter_d,
51 const_dnnl_memory_desc_t weights_peephole_d,
52 const_dnnl_memory_desc_t weights_projection_d,
53 const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
54 const_dnnl_memory_desc_t dst_iter_d,
55 const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr);
56
57dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
58 const prb_t &prb, dnnl_prop_kind_t prop_kind,
59 const_dnnl_memory_desc_t src_layer_d,
60 const_dnnl_memory_desc_t src_iter_d,
61 const_dnnl_memory_desc_t src_iter_c_d,
62 const_dnnl_memory_desc_t attention_d,
63 const_dnnl_memory_desc_t weights_layer_d,
64 const_dnnl_memory_desc_t weights_iter_d,
65 const_dnnl_memory_desc_t weights_peephole_d,
66 const_dnnl_memory_desc_t weights_projection_d,
67 const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
68 const_dnnl_memory_desc_t dst_iter_d,
69 const_dnnl_memory_desc_t dst_iter_c_d,
70 const_dnnl_memory_desc_t diff_src_layer_d,
71 const_dnnl_memory_desc_t diff_src_iter_d,
72 const_dnnl_memory_desc_t diff_src_iter_c_d,
73 const_dnnl_memory_desc_t diff_attention_d,
74 const_dnnl_memory_desc_t diff_weights_layer_d,
75 const_dnnl_memory_desc_t diff_weights_iter_d,
76 const_dnnl_memory_desc_t diff_weights_peephole_d,
77 const_dnnl_memory_desc_t diff_weights_projection_d,
78 const_dnnl_memory_desc_t diff_bias_d,
79 const_dnnl_memory_desc_t diff_dst_layer_d,
80 const_dnnl_memory_desc_t diff_dst_iter_d,
81 const_dnnl_memory_desc_t diff_dst_iter_c_d,
82 const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr);
83
84void init_buffer(float *buf, int64_t size, float value);
85
86float maybe_q(const prb_t &prb, float h);
87float maybe_deq(const prb_t &prb, const float in);
88float maybe_deq(const prb_t &prb, const float in, int64_t oc);
89float maybe_deq(
90 const prb_t &prb, const float in, float scale, float compensation);
91float maybe_deq_proj(
92 const prb_t &prb, const float in, float compensation, int64_t oc);
93
94float logistic(float x);
95float dlogistic(float x);
96float relu(float x, float alpha);
97float drelu(float x, float alpha);
98float dtanhf(float x);
99float one_m_square(float x);
100float x_m_square(float x);
101
102void copy(int64_t dimc, int64_t dimr, int64_t ld_src, int64_t ld_dst,
103 const float *src_, float *dst_, rnn_action_t action = action_copy,
104 bool saturate_to_u8 = false);
105void data_q10n(int64_t dimc, int64_t dimr, int64_t ld_src, float *src_,
106 float data_scale, float data_shift);
107void data_deq10n(int64_t dimc, int64_t dimr, int64_t ld_src, float *src_,
108 float data_scale, float data_shift);
109void gates_reduction(
110 const prb_t &prb, const float *b_gates_, float *diff_bias_);
111
112rnn_data_kind_t data_kind2rnn_data_kind(data_kind_t data_kind);
113
114}; // namespace rnn
115
116#endif
117