1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_POSTGEMM_DISPATCHER_HPP |
18 | #define CPU_RNN_POSTGEMM_DISPATCHER_HPP |
19 | |
20 | #include <memory> |
21 | |
22 | #include "common/z_magic.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | |
26 | #include "cpu/rnn/cpu_rnn_pd.hpp" |
27 | #include "cpu/rnn/rnn_utils.hpp" |
28 | |
29 | #if DNNL_X64 |
30 | #include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_1_bwd.hpp" |
31 | #include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_1_fwd.hpp" |
32 | #include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_2_bwd.hpp" |
33 | #include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_2_fwd.hpp" |
34 | #include "cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_bwd.hpp" |
35 | #include "cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_fwd.hpp" |
36 | #include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm_bwd.hpp" |
37 | #include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm_fwd.hpp" |
38 | #include "cpu/x64/rnn/jit_uni_lstm_cell_projection_postgemm_fwd.hpp" |
39 | #include "cpu/x64/rnn/jit_uni_rnn_cell_postgemm_bwd.hpp" |
40 | #include "cpu/x64/rnn/jit_uni_rnn_cell_postgemm_fwd.hpp" |
41 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
42 | #endif |
43 | |
44 | namespace dnnl { |
45 | namespace impl { |
46 | namespace cpu { |
47 | |
48 | template <alg_kind_t alg_kind, prop_kind_t prop_kind> |
49 | float activation(float s, float alpha, float cliping); |
50 | |
51 | template <prop_kind_t aprop, impl::data_type_t src_type, |
52 | impl::data_type_t scratch_type, impl::data_type_t acc_type> |
53 | struct rnn_postgemm_dispatcher { |
54 | |
55 | typedef typename prec_traits<src_type>::type src_layer_t; |
56 | typedef typename prec_traits<src_type>::type src_iter_t; |
57 | typedef typename prec_traits<src_type>::type dst_layer_t; |
58 | typedef typename prec_traits<src_type>::type dst_iter_t; |
59 | typedef typename prec_traits<acc_type>::type gemm_acc_t; |
60 | typedef typename prec_traits<scratch_type>::type scratch_t; |
61 | typedef typename prec_traits<src_type>::type ht_t; |
62 | typedef typename prec_traits<src_type>::type gates_t; |
63 | |
64 | using class_name |
65 | = rnn_postgemm_dispatcher<aprop, src_type, scratch_type, acc_type>; |
66 | typedef rnn_postgemm_sig((class_name::*postgemm_f)); |
67 | |
68 | rnn_postgemm_dispatcher( |
69 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
70 | : pd_(pd) { |
71 | // add check if in testing mode |
72 | if (pd->attr()->rnn_tparams_.test_mode_) { |
73 | const auto ngates = utils::map(pd->cell_kind(), 0, |
74 | alg_kind::vanilla_rnn, 1, alg_kind::vanilla_lstm, 4, |
75 | alg_kind::vanilla_gru, 3, alg_kind::lbr_gru, 3, |
76 | alg_kind::vanilla_augru, 3, alg_kind::lbr_augru, 3); |
77 | assert(pd->attr()->rnn_tparams_.ngates_ == ngates); |
78 | MAYBE_UNUSED(ngates); |
79 | } |
80 | |
81 | switch (pd->cell_kind()) { |
82 | case alg_kind::vanilla_lstm: |
83 | postgemm_func = &class_name::lstm_postgemm; |
84 | // used for int8 requantization after projection |
85 | postgemm_part2_func = pd->is_lstm_projection() && pd_->is_fwd() |
86 | ? &class_name::lstm_projection_postgemm |
87 | : nullptr; |
88 | break; |
89 | case alg_kind::vanilla_rnn: |
90 | postgemm_func = &class_name::rnn_postgemm; |
91 | switch (pd->activation_kind()) { |
92 | case alg_kind::eltwise_relu: |
93 | activation_func |
94 | = &activation<alg_kind::eltwise_relu, aprop>; |
95 | break; |
96 | case alg_kind::eltwise_tanh: |
97 | activation_func |
98 | = &activation<alg_kind::eltwise_tanh, aprop>; |
99 | break; |
100 | case alg_kind::eltwise_logistic: |
101 | activation_func |
102 | = &activation<alg_kind::eltwise_logistic, |
103 | aprop>; |
104 | break; |
105 | default: assert(!"Unsupported activation function" ); break; |
106 | } |
107 | break; |
108 | case alg_kind::vanilla_gru: |
109 | case alg_kind::vanilla_augru: |
110 | postgemm_func = &class_name::gru_part1_postgemm; |
111 | postgemm_part2_func = &class_name::gru_part2_postgemm; |
112 | break; |
113 | case alg_kind::lbr_gru: |
114 | case alg_kind::lbr_augru: |
115 | postgemm_func = &class_name::gru_lbr_postgemm; |
116 | break; |
117 | default: assert(!"Unsupported algorithm kind" ); break; |
118 | } |
119 | |
120 | DNNL_X64_ONLY(initialize_jit(rnn)); |
121 | } |
122 | |
123 | ~rnn_postgemm_dispatcher() = default; |
124 | |
125 | rnn_postgemm_sig(unpoison) { |
126 | // XXX (rsdubtso): This is a big hammer that unpoisons everything |
127 | // that a postgemm may touch to avoid writing per-cell-kind |
128 | // versions of unpoisoning code. This must be removed alongside with |
129 | // the big unpoison_outputs() hammer in common/primitive.cpp. |
130 | |
131 | const size_t states_nelems |
132 | = rnn.ws_states_layer_nld * rnn.ws_states_layer_ld; |
133 | const size_t gates_nelems |
134 | = rnn.scratch_gates_nld * rnn.scratch_gates_ld; |
135 | |
136 | if (pd_->is_fwd()) { |
137 | msan_unpoison(dst_layer_, sizeof(*dst_layer_) * states_nelems); |
138 | msan_unpoison(dst_iter_, sizeof(*dst_iter_) * states_nelems); |
139 | if (rnn.is_training) |
140 | msan_unpoison(ws_gates_, sizeof(*ws_gates_) * gates_nelems); |
141 | } else { |
142 | msan_unpoison(diff_src_layer_, |
143 | sizeof(*diff_src_layer_) * (rnn.n_iter + 1) |
144 | * rnn.ws_diff_states_layer_nld |
145 | * rnn.ws_diff_states_layer_ld); |
146 | msan_unpoison(diff_augru_attention_, |
147 | sizeof(*diff_augru_attention_) * rnn.n_iter * rnn.mb |
148 | * rnn.dhc); |
149 | msan_unpoison(diff_src_iter_, |
150 | sizeof(*diff_src_iter_) * (rnn.n_iter + 1) |
151 | * rnn.ws_diff_states_iter_nld |
152 | * rnn.ws_diff_states_iter_ld); |
153 | msan_unpoison(diff_src_iter_c_, |
154 | sizeof(*diff_src_iter_c_) * (rnn.n_iter + 1) |
155 | * rnn.ws_diff_states_iter_c_nld |
156 | * rnn.ws_diff_states_iter_c_ld); |
157 | msan_unpoison( |
158 | scratch_gates_, sizeof(*scratch_gates_) * gates_nelems); |
159 | msan_unpoison( |
160 | scratch_cell_, sizeof(*scratch_cell_) * states_nelems); |
161 | } |
162 | } |
163 | |
164 | // template <typename src_data_t, typename acc_data_t> |
165 | rnn_postgemm_sig(execute) { |
166 | /* This block has an impact on performance in case it is executed |
167 | * multiple times. Be careful when changing it. |
168 | * XXX: The code is compiler sensitive, jit might help with that. |
169 | */ |
170 | #if DNNL_X64 |
171 | if (rnn_postgemm_) { |
172 | rnn_postgemm_->execute(rnn, cell_position, ws_gates_, |
173 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_, |
174 | src_iter_, src_iter_c_, diff_src_layer_, |
175 | diff_augru_attention_, diff_src_iter_, diff_src_iter_c_, |
176 | diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
177 | weights_peephole_, bias_, ws_grid_, scratch_cell_, |
178 | dst_iter_, weights_scales_, block_step); |
179 | unpoison(rnn, cell_position, ws_gates_, scratch_gates_, |
180 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, |
181 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
182 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
183 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
184 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
185 | block_step); |
186 | return; |
187 | } |
188 | #endif |
189 | (this->*postgemm_func)(rnn, cell_position, ws_gates_, scratch_gates_, |
190 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, |
191 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
192 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
193 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
194 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
195 | block_step); |
196 | } |
197 | |
198 | // template <typename src_data_t, typename acc_data_t> |
199 | rnn_postgemm_sig(execute_part2) { |
200 | /* This block has an impact on performance in case it is executed |
201 | * multiple times. Be careful when changing it. |
202 | * XXX: The code is compiler sensitive, jit might help with that. |
203 | */ |
204 | #if DNNL_X64 |
205 | if (rnn_postgemm_part2_) { |
206 | rnn_postgemm_part2_->execute(rnn, cell_position, ws_gates_, |
207 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_, |
208 | src_iter_, src_iter_c_, diff_src_layer_, |
209 | diff_augru_attention_, diff_src_iter_, diff_src_iter_c_, |
210 | diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_, |
211 | weights_peephole_, bias_, ws_grid_, scratch_cell_, |
212 | dst_iter_, weights_scales_, block_step); |
213 | unpoison(rnn, cell_position, ws_gates_, scratch_gates_, |
214 | augru_attention_, dst_layer_, dst_iter_c_, src_iter_, |
215 | src_iter_c_, diff_src_layer_, diff_augru_attention_, |
216 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
217 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
218 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
219 | block_step); |
220 | return; |
221 | } |
222 | #endif |
223 | (this->*postgemm_part2_func)(rnn, cell_position, ws_gates_, |
224 | scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_, |
225 | src_iter_, src_iter_c_, diff_src_layer_, diff_augru_attention_, |
226 | diff_src_iter_, diff_src_iter_c_, diff_dst_layer_, |
227 | diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_, |
228 | ws_grid_, scratch_cell_, dst_iter_, weights_scales_, |
229 | block_step); |
230 | } |
231 | |
232 | private: |
233 | float (*activation_func)(float s, float alpha, float cliping); |
234 | rnn_postgemm_sig(rnn_postgemm); |
235 | rnn_postgemm_sig(lstm_postgemm); |
236 | rnn_postgemm_sig(lstm_projection_postgemm); |
237 | rnn_postgemm_sig(gru_part1_postgemm); |
238 | rnn_postgemm_sig(gru_part2_postgemm); |
239 | rnn_postgemm_sig(gru_lbr_postgemm); |
240 | |
241 | const rnn_pd_t *pd_; |
242 | |
243 | postgemm_f postgemm_func; |
244 | postgemm_f postgemm_part2_func; |
245 | |
246 | DNNL_DISALLOW_COPY_AND_ASSIGN(rnn_postgemm_dispatcher); |
247 | |
248 | #if DNNL_X64 |
249 | std::unique_ptr<x64::jit_uni_rnn_postgemm> rnn_postgemm_; |
250 | std::unique_ptr<x64::jit_uni_rnn_postgemm> rnn_postgemm_part2_; |
251 | |
252 | void initialize_jit(const rnn_utils::rnn_conf_t &rnn) { |
253 | using namespace dnnl::impl::cpu::x64; |
254 | |
255 | if (pd_->attr()->rnn_tparams_.test_mode_) return; |
256 | |
257 | const bool jit_fwd = pd_->is_fwd() |
258 | && utils::one_of(src_type, data_type::f32, data_type::u8, |
259 | data_type::s8, data_type::bf16); |
260 | const bool jit_bwd = !pd_->is_fwd() |
261 | && utils::one_of(src_type, data_type::f32, data_type::bf16); |
262 | |
263 | #define CREATE_WITH_DIR(k, ker_t) \ |
264 | do { \ |
265 | if (mayiuse(avx512_core)) \ |
266 | k.reset(new ker_t<avx512_core, src_type, scratch_type>(rnn, pd_)); \ |
267 | else if (mayiuse(avx2)) \ |
268 | k.reset(new ker_t<avx2, src_type, scratch_type>(rnn, pd_)); \ |
269 | else \ |
270 | k.reset(new ker_t<sse41, src_type, scratch_type>(rnn, pd_)); \ |
271 | } while (0) |
272 | #define CREATE(k, ker_t) \ |
273 | do { \ |
274 | if (jit_fwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _fwd)); \ |
275 | if (jit_bwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _bwd)); \ |
276 | } while (0) |
277 | |
278 | if (pd_->cell_kind() == alg_kind::vanilla_lstm) { |
279 | CREATE(rnn_postgemm_, jit_uni_lstm_cell_postgemm); |
280 | } else if (pd_->cell_kind() == alg_kind::vanilla_rnn) { |
281 | CREATE(rnn_postgemm_, jit_uni_rnn_cell_postgemm); |
282 | } else if (utils::one_of(pd_->cell_kind(), alg_kind::vanilla_gru, |
283 | alg_kind::vanilla_augru)) { |
284 | CREATE(rnn_postgemm_, jit_uni_gru_cell_postgemm_part1); |
285 | CREATE(rnn_postgemm_part2_, jit_uni_gru_cell_postgemm_part2); |
286 | } else if (utils::one_of(pd_->cell_kind(), alg_kind::lbr_gru, |
287 | alg_kind::lbr_augru)) { |
288 | CREATE(rnn_postgemm_, jit_uni_gru_lbr_cell_postgemm); |
289 | } |
290 | |
291 | #undef CREATE |
292 | #undef CREATE_WITH_DIR |
293 | |
294 | if (rnn_postgemm_) rnn_postgemm_->init(src_type); |
295 | if (rnn_postgemm_part2_) rnn_postgemm_part2_->init(src_type); |
296 | } |
297 | #endif |
298 | }; |
299 | |
300 | using rnn_postgemm_fwd_f32_t = rnn_postgemm_dispatcher<prop_kind::forward, |
301 | data_type::f32, data_type::f32, data_type::f32>; |
302 | using rnn_postgemm_bwd_f32_t = rnn_postgemm_dispatcher<prop_kind::backward, |
303 | data_type::f32, data_type::f32, data_type::f32>; |
304 | |
305 | using rnn_postgemm_fwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::forward, |
306 | data_type::bf16, data_type::f32, data_type::f32>; |
307 | using rnn_postgemm_bwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::backward, |
308 | data_type::bf16, data_type::bf16, data_type::f32>; |
309 | |
310 | using rnn_postgemm_fwd_u8_t = rnn_postgemm_dispatcher<prop_kind::forward, |
311 | data_type::u8, data_type::s32, data_type::s32>; |
312 | using rnn_postgemm_fwd_s8_t = rnn_postgemm_dispatcher<prop_kind::forward, |
313 | data_type::s8, data_type::s32, data_type::s32>; |
314 | |
315 | } // namespace cpu |
316 | } // namespace impl |
317 | } // namespace dnnl |
318 | |
319 | #endif |
320 | |