1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_FWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_FWD_HPP |
19 | |
20 | #include "cpu/x64/injectors/injector_utils.hpp" |
21 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
29 | impl::data_type_t scratch_data_t> |
30 | struct jit_uni_gru_cell_postgemm_part1_fwd : public jit_uni_rnn_postgemm { |
31 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part1_fwd) |
32 | |
33 | using injector_t = typename utils::conditional<isa == avx512_core, |
34 | jit_uni_eltwise_injector_f32<avx512_core>, |
35 | jit_uni_eltwise_injector_f32<isa>>::type; |
36 | |
37 | jit_uni_gru_cell_postgemm_part1_fwd( |
38 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
39 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
40 | |
41 | status_t init(data_type_t sdt) override { |
42 | jit_uni_rnn_postgemm::init(src_data_t); |
43 | // no need to save state of registers |
44 | // (unless emulating bf16 support) |
45 | const bool save_state |
46 | = src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16); |
47 | // we use rax for both constant tables as they use the same table |
48 | CHECK(safe_ptr_assign(sigmoid_injector_, |
49 | new injector_t(this, alg_kind::eltwise_logistic, 0.0f, 0.0f, |
50 | 1.0f, save_state, rax))); |
51 | return create_kernel(); |
52 | } |
53 | |
54 | protected: |
55 | std::unique_ptr<injector_t> sigmoid_injector_; |
56 | |
57 | // register size in bytes |
58 | using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm; |
59 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
60 | static constexpr size_t qscale_dt_size = sizeof(float); |
61 | const size_t vlen_dst |
62 | = vlen / (sizeof(float) / types::data_type_size(src_data_t)); |
63 | const size_t vlen_bias_ = vlen / (sizeof(float) / bias_dt_size_); |
64 | const size_t hstate_dt_size = types::data_type_size(src_data_t); |
65 | const size_t gate_dt_size = types::data_type_size(src_data_t); |
66 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
67 | const size_t vlen_qscale = vlen / qscale_dt_size; |
68 | const size_t vlen_elems = vlen / scratch_dt_size; |
69 | |
70 | const int loop_ur_max = 4; |
71 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
72 | int G0_idx(int i) { |
73 | const int idx = 1 + i; |
74 | assert(idx < loop_ur_max + 1); |
75 | return idx; |
76 | } |
77 | int G1_idx(int i) { |
78 | const int idx = loop_ur_max + 1 + i; |
79 | assert(idx < 2 * loop_ur_max + 1); |
80 | return idx; |
81 | } |
82 | const Vmm tmp1_vmm = Vmm(9); |
83 | const Vmm tmp2_vmm = Vmm(10); |
84 | |
85 | void generate() override { |
86 | using namespace Xbyak; |
87 | const auto is_training |
88 | = pd_->desc()->prop_kind == prop_kind::forward_training; |
89 | |
90 | const int mask = pd_->attr()->rnn_weights_qparams_.mask_; |
91 | float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_; |
92 | |
93 | // Register map |
94 | const Reg64 loop_cnt(rbx); // loop counter |
95 | |
96 | // We start code generations here |
97 | preamble(); |
98 | |
99 | // extract addresses passed as parameter |
100 | const auto addr_ws_gates_reg = abi_param1; |
101 | const auto addr_scratch_gates_reg = abi_param2; |
102 | const auto addr_bias_reg = abi_param3; |
103 | const auto addr_states_t_l_reg = abi_param4; |
104 | #ifdef _WIN32 |
105 | const auto addr_states_t_l_copy_reg = r10; |
106 | const auto addr_states_tm1_l_reg = r11; |
107 | // Here we cannot use rbp to have initial stack pointer so we |
108 | // use rsp and offset it with the size of pushed registers in |
109 | // preamble |
110 | const auto base_args = get_stack_params_address(); |
111 | mov(addr_states_t_l_copy_reg, ptr[base_args]); |
112 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
113 | #else |
114 | const auto addr_states_t_l_copy_reg = abi_param5; |
115 | const auto addr_states_tm1_l_reg = abi_param6; |
116 | #endif |
117 | // helper lambda to address the gates and biases |
118 | const auto sg_addr = [&](int i, int j) { |
119 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size |
120 | + j * vlen]; |
121 | }; |
122 | const auto wg_addr = [&](int i, int j) { |
123 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size |
124 | + j * vlen_dst]; |
125 | }; |
126 | const auto B_addr = [&](int i, int j) { |
127 | return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_ + j * vlen]; |
128 | }; |
129 | |
130 | const size_t loop_len = rnn_.dhc; |
131 | const size_t loop_tail = loop_len % vlen_elems; |
132 | // initialize registers with addresses and constants |
133 | init_regs(weights_scales, vlen, loop_tail); |
134 | |
135 | // both sigmoid and tanh use the same table so load address just once in rax |
136 | sigmoid_injector_->load_table_addr(); |
137 | |
138 | const size_t nb_loop_len = loop_len / vlen_elems; |
139 | size_t loop_ur_val = 1; |
140 | const bool is_brgemm = rnn_.is_brgemm && !rnn_.unfused_post_gemm; |
141 | if (is_brgemm) { |
142 | #ifdef _WIN32 |
143 | mov(loop_cnt, ptr[base_args + 40]); |
144 | #else |
145 | // Here we cannot use rbp to have initial stack pointer so we |
146 | // use rsp and offset it with the size of pushed registers in |
147 | // preamble |
148 | const auto base_args = get_stack_params_address(); |
149 | mov(loop_cnt, ptr[base_args + 24]); |
150 | #endif |
151 | } else { |
152 | for (loop_ur_val = loop_ur_max; loop_ur_val > 1; --loop_ur_val) |
153 | if (nb_loop_len % loop_ur_val == 0) break; |
154 | |
155 | mov(loop_cnt, loop_len); |
156 | } |
157 | const size_t loop_ur = loop_ur_val; |
158 | |
159 | auto compute_loop = [=](size_t current_vlen_elem, |
160 | size_t current_loop_unroll) { |
161 | const auto current_vlen = current_vlen_elem * scratch_dt_size; |
162 | Label loop_start_label; |
163 | L(loop_start_label); |
164 | { |
165 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
166 | ++loop_ur_idx) { |
167 | const Vmm G0(G0_idx(loop_ur_idx)); |
168 | const Vmm G1(G1_idx(loop_ur_idx)); |
169 | // batch these operations in order to combine calls to injector: |
170 | // Compute gate 0: G0 = sigmoid(G0 + b0) |
171 | // Compute gate 1: G1 = sigmoid(G1 + b1) |
172 | |
173 | // load gates from scratchpad |
174 | load(G0, sg_addr(0, loop_ur_idx), scratch_data_t, |
175 | current_vlen); |
176 | load(G1, sg_addr(1, loop_ur_idx), scratch_data_t, |
177 | current_vlen); |
178 | |
179 | // dequantize gates from s32 to f32 if needed |
180 | deq_w(src_data_t, G0, tmp1_vmm, tmp2_vmm, |
181 | 0 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask, |
182 | current_vlen); |
183 | deq_w(src_data_t, G1, tmp1_vmm, tmp2_vmm, |
184 | 1 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask, |
185 | current_vlen); |
186 | |
187 | // apply bias |
188 | to_float(tmp1_vmm, B_addr(0, loop_ur_idx), rnn_.bias_dt, |
189 | current_vlen); |
190 | compute_vaddps(G0, G0, tmp1_vmm, current_vlen); |
191 | to_float(tmp2_vmm, B_addr(1, loop_ur_idx), rnn_.bias_dt, |
192 | current_vlen); |
193 | compute_vaddps(G1, G1, tmp2_vmm, current_vlen); |
194 | } |
195 | |
196 | // Compute sigmoid of unrolled G0 and G1 regs together |
197 | // (this allows to not save any registers during eltwise) |
198 | injector_utils::vmm_index_set_t vmm_idxs; |
199 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
200 | ++loop_ur_idx) { |
201 | vmm_idxs.emplace(G0_idx(loop_ur_idx)); |
202 | vmm_idxs.emplace(G1_idx(loop_ur_idx)); |
203 | } |
204 | sigmoid_injector_->compute_vector_range(vmm_idxs); |
205 | |
206 | for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll; |
207 | ++loop_ur_idx) { |
208 | const Vmm G0(G0_idx(loop_ur_idx)); |
209 | const Vmm G1(G1_idx(loop_ur_idx)); |
210 | // store G0 for use in postgemm_part2 |
211 | store(sg_addr(0, loop_ur_idx), G0, scratch_data_t, |
212 | current_vlen); |
213 | |
214 | // if training we write back the gates |
215 | if (is_training) { |
216 | to_src(wg_addr(1, loop_ur_idx), G1, src_data_t, |
217 | current_vlen); |
218 | to_src(wg_addr(0, loop_ur_idx), G0, src_data_t, |
219 | current_vlen); |
220 | } |
221 | |
222 | // states_t_l = states_tm1_l * G1 |
223 | to_float(tmp1_vmm, |
224 | ptr[addr_states_tm1_l_reg + loop_ur_idx * vlen_dst], |
225 | src_data_t, current_vlen); |
226 | compute_vmulps(G1, G1, tmp1_vmm, current_vlen); |
227 | to_src(ptr[addr_states_t_l_reg + loop_ur_idx * vlen_dst], |
228 | G1, src_data_t, current_vlen); |
229 | // if states_t_l_copy is a non null ptr, we write the output |
230 | // to both tensors |
231 | Label loop_inc_regs; |
232 | cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size); |
233 | jle(loop_inc_regs); |
234 | // As to_src is called with write_only=true it's important |
235 | // for bf16 src_dt to execute just after to_src method with |
236 | // write_only=false for the same Vmm |
237 | to_src(ptr[addr_states_t_l_copy_reg |
238 | + loop_ur_idx * vlen_dst], |
239 | G1, src_data_t, current_vlen, true); |
240 | L(loop_inc_regs); |
241 | } |
242 | |
243 | if (current_vlen_elem != loop_tail) { |
244 | // increment address pointers |
245 | const auto current_gate_size = current_vlen == vlen |
246 | ? vlen_dst * current_loop_unroll |
247 | : gate_dt_size; |
248 | const auto current_states_size = current_vlen == vlen |
249 | ? vlen_dst * current_loop_unroll |
250 | : hstate_dt_size; |
251 | |
252 | add(addr_scratch_gates_reg, |
253 | current_vlen * current_loop_unroll); |
254 | add(addr_bias_reg, |
255 | current_vlen == vlen |
256 | ? vlen_bias_ * current_loop_unroll |
257 | : bias_dt_size_); |
258 | add(addr_states_t_l_reg, current_states_size); |
259 | add(addr_states_t_l_copy_reg, current_states_size); |
260 | add(addr_states_tm1_l_reg, current_states_size); |
261 | if (is_training) add(addr_ws_gates_reg, current_gate_size); |
262 | inc_regs(mask, |
263 | current_vlen == vlen |
264 | ? current_vlen * current_loop_unroll |
265 | : qscale_dt_size); |
266 | |
267 | // increment loop counter |
268 | sub(loop_cnt, current_vlen_elem * current_loop_unroll); |
269 | cmp(loop_cnt, current_vlen_elem * current_loop_unroll); |
270 | jge(loop_start_label); |
271 | } |
272 | } |
273 | }; |
274 | |
275 | // vector processing |
276 | if (loop_len >= vlen_elems) { |
277 | Label tail_processing_or_exit_label; |
278 | if (is_brgemm) { |
279 | cmp(loop_cnt, vlen_elems * loop_ur); |
280 | jl(tail_processing_or_exit_label, T_NEAR); |
281 | } |
282 | compute_loop(vlen_elems, loop_ur); |
283 | L(tail_processing_or_exit_label); |
284 | } |
285 | |
286 | // tail processing |
287 | if (loop_tail > 0) { |
288 | Label exit_label; |
289 | if (is_brgemm) { |
290 | cmp(loop_cnt, 0); |
291 | jle(exit_label, T_NEAR); |
292 | } |
293 | |
294 | compute_loop(is_avx512 ? loop_tail : 1, 1); |
295 | L(exit_label); |
296 | } |
297 | |
298 | postamble(); |
299 | |
300 | // Again, only one table is needed and shared between sigmoid and tanh |
301 | sigmoid_injector_->prepare_table(true); |
302 | init_table(vlen); |
303 | } |
304 | }; |
305 | |
306 | } // namespace x64 |
307 | } // namespace cpu |
308 | } // namespace impl |
309 | } // namespace dnnl |
310 | |
311 | #endif |
312 | |