1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18
19#include "utils/parallel.hpp"
20
21#include "rnn/rnn.hpp"
22#include "rnn/rnn_aux.hpp"
23
24#include "rnn/cells.hpp"
25
26namespace rnn {
27
28template <typename T>
29void gru_fwd_postgemm_part1_template(T func1, const prb_t &prb, float *gates_,
30 const float *src_iter_, const float *bias_, float *dst_layer_) {
31 AOC<const float> bias(bias_, prb.n_gates(), prb.dhc);
32 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
33 AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
34 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
35
36 for (int64_t i = 0; i < prb.mb; i++)
37 for (int64_t k = 0; k < prb.dhc; k++) {
38 gates(i, GRU_U, k) = func1(prb.linear_scales[GRU_U],
39 maybe_deq(prb, gates(i, GRU_U, k), GRU_U * prb.dhc + k)
40 + bias(GRU_U, k));
41 gates(i, GRU_R, k) = func1(prb.linear_scales[GRU_R],
42 maybe_deq(prb, gates(i, GRU_R, k), GRU_R * prb.dhc + k)
43 + bias(GRU_R, k));
44 dst_layer(i, k) = maybe_q(
45 prb, (maybe_deq(prb, src_iter(i, k)) * gates(i, GRU_R, k)));
46 }
47}
48
49void gru_fwd_postgemm_part1(const prb_t &prb, float *gates_,
50 const float *src_iter_, const float *bias_, float *dst_layer_) {
51 if (prb.skip_nonlinear)
52 gru_fwd_postgemm_part1_template(
53 [](float scale, float a) { return scale * a; }, prb, gates_,
54 src_iter_, bias_, dst_layer_);
55 else
56 gru_fwd_postgemm_part1_template(
57 [](float scale, float a) { return logistic(a); }, prb, gates_,
58 src_iter_, bias_, dst_layer_);
59}
60
61template <typename T>
62void gru_fwd_postgemm_part2_template(T func1, const prb_t &prb, float *gates_,
63 const float *src_iter_, const float *bias_,
64 const float *src_layer_attention_, float *dst_layer_) {
65 AOC<const float> bias(bias_, prb.n_gates(), prb.dhc);
66 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
67 AOC<const float> src_layer_attention(src_layer_attention_, prb.mb);
68 AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
69 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
70 for (int64_t i = 0; i < prb.mb; i++)
71 for (int64_t k = 0; k < prb.dhc; k++) {
72 double U = gates(i, GRU_U, k);
73 double O = func1(prb.linear_scales[GRU_O],
74 maybe_deq(prb, gates(i, GRU_O, k), GRU_O * prb.dhc + k)
75 + bias(GRU_O, k));
76 if (prb.alg == VANILLA_AUGRU) {
77 const double A = src_layer_attention(i);
78 U = (1 - A) * U;
79 }
80 dst_layer(i, k) = maybe_q(prb,
81 (float)(U * maybe_deq(prb, src_iter(i, k))
82 + (1.0 - U) * O));
83
84 gates(i, GRU_O, k) = O;
85 }
86}
87
88void gru_fwd_postgemm_part2(const prb_t &prb, float *gates_,
89 const float *src_iter_, const float *bias_,
90 const float *src_layer_attention_, float *dst_layer_) {
91 if (prb.skip_nonlinear)
92 gru_fwd_postgemm_part2_template(
93 [](float scale, float a) { return scale * a; }, prb, gates_,
94 src_iter_, bias_, src_layer_attention_, dst_layer_);
95 else
96 gru_fwd_postgemm_part2_template(
97 [](float scale, float a) { return tanhf(a); }, prb, gates_,
98 src_iter_, bias_, src_layer_attention_, dst_layer_);
99}
100
101void gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
102 const float *weights_layer_, const float *weights_iter_,
103 const float *bias_, const float *src_layer_,
104 const float *src_layer_attention_, const float *src_iter_) {
105 AOC<const float> weights_iter(
106 weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
107 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
108
109 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0,
110 src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0,
111 gates_, prb.n_gates() * prb.dhc);
112 gemm("C", "N", "N", prb.mb, (prb.n_gates() - 1) * prb.dhc, prb.sic, 1.0,
113 src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 1.0,
114 gates_, prb.n_gates() * prb.dhc);
115
116 gru_fwd_postgemm_part1(prb, gates_, src_iter_, bias_, dst_layer_);
117
118 gemm("C", "N", "N", prb.mb, prb.dhc, prb.sic, 1.0, dst_layer_, prb.wc,
119 &(weights_iter(0, GRU_O, 0)), prb.n_gates() * prb.dhc, 1.0,
120 &(gates(0, GRU_O, 0)), prb.n_gates() * prb.dhc);
121
122 gru_fwd_postgemm_part2(
123 prb, gates_, src_iter_, bias_, src_layer_attention_, dst_layer_);
124}
125
126void gru_bwd_pregemm_part1(const prb_t &prb, const float *src_iter_,
127 const float *src_layer_attention_, const float *diff_dst_layer_,
128 const float *diff_dst_iter_, const float *gates_, float *diff_src_iter_,
129 float *diff_src_layer_attention_, float *b_gates_) {
130 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
131 AOC<const float> src_layer_attention(
132 src_layer_attention_, prb.n_iter, prb.mb);
133 AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc);
134 AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc);
135 AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
136
137 AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc);
138 AOC<float> diff_src_layer_attention(diff_src_layer_attention_, prb.mb);
139 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
140
141 // do = (1 - u) * dh; do^ = one_m_square(o) * do;
142 // du = (h - u) * dh; du^ = x_m_square(u) * du;
143 for (int64_t ib = 0; ib < prb.mb; ib++) {
144 if (prb.alg == VANILLA_AUGRU) diff_src_layer_attention(ib) = 0.0f;
145 for (int64_t ih = 0; ih < prb.dhc; ih++) {
146 float h = src_iter(ib, ih);
147 float o = gates(ib, GRU_O, ih);
148 float u = gates(ib, GRU_U, ih);
149 float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
150 float du = (h - o) * dh * x_m_square(u);
151 float dO = (1.0f - u) * dh * one_m_square(o);
152 if (prb.alg == VANILLA_AUGRU) {
153 diff_src_layer_attention(ib) -= du * u;
154 du *= 1 - src_layer_attention(ib);
155 }
156 b_gates(ib, GRU_U, ih) = du;
157 b_gates(ib, GRU_O, ih) = dO;
158 diff_src_iter(ib, ih) = dh * u;
159 }
160 }
161}
162
163void gru_bwd_pregemm_part2(const prb_t &prb, const float *src_iter_,
164 const float *gates_, const float *dhr_, float *diff_src_iter_,
165 float *b_gates_, float *hr_) {
166 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
167 AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
168 AOC<const float> dhr(dhr_, prb.mb, prb.dhc);
169 AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc);
170 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
171 AOC<float> hr(hr_, prb.mb, prb.dhc);
172
173 // dhr = Wo do^;
174 // dr = h * dhr; dr^ = x_m_square(r) * dr;
175 for (int64_t ib = 0; ib < prb.mb; ib++)
176 for (int64_t ih = 0; ih < prb.dhc; ih++) {
177 float h = src_iter(ib, ih);
178 float r = gates(ib, GRU_R, ih);
179 float dr = h * dhr(ib, ih);
180 hr(ib, ih) = h * r;
181 diff_src_iter(ib, ih) += dhr(ib, ih) * r;
182 b_gates(ib, GRU_R, ih) = x_m_square(r) * dr;
183 }
184}
185
186void gru_bwd(const prb_t &prb, float *diff_src_layer_,
187 float *diff_src_layer_attention_, float *diff_src_iter_,
188 float *diff_weights_layer_, float *diff_weights_iter_,
189 float *diff_bias_, float *b_gates_, const float *src_layer_,
190 const float *src_layer_attention_, const float *src_iter_,
191 const float *weights_layer_, const float *weights_iter_,
192 const float *bias_, const float *gates_, const float *diff_dst_layer_,
193 const float *diff_dst_iter_, float *cell_scratchpad_) {
194 AOC<const float> weights_iter(
195 weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
196
197 AOC<float> diff_weights_iter(
198 diff_weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
199 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
200
201 assert(prb.dhc == prb.sic);
202 float *dhr_ = cell_scratchpad_;
203 float *hr_ = cell_scratchpad_ + prb.mb * prb.dhc;
204
205 gru_bwd_pregemm_part1(prb, src_iter_, src_layer_attention_, diff_dst_layer_,
206 diff_dst_iter_, gates_, diff_src_iter_, diff_src_layer_attention_,
207 b_gates_);
208
209 gemm("C", "N", "T", prb.mb, prb.sic, prb.dhc, 1.0, &(b_gates(0, GRU_O, 0)),
210 prb.n_gates() * prb.dhc, &(weights_iter(0, GRU_O, 0)),
211 prb.n_gates() * prb.dhc, 0.0, dhr_, prb.dhc);
212
213 gru_bwd_pregemm_part2(
214 prb, src_iter_, gates_, dhr_, diff_src_iter_, b_gates_, hr_);
215
216 // dWx += xdu^ | xdr^ | xdo^
217 // dWh += hdu^ | ddr^ | (h * r)do^
218 gemm("C", "T", "N", prb.sic, (prb.n_gates() - 1) * prb.dhc, prb.mb, 1.0,
219 src_iter_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
220 diff_weights_iter_, prb.n_gates() * prb.dhc);
221 gemm("C", "T", "N", prb.sic, prb.dhc, prb.mb, 1.0, hr_, prb.dhc,
222 &(b_gates(0, GRU_O, 0)), prb.n_gates() * prb.dhc, 1.0,
223 &(diff_weights_iter(0, GRU_O, 0)), prb.n_gates() * prb.dhc);
224 gemm("C", "T", "N", prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0,
225 src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
226 diff_weights_layer_, prb.n_gates() * prb.dhc);
227
228 // dx_next = Wxudu^ + Wxrdr^ + Wxodo^
229 // dh_next = dh * u + Whudu^ + Whzdz^ + r * Whodo^
230 gemm("C", "N", "T", prb.mb, prb.sic, (prb.n_gates() - 1) * prb.dhc, 1.0,
231 b_gates_, prb.n_gates() * prb.dhc, weights_iter_,
232 prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc);
233 gemm("C", "N", "T", prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_,
234 prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc,
235 0.0, diff_src_layer_, prb.wc);
236
237 gates_reduction(prb, b_gates_, diff_bias_);
238}
239
240} // namespace rnn
241