1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | /* |
18 | General architecture |
19 | |
20 | for diff states, we have n_states + 1 as we have n_states diff |
21 | to propagate to the previous iteration and 1 states to propagate |
22 | to the previous layer |
23 | index 0 is dh for cell(t-1, l) to consume // replaced by diff_src_iter |
24 | index 1 is dc for cell(t-1, l) to consume // replaced by diff_src_iter_c |
25 | index 2 is dh for cell(t, l-1) to consume // replace by diff_src_layer |
26 | this indexing enables to have the same indexing for states in elemwise |
27 | function |
28 | only the cell execution function should be impacted |
29 | |
30 | */ |
31 | |
32 | #include "common/dnnl_thread.hpp" |
33 | #include "common/stream.hpp" |
34 | |
35 | #include "cpu/simple_q10n.hpp" |
36 | |
37 | #include "cpu/gemm/gemm.hpp" |
38 | #include "cpu/gemm/gemm_pack.hpp" |
39 | |
40 | #include "cpu/rnn/ref_rnn.hpp" |
41 | |
42 | namespace dnnl { |
43 | namespace impl { |
44 | namespace cpu { |
45 | |
46 | using namespace dnnl::impl::utils; |
47 | using namespace dnnl::impl::memory_tracking::names; |
48 | using namespace rnn_utils; |
49 | #define AOC array_offset_calculator |
50 | |
51 | // GEMM functions wrapper definitions |
52 | |
53 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
54 | data_type_t acc_type> |
55 | rnn_gemm_sig( |
56 | (_ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::gemm)) { |
57 | assert(!"non packed gemm is unavailable for this data type" ); |
58 | return dnnl_unimplemented; |
59 | } |
60 | |
61 | template <> |
62 | rnn_gemm_sig((ref_rnn_fwd_f32_t::gemm)) { |
63 | assert(ldA * ldB * ldC != 0); |
64 | return extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, |
65 | &ldB, &beta, c_, &ldC, nullptr, pd()->rnn_.force_nocopy); |
66 | } |
67 | |
68 | template <> |
69 | rnn_gemm_sig((ref_rnn_bwd_f32_t::gemm)) { |
70 | assert(ldA * ldB * ldC != 0); |
71 | return extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, |
72 | &ldB, &beta, c_, &ldC, nullptr, pd()->rnn_.force_nocopy); |
73 | } |
74 | |
75 | template <> |
76 | rnn_gemm_sig((ref_rnn_fwd_bf16_t::gemm)) { |
77 | assert(ldA * ldB * ldC != 0); |
78 | return gemm_bf16bf16f32(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, |
79 | &ldB, &beta, c_, &ldC); |
80 | } |
81 | |
82 | template <> |
83 | rnn_gemm_sig((ref_rnn_bwd_bf16_t::gemm)) { |
84 | assert(ldA * ldB * ldC != 0); |
85 | return gemm_bf16bf16f32(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, |
86 | &ldB, &beta, c_, &ldC); |
87 | } |
88 | |
89 | // packed GEMM functions wrapper definitions |
90 | |
91 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
92 | data_type_t acc_type> |
93 | rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
94 | acc_type>::packed_gemm)) { |
95 | assert(!"packed gemm is unavailable for this datatype" ); |
96 | return dnnl_unimplemented; |
97 | } |
98 | |
99 | template <> |
100 | rnn_gemm_sig(ref_rnn_fwd_f32_t::packed_gemm) { |
101 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
102 | return sgemm_compute( |
103 | "P" , "N" , &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC); |
104 | } |
105 | |
106 | template <> |
107 | rnn_gemm_sig(ref_rnn_bwd_f32_t::packed_gemm) { |
108 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
109 | return sgemm_compute( |
110 | "P" , "N" , &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC); |
111 | } |
112 | |
113 | template <> |
114 | rnn_gemm_sig((ref_rnn_fwd_bf16_t::packed_gemm)) { |
115 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
116 | return gemm_bf16bf16f32_compute( |
117 | "P" , "N" , &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC); |
118 | } |
119 | |
120 | template <> |
121 | rnn_gemm_sig((ref_rnn_bwd_bf16_t::packed_gemm)) { |
122 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
123 | return gemm_bf16bf16f32_compute( |
124 | "P" , "N" , &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC); |
125 | } |
126 | |
127 | template <> |
128 | rnn_gemm_sig(ref_rnn_fwd_u8s8_t::packed_gemm) { |
129 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
130 | int32_t offsetc = 0; |
131 | return gemm_s8u8s32_compute("P" , "N" , "F" , &m, &n, &k, a_, &ldA, b_, &ldB, |
132 | &beta, c_, &ldC, &offsetc); |
133 | } |
134 | |
135 | template <> |
136 | rnn_gemm_sig(ref_rnn_fwd_s8s8_t::packed_gemm) { |
137 | assert(transA == 'N' && transB == 'N' && alpha == 1.); |
138 | int32_t offsetc = 0; |
139 | return gemm_s8s8s32_compute("P" , "N" , "F" , &m, &n, &k, a_, &ldA, b_, &ldB, |
140 | &beta, c_, &ldC, &offsetc); |
141 | } |
142 | |
143 | //*************** Grid computations strategy: linear ***************// |
144 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
145 | data_type_t acc_type> |
146 | rnn_grid_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
147 | acc_type>::linear_execution)) { |
148 | const AOC<src_layer_t, 4> ws_states_layer(ws_states_layer_, rnn.n_layer + 1, |
149 | rnn.n_dir, rnn.n_iter + 1, |
150 | rnn.ws_states_layer_nld * rnn.ws_states_layer_ld); |
151 | const AOC<const src_layer_t, 3> augru_attention( |
152 | augru_attention_, rnn.n_iter, rnn.mb, 1); |
153 | const AOC<src_iter_t, 4> ws_states_iter(ws_states_iter_, rnn.n_layer + 1, |
154 | rnn.n_dir, rnn.n_iter + 1, |
155 | rnn.ws_states_iter_nld * rnn.ws_states_iter_ld); |
156 | const auto ws_states_iter_c = rnn_utils::make_raw_aoc(ws_states_iter_c_, |
157 | types::data_type_size(rnn.src_iter_c_dt), rnn.n_layer + 1, |
158 | rnn.n_dir, rnn.n_iter + 1, |
159 | rnn.ws_diff_states_iter_c_nld * rnn.ws_diff_states_iter_c_ld); |
160 | const AOC<gemm_acc_t, 4> ws_diff_states_layer(ws_diff_states_layer_, |
161 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, |
162 | rnn.ws_diff_states_layer_nld * rnn.ws_diff_states_layer_ld); |
163 | const AOC<gemm_acc_t, 3> diff_augru_attention( |
164 | diff_augru_attention_, rnn.n_iter, rnn.mb, 1); |
165 | const AOC<gemm_acc_t, 4> ws_diff_states_iter(ws_diff_states_iter_, |
166 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, |
167 | rnn.ws_diff_states_iter_nld * rnn.ws_diff_states_iter_ld); |
168 | const AOC<gemm_acc_t, 4> ws_diff_states_iter_c(ws_diff_states_iter_c_, |
169 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, |
170 | rnn.ws_diff_states_iter_c_nld * rnn.ws_diff_states_iter_c_ld); |
171 | const AOC<gates_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, |
172 | rnn.n_iter, rnn.ws_gates_nld * rnn.ws_gates_ld); |
173 | const AOC<dst_iter_t, 4> ws_ht(ws_ht_, rnn.n_layer, rnn.n_dir, rnn.n_iter, |
174 | rnn.ws_ht_nld * rnn.ws_ht_ld); |
175 | const AOC<weights_t *, 3> weights_layer( |
176 | weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); |
177 | const AOC<weights_t *, 3> weights_iter( |
178 | weights_iter_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); |
179 | const AOC<weights_t *, 2> weights_projection( |
180 | weights_projection_, rnn.n_layer, rnn.n_dir); |
181 | const AOC<const float, 3> weights_peephole( |
182 | weights_peephole_, rnn.n_layer, rnn.n_dir, 3 * rnn.dhc); |
183 | bias_linear_exec_aoc_t bias(rnn, bias_); |
184 | const AOC<gemm_acc_t, 3> diff_weights_layer(diff_weights_layer_, |
185 | rnn.n_layer, rnn.n_dir, |
186 | rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); |
187 | const AOC<gemm_acc_t, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, |
188 | rnn.n_dir, rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); |
189 | const AOC<float, 3> diff_weights_peephole( |
190 | diff_weights_peephole_, rnn.n_layer, rnn.n_dir, 3 * rnn.dhc); |
191 | const AOC<float, 3> diff_weights_projection(diff_weights_projection_, |
192 | rnn.n_layer, rnn.n_dir, |
193 | rnn.diff_weights_projection_nld * rnn.diff_weights_projection_ld); |
194 | const AOC<float, 3> diff_bias( |
195 | diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc); |
196 | const AOC<gates_t, 4> ws_grid( |
197 | ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); |
198 | |
199 | /* Raw inputs/outputs coming from the user */ |
200 | // Here we cannot use AOC as user's input can have arbitrary strides, so we use desc_wrapper. |
201 | const auto src_layer_mdw = memory_desc_wrapper(pd()->src_md(0)); |
202 | const auto dst_layer_mdw = memory_desc_wrapper(pd()->dst_md(0)); |
203 | const auto src_iter_mdw = memory_desc_wrapper(pd()->src_md(1)); |
204 | const auto dst_iter_mdw = memory_desc_wrapper(pd()->dst_md(1)); |
205 | const auto src_iter_c_mdw = memory_desc_wrapper(pd()->src_md(2)); |
206 | const auto dst_iter_c_mdw = memory_desc_wrapper(pd()->dst_md(2)); |
207 | |
208 | // Since the function FN(...) returns by reference so an extra exception |
209 | // has to be made for nullptr argument |
210 | #define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr |
211 | const auto compute_merged_layer_part_if_applicable |
212 | = [&](prop_kind_t target_prop, int dir, int lay) { |
213 | if (IMPLICATION(rnn.merge_gemm_layer, aprop != target_prop)) |
214 | return dnnl_success; |
215 | |
216 | cell_position_t cell_position = middle_cell; |
217 | if (lay == 0) cell_position |= first_layer; |
218 | |
219 | const src_layer_t *src_layer |
220 | = lay == 0 && rnn.skip_src_layer_copy() |
221 | ? src_layer_ |
222 | : SAFE_PTR(ws_states_layer, lay, dir, 1, 0); |
223 | #if DNNL_X64 |
224 | CHECK((this->*merged_layer_func)(ctx, rnn, cell_position, |
225 | SAFE_PTR(weights_layer, lay, dir, 0), src_layer, |
226 | scratch_gates_, |
227 | SAFE_PTR(ws_diff_states_layer, lay, dir, 0, 0), |
228 | SAFE_PTR(diff_weights_layer, lay, dir, 0), |
229 | amx_scratchpad, addr_batch_global)); |
230 | #else |
231 | CHECK((this->*merged_layer_func)(rnn, cell_position, |
232 | SAFE_PTR(weights_layer, lay, dir, 0), src_layer, |
233 | scratch_gates_, |
234 | SAFE_PTR(ws_diff_states_layer, lay, dir, 0, 0), |
235 | SAFE_PTR(diff_weights_layer, lay, dir, 0))); |
236 | #endif |
237 | return dnnl_success; |
238 | }; |
239 | |
240 | // We run the grid of computation |
241 | for_(int dir = 0; dir < rnn.n_dir; dir++) |
242 | for (int j = 0; j < rnn.n_layer; j++) { |
243 | const int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; |
244 | |
245 | CHECK(compute_merged_layer_part_if_applicable( |
246 | prop_kind::forward, dir, lay)); |
247 | |
248 | // TODO: enable merging projection gemm in bwd lstm projection |
249 | |
250 | for (int i = 0; i < rnn.n_iter; i++) { |
251 | const int iter |
252 | = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; |
253 | |
254 | // We set parameters to the cell execution call |
255 | |
256 | // dst_layer is equal to dst_iter. To avoid |
257 | // duplication of memory access we hence use only |
258 | // dst_layer and set dst_iter to nullptr, unless we |
259 | // cannot for one of the following condition: |
260 | // - in the last layer and last iteration, we need to |
261 | // copy ht in two tensors (dst_layer and dst_iter) |
262 | dst_layer_t *cell_dst_layer |
263 | = &(ws_states_layer(lay + 1, dir, iter + 1, 0)); |
264 | dst_iter_t *cell_dst_iter = nullptr; |
265 | const src_layer_t *cell_src_layer |
266 | = &(ws_states_layer(lay, dir, iter + 1, 0)); |
267 | const src_iter_t *cell_src_iter |
268 | = &(ws_states_iter(lay + 1, dir, iter, 0)); |
269 | |
270 | void *cell_dst_iter_c = const_cast<void *>( |
271 | ws_states_iter_c(lay + 1, dir, iter + 1, 0)); |
272 | const void *cell_src_iter_c |
273 | = ws_states_iter_c(lay + 1, dir, iter, 0); |
274 | |
275 | // the cell_position is used only when skip_data_copy is |
276 | // supported currently supported only for forward |
277 | cell_position_t cell_position = middle_cell; |
278 | if (iter == 0) cell_position |= first_iter; |
279 | if (lay == 0) cell_position |= first_layer; |
280 | if (iter == rnn.n_iter - 1) cell_position |= last_iter; |
281 | if (lay == rnn.n_layer - 1) cell_position |= last_layer; |
282 | |
283 | // The dst_* paths should be before the src_* paths as |
284 | // the later will override cell_src_layer and |
285 | // cell_src_iter appropriately for 1st layer and 1st |
286 | // iter. |
287 | const bool last_iter_skip_copy |
288 | = rnn.skip_dst_iter_copy() && (cell_position & last_iter); |
289 | if (last_iter_skip_copy) { |
290 | cell_dst_layer = dst_iter_ + dst_iter_mdw.off(lay, dir, 0, 0); |
291 | cell_src_layer |
292 | = dst_iter_ + dst_iter_mdw.off(lay - 1, dir, 0, 0); |
293 | } |
294 | |
295 | if (rnn.skip_dst_layer_copy() && (cell_position & last_layer)) { |
296 | // Note: for last layer and last iter, the output is in dst_layer |
297 | // and still need to be copied to dst_iter |
298 | cell_dst_layer = dst_layer_ + dst_layer_mdw.off(iter, 0, 0); |
299 | cell_dst_iter = last_iter_skip_copy |
300 | ? dst_iter_ + dst_iter_mdw.off(lay, dir, 0, 0) |
301 | : nullptr; |
302 | cell_src_iter = (iter != 0) |
303 | ? dst_layer_ + dst_layer_mdw.off(iter - 1, 0, 0) |
304 | : cell_src_iter; |
305 | } |
306 | if (rnn.skip_src_iter_copy() && (cell_position & first_iter)) |
307 | cell_src_iter = src_iter_ + src_iter_mdw.off(lay, dir, 0, 0); |
308 | |
309 | if (rnn.skip_src_layer_copy() && (cell_position & first_layer)) |
310 | cell_src_layer = src_layer_ + src_layer_mdw.off(iter, 0, 0); |
311 | |
312 | // because the c state is always f32 and require no |
313 | // conversion, we can always skip to copy for the 1st |
314 | // and last iteration |
315 | if (iter == 0 && src_iter_c_) { |
316 | cell_src_iter_c = inc_ptr(src_iter_c_, rnn.src_iter_c_dt, |
317 | src_iter_c_mdw.off(lay, dir, 0, 0)); |
318 | cell_position |= c_state_first_iter; |
319 | } |
320 | if (iter == rnn.n_iter - 1 && dst_iter_c_) { |
321 | cell_dst_iter_c = inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt, |
322 | dst_iter_c_mdw.off(lay, dir, 0, 0)); |
323 | cell_position |= c_state_last_iter; |
324 | } |
325 | const size_t sg_start_idx = rnn.n_iter_scratch_gates == 1 |
326 | ? static_cast<size_t>(0) |
327 | : static_cast<size_t>(iter) * rnn.scratch_gates_nld |
328 | * rnn.scratch_gates_ld; |
329 | const auto cell_scratch_gates = &scratch_gates_[sg_start_idx]; |
330 | |
331 | dst_iter_t *proj_ht = nullptr; |
332 | if (rnn.is_lstm_projection) { |
333 | if (rnn.is_training) |
334 | proj_ht = &(ws_ht(lay, dir, iter, 0)); |
335 | else |
336 | proj_ht = scratch_ht_; |
337 | } |
338 | |
339 | #if DNNL_X64 |
340 | CHECK((this->*cell_func)(ctx, rnn, cell_position, cell_dst_layer, |
341 | cell_dst_iter_c, |
342 | SAFE_PTR(ws_diff_states_layer, lay, dir, iter, 0), |
343 | SAFE_PTR(diff_augru_attention, iter, 0, 0), |
344 | SAFE_PTR(ws_diff_states_iter, lay, dir, iter, 0), |
345 | SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter, 0), |
346 | SAFE_PTR(weights_layer, lay, dir, 0), |
347 | SAFE_PTR(weights_iter, lay, dir, 0), |
348 | SAFE_PTR(weights_projection, lay, dir), |
349 | SAFE_PTR(weights_peephole, lay, dir, 0), |
350 | w_proj_comp ? w_proj_comp + (j * rnn.n_dir + dir) * rnn.dic |
351 | : nullptr, |
352 | bias(lay, dir), cell_src_layer, |
353 | SAFE_PTR(augru_attention, iter, 0, 0), cell_src_iter, |
354 | cell_src_iter_c, |
355 | SAFE_PTR(ws_diff_states_layer, lay + 1, dir, iter, 0), |
356 | SAFE_PTR(ws_diff_states_iter, lay, dir, iter + 1, 0), |
357 | SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter + 1, 0), |
358 | SAFE_PTR(diff_weights_layer, lay, dir, 0), |
359 | SAFE_PTR(diff_weights_iter, lay, dir, 0), |
360 | SAFE_PTR(diff_weights_projection, lay, dir, 0), |
361 | SAFE_PTR(diff_weights_peephole, lay, dir, 0), |
362 | SAFE_PTR(diff_bias, lay, dir, 0), |
363 | SAFE_PTR(ws_gates, lay, dir, iter, 0), cell_scratch_gates, |
364 | proj_ht, scratch_diff_ht_, |
365 | SAFE_PTR(ws_grid, lay, dir, iter, 0), scratch_cell_, |
366 | scratch_gates_blocked_, scratch_src_layer_, |
367 | scratch_src_iter_, cell_dst_iter, amx_scratchpad, |
368 | addr_batch_global)); |
369 | #else |
370 | CHECK((this->*cell_func)(rnn, cell_position, cell_dst_layer, |
371 | cell_dst_iter_c, |
372 | SAFE_PTR(ws_diff_states_layer, lay, dir, iter, 0), |
373 | SAFE_PTR(diff_augru_attention, iter, 0, 0), |
374 | SAFE_PTR(ws_diff_states_iter, lay, dir, iter, 0), |
375 | SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter, 0), |
376 | SAFE_PTR(weights_layer, lay, dir, 0), |
377 | SAFE_PTR(weights_iter, lay, dir, 0), |
378 | SAFE_PTR(weights_projection, lay, dir), |
379 | SAFE_PTR(weights_peephole, lay, dir, 0), |
380 | w_proj_comp ? w_proj_comp + (j * rnn.n_dir + dir) * rnn.dic |
381 | : nullptr, |
382 | bias(lay, dir), cell_src_layer, |
383 | SAFE_PTR(augru_attention, iter, 0, 0), cell_src_iter, |
384 | cell_src_iter_c, |
385 | SAFE_PTR(ws_diff_states_layer, lay + 1, dir, iter, 0), |
386 | SAFE_PTR(ws_diff_states_iter, lay, dir, iter + 1, 0), |
387 | SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter + 1, 0), |
388 | SAFE_PTR(diff_weights_layer, lay, dir, 0), |
389 | SAFE_PTR(diff_weights_iter, lay, dir, 0), |
390 | SAFE_PTR(diff_weights_projection, lay, dir, 0), |
391 | SAFE_PTR(diff_weights_peephole, lay, dir, 0), |
392 | SAFE_PTR(diff_bias, lay, dir, 0), |
393 | SAFE_PTR(ws_gates, lay, dir, iter, 0), cell_scratch_gates, |
394 | proj_ht, scratch_diff_ht_, |
395 | SAFE_PTR(ws_grid, lay, dir, iter, 0), scratch_cell_, |
396 | cell_dst_iter, amx_scratchpad)); |
397 | #endif |
398 | } |
399 | |
400 | CHECK(compute_merged_layer_part_if_applicable( |
401 | prop_kind::backward, dir, lay)); |
402 | #undef SAFE_PTR |
403 | |
404 | if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { |
405 | // This is split in 3 pieces if we skip copies. |
406 | // last iter in user mem, middle iters in ws, first iter in user mem |
407 | // Note 1: here we assume no change in datatypes for src_iter, ws_iter and dst_iter |
408 | |
409 | const dst_iter_t *states_iter = nullptr; |
410 | int states_iter_ld = 0; |
411 | int niter_merge_gemm_iter = 0; |
412 | |
413 | states_iter = &( |
414 | ws_states_iter(lay + 1, dir, rnn.skip_src_iter_copy(), 0)); |
415 | states_iter_ld = rnn.ws_states_iter_ld; |
416 | if (rnn.skip_dst_layer_copy() |
417 | && (lay == rnn.n_layer - 1)) { // last layer |
418 | states_iter = dst_layer_; |
419 | states_iter_ld = rnn.dst_layer_ld_; |
420 | } |
421 | niter_merge_gemm_iter = rnn.n_iter - rnn.skip_src_iter_copy(); |
422 | if (niter_merge_gemm_iter > 0) { |
423 | CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, |
424 | rnn.mb * niter_merge_gemm_iter, 1.0, |
425 | (weights_t *)scratch_gates_ |
426 | + rnn.skip_src_iter_copy() |
427 | * rnn.scratch_gates_nld |
428 | * rnn.scratch_gates_ld, |
429 | rnn.scratch_gates_ld, states_iter, states_iter_ld, 1.0, |
430 | &(diff_weights_iter(lay, dir, 0)), |
431 | rnn.diff_weights_iter_ld)); |
432 | } |
433 | |
434 | if (rnn.skip_src_iter_copy()) { |
435 | states_iter = src_iter_ + src_iter_mdw.off(lay, dir, 0, 0); |
436 | states_iter_ld = rnn.src_iter_ld_; |
437 | niter_merge_gemm_iter = 1; |
438 | CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic, |
439 | rnn.mb * niter_merge_gemm_iter, 1.0, |
440 | (weights_t *)scratch_gates_, rnn.scratch_gates_ld, |
441 | states_iter, states_iter_ld, 1.0, |
442 | &(diff_weights_iter(lay, dir, 0)), |
443 | rnn.diff_weights_iter_ld)); |
444 | } |
445 | } |
446 | } |
447 | return dnnl_success; |
448 | } |
449 | |
450 | //********* GRID computations strategy: utility functions **********// |
451 | |
452 | // for bf32 src_data_t(bf16) and input_data_t(f32) types can be different. |
453 | template <typename src_data_t, typename input_data_t> |
454 | void copy_init_layer_fwd_template(const rnn_conf_t &rnn, |
455 | src_data_t *__restrict ws_states_layer_, |
456 | const input_data_t *__restrict xt_, const memory_desc_wrapper &xt_d) { |
457 | |
458 | const AOC<src_data_t, 4> ws_states_layer(ws_states_layer_, rnn.n_dir, |
459 | rnn.n_iter + 1, rnn.mb, rnn.ws_states_layer_ld); |
460 | |
461 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
462 | auto xxt = xt_ + xt_d.blk_off(it, b); |
463 | src_data_t *ws_l2r_ptr = &(ws_states_layer(0, it + 1, b, 0)); |
464 | src_data_t *ws_r2l_ptr |
465 | = &(ws_states_layer(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); |
466 | if (rnn.exec_dir != r2l) { |
467 | if (rnn.is_bf32()) { |
468 | cvt_float_to_bfloat16( |
469 | (bfloat16_t *)ws_l2r_ptr, (const float *)xxt, rnn.slc); |
470 | } else { |
471 | PRAGMA_OMP_SIMD() |
472 | for (int c = 0; c < rnn.slc; c++) |
473 | ws_l2r_ptr[c] = xxt[c]; |
474 | } |
475 | } |
476 | if (rnn.exec_dir != l2r) { |
477 | if (rnn.is_bf32()) { |
478 | cvt_float_to_bfloat16( |
479 | (bfloat16_t *)ws_r2l_ptr, (const float *)xxt, rnn.slc); |
480 | } else { |
481 | PRAGMA_OMP_SIMD() |
482 | for (int c = 0; c < rnn.slc; c++) |
483 | ws_r2l_ptr[c] = xxt[c]; |
484 | } |
485 | } |
486 | }); |
487 | } |
488 | |
489 | template <typename acc_data_t> |
490 | void copy_init_layer_bwd_template(const rnn_conf_t &rnn, |
491 | acc_data_t *ws_diff_states_layer_, const acc_data_t *diff_dst_layer_, |
492 | const memory_desc_wrapper &diff_dst_layer_d) { |
493 | const AOC<acc_data_t, 5> ws_diff_states_layer(ws_diff_states_layer_, |
494 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
495 | rnn.ws_diff_states_layer_ld); |
496 | |
497 | switch (rnn.exec_dir) { |
498 | case bi_concat: |
499 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
500 | const auto diff_dst_layer_x |
501 | = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); |
502 | for (int s = 0; s < rnn.dlc; s++) { |
503 | ws_diff_states_layer(rnn.n_layer, 0, it, b, s) |
504 | = diff_dst_layer_x[s]; |
505 | ws_diff_states_layer( |
506 | rnn.n_layer, 1, rnn.n_iter - it - 1, b, s) |
507 | = diff_dst_layer_x[rnn.dlc + s]; |
508 | } |
509 | }); |
510 | break; |
511 | case bi_sum: |
512 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
513 | const auto diff_dst_layer_x |
514 | = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); |
515 | for (int s = 0; s < rnn.dlc; s++) { |
516 | ws_diff_states_layer(rnn.n_layer, 0, it, b, s) |
517 | = diff_dst_layer_x[s]; |
518 | ws_diff_states_layer( |
519 | rnn.n_layer, 1, rnn.n_iter - it - 1, b, s) |
520 | = diff_dst_layer_x[s]; |
521 | } |
522 | }); |
523 | break; |
524 | case l2r: |
525 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
526 | const auto diff_dst_layer_x |
527 | = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); |
528 | for (int s = 0; s < rnn.dlc; s++) { |
529 | ws_diff_states_layer(rnn.n_layer, 0, it, b, s) |
530 | = diff_dst_layer_x[s]; |
531 | } |
532 | }); |
533 | break; |
534 | case r2l: |
535 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
536 | const auto diff_dst_layer_x = diff_dst_layer_ |
537 | + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); |
538 | for (int s = 0; s < rnn.dlc; s++) { |
539 | ws_diff_states_layer(rnn.n_layer, 0, it, b, s) |
540 | = diff_dst_layer_x[s]; |
541 | } |
542 | }); |
543 | break; |
544 | default: assert(!"Unsupported direction" ); break; |
545 | } |
546 | } |
547 | |
548 | #define RNN_DECL_COPY_INIT_LAYER_FWD(cname) \ |
549 | template <> \ |
550 | template <typename input_data_t> \ |
551 | void cname::copy_init_layer(const rnn_conf_t &rnn, \ |
552 | src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_, \ |
553 | const input_data_t *xt_, const gemm_acc_t *diff_dst_layer_) \ |
554 | const { \ |
555 | copy_init_layer_fwd_template(rnn, ws_states_layer_, xt_, \ |
556 | memory_desc_wrapper(pd()->src_md(0))); \ |
557 | } |
558 | |
559 | RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_f32_t) |
560 | RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_bf16_t) |
561 | RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_u8s8_t) |
562 | RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_s8s8_t) |
563 | |
564 | #define RNN_DECL_COPY_INIT_LAYER_BWD(cname) \ |
565 | template <> \ |
566 | template <typename input_data_t> \ |
567 | void cname::copy_init_layer(const rnn_conf_t &rnn, \ |
568 | src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_, \ |
569 | const input_data_t *xt_, const gemm_acc_t *diff_dst_layer_) \ |
570 | const { \ |
571 | copy_init_layer_bwd_template(rnn, ws_diff_states_layer_, \ |
572 | diff_dst_layer_, memory_desc_wrapper(pd()->diff_dst_md(0))); \ |
573 | } |
574 | |
575 | RNN_DECL_COPY_INIT_LAYER_BWD(ref_rnn_bwd_f32_t) |
576 | RNN_DECL_COPY_INIT_LAYER_BWD(ref_rnn_bwd_bf16_t) |
577 | |
578 | /* For int8 configuration, input iteration states may be of types f32 or u8 |
579 | * Internally h_state is always stored in u8 and c_state is always stored in f32 |
580 | * If input states are of type u8 then h state is copied and c state is dequantized |
581 | * If input states are of type f32 then h state is quantized and c_state is copied |
582 | * */ |
583 | template <typename src_data_t, typename input_data_t> |
584 | void copy_init_iter_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, |
585 | src_data_t *__restrict ws_states_iter_, |
586 | void *__restrict ws_states_iter_c_, |
587 | const input_data_t *__restrict src_iter_, |
588 | const memory_desc_wrapper &src_iter_d, |
589 | const void *__restrict src_iter_c_, |
590 | const memory_desc_wrapper &src_iter_c_d) { |
591 | const AOC<src_data_t, 5> ws_states_iter(ws_states_iter_, rnn.n_layer + 1, |
592 | rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.ws_states_iter_ld); |
593 | const auto ws_states_iter_c_aoc = rnn_utils::make_raw_aoc(ws_states_iter_c_, |
594 | types::data_type_size(rnn.src_iter_c_dt), rnn.n_layer + 1, |
595 | rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.ws_states_iter_c_ld); |
596 | |
597 | const float data_shift = pd->attr()->rnn_data_qparams_.shift_; |
598 | const float data_scale = pd->attr()->rnn_data_qparams_.scale_; |
599 | |
600 | const bool quantize = rnn.is_int8_conf() |
601 | && IMPLICATION(pd->with_src_iter(), |
602 | pd->src_md(1)->data_type == data_type::f32); |
603 | const auto maybe_q = [&](input_data_t f) { |
604 | if (quantize) { |
605 | float qf = f * data_scale + data_shift; |
606 | return qz_a1b0<float, src_data_t>()(qf); |
607 | } else |
608 | return (src_data_t)f; |
609 | }; |
610 | const src_data_t zero = maybe_q(0.f); |
611 | const auto zero_ws_iter_c = [&](int lay, int dir, int mb_id, int sic_id) { |
612 | void *ws_states_iter_c = const_cast<void *>( |
613 | ws_states_iter_c_aoc(lay, dir, 0, mb_id, sic_id)); |
614 | if (rnn.src_iter_c_dt == data_type::f32) |
615 | *(static_cast<float *>(ws_states_iter_c)) = 0.0f; |
616 | else if (rnn.src_iter_c_dt == data_type::bf16) |
617 | *(static_cast<bfloat16_t *>(ws_states_iter_c)) = 0.0f; |
618 | }; |
619 | |
620 | if (src_iter_) { |
621 | parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, |
622 | [&](dim_t lay, dim_t dir, dim_t b) { |
623 | const auto *ss |
624 | = &src_iter_[src_iter_d.blk_off(lay, dir, b, 0)]; |
625 | auto *dd = &ws_states_iter(lay + 1, dir, 0, b, 0); |
626 | PRAGMA_OMP_SIMD() |
627 | for (int s = 0; s < rnn.sic; s++) |
628 | dd[s] = maybe_q(ss[s]); |
629 | }); |
630 | } else { |
631 | parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, |
632 | [&](dim_t lay, dim_t dir, dim_t b) { |
633 | for (int j = 0; j < rnn.sic; j++) |
634 | ws_states_iter(lay + 1, dir, 0, b, j) = zero; |
635 | if (pd->cell_kind() == alg_kind::vanilla_lstm) |
636 | for (int j = 0; j < rnn.dhc; j++) |
637 | zero_ws_iter_c(lay + 1, dir, b, j); |
638 | }); |
639 | } |
640 | } |
641 | |
642 | template <typename acc_data_t> |
643 | void copy_init_iter_bwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, |
644 | acc_data_t *ws_diff_states_iter_, acc_data_t *ws_diff_states_iter_c_, |
645 | const acc_data_t *diff_dst_iter_, |
646 | const memory_desc_wrapper diff_dst_iter_d, |
647 | const float *diff_dst_iter_c_, |
648 | const memory_desc_wrapper diff_dst_iter_c_d) { |
649 | const AOC<acc_data_t, 5> ws_diff_states_iter(ws_diff_states_iter_, |
650 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
651 | rnn.ws_diff_states_iter_ld); |
652 | const AOC<acc_data_t, 5> ws_diff_states_iter_c(ws_diff_states_iter_c_, |
653 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
654 | rnn.ws_diff_states_iter_c_ld); |
655 | if (diff_dst_iter_) { |
656 | parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, |
657 | [&](dim_t lay, dim_t dir, dim_t b) { |
658 | array_copy( |
659 | &(ws_diff_states_iter(lay, dir, rnn.n_iter, b, 0)), |
660 | diff_dst_iter_ |
661 | + diff_dst_iter_d.blk_off(lay, dir, b), |
662 | rnn.dic); |
663 | if (pd->cell_kind() == alg_kind::vanilla_lstm) |
664 | array_copy(&(ws_diff_states_iter_c( |
665 | lay, dir, rnn.n_iter, b, 0)), |
666 | diff_dst_iter_c_ |
667 | + diff_dst_iter_c_d.blk_off( |
668 | lay, dir, b), |
669 | rnn.dhc); |
670 | }); |
671 | } else { |
672 | parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, |
673 | [&](dim_t lay, dim_t dir, dim_t i) { |
674 | for (int j = 0; j < rnn.dic; j++) |
675 | ws_diff_states_iter(lay, dir, rnn.n_iter, i, j) = 0.0f; |
676 | if (pd->cell_kind() == alg_kind::vanilla_lstm) |
677 | for (int j = 0; j < rnn.dhc; j++) |
678 | ws_diff_states_iter_c(lay, dir, rnn.n_iter, i, j) |
679 | = 0.0f; |
680 | }); |
681 | } |
682 | } |
683 | |
684 | #define RNN_DECL_COPY_INIT_ITER_FWD(cname) \ |
685 | template <> \ |
686 | template <typename input_data_t> \ |
687 | void cname::copy_init_iter(const rnn_conf_t &rnn, \ |
688 | src_layer_t *__restrict ws_states_iter_, \ |
689 | void *__restrict ws_states_iter_c_, \ |
690 | gemm_acc_t *__restrict ws_diff_states_iter_, \ |
691 | gemm_acc_t *__restrict ws_diff_states_iter_c_, \ |
692 | const input_data_t *__restrict src_iter_, \ |
693 | const void *__restrict src_iter_c_, \ |
694 | const gemm_acc_t *__restrict diff_dst_iter_, \ |
695 | const float *__restrict diff_dst_iter_c_) const { \ |
696 | auto src_iter_d = memory_desc_wrapper(pd()->src_md(1)); \ |
697 | auto src_iter_c_d = memory_desc_wrapper(pd()->src_md(2)); \ |
698 | copy_init_iter_fwd_template(rnn, pd(), ws_states_iter_, \ |
699 | ws_states_iter_c_, src_iter_, src_iter_d, src_iter_c_, \ |
700 | src_iter_c_d); \ |
701 | } |
702 | |
703 | RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_f32_t) |
704 | RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_bf16_t) |
705 | RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_u8s8_t) |
706 | RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_s8s8_t) |
707 | |
708 | #define RNN_DECL_COPY_INIT_ITER_BWD(cname) \ |
709 | template <> \ |
710 | template <typename input_data_t> \ |
711 | void cname::copy_init_iter(const rnn_conf_t &rnn, \ |
712 | src_layer_t *ws_states_iter_, void *ws_states_iter_c_, \ |
713 | gemm_acc_t *ws_diff_states_iter_, \ |
714 | gemm_acc_t *ws_diff_states_iter_c_, const input_data_t *src_iter_, \ |
715 | const void *src_iter_c_, const gemm_acc_t *diff_dst_iter_, \ |
716 | const float *diff_dst_iter_c_) const { \ |
717 | auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); \ |
718 | auto diff_dst_iter_c_d = memory_desc_wrapper(pd()->diff_dst_md(2)); \ |
719 | copy_init_iter_bwd_template(rnn, pd(), ws_diff_states_iter_, \ |
720 | ws_diff_states_iter_c_, diff_dst_iter_, diff_dst_iter_d, \ |
721 | diff_dst_iter_c_, diff_dst_iter_c_d); \ |
722 | } |
723 | |
724 | RNN_DECL_COPY_INIT_ITER_BWD(ref_rnn_bwd_f32_t) |
725 | RNN_DECL_COPY_INIT_ITER_BWD(ref_rnn_bwd_bf16_t) |
726 | |
727 | template <typename src_data_t, typename dst_layer_dt, typename dst_iter_dt> |
728 | void copy_res_layer_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, |
729 | dst_layer_dt *dst_layer_, memory_desc_wrapper &dst_layer_d, |
730 | const dst_iter_dt *dst_iter_, const memory_desc_wrapper &dst_iter_d, |
731 | const src_data_t *ws_states_layer_) { |
732 | |
733 | const AOC<const src_data_t, 5> ws_states_layer(ws_states_layer_, |
734 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
735 | rnn.ws_states_layer_ld); |
736 | const float shift = (pd->attr()->rnn_data_qparams_.shift_); |
737 | const float scale = (pd->attr()->rnn_data_qparams_.scale_); |
738 | |
739 | const bool dequantize |
740 | = pd->dst_md(0)->data_type == data_type::f32 && rnn.is_int8_conf(); |
741 | const bool dequantize_at_copy = dequantize && rnn.exec_dir != bi_sum; |
742 | |
743 | // minor optimization helper for a compiler |
744 | static constexpr bool rnn_u8u8_case |
745 | = std::is_same<dst_layer_dt, uint8_t>::value |
746 | && std::is_same<src_data_t, uint8_t>::value; |
747 | static constexpr bool rnn_s8s8_case |
748 | = std::is_same<dst_layer_dt, int8_t>::value |
749 | && std::is_same<src_data_t, int8_t>::value; |
750 | |
751 | const auto copy_vec = [&](dst_layer_dt *dd, const src_data_t *ss) { |
752 | if (dequantize_at_copy) { |
753 | PRAGMA_OMP_SIMD() |
754 | for (int s = 0; s < rnn.dlc; s++) |
755 | dd[s] = (dst_layer_dt)(((float)ss[s] - shift) / scale); |
756 | } else { |
757 | PRAGMA_OMP_SIMD() |
758 | for (int s = 0; s < rnn.dlc; s++) |
759 | dd[s] = (dst_layer_dt)ss[s]; |
760 | } |
761 | }; |
762 | |
763 | const auto acc_vec = [&](dst_layer_dt *dd, const src_data_t *ss) { |
764 | if (dequantize) { |
765 | PRAGMA_OMP_SIMD() |
766 | for (int s = 0; s < rnn.dlc; s++) { |
767 | float val = (float)ss[s] + dd[s]; |
768 | val = qz_a1b0<float, src_data_t>()(val); |
769 | dd[s] = (dst_layer_dt)((val - 2 * shift) / scale); |
770 | } |
771 | } else if (rnn_u8u8_case |
772 | || rnn_s8s8_case) { // instead of checking for rnn.is_int8() |
773 | PRAGMA_OMP_SIMD() |
774 | for (int s = 0; s < rnn.dlc; s++) |
775 | dd[s] = saturate<dst_layer_dt, int16_t>( |
776 | (int16_t)dd[s] + (int16_t)ss[s]); |
777 | } else { |
778 | PRAGMA_OMP_SIMD() |
779 | for (int s = 0; s < rnn.dlc; s++) |
780 | dd[s] += (dst_layer_dt)ss[s]; |
781 | } |
782 | }; |
783 | |
784 | // if skip_dst_iter_copy, then the data for the last iteration is |
785 | // in dst_iter, not in workspace |
786 | parallel_nd(rnn.n_iter - (rnn.skip_dst_iter_copy() ? 1 : 0), rnn.mb, |
787 | [&](dim_t it, dim_t b) { |
788 | int dir = 0; |
789 | if (rnn.exec_dir != r2l) { |
790 | const auto *ss |
791 | = &ws_states_layer(rnn.n_layer, dir, it + 1, b, 0); |
792 | auto *dd = &dst_layer_[dst_layer_d.blk_off( |
793 | it, b, dir * rnn.dlc)]; |
794 | copy_vec(dd, ss); |
795 | dir = 1; |
796 | } |
797 | if (rnn.exec_dir != l2r) { |
798 | const auto *ss = &ws_states_layer( |
799 | rnn.n_layer, dir, rnn.n_iter - it, b, 0); |
800 | if (rnn.exec_dir == bi_sum) { |
801 | auto *dd = &dst_layer_[dst_layer_d.blk_off(it, b, 0)]; |
802 | acc_vec(dd, ss); |
803 | } else { |
804 | auto *dd = &dst_layer_[dst_layer_d.blk_off( |
805 | it, b, dir * rnn.dlc)]; |
806 | copy_vec(dd, ss); |
807 | } |
808 | } |
809 | }); |
810 | if (rnn.skip_dst_iter_copy()) { |
811 | parallel_nd(rnn.mb, [&](dim_t b) { |
812 | const int it = rnn.n_iter - 1; |
813 | int dir = 0; |
814 | if (rnn.exec_dir != r2l) { |
815 | const auto *ss = dst_iter_ |
816 | + dst_iter_d.blk_off(rnn.n_layer - 1, dir, b, 0); |
817 | auto *dd = &dst_layer_[dst_layer_d.blk_off( |
818 | it, b, dir * rnn.dlc)]; |
819 | copy_vec(dd, (src_data_t *)ss); |
820 | dir = 1; |
821 | } |
822 | if (rnn.exec_dir != l2r) { |
823 | const auto *ss = dst_iter_ |
824 | + dst_iter_d.blk_off(rnn.n_layer - 1, dir, b, 0); |
825 | if (rnn.exec_dir == bi_sum) { |
826 | auto *dd = &dst_layer_[dst_layer_d.blk_off(it, b, 0)]; |
827 | acc_vec(dd, (src_data_t *)ss); |
828 | } else { |
829 | auto *dd = &dst_layer_[dst_layer_d.blk_off( |
830 | it, b, dir * rnn.dlc)]; |
831 | copy_vec(dd, (src_data_t *)ss); |
832 | } |
833 | } |
834 | }); |
835 | } |
836 | } |
837 | |
838 | template <typename acc_data_t> |
839 | void copy_res_layer_bwd_template(const rnn_conf_t &rnn, |
840 | acc_data_t *diff_src_layer_, memory_desc_wrapper &diff_src_layer_d, |
841 | const acc_data_t *ws_diff_states_layer_) { |
842 | const AOC<const acc_data_t, 5> ws_diff_states_layer(ws_diff_states_layer_, |
843 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
844 | rnn.ws_diff_states_layer_ld); |
845 | |
846 | parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) { |
847 | int dir = 0; |
848 | for (int s = 0; s < rnn.slc; s++) { |
849 | acc_data_t *dst_addr = diff_src_layer_ |
850 | + diff_src_layer_d.blk_off( |
851 | (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, b, |
852 | dir * rnn.slc + s); |
853 | acc_data_t res = ws_diff_states_layer(0, 0, it, b, s); |
854 | if (rnn.n_dir - 1) |
855 | res += ws_diff_states_layer(0, 1, rnn.n_iter - 1 - it, b, s); |
856 | dst_addr[0] = res; |
857 | } |
858 | }); |
859 | } |
860 | |
861 | #define RNN_DECL_COPY_RES_LAYER_FWD(cname) \ |
862 | template <> \ |
863 | template <typename dst_layer_dt, typename dst_iter_dt> \ |
864 | void cname::copy_res_layer(const rnn_conf_t &rnn, \ |
865 | dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer, \ |
866 | const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_, \ |
867 | const gemm_acc_t *ws_diff_states_layer_) const { \ |
868 | auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); \ |
869 | auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); \ |
870 | copy_res_layer_fwd_template(rnn, pd(), dst_layer_, dst_layer_d, \ |
871 | dst_iter_, dst_iter_d, ws_states_layer_); \ |
872 | } |
873 | |
874 | RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_f32_t) |
875 | RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_bf16_t) |
876 | RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_u8s8_t) |
877 | RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_s8s8_t) |
878 | |
879 | #define RNN_DECL_COPY_RES_LAYER_BWD(cname) \ |
880 | template <> \ |
881 | template <typename dst_layer_dt, typename dst_iter_dt> \ |
882 | void cname::copy_res_layer(const rnn_conf_t &rnn, \ |
883 | dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer_, \ |
884 | const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_, \ |
885 | const gemm_acc_t *ws_diff_states_layer_) const { \ |
886 | auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); \ |
887 | copy_res_layer_bwd_template(rnn, diff_src_layer_, diff_src_layer_d, \ |
888 | ws_diff_states_layer_); \ |
889 | } |
890 | |
891 | RNN_DECL_COPY_RES_LAYER_BWD(ref_rnn_bwd_f32_t) |
892 | RNN_DECL_COPY_RES_LAYER_BWD(ref_rnn_bwd_bf16_t) |
893 | |
894 | template <typename src_data_t, typename dst_iter_dt, typename dst_layer_dt> |
895 | void copy_res_iter_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, |
896 | dst_iter_dt *dst_iter_, memory_desc_wrapper &dst_iter_d, |
897 | void *dst_iter_c_, memory_desc_wrapper dst_iter_c_d, |
898 | const dst_layer_dt *dst_layer_, memory_desc_wrapper dst_layer_d, |
899 | const src_data_t *ws_states_iter_, const void *ws_states_iter_c_) { |
900 | if (dst_iter_ == nullptr) return; |
901 | |
902 | const AOC<const src_data_t, 5> ws_states_iter(ws_states_iter_, |
903 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
904 | rnn.ws_states_iter_ld); |
905 | |
906 | const float data_shift = pd->attr()->rnn_data_qparams_.shift_; |
907 | const float data_scale = pd->attr()->rnn_data_qparams_.scale_; |
908 | |
909 | const bool dequantize = pd->with_dst_iter() |
910 | && pd->dst_md(1)->data_type == data_type::f32 && rnn.is_int8_conf(); |
911 | const auto copy_vec = [&](dst_iter_dt *dd, const src_data_t *ss) { |
912 | if (dequantize) { |
913 | PRAGMA_OMP_SIMD() |
914 | for (int s = 0; s < rnn.dic; s++) |
915 | dd[s] = (dst_iter_dt)(((float)ss[s] - data_shift) / data_scale); |
916 | } else { |
917 | PRAGMA_OMP_SIMD() |
918 | for (int s = 0; s < rnn.dic; s++) |
919 | dd[s] = (dst_iter_dt)ss[s]; |
920 | } |
921 | }; |
922 | |
923 | // If skip_dst_layer_copy, then the data to copy for the last |
924 | // layer is in dst_layer, not in workspace. |
925 | const auto n_layer_in_ws = rnn.n_layer - rnn.skip_dst_layer_copy(); |
926 | |
927 | parallel_nd(n_layer_in_ws, rnn.n_dir, rnn.mb, |
928 | [&](dim_t lay, dim_t dir, dim_t b) { |
929 | const auto *ss |
930 | = &ws_states_iter(lay + 1, dir, rnn.n_iter, b, 0); |
931 | auto *dd = dst_iter_ + dst_iter_d.blk_off(lay, dir, b, 0); |
932 | copy_vec(dd, ss); |
933 | }); |
934 | |
935 | if (rnn.skip_dst_layer_copy()) { |
936 | parallel_nd(rnn.n_dir, rnn.mb, [&](dim_t dir, dim_t b) { |
937 | const auto *ss |
938 | = &dst_layer_[dst_layer_d.blk_off(rnn.n_iter - 1, b, dir)]; |
939 | auto *dd = &dst_iter_[dst_iter_d.blk_off( |
940 | rnn.n_layer - 1, dir, b, 0)]; |
941 | copy_vec(dd, (src_data_t *)ss); |
942 | }); |
943 | } |
944 | } |
945 | |
946 | template <typename acc_data_t> |
947 | void copy_res_iter_bwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd, |
948 | acc_data_t *diff_src_iter_, memory_desc_wrapper &diff_src_iter_d, |
949 | float *diff_src_iter_c_, memory_desc_wrapper &diff_src_iter_c_d, |
950 | const acc_data_t *ws_diff_states_iter_, |
951 | const acc_data_t *ws_diff_states_iter_c_) { |
952 | const AOC<const acc_data_t, 5> ws_diff_states_iter(ws_diff_states_iter_, |
953 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
954 | rnn.ws_diff_states_iter_ld); |
955 | const AOC<const acc_data_t, 5> ws_diff_states_iter_c(ws_diff_states_iter_c_, |
956 | rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb, |
957 | rnn.ws_diff_states_iter_c_ld); |
958 | if (diff_src_iter_) { |
959 | parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, |
960 | [&](dim_t lay, dim_t dir, dim_t b) { |
961 | for (int s = 0; s < rnn.sic; s++) { |
962 | diff_src_iter_[diff_src_iter_d.blk_off(lay, dir, b, s)] |
963 | = ws_diff_states_iter(lay, dir, 0, b, s); |
964 | } |
965 | if (pd->cell_kind() == alg_kind::vanilla_lstm) |
966 | for (int s = 0; s < rnn.dhc; s++) { |
967 | diff_src_iter_c_[diff_src_iter_c_d.blk_off( |
968 | lay, dir, b, s)] |
969 | = ws_diff_states_iter_c(lay, dir, 0, b, s); |
970 | } |
971 | }); |
972 | } |
973 | } |
974 | |
975 | #define RNN_DECL_COPY_RES_ITER_FWD(cname) \ |
976 | template <> \ |
977 | template <typename dst_iter_dt, typename dst_layer_dt> \ |
978 | void cname::copy_res_iter(const rnn_conf_t &rnn, dst_iter_dt *dst_iter_, \ |
979 | void *dst_iter_c_, gemm_acc_t *diff_src_iter_, \ |
980 | float *diff_src_iter_c_, const dst_layer_dt *dst_layer_, \ |
981 | const src_layer_t *ws_states_layer_, \ |
982 | const void *ws_states_iter_c_, \ |
983 | const gemm_acc_t *ws_diff_states_iter_, \ |
984 | const gemm_acc_t *ws_diff_states_iter_c_) const { \ |
985 | auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); \ |
986 | auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); \ |
987 | auto dst_iter_c_d = memory_desc_wrapper(pd()->dst_md(2)); \ |
988 | copy_res_iter_fwd_template(rnn, pd(), dst_iter_, dst_iter_d, \ |
989 | dst_iter_c_, dst_iter_c_d, dst_layer_, dst_layer_d, \ |
990 | ws_states_layer_, ws_states_iter_c_); \ |
991 | } |
992 | |
993 | RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_f32_t) |
994 | RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_bf16_t) |
995 | RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_u8s8_t) |
996 | RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_s8s8_t) |
997 | |
998 | #define RNN_DECL_COPY_RES_ITER_BWD(cname) \ |
999 | template <> \ |
1000 | template <typename output_data_t, typename dst_data_t> \ |
1001 | void cname::copy_res_iter(const rnn_conf_t &rnn, output_data_t *dst_iter_, \ |
1002 | void *dst_iter_c_, gemm_acc_t *diff_src_iter_, \ |
1003 | float *diff_src_iter_c_, const dst_data_t *dst_layer_, \ |
1004 | const src_layer_t *ws_states_layer_, \ |
1005 | const void *ws_states_iter_c_, \ |
1006 | const gemm_acc_t *ws_diff_states_iter_, \ |
1007 | const gemm_acc_t *ws_diff_states_iter_c_) const { \ |
1008 | auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); \ |
1009 | auto diff_src_iter_c_d = memory_desc_wrapper(pd()->diff_src_md(2)); \ |
1010 | copy_res_iter_bwd_template(rnn, pd(), diff_src_iter_, diff_src_iter_d, \ |
1011 | diff_src_iter_c_, diff_src_iter_c_d, ws_diff_states_iter_, \ |
1012 | ws_diff_states_iter_c_); \ |
1013 | } |
1014 | |
1015 | RNN_DECL_COPY_RES_ITER_BWD(ref_rnn_bwd_f32_t) |
1016 | RNN_DECL_COPY_RES_ITER_BWD(ref_rnn_bwd_bf16_t) |
1017 | |
1018 | rnn_bias_prepare_sig_templ(copy_bias_to_scratch) { |
1019 | const AOC<T, 3> scratch_bias( |
1020 | scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc); |
1021 | |
1022 | parallel_nd(rnn.n_layer * rnn.n_dir, [&](dim_t i) { |
1023 | const int off = i * rnn.n_bias * rnn.dhc; |
1024 | PRAGMA_OMP_SIMD() |
1025 | for (int j = 0; j < rnn.n_bias * rnn.dhc; j++) |
1026 | scratch_bias_[off + j] = b_[off + j]; |
1027 | }); |
1028 | } |
1029 | |
1030 | rnn_bias_prepare_sig_templ(copy_bias_to_ws) { |
1031 | /* Original set of bias provided by the user */ |
1032 | const AOC<const T, 5> b(b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc); |
1033 | /* Array of pointers initialized in packing */ |
1034 | const AOC<T *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); |
1035 | const AOC<T, 3> scratch_bias( |
1036 | scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc); |
1037 | |
1038 | for (int i = 0; i < rnn.n_layer; i++) { |
1039 | for (int d = 0; d < rnn.n_dir; d++) { |
1040 | int offset_bias = 0; |
1041 | for (int p = 0; p < rnn.n_parts_bias; p++) { |
1042 | bias(i, d, p) = rnn.copy_bias |
1043 | ? const_cast<T *>(&scratch_bias(i, d, offset_bias)) |
1044 | : const_cast<T *>(&b(i, d, offset_bias)); |
1045 | offset_bias += rnn.parts_bias[p] * rnn.dhc; |
1046 | } |
1047 | } |
1048 | } |
1049 | } |
1050 | |
1051 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
1052 | data_type_t acc_type> |
1053 | rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
1054 | acc_type>::bias_prepare)) { |
1055 | |
1056 | if (rnn.copy_bias) { |
1057 | if (rnn.bias_dt == data_type::f32) |
1058 | copy_bias_to_scratch(rnn, reinterpret_cast<float **>(bias_), |
1059 | static_cast<const float *>(b_), |
1060 | static_cast<float *>(scratch_bias_)); |
1061 | else if (rnn.bias_dt == data_type::bf16) |
1062 | copy_bias_to_scratch(rnn, reinterpret_cast<bfloat16_t **>(bias_), |
1063 | static_cast<const bfloat16_t *>(b_), |
1064 | static_cast<bfloat16_t *>(scratch_bias_)); |
1065 | else |
1066 | assert("Unsupported bias data type" ); |
1067 | } |
1068 | |
1069 | if (rnn.bias_dt == data_type::f32) |
1070 | copy_bias_to_ws(rnn, reinterpret_cast<float **>(bias_), |
1071 | static_cast<const float *>(b_), |
1072 | static_cast<float *>(scratch_bias_)); |
1073 | else if (rnn.bias_dt == data_type::bf16) |
1074 | copy_bias_to_ws(rnn, reinterpret_cast<bfloat16_t **>(bias_), |
1075 | static_cast<const bfloat16_t *>(b_), |
1076 | static_cast<bfloat16_t *>(scratch_bias_)); |
1077 | else |
1078 | assert("Unsupported bias data type" ); |
1079 | } |
1080 | |
1081 | static void apply_bias_compensation(const rnn_utils::rnn_conf_t &rnn, |
1082 | float *scratch_bias_, const float *w_iter_comp, |
1083 | const float *w_layer_comp, const float data_shift, |
1084 | const float data_scale, const float *const weights_scales, |
1085 | const bool scale_per_oc) { |
1086 | |
1087 | for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) |
1088 | for (int j = 0; j < rnn.n_bias * rnn.dhc; j++) { |
1089 | const size_t off = i * rnn.n_bias * rnn.dhc + j; |
1090 | const float weights_scale |
1091 | = scale_per_oc ? weights_scales[j] : weights_scales[0]; |
1092 | scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) |
1093 | * data_shift / (weights_scale * data_scale); |
1094 | } |
1095 | } |
1096 | |
1097 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
1098 | data_type_t acc_type> |
1099 | rnn_bias_finalize_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
1100 | acc_type>::bias_finalize)) { |
1101 | if (rnn.is_unsigned_int8_conf()) { |
1102 | const float data_shift = pd()->attr()->rnn_data_qparams_.shift_; |
1103 | const float data_scale = pd()->attr()->rnn_data_qparams_.scale_; |
1104 | const float *const weights_scales |
1105 | = pd()->attr()->rnn_weights_qparams_.scales_; |
1106 | const bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; |
1107 | |
1108 | apply_bias_compensation(rnn, static_cast<float *>(scratch_bias_), |
1109 | w_iter_comp, w_layer_comp, data_shift, data_scale, |
1110 | weights_scales, scale_per_oc); |
1111 | } |
1112 | } |
1113 | |
1114 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
1115 | data_type_t acc_type> |
1116 | rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
1117 | acc_type>::assign_packed_weights)) { |
1118 | assert(md->format_kind == format_kind::rnn_packed); |
1119 | const auto packed_desc = md->format_desc.rnn_packed_desc; |
1120 | const AOC<weights_t *, 3> weights( |
1121 | weights_, rnn.n_layer, rnn.n_dir, packed_desc.n_parts); |
1122 | |
1123 | size_t offset_packed = 0; |
1124 | for (int l = 0; l < rnn.n_layer; l++) |
1125 | for (int d = 0; d < rnn.n_dir; d++) { |
1126 | for (int p = 0; p < packed_desc.n_parts; p++) { |
1127 | weights(l, d, p) = (weights_t *)&w_[offset_packed]; |
1128 | offset_packed |
1129 | += packed_desc.part_pack_size[p] / sizeof(weights_t); |
1130 | } |
1131 | } |
1132 | } |
1133 | |
1134 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
1135 | data_type_t acc_type> |
1136 | rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, weights_type, |
1137 | acc_type>::assign_weights)) { |
1138 | assert(md->format_kind == format_kind::blocked); |
1139 | const auto &blk = md->format_desc.blocking; |
1140 | /* Original set of weights provided by the user */ |
1141 | const AOC<const weights_t, 3> w( |
1142 | w_, rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); |
1143 | /* Array of pointers for each part of weights */ |
1144 | const AOC<weights_t *, 3> weights( |
1145 | weights_, rnn.n_layer, rnn.n_dir, n_parts); |
1146 | |
1147 | for (int i = 0; i < rnn.n_layer; i++) |
1148 | for (int d = 0; d < rnn.n_dir; d++) { |
1149 | size_t offset_weights = 0; |
1150 | for (int p = 0; p < n_parts; p++) { |
1151 | weights(i, d, p) = (weights_t *)&w(i, d, offset_weights); |
1152 | offset_weights += gates_per_part[p] * blk.strides[3]; |
1153 | } |
1154 | } |
1155 | } |
1156 | |
1157 | //********************* Execution function *********************// |
1158 | template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type, |
1159 | data_type_t acc_type> |
1160 | void _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::execute_( |
1161 | const exec_ctx_t &ctx) const { |
1162 | const rnn_conf_t &rnn = this->pd()->rnn_; |
1163 | auto src_layer = CTX_IN_MEM(const src_layer_t *, DNNL_ARG_SRC_LAYER); |
1164 | auto augru_attention |
1165 | = CTX_IN_MEM(const src_layer_t *, DNNL_ARG_AUGRU_ATTENTION); |
1166 | auto src_iter = CTX_IN_MEM(const char *, DNNL_ARG_SRC_ITER); |
1167 | auto src_iter_c = CTX_IN_MEM(const void *, DNNL_ARG_SRC_ITER_C); |
1168 | auto layer_weights_n_comp |
1169 | = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_LAYER); |
1170 | auto iter_weights_n_comp = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_ITER); |
1171 | auto weights_peephole |
1172 | = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS_PEEPHOLE); |
1173 | auto projection_weights_n_comp |
1174 | = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_PROJECTION); |
1175 | auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
1176 | |
1177 | auto dst_layer = rnn.is_fwd |
1178 | ? CTX_OUT_MEM(char *, DNNL_ARG_DST_LAYER) |
1179 | : const_cast<char *>(CTX_IN_MEM(const char *, DNNL_ARG_DST_LAYER)); |
1180 | auto dst_iter = rnn.is_fwd |
1181 | ? CTX_OUT_MEM(char *, DNNL_ARG_DST_ITER) |
1182 | : const_cast<char *>(CTX_IN_MEM(const char *, DNNL_ARG_DST_ITER)); |
1183 | auto dst_iter_c = CTX_OUT_MEM(void *, DNNL_ARG_DST_ITER_C); |
1184 | |
1185 | auto diff_dst_layer |
1186 | = CTX_IN_MEM(const gemm_acc_t *, DNNL_ARG_DIFF_DST_LAYER); |
1187 | auto diff_dst_iter = CTX_IN_MEM(const gemm_acc_t *, DNNL_ARG_DIFF_DST_ITER); |
1188 | auto diff_dst_iter_c = CTX_IN_MEM(const float *, DNNL_ARG_DIFF_DST_ITER_C); |
1189 | |
1190 | auto w_layer = reinterpret_cast<const weights_t *>(layer_weights_n_comp); |
1191 | auto w_iter = reinterpret_cast<const weights_t *>(iter_weights_n_comp); |
1192 | auto w_projection |
1193 | = reinterpret_cast<const weights_t *>(projection_weights_n_comp); |
1194 | auto w_layer_comp = reinterpret_cast<const float *>( |
1195 | layer_weights_n_comp + rnn.weights_layer_comp_offset); |
1196 | auto w_iter_comp = reinterpret_cast<const float *>( |
1197 | iter_weights_n_comp + rnn.weights_iter_comp_offset); |
1198 | auto w_projection_comp = reinterpret_cast<const float *>( |
1199 | projection_weights_n_comp + rnn.weights_projection_comp_offset); |
1200 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1201 | |
1202 | auto ptr_wei_layer |
1203 | = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_layer); |
1204 | auto ptr_wei_iter |
1205 | = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_iter); |
1206 | auto ptr_wei_projection |
1207 | = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_projection); |
1208 | auto ptr_bias = scratchpad.template get<void *>(key_rnn_ptrs_bia); |
1209 | // Here we use scratch_gates for the output of GEMMs on FWD and on input of GEMMs for BWD. |
1210 | // None of the values are kept for bwd |
1211 | auto scratch_gates = scratchpad.template get<scratch_t>(key_rnn_gates); |
1212 | #if DNNL_X64 |
1213 | const auto scratch_gates_blocked |
1214 | = scratchpad.template get<scratch_t>(key_rnn_gates_blocked); |
1215 | const auto scratch_src_layer |
1216 | = scratchpad.template get<scratch_t>(key_rnn_src_layer_trans); |
1217 | const auto scratch_src_iter |
1218 | = scratchpad.template get<scratch_t>(key_rnn_src_iter_trans); |
1219 | #endif |
1220 | |
1221 | auto scratch_ht = scratchpad.template get<ht_t>(key_rnn_ht); |
1222 | auto scratch_diff_ht = scratchpad.template get<gemm_acc_t>(key_rnn_diff_ht); |
1223 | auto scratch_cell = scratchpad.template get<scratch_t>(key_rnn_cell); |
1224 | |
1225 | gemm_acc_t *amx_scratchpad = nullptr; |
1226 | #if DNNL_X64 |
1227 | x64::brgemm_batch_element_t *addr_batch_global = nullptr; |
1228 | if (rnn.is_brgemm && (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx())) { |
1229 | amx_scratchpad = scratchpad.template get<gemm_acc_t>( |
1230 | key_brgemm_primitive_buffer); |
1231 | } |
1232 | addr_batch_global = scratchpad.template get<x64::brgemm_batch_element_t>( |
1233 | key_brgemm_primitive_batch); |
1234 | #endif |
1235 | // Fetching buffers from the workspace |
1236 | // if no workspace was provided we use the scratchpad |
1237 | char *scratch_ptr = scratchpad.template get<char>(key_rnn_space); |
1238 | char *ws_ptr = nullptr; |
1239 | if (rnn.use_workspace) |
1240 | ws_ptr = rnn.is_fwd ? CTX_OUT_MEM(char *, DNNL_ARG_WORKSPACE) |
1241 | : const_cast<char *>(CTX_IN_MEM( |
1242 | const char *, DNNL_ARG_WORKSPACE)); |
1243 | |
1244 | char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; |
1245 | // ws_gates is only used to pass data from FWD to BWD. |
1246 | // assumption: in training, src_data_t and weights_t match |
1247 | gates_t *ws_gates = (gates_t *)(base_ptr + ws_gates_offset_); |
1248 | dst_iter_t *ws_ht = (dst_iter_t *)(base_ptr + ws_ht_offset_); |
1249 | src_layer_t *ws_states_layer |
1250 | = (src_layer_t *)(base_ptr + ws_states_layer_offset_); |
1251 | src_iter_t *ws_states_iter |
1252 | = (src_iter_t *)(base_ptr + ws_states_iter_offset_); |
1253 | void *ws_states_iter_c = (void *)(base_ptr + ws_states_iter_c_offset_); |
1254 | gemm_acc_t *ws_diff_states_layer |
1255 | = (gemm_acc_t *)(base_ptr + ws_diff_states_layer_offset_); |
1256 | gemm_acc_t *ws_diff_states_iter |
1257 | = (gemm_acc_t *)(base_ptr + ws_diff_states_iter_offset_); |
1258 | gemm_acc_t *ws_diff_states_iter_c |
1259 | = (gemm_acc_t *)(base_ptr + ws_diff_states_iter_c_offset_); |
1260 | gates_t *ws_grid = (gates_t *)(base_ptr + ws_grid_comp_offset_); |
1261 | |
1262 | auto diff_src_layer = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_SRC_LAYER); |
1263 | auto diff_src_iter = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_SRC_ITER); |
1264 | auto diff_src_iter_c = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_SRC_ITER_C); |
1265 | |
1266 | auto diff_augru_attention |
1267 | = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_AUGRU_ATTENTION); |
1268 | auto diff_weights_layer |
1269 | = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_WEIGHTS_LAYER); |
1270 | auto diff_weights_iter |
1271 | = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_WEIGHTS_ITER); |
1272 | auto diff_weights_projection |
1273 | = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_WEIGHTS_PROJECTION); |
1274 | auto diff_weights_peephole |
1275 | = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE); |
1276 | auto diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
1277 | |
1278 | // Fetching extra buffers from scratchpad |
1279 | void *ws_bias = static_cast<void *>(scratch_ptr + ws_bias_offset_); |
1280 | /* Pack(if using packed gemm API) or copy(if input arrays have bad leading |
1281 | * dimension */ |
1282 | (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); |
1283 | |
1284 | const memory_desc_t *weights_layer_md = pd()->weights_md(0); |
1285 | const memory_desc_t *weights_iter_md = pd()->weights_md(1); |
1286 | |
1287 | const auto tag = rnn.n_block == 64 ? format_tag::ldgOI64o2i |
1288 | : format_tag::ldgOI32o2i; |
1289 | memory_desc_t wei_layer_desc; |
1290 | memory_desc_init_by_tag(wei_layer_desc, weights_layer_md->ndims, |
1291 | weights_layer_md->dims, data_type::bf16, tag); |
1292 | |
1293 | memory_desc_t wei_iter_desc; |
1294 | memory_desc_init_by_tag(wei_iter_desc, weights_iter_md->ndims, |
1295 | weights_iter_md->dims, data_type::bf16, tag); |
1296 | |
1297 | #if DNNL_X64 |
1298 | if (rnn.is_bf32()) { |
1299 | if (rnn.is_augru) { |
1300 | const auto bf32_augru_attention |
1301 | = scratchpad.template get<src_layer_t>( |
1302 | key_rnn_bf32_attention_trans); |
1303 | cvt_float_to_bfloat16((bfloat16_t *)bf32_augru_attention, |
1304 | (float *)augru_attention, rnn.n_iter * rnn.mb); |
1305 | augru_attention = bf32_augru_attention; |
1306 | } |
1307 | engine_t *engine = ctx.stream()->engine(); |
1308 | auto wei_layer_mem |
1309 | = scratchpad.get_memory_storage(key_rnn_bf32_wei_layer_trans); |
1310 | auto wei_iter_mem |
1311 | = scratchpad.get_memory_storage(key_rnn_bf32_wei_iter_trans); |
1312 | { |
1313 | memory_t reorder_dst( |
1314 | engine, &wei_layer_desc, std::move(wei_layer_mem)); |
1315 | exec_args_t reorder_args; |
1316 | reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_LAYER); |
1317 | reorder_args[DNNL_ARG_DST] = {&reorder_dst, false}; |
1318 | exec_ctx_t reorder_ctx(ctx, std::move(reorder_args)); |
1319 | nested_scratchpad_t ns( |
1320 | ctx, key_nested_multiple + 0, bf32_wei_layer_reorder_); |
1321 | reorder_ctx.set_scratchpad_grantor(ns.grantor()); |
1322 | bf32_wei_layer_reorder_->execute(reorder_ctx); |
1323 | w_layer = scratchpad.template get<weights_t>( |
1324 | key_rnn_bf32_wei_layer_trans); |
1325 | weights_layer_md = &wei_layer_desc; |
1326 | } |
1327 | |
1328 | { |
1329 | memory_t reorder_dst( |
1330 | engine, &wei_iter_desc, std::move(wei_iter_mem)); |
1331 | exec_args_t reorder_args; |
1332 | reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_ITER); |
1333 | reorder_args[DNNL_ARG_DST] = {&reorder_dst, false}; |
1334 | exec_ctx_t reorder_ctx(ctx, std::move(reorder_args)); |
1335 | nested_scratchpad_t ns( |
1336 | ctx, key_nested_multiple + 1, bf32_wei_iter_reorder_); |
1337 | reorder_ctx.set_scratchpad_grantor(ns.grantor()); |
1338 | bf32_wei_iter_reorder_->execute(reorder_ctx); |
1339 | w_iter = scratchpad.template get<weights_t>( |
1340 | key_rnn_bf32_wei_iter_trans); |
1341 | weights_iter_md = &wei_iter_desc; |
1342 | } |
1343 | } |
1344 | #endif |
1345 | |
1346 | (this->*weights_iter_assign_func)(rnn, weights_iter_md, |
1347 | rnn.n_parts_weights_iter, rnn.parts_weights_iter, ptr_wei_iter, |
1348 | w_iter); |
1349 | (this->*weights_layer_assign_func)(rnn, weights_layer_md, |
1350 | rnn.n_parts_weights_layer, rnn.parts_weights_layer, ptr_wei_layer, |
1351 | w_layer); |
1352 | |
1353 | if (rnn.is_lstm_projection) { |
1354 | (this->*weights_projection_assign_func)(rnn, |
1355 | pd()->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), |
1356 | rnn.n_parts_weights_projection, rnn.parts_weights_projection, |
1357 | ptr_wei_projection, w_projection); |
1358 | } |
1359 | |
1360 | (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); |
1361 | |
1362 | // we first need to copy the initial states and input into ws |
1363 | if (!(rnn.skip_src_layer_copy() && rnn.is_fwd)) { |
1364 | if (pd()->src_md(0)->data_type == data_type::f32) |
1365 | copy_init_layer(rnn, ws_states_layer, ws_diff_states_layer, |
1366 | (const float *)src_layer, diff_dst_layer); |
1367 | else |
1368 | copy_init_layer(rnn, ws_states_layer, ws_diff_states_layer, |
1369 | src_layer, diff_dst_layer); |
1370 | } |
1371 | |
1372 | if (!(rnn.skip_src_iter_copy() && rnn.is_fwd)) { |
1373 | if (pd()->src_md(1)->data_type == data_type::f32) |
1374 | copy_init_iter(rnn, ws_states_iter, |
1375 | static_cast<void *>(ws_states_iter_c), ws_diff_states_iter, |
1376 | ws_diff_states_iter_c, (const float *)src_iter, src_iter_c, |
1377 | diff_dst_iter, diff_dst_iter_c); |
1378 | else |
1379 | copy_init_iter(rnn, ws_states_iter, ws_states_iter_c, |
1380 | ws_diff_states_iter, ws_diff_states_iter_c, |
1381 | (const src_iter_t *)src_iter, src_iter_c, diff_dst_iter, |
1382 | diff_dst_iter_c); |
1383 | } |
1384 | |
1385 | // run the execution on the grid |
1386 | (this->*grid_computation)( |
1387 | #if DNNL_X64 |
1388 | ctx, |
1389 | #endif |
1390 | rnn, ptr_wei_layer, ptr_wei_iter, ptr_wei_projection, |
1391 | weights_peephole, w_projection_comp, ptr_bias, src_layer, |
1392 | augru_attention, (const src_iter_t *)src_iter, src_iter_c, |
1393 | (dst_layer_t *)dst_layer, (dst_iter_t *)dst_iter, dst_iter_c, |
1394 | ws_states_layer, ws_states_iter, ws_states_iter_c, |
1395 | ws_diff_states_layer, ws_diff_states_iter, ws_diff_states_iter_c, |
1396 | ws_gates, ws_ht, ws_grid, scratch_gates, scratch_ht, |
1397 | scratch_diff_ht, scratch_cell, |
1398 | #if DNNL_X64 |
1399 | scratch_gates_blocked, scratch_src_layer, scratch_src_iter, |
1400 | #endif |
1401 | diff_augru_attention, diff_weights_layer, diff_weights_iter, |
1402 | diff_weights_projection, diff_weights_peephole, diff_bias, |
1403 | amx_scratchpad |
1404 | #if DNNL_X64 |
1405 | , |
1406 | addr_batch_global |
1407 | #endif |
1408 | ); |
1409 | |
1410 | // Finally we copy the results to the result buffers |
1411 | if (!(rnn.skip_dst_layer_copy() && rnn.is_fwd)) { |
1412 | if (pd()->dst_md(0)->data_type == data_type::f32) |
1413 | copy_res_layer(rnn, (float *)dst_layer, diff_src_layer, dst_iter, |
1414 | ws_states_layer, ws_diff_states_layer); |
1415 | else |
1416 | copy_res_layer(rnn, (dst_layer_t *)dst_layer, diff_src_layer, |
1417 | dst_iter, ws_states_layer, ws_diff_states_layer); |
1418 | } |
1419 | |
1420 | if (!(rnn.skip_dst_iter_copy() && rnn.is_fwd)) { |
1421 | if (pd()->dst_md(1)->data_type == data_type::f32) |
1422 | copy_res_iter(rnn, (float *)dst_iter, dst_iter_c, diff_src_iter, |
1423 | diff_src_iter_c, dst_layer, ws_states_iter, |
1424 | ws_states_iter_c, ws_diff_states_iter, |
1425 | ws_diff_states_iter_c); |
1426 | else |
1427 | copy_res_iter(rnn, (dst_iter_t *)dst_iter, dst_iter_c, |
1428 | diff_src_iter, diff_src_iter_c, dst_layer, ws_states_iter, |
1429 | ws_states_iter_c, ws_diff_states_iter, |
1430 | ws_diff_states_iter_c); |
1431 | } |
1432 | }; |
1433 | |
1434 | /* Fix for MSVS warning C4661 */ |
1435 | template <> |
1436 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_ref); |
1437 | template <> |
1438 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_fwd); |
1439 | template <> |
1440 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_bwd); |
1441 | template <> |
1442 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); |
1443 | template <> |
1444 | rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); |
1445 | template <> |
1446 | rnn_merged_layer_execution_sig(ref_rnn_fwd_f32_t::merged_layer_execution_ref); |
1447 | template <> |
1448 | rnn_merged_layer_execution_sig(ref_rnn_fwd_f32_t::merged_layer_brgemm_fwd); |
1449 | template <> |
1450 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_ref); |
1451 | template <> |
1452 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_fwd); |
1453 | template <> |
1454 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_bwd); |
1455 | template <> |
1456 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); |
1457 | template <> |
1458 | rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); |
1459 | template <> |
1460 | rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_execution_ref); |
1461 | template <> |
1462 | rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_brgemm_fwd); |
1463 | |
1464 | template <> |
1465 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_ref); |
1466 | template <> |
1467 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_fwd); |
1468 | template <> |
1469 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_bwd); |
1470 | template <> |
1471 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru); |
1472 | template <> |
1473 | rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru_lbr); |
1474 | template <> |
1475 | rnn_merged_layer_execution_sig(ref_rnn_fwd_bf16_t::merged_layer_execution_ref); |
1476 | template <> |
1477 | rnn_merged_layer_execution_sig(ref_rnn_fwd_bf16_t::merged_layer_brgemm_fwd); |
1478 | template <> |
1479 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_ref); |
1480 | template <> |
1481 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_fwd); |
1482 | template <> |
1483 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_bwd); |
1484 | template <> |
1485 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru); |
1486 | template <> |
1487 | rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru_lbr); |
1488 | template <> |
1489 | rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_execution_ref); |
1490 | template <> |
1491 | rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_brgemm_fwd); |
1492 | |
1493 | template <> |
1494 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_ref); |
1495 | template <> |
1496 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_fwd); |
1497 | template <> |
1498 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_bwd); |
1499 | template <> |
1500 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); |
1501 | template <> |
1502 | rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); |
1503 | template <> |
1504 | rnn_merged_layer_execution_sig(ref_rnn_fwd_u8s8_t::merged_layer_execution_ref); |
1505 | template <> |
1506 | rnn_merged_layer_execution_sig(ref_rnn_fwd_u8s8_t::merged_layer_brgemm_fwd); |
1507 | |
1508 | template <> |
1509 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_ref); |
1510 | template <> |
1511 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_fwd); |
1512 | template <> |
1513 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_bwd); |
1514 | template <> |
1515 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru); |
1516 | template <> |
1517 | rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru_lbr); |
1518 | template <> |
1519 | rnn_merged_layer_execution_sig(ref_rnn_fwd_s8s8_t::merged_layer_execution_ref); |
1520 | template <> |
1521 | rnn_merged_layer_execution_sig(ref_rnn_fwd_s8s8_t::merged_layer_brgemm_fwd); |
1522 | |
1523 | template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, |
1524 | data_type::f32, data_type::f32>; |
1525 | template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, |
1526 | data_type::f32, data_type::f32>; |
1527 | |
1528 | template struct _ref_rnn_common_t<prop_kind::forward, data_type::bf16, |
1529 | data_type::bf16, data_type::f32>; |
1530 | template struct _ref_rnn_common_t<prop_kind::backward, data_type::bf16, |
1531 | data_type::bf16, data_type::f32>; |
1532 | |
1533 | template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, |
1534 | data_type::s8, data_type::s32>; |
1535 | template struct _ref_rnn_common_t<prop_kind::forward, data_type::s8, |
1536 | data_type::s8, data_type::s32>; |
1537 | |
1538 | #undef AOC |
1539 | |
1540 | } // namespace cpu |
1541 | } // namespace impl |
1542 | } // namespace dnnl |
1543 | |