1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution LSTM
19 */
20
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23
24#include "cpu/rnn/postgemm_dispatcher.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30using namespace dnnl::impl::utils;
31using namespace dnnl::impl::math;
32using namespace rnn_utils;
33#define AOC array_offset_calculator
34
35template <typename T1, typename T2, typename T3, typename src_data_t,
36 typename scratch_data_t>
37void gru_lbr_fwd_postgemm_template(T1 func1, T2 func2, T3 to_src,
38 const float *scales, const rnn_utils::rnn_conf_t &rnn,
39 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
40 scratch_data_t *scratch_gates_, const src_data_t *augru_attention_,
41 src_data_t *dst_layer_, src_data_t *dst_iter_,
42 const src_data_t *src_iter_, const void *bias_, src_data_t *ws_grid_,
43 scratch_data_t *scratch_cell_) {
44
45 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
46 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
47 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
48
49 const augru_attention_aoc<const src_data_t> augru_attention(
50 rnn, augru_attention_);
51 const ws_states_layer_aoc<src_data_t> dst_layer(
52 rnn, dst_layer_, dst_layer_ld);
53 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
54 const ws_states_iter_aoc<const src_data_t> src_iter(
55 rnn, src_iter_, src_iter_ld);
56 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
57 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
58 const auto bias_aoc = rnn_utils::make_raw_aoc(
59 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
60
61 const auto bias = [&](int gate_id, int dhc_id) {
62 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
63 };
64 const ws_gates_aoc<scratch_data_t> scratch_cell(rnn, scratch_cell_);
65 const AOC<src_data_t, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dhc);
66
67 const auto get_scales = [](const float *scales, int idx) {
68 return scales ? scales + idx : nullptr;
69 };
70 const float *scales_G1 = get_scales(scales, 1);
71 const float *scales_G2 = get_scales(scales, 2);
72
73 parallel_nd(rnn.mb, [&](dim_t i) {
74 PRAGMA_OMP_SIMD()
75 for (int j = 0; j < rnn.dhc; j++) {
76 const float Wh_b = scratch_cell(i, 2, j) + bias(3, j);
77 auto G0 = func1(scales, // default func1 is sigmoid
78 scratch_gates(i, 0, j) + scratch_cell(i, 0, j)
79 + bias(0, j));
80 const auto G1 = func1(scales_G1, // default func1 is sigmoid
81 scratch_gates(i, 1, j) + scratch_cell(i, 1, j)
82 + bias(1, j));
83 const auto G2 = func2(scales_G2, // default func2 is tanh
84 scratch_gates(i, 2, j) + G1 * Wh_b + bias(2, j));
85 if (rnn.is_training) {
86 ws_gates(i, 0, j) = to_src(G0);
87 ws_gates(i, 1, j) = to_src(G1);
88 ws_gates(i, 2, j) = to_src(G2);
89 ws_Wh_b(i, j) = to_src(Wh_b);
90 }
91 if (rnn.is_augru) {
92 const auto a = to_src(augru_attention(i));
93 G0 = (1.0f - a) * G0;
94 }
95 const auto tmp = to_src(src_iter(i, j) * G0 + (1.0f - G0) * G2);
96 if (dst_layer_ != nullptr) dst_layer(i, j) = tmp;
97 if (dst_iter_ != nullptr) dst_iter(i, j) = tmp;
98 }
99 });
100}
101
102template <>
103rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_lbr_postgemm) {
104 const float *scales = pd_->attr()->rnn_tparams_.scales_;
105
106 const auto linear_f
107 = [](const float *scale, float a) { return *scale * a; };
108 const auto logistic_f = [](const float *scale, float a) {
109 return logistic_fwd<float>(a);
110 };
111 const auto tanh_f
112 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
113 const auto to_src = [](float a) { return a; };
114
115 if (!pd_->attr()->rnn_tparams_.test_mode_)
116 gru_lbr_fwd_postgemm_template(logistic_f, tanh_f, to_src, scales, rnn,
117 cell_position, ws_gates_, scratch_gates_, augru_attention_,
118 dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_,
119 scratch_cell_);
120 else
121 gru_lbr_fwd_postgemm_template(linear_f, linear_f, to_src, scales, rnn,
122 cell_position, ws_gates_, scratch_gates_, augru_attention_,
123 dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_,
124 scratch_cell_);
125}
126
127template <>
128rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_lbr_postgemm) {
129 const float *scales = pd_->attr()->rnn_tparams_.scales_;
130
131 const auto linear_f
132 = [](const float *scale, float a) { return *scale * a; };
133 const auto logistic_f = [](const float *scale, float a) {
134 return logistic_fwd<float>(a);
135 };
136 const auto tanh_f
137 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
138 const auto to_src = [](float a) { return bfloat16_t(a); };
139
140 if (!pd_->attr()->rnn_tparams_.test_mode_)
141 gru_lbr_fwd_postgemm_template(logistic_f, tanh_f, to_src, scales, rnn,
142 cell_position, ws_gates_, scratch_gates_, augru_attention_,
143 dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_,
144 scratch_cell_);
145 else
146 gru_lbr_fwd_postgemm_template(linear_f, linear_f, to_src, scales, rnn,
147 cell_position, ws_gates_, scratch_gates_, augru_attention_,
148 dst_layer_, dst_iter_, src_iter_, bias_, ws_grid_,
149 scratch_cell_);
150}
151
152template <>
153rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_lbr_postgemm) {
154 assert(!"GRU LBR int8 is not supported");
155}
156
157template <>
158rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_lbr_postgemm) {
159 assert(!"GRU LBR signed int8 is not supported");
160}
161
162template <typename T1, typename src_data_t, typename acc_data_t,
163 typename scratch_data_t>
164void gru_lbr_bwd_postgemm_template(T1 to_src, const rnn_utils::rnn_conf_t &rnn,
165 cell_position_t cell_position, src_data_t *ws_gates_,
166 scratch_data_t *scratch_gates_, const src_data_t *augru_attention_,
167 const src_data_t *src_iter_, acc_data_t *diff_src_iter_,
168 acc_data_t *diff_dst_iter_, acc_data_t *diff_augru_attention_,
169 acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_,
170 src_data_t *ws_grid_) {
171 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
172
173 const augru_attention_aoc<const src_data_t> augru_attention(
174 rnn, augru_attention_);
175 const augru_attention_aoc<acc_data_t> diff_augru_attention(
176 rnn, diff_augru_attention_);
177
178 const ws_states_iter_aoc<const src_data_t> src_iter(
179 rnn, src_iter_, src_iter_ld);
180 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
181 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
182 const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
183 rnn, diff_src_iter_);
184 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
185 rnn, diff_dst_iter_);
186 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
187 rnn, diff_dst_layer_);
188 const ws_gates_aoc<scratch_data_t> scratch_gates_r(rnn, scratch_cell_);
189 const AOC<src_data_t, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dhc);
190
191 // 1. calculate dG1 dG2 dG3
192 // dG0 = (dht - G2) * dht * (1 - G0) * G0
193 // dG1 = (W*h + b) * dG2 * (1 - G1) * G1
194 // dG2 = (1 - G0) * dht * (1 - G2*G2)
195 parallel_nd(rnn.mb, [&](dim_t i) {
196 acc_data_t diff_attention = 0.0f;
197 PRAGMA_OMP_SIMD(reduction(+ : diff_attention))
198 for (int j = 0; j < rnn.dhc; j++) {
199 const float h = src_iter(i, j);
200 const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j);
201 float dG0 = (h - ws_gates(i, 2, j)) * dHt
202 * x_m_square(ws_gates(i, 0, j));
203 const float dG2 = (1.0f - ws_gates(i, 0, j))
204 * one_m_square(ws_gates(i, 2, j)) * dHt;
205 const float dG1
206 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j));
207
208 if (rnn.is_augru) {
209 diff_attention -= dG0 * ws_gates(i, 0, j);
210 dG0 *= 1.0f - augru_attention(i);
211 }
212
213 diff_src_iter(i, j) = dHt * ws_gates(i, 0, j);
214 scratch_gates(i, 2, j) = to_src(dG2);
215 scratch_gates_r(i, 2, j) = to_src(dG2 * ws_gates(i, 1, j));
216 scratch_gates(i, 0, j) = scratch_gates_r(i, 0, j) = to_src(dG0);
217 scratch_gates(i, 1, j) = scratch_gates_r(i, 1, j) = to_src(dG1);
218 }
219 if (rnn.is_augru) diff_augru_attention(i) = diff_attention;
220 });
221}
222
223template <>
224rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_lbr_postgemm) {
225 auto to_src = [&](float a) { return a; };
226 gru_lbr_bwd_postgemm_template(to_src, rnn, cell_position, ws_gates_,
227 scratch_gates_, augru_attention_, src_iter_, diff_src_iter_,
228 diff_dst_iter_, diff_augru_attention_, diff_dst_layer_,
229 scratch_cell_, ws_grid_);
230}
231
232template <>
233rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_lbr_postgemm) {
234 auto to_src = [&](float a) { return bfloat16_t(a); };
235 gru_lbr_bwd_postgemm_template(to_src, rnn, cell_position, ws_gates_,
236 scratch_gates_, augru_attention_, src_iter_, diff_src_iter_,
237 diff_dst_iter_, diff_augru_attention_, diff_dst_layer_,
238 scratch_cell_, ws_grid_);
239}
240
241} // namespace cpu
242} // namespace impl
243} // namespace dnnl
244