1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution of Vanilla RNN |
19 | */ |
20 | |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | |
24 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | using namespace dnnl::impl::utils; |
31 | using namespace dnnl::impl::math; |
32 | using namespace rnn_utils; |
33 | |
34 | template <> |
35 | float activation<alg_kind::eltwise_relu, prop_kind::forward>( |
36 | float s, float alpha, float cliping) { |
37 | return relu_fwd<float>(s, alpha); |
38 | } |
39 | |
40 | template <> |
41 | float activation<alg_kind::eltwise_relu, prop_kind::backward>( |
42 | float s, float alpha, float cliping) { |
43 | return relu_bwd<float>(s, alpha); |
44 | } |
45 | |
46 | template <> |
47 | float activation<alg_kind::eltwise_tanh, prop_kind::forward>( |
48 | float s, float alpha, float cliping) { |
49 | return tanh_fwd<float>(s); |
50 | } |
51 | |
52 | template <> |
53 | float activation<alg_kind::eltwise_tanh, prop_kind::backward>( |
54 | float s, float alpha, float cliping) { |
55 | return one_m_square<float>(s); |
56 | } |
57 | |
58 | template <> |
59 | float activation<alg_kind::eltwise_logistic, prop_kind::forward>( |
60 | float s, float alpha, float cliping) { |
61 | return logistic_fwd<float>(s); |
62 | } |
63 | |
64 | template <> |
65 | float activation<alg_kind::eltwise_logistic, prop_kind::backward>( |
66 | float s, float alpha, float cliping) { |
67 | return x_m_square<float>(s); |
68 | } |
69 | |
70 | constexpr float linear(float s, float alpha, float clipping) { |
71 | return alpha * s; |
72 | } |
73 | |
74 | template <typename T, typename src_data_t, typename scratch_data_t> |
75 | void rnn_fwd_postgemm_template(T func1, const float *scales, float alpha, |
76 | const rnn_utils::rnn_conf_t &rnn, |
77 | rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_, |
78 | scratch_data_t *scratch_gates_, src_data_t *dst_layer_, |
79 | src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_, |
80 | int block_step) { |
81 | |
82 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
83 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
84 | const auto bias_aoc = rnn_utils::make_raw_aoc( |
85 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
86 | const auto bias = [&](int gate_id, int dhc_id) { |
87 | return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt); |
88 | }; |
89 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position); |
90 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
91 | const ws_states_layer_aoc<src_data_t> dst_layer( |
92 | rnn, dst_layer_, dst_layer_ld); |
93 | const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld); |
94 | |
95 | if (scales != nullptr) alpha = scales[0]; |
96 | |
97 | const int n_elem = block_step / sizeof(scratch_data_t); |
98 | |
99 | const auto postgemm_call = [&](dim_t i) { |
100 | for (int j = 0; j < n_elem; j++) { |
101 | const float h |
102 | = func1(scratch_gates(i, 0, j) + bias(0, j), alpha, 0); |
103 | if (dst_layer_ != nullptr) dst_layer(i, j) = h; |
104 | if (dst_iter_ != nullptr) dst_iter(i, j) = h; |
105 | if (rnn.is_training) ws_gates(i, 0, j) = h; |
106 | } |
107 | }; |
108 | |
109 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
110 | for (int i = 0; i < rnn.m_block; i++) |
111 | postgemm_call(i); |
112 | } else |
113 | parallel_nd(rnn.mb, postgemm_call); |
114 | } |
115 | |
116 | template <> |
117 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::rnn_postgemm) { |
118 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
119 | const auto act_f = [this](float a, float alpha, float clipping) { |
120 | return this->activation_func(a, alpha, clipping); |
121 | }; |
122 | const auto linear_f = [](float a, float alpha, float clipping) { |
123 | return linear(a, alpha, clipping); |
124 | }; |
125 | const auto alpha = pd_->desc()->alpha; |
126 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
127 | rnn_fwd_postgemm_template(act_f, nullptr, alpha, rnn, cell_position, |
128 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_, |
129 | bias_, block_step); |
130 | else |
131 | rnn_fwd_postgemm_template(linear_f, scales, alpha, rnn, cell_position, |
132 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_, |
133 | bias_, block_step); |
134 | } |
135 | |
136 | template <> |
137 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::rnn_postgemm) { |
138 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
139 | const auto act_f = [this](float a, float alpha, float clipping) { |
140 | return bfloat16_t(this->activation_func(a, alpha, clipping)); |
141 | }; |
142 | const auto linear_f = [](float a, float alpha, float clipping) { |
143 | return bfloat16_t(linear(a, alpha, clipping)); |
144 | }; |
145 | const auto alpha = pd_->desc()->alpha; |
146 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
147 | rnn_fwd_postgemm_template(act_f, nullptr, alpha, rnn, cell_position, |
148 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_, |
149 | bias_, block_step); |
150 | else |
151 | rnn_fwd_postgemm_template(linear_f, scales, alpha, rnn, cell_position, |
152 | ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_, |
153 | bias_, block_step); |
154 | } |
155 | |
156 | template <> |
157 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::rnn_postgemm) { |
158 | assert(!"VANILLA RNN int8 is not supported" ); |
159 | } |
160 | |
161 | template <> |
162 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::rnn_postgemm) { |
163 | assert(!"VANILLA RNN int8 is not supported" ); |
164 | } |
165 | |
166 | template <typename T1, typename T2, typename src_data_t, typename acc_data_t, |
167 | typename scratch_data_t> |
168 | void rnn_bwd_postgemm_template(T1 func1, T2 to_src, const float *scales, |
169 | float alpha, const rnn_utils::rnn_conf_t &rnn, src_data_t *ws_gates_, |
170 | scratch_data_t *scratch_gates_, acc_data_t *diff_dst_iter_, |
171 | acc_data_t *diff_dst_layer_) { |
172 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
173 | const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
174 | const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter( |
175 | rnn, diff_dst_iter_); |
176 | const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer( |
177 | rnn, diff_dst_layer_); |
178 | if (scales != nullptr) alpha = scales[0]; |
179 | |
180 | parallel_nd(rnn.mb, [&](dim_t i) { |
181 | for (int j = 0; j < rnn.dhc; ++j) { |
182 | const float dH = diff_dst_layer(i, j) + diff_dst_iter(i, j); |
183 | const auto g = (float)ws_gates(i, 0, j); |
184 | const float res = dH * func1(g, alpha, 0); |
185 | src_data_t res_converted = to_src(res); |
186 | scratch_gates(i, 0, j) = res_converted; |
187 | } |
188 | }); |
189 | } |
190 | |
191 | template <> |
192 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::rnn_postgemm) { |
193 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
194 | const auto act_f = [this](float a, float alpha, float clipping) { |
195 | return this->activation_func(a, alpha, 0); |
196 | }; |
197 | const auto linear_f = [](float a, float alpha, float clipping) { |
198 | return linear(a, alpha, 0); |
199 | }; |
200 | const auto to_src = [&](float a) { return a; }; |
201 | const auto alpha = pd_->desc()->alpha; |
202 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
203 | rnn_bwd_postgemm_template(act_f, to_src, nullptr, alpha, rnn, ws_gates_, |
204 | scratch_gates_, diff_dst_iter_, diff_dst_layer_); |
205 | else |
206 | rnn_bwd_postgemm_template(linear_f, to_src, scales, alpha, rnn, |
207 | ws_gates_, scratch_gates_, diff_dst_iter_, diff_dst_layer_); |
208 | } |
209 | |
210 | template <> |
211 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::rnn_postgemm) { |
212 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
213 | const auto act_f = [this](float a, float alpha, float clipping) { |
214 | return this->activation_func(a, alpha, 0); |
215 | }; |
216 | const auto linear_f = [](float a, float alpha, float clipping) { |
217 | return linear(a, alpha, 0); |
218 | }; |
219 | const auto to_src = [&](float a) { return bfloat16_t(a); }; |
220 | const auto alpha = pd_->desc()->alpha; |
221 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
222 | rnn_bwd_postgemm_template(act_f, to_src, nullptr, alpha, rnn, ws_gates_, |
223 | scratch_gates_, diff_dst_iter_, diff_dst_layer_); |
224 | else |
225 | rnn_bwd_postgemm_template(linear_f, to_src, scales, alpha, rnn, |
226 | ws_gates_, scratch_gates_, diff_dst_iter_, diff_dst_layer_); |
227 | } |
228 | |
229 | } // namespace cpu |
230 | } // namespace impl |
231 | } // namespace dnnl |
232 | |