1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | |
19 | #include "utils/parallel.hpp" |
20 | |
21 | #include "rnn/rnn.hpp" |
22 | #include "rnn/rnn_aux.hpp" |
23 | |
24 | #include "rnn/cells.hpp" |
25 | |
26 | namespace rnn { |
27 | |
28 | template <typename T> |
29 | void gru_fwd_postgemm_part1_template(T func1, const prb_t &prb, float *gates_, |
30 | const float *src_iter_, const float *bias_, float *dst_layer_) { |
31 | AOC<const float> bias(bias_, prb.n_gates(), prb.dhc); |
32 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
33 | AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc); |
34 | AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
35 | |
36 | for (int64_t i = 0; i < prb.mb; i++) |
37 | for (int64_t k = 0; k < prb.dhc; k++) { |
38 | gates(i, GRU_U, k) = func1(prb.linear_scales[GRU_U], |
39 | maybe_deq(prb, gates(i, GRU_U, k), GRU_U * prb.dhc + k) |
40 | + bias(GRU_U, k)); |
41 | gates(i, GRU_R, k) = func1(prb.linear_scales[GRU_R], |
42 | maybe_deq(prb, gates(i, GRU_R, k), GRU_R * prb.dhc + k) |
43 | + bias(GRU_R, k)); |
44 | dst_layer(i, k) = maybe_q( |
45 | prb, (maybe_deq(prb, src_iter(i, k)) * gates(i, GRU_R, k))); |
46 | } |
47 | } |
48 | |
49 | void gru_fwd_postgemm_part1(const prb_t &prb, float *gates_, |
50 | const float *src_iter_, const float *bias_, float *dst_layer_) { |
51 | if (prb.skip_nonlinear) |
52 | gru_fwd_postgemm_part1_template( |
53 | [](float scale, float a) { return scale * a; }, prb, gates_, |
54 | src_iter_, bias_, dst_layer_); |
55 | else |
56 | gru_fwd_postgemm_part1_template( |
57 | [](float scale, float a) { return logistic(a); }, prb, gates_, |
58 | src_iter_, bias_, dst_layer_); |
59 | } |
60 | |
61 | template <typename T> |
62 | void gru_fwd_postgemm_part2_template(T func1, const prb_t &prb, float *gates_, |
63 | const float *src_iter_, const float *bias_, |
64 | const float *src_layer_attention_, float *dst_layer_) { |
65 | AOC<const float> bias(bias_, prb.n_gates(), prb.dhc); |
66 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
67 | AOC<const float> src_layer_attention(src_layer_attention_, prb.mb); |
68 | AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc); |
69 | AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
70 | for (int64_t i = 0; i < prb.mb; i++) |
71 | for (int64_t k = 0; k < prb.dhc; k++) { |
72 | double U = gates(i, GRU_U, k); |
73 | double O = func1(prb.linear_scales[GRU_O], |
74 | maybe_deq(prb, gates(i, GRU_O, k), GRU_O * prb.dhc + k) |
75 | + bias(GRU_O, k)); |
76 | if (prb.alg == VANILLA_AUGRU) { |
77 | const double A = src_layer_attention(i); |
78 | U = (1 - A) * U; |
79 | } |
80 | dst_layer(i, k) = maybe_q(prb, |
81 | (float)(U * maybe_deq(prb, src_iter(i, k)) |
82 | + (1.0 - U) * O)); |
83 | |
84 | gates(i, GRU_O, k) = O; |
85 | } |
86 | } |
87 | |
88 | void gru_fwd_postgemm_part2(const prb_t &prb, float *gates_, |
89 | const float *src_iter_, const float *bias_, |
90 | const float *src_layer_attention_, float *dst_layer_) { |
91 | if (prb.skip_nonlinear) |
92 | gru_fwd_postgemm_part2_template( |
93 | [](float scale, float a) { return scale * a; }, prb, gates_, |
94 | src_iter_, bias_, src_layer_attention_, dst_layer_); |
95 | else |
96 | gru_fwd_postgemm_part2_template( |
97 | [](float scale, float a) { return tanhf(a); }, prb, gates_, |
98 | src_iter_, bias_, src_layer_attention_, dst_layer_); |
99 | } |
100 | |
101 | void gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_, |
102 | const float *weights_layer_, const float *weights_iter_, |
103 | const float *bias_, const float *src_layer_, |
104 | const float *src_layer_attention_, const float *src_iter_) { |
105 | AOC<const float> weights_iter( |
106 | weights_iter_, prb.sic, prb.n_gates(), prb.dhc); |
107 | AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
108 | |
109 | gemm("C" , "N" , "N" , prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0, |
110 | src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0, |
111 | gates_, prb.n_gates() * prb.dhc); |
112 | gemm("C" , "N" , "N" , prb.mb, (prb.n_gates() - 1) * prb.dhc, prb.sic, 1.0, |
113 | src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 1.0, |
114 | gates_, prb.n_gates() * prb.dhc); |
115 | |
116 | gru_fwd_postgemm_part1(prb, gates_, src_iter_, bias_, dst_layer_); |
117 | |
118 | gemm("C" , "N" , "N" , prb.mb, prb.dhc, prb.sic, 1.0, dst_layer_, prb.wc, |
119 | &(weights_iter(0, GRU_O, 0)), prb.n_gates() * prb.dhc, 1.0, |
120 | &(gates(0, GRU_O, 0)), prb.n_gates() * prb.dhc); |
121 | |
122 | gru_fwd_postgemm_part2( |
123 | prb, gates_, src_iter_, bias_, src_layer_attention_, dst_layer_); |
124 | } |
125 | |
126 | void gru_bwd_pregemm_part1(const prb_t &prb, const float *src_iter_, |
127 | const float *src_layer_attention_, const float *diff_dst_layer_, |
128 | const float *diff_dst_iter_, const float *gates_, float *diff_src_iter_, |
129 | float *diff_src_layer_attention_, float *b_gates_) { |
130 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
131 | AOC<const float> src_layer_attention( |
132 | src_layer_attention_, prb.n_iter, prb.mb); |
133 | AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc); |
134 | AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc); |
135 | AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
136 | |
137 | AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc); |
138 | AOC<float> diff_src_layer_attention(diff_src_layer_attention_, prb.mb); |
139 | AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
140 | |
141 | // do = (1 - u) * dh; do^ = one_m_square(o) * do; |
142 | // du = (h - u) * dh; du^ = x_m_square(u) * du; |
143 | for (int64_t ib = 0; ib < prb.mb; ib++) { |
144 | if (prb.alg == VANILLA_AUGRU) diff_src_layer_attention(ib) = 0.0f; |
145 | for (int64_t ih = 0; ih < prb.dhc; ih++) { |
146 | float h = src_iter(ib, ih); |
147 | float o = gates(ib, GRU_O, ih); |
148 | float u = gates(ib, GRU_U, ih); |
149 | float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih); |
150 | float du = (h - o) * dh * x_m_square(u); |
151 | float dO = (1.0f - u) * dh * one_m_square(o); |
152 | if (prb.alg == VANILLA_AUGRU) { |
153 | diff_src_layer_attention(ib) -= du * u; |
154 | du *= 1 - src_layer_attention(ib); |
155 | } |
156 | b_gates(ib, GRU_U, ih) = du; |
157 | b_gates(ib, GRU_O, ih) = dO; |
158 | diff_src_iter(ib, ih) = dh * u; |
159 | } |
160 | } |
161 | } |
162 | |
163 | void gru_bwd_pregemm_part2(const prb_t &prb, const float *src_iter_, |
164 | const float *gates_, const float *dhr_, float *diff_src_iter_, |
165 | float *b_gates_, float *hr_) { |
166 | AOC<const float> src_iter(src_iter_, prb.mb, prb.wc); |
167 | AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc); |
168 | AOC<const float> dhr(dhr_, prb.mb, prb.dhc); |
169 | AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc); |
170 | AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
171 | AOC<float> hr(hr_, prb.mb, prb.dhc); |
172 | |
173 | // dhr = Wo do^; |
174 | // dr = h * dhr; dr^ = x_m_square(r) * dr; |
175 | for (int64_t ib = 0; ib < prb.mb; ib++) |
176 | for (int64_t ih = 0; ih < prb.dhc; ih++) { |
177 | float h = src_iter(ib, ih); |
178 | float r = gates(ib, GRU_R, ih); |
179 | float dr = h * dhr(ib, ih); |
180 | hr(ib, ih) = h * r; |
181 | diff_src_iter(ib, ih) += dhr(ib, ih) * r; |
182 | b_gates(ib, GRU_R, ih) = x_m_square(r) * dr; |
183 | } |
184 | } |
185 | |
186 | void gru_bwd(const prb_t &prb, float *diff_src_layer_, |
187 | float *diff_src_layer_attention_, float *diff_src_iter_, |
188 | float *diff_weights_layer_, float *diff_weights_iter_, |
189 | float *diff_bias_, float *b_gates_, const float *src_layer_, |
190 | const float *src_layer_attention_, const float *src_iter_, |
191 | const float *weights_layer_, const float *weights_iter_, |
192 | const float *bias_, const float *gates_, const float *diff_dst_layer_, |
193 | const float *diff_dst_iter_, float *cell_scratchpad_) { |
194 | AOC<const float> weights_iter( |
195 | weights_iter_, prb.sic, prb.n_gates(), prb.dhc); |
196 | |
197 | AOC<float> diff_weights_iter( |
198 | diff_weights_iter_, prb.sic, prb.n_gates(), prb.dhc); |
199 | AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc); |
200 | |
201 | assert(prb.dhc == prb.sic); |
202 | float *dhr_ = cell_scratchpad_; |
203 | float *hr_ = cell_scratchpad_ + prb.mb * prb.dhc; |
204 | |
205 | gru_bwd_pregemm_part1(prb, src_iter_, src_layer_attention_, diff_dst_layer_, |
206 | diff_dst_iter_, gates_, diff_src_iter_, diff_src_layer_attention_, |
207 | b_gates_); |
208 | |
209 | gemm("C" , "N" , "T" , prb.mb, prb.sic, prb.dhc, 1.0, &(b_gates(0, GRU_O, 0)), |
210 | prb.n_gates() * prb.dhc, &(weights_iter(0, GRU_O, 0)), |
211 | prb.n_gates() * prb.dhc, 0.0, dhr_, prb.dhc); |
212 | |
213 | gru_bwd_pregemm_part2( |
214 | prb, src_iter_, gates_, dhr_, diff_src_iter_, b_gates_, hr_); |
215 | |
216 | // dWx += xdu^ | xdr^ | xdo^ |
217 | // dWh += hdu^ | ddr^ | (h * r)do^ |
218 | gemm("C" , "T" , "N" , prb.sic, (prb.n_gates() - 1) * prb.dhc, prb.mb, 1.0, |
219 | src_iter_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0, |
220 | diff_weights_iter_, prb.n_gates() * prb.dhc); |
221 | gemm("C" , "T" , "N" , prb.sic, prb.dhc, prb.mb, 1.0, hr_, prb.dhc, |
222 | &(b_gates(0, GRU_O, 0)), prb.n_gates() * prb.dhc, 1.0, |
223 | &(diff_weights_iter(0, GRU_O, 0)), prb.n_gates() * prb.dhc); |
224 | gemm("C" , "T" , "N" , prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0, |
225 | src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0, |
226 | diff_weights_layer_, prb.n_gates() * prb.dhc); |
227 | |
228 | // dx_next = Wxudu^ + Wxrdr^ + Wxodo^ |
229 | // dh_next = dh * u + Whudu^ + Whzdz^ + r * Whodo^ |
230 | gemm("C" , "N" , "T" , prb.mb, prb.sic, (prb.n_gates() - 1) * prb.dhc, 1.0, |
231 | b_gates_, prb.n_gates() * prb.dhc, weights_iter_, |
232 | prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc); |
233 | gemm("C" , "N" , "T" , prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_, |
234 | prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc, |
235 | 0.0, diff_src_layer_, prb.wc); |
236 | |
237 | gates_reduction(prb, b_gates_, diff_bias_); |
238 | } |
239 | |
240 | } // namespace rnn |
241 | |