1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_PROJECTION_POSTGEMM_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_LSTM_CELL_PROJECTION_POSTGEMM_FWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_lstm_cell_projection_postgemm_fwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_cell_projection_postgemm_fwd)
31
32 jit_uni_lstm_cell_projection_postgemm_fwd(
33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
35
36 ~jit_uni_lstm_cell_projection_postgemm_fwd() {}
37
38 status_t init(data_type_t sdt) override {
39 jit_uni_rnn_postgemm::init(src_data_t);
40 projection_ = true;
41 return create_kernel();
42 }
43
44protected:
45 // register size in bytes
46 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
47 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
48 static constexpr size_t qscale_dt_size = sizeof(float);
49 const size_t vlen_dst
50 = vlen / (sizeof(float) / types::data_type_size(src_data_t));
51 const size_t hstate_dt_size = types::data_type_size(src_data_t);
52 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
53
54 void generate() {
55 using namespace Xbyak;
56
57 // Labels declaration
58 Label vector_loop_start_label, vector_loop_inc_regs,
59 vector_loop_end_label;
60 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
61
62 // Register map
63 const Reg64 loop_cnt(rbx); // loop counter
64 // We skip vmm0 as it can be used by the injector for masks on sse4.1
65 const Vmm in(1), tmp1_vmm(5), tmp2_vmm(6);
66
67 const int mask = pd_->attr()->rnn_weights_projection_qparams_.mask_;
68 float *weights_scales
69 = pd_->attr()->rnn_weights_projection_qparams_.scales_;
70
71 // We start code generations here
72 preamble();
73
74 // extract addresses passed as parameter
75 const auto addr_scratch_reg = abi_param2;
76 const auto addr_states_t_l_reg = abi_param4;
77#ifdef _WIN32
78 const auto addr_states_t_l_copy_reg = r10;
79 const auto addr_wcomp_reg = rdi;
80 // Here we cannot use rbp to have initial stack pointer so we
81 // use rsp and offset it with the size of pushed registers in
82 // preamble
83 const auto base_args = get_stack_params_address();
84 mov(addr_states_t_l_copy_reg, ptr[base_args]);
85 mov(addr_wcomp_reg, ptr[base_args + 8]);
86#else
87 const auto addr_states_t_l_copy_reg = abi_param5;
88 const auto addr_wcomp_reg = abi_param6;
89#endif
90
91 // initialize registers with addresses and constants
92 init_regs(weights_scales, vlen);
93
94 mov(loop_cnt, rnn_.dic * scratch_dt_size);
95 cmp(loop_cnt, vlen);
96 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
97
98 L(vector_loop_start_label);
99 {
100 uni_vmovups(in, ptr[addr_scratch_reg]);
101 deq_w(src_data_t, in, tmp1_vmm, tmp2_vmm, 0, mask, vlen,
102 &addr_wcomp_reg);
103 to_src(ptr[addr_states_t_l_reg], in, src_data_t, vlen);
104
105 // if states_t_l_copy is a non null ptr, we write the output to both
106 // tensors
107 cmp(addr_states_t_l_copy_reg, 0);
108 je(vector_loop_inc_regs);
109 // As to_src is called with write_only=true it's important for bf16
110 // src_dt to execute just after to_src method with write_only=false
111 // for the same Vmm
112 to_src(ptr[addr_states_t_l_copy_reg], in, src_data_t, vlen, true);
113 add(addr_states_t_l_copy_reg, vlen_dst);
114
115 // increment address pointers
116 L(vector_loop_inc_regs);
117 add(addr_scratch_reg, vlen);
118 add(addr_states_t_l_reg, vlen_dst);
119 inc_regs(mask, vlen);
120
121 // increment loop counter
122 sub(loop_cnt, vlen);
123 cmp(loop_cnt, vlen);
124 jge(vector_loop_start_label);
125 }
126 L(vector_loop_end_label);
127
128 cmp(loop_cnt, 0);
129 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
130 // Same code as above, we just use vmovss for accessing inputs
131 L(rem_loop_start_label);
132 {
133
134 uni_vmovss(in, ptr[addr_scratch_reg]);
135 deq_w(src_data_t, in, tmp1_vmm, tmp2_vmm, 0, mask, scratch_dt_size,
136 &addr_wcomp_reg);
137 to_src(ptr[addr_states_t_l_reg], in, src_data_t, scratch_dt_size);
138
139 // if states_t_l_copy is a non null ptr, we write the output to both
140 // tensors
141 cmp(addr_states_t_l_copy_reg, 0);
142 je(rem_loop_inc_regs);
143 // As to_src is called with write_only=true it's important for bf16
144 // src_dt to execute just after to_src method with write_only=false
145 // for the same Vmm
146 to_src(ptr[addr_states_t_l_copy_reg], in, src_data_t,
147 scratch_dt_size, true);
148 add(addr_states_t_l_copy_reg, hstate_dt_size);
149
150 // increment address pointers
151 L(rem_loop_inc_regs);
152 add(addr_scratch_reg, scratch_dt_size);
153 add(addr_states_t_l_reg, hstate_dt_size);
154 inc_regs(mask, qscale_dt_size);
155
156 // increment loop counter
157 sub(loop_cnt, scratch_dt_size);
158 cmp(loop_cnt, 0);
159 jg(rem_loop_start_label);
160 }
161 L(rem_loop_end_label);
162
163 postamble();
164 init_table(vlen);
165 }
166};
167
168} // namespace x64
169} // namespace cpu
170} // namespace impl
171} // namespace dnnl
172
173#endif
174