1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution of Vanilla RNN
19 */
20
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23
24#include "cpu/rnn/postgemm_dispatcher.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30using namespace dnnl::impl::utils;
31using namespace dnnl::impl::math;
32using namespace rnn_utils;
33
34template <>
35float activation<alg_kind::eltwise_relu, prop_kind::forward>(
36 float s, float alpha, float cliping) {
37 return relu_fwd<float>(s, alpha);
38}
39
40template <>
41float activation<alg_kind::eltwise_relu, prop_kind::backward>(
42 float s, float alpha, float cliping) {
43 return relu_bwd<float>(s, alpha);
44}
45
46template <>
47float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
48 float s, float alpha, float cliping) {
49 return tanh_fwd<float>(s);
50}
51
52template <>
53float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
54 float s, float alpha, float cliping) {
55 return one_m_square<float>(s);
56}
57
58template <>
59float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
60 float s, float alpha, float cliping) {
61 return logistic_fwd<float>(s);
62}
63
64template <>
65float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
66 float s, float alpha, float cliping) {
67 return x_m_square<float>(s);
68}
69
70constexpr float linear(float s, float alpha, float clipping) {
71 return alpha * s;
72}
73
74template <typename T, typename src_data_t, typename scratch_data_t>
75void rnn_fwd_postgemm_template(T func1, const float *scales, float alpha,
76 const rnn_utils::rnn_conf_t &rnn,
77 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
78 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
79 src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_,
80 int block_step) {
81
82 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
83 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
84 const auto bias_aoc = rnn_utils::make_raw_aoc(
85 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
86 const auto bias = [&](int gate_id, int dhc_id) {
87 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
88 };
89 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
90 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
91 const ws_states_layer_aoc<src_data_t> dst_layer(
92 rnn, dst_layer_, dst_layer_ld);
93 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
94
95 if (scales != nullptr) alpha = scales[0];
96
97 const int n_elem = block_step / sizeof(scratch_data_t);
98
99 const auto postgemm_call = [&](dim_t i) {
100 for (int j = 0; j < n_elem; j++) {
101 const float h
102 = func1(scratch_gates(i, 0, j) + bias(0, j), alpha, 0);
103 if (dst_layer_ != nullptr) dst_layer(i, j) = h;
104 if (dst_iter_ != nullptr) dst_iter(i, j) = h;
105 if (rnn.is_training) ws_gates(i, 0, j) = h;
106 }
107 };
108
109 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
110 for (int i = 0; i < rnn.m_block; i++)
111 postgemm_call(i);
112 } else
113 parallel_nd(rnn.mb, postgemm_call);
114}
115
116template <>
117rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::rnn_postgemm) {
118 const float *scales = pd_->attr()->rnn_tparams_.scales_;
119 const auto act_f = [this](float a, float alpha, float clipping) {
120 return this->activation_func(a, alpha, clipping);
121 };
122 const auto linear_f = [](float a, float alpha, float clipping) {
123 return linear(a, alpha, clipping);
124 };
125 const auto alpha = pd_->desc()->alpha;
126 if (!pd_->attr()->rnn_tparams_.test_mode_)
127 rnn_fwd_postgemm_template(act_f, nullptr, alpha, rnn, cell_position,
128 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_,
129 bias_, block_step);
130 else
131 rnn_fwd_postgemm_template(linear_f, scales, alpha, rnn, cell_position,
132 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_,
133 bias_, block_step);
134}
135
136template <>
137rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::rnn_postgemm) {
138 const float *scales = pd_->attr()->rnn_tparams_.scales_;
139 const auto act_f = [this](float a, float alpha, float clipping) {
140 return bfloat16_t(this->activation_func(a, alpha, clipping));
141 };
142 const auto linear_f = [](float a, float alpha, float clipping) {
143 return bfloat16_t(linear(a, alpha, clipping));
144 };
145 const auto alpha = pd_->desc()->alpha;
146 if (!pd_->attr()->rnn_tparams_.test_mode_)
147 rnn_fwd_postgemm_template(act_f, nullptr, alpha, rnn, cell_position,
148 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_,
149 bias_, block_step);
150 else
151 rnn_fwd_postgemm_template(linear_f, scales, alpha, rnn, cell_position,
152 ws_gates_, scratch_gates_, dst_layer_, dst_iter_, src_iter_,
153 bias_, block_step);
154}
155
156template <>
157rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::rnn_postgemm) {
158 assert(!"VANILLA RNN int8 is not supported");
159}
160
161template <>
162rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::rnn_postgemm) {
163 assert(!"VANILLA RNN int8 is not supported");
164}
165
166template <typename T1, typename T2, typename src_data_t, typename acc_data_t,
167 typename scratch_data_t>
168void rnn_bwd_postgemm_template(T1 func1, T2 to_src, const float *scales,
169 float alpha, const rnn_utils::rnn_conf_t &rnn, src_data_t *ws_gates_,
170 scratch_data_t *scratch_gates_, acc_data_t *diff_dst_iter_,
171 acc_data_t *diff_dst_layer_) {
172 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
173 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
174 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
175 rnn, diff_dst_iter_);
176 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
177 rnn, diff_dst_layer_);
178 if (scales != nullptr) alpha = scales[0];
179
180 parallel_nd(rnn.mb, [&](dim_t i) {
181 for (int j = 0; j < rnn.dhc; ++j) {
182 const float dH = diff_dst_layer(i, j) + diff_dst_iter(i, j);
183 const auto g = (float)ws_gates(i, 0, j);
184 const float res = dH * func1(g, alpha, 0);
185 src_data_t res_converted = to_src(res);
186 scratch_gates(i, 0, j) = res_converted;
187 }
188 });
189}
190
191template <>
192rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::rnn_postgemm) {
193 const float *scales = pd_->attr()->rnn_tparams_.scales_;
194 const auto act_f = [this](float a, float alpha, float clipping) {
195 return this->activation_func(a, alpha, 0);
196 };
197 const auto linear_f = [](float a, float alpha, float clipping) {
198 return linear(a, alpha, 0);
199 };
200 const auto to_src = [&](float a) { return a; };
201 const auto alpha = pd_->desc()->alpha;
202 if (!pd_->attr()->rnn_tparams_.test_mode_)
203 rnn_bwd_postgemm_template(act_f, to_src, nullptr, alpha, rnn, ws_gates_,
204 scratch_gates_, diff_dst_iter_, diff_dst_layer_);
205 else
206 rnn_bwd_postgemm_template(linear_f, to_src, scales, alpha, rnn,
207 ws_gates_, scratch_gates_, diff_dst_iter_, diff_dst_layer_);
208}
209
210template <>
211rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::rnn_postgemm) {
212 const float *scales = pd_->attr()->rnn_tparams_.scales_;
213 const auto act_f = [this](float a, float alpha, float clipping) {
214 return this->activation_func(a, alpha, 0);
215 };
216 const auto linear_f = [](float a, float alpha, float clipping) {
217 return linear(a, alpha, 0);
218 };
219 const auto to_src = [&](float a) { return bfloat16_t(a); };
220 const auto alpha = pd_->desc()->alpha;
221 if (!pd_->attr()->rnn_tparams_.test_mode_)
222 rnn_bwd_postgemm_template(act_f, to_src, nullptr, alpha, rnn, ws_gates_,
223 scratch_gates_, diff_dst_iter_, diff_dst_layer_);
224 else
225 rnn_bwd_postgemm_template(linear_f, to_src, scales, alpha, rnn,
226 ws_gates_, scratch_gates_, diff_dst_iter_, diff_dst_layer_);
227}
228
229} // namespace cpu
230} // namespace impl
231} // namespace dnnl
232