1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18
19#include "utils/parallel.hpp"
20
21#include "rnn/rnn.hpp"
22#include "rnn/rnn_aux.hpp"
23
24#include "rnn/cells.hpp"
25
26namespace rnn {
27template <typename T1, typename T2>
28void lbr_gru_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb,
29 float *gates_, const float *src_iter_, const float *bias_,
30 const float *src_layer_attention_, float *dst_layer_,
31 float *cell_scratchpad_) {
32 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
33 AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
34 AOC<const float> src_layer_attention(src_layer_attention_, prb.mb);
35 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
36 AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
37 AOC<float> cell_scratchpad(
38 cell_scratchpad_, prb.mb, prb.n_gates(), prb.dhc);
39
40 for (int64_t i = 0; i < prb.mb; i++)
41 for (int64_t j = 0; j < prb.n_gates() - 1; j++)
42 for (int64_t k = 0; k < prb.dhc; k++) {
43 gates(i, j, k) = func1(prb.linear_scales[j],
44 gates(i, j, k) + cell_scratchpad(i, j, k) + bias(j, k));
45 }
46
47 for (int64_t i = 0; i < prb.mb; i++)
48 for (int64_t k = 0; k < prb.dhc; k++) {
49 gates(i, GRU_O, k) = func2(prb.linear_scales[GRU_O],
50 gates(i, GRU_O, k)
51 + gates(i, GRU_R, k)
52 * (cell_scratchpad(i, GRU_O, k)
53 + bias(LBR_GRU_U_PRIME, k))
54 + bias(GRU_O, k));
55 }
56
57 for (int64_t i = 0; i < prb.mb; i++)
58 for (int64_t k = 0; k < prb.dhc; k++) {
59 double U = gates(i, GRU_U, k);
60 if (prb.alg == LBR_AUGRU) {
61 const double A = src_layer_attention(i);
62 U = (1 - A) * U;
63 }
64 dst_layer(i, k) = U * src_iter(i, k) + (1 - U) * gates(i, GRU_O, k);
65 }
66}
67
68void lbr_gru_fwd_postgemm(const prb_t &prb, float *gates_,
69 const float *src_iter_, const float *bias_,
70 const float *src_layer_attention_, float *dst_layer_,
71 float *cell_scratchpad_) {
72 if (prb.skip_nonlinear)
73 lbr_gru_fwd_postgemm_template(
74 [](float scale, float a) { return scale * a; },
75 [](float scale, float a) { return scale * a; }, prb, gates_,
76 src_iter_, bias_, src_layer_attention_, dst_layer_,
77 cell_scratchpad_);
78 else
79 lbr_gru_fwd_postgemm_template(
80 [](float scale, float a) { return logistic(a); },
81 [](float scale, float a) { return tanhf(a); }, prb, gates_,
82 src_iter_, bias_, src_layer_attention_, dst_layer_,
83 cell_scratchpad_);
84}
85
86void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
87 const float *weights_layer_, const float *weights_iter_,
88 const float *bias_, const float *src_layer_,
89 const float *src_layer_attention_, const float *src_iter_,
90 float *cell_scratchpad_) {
91 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0,
92 src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0,
93 gates_, prb.n_gates() * prb.dhc);
94
95 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0,
96 src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 0.0,
97 cell_scratchpad_, prb.n_gates() * prb.dhc);
98
99 lbr_gru_fwd_postgemm(prb, gates_, src_iter_, bias_, src_layer_attention_,
100 dst_layer_, cell_scratchpad_);
101}
102
103void lbr_gru_bwd_pregemm(const prb_t &prb, const float *src_iter_,
104 const float *src_layer_attention_, const float *diff_dst_layer_,
105 const float *diff_dst_iter_, const float *gates_, const float *Wh_b_,
106 float *diff_src_iter_, float *diff_src_layer_attention_,
107 float *b_gates_, float *b_gates_r_) {
108 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
109 AOC<const float> src_layer_attention(src_layer_attention_, prb.mb);
110 AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc);
111 AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc);
112 AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
113 AOC<const float> Wh_b(Wh_b_, prb.mb, prb.dhc);
114
115 AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc);
116 AOC<float> diff_src_layer_attention(diff_src_layer_attention_, prb.mb);
117 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
118 AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
119
120 // do = (1 - u) * dh; do^ = one_m_square(o) * do;
121 // du = (h - o) * dh; du^ = x_m_square(u) * du;
122 // dr = (Wh + b) * do^; dr^ = x_m_square(r) * dr;
123 for (int64_t ib = 0; ib < prb.mb; ib++) {
124 if (prb.alg == LBR_AUGRU) diff_src_layer_attention(ib) = 0.0f;
125 for (int64_t ih = 0; ih < prb.dhc; ih++) {
126 float h = src_iter(ib, ih);
127 float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
128 float u = gates(ib, GRU_U, ih);
129 float r = gates(ib, GRU_R, ih);
130 float o = gates(ib, GRU_O, ih);
131 float du = (h - o) * dh * x_m_square(u);
132 float dO = (1.0f - u) * dh;
133 if (prb.alg == LBR_AUGRU) {
134 diff_src_layer_attention(ib) -= du * u;
135 du *= 1 - src_layer_attention(ib);
136 }
137
138 b_gates(ib, GRU_U, ih) = du;
139 b_gates(ib, GRU_O, ih) = one_m_square(o) * dO;
140
141 float dr = Wh_b(ib, ih) * b_gates(ib, GRU_O, ih);
142 b_gates(ib, GRU_R, ih) = x_m_square(r) * dr;
143
144 b_gates_r(ib, GRU_U, ih) = b_gates(ib, GRU_U, ih);
145 b_gates_r(ib, GRU_R, ih) = b_gates(ib, GRU_R, ih);
146 b_gates_r(ib, GRU_O, ih) = b_gates(ib, GRU_O, ih) * r;
147 diff_src_iter(ib, ih) = dh * u;
148 }
149 }
150}
151
152void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_,
153 float *diff_src_layer_attention_, float *diff_src_iter_,
154 float *diff_weights_layer_, float *diff_weights_iter_,
155 float *diff_bias_, float *b_gates_, const float *src_layer_,
156 const float *src_layer_attention_, const float *src_iter_,
157 const float *weights_layer_, const float *weights_iter_,
158 const float *bias_, const float *gates_, const float *diff_dst_layer_,
159 const float *diff_dst_iter_, float *cell_scratchpad_) {
160 AOC<const float> weights_iter(
161 weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
162 AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
163
164 float *Wh_b_ = cell_scratchpad_;
165 float *b_gates_r_ = cell_scratchpad_ + prb.dhc * prb.mb;
166 AOC<float> Wh_b(Wh_b_, prb.mb, prb.dhc);
167 AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
168
169 // TODO: save this this GEMM + bias in the fwd pass
170 for (int64_t ib = 0; ib < prb.mb; ib++)
171 for (int64_t ih = 0; ih < prb.dhc; ih++)
172 Wh_b(ib, ih) = bias(LBR_GRU_U_PRIME, ih);
173
174 gemm("C", "N", "N", prb.mb, prb.dhc, prb.sic, 1.0, src_iter_, prb.wc,
175 &weights_iter(0, GRU_O, 0), prb.n_gates() * prb.dhc, 1.0, Wh_b_,
176 prb.dhc);
177
178 lbr_gru_bwd_pregemm(prb, src_iter_, src_layer_attention_, diff_dst_layer_,
179 diff_dst_iter_, gates_, Wh_b_, diff_src_iter_,
180 diff_src_layer_attention_, b_gates_, b_gates_r_);
181
182 gemm("C", "T", "N", prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0,
183 src_iter_, prb.wc, b_gates_r_, prb.n_gates() * prb.dhc, 1.0,
184 diff_weights_iter_, prb.n_gates() * prb.dhc);
185 gemm("C", "T", "N", prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0,
186 src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
187 diff_weights_layer_, prb.n_gates() * prb.dhc);
188
189 gemm("C", "N", "T", prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_,
190 prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc,
191 0.0, diff_src_layer_, prb.wc);
192 gemm("C", "N", "T", prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0,
193 b_gates_r_, prb.n_gates() * prb.dhc, weights_iter_,
194 prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc);
195
196 gates_reduction(prb, b_gates_, diff_bias_);
197 for (int64_t i = 0; i < prb.mb; i++)
198 for (int64_t k = 0; k < prb.dhc; k++)
199 diff_bias_[LBR_GRU_U_PRIME * prb.dhc + k] += b_gates_r(i, GRU_O, k);
200}
201
202} // namespace rnn
203