1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_PROJECTION_POSTGEMM_FWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_LSTM_CELL_PROJECTION_POSTGEMM_FWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_lstm_cell_projection_postgemm_fwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_cell_projection_postgemm_fwd) |
31 | |
32 | jit_uni_lstm_cell_projection_postgemm_fwd( |
33 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
34 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
35 | |
36 | ~jit_uni_lstm_cell_projection_postgemm_fwd() {} |
37 | |
38 | status_t init(data_type_t sdt) override { |
39 | jit_uni_rnn_postgemm::init(src_data_t); |
40 | projection_ = true; |
41 | return create_kernel(); |
42 | } |
43 | |
44 | protected: |
45 | // register size in bytes |
46 | using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; |
47 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
48 | static constexpr size_t qscale_dt_size = sizeof(float); |
49 | const size_t vlen_dst |
50 | = vlen / (sizeof(float) / types::data_type_size(src_data_t)); |
51 | const size_t hstate_dt_size = types::data_type_size(src_data_t); |
52 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
53 | |
54 | void generate() { |
55 | using namespace Xbyak; |
56 | |
57 | // Labels declaration |
58 | Label vector_loop_start_label, vector_loop_inc_regs, |
59 | vector_loop_end_label; |
60 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
61 | |
62 | // Register map |
63 | const Reg64 loop_cnt(rbx); // loop counter |
64 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
65 | const Vmm in(1), tmp1_vmm(5), tmp2_vmm(6); |
66 | |
67 | const int mask = pd_->attr()->rnn_weights_projection_qparams_.mask_; |
68 | float *weights_scales |
69 | = pd_->attr()->rnn_weights_projection_qparams_.scales_; |
70 | |
71 | // We start code generations here |
72 | preamble(); |
73 | |
74 | // extract addresses passed as parameter |
75 | const auto addr_scratch_reg = abi_param2; |
76 | const auto addr_states_t_l_reg = abi_param4; |
77 | #ifdef _WIN32 |
78 | const auto addr_states_t_l_copy_reg = r10; |
79 | const auto addr_wcomp_reg = rdi; |
80 | // Here we cannot use rbp to have initial stack pointer so we |
81 | // use rsp and offset it with the size of pushed registers in |
82 | // preamble |
83 | const auto base_args = get_stack_params_address(); |
84 | mov(addr_states_t_l_copy_reg, ptr[base_args]); |
85 | mov(addr_wcomp_reg, ptr[base_args + 8]); |
86 | #else |
87 | const auto addr_states_t_l_copy_reg = abi_param5; |
88 | const auto addr_wcomp_reg = abi_param6; |
89 | #endif |
90 | |
91 | // initialize registers with addresses and constants |
92 | init_regs(weights_scales, vlen); |
93 | |
94 | mov(loop_cnt, rnn_.dic * scratch_dt_size); |
95 | cmp(loop_cnt, vlen); |
96 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
97 | |
98 | L(vector_loop_start_label); |
99 | { |
100 | uni_vmovups(in, ptr[addr_scratch_reg]); |
101 | deq_w(src_data_t, in, tmp1_vmm, tmp2_vmm, 0, mask, vlen, |
102 | &addr_wcomp_reg); |
103 | to_src(ptr[addr_states_t_l_reg], in, src_data_t, vlen); |
104 | |
105 | // if states_t_l_copy is a non null ptr, we write the output to both |
106 | // tensors |
107 | cmp(addr_states_t_l_copy_reg, 0); |
108 | je(vector_loop_inc_regs); |
109 | // As to_src is called with write_only=true it's important for bf16 |
110 | // src_dt to execute just after to_src method with write_only=false |
111 | // for the same Vmm |
112 | to_src(ptr[addr_states_t_l_copy_reg], in, src_data_t, vlen, true); |
113 | add(addr_states_t_l_copy_reg, vlen_dst); |
114 | |
115 | // increment address pointers |
116 | L(vector_loop_inc_regs); |
117 | add(addr_scratch_reg, vlen); |
118 | add(addr_states_t_l_reg, vlen_dst); |
119 | inc_regs(mask, vlen); |
120 | |
121 | // increment loop counter |
122 | sub(loop_cnt, vlen); |
123 | cmp(loop_cnt, vlen); |
124 | jge(vector_loop_start_label); |
125 | } |
126 | L(vector_loop_end_label); |
127 | |
128 | cmp(loop_cnt, 0); |
129 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
130 | // Same code as above, we just use vmovss for accessing inputs |
131 | L(rem_loop_start_label); |
132 | { |
133 | |
134 | uni_vmovss(in, ptr[addr_scratch_reg]); |
135 | deq_w(src_data_t, in, tmp1_vmm, tmp2_vmm, 0, mask, scratch_dt_size, |
136 | &addr_wcomp_reg); |
137 | to_src(ptr[addr_states_t_l_reg], in, src_data_t, scratch_dt_size); |
138 | |
139 | // if states_t_l_copy is a non null ptr, we write the output to both |
140 | // tensors |
141 | cmp(addr_states_t_l_copy_reg, 0); |
142 | je(rem_loop_inc_regs); |
143 | // As to_src is called with write_only=true it's important for bf16 |
144 | // src_dt to execute just after to_src method with write_only=false |
145 | // for the same Vmm |
146 | to_src(ptr[addr_states_t_l_copy_reg], in, src_data_t, |
147 | scratch_dt_size, true); |
148 | add(addr_states_t_l_copy_reg, hstate_dt_size); |
149 | |
150 | // increment address pointers |
151 | L(rem_loop_inc_regs); |
152 | add(addr_scratch_reg, scratch_dt_size); |
153 | add(addr_states_t_l_reg, hstate_dt_size); |
154 | inc_regs(mask, qscale_dt_size); |
155 | |
156 | // increment loop counter |
157 | sub(loop_cnt, scratch_dt_size); |
158 | cmp(loop_cnt, 0); |
159 | jg(rem_loop_start_label); |
160 | } |
161 | L(rem_loop_end_label); |
162 | |
163 | postamble(); |
164 | init_table(vlen); |
165 | } |
166 | }; |
167 | |
168 | } // namespace x64 |
169 | } // namespace cpu |
170 | } // namespace impl |
171 | } // namespace dnnl |
172 | |
173 | #endif |
174 | |