1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution GRU with linear before reset
19 */
20#pragma warning(disable : 4503) /* name is too long */
21
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24
25#include "cpu/rnn/ref_rnn.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31using namespace dnnl::impl::utils;
32using namespace dnnl::impl::math;
33using namespace rnn_utils;
34#define AOC array_offset_calculator
35
36template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
37 data_type_t acc_type>
38rnn_cell_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
39 acc_type>::cell_execution_gru_lbr)) {
40 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
41 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
42
43 if (rnn.need_gemm_layer(cell_position)) {
44 CHECK((this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb,
45 rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld, src_layer_,
46 src_layer_ld, 0.0, scratch_gates_, rnn.scratch_gates_ld));
47 }
48 CHECK((this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dhc, rnn.mb,
49 rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, src_iter_,
50 src_iter_ld, 0.0, scratch_cell_, rnn.ws_gates_ld));
51
52 rnn_postgemm_->execute(rnn, cell_position, ws_gates_, scratch_gates_,
53 augru_attention_, dst_layer_, dst_iter_c_, src_iter_, src_iter_c_,
54 diff_src_layer_, diff_augru_attention_, diff_src_iter_,
55 diff_src_iter_c_, diff_dst_layer_, diff_dst_iter_, nullptr, nullptr,
56 bias_[0], ws_grid_, scratch_cell_, dst_iter_, nullptr, 0);
57
58 return dnnl_success;
59}
60
61template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
62template rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru_lbr);
63template <>
64rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) {
65 assert(!"GRU LBR int8 is not supported");
66 return dnnl_unimplemented;
67}
68
69template <>
70rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru_lbr) {
71 assert(!"GRU LBR int8 is not supported");
72 return dnnl_unimplemented;
73}
74
75template <typename T1, typename T2, typename T3, typename T4, typename T5,
76 typename weights_data_t, typename src_data_t, typename acc_data_t,
77 typename scratch_data_t>
78dnnl_status_t common_bwd_cell_exec_template(T1 gemm_layer_f, T2 gemm_iter_f,
79 T3 gemm_weights_layer_f, T4 gemm_weights_iter_f, T5 rnn_postgemm,
80 const rnn_utils::rnn_conf_t &rnn, cell_position_t cell_position,
81 src_data_t *dst_layer_, acc_data_t *diff_src_layer_,
82 acc_data_t *diff_augru_attention_, acc_data_t *diff_src_iter_,
83 weights_data_t **w_layer_, weights_data_t **w_iter_, void **bias_,
84 const src_data_t *src_layer_, const src_data_t *augru_attention_,
85 const src_data_t *src_iter_, acc_data_t *diff_dst_layer_,
86 acc_data_t *diff_dst_iter_, acc_data_t *diff_w_layer_,
87 acc_data_t *diff_w_iter_, acc_data_t *diff_bias_, src_data_t *ws_gates_,
88 src_data_t *ws_grid_, scratch_data_t *scratch_gates_,
89 scratch_data_t *scratch_cell_, src_data_t *dst_iter_) {
90 const auto src_layer_ld = rnn.src_layer_ld(cell_position);
91 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
92
93 const ws_gates_aoc<scratch_data_t> scratch_gates_r(rnn, scratch_cell_);
94
95 rnn_postgemm->execute(rnn, cell_position, ws_gates_, scratch_gates_,
96 augru_attention_, dst_layer_, nullptr, src_iter_, nullptr, nullptr,
97 diff_augru_attention_, diff_src_iter_, nullptr, diff_dst_layer_,
98 diff_dst_iter_, nullptr, nullptr, bias_[0], ws_grid_, scratch_cell_,
99 dst_iter_, nullptr, 0);
100
101 // dWx += dG^t * x
102 if (rnn.need_gemm_layer(cell_position))
103 CHECK(gemm_weights_layer_f(
104 scratch_gates_, src_layer_, src_layer_ld, diff_w_layer_));
105
106 // dx = dG * Wx^t
107 if (!rnn.merge_gemm_layer)
108 CHECK(gemm_layer_f(w_layer_[0], scratch_gates_, diff_src_layer_));
109
110 // dh += dGr * Wh^t
111 CHECK(gemm_iter_f(w_iter_[0], scratch_cell_, diff_src_iter_));
112
113 // dWh += dGr^t * h
114 CHECK(gemm_weights_iter_f(
115 scratch_cell_, src_iter_, src_iter_ld, diff_w_iter_));
116
117 // db1-3 += e * dG
118 // db4 += e * (r * dG2)
119 gates_reduction(rnn, scratch_gates_, diff_bias_);
120
121 parallel_nd(rnn.dhc, [&](dim_t j) {
122 for (int i = 0; i < rnn.mb; i++) {
123 diff_bias_[3 * rnn.dhc + j] += scratch_gates_r(i, 2, j);
124 }
125 });
126
127 return dnnl_success;
128}
129
130#undef AOC
131
132template <>
133rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) {
134 const auto gemm_layer = [&](const float *A, const float *B, float *C) {
135 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
136 rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_layer_ld, B,
137 rnn.scratch_gates_ld, 0.0f, C, rnn.ws_diff_states_layer_ld);
138 };
139 const auto gemm_iter = [&](const float *A, const float *B, float *C) {
140 return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
141 rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_iter_ld, B,
142 rnn.ws_gates_ld, 1.0f, C, rnn.ws_diff_states_iter_ld);
143 };
144 const auto gemm_weights_layer
145 = [&](const float *A, const float *B, int ldb, float *C) {
146 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb,
147 1.0f, A, rnn.scratch_gates_ld, B, ldb, 1.0f, C,
148 rnn.diff_weights_layer_ld);
149 };
150 const auto gemm_weights_iter = [&](const float *A, const float *B, int ldb,
151 float *C) {
152 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, 1.0f, A,
153 rnn.ws_gates_ld, B, ldb, 1.0f, C, rnn.diff_weights_iter_ld);
154 };
155
156 common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_weights_layer,
157 gemm_weights_iter, rnn_postgemm_, rnn, cell_position, dst_layer_,
158 diff_src_layer_, diff_augru_attention_, diff_src_iter_, w_layer_,
159 w_iter_, bias_, src_layer_, augru_attention_, src_iter_,
160 diff_dst_layer_, diff_dst_iter_, diff_w_layer_, diff_w_iter_,
161 diff_bias_, ws_gates_, ws_grid_, scratch_gates_, scratch_cell_,
162 dst_iter_);
163
164 return dnnl_success;
165}
166
167template <>
168rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru_lbr) {
169 const auto gemm_layer = [&](const bfloat16_t *A, const bfloat16_t *B,
170 float *C) {
171 return (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
172 rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_layer_ld, B,
173 rnn.scratch_gates_ld, 0.0f, C, rnn.ws_diff_states_layer_ld);
174 };
175 const auto gemm_iter = [&](const bfloat16_t *A, const bfloat16_t *B,
176 float *C) {
177 return (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
178 rnn.n_gates * rnn.dhc, 1.0f, A, rnn.weights_iter_ld, B,
179 rnn.ws_gates_ld, 1.0f, C, rnn.ws_diff_states_iter_ld);
180 };
181 const auto gemm_weights_layer
182 = [&](const bfloat16_t *A, const bfloat16_t *B, int ldb, float *C) {
183 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.slc, rnn.mb,
184 1.0f, A, rnn.scratch_gates_ld, B, ldb, 1.0f, C,
185 rnn.diff_weights_layer_ld);
186 };
187 const auto gemm_weights_iter = [&](const bfloat16_t *A, const bfloat16_t *B,
188 int ldb, float *C) {
189 return gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, rnn.mb, 1.0f, A,
190 rnn.ws_gates_ld, B, ldb, 1.0f, C, rnn.diff_weights_iter_ld);
191 };
192
193 common_bwd_cell_exec_template(gemm_layer, gemm_iter, gemm_weights_layer,
194 gemm_weights_iter, rnn_postgemm_, rnn, cell_position, dst_layer_,
195 diff_src_layer_, diff_augru_attention_, diff_src_iter_, w_layer_,
196 w_iter_, bias_, src_layer_, augru_attention_, src_iter_,
197 diff_dst_layer_, diff_dst_iter_, diff_w_layer_, diff_w_iter_,
198 diff_bias_, ws_gates_, ws_grid_, scratch_gates_, scratch_cell_,
199 dst_iter_);
200 return dnnl_success;
201}
202
203} // namespace cpu
204} // namespace impl
205} // namespace dnnl
206