1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution GRU
19 */
20
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23
24#include "cpu/rnn/ref_rnn.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30using namespace dnnl::impl::utils;
31using namespace dnnl::impl::math;
32using namespace rnn_utils;
33
34#define AOC array_offset_calculator
35template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
36 data_type_t acc_type>
37rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
38 acc_type>::cell_execution_gru)) {
39 const ws_gates_aoc<gates_t> ws_gates(rnn, ws_gates_);
40 const scratch_gates_aoc<scratch_t> scratch_gates(rnn, scratch_gates_);
41 const auto weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
42
43 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
44 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
45 const auto dst_iter_part2_ld = rnn.dst_iter_part2_ld(cell_position);
46
47 // 1. gemm Wx[0-2],x
48 if (rnn.need_gemm_layer(cell_position)) {
49 CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb,
50 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, src_layer_,
51 src_layer_ld, 0.0f, scratch_gates_, rnn.scratch_gates_ld));
52 }
53
54 // 2. gemm Wh[0-1],h
55 CHECK((this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dhc, rnn.mb,
56 rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, src_iter_,
57 src_iter_ld, 1.0f, scratch_gates_, rnn.scratch_gates_ld));
58
59 // 3. activation zt and rt + elemwise multiplication rt,ht-1
60 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
61 augru_attention_, dst_layer_, nullptr, src_iter_, nullptr,
62 diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr,
63 diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, bias_[0],
64 nullptr, nullptr, dst_iter_, weights_scales, rnn.dhc);
65
66 // 4. gemm Wh[2],h~t
67 CHECK((this->*gemm_iter_func)('N', 'N', rnn.dhc, rnn.mb, rnn.sic, 1.0,
68 w_iter_[1], rnn.weights_iter_ld, dst_layer_, dst_iter_part2_ld, 1.0,
69 &(scratch_gates(0, 2, 0)), rnn.scratch_gates_ld));
70
71 // 5. activation h~t + calculate ht
72 rnn_postgemm_->execute_part2(rnn, cell_position, ws_gates_, scratch_gates_,
73 augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_,
74 diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr,
75 diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, bias_[0],
76 nullptr, nullptr, dst_iter_, weights_scales, rnn.dhc);
77
78 return dnnl_success;
79}
80
81template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
82template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru);
83template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
84template rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru);
85
86template <typename T1, typename T2, typename T3, typename T4, typename T5,
87 typename weights_data_t, typename src_data_t, typename acc_data_t,
88 typename scratch_data_t>
89dnnl_status_t gru_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f,
90 T3 gemm_weights_layer_f, T4 gemm_weights_iter_f, T5 rnn_postgemm_,
91 const rnn_utils::rnn_conf_t &rnn, cell_position_t cell_position,
92 src_data_t *ws_gates_, scratch_data_t *scratch_gates_,
93 src_data_t *dst_layer_, const src_data_t *src_iter_,
94 const src_data_t *src_layer_, const src_data_t *augru_attention_,
95 weights_data_t **w_layer_, weights_data_t **w_iter_,
96 acc_data_t *diff_w_layer_, acc_data_t *diff_w_iter_,
97 acc_data_t *diff_src_layer_, acc_data_t *diff_augru_attention_,
98 acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_,
99 acc_data_t *diff_dst_layer_, acc_data_t *diff_bias_,
100 scratch_data_t *scratch_cell_, src_data_t *dst_iter_) {
101 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
102 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
103
104 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
105 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
106 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
107 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
108 const ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_,
109 (cell_position & last_layer) ? dst_layer_ld : dst_iter_ld);
110 const ws_states_iter_aoc<const src_data_t> src_iter(
111 rnn, src_iter_, src_iter_ld);
112 const ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
113
114 // use state memory for intermediate computations
115 // TODO: use cell ws for that
116 float *dhG1_ = diff_src_layer_;
117 const AOC<acc_data_t, 2> dhG1(
118 dhG1_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
119 // hg1 needs to be bf16 as it is used as gemm output
120 // hence it cannot alias to dhG1, and should use scratch_cell
121 const AOC<scratch_data_t, 2> hG1(
122 scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
123
124 // 1. calculate dG2, dG1, and part of dht-1
125 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
126 augru_attention_, dst_layer_, nullptr, src_iter_, nullptr,
127 diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr,
128 diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, nullptr, nullptr,
129 scratch_cell_, dst_iter_, nullptr, 0);
130
131 // 2. calculate intermediate d(hG1)
132 // d(hG1) = dG2 * W2h^t
133 CHECK(gemm_iter_f(rnn.sic, rnn.mb, rnn.dhc, w_iter_[1],
134 &(scratch_gates(0, 2, 0)), 0.0f, dhG1_));
135
136 // 3. calculate dG1^ and part of dht-1
137 rnn_postgemm_->execute_part2(rnn, cell_position, ws_gates_, scratch_gates_,
138 augru_attention_, dst_layer_, nullptr, src_iter_, nullptr,
139 diff_src_layer_, diff_augru_attention_, diff_src_iter_, nullptr,
140 diff_dst_layer_, diff_dst_iter_, nullptr, nullptr, nullptr, nullptr,
141 scratch_cell_, dst_iter_, nullptr, 0);
142
143 // 4. calculate diff weights
144 // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
145 CHECK(gemm_weights_iter_f((rnn.n_gates - 1) * rnn.dhc, rnn.sic, rnn.mb,
146 scratch_gates_, src_iter_, src_iter_ld, 1.0f, diff_w_iter_));
147 CHECK(gemm_weights_iter_f(rnn.dhc, rnn.sic, rnn.mb,
148 &(scratch_gates(0, 2, 0)), scratch_cell_, rnn.ws_states_layer_ld,
149 1.0f, &(diff_w_iter(0, 2, 0))));
150
151 // 5. calculate diff states
152 // dht-1 += dG1 * W1h + dG0 * W0h
153 CHECK(gemm_iter_f(rnn.sic, rnn.mb, (rnn.n_gates - 1) * rnn.dhc, w_iter_[0],
154 scratch_gates_, 1.0f, diff_src_iter_));
155
156 // dWx += [dG0 dG1 dG2] * [x]
157 if (rnn.need_gemm_layer(cell_position))
158 CHECK(gemm_weights_layer_f(
159 scratch_gates_, src_layer_, src_layer_ld, diff_w_layer_));
160
161 // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
162 if (!rnn.merge_gemm_layer)
163 CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_));
164
165 // 6. calculate diff bias
166 gates_reduction(rnn, scratch_gates_, diff_bias_);
167
168 return dnnl_success;
169}
170
171template <>
172rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
173 auto gemm_iter_f
174 = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B,
175 float beta, gemm_acc_t *C) {
176 return (this->*gemm_iter_func)('N', 'N', m, n, k, 1.0f, A,
177 rnn.weights_iter_ld, B, rnn.scratch_gates_ld, beta, C,
178 rnn.ws_diff_states_iter_ld);
179 };
180 auto gemm_layer_f = [&](const weights_t *A, const gemm_data_t *B,
181 gemm_acc_t *C) {
182 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
183 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B,
184 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld);
185 };
186 auto gemm_weights_layer_f = [&](const gemm_data_t *A, const weights_t *B,
187 int ldb, gemm_acc_t *C) {
188 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, 1.0, A,
189 rnn.scratch_gates_ld, B, ldb, 1.0, C,
190 rnn.diff_weights_layer_ld);
191 };
192 auto gemm_weights_iter_f
193 = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B,
194 int ldb, float beta, gemm_acc_t *C) {
195 return gemm('N', 'T', m, n, k, 1.0f, A, rnn.ws_gates_ld, B,
196 ldb, 1.0f, C, rnn.diff_weights_iter_ld);
197 };
198
199 return gru_bwd_cell_exec_template(gemm_layer_f, gemm_iter_f,
200 gemm_weights_layer_f, gemm_weights_iter_f, this->rnn_postgemm_, rnn,
201 cell_position, ws_gates_, scratch_gates_, dst_layer_, src_iter_,
202 src_layer_, augru_attention_, w_layer_, w_iter_, diff_w_layer_,
203 diff_w_iter_, diff_src_layer_, diff_augru_attention_,
204 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, diff_bias_,
205 scratch_cell_, dst_iter_);
206}
207
208template <>
209rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru) {
210 auto gemm_iter_f
211 = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B,
212 float beta, gemm_acc_t *C) {
213 return (this->*gemm_iter_func)('N', 'N', m, n, k, 1.0f, A,
214 rnn.weights_iter_ld, B, rnn.scratch_gates_ld, beta, C,
215 rnn.ws_diff_states_iter_ld);
216 };
217 auto gemm_layer_f = [&](const weights_t *A, const gemm_data_t *B,
218 gemm_acc_t *C) {
219 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
220 rnn.n_gates * rnn.dhc, 1.0, A, rnn.weights_layer_ld, B,
221 rnn.scratch_gates_ld, 0.0, C, rnn.ws_diff_states_layer_ld);
222 };
223 auto gemm_weights_layer_f = [&](const gemm_data_t *A, const weights_t *B,
224 int ldb, gemm_acc_t *C) {
225 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb, 1.0, A,
226 rnn.scratch_gates_ld, B, ldb, 1.0, C,
227 rnn.diff_weights_layer_ld);
228 };
229 auto gemm_weights_iter_f
230 = [&](int m, int n, int k, const weights_t *A, const gemm_data_t *B,
231 int ldb, float beta, gemm_acc_t *C) {
232 return gemm('N', 'T', m, n, k, 1.0f, A, rnn.ws_gates_ld, B,
233 ldb, 1.0f, C, rnn.diff_weights_iter_ld);
234 };
235
236 return gru_bwd_cell_exec_template(gemm_layer_f, gemm_iter_f,
237 gemm_weights_layer_f, gemm_weights_iter_f, this->rnn_postgemm_, rnn,
238 cell_position, ws_gates_, scratch_gates_, dst_layer_, src_iter_,
239 src_layer_, augru_attention_, w_layer_, w_iter_, diff_w_layer_,
240 diff_w_iter_, diff_src_layer_, diff_augru_attention_,
241 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, diff_bias_,
242 scratch_cell_, dst_iter_);
243}
244
245#undef AOC
246} // namespace cpu
247} // namespace impl
248} // namespace dnnl
249