1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_FWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_gru_cell_postgemm_part2_fwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_fwd)
31
32 using injector_t = typename utils::conditional<isa == avx512_core,
33 jit_uni_eltwise_injector_f32<avx512_core>,
34 jit_uni_eltwise_injector_f32<isa>>::type;
35
36 jit_uni_gru_cell_postgemm_part2_fwd(
37 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
38 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
39
40 status_t init(data_type_t sdt) override {
41 jit_uni_rnn_postgemm::init(src_data_t);
42 // no need to save state of registers
43 // (unless emulating bf16 support or using pre-avx2 isa)
44 const bool save_state = (isa == sse41 || isa == avx)
45 || (src_data_t == data_type::bf16
46 && !mayiuse(avx512_core_bf16));
47 // we use rax for both constant tables as they use the same table
48 CHECK(safe_ptr_assign(tanh_injector_,
49 new injector_t(this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f,
50 save_state, rax)));
51 return create_kernel();
52 }
53
54protected:
55 std::unique_ptr<injector_t> tanh_injector_;
56
57 // register size in bytes
58 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
59 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
60 static constexpr size_t qscale_dt_size = sizeof(float);
61 const size_t vlen_dst
62 = vlen / (sizeof(float) / types::data_type_size(src_data_t));
63 const size_t vlen_bias = vlen / (sizeof(float) / bias_dt_size_);
64 const size_t hstate_dt_size = types::data_type_size(src_data_t);
65 const size_t gate_dt_size = types::data_type_size(src_data_t);
66 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
67 const size_t vlen_qscale = vlen / qscale_dt_size;
68 const size_t vlen_elems = vlen / scratch_dt_size;
69
70 const int loop_ur_max = 4;
71 // We skip vmm0 as it can be used by the injector for masks on sse4.1
72 Vmm G0(int i) {
73 const int idx = 1 + i;
74 assert(idx < loop_ur_max + 1);
75 return Vmm(idx); // max of vmm4
76 }
77 Vmm G2(int i) {
78 const int idx = loop_ur_max + 1 + i;
79 assert(idx < 2 * loop_ur_max + 1);
80 return Vmm(idx); // max of vmm8
81 }
82 const Vmm tmp1_vmm = Vmm(9);
83 const Vmm tmp2_vmm = Vmm(10);
84 const Vmm tmp3_vmm = Vmm(11);
85
86 void generate() override {
87 using namespace Xbyak;
88 const auto is_training
89 = pd_->desc()->prop_kind == prop_kind::forward_training;
90
91 const bool is_augru = pd_->cell_kind() == alg_kind::vanilla_augru;
92
93 const int mask = pd_->attr()->rnn_weights_qparams_.mask_;
94 float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
95
96 // Labels declaration
97 Label table_label;
98
99 // Register map
100 const Reg64 loop_cnt(r10); // loop counter
101 const Reg64 table_reg(rbx); // table is used for data scale and shifts
102
103 // constant table map
104 const Address one_addr = ptr[table_reg];
105
106 // We start code generations here
107 preamble();
108
109 // extract addresses passed as parameter
110 const auto addr_ws_gates_reg = abi_param1;
111 const auto addr_scratch_gates_reg = abi_param2;
112 const auto addr_bias_reg = abi_param3;
113 const auto addr_states_t_l_reg = abi_param4;
114 const auto addr_attn_reg = r15;
115#ifdef _WIN32
116 const auto addr_states_t_l_copy_reg = r11;
117 const auto addr_states_tm1_l_reg = r12;
118 // Here we cannot use rbp to have initial stack pointer so we
119 // use rsp and offset it with the size of pushed registers in
120 // preamble
121 const auto base_args = get_stack_params_address();
122 mov(addr_states_t_l_copy_reg, ptr[base_args]);
123 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
124 if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]);
125#else
126 const auto addr_states_t_l_copy_reg = abi_param5;
127 const auto addr_states_tm1_l_reg = abi_param6;
128 const auto base_args = get_stack_params_address();
129 if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]);
130#endif
131
132 // helper lambda to address the gates and biases
133 const auto sg_addr = [&](int i, int j) {
134 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size
135 + j * vlen];
136 };
137 const auto wg_addr = [&](int i, int j) {
138 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size
139 + j * vlen_dst];
140 };
141 const auto B_addr = [&](int i, int j) {
142 return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_
143 + j * vlen_bias];
144 };
145
146 const size_t loop_len = rnn_.dhc;
147 const size_t loop_tail = loop_len % vlen_elems;
148
149 // initialize registers with addresses and constants
150 mov(table_reg, table_label);
151 tanh_injector_->load_table_addr();
152 init_regs(weights_scales, vlen, loop_tail);
153
154 const size_t nb_loop_len = loop_len / vlen_elems;
155 size_t loop_ur_val = 1;
156 const bool is_brgemm = rnn_.is_brgemm && !rnn_.unfused_post_gemm;
157 if (is_brgemm) {
158#ifdef _WIN32
159 mov(loop_cnt, ptr[base_args + 40]);
160#else
161 // Here we cannot use rbp to have initial stack pointer so we
162 // use rsp and offset it with the size of pushed registers in
163 // preamble
164 const auto base_args = get_stack_params_address();
165 mov(loop_cnt, ptr[base_args + 24]);
166#endif
167 } else {
168 for (loop_ur_val = loop_ur_max; loop_ur_val > 1; --loop_ur_val)
169 if (nb_loop_len % loop_ur_val == 0) break;
170
171 mov(loop_cnt, loop_len);
172 }
173 const size_t loop_ur = loop_ur_val;
174
175 auto compute_loop = [=](size_t current_vlen_elem,
176 size_t current_loop_unroll) {
177 const auto current_vlen = current_vlen_elem * scratch_dt_size;
178 Label loop_start_label;
179 L(loop_start_label);
180 {
181 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
182 ++loop_ur_idx) {
183 // Compute gate 2: G2 = tanh(G2 + b2)
184 load(G2(loop_ur_idx), sg_addr(2, loop_ur_idx),
185 scratch_data_t, current_vlen);
186 // dequantize gate from s32 to f32 if needed
187 deq_w(src_data_t, G2(loop_ur_idx), tmp1_vmm, tmp2_vmm,
188 2 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask,
189 current_vlen);
190 to_float(tmp1_vmm, B_addr(2, loop_ur_idx), rnn_.bias_dt,
191 current_vlen);
192 compute_vaddps(G2(loop_ur_idx), G2(loop_ur_idx), tmp1_vmm,
193 current_vlen);
194 }
195
196 // Compute tanh of unrolled G2 regs together
197 // (this allows to not save any registers during eltwise)
198 injector_utils::vmm_index_set_t vmm_idxs;
199 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
200 ++loop_ur_idx) {
201 vmm_idxs.emplace(G2(loop_ur_idx).getIdx());
202 }
203 tanh_injector_->compute_vector_range(vmm_idxs);
204
205 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
206 ++loop_ur_idx) {
207 // if training we write back the gates
208 if (is_training)
209 to_src(wg_addr(2, loop_ur_idx), G2(loop_ur_idx),
210 src_data_t, current_vlen);
211
212 load(G0(loop_ur_idx), sg_addr(0, loop_ur_idx),
213 scratch_data_t, current_vlen);
214 load(tmp1_vmm, one_addr, scratch_data_t, current_vlen);
215 if (is_augru) {
216 // for augru there is additional step G01 = (1 - a) * G0
217 // states_t_l = states_tm1_l * G01 + (1 - G01) * G2
218 const Xmm tmp2s_vmm(tmp2_vmm.getIdx());
219 to_float(tmp2s_vmm, ptr[addr_attn_reg], src_data_t,
220 scratch_dt_size);
221 uni_vbroadcastss(tmp2_vmm, tmp2s_vmm);
222 // G01 = (1 - a) * G0
223 compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm,
224 current_vlen);
225 compute_vmulps(G0(loop_ur_idx), G0(loop_ur_idx),
226 tmp2_vmm, current_vlen);
227 to_float(tmp2_vmm,
228 ptr[addr_states_tm1_l_reg
229 + loop_ur_idx * vlen_dst],
230 src_data_t, current_vlen);
231 // tmp1 = 1 - G01
232 compute_vsubps(tmp1_vmm, tmp1_vmm, G0(loop_ur_idx),
233 current_vlen);
234 // tmp1 = G2 * tmp1
235 compute_vmulps(tmp1_vmm, G2(loop_ur_idx), tmp1_vmm,
236 tmp3_vmm, current_vlen);
237 // states_t_l = G01 * states_tm1_l + tmp1
238 compute_vfmadd213ps(G0(loop_ur_idx), tmp2_vmm, tmp1_vmm,
239 current_vlen);
240 } else {
241 // states_t_l = states_tm1_l * G0 + (1 - G0) * G2
242 compute_vsubps(tmp1_vmm, tmp1_vmm, G0(loop_ur_idx),
243 current_vlen);
244 to_float(tmp2_vmm,
245 ptr[addr_states_tm1_l_reg
246 + loop_ur_idx * vlen_dst],
247 src_data_t, current_vlen);
248 compute_vmulps(G0(loop_ur_idx), G0(loop_ur_idx),
249 tmp2_vmm, current_vlen);
250 compute_vfmadd231ps(G0(loop_ur_idx), tmp1_vmm,
251 G2(loop_ur_idx), current_vlen);
252 }
253 to_src(ptr[addr_states_t_l_reg + loop_ur_idx * vlen_dst],
254 G0(loop_ur_idx), src_data_t, current_vlen);
255 // if states_t_l_copy is a non null ptr, we write the output
256 // to both tensors
257 Label loop_inc_regs;
258 cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size);
259 jle(loop_inc_regs);
260 // As to_src is called with write_only=true it's important
261 // for bf16 src_dt to execute just after to_src method with
262 // write_only=false for the same Vmm
263 to_src(ptr[addr_states_t_l_copy_reg
264 + loop_ur_idx * vlen_dst],
265 G0(loop_ur_idx), src_data_t, current_vlen, true);
266 L(loop_inc_regs);
267 }
268
269 if (current_vlen_elem != loop_tail) {
270 // increment address pointers
271 const auto current_gate_size = current_vlen == vlen
272 ? vlen_dst * current_loop_unroll
273 : gate_dt_size;
274 const auto current_states_size = current_vlen == vlen
275 ? vlen_dst * current_loop_unroll
276 : hstate_dt_size;
277
278 add(addr_scratch_gates_reg,
279 current_vlen * current_loop_unroll);
280 add(addr_bias_reg,
281 current_vlen == vlen
282 ? vlen_bias * current_loop_unroll
283 : bias_dt_size_);
284 add(addr_states_t_l_reg, current_states_size);
285 add(addr_states_t_l_copy_reg, current_states_size);
286 add(addr_states_tm1_l_reg, current_states_size);
287 if (is_training) add(addr_ws_gates_reg, current_gate_size);
288 inc_regs(mask,
289 current_vlen == vlen
290 ? current_vlen * current_loop_unroll
291 : qscale_dt_size);
292
293 // increment loop counter
294 sub(loop_cnt, current_vlen_elem * current_loop_unroll);
295 cmp(loop_cnt, current_vlen_elem * current_loop_unroll);
296 jge(loop_start_label);
297 }
298 }
299 };
300
301 // vector processing
302 if (loop_len >= vlen_elems) {
303 Label tail_processing_or_exit_label;
304 if (is_brgemm) {
305 cmp(loop_cnt, vlen_elems * loop_ur);
306 jl(tail_processing_or_exit_label, T_NEAR);
307 }
308 compute_loop(vlen_elems, loop_ur);
309 L(tail_processing_or_exit_label);
310 }
311
312 // tail processing
313 if (loop_tail > 0) {
314 Label exit_label;
315 if (is_brgemm) {
316 cmp(loop_cnt, 0);
317 jle(exit_label, T_NEAR);
318 }
319
320 compute_loop(is_avx512 ? loop_tail : 1, 1);
321 L(exit_label);
322 }
323
324 postamble();
325
326 tanh_injector_->prepare_table(true);
327 init_table(vlen);
328 L(table_label);
329 {
330 for (size_t i = 0; i < vlen / sizeof(float); i++)
331 dd(float2int(1.0f));
332 }
333 }
334};
335
336} // namespace x64
337} // namespace cpu
338} // namespace impl
339} // namespace dnnl
340
341#endif
342