1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_FWD_HPP
19
20#include <memory>
21#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28template <cpu_isa_t isa, impl::data_type_t src_data_t,
29 impl::data_type_t scratch_data_t>
30struct jit_uni_rnn_cell_postgemm_fwd : public jit_uni_rnn_postgemm {
31 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_rnn_cell_postgemm_fwd)
32
33 using injector_t = typename utils::conditional<isa == avx512_core,
34 jit_uni_eltwise_injector_f32<avx512_core>,
35 jit_uni_eltwise_injector_f32<isa>>::type;
36
37 jit_uni_rnn_cell_postgemm_fwd(
38 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
39 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
40
41 status_t init(data_type_t sdt) override {
42 jit_uni_rnn_postgemm::init(src_data_t);
43 // we use rax for constant tables
44 injector_ = utils::make_unique<injector_t>(this, pd_->activation_kind(),
45 pd_->desc()->alpha, pd_->desc()->beta, 1.0f, true, rax);
46 return create_kernel();
47 }
48
49protected:
50 std::unique_ptr<injector_t> injector_;
51
52 // register size in bytes
53 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
54 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
55 static constexpr size_t cstate_dt_size = sizeof(float);
56 static constexpr size_t qscale_dt_size = sizeof(float);
57
58 const size_t vlen_dst
59 = vlen / (sizeof(float) / types::data_type_size(src_data_t));
60 const size_t vlen_bias = vlen / (sizeof(float) / bias_dt_size_);
61 const size_t hstate_dt_size = types::data_type_size(src_data_t);
62 const size_t gate_dt_size = types::data_type_size(src_data_t);
63 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
64
65 void generate() override {
66 using namespace Xbyak;
67
68 const int mask = pd_->attr()->rnn_weights_qparams_.mask_;
69 float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
70
71 // Labels declaration
72 Label vector_loop_start_label, vector_loop_inc_regs,
73 vector_loop_end_label;
74 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
75 Label table_label;
76
77 // Register map
78 const Reg64 loop_cnt(r11); // loop counter
79 const Reg64 n_step_reg(r12);
80
81 // Here we do no unrolling, loop overhead should not be that dramatic
82 // We skip vmm0 as it can be used by the injector for masks on sse4.1
83 const Vmm G(1), tmp1_vmm(5), tmp2_vmm(6);
84
85 const auto is_training
86 = pd_->desc()->prop_kind == prop_kind::forward_training;
87
88 // We start code generations here
89 preamble();
90
91 // extract addresses passed as parameter
92 const auto addr_ws_gates_reg = abi_param1;
93 const auto addr_scratch_gates_reg = abi_param2;
94 const auto addr_bias_reg = abi_param3;
95 const auto addr_states_t_l_reg = abi_param4;
96 const auto base_args = get_stack_params_address();
97#ifdef _WIN32
98 const auto addr_states_t_l_copy_reg = r10;
99 // Here we cannot use rbp to have initial stack pointer so we
100 // use rsp and offset it with the size of pushed registers in
101 // preamble
102 mov(addr_states_t_l_copy_reg, ptr[base_args]);
103 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm)
104 mov(n_step_reg, ptr[base_args + 40]);
105#else
106 const auto addr_states_t_l_copy_reg = abi_param5;
107 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm)
108 mov(n_step_reg, ptr[base_args + 24]);
109#endif
110
111 const auto sg_addr
112 = ptr[addr_scratch_gates_reg + 0 * rnn_.dhc * scratch_dt_size];
113 const auto wg_addr
114 = ptr[addr_ws_gates_reg + 0 * rnn_.dhc * gate_dt_size];
115 const auto B_addr = ptr[addr_bias_reg + 0 * rnn_.dhc * bias_dt_size_];
116
117 // initialize registers with addresses and constants
118 init_regs(weights_scales, vlen);
119 injector_->load_table_addr();
120
121 if (rnn_.is_brgemm && !rnn_.unfused_post_gemm)
122 mov(loop_cnt, n_step_reg);
123 else
124 mov(loop_cnt, rnn_.dhc * scratch_dt_size);
125
126 cmp(loop_cnt, vlen);
127 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
128
129 L_aligned(vector_loop_start_label, 64);
130 {
131 // load G
132 uni_vmovups(G, sg_addr);
133
134 // dequantize the gates from s32 to f32 if needed
135 deq_w(src_data_t, G, tmp1_vmm, tmp2_vmm, 0, mask, vlen);
136
137 // add biases
138 to_float(tmp1_vmm, B_addr, rnn_.bias_dt, vlen);
139 uni_vaddps(G, G, tmp1_vmm);
140
141 // inject eltwise code
142 injector_->compute_vector(G.getIdx());
143
144 // if training we write back the gates
145 if (is_training) to_src(wg_addr, G, src_data_t, vlen);
146
147 to_src(ptr[addr_states_t_l_reg], G, src_data_t, vlen);
148 // if states_t_l_copy is a non null ptr, we write the output to both
149 // tensors
150 cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size);
151 jle(vector_loop_inc_regs);
152 // As to_src is called with write_only=true it's important for bf16
153 // src_dt to execute just after to_src method with write_only=false
154 // for the same Vmm
155 to_src(ptr[addr_states_t_l_copy_reg], G, src_data_t, vlen, true);
156
157 // increment address pointers
158 L(vector_loop_inc_regs);
159 add(addr_scratch_gates_reg, vlen);
160 add(addr_bias_reg, vlen_bias);
161 add(addr_states_t_l_reg, vlen_dst);
162 add(addr_states_t_l_copy_reg, vlen_dst);
163 if (is_training) add(addr_ws_gates_reg, vlen_dst);
164 inc_regs(mask, vlen);
165
166 // increment loop counter
167 sub(loop_cnt, vlen);
168 cmp(loop_cnt, vlen);
169 jge(vector_loop_start_label);
170 }
171 L(vector_loop_end_label);
172
173 cmp(loop_cnt, 0);
174 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
175 // Same code as above, we just use movuss for accessing inputs
176 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
177 L(rem_loop_start_label);
178 {
179 // remaping registers to Xmms
180 const Xmm Gs(G.getIdx());
181 const Xmm tmp1s_vmm(tmp1_vmm.getIdx());
182
183 // load G
184 uni_vmovss(Gs, sg_addr);
185
186 // dequantize the gates from s32 to f32 if needed
187 deq_w(src_data_t, G, tmp1_vmm, tmp2_vmm, 0, mask, scratch_dt_size);
188
189 // add biases
190 to_float(tmp1_vmm, B_addr, rnn_.bias_dt, sizeof(float));
191 uni_vaddps(Gs, Gs, tmp1s_vmm);
192
193 // inject eltwise code
194 injector_->compute_vector(Gs.getIdx());
195
196 // if training we write back the gates
197 if (is_training) to_src(wg_addr, G, src_data_t, scratch_dt_size);
198
199 to_src(ptr[addr_states_t_l_reg], G, src_data_t, scratch_dt_size);
200 // if states_t_l_copy is a non null ptr, we write the output to both
201 // tensors
202 cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size);
203 jle(rem_loop_inc_regs);
204 // As to_src is called with write_only=true it's important for bf16
205 // src_dt to execute just after to_src method with write_only=false
206 // for the same Vmm
207 to_src(ptr[addr_states_t_l_copy_reg], G, src_data_t,
208 scratch_dt_size, true);
209
210 // increment address pointers
211 L(rem_loop_inc_regs);
212 add(addr_scratch_gates_reg, scratch_dt_size);
213 add(addr_bias_reg, bias_dt_size_);
214 add(addr_states_t_l_reg, hstate_dt_size);
215 add(addr_states_t_l_copy_reg, hstate_dt_size);
216 if (is_training) add(addr_ws_gates_reg, gate_dt_size);
217 inc_regs(mask, qscale_dt_size);
218
219 // increment loop counter
220 sub(loop_cnt, scratch_dt_size);
221 cmp(loop_cnt, 0);
222 jg(rem_loop_start_label);
223 }
224 L(rem_loop_end_label);
225
226 postamble();
227
228 // inject the constant table for the activation
229 injector_->prepare_table();
230 init_table(vlen);
231 }
232};
233
234} // namespace x64
235} // namespace cpu
236} // namespace impl
237} // namespace dnnl
238
239#endif
240