1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_FWD_HPP
19
20#include <memory>
21#include "common/utils.hpp"
22#include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm.hpp"
23#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28
29template <cpu_isa_t isa, impl::data_type_t src_data_t,
30 impl::data_type_t scratch_data_t>
31struct jit_uni_lstm_cell_postgemm_fwd
32 : public jit_uni_rnn_postgemm,
33 public jit_uni_lstm_cell_postgemm_t<isa> {
34 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_cell_postgemm_fwd)
35
36 jit_uni_lstm_cell_postgemm_fwd(
37 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
38 : jit_uni_rnn_postgemm(rnn, pd, jit_name())
39 , jit_uni_lstm_cell_postgemm_t<isa>(this,
40 get_last_preserved_vmm_idx(1) + 1,
41 // usage of jit_uni_rnn_postgemm::bf16_emu_ to identify bf16
42 // emulation case is illegal here, it's created in
43 // jit_uni_rnn_postgemm::init(), not in constructor, so
44 // jit_uni_rnn_postgemm::bf16_emu_ = nullptr always on this
45 // stage
46 src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16)) {
47 }
48
49 ~jit_uni_lstm_cell_postgemm_fwd() = default;
50
51 status_t init(data_type_t sdt) override {
52 jit_uni_rnn_postgemm::init(src_data_t);
53 // we use rax for both constant tables and load correspondent label
54 // into it when calling correspondent injector.
55 sigmoid_injector_ = utils::make_unique<injector_t>(
56 this, alg_kind::eltwise_logistic, 0.0f, 0.0f, 1.0f, true, rax);
57 tanh_injector_ = utils::make_unique<injector_t>(
58 this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, true, rax);
59 return create_kernel();
60 }
61
62protected:
63 using injector_t = typename jit_uni_lstm_cell_postgemm_t<isa>::injector_t;
64 using Vmm = typename jit_uni_lstm_cell_postgemm_t<isa>::Vmm;
65
66 std::unique_ptr<injector_t> sigmoid_injector_;
67 std::unique_ptr<injector_t> tanh_injector_;
68
69 // register size in bytes
70 static constexpr size_t vlen_ = cpu_isa_traits<isa>::vlen;
71 static constexpr size_t qscale_dt_size = sizeof(float);
72 static constexpr size_t weights_peephole_dt_size_ = sizeof(float);
73 const size_t vlen_dst_
74 = vlen_ / (sizeof(float) / types::data_type_size(src_data_t));
75 const size_t vlen_bias_ = vlen_ / (sizeof(float) / bias_dt_size_);
76 const size_t vlen_c_states_ = vlen_ / (sizeof(float) / cstate_dt_size_);
77 const size_t hstate_dt_size_ = types::data_type_size(src_data_t);
78 const size_t gate_dt_size_ = types::data_type_size(src_data_t);
79 const size_t scratch_dt_size_ = types::data_type_size(scratch_data_t);
80 int get_vmm_idx(int unroll_idx, int type_shift) const {
81 const int preserved_vmm_start_idx = 1;
82 // G0, G1, G2, G3, c_states;
83 const int num_preserved_regs_for_loop_iter = 5;
84 assert(type_shift < num_preserved_regs_for_loop_iter);
85 const int unroll_idx_start = preserved_vmm_start_idx
86 + num_preserved_regs_for_loop_iter * unroll_idx;
87 return unroll_idx_start + type_shift;
88 }
89
90 int G0_idx(int unroll_idx) const { return get_vmm_idx(unroll_idx, 0); }
91 int G1_idx(int unroll_idx) const { return get_vmm_idx(unroll_idx, 1); }
92 int G2_idx(int unroll_idx) const { return get_vmm_idx(unroll_idx, 2); }
93 int G3_idx(int unroll_idx) const { return get_vmm_idx(unroll_idx, 3); }
94 int c_states_idx(int unroll_idx) const {
95 return get_vmm_idx(unroll_idx, 4);
96 }
97 int get_last_preserved_vmm_idx(int current_loop_unroll) const {
98 return c_states_idx(current_loop_unroll - 1);
99 }
100
101 dim_t scale_off(int gate_idx, int unroll_idx) const {
102 const size_t vlen_qscale_elem = vlen_ / qscale_dt_size;
103 return gate_idx * rnn_.dhc + unroll_idx * vlen_qscale_elem;
104 }
105
106 void generate() override {
107 using namespace Xbyak;
108
109 const auto is_training
110 = (pd_->desc()->prop_kind == prop_kind::forward_training);
111
112 const int mask = pd_->attr()->rnn_weights_qparams_.mask_;
113 float *const weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
114
115 // Register map
116 const Reg64 loop_cnt(rbx); // loop counter
117
118 // We start code generations here
119 preamble();
120
121 const Reg64 n_step_reg(rbp);
122
123 // extract addresses passed as parameter
124 const auto addr_ws_gates_reg = abi_param1;
125 const auto addr_scratch_gates_reg = abi_param2;
126 const auto addr_weights_peephole_reg = r11;
127 const auto addr_bias_reg = abi_param3;
128 const auto addr_states_t_l_reg = abi_param4;
129#ifdef _WIN32
130 const auto addr_states_t_l_copy_reg = r10;
131 const auto addr_c_states_tm1_l_reg = rdi;
132 const auto addr_c_states_t_l_reg = rsi;
133 // Here we cannot use rbp to have initial stack pointer so we
134 // use rsp and offset it with the size of pushed registers in
135 // preamble
136 const auto base_args = get_stack_params_address();
137 mov(addr_states_t_l_copy_reg, ptr[base_args]);
138 mov(addr_c_states_tm1_l_reg, ptr[base_args + 8]);
139 mov(addr_c_states_t_l_reg, ptr[base_args + 16]);
140 mov(addr_weights_peephole_reg, ptr[base_args + 24]);
141 mov(n_step_reg, ptr[base_args + 40]);
142#else
143 const auto addr_states_t_l_copy_reg = abi_param5;
144 const auto addr_c_states_tm1_l_reg = abi_param6;
145 const auto addr_c_states_t_l_reg = r10;
146 const auto base_args = get_stack_params_address();
147 mov(addr_c_states_t_l_reg, ptr[base_args]);
148 mov(addr_weights_peephole_reg, ptr[base_args + 8]);
149 mov(n_step_reg, ptr[base_args + 24]);
150#endif
151
152 // helper lambda to address the gates and biases
153 const auto sg_addr = [&](int i, int j = 0) {
154 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size_
155 + j * vlen_];
156 };
157
158 const auto wg_addr = [&](int i, int j = 0) {
159 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size_
160 + j * vlen_dst_];
161 };
162
163 const auto B_addr = [&](int i, int j = 0) {
164 return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_
165 + j * vlen_bias_];
166 };
167
168 const auto weights_peephole_addr = [&](int i, int j = 0) {
169 return ptr[addr_weights_peephole_reg
170 + i * rnn_.dhc * weights_peephole_dt_size_ + j * vlen_];
171 };
172
173 const auto loop_len = rnn_.dhc * scratch_dt_size_;
174 const auto loop_tail = loop_len % vlen_;
175
176 // initialize registers with addresses and constants
177 init_regs(weights_scales, vlen_, loop_tail / scratch_dt_size_);
178 sigmoid_injector_->load_table_addr();
179 tanh_injector_->load_table_addr();
180 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm)
181 mov(loop_cnt, n_step_reg);
182 else
183 mov(loop_cnt, loop_len);
184
185 int loop_unroll = 1;
186 int loop_unroll_tail = 0;
187
188 const int loop_unroll_max = is_avx512 ? 4 : 1;
189 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) {
190 const auto block_loop_len = rnn_.n_block * scratch_dt_size_;
191 for (loop_unroll = loop_unroll_max; loop_unroll > 1;
192 loop_unroll--) {
193 if (block_loop_len % (loop_unroll * vlen_) == 0) break;
194 }
195 if (loop_unroll > 1 && rnn_.n_tail > 0
196 && rnn_.n_tail * scratch_dt_size_ - loop_tail > 0)
197 loop_unroll_tail = 1;
198 } else {
199 for (loop_unroll = loop_unroll_max; loop_unroll > 1;
200 loop_unroll--) {
201 if (loop_len >= (loop_unroll * vlen_)) break;
202 }
203 if (loop_unroll > 1
204 && (loop_len - loop_tail) % (loop_unroll * vlen_) > 0)
205 loop_unroll_tail = 1;
206 }
207
208 auto compute_loop = [=](size_t current_vlen, int current_unroll_len) {
209 this->reset_tmp_vmm_idx_range(
210 get_last_preserved_vmm_idx(current_unroll_len) + 1,
211 this->get_max_allowed_tmp_vmm_allowed_idx());
212
213 injector_utils::vmm_index_set_t vmm_idxs;
214
215 const bool single_tail_loop_iter
216 = current_vlen < vlen_ && current_vlen == loop_tail;
217 const bool need_increment_regs = !single_tail_loop_iter;
218 const auto iter_size = current_unroll_len * current_vlen;
219
220 Label loop_start_label, loop_skip_label;
221 cmp(loop_cnt, iter_size);
222 jl(loop_skip_label, T_NEAR);
223
224 L_aligned(loop_start_label, 64);
225 {
226 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
227 const Vmm G0(G0_idx(ur_idx)), G1(G1_idx(ur_idx)),
228 G2(G2_idx(ur_idx)), G3(G3_idx(ur_idx)),
229 tmp_c_states(c_states_idx(ur_idx));
230 // load G0 G1 G2 G3
231 load(G0, sg_addr(0, ur_idx), scratch_data_t, current_vlen);
232 load(G1, sg_addr(1, ur_idx), scratch_data_t, current_vlen);
233 load(G2, sg_addr(2, ur_idx), scratch_data_t, current_vlen);
234 load(G3, sg_addr(3, ur_idx), scratch_data_t, current_vlen);
235
236 // dequantize the gates from s32 to f32 if needed, add bias
237 deq_w(src_data_t, G0, this->get_next_tmp_vmm(),
238 this->get_next_tmp_vmm(), scale_off(0, ur_idx),
239 mask, current_vlen);
240 const auto bias_g0_vmm = this->get_next_tmp_vmm();
241 to_float(bias_g0_vmm, B_addr(0, ur_idx), rnn_.bias_dt,
242 current_vlen);
243 compute_vaddps(G0, G0, bias_g0_vmm, current_vlen);
244
245 deq_w(src_data_t, G1, this->get_next_tmp_vmm(),
246 this->get_next_tmp_vmm(), scale_off(1, ur_idx),
247 mask, current_vlen);
248 const auto bias_g1_vmm = this->get_next_tmp_vmm();
249 to_float(bias_g1_vmm, B_addr(1, ur_idx), rnn_.bias_dt,
250 current_vlen);
251 compute_vaddps(G1, G1, bias_g1_vmm, current_vlen);
252
253 deq_w(src_data_t, G2, this->get_next_tmp_vmm(),
254 this->get_next_tmp_vmm(), scale_off(2, ur_idx),
255 mask, current_vlen);
256 const auto bias_g2_vmm = this->get_next_tmp_vmm();
257 to_float(bias_g2_vmm, B_addr(2, ur_idx), rnn_.bias_dt,
258 current_vlen);
259 compute_vaddps(G2, G2, bias_g2_vmm, current_vlen);
260
261 deq_w(src_data_t, G3, this->get_next_tmp_vmm(),
262 this->get_next_tmp_vmm(), scale_off(3, ur_idx),
263 mask, current_vlen);
264 const auto bias_g3_vmm = this->get_next_tmp_vmm();
265 to_float(bias_g3_vmm, B_addr(3, ur_idx), rnn_.bias_dt,
266 current_vlen);
267 compute_vaddps(G3, G3, bias_g3_vmm, current_vlen);
268
269 to_float(tmp_c_states,
270 ptr[addr_c_states_tm1_l_reg
271 + ur_idx * vlen_c_states_],
272 rnn_.src_iter_c_dt, current_vlen);
273
274 // add peephole
275 if (rnn_.is_lstm_peephole) {
276 compute_vfmadd231ps(G0, tmp_c_states,
277 weights_peephole_addr(0, ur_idx), current_vlen,
278 this->maybe_get_next_tmp_vmm_for_below_avx2_isa());
279 compute_vfmadd231ps(G1, tmp_c_states,
280 weights_peephole_addr(1, ur_idx), current_vlen,
281 this->maybe_get_next_tmp_vmm_for_below_avx2_isa());
282 }
283
284 vmm_idxs.emplace(G0.getIdx());
285 vmm_idxs.emplace(G1.getIdx());
286 if (!rnn_.is_lstm_peephole) vmm_idxs.emplace(G3.getIdx());
287 }
288
289 // inject eltwise code
290 sigmoid_injector_->load_table_addr();
291 sigmoid_injector_->compute_vector_range(vmm_idxs);
292 vmm_idxs.clear();
293
294 if (is_training) {
295 for (int ur_idx = 0; ur_idx < current_unroll_len;
296 ur_idx++) {
297 to_src(wg_addr(0, ur_idx), Vmm(G0_idx(ur_idx)),
298 src_data_t, current_vlen);
299 to_src(wg_addr(1, ur_idx), Vmm(G1_idx(ur_idx)),
300 src_data_t, current_vlen);
301 if (!rnn_.is_lstm_peephole)
302 to_src(wg_addr(3, ur_idx), Vmm(G3_idx(ur_idx)),
303 src_data_t, current_vlen);
304 }
305 }
306 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
307 vmm_idxs.emplace(G2_idx(ur_idx));
308 }
309 tanh_injector_->load_table_addr();
310 tanh_injector_->compute_vector_range(vmm_idxs);
311 vmm_idxs.clear();
312
313 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
314 const Vmm G0(G0_idx(ur_idx)), G1(G1_idx(ur_idx)),
315 G2(G2_idx(ur_idx)),
316 tmp_c_states(c_states_idx(ur_idx));
317 if (is_training) {
318 to_src(wg_addr(2, ur_idx), G2, src_data_t,
319 current_vlen);
320 }
321
322 // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
323 compute_vmulps(
324 tmp_c_states, tmp_c_states, G1, current_vlen);
325 compute_vfmadd231ps(tmp_c_states, this->vmm_backup(G0), G2,
326 current_vlen);
327 to_src(ptr[addr_c_states_t_l_reg + ur_idx * vlen_c_states_],
328 tmp_c_states, rnn_.dst_iter_c_dt, current_vlen);
329 }
330
331 // add peephole
332 if (rnn_.is_lstm_peephole) {
333 for (int ur_idx = 0; ur_idx < current_unroll_len;
334 ur_idx++) {
335 const int cur_g3_idx = G3_idx(ur_idx);
336 compute_vfmadd231ps(Vmm(cur_g3_idx),
337 Vmm(c_states_idx(ur_idx)),
338 weights_peephole_addr(2, ur_idx), current_vlen,
339 this->maybe_get_next_tmp_vmm_for_below_avx2_isa());
340 vmm_idxs.emplace(cur_g3_idx);
341 }
342 sigmoid_injector_->load_table_addr();
343 sigmoid_injector_->compute_vector_range(vmm_idxs);
344 vmm_idxs.clear();
345
346 // if training we write back the gates
347 if (is_training) {
348 for (int ur_idx = 0; ur_idx < current_unroll_len;
349 ur_idx++) {
350 to_src(wg_addr(3, ur_idx), Vmm(G3_idx(ur_idx)),
351 src_data_t, current_vlen);
352 }
353 }
354 }
355
356 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
357 vmm_idxs.emplace(c_states_idx(ur_idx));
358 }
359 // states_t_l = G3 * tanh(c_states_t_l)
360 tanh_injector_->load_table_addr();
361 tanh_injector_->compute_vector_range(vmm_idxs);
362 vmm_idxs.clear();
363
364 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
365 const Vmm G3(G3_idx(ur_idx)),
366 tmp_c_states(c_states_idx(ur_idx));
367 compute_vmulps(
368 tmp_c_states, tmp_c_states, G3, current_vlen);
369 }
370
371 // downconvert/quantize and write back the state
372 Label loop_inc_regs_label,
373 update_single_states_tensor_only_label;
374 cmp(addr_states_t_l_copy_reg, 0);
375 je(update_single_states_tensor_only_label, T_NEAR);
376 // if states_t_l_copy is a non null ptr, we write the output to
377 // both tensors
378 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
379 const Vmm tmp_c_states(c_states_idx(ur_idx));
380 to_src(ptr[addr_states_t_l_reg + ur_idx * vlen_dst_],
381 tmp_c_states, src_data_t, current_vlen);
382 // As to_src is called with write_only=true it's important
383 // for bf16 src_dt to execute just after to_src method with
384 // write_only=false for the same Vmm
385 to_src(ptr[addr_states_t_l_copy_reg + ur_idx * vlen_dst_],
386 tmp_c_states, src_data_t, current_vlen, true);
387 }
388 const size_t hstate_shift = current_vlen < vlen_
389 ? hstate_dt_size_
390 : current_unroll_len * vlen_dst_;
391 if (need_increment_regs)
392 add(addr_states_t_l_copy_reg, hstate_shift);
393 jmp(loop_inc_regs_label, T_NEAR);
394
395 L_aligned(update_single_states_tensor_only_label);
396 for (int ur_idx = 0; ur_idx < current_unroll_len; ur_idx++) {
397 to_src(ptr[addr_states_t_l_reg + ur_idx * vlen_dst_],
398 Vmm(c_states_idx(ur_idx)), src_data_t,
399 current_vlen);
400 }
401
402 // increment address pointers
403 L_aligned(loop_inc_regs_label);
404 if (need_increment_regs) {
405 const size_t scratch_shift = current_vlen < vlen_
406 ? scratch_dt_size_
407 : current_unroll_len * vlen_;
408 add(addr_scratch_gates_reg, scratch_shift);
409 if (rnn_.is_lstm_peephole) {
410 const size_t wpeephole_shift = current_vlen < vlen_
411 ? weights_peephole_dt_size_
412 : current_unroll_len * vlen_;
413 add(addr_weights_peephole_reg, wpeephole_shift);
414 }
415 const size_t bias_shift = current_vlen < vlen_
416 ? bias_dt_size_
417 : current_unroll_len * vlen_bias_;
418 add(addr_bias_reg, bias_shift);
419 add(addr_states_t_l_reg, hstate_shift);
420 const size_t cstate_shift = current_vlen < vlen_
421 ? cstate_dt_size_
422 : current_unroll_len * vlen_c_states_;
423 add(addr_c_states_tm1_l_reg, cstate_shift);
424 add(addr_c_states_t_l_reg, cstate_shift);
425 if (is_training) {
426 const size_t gate_shift = current_vlen < vlen_
427 ? gate_dt_size_
428 : current_unroll_len * vlen_dst_;
429 add(addr_ws_gates_reg, gate_shift);
430 }
431 const size_t qscale_shift = current_vlen < vlen_
432 ? qscale_dt_size
433 : current_unroll_len * vlen_;
434 inc_regs(mask, qscale_shift);
435 }
436
437 // increment loop counter
438 sub(loop_cnt, iter_size);
439 cmp(loop_cnt, iter_size);
440 jge(loop_start_label, T_NEAR);
441 }
442 L_aligned(loop_skip_label, 64);
443 };
444
445 if (loop_unroll > 0) {
446 // unrolled vector loop
447 compute_loop(vlen_, loop_unroll);
448 }
449
450 if (loop_unroll_tail > 0) {
451 // not unrolled vector loop if required
452 compute_loop(vlen_, loop_unroll_tail);
453 }
454
455 if (loop_tail > 0) {
456 // tail processing
457 compute_loop(is_avx512 ? loop_tail : scratch_dt_size_, 1);
458 }
459
460 postamble();
461
462 sigmoid_injector_->prepare_table(true);
463 tanh_injector_->prepare_table(true);
464
465 init_table(vlen_);
466 }
467};
468
469} // namespace x64
470} // namespace cpu
471} // namespace impl
472} // namespace dnnl
473
474#endif
475