1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_FWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_FWD_HPP |
19 | |
20 | #include <memory> |
21 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
29 | impl::data_type_t scratch_data_t> |
30 | struct jit_uni_gru_lbr_cell_postgemm_fwd : public jit_uni_rnn_postgemm { |
31 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_lbr_cell_postgemm_fwd) |
32 | |
33 | using injector_t = typename utils::conditional<isa == avx512_core, |
34 | jit_uni_eltwise_injector_f32<avx512_core>, |
35 | jit_uni_eltwise_injector_f32<isa>>::type; |
36 | |
37 | jit_uni_gru_lbr_cell_postgemm_fwd( |
38 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
39 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
40 | |
41 | status_t init(data_type_t sdt) override { |
42 | jit_uni_rnn_postgemm::init(src_data_t); |
43 | // we use rax for both constant tables and load correspondent label |
44 | // into it when calling correspondent injector. |
45 | sigmoid_injector_ = utils::make_unique<injector_t>( |
46 | this, alg_kind::eltwise_logistic, 0.0f, 0.0f, 1.0f, true, rax); |
47 | tanh_injector_ = utils::make_unique<injector_t>( |
48 | this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, true, rax); |
49 | return create_kernel(); |
50 | } |
51 | |
52 | protected: |
53 | std::unique_ptr<injector_t> sigmoid_injector_; |
54 | std::unique_ptr<injector_t> tanh_injector_; |
55 | |
56 | // register size in bytes |
57 | using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; |
58 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
59 | |
60 | const size_t vlen_dst |
61 | = vlen / (sizeof(float) / types::data_type_size(src_data_t)); |
62 | const size_t vlen_bias_ = vlen / (sizeof(float) / bias_dt_size_); |
63 | const size_t hstate_dt_size = types::data_type_size(src_data_t); |
64 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
65 | const size_t gate_dt_size = types::data_type_size(src_data_t); |
66 | const size_t loop_len_bytes = rnn_.dhc * scratch_dt_size; |
67 | const size_t loop_tail_bytes = loop_len_bytes % vlen; |
68 | |
69 | void generate() override { |
70 | using namespace Xbyak; |
71 | |
72 | const auto is_training |
73 | = (pd_->desc()->prop_kind == prop_kind::forward_training); |
74 | |
75 | const bool is_augru = pd_->cell_kind() == alg_kind::lbr_augru; |
76 | |
77 | // Labels declaration |
78 | Label tail_processing_or_exit_label, table_label; |
79 | |
80 | // Register map |
81 | const Reg64 loop_cnt(r10); // loop counter |
82 | const Reg64 table_reg(rbx); // table is used for data scale and shifts |
83 | |
84 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
85 | const Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6), tmp3_vmm(7); |
86 | |
87 | // constant table map |
88 | const Address one_addr = ptr[table_reg]; |
89 | |
90 | // We start code generations here |
91 | preamble(); |
92 | |
93 | // extract addresses passed as parameter |
94 | const auto addr_ws_gates_reg = abi_param1; |
95 | const auto addr_scratch_gates_reg = abi_param2; |
96 | const auto addr_bias_reg = abi_param3; |
97 | const auto addr_states_t_l_reg = abi_param4; |
98 | const auto addr_attn_reg = r15; |
99 | #ifdef _WIN32 |
100 | const auto addr_states_t_l_copy_reg = r11; |
101 | const auto addr_states_tm1_l_reg = r12; |
102 | const auto addr_scratch_cell_reg = rsi; |
103 | const auto addr_ws_h_reg = rdi; |
104 | // Here we cannot use rbp to have initial stack pointer so we |
105 | // use rsp and offset it with the size of pushed registers in |
106 | // preamble |
107 | const auto base_args = get_stack_params_address(); |
108 | mov(addr_states_t_l_copy_reg, ptr[base_args]); |
109 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
110 | mov(addr_scratch_cell_reg, ptr[base_args + 16]); |
111 | mov(addr_ws_h_reg, ptr[base_args + 24]); |
112 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]); |
113 | #else |
114 | const auto addr_states_t_l_copy_reg = abi_param5; |
115 | const auto addr_states_tm1_l_reg = abi_param6; |
116 | const auto addr_scratch_cell_reg = r11; |
117 | const auto addr_ws_h_reg = r12; |
118 | const auto base_args = get_stack_params_address(); |
119 | mov(addr_scratch_cell_reg, ptr[base_args]); |
120 | mov(addr_ws_h_reg, ptr[base_args + 8]); |
121 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]); |
122 | #endif |
123 | |
124 | // helper lambda to address the gates and biases |
125 | const auto sg_addr = [&](int i) { |
126 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; |
127 | }; |
128 | const auto wg_addr = [&](int i) { |
129 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; |
130 | }; |
131 | const auto B_addr = [&](int i) { |
132 | return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_]; |
133 | }; |
134 | const auto sc_addr = [&](int i) { |
135 | return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size]; |
136 | }; |
137 | |
138 | auto compute_loop = [=](size_t current_vlen) { |
139 | Label loop_start_label, loop_inc_regs_or_finish; |
140 | L(loop_start_label); |
141 | { |
142 | load(G0, sg_addr(0), scratch_data_t, current_vlen); |
143 | to_float(tmp1_vmm, B_addr(0), rnn_.bias_dt, current_vlen); |
144 | compute_vaddps(G0, G0, tmp1_vmm, current_vlen); |
145 | load(tmp1_vmm, sc_addr(0), scratch_data_t, current_vlen); |
146 | compute_vaddps(G0, G0, tmp1_vmm, current_vlen); |
147 | sigmoid_injector_->load_table_addr(); |
148 | sigmoid_injector_->compute_vector(G0.getIdx()); |
149 | // if training we write back the gates |
150 | if (is_training) |
151 | to_src(wg_addr(0), G0, src_data_t, current_vlen); |
152 | |
153 | // Compute gate 1 |
154 | load(G1, sg_addr(1), scratch_data_t, current_vlen); |
155 | to_float(tmp1_vmm, B_addr(1), rnn_.bias_dt, current_vlen); |
156 | compute_vaddps(G1, G1, tmp1_vmm, current_vlen); |
157 | load(tmp1_vmm, sc_addr(1), scratch_data_t, current_vlen); |
158 | compute_vaddps(G1, G1, tmp1_vmm, current_vlen); |
159 | sigmoid_injector_->load_table_addr(); |
160 | sigmoid_injector_->compute_vector(G1.getIdx()); |
161 | // if training we write back the gates |
162 | if (is_training) |
163 | to_src(wg_addr(1), G1, src_data_t, current_vlen); |
164 | |
165 | // compute last gate |
166 | const auto wh_b_addr = sc_addr(2); |
167 | const auto ws_h_addr = ptr[addr_ws_h_reg]; |
168 | load(tmp1_vmm, wh_b_addr, scratch_data_t, current_vlen); |
169 | to_float(tmp2_vmm, B_addr(3), rnn_.bias_dt, current_vlen); |
170 | compute_vaddps(tmp1_vmm, tmp1_vmm, tmp2_vmm, current_vlen); |
171 | if (is_training) |
172 | to_src(ws_h_addr, tmp1_vmm, src_data_t, current_vlen); |
173 | load(G2, sg_addr(2), scratch_data_t, current_vlen); |
174 | to_float(tmp2_vmm, B_addr(2), rnn_.bias_dt, current_vlen); |
175 | compute_vaddps(G2, G2, tmp2_vmm, current_vlen); |
176 | compute_vfmadd231ps(G2, G1, tmp1_vmm, current_vlen); |
177 | tanh_injector_->load_table_addr(); |
178 | tanh_injector_->compute_vector(G2.getIdx()); |
179 | // if training we write back the gates |
180 | if (is_training) |
181 | to_src(wg_addr(2), G2, src_data_t, current_vlen); |
182 | |
183 | if (is_augru) { |
184 | load(tmp1_vmm, one_addr, scratch_data_t, current_vlen); |
185 | // for augru there is additional step G01 = (1 - a) * G0 |
186 | // states_t_l = states_tm1_l * G01 + (1 - G01) * G2 |
187 | const Xmm tmp2s_vmm(tmp2_vmm.getIdx()); |
188 | to_float(tmp2s_vmm, ptr[addr_attn_reg], src_data_t, |
189 | scratch_dt_size); |
190 | uni_vbroadcastss(tmp2_vmm, tmp2s_vmm); |
191 | // G01 = (1 - a) * G0 |
192 | compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm, |
193 | current_vlen); |
194 | compute_vmulps(G0, G0, tmp2_vmm, current_vlen); |
195 | // tmp1 = 1 - G01 |
196 | compute_vsubps(tmp1_vmm, tmp1_vmm, G0, current_vlen); |
197 | // tmp1 = G2 * tmp1 |
198 | compute_vmulps( |
199 | tmp1_vmm, G2, tmp1_vmm, tmp3_vmm, current_vlen); |
200 | // states_t_l = G01 * states_tm1_l + tmp2 |
201 | to_float(tmp2_vmm, ptr[addr_states_tm1_l_reg], src_data_t, |
202 | current_vlen); |
203 | compute_vfmadd213ps(G0, tmp2_vmm, tmp1_vmm, current_vlen); |
204 | } else { |
205 | // states_t_l = states_tm1_l * G0 + (1 - G0) * G2 |
206 | load(tmp1_vmm, one_addr, scratch_data_t, current_vlen); |
207 | compute_vsubps(tmp1_vmm, tmp1_vmm, G0, current_vlen); |
208 | to_float(tmp2_vmm, ptr[addr_states_tm1_l_reg], src_data_t, |
209 | current_vlen); |
210 | compute_vmulps(G0, G0, tmp2_vmm, current_vlen); |
211 | compute_vfmadd231ps(G0, tmp1_vmm, G2, current_vlen); |
212 | } |
213 | |
214 | // write back the result |
215 | to_src(ptr[addr_states_t_l_reg], G0, src_data_t, current_vlen); |
216 | // if states_t_l_copy is a non null ptr, we write the output to |
217 | // both tensors |
218 | cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size); |
219 | jle(loop_inc_regs_or_finish); |
220 | // As to_src is called with write_only=true it's important for |
221 | // bf16 src_dt to execute just after to_src method with |
222 | // write_only=false for the same Vmm |
223 | to_src(ptr[addr_states_t_l_copy_reg], G0, src_data_t, |
224 | current_vlen, true); |
225 | // increment address pointers |
226 | L(loop_inc_regs_or_finish); |
227 | if (current_vlen != loop_tail_bytes) { |
228 | const auto current_gate_size |
229 | = current_vlen == vlen ? vlen_dst : gate_dt_size; |
230 | const auto current_states_size |
231 | = current_vlen == vlen ? vlen_dst : hstate_dt_size; |
232 | add(addr_scratch_gates_reg, current_vlen); |
233 | add(addr_ws_h_reg, current_gate_size); |
234 | add(addr_bias_reg, |
235 | current_vlen == vlen ? vlen_bias_ : bias_dt_size_); |
236 | add(addr_states_t_l_reg, current_states_size); |
237 | add(addr_states_t_l_copy_reg, current_states_size); |
238 | add(addr_states_tm1_l_reg, current_states_size); |
239 | add(addr_scratch_cell_reg, current_vlen); |
240 | if (is_training) add(addr_ws_gates_reg, current_gate_size); |
241 | |
242 | // increment loop counter |
243 | sub(loop_cnt, current_vlen); |
244 | cmp(loop_cnt, current_vlen); |
245 | jge(loop_start_label); |
246 | } |
247 | } |
248 | }; |
249 | |
250 | // initialize registers with addresses and constants |
251 | mov(table_reg, table_label); |
252 | init_regs(vlen, loop_tail_bytes / scratch_dt_size); |
253 | mov(loop_cnt, loop_len_bytes); |
254 | if (loop_tail_bytes > 0) { |
255 | cmp(loop_cnt, vlen); |
256 | jl(tail_processing_or_exit_label, T_NEAR); |
257 | } |
258 | |
259 | compute_loop(vlen); |
260 | |
261 | L(tail_processing_or_exit_label); |
262 | if (loop_tail_bytes > 0) { |
263 | Label exit_label; |
264 | cmp(loop_cnt, 0); |
265 | jle(exit_label, T_NEAR); |
266 | compute_loop(is_avx512 ? loop_tail_bytes : scratch_dt_size); |
267 | L(exit_label); |
268 | } |
269 | |
270 | postamble(); |
271 | |
272 | sigmoid_injector_->prepare_table(true); |
273 | tanh_injector_->prepare_table(true); |
274 | init_table(vlen); |
275 | |
276 | L(table_label); |
277 | { |
278 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
279 | dd(float2int(1.0f)); |
280 | } |
281 | } |
282 | }; // namespace cpu |
283 | |
284 | } // namespace x64 |
285 | } // namespace cpu |
286 | } // namespace impl |
287 | } // namespace dnnl |
288 | |
289 | #endif |
290 | |