1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_FWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_FWD_HPP |
19 | |
20 | #include <memory> |
21 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
29 | impl::data_type_t scratch_data_t> |
30 | struct jit_uni_rnn_cell_postgemm_fwd : public jit_uni_rnn_postgemm { |
31 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_rnn_cell_postgemm_fwd) |
32 | |
33 | using injector_t = typename utils::conditional<isa == avx512_core, |
34 | jit_uni_eltwise_injector_f32<avx512_core>, |
35 | jit_uni_eltwise_injector_f32<isa>>::type; |
36 | |
37 | jit_uni_rnn_cell_postgemm_fwd( |
38 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
39 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
40 | |
41 | status_t init(data_type_t sdt) override { |
42 | jit_uni_rnn_postgemm::init(src_data_t); |
43 | // we use rax for constant tables |
44 | injector_ = utils::make_unique<injector_t>(this, pd_->activation_kind(), |
45 | pd_->desc()->alpha, pd_->desc()->beta, 1.0f, true, rax); |
46 | return create_kernel(); |
47 | } |
48 | |
49 | protected: |
50 | std::unique_ptr<injector_t> injector_; |
51 | |
52 | // register size in bytes |
53 | using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; |
54 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
55 | static constexpr size_t cstate_dt_size = sizeof(float); |
56 | static constexpr size_t qscale_dt_size = sizeof(float); |
57 | |
58 | const size_t vlen_dst |
59 | = vlen / (sizeof(float) / types::data_type_size(src_data_t)); |
60 | const size_t vlen_bias = vlen / (sizeof(float) / bias_dt_size_); |
61 | const size_t hstate_dt_size = types::data_type_size(src_data_t); |
62 | const size_t gate_dt_size = types::data_type_size(src_data_t); |
63 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
64 | |
65 | void generate() override { |
66 | using namespace Xbyak; |
67 | |
68 | const int mask = pd_->attr()->rnn_weights_qparams_.mask_; |
69 | float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
70 | |
71 | // Labels declaration |
72 | Label vector_loop_start_label, vector_loop_inc_regs, |
73 | vector_loop_end_label; |
74 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
75 | Label table_label; |
76 | |
77 | // Register map |
78 | const Reg64 loop_cnt(r11); // loop counter |
79 | const Reg64 n_step_reg(r12); |
80 | |
81 | // Here we do no unrolling, loop overhead should not be that dramatic |
82 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
83 | const Vmm G(1), tmp1_vmm(5), tmp2_vmm(6); |
84 | |
85 | const auto is_training |
86 | = pd_->desc()->prop_kind == prop_kind::forward_training; |
87 | |
88 | // We start code generations here |
89 | preamble(); |
90 | |
91 | // extract addresses passed as parameter |
92 | const auto addr_ws_gates_reg = abi_param1; |
93 | const auto addr_scratch_gates_reg = abi_param2; |
94 | const auto addr_bias_reg = abi_param3; |
95 | const auto addr_states_t_l_reg = abi_param4; |
96 | const auto base_args = get_stack_params_address(); |
97 | #ifdef _WIN32 |
98 | const auto addr_states_t_l_copy_reg = r10; |
99 | // Here we cannot use rbp to have initial stack pointer so we |
100 | // use rsp and offset it with the size of pushed registers in |
101 | // preamble |
102 | mov(addr_states_t_l_copy_reg, ptr[base_args]); |
103 | if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) |
104 | mov(n_step_reg, ptr[base_args + 40]); |
105 | #else |
106 | const auto addr_states_t_l_copy_reg = abi_param5; |
107 | if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) |
108 | mov(n_step_reg, ptr[base_args + 24]); |
109 | #endif |
110 | |
111 | const auto sg_addr |
112 | = ptr[addr_scratch_gates_reg + 0 * rnn_.dhc * scratch_dt_size]; |
113 | const auto wg_addr |
114 | = ptr[addr_ws_gates_reg + 0 * rnn_.dhc * gate_dt_size]; |
115 | const auto B_addr = ptr[addr_bias_reg + 0 * rnn_.dhc * bias_dt_size_]; |
116 | |
117 | // initialize registers with addresses and constants |
118 | init_regs(weights_scales, vlen); |
119 | injector_->load_table_addr(); |
120 | |
121 | if (rnn_.is_brgemm && !rnn_.unfused_post_gemm) |
122 | mov(loop_cnt, n_step_reg); |
123 | else |
124 | mov(loop_cnt, rnn_.dhc * scratch_dt_size); |
125 | |
126 | cmp(loop_cnt, vlen); |
127 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
128 | |
129 | L_aligned(vector_loop_start_label, 64); |
130 | { |
131 | // load G |
132 | uni_vmovups(G, sg_addr); |
133 | |
134 | // dequantize the gates from s32 to f32 if needed |
135 | deq_w(src_data_t, G, tmp1_vmm, tmp2_vmm, 0, mask, vlen); |
136 | |
137 | // add biases |
138 | to_float(tmp1_vmm, B_addr, rnn_.bias_dt, vlen); |
139 | uni_vaddps(G, G, tmp1_vmm); |
140 | |
141 | // inject eltwise code |
142 | injector_->compute_vector(G.getIdx()); |
143 | |
144 | // if training we write back the gates |
145 | if (is_training) to_src(wg_addr, G, src_data_t, vlen); |
146 | |
147 | to_src(ptr[addr_states_t_l_reg], G, src_data_t, vlen); |
148 | // if states_t_l_copy is a non null ptr, we write the output to both |
149 | // tensors |
150 | cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size); |
151 | jle(vector_loop_inc_regs); |
152 | // As to_src is called with write_only=true it's important for bf16 |
153 | // src_dt to execute just after to_src method with write_only=false |
154 | // for the same Vmm |
155 | to_src(ptr[addr_states_t_l_copy_reg], G, src_data_t, vlen, true); |
156 | |
157 | // increment address pointers |
158 | L(vector_loop_inc_regs); |
159 | add(addr_scratch_gates_reg, vlen); |
160 | add(addr_bias_reg, vlen_bias); |
161 | add(addr_states_t_l_reg, vlen_dst); |
162 | add(addr_states_t_l_copy_reg, vlen_dst); |
163 | if (is_training) add(addr_ws_gates_reg, vlen_dst); |
164 | inc_regs(mask, vlen); |
165 | |
166 | // increment loop counter |
167 | sub(loop_cnt, vlen); |
168 | cmp(loop_cnt, vlen); |
169 | jge(vector_loop_start_label); |
170 | } |
171 | L(vector_loop_end_label); |
172 | |
173 | cmp(loop_cnt, 0); |
174 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
175 | // Same code as above, we just use movuss for accessing inputs |
176 | // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar |
177 | L(rem_loop_start_label); |
178 | { |
179 | // remaping registers to Xmms |
180 | const Xmm Gs(G.getIdx()); |
181 | const Xmm tmp1s_vmm(tmp1_vmm.getIdx()); |
182 | |
183 | // load G |
184 | uni_vmovss(Gs, sg_addr); |
185 | |
186 | // dequantize the gates from s32 to f32 if needed |
187 | deq_w(src_data_t, G, tmp1_vmm, tmp2_vmm, 0, mask, scratch_dt_size); |
188 | |
189 | // add biases |
190 | to_float(tmp1_vmm, B_addr, rnn_.bias_dt, sizeof(float)); |
191 | uni_vaddps(Gs, Gs, tmp1s_vmm); |
192 | |
193 | // inject eltwise code |
194 | injector_->compute_vector(Gs.getIdx()); |
195 | |
196 | // if training we write back the gates |
197 | if (is_training) to_src(wg_addr, G, src_data_t, scratch_dt_size); |
198 | |
199 | to_src(ptr[addr_states_t_l_reg], G, src_data_t, scratch_dt_size); |
200 | // if states_t_l_copy is a non null ptr, we write the output to both |
201 | // tensors |
202 | cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size); |
203 | jle(rem_loop_inc_regs); |
204 | // As to_src is called with write_only=true it's important for bf16 |
205 | // src_dt to execute just after to_src method with write_only=false |
206 | // for the same Vmm |
207 | to_src(ptr[addr_states_t_l_copy_reg], G, src_data_t, |
208 | scratch_dt_size, true); |
209 | |
210 | // increment address pointers |
211 | L(rem_loop_inc_regs); |
212 | add(addr_scratch_gates_reg, scratch_dt_size); |
213 | add(addr_bias_reg, bias_dt_size_); |
214 | add(addr_states_t_l_reg, hstate_dt_size); |
215 | add(addr_states_t_l_copy_reg, hstate_dt_size); |
216 | if (is_training) add(addr_ws_gates_reg, gate_dt_size); |
217 | inc_regs(mask, qscale_dt_size); |
218 | |
219 | // increment loop counter |
220 | sub(loop_cnt, scratch_dt_size); |
221 | cmp(loop_cnt, 0); |
222 | jg(rem_loop_start_label); |
223 | } |
224 | L(rem_loop_end_label); |
225 | |
226 | postamble(); |
227 | |
228 | // inject the constant table for the activation |
229 | injector_->prepare_table(); |
230 | init_table(vlen); |
231 | } |
232 | }; |
233 | |
234 | } // namespace x64 |
235 | } // namespace cpu |
236 | } // namespace impl |
237 | } // namespace dnnl |
238 | |
239 | #endif |
240 | |