1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | * Cell execution LSTM |
19 | */ |
20 | |
21 | #include "common/bit_cast.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | |
25 | #include "cpu/simple_q10n.hpp" |
26 | |
27 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | using namespace dnnl::impl::utils; |
34 | using namespace dnnl::impl::math; |
35 | using namespace rnn_utils; |
36 | #define AOC array_offset_calculator |
37 | |
38 | template <typename T1, typename T2, typename T3, typename T4, typename T5, |
39 | typename src_data_t, typename scratch_data_t> |
40 | void gru_fwd_part1_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float, |
41 | T4 src_to_float, T5 reinterpret_as_acc, const float *scales, |
42 | const rnn_utils::rnn_conf_t &rnn, |
43 | rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_, |
44 | scratch_data_t *scratch_gates_, const src_data_t *augru_attention_, |
45 | src_data_t *dst_layer_, src_data_t *dst_iter_, |
46 | const src_data_t *src_iter_, const void *bias_, int block_step) { |
47 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
48 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
49 | const auto bias_aoc = rnn_utils::make_raw_aoc( |
50 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
51 | const auto bias = [&](int gate_id, int dhc_id) { |
52 | return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt); |
53 | }; |
54 | |
55 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
56 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position); |
57 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
58 | |
59 | const ws_states_layer_aoc<src_data_t> dst_layer( |
60 | rnn, dst_layer_, dst_layer_ld); |
61 | const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld); |
62 | const ws_states_iter_aoc<const src_data_t> src_iter( |
63 | rnn, src_iter_, src_iter_ld); |
64 | |
65 | const float *scales_G1 = scales ? scales + 1 : nullptr; |
66 | |
67 | const auto postgemm_call = [&](int i) { |
68 | const int n_elem = block_step; |
69 | PRAGMA_OMP_SIMD() |
70 | for (int j = 0; j < n_elem; j++) { |
71 | const auto G0 // default func1 is sigmoid |
72 | = func1(scales, |
73 | acc_to_float(scratch_gates(i, 0, j), 0, j) |
74 | + bias(0, j)); |
75 | const auto G1 // default func1 is sigmoid |
76 | = func1(scales_G1, |
77 | acc_to_float(scratch_gates(i, 1, j), 1, j) |
78 | + bias(1, j)); |
79 | /* TODO: Can be optimized for fwd_training by using ws_gates instead of scratch_gates in p2 */ |
80 | scratch_gates(i, 0, j) = reinterpret_as_acc(G0); |
81 | const auto t = to_src(src_to_float(src_iter(i, j)) * G1); |
82 | if (dst_layer_) dst_layer(i, j) = t; |
83 | if (dst_iter_) dst_iter(i, j) = t; |
84 | |
85 | if (rnn.is_training) { |
86 | ws_gates(i, 0, j) = to_src(G0); |
87 | ws_gates(i, 1, j) = to_src(G1); |
88 | } |
89 | } |
90 | }; |
91 | |
92 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
93 | for (int i = 0; i < rnn.m_block; i++) |
94 | postgemm_call(i); |
95 | } else { |
96 | parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); }); |
97 | } |
98 | } |
99 | |
100 | template <typename T1, typename T2, typename T3, typename T4, typename T5, |
101 | typename src_data_t, typename scratch_data_t> |
102 | void gru_fwd_part2_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float, |
103 | T4 src_to_float, T5 reinterpret_as_float, const float *scales, |
104 | const rnn_utils::rnn_conf_t &rnn, |
105 | rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_, |
106 | scratch_data_t *scratch_gates_, const src_data_t *augru_attention_, |
107 | src_data_t *dst_layer_, src_data_t *dst_iter_, |
108 | const src_data_t *src_iter_, const void *bias_, int block_step) { |
109 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
110 | const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
111 | const auto bias_aoc = rnn_utils::make_raw_aoc( |
112 | bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc); |
113 | const auto bias = [&](int gate_id, int dhc_id) { |
114 | return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt); |
115 | }; |
116 | |
117 | const auto dst_layer_ld = rnn.dst_layer_ld(cell_position); |
118 | const auto dst_iter_ld = rnn.dst_iter_ld(cell_position); |
119 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
120 | const augru_attention_aoc<const src_data_t> augru_attention( |
121 | rnn, augru_attention_); |
122 | const ws_states_layer_aoc<src_data_t> dst_layer( |
123 | rnn, dst_layer_, dst_layer_ld); |
124 | const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld); |
125 | const ws_states_iter_aoc<const src_data_t> src_iter( |
126 | rnn, src_iter_, src_iter_ld); |
127 | |
128 | const float *scales_G2 = scales ? scales + 2 : nullptr; |
129 | |
130 | const auto postgemm_call = [&](int i) { |
131 | const int n_elem = block_step; |
132 | PRAGMA_OMP_SIMD() |
133 | for (int j = 0; j < n_elem; j++) { |
134 | auto G0 = reinterpret_as_float(scratch_gates(i, 0, j)); |
135 | const auto G2 // default func1 is tanh |
136 | = func1(scales_G2, |
137 | acc_to_float(scratch_gates(i, 2, j), 2, j) |
138 | + bias(2, j)); |
139 | |
140 | if (rnn.is_augru) { |
141 | const auto a = reinterpret_as_float(augru_attention(i)); |
142 | G0 = (1.0f - a) * G0; |
143 | } |
144 | |
145 | const auto tmp = to_src( |
146 | src_to_float(src_iter(i, j)) * G0 + (1.0f - G0) * G2); |
147 | if (dst_layer_ != nullptr) dst_layer(i, j) = tmp; |
148 | if (dst_iter_ != nullptr) dst_iter(i, j) = tmp; |
149 | |
150 | if (rnn.is_training) { ws_gates(i, 2, j) = to_src(G2); } |
151 | } |
152 | }; |
153 | |
154 | if (rnn.is_brgemm && !rnn.unfused_post_gemm) { |
155 | for (int i = 0; i < rnn.m_block; i++) |
156 | postgemm_call(i); |
157 | } else { |
158 | parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); }); |
159 | } |
160 | } |
161 | |
162 | template <> |
163 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm) { |
164 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
165 | const auto linear_f |
166 | = [](const float *scale, float a) { return *scale * a; }; |
167 | const auto logistic_f = [](const float *scale, float a) { |
168 | return logistic_fwd<float>(a); |
169 | }; |
170 | |
171 | const auto deq_id = [](float f, int i, int j) { return f; }; |
172 | const auto id = [](float f) { return f; }; |
173 | |
174 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
175 | gru_fwd_part1_postgemm_template(logistic_f, id, deq_id, id, id, scales, |
176 | rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_, |
177 | dst_layer_, dst_iter_, src_iter_, bias_, block_step); |
178 | else |
179 | gru_fwd_part1_postgemm_template(linear_f, id, deq_id, id, id, scales, |
180 | rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_, |
181 | dst_layer_, dst_iter_, src_iter_, bias_, block_step); |
182 | } |
183 | |
184 | template <> |
185 | rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm) { |
186 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
187 | const auto linear_f |
188 | = [](const float *scale, float a) { return *scale * a; }; |
189 | const auto tanh_f |
190 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
191 | |
192 | const auto deq_id = [](float f, int i, int j) { return f; }; |
193 | const auto id = [](float f) { return f; }; |
194 | |
195 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
196 | gru_fwd_part2_postgemm_template(tanh_f, id, deq_id, id, id, scales, rnn, |
197 | cell_position, ws_gates_, scratch_gates_, augru_attention_, |
198 | dst_layer_, dst_iter_, src_iter_, bias_, block_step); |
199 | else |
200 | gru_fwd_part2_postgemm_template(linear_f, id, deq_id, id, id, scales, |
201 | rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_, |
202 | dst_layer_, dst_iter_, src_iter_, bias_, block_step); |
203 | } |
204 | |
205 | template <> |
206 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm) { |
207 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
208 | const auto linear_f |
209 | = [](const float *scale, float a) { return *scale * a; }; |
210 | const auto logistic_f = [](const float *scale, float a) { |
211 | return logistic_fwd<float>(a); |
212 | }; |
213 | |
214 | const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); }; |
215 | const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); }; |
216 | const auto deq_id = [](float f, int i, int j) { return f; }; |
217 | const auto id = [](float f) { return f; }; |
218 | |
219 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
220 | gru_fwd_part1_postgemm_template(logistic_f, dn_cvt_f32_bf16, deq_id, |
221 | up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_, |
222 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_, |
223 | src_iter_, bias_, block_step); |
224 | else |
225 | gru_fwd_part1_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id, |
226 | up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_, |
227 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_, |
228 | src_iter_, bias_, block_step); |
229 | } |
230 | template <> |
231 | rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm) { |
232 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
233 | const auto linear_f |
234 | = [](const float *scale, float a) { return *scale * a; }; |
235 | const auto tanh_f |
236 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
237 | |
238 | const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); }; |
239 | const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); }; |
240 | const auto deq_id = [](float f, int i, int j) { return f; }; |
241 | const auto id = [](float f) { return f; }; |
242 | |
243 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
244 | gru_fwd_part2_postgemm_template(tanh_f, dn_cvt_f32_bf16, deq_id, |
245 | up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_, |
246 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_, |
247 | src_iter_, bias_, block_step); |
248 | else |
249 | gru_fwd_part2_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id, |
250 | up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_, |
251 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_, |
252 | src_iter_, bias_, block_step); |
253 | } |
254 | |
255 | template <> |
256 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm) { |
257 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
258 | const auto linear_f |
259 | = [](const float *scale, float a) { return *scale * a; }; |
260 | const auto logistic_f = [](const float *scale, float a) { |
261 | return logistic_fwd<float>(a); |
262 | }; |
263 | |
264 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
265 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
266 | |
267 | const auto quantize_f32_u8 = [&](float f) { |
268 | float qf = f * data_scale + data_shift; |
269 | qf = nstl::min(qf, 255.0f); |
270 | qf = nstl::max(qf, 0.0f); |
271 | return (dst_layer_t)mxcsr_cvt(qf); |
272 | }; |
273 | |
274 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { |
275 | const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0 |
276 | ? weights_scales_[0] |
277 | : weights_scales_[gate * rnn.dhc + j]; |
278 | return saturate<float>(s) * (1.f / (wscale * data_scale)); |
279 | }; |
280 | |
281 | const auto dequantize_u8_f32 = [&](src_iter_t s) { |
282 | return (static_cast<float>(s) - data_shift) * (1.f / data_scale); |
283 | }; |
284 | |
285 | const auto reinterpret_f32_s32 |
286 | = [](float a) { return bit_cast<gemm_acc_t>(a); }; |
287 | |
288 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
289 | gru_fwd_part1_postgemm_template(logistic_f, quantize_f32_u8, |
290 | dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32, |
291 | scales, rnn, cell_position, ws_gates_, scratch_gates_, |
292 | augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_, |
293 | block_step); |
294 | else |
295 | gru_fwd_part1_postgemm_template(linear_f, quantize_f32_u8, |
296 | dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32, |
297 | scales, rnn, cell_position, ws_gates_, scratch_gates_, |
298 | augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_, |
299 | block_step); |
300 | } |
301 | |
302 | template <> |
303 | rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm) { |
304 | const float *scales = pd_->attr()->rnn_tparams_.scales_; |
305 | const auto linear_f |
306 | = [](const float *scale, float a) { return *scale * a; }; |
307 | const auto tanh_f |
308 | = [](const float *scale, float a) { return tanh_fwd<float>(a); }; |
309 | |
310 | const float data_shift = pd_->attr()->rnn_data_qparams_.shift_; |
311 | const float data_scale = pd_->attr()->rnn_data_qparams_.scale_; |
312 | |
313 | const auto quantize_f32_u8 = [&](float f) { |
314 | float qf = f * data_scale + data_shift; |
315 | qf = nstl::min(qf, 255.0f); |
316 | qf = nstl::max(qf, 0.0f); |
317 | return (dst_layer_t)mxcsr_cvt(qf); |
318 | }; |
319 | |
320 | const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) { |
321 | const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0 |
322 | ? weights_scales_[0] |
323 | : weights_scales_[gate * rnn.dhc + j]; |
324 | return saturate<float>(s) * (1.f / (wscale * data_scale)); |
325 | }; |
326 | |
327 | const auto dequantize_u8_f32 = [&](src_iter_t s) { |
328 | return (static_cast<float>(s) - data_shift) * (1.f / data_scale); |
329 | }; |
330 | |
331 | const auto reinterpret_s32_f32 |
332 | = [](gemm_acc_t a) { return bit_cast<float>(a); }; |
333 | |
334 | if (!pd_->attr()->rnn_tparams_.test_mode_) |
335 | gru_fwd_part2_postgemm_template(tanh_f, quantize_f32_u8, |
336 | dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32, |
337 | scales, rnn, cell_position, ws_gates_, scratch_gates_, |
338 | augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_, |
339 | block_step); |
340 | else |
341 | gru_fwd_part2_postgemm_template(linear_f, quantize_f32_u8, |
342 | dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32, |
343 | scales, rnn, cell_position, ws_gates_, scratch_gates_, |
344 | augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_, |
345 | block_step); |
346 | } |
347 | |
348 | template <> |
349 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm) { |
350 | assert(!"GRU signed int8 is not supported" ); |
351 | } |
352 | |
353 | template <> |
354 | rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm) { |
355 | assert(!"GRU signed int8 is not supported" ); |
356 | } |
357 | |
358 | template <typename T, typename src_data_t, typename acc_data_t, |
359 | typename scratch_data_t> |
360 | void gru_bwd_part1_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn, |
361 | cell_position_t cell_position, src_data_t *ws_gates_, |
362 | scratch_data_t *scratch_gates_, const src_data_t *augru_attention_, |
363 | src_data_t *dst_layer_, const src_data_t *src_iter_, |
364 | acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_, |
365 | acc_data_t *diff_augru_attention_, acc_data_t *diff_dst_layer_) { |
366 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
367 | |
368 | const augru_attention_aoc<const src_data_t> augru_attention( |
369 | rnn, augru_attention_); |
370 | const augru_attention_aoc<acc_data_t> diff_augru_attention( |
371 | rnn, diff_augru_attention_); |
372 | |
373 | const ws_states_iter_aoc<const src_data_t> src_iter( |
374 | rnn, src_iter_, src_iter_ld); |
375 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
376 | const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
377 | const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter( |
378 | rnn, diff_src_iter_); |
379 | const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter( |
380 | rnn, diff_dst_iter_); |
381 | const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer( |
382 | rnn, diff_dst_layer_); |
383 | |
384 | // dG2^ = dh * (1 - G0) * (1 - G2^2) |
385 | // dG0^ = dh * (ht-1 - G2) * u * (1 - G0) |
386 | // dht-1 (part) = dh * G0 |
387 | parallel_nd(rnn.mb, [&](dim_t i) { |
388 | acc_data_t diff_attention = 0.0f; |
389 | PRAGMA_OMP_SIMD(reduction(+ : diff_attention)) |
390 | for (int j = 0; j < rnn.dhc; j++) { |
391 | const float h = src_iter(i, j); |
392 | const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j); |
393 | const float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt |
394 | * one_m_square(ws_gates(i, 2, j)); |
395 | float dG0 = (h - ws_gates(i, 2, j)) * dHt |
396 | * x_m_square(ws_gates(i, 0, j)); |
397 | |
398 | if (rnn.is_augru) { |
399 | diff_attention -= dG0 * ws_gates(i, 0, j); |
400 | dG0 *= 1.0f - augru_attention(i); |
401 | } |
402 | |
403 | diff_src_iter(i, j) = dHt * ws_gates(i, 0, j); |
404 | scratch_gates(i, 0, j) = to_src(dG0); |
405 | scratch_gates(i, 2, j) = to_src(dG2); |
406 | } |
407 | if (rnn.is_augru) diff_augru_attention(i) = diff_attention; |
408 | }); |
409 | } |
410 | |
411 | template <typename T, typename src_data_t, typename acc_data_t, |
412 | typename scratch_data_t> |
413 | void gru_bwd_part2_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn, |
414 | cell_position_t cell_position, src_data_t *ws_gates_, |
415 | scratch_data_t *scratch_gates_, src_data_t *dst_layer_, |
416 | const src_data_t *src_iter_, acc_data_t *diff_src_layer_, |
417 | acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_, |
418 | acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_) { |
419 | const auto src_iter_ld = rnn.src_iter_ld(cell_position); |
420 | // auto dst_ld = rnn.dst_ld(cell_position); |
421 | // ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_, dst_ld); |
422 | const ws_states_iter_aoc<const src_data_t> src_iter( |
423 | rnn, src_iter_, src_iter_ld); |
424 | const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_); |
425 | const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_); |
426 | const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer( |
427 | rnn, diff_dst_layer_); |
428 | const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter( |
429 | rnn, diff_dst_iter_); |
430 | |
431 | const ws_diff_states_layer_aoc<acc_data_t> dhG1(rnn, diff_src_layer_); |
432 | const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter( |
433 | rnn, diff_src_iter_); |
434 | const AOC<scratch_data_t, 2> hG1( |
435 | scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld); |
436 | |
437 | // dG1^ = d(hG1) * h * G1 * (1 - G1) |
438 | // dht-1 (part) += d(hG1) * G1 |
439 | // h * G1 (required for dWh) |
440 | parallel_nd(rnn.mb, [&](dim_t i) { |
441 | PRAGMA_OMP_SIMD() |
442 | for (int j = 0; j < rnn.dhc; j++) { |
443 | const float h = src_iter(i, j); |
444 | const float G1 = ws_gates(i, 1, j); |
445 | diff_src_iter(i, j) += dhG1(i, j) * G1; |
446 | scratch_gates(i, 1, j) = to_src(dhG1(i, j) * h * x_m_square(G1)); |
447 | hG1(i, j) = to_src(G1 * h); |
448 | } |
449 | }); |
450 | } |
451 | |
452 | template <> |
453 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm) { |
454 | const auto to_src = [](float a) { return a; }; |
455 | |
456 | gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
457 | scratch_gates_, augru_attention_, dst_layer_, src_iter_, |
458 | diff_src_iter_, diff_dst_iter_, diff_augru_attention_, |
459 | diff_dst_layer_); |
460 | } |
461 | |
462 | template <> |
463 | rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm) { |
464 | const auto to_src = [](float a) { return a; }; |
465 | |
466 | gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
467 | scratch_gates_, dst_layer_, src_iter_, diff_src_layer_, |
468 | diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_); |
469 | } |
470 | |
471 | template <> |
472 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm) { |
473 | const auto to_src = [](float a) { return bfloat16_t(a); }; |
474 | |
475 | gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
476 | scratch_gates_, augru_attention_, dst_layer_, src_iter_, |
477 | diff_src_iter_, diff_dst_iter_, diff_augru_attention_, |
478 | diff_dst_layer_); |
479 | } |
480 | |
481 | template <> |
482 | rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm) { |
483 | const auto to_src = [](float a) { return bfloat16_t(a); }; |
484 | |
485 | gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_, |
486 | scratch_gates_, dst_layer_, src_iter_, diff_src_layer_, |
487 | diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_); |
488 | } |
489 | |
490 | #undef AOC |
491 | } // namespace cpu |
492 | } // namespace impl |
493 | } // namespace dnnl |
494 | |