1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_FWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_FWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_gru_cell_postgemm_part2_fwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_fwd) |
31 | |
32 | using injector_t = typename utils::conditional<isa == avx512_core, |
33 | jit_uni_eltwise_injector_f32<avx512_core>, |
34 | jit_uni_eltwise_injector_f32<isa>>::type; |
35 | |
36 | jit_uni_gru_cell_postgemm_part2_fwd( |
37 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
38 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
39 | |
40 | status_t init(data_type_t sdt) override { |
41 | jit_uni_rnn_postgemm::init(src_data_t); |
42 | // no need to save state of registers |
43 | // (unless emulating bf16 support or using pre-avx2 isa) |
44 | const bool save_state = (isa == sse41 || isa == avx) |
45 | || (src_data_t == data_type::bf16 |
46 | && !mayiuse(avx512_core_bf16)); |
47 | // we use rax for both constant tables as they use the same table |
48 | CHECK(safe_ptr_assign(tanh_injector_, |
49 | new injector_t(this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, |
50 | save_state, rax))); |
51 | return create_kernel(); |
52 | } |
53 | |
54 | protected: |
55 | std::unique_ptr<injector_t> tanh_injector_; |
56 | |
57 | // register size in bytes |
58 | using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; |
59 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
60 | static constexpr size_t qscale_dt_size = sizeof(float); |
61 | const size_t vlen_dst |
62 | = vlen / (sizeof(float) / types::data_type_size(src_data_t)); |
63 | const size_t vlen_bias = vlen / (sizeof(float) / bias_dt_size_); |
64 | const size_t hstate_dt_size = types::data_type_size(src_data_t); |
65 | const size_t gate_dt_size = types::data_type_size(src_data_t); |
66 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
67 | const size_t vlen_qscale = vlen / qscale_dt_size; |
68 | const size_t vlen_elems = vlen / scratch_dt_size; |
69 | |
70 | const int loop_ur_max = 4; |
71 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
72 | Vmm G0(int i) { |
73 | const int idx = 1 + i; |
74 | assert(idx < loop_ur_max + 1); |
75 | return Vmm(idx); // max of vmm4 |
76 | } |
77 | Vmm G2(int i) { |
78 | const int idx = loop_ur_max + 1 + i; |
79 | assert(idx < 2 * loop_ur_max + 1); |
80 | return Vmm(idx); // max of vmm8 |
81 | } |
82 | const Vmm tmp1_vmm = Vmm(9); |
83 | const Vmm tmp2_vmm = Vmm(10); |
84 | const Vmm tmp3_vmm = Vmm(11); |
85 | |
86 | void generate() override { |
87 | using namespace Xbyak; |
88 | const auto is_training |
89 | = pd_->desc()->prop_kind == prop_kind::forward_training; |
90 | |
91 | const bool is_augru = pd_->cell_kind() == alg_kind::vanilla_augru; |
92 | |
93 | const int mask = pd_->attr()->rnn_weights_qparams_.mask_; |
94 | float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
95 | |
96 | // Labels declaration |
97 | Label table_label; |
98 | |
99 | // Register map |
100 | const Reg64 loop_cnt(r10); // loop counter |
101 | const Reg64 table_reg(rbx); // table is used for data scale and shifts |
102 | |
103 | // constant table map |
104 | const Address one_addr = ptr[table_reg]; |
105 | |
106 | // We start code generations here |
107 | preamble(); |
108 | |
109 | // extract addresses passed as parameter |
110 | const auto addr_ws_gates_reg = abi_param1; |
111 | const auto addr_scratch_gates_reg = abi_param2; |
112 | const auto addr_bias_reg = abi_param3; |
113 | const auto addr_states_t_l_reg = abi_param4; |
114 | const auto addr_attn_reg = r15; |
115 | #ifdef _WIN32 |
116 | const auto addr_states_t_l_copy_reg = r11; |
117 | const auto addr_states_tm1_l_reg = r12; |
118 | // Here we cannot use rbp to have initial stack pointer so we |
119 | // use rsp and offset it with the size of pushed registers in |
120 | // preamble |
121 | const auto base_args = get_stack_params_address(); |
122 | mov(addr_states_t_l_copy_reg, ptr[base_args]); |
123 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
124 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]); |
125 | #else |
126 | const auto addr_states_t_l_copy_reg = abi_param5; |
127 | const auto addr_states_tm1_l_reg = abi_param6; |
128 | const auto base_args = get_stack_params_address(); |
129 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]); |
130 | #endif |
131 | |
132 | // helper lambda to address the gates and biases |
133 | const auto sg_addr = [&](int i, int j) { |
134 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size |
135 | + j * vlen]; |
136 | }; |
137 | const auto wg_addr = [&](int i, int j) { |
138 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size |
139 | + j * vlen_dst]; |
140 | }; |
141 | const auto B_addr = [&](int i, int j) { |
142 | return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_ |
143 | + j * vlen_bias]; |
144 | }; |
145 | |
146 | const size_t loop_len = rnn_.dhc; |
147 | const size_t loop_tail = loop_len % vlen_elems; |
148 | |
149 | // initialize registers with addresses and constants |
150 | mov(table_reg, table_label); |
151 | tanh_injector_->load_table_addr(); |
152 | init_regs(weights_scales, vlen, loop_tail); |
153 | |
154 | const size_t nb_loop_len = loop_len / vlen_elems; |
155 | size_t loop_ur_val = 1; |
156 | const bool is_brgemm = rnn_.is_brgemm && !rnn_.unfused_post_gemm; |
157 | if (is_brgemm) { |
158 | #ifdef _WIN32 |
159 | mov(loop_cnt, ptr[base_args + 40]); |
160 | #else |
161 | // Here we cannot use rbp to have initial stack pointer so we |
162 | // use rsp and offset it with the size of pushed registers in |
163 | // preamble |
164 | const auto base_args = get_stack_params_address(); |
165 | mov(loop_cnt, ptr[base_args + 24]); |
166 | #endif |
167 | } else { |
168 | for (loop_ur_val = loop_ur_max; loop_ur_val > 1; --loop_ur_val) |
169 | if (nb_loop_len % loop_ur_val == 0) break; |
170 | |
171 | mov(loop_cnt, loop_len); |
172 | } |
173 | const size_t loop_ur = loop_ur_val; |
174 | |
175 | auto compute_loop = [=](size_t current_vlen_elem, |
176 | size_t current_loop_unroll) { |
177 | const auto current_vlen = current_vlen_elem * scratch_dt_size; |
178 | Label loop_start_label; |
179 | L(loop_start_label); |
180 | { |
181 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
182 | ++loop_ur_idx) { |
183 | // Compute gate 2: G2 = tanh(G2 + b2) |
184 | load(G2(loop_ur_idx), sg_addr(2, loop_ur_idx), |
185 | scratch_data_t, current_vlen); |
186 | // dequantize gate from s32 to f32 if needed |
187 | deq_w(src_data_t, G2(loop_ur_idx), tmp1_vmm, tmp2_vmm, |
188 | 2 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask, |
189 | current_vlen); |
190 | to_float(tmp1_vmm, B_addr(2, loop_ur_idx), rnn_.bias_dt, |
191 | current_vlen); |
192 | compute_vaddps(G2(loop_ur_idx), G2(loop_ur_idx), tmp1_vmm, |
193 | current_vlen); |
194 | } |
195 | |
196 | // Compute tanh of unrolled G2 regs together |
197 | // (this allows to not save any registers during eltwise) |
198 | injector_utils::vmm_index_set_t vmm_idxs; |
199 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
200 | ++loop_ur_idx) { |
201 | vmm_idxs.emplace(G2(loop_ur_idx).getIdx()); |
202 | } |
203 | tanh_injector_->compute_vector_range(vmm_idxs); |
204 | |
205 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
206 | ++loop_ur_idx) { |
207 | // if training we write back the gates |
208 | if (is_training) |
209 | to_src(wg_addr(2, loop_ur_idx), G2(loop_ur_idx), |
210 | src_data_t, current_vlen); |
211 | |
212 | load(G0(loop_ur_idx), sg_addr(0, loop_ur_idx), |
213 | scratch_data_t, current_vlen); |
214 | load(tmp1_vmm, one_addr, scratch_data_t, current_vlen); |
215 | if (is_augru) { |
216 | // for augru there is additional step G01 = (1 - a) * G0 |
217 | // states_t_l = states_tm1_l * G01 + (1 - G01) * G2 |
218 | const Xmm tmp2s_vmm(tmp2_vmm.getIdx()); |
219 | to_float(tmp2s_vmm, ptr[addr_attn_reg], src_data_t, |
220 | scratch_dt_size); |
221 | uni_vbroadcastss(tmp2_vmm, tmp2s_vmm); |
222 | // G01 = (1 - a) * G0 |
223 | compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm, |
224 | current_vlen); |
225 | compute_vmulps(G0(loop_ur_idx), G0(loop_ur_idx), |
226 | tmp2_vmm, current_vlen); |
227 | to_float(tmp2_vmm, |
228 | ptr[addr_states_tm1_l_reg |
229 | + loop_ur_idx * vlen_dst], |
230 | src_data_t, current_vlen); |
231 | // tmp1 = 1 - G01 |
232 | compute_vsubps(tmp1_vmm, tmp1_vmm, G0(loop_ur_idx), |
233 | current_vlen); |
234 | // tmp1 = G2 * tmp1 |
235 | compute_vmulps(tmp1_vmm, G2(loop_ur_idx), tmp1_vmm, |
236 | tmp3_vmm, current_vlen); |
237 | // states_t_l = G01 * states_tm1_l + tmp1 |
238 | compute_vfmadd213ps(G0(loop_ur_idx), tmp2_vmm, tmp1_vmm, |
239 | current_vlen); |
240 | } else { |
241 | // states_t_l = states_tm1_l * G0 + (1 - G0) * G2 |
242 | compute_vsubps(tmp1_vmm, tmp1_vmm, G0(loop_ur_idx), |
243 | current_vlen); |
244 | to_float(tmp2_vmm, |
245 | ptr[addr_states_tm1_l_reg |
246 | + loop_ur_idx * vlen_dst], |
247 | src_data_t, current_vlen); |
248 | compute_vmulps(G0(loop_ur_idx), G0(loop_ur_idx), |
249 | tmp2_vmm, current_vlen); |
250 | compute_vfmadd231ps(G0(loop_ur_idx), tmp1_vmm, |
251 | G2(loop_ur_idx), current_vlen); |
252 | } |
253 | to_src(ptr[addr_states_t_l_reg + loop_ur_idx * vlen_dst], |
254 | G0(loop_ur_idx), src_data_t, current_vlen); |
255 | // if states_t_l_copy is a non null ptr, we write the output |
256 | // to both tensors |
257 | Label loop_inc_regs; |
258 | cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size); |
259 | jle(loop_inc_regs); |
260 | // As to_src is called with write_only=true it's important |
261 | // for bf16 src_dt to execute just after to_src method with |
262 | // write_only=false for the same Vmm |
263 | to_src(ptr[addr_states_t_l_copy_reg |
264 | + loop_ur_idx * vlen_dst], |
265 | G0(loop_ur_idx), src_data_t, current_vlen, true); |
266 | L(loop_inc_regs); |
267 | } |
268 | |
269 | if (current_vlen_elem != loop_tail) { |
270 | // increment address pointers |
271 | const auto current_gate_size = current_vlen == vlen |
272 | ? vlen_dst * current_loop_unroll |
273 | : gate_dt_size; |
274 | const auto current_states_size = current_vlen == vlen |
275 | ? vlen_dst * current_loop_unroll |
276 | : hstate_dt_size; |
277 | |
278 | add(addr_scratch_gates_reg, |
279 | current_vlen * current_loop_unroll); |
280 | add(addr_bias_reg, |
281 | current_vlen == vlen |
282 | ? vlen_bias * current_loop_unroll |
283 | : bias_dt_size_); |
284 | add(addr_states_t_l_reg, current_states_size); |
285 | add(addr_states_t_l_copy_reg, current_states_size); |
286 | add(addr_states_tm1_l_reg, current_states_size); |
287 | if (is_training) add(addr_ws_gates_reg, current_gate_size); |
288 | inc_regs(mask, |
289 | current_vlen == vlen |
290 | ? current_vlen * current_loop_unroll |
291 | : qscale_dt_size); |
292 | |
293 | // increment loop counter |
294 | sub(loop_cnt, current_vlen_elem * current_loop_unroll); |
295 | cmp(loop_cnt, current_vlen_elem * current_loop_unroll); |
296 | jge(loop_start_label); |
297 | } |
298 | } |
299 | }; |
300 | |
301 | // vector processing |
302 | if (loop_len >= vlen_elems) { |
303 | Label tail_processing_or_exit_label; |
304 | if (is_brgemm) { |
305 | cmp(loop_cnt, vlen_elems * loop_ur); |
306 | jl(tail_processing_or_exit_label, T_NEAR); |
307 | } |
308 | compute_loop(vlen_elems, loop_ur); |
309 | L(tail_processing_or_exit_label); |
310 | } |
311 | |
312 | // tail processing |
313 | if (loop_tail > 0) { |
314 | Label exit_label; |
315 | if (is_brgemm) { |
316 | cmp(loop_cnt, 0); |
317 | jle(exit_label, T_NEAR); |
318 | } |
319 | |
320 | compute_loop(is_avx512 ? loop_tail : 1, 1); |
321 | L(exit_label); |
322 | } |
323 | |
324 | postamble(); |
325 | |
326 | tanh_injector_->prepare_table(true); |
327 | init_table(vlen); |
328 | L(table_label); |
329 | { |
330 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
331 | dd(float2int(1.0f)); |
332 | } |
333 | } |
334 | }; |
335 | |
336 | } // namespace x64 |
337 | } // namespace cpu |
338 | } // namespace impl |
339 | } // namespace dnnl |
340 | |
341 | #endif |
342 | |