1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include "utils/parallel.hpp" |
20 | |
21 | #include "rnn/rnn.hpp" |
22 | #include "rnn/rnn_aux.hpp" |
23 | |
24 | #include "rnn/cells.hpp" |
25 | |
26 | namespace rnn { |
27 | template <typename T1, typename T2> |
28 | void lbr_gru_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb, |
29 | float *gates_, const float *src_iter_, const float *bias_, |
30 | const float *src_layer_attention_, float *dst_layer_, |
31 | float *cell_scratchpad_) { |
32 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
33 | AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc); |
34 | AOC<const float> src_layer_attention(src_layer_attention_, prb.mb); |
35 | AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
36 | AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc); |
37 | AOC<float> cell_scratchpad( |
38 | cell_scratchpad_, prb.mb, prb.n_gates(), prb.dhc); |
39 | |
40 | for (int64_t i = 0; i < prb.mb; i++) |
41 | for (int64_t j = 0; j < prb.n_gates() - 1; j++) |
42 | for (int64_t k = 0; k < prb.dhc; k++) { |
43 | gates(i, j, k) = func1(prb.linear_scales[j], |
44 | gates(i, j, k) + cell_scratchpad(i, j, k) + bias(j, k)); |
45 | } |
46 | |
47 | for (int64_t i = 0; i < prb.mb; i++) |
48 | for (int64_t k = 0; k < prb.dhc; k++) { |
49 | gates(i, GRU_O, k) = func2(prb.linear_scales[GRU_O], |
50 | gates(i, GRU_O, k) |
51 | + gates(i, GRU_R, k) |
52 | * (cell_scratchpad(i, GRU_O, k) |
53 | + bias(LBR_GRU_U_PRIME, k)) |
54 | + bias(GRU_O, k)); |
55 | } |
56 | |
57 | for (int64_t i = 0; i < prb.mb; i++) |
58 | for (int64_t k = 0; k < prb.dhc; k++) { |
59 | double U = gates(i, GRU_U, k); |
60 | if (prb.alg == LBR_AUGRU) { |
61 | const double A = src_layer_attention(i); |
62 | U = (1 - A) * U; |
63 | } |
64 | dst_layer(i, k) = U * src_iter(i, k) + (1 - U) * gates(i, GRU_O, k); |
65 | } |
66 | } |
67 | |
68 | void lbr_gru_fwd_postgemm(const prb_t &prb, float *gates_, |
69 | const float *src_iter_, const float *bias_, |
70 | const float *src_layer_attention_, float *dst_layer_, |
71 | float *cell_scratchpad_) { |
72 | if (prb.skip_nonlinear) |
73 | lbr_gru_fwd_postgemm_template( |
74 | [](float scale, float a) { return scale * a; }, |
75 | [](float scale, float a) { return scale * a; }, prb, gates_, |
76 | src_iter_, bias_, src_layer_attention_, dst_layer_, |
77 | cell_scratchpad_); |
78 | else |
79 | lbr_gru_fwd_postgemm_template( |
80 | [](float scale, float a) { return logistic(a); }, |
81 | [](float scale, float a) { return tanhf(a); }, prb, gates_, |
82 | src_iter_, bias_, src_layer_attention_, dst_layer_, |
83 | cell_scratchpad_); |
84 | } |
85 | |
86 | void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_, |
87 | const float *weights_layer_, const float *weights_iter_, |
88 | const float *bias_, const float *src_layer_, |
89 | const float *src_layer_attention_, const float *src_iter_, |
90 | float *cell_scratchpad_) { |
91 | gemm("C" , "N" , "N" , prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0, |
92 | src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0, |
93 | gates_, prb.n_gates() * prb.dhc); |
94 | |
95 | gemm("C" , "N" , "N" , prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0, |
96 | src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 0.0, |
97 | cell_scratchpad_, prb.n_gates() * prb.dhc); |
98 | |
99 | lbr_gru_fwd_postgemm(prb, gates_, src_iter_, bias_, src_layer_attention_, |
100 | dst_layer_, cell_scratchpad_); |
101 | } |
102 | |
103 | void lbr_gru_bwd_pregemm(const prb_t &prb, const float *src_iter_, |
104 | const float *src_layer_attention_, const float *diff_dst_layer_, |
105 | const float *diff_dst_iter_, const float *gates_, const float *Wh_b_, |
106 | float *diff_src_iter_, float *diff_src_layer_attention_, |
107 | float *b_gates_, float *b_gates_r_) { |
108 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
109 | AOC<const float> src_layer_attention(src_layer_attention_, prb.mb); |
110 | AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc); |
111 | AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc); |
112 | AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
113 | AOC<const float> Wh_b(Wh_b_, prb.mb, prb.dhc); |
114 | |
115 | AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc); |
116 | AOC<float> diff_src_layer_attention(diff_src_layer_attention_, prb.mb); |
117 | AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
118 | AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc); |
119 | |
120 | // do = (1 - u) * dh; do^ = one_m_square(o) * do; |
121 | // du = (h - o) * dh; du^ = x_m_square(u) * du; |
122 | // dr = (Wh + b) * do^; dr^ = x_m_square(r) * dr; |
123 | for (int64_t ib = 0; ib < prb.mb; ib++) { |
124 | if (prb.alg == LBR_AUGRU) diff_src_layer_attention(ib) = 0.0f; |
125 | for (int64_t ih = 0; ih < prb.dhc; ih++) { |
126 | float h = src_iter(ib, ih); |
127 | float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih); |
128 | float u = gates(ib, GRU_U, ih); |
129 | float r = gates(ib, GRU_R, ih); |
130 | float o = gates(ib, GRU_O, ih); |
131 | float du = (h - o) * dh * x_m_square(u); |
132 | float dO = (1.0f - u) * dh; |
133 | if (prb.alg == LBR_AUGRU) { |
134 | diff_src_layer_attention(ib) -= du * u; |
135 | du *= 1 - src_layer_attention(ib); |
136 | } |
137 | |
138 | b_gates(ib, GRU_U, ih) = du; |
139 | b_gates(ib, GRU_O, ih) = one_m_square(o) * dO; |
140 | |
141 | float dr = Wh_b(ib, ih) * b_gates(ib, GRU_O, ih); |
142 | b_gates(ib, GRU_R, ih) = x_m_square(r) * dr; |
143 | |
144 | b_gates_r(ib, GRU_U, ih) = b_gates(ib, GRU_U, ih); |
145 | b_gates_r(ib, GRU_R, ih) = b_gates(ib, GRU_R, ih); |
146 | b_gates_r(ib, GRU_O, ih) = b_gates(ib, GRU_O, ih) * r; |
147 | diff_src_iter(ib, ih) = dh * u; |
148 | } |
149 | } |
150 | } |
151 | |
152 | void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_, |
153 | float *diff_src_layer_attention_, float *diff_src_iter_, |
154 | float *diff_weights_layer_, float *diff_weights_iter_, |
155 | float *diff_bias_, float *b_gates_, const float *src_layer_, |
156 | const float *src_layer_attention_, const float *src_iter_, |
157 | const float *weights_layer_, const float *weights_iter_, |
158 | const float *bias_, const float *gates_, const float *diff_dst_layer_, |
159 | const float *diff_dst_iter_, float *cell_scratchpad_) { |
160 | AOC<const float> weights_iter( |
161 | weights_iter_, prb.sic, prb.n_gates(), prb.dhc); |
162 | AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc); |
163 | |
164 | float *Wh_b_ = cell_scratchpad_; |
165 | float *b_gates_r_ = cell_scratchpad_ + prb.dhc * prb.mb; |
166 | AOC<float> Wh_b(Wh_b_, prb.mb, prb.dhc); |
167 | AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc); |
168 | |
169 | // TODO: save this this GEMM + bias in the fwd pass |
170 | for (int64_t ib = 0; ib < prb.mb; ib++) |
171 | for (int64_t ih = 0; ih < prb.dhc; ih++) |
172 | Wh_b(ib, ih) = bias(LBR_GRU_U_PRIME, ih); |
173 | |
174 | gemm("C" , "N" , "N" , prb.mb, prb.dhc, prb.sic, 1.0, src_iter_, prb.wc, |
175 | &weights_iter(0, GRU_O, 0), prb.n_gates() * prb.dhc, 1.0, Wh_b_, |
176 | prb.dhc); |
177 | |
178 | lbr_gru_bwd_pregemm(prb, src_iter_, src_layer_attention_, diff_dst_layer_, |
179 | diff_dst_iter_, gates_, Wh_b_, diff_src_iter_, |
180 | diff_src_layer_attention_, b_gates_, b_gates_r_); |
181 | |
182 | gemm("C" , "T" , "N" , prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0, |
183 | src_iter_, prb.wc, b_gates_r_, prb.n_gates() * prb.dhc, 1.0, |
184 | diff_weights_iter_, prb.n_gates() * prb.dhc); |
185 | gemm("C" , "T" , "N" , prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0, |
186 | src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0, |
187 | diff_weights_layer_, prb.n_gates() * prb.dhc); |
188 | |
189 | gemm("C" , "N" , "T" , prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_, |
190 | prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc, |
191 | 0.0, diff_src_layer_, prb.wc); |
192 | gemm("C" , "N" , "T" , prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0, |
193 | b_gates_r_, prb.n_gates() * prb.dhc, weights_iter_, |
194 | prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc); |
195 | |
196 | gates_reduction(prb, b_gates_, diff_bias_); |
197 | for (int64_t i = 0; i < prb.mb; i++) |
198 | for (int64_t k = 0; k < prb.dhc; k++) |
199 | diff_bias_[LBR_GRU_U_PRIME * prb.dhc + k] += b_gates_r(i, GRU_O, k); |
200 | } |
201 | |
202 | } // namespace rnn |
203 | |