1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_RNN_POSTGEMM_DISPATCHER_HPP
18#define CPU_RNN_POSTGEMM_DISPATCHER_HPP
19
20#include <memory>
21
22#include "common/z_magic.hpp"
23
24#include "cpu/platform.hpp"
25
26#include "cpu/rnn/cpu_rnn_pd.hpp"
27#include "cpu/rnn/rnn_utils.hpp"
28
29#if DNNL_X64
30#include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_1_bwd.hpp"
31#include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_1_fwd.hpp"
32#include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_2_bwd.hpp"
33#include "cpu/x64/rnn/jit_uni_gru_cell_postgemm_2_fwd.hpp"
34#include "cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_bwd.hpp"
35#include "cpu/x64/rnn/jit_uni_gru_lbr_cell_postgemm_fwd.hpp"
36#include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm_bwd.hpp"
37#include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm_fwd.hpp"
38#include "cpu/x64/rnn/jit_uni_lstm_cell_projection_postgemm_fwd.hpp"
39#include "cpu/x64/rnn/jit_uni_rnn_cell_postgemm_bwd.hpp"
40#include "cpu/x64/rnn/jit_uni_rnn_cell_postgemm_fwd.hpp"
41#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
42#endif
43
44namespace dnnl {
45namespace impl {
46namespace cpu {
47
48template <alg_kind_t alg_kind, prop_kind_t prop_kind>
49float activation(float s, float alpha, float cliping);
50
51template <prop_kind_t aprop, impl::data_type_t src_type,
52 impl::data_type_t scratch_type, impl::data_type_t acc_type>
53struct rnn_postgemm_dispatcher {
54
55 typedef typename prec_traits<src_type>::type src_layer_t;
56 typedef typename prec_traits<src_type>::type src_iter_t;
57 typedef typename prec_traits<src_type>::type dst_layer_t;
58 typedef typename prec_traits<src_type>::type dst_iter_t;
59 typedef typename prec_traits<acc_type>::type gemm_acc_t;
60 typedef typename prec_traits<scratch_type>::type scratch_t;
61 typedef typename prec_traits<src_type>::type ht_t;
62 typedef typename prec_traits<src_type>::type gates_t;
63
64 using class_name
65 = rnn_postgemm_dispatcher<aprop, src_type, scratch_type, acc_type>;
66 typedef rnn_postgemm_sig((class_name::*postgemm_f));
67
68 rnn_postgemm_dispatcher(
69 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
70 : pd_(pd) {
71 // add check if in testing mode
72 if (pd->attr()->rnn_tparams_.test_mode_) {
73 const auto ngates = utils::map(pd->cell_kind(), 0,
74 alg_kind::vanilla_rnn, 1, alg_kind::vanilla_lstm, 4,
75 alg_kind::vanilla_gru, 3, alg_kind::lbr_gru, 3,
76 alg_kind::vanilla_augru, 3, alg_kind::lbr_augru, 3);
77 assert(pd->attr()->rnn_tparams_.ngates_ == ngates);
78 MAYBE_UNUSED(ngates);
79 }
80
81 switch (pd->cell_kind()) {
82 case alg_kind::vanilla_lstm:
83 postgemm_func = &class_name::lstm_postgemm;
84 // used for int8 requantization after projection
85 postgemm_part2_func = pd->is_lstm_projection() && pd_->is_fwd()
86 ? &class_name::lstm_projection_postgemm
87 : nullptr;
88 break;
89 case alg_kind::vanilla_rnn:
90 postgemm_func = &class_name::rnn_postgemm;
91 switch (pd->activation_kind()) {
92 case alg_kind::eltwise_relu:
93 activation_func
94 = &activation<alg_kind::eltwise_relu, aprop>;
95 break;
96 case alg_kind::eltwise_tanh:
97 activation_func
98 = &activation<alg_kind::eltwise_tanh, aprop>;
99 break;
100 case alg_kind::eltwise_logistic:
101 activation_func
102 = &activation<alg_kind::eltwise_logistic,
103 aprop>;
104 break;
105 default: assert(!"Unsupported activation function"); break;
106 }
107 break;
108 case alg_kind::vanilla_gru:
109 case alg_kind::vanilla_augru:
110 postgemm_func = &class_name::gru_part1_postgemm;
111 postgemm_part2_func = &class_name::gru_part2_postgemm;
112 break;
113 case alg_kind::lbr_gru:
114 case alg_kind::lbr_augru:
115 postgemm_func = &class_name::gru_lbr_postgemm;
116 break;
117 default: assert(!"Unsupported algorithm kind"); break;
118 }
119
120 DNNL_X64_ONLY(initialize_jit(rnn));
121 }
122
123 ~rnn_postgemm_dispatcher() = default;
124
125 rnn_postgemm_sig(unpoison) {
126 // XXX (rsdubtso): This is a big hammer that unpoisons everything
127 // that a postgemm may touch to avoid writing per-cell-kind
128 // versions of unpoisoning code. This must be removed alongside with
129 // the big unpoison_outputs() hammer in common/primitive.cpp.
130
131 const size_t states_nelems
132 = rnn.ws_states_layer_nld * rnn.ws_states_layer_ld;
133 const size_t gates_nelems
134 = rnn.scratch_gates_nld * rnn.scratch_gates_ld;
135
136 if (pd_->is_fwd()) {
137 msan_unpoison(dst_layer_, sizeof(*dst_layer_) * states_nelems);
138 msan_unpoison(dst_iter_, sizeof(*dst_iter_) * states_nelems);
139 if (rnn.is_training)
140 msan_unpoison(ws_gates_, sizeof(*ws_gates_) * gates_nelems);
141 } else {
142 msan_unpoison(diff_src_layer_,
143 sizeof(*diff_src_layer_) * (rnn.n_iter + 1)
144 * rnn.ws_diff_states_layer_nld
145 * rnn.ws_diff_states_layer_ld);
146 msan_unpoison(diff_augru_attention_,
147 sizeof(*diff_augru_attention_) * rnn.n_iter * rnn.mb
148 * rnn.dhc);
149 msan_unpoison(diff_src_iter_,
150 sizeof(*diff_src_iter_) * (rnn.n_iter + 1)
151 * rnn.ws_diff_states_iter_nld
152 * rnn.ws_diff_states_iter_ld);
153 msan_unpoison(diff_src_iter_c_,
154 sizeof(*diff_src_iter_c_) * (rnn.n_iter + 1)
155 * rnn.ws_diff_states_iter_c_nld
156 * rnn.ws_diff_states_iter_c_ld);
157 msan_unpoison(
158 scratch_gates_, sizeof(*scratch_gates_) * gates_nelems);
159 msan_unpoison(
160 scratch_cell_, sizeof(*scratch_cell_) * states_nelems);
161 }
162 }
163
164 // template <typename src_data_t, typename acc_data_t>
165 rnn_postgemm_sig(execute) {
166 /* This block has an impact on performance in case it is executed
167 * multiple times. Be careful when changing it.
168 * XXX: The code is compiler sensitive, jit might help with that.
169 */
170#if DNNL_X64
171 if (rnn_postgemm_) {
172 rnn_postgemm_->execute(rnn, cell_position, ws_gates_,
173 scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_,
174 src_iter_, src_iter_c_, diff_src_layer_,
175 diff_augru_attention_, diff_src_iter_, diff_src_iter_c_,
176 diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
177 weights_peephole_, bias_, ws_grid_, scratch_cell_,
178 dst_iter_, weights_scales_, block_step);
179 unpoison(rnn, cell_position, ws_gates_, scratch_gates_,
180 augru_attention_, dst_layer_, dst_iter_c_, src_iter_,
181 src_iter_c_, diff_src_layer_, diff_augru_attention_,
182 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
183 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
184 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
185 block_step);
186 return;
187 }
188#endif
189 (this->*postgemm_func)(rnn, cell_position, ws_gates_, scratch_gates_,
190 augru_attention_, dst_layer_, dst_iter_c_, src_iter_,
191 src_iter_c_, diff_src_layer_, diff_augru_attention_,
192 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
193 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
194 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
195 block_step);
196 }
197
198 // template <typename src_data_t, typename acc_data_t>
199 rnn_postgemm_sig(execute_part2) {
200 /* This block has an impact on performance in case it is executed
201 * multiple times. Be careful when changing it.
202 * XXX: The code is compiler sensitive, jit might help with that.
203 */
204#if DNNL_X64
205 if (rnn_postgemm_part2_) {
206 rnn_postgemm_part2_->execute(rnn, cell_position, ws_gates_,
207 scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_,
208 src_iter_, src_iter_c_, diff_src_layer_,
209 diff_augru_attention_, diff_src_iter_, diff_src_iter_c_,
210 diff_dst_layer_, diff_dst_iter_, diff_dst_iter_c_,
211 weights_peephole_, bias_, ws_grid_, scratch_cell_,
212 dst_iter_, weights_scales_, block_step);
213 unpoison(rnn, cell_position, ws_gates_, scratch_gates_,
214 augru_attention_, dst_layer_, dst_iter_c_, src_iter_,
215 src_iter_c_, diff_src_layer_, diff_augru_attention_,
216 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
217 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
218 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
219 block_step);
220 return;
221 }
222#endif
223 (this->*postgemm_part2_func)(rnn, cell_position, ws_gates_,
224 scratch_gates_, augru_attention_, dst_layer_, dst_iter_c_,
225 src_iter_, src_iter_c_, diff_src_layer_, diff_augru_attention_,
226 diff_src_iter_, diff_src_iter_c_, diff_dst_layer_,
227 diff_dst_iter_, diff_dst_iter_c_, weights_peephole_, bias_,
228 ws_grid_, scratch_cell_, dst_iter_, weights_scales_,
229 block_step);
230 }
231
232private:
233 float (*activation_func)(float s, float alpha, float cliping);
234 rnn_postgemm_sig(rnn_postgemm);
235 rnn_postgemm_sig(lstm_postgemm);
236 rnn_postgemm_sig(lstm_projection_postgemm);
237 rnn_postgemm_sig(gru_part1_postgemm);
238 rnn_postgemm_sig(gru_part2_postgemm);
239 rnn_postgemm_sig(gru_lbr_postgemm);
240
241 const rnn_pd_t *pd_;
242
243 postgemm_f postgemm_func;
244 postgemm_f postgemm_part2_func;
245
246 DNNL_DISALLOW_COPY_AND_ASSIGN(rnn_postgemm_dispatcher);
247
248#if DNNL_X64
249 std::unique_ptr<x64::jit_uni_rnn_postgemm> rnn_postgemm_;
250 std::unique_ptr<x64::jit_uni_rnn_postgemm> rnn_postgemm_part2_;
251
252 void initialize_jit(const rnn_utils::rnn_conf_t &rnn) {
253 using namespace dnnl::impl::cpu::x64;
254
255 if (pd_->attr()->rnn_tparams_.test_mode_) return;
256
257 const bool jit_fwd = pd_->is_fwd()
258 && utils::one_of(src_type, data_type::f32, data_type::u8,
259 data_type::s8, data_type::bf16);
260 const bool jit_bwd = !pd_->is_fwd()
261 && utils::one_of(src_type, data_type::f32, data_type::bf16);
262
263#define CREATE_WITH_DIR(k, ker_t) \
264 do { \
265 if (mayiuse(avx512_core)) \
266 k.reset(new ker_t<avx512_core, src_type, scratch_type>(rnn, pd_)); \
267 else if (mayiuse(avx2)) \
268 k.reset(new ker_t<avx2, src_type, scratch_type>(rnn, pd_)); \
269 else \
270 k.reset(new ker_t<sse41, src_type, scratch_type>(rnn, pd_)); \
271 } while (0)
272#define CREATE(k, ker_t) \
273 do { \
274 if (jit_fwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _fwd)); \
275 if (jit_bwd) CREATE_WITH_DIR(k, CONCAT2(ker_t, _bwd)); \
276 } while (0)
277
278 if (pd_->cell_kind() == alg_kind::vanilla_lstm) {
279 CREATE(rnn_postgemm_, jit_uni_lstm_cell_postgemm);
280 } else if (pd_->cell_kind() == alg_kind::vanilla_rnn) {
281 CREATE(rnn_postgemm_, jit_uni_rnn_cell_postgemm);
282 } else if (utils::one_of(pd_->cell_kind(), alg_kind::vanilla_gru,
283 alg_kind::vanilla_augru)) {
284 CREATE(rnn_postgemm_, jit_uni_gru_cell_postgemm_part1);
285 CREATE(rnn_postgemm_part2_, jit_uni_gru_cell_postgemm_part2);
286 } else if (utils::one_of(pd_->cell_kind(), alg_kind::lbr_gru,
287 alg_kind::lbr_augru)) {
288 CREATE(rnn_postgemm_, jit_uni_gru_lbr_cell_postgemm);
289 }
290
291#undef CREATE
292#undef CREATE_WITH_DIR
293
294 if (rnn_postgemm_) rnn_postgemm_->init(src_type);
295 if (rnn_postgemm_part2_) rnn_postgemm_part2_->init(src_type);
296 }
297#endif
298};
299
300using rnn_postgemm_fwd_f32_t = rnn_postgemm_dispatcher<prop_kind::forward,
301 data_type::f32, data_type::f32, data_type::f32>;
302using rnn_postgemm_bwd_f32_t = rnn_postgemm_dispatcher<prop_kind::backward,
303 data_type::f32, data_type::f32, data_type::f32>;
304
305using rnn_postgemm_fwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::forward,
306 data_type::bf16, data_type::f32, data_type::f32>;
307using rnn_postgemm_bwd_bf16_t = rnn_postgemm_dispatcher<prop_kind::backward,
308 data_type::bf16, data_type::bf16, data_type::f32>;
309
310using rnn_postgemm_fwd_u8_t = rnn_postgemm_dispatcher<prop_kind::forward,
311 data_type::u8, data_type::s32, data_type::s32>;
312using rnn_postgemm_fwd_s8_t = rnn_postgemm_dispatcher<prop_kind::forward,
313 data_type::s8, data_type::s32, data_type::s32>;
314
315} // namespace cpu
316} // namespace impl
317} // namespace dnnl
318
319#endif
320