1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_FWD_HPP
19
20#include <memory>
21#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28template <cpu_isa_t isa, impl::data_type_t src_data_t,
29 impl::data_type_t scratch_data_t>
30struct jit_uni_gru_lbr_cell_postgemm_fwd : public jit_uni_rnn_postgemm {
31 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_lbr_cell_postgemm_fwd)
32
33 using injector_t = typename utils::conditional<isa == avx512_core,
34 jit_uni_eltwise_injector_f32<avx512_core>,
35 jit_uni_eltwise_injector_f32<isa>>::type;
36
37 jit_uni_gru_lbr_cell_postgemm_fwd(
38 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
39 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
40
41 status_t init(data_type_t sdt) override {
42 jit_uni_rnn_postgemm::init(src_data_t);
43 // we use rax for both constant tables and load correspondent label
44 // into it when calling correspondent injector.
45 sigmoid_injector_ = utils::make_unique<injector_t>(
46 this, alg_kind::eltwise_logistic, 0.0f, 0.0f, 1.0f, true, rax);
47 tanh_injector_ = utils::make_unique<injector_t>(
48 this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, true, rax);
49 return create_kernel();
50 }
51
52protected:
53 std::unique_ptr<injector_t> sigmoid_injector_;
54 std::unique_ptr<injector_t> tanh_injector_;
55
56 // register size in bytes
57 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
58 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
59
60 const size_t vlen_dst
61 = vlen / (sizeof(float) / types::data_type_size(src_data_t));
62 const size_t vlen_bias_ = vlen / (sizeof(float) / bias_dt_size_);
63 const size_t hstate_dt_size = types::data_type_size(src_data_t);
64 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
65 const size_t gate_dt_size = types::data_type_size(src_data_t);
66 const size_t loop_len_bytes = rnn_.dhc * scratch_dt_size;
67 const size_t loop_tail_bytes = loop_len_bytes % vlen;
68
69 void generate() override {
70 using namespace Xbyak;
71
72 const auto is_training
73 = (pd_->desc()->prop_kind == prop_kind::forward_training);
74
75 const bool is_augru = pd_->cell_kind() == alg_kind::lbr_augru;
76
77 // Labels declaration
78 Label tail_processing_or_exit_label, table_label;
79
80 // Register map
81 const Reg64 loop_cnt(r10); // loop counter
82 const Reg64 table_reg(rbx); // table is used for data scale and shifts
83
84 // We skip vmm0 as it can be used by the injector for masks on sse4.1
85 const Vmm G0(1), G1(2), G2(3), tmp1_vmm(5), tmp2_vmm(6), tmp3_vmm(7);
86
87 // constant table map
88 const Address one_addr = ptr[table_reg];
89
90 // We start code generations here
91 preamble();
92
93 // extract addresses passed as parameter
94 const auto addr_ws_gates_reg = abi_param1;
95 const auto addr_scratch_gates_reg = abi_param2;
96 const auto addr_bias_reg = abi_param3;
97 const auto addr_states_t_l_reg = abi_param4;
98 const auto addr_attn_reg = r15;
99#ifdef _WIN32
100 const auto addr_states_t_l_copy_reg = r11;
101 const auto addr_states_tm1_l_reg = r12;
102 const auto addr_scratch_cell_reg = rsi;
103 const auto addr_ws_h_reg = rdi;
104 // Here we cannot use rbp to have initial stack pointer so we
105 // use rsp and offset it with the size of pushed registers in
106 // preamble
107 const auto base_args = get_stack_params_address();
108 mov(addr_states_t_l_copy_reg, ptr[base_args]);
109 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
110 mov(addr_scratch_cell_reg, ptr[base_args + 16]);
111 mov(addr_ws_h_reg, ptr[base_args + 24]);
112 if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]);
113#else
114 const auto addr_states_t_l_copy_reg = abi_param5;
115 const auto addr_states_tm1_l_reg = abi_param6;
116 const auto addr_scratch_cell_reg = r11;
117 const auto addr_ws_h_reg = r12;
118 const auto base_args = get_stack_params_address();
119 mov(addr_scratch_cell_reg, ptr[base_args]);
120 mov(addr_ws_h_reg, ptr[base_args + 8]);
121 if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]);
122#endif
123
124 // helper lambda to address the gates and biases
125 const auto sg_addr = [&](int i) {
126 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
127 };
128 const auto wg_addr = [&](int i) {
129 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
130 };
131 const auto B_addr = [&](int i) {
132 return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_];
133 };
134 const auto sc_addr = [&](int i) {
135 return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size];
136 };
137
138 auto compute_loop = [=](size_t current_vlen) {
139 Label loop_start_label, loop_inc_regs_or_finish;
140 L(loop_start_label);
141 {
142 load(G0, sg_addr(0), scratch_data_t, current_vlen);
143 to_float(tmp1_vmm, B_addr(0), rnn_.bias_dt, current_vlen);
144 compute_vaddps(G0, G0, tmp1_vmm, current_vlen);
145 load(tmp1_vmm, sc_addr(0), scratch_data_t, current_vlen);
146 compute_vaddps(G0, G0, tmp1_vmm, current_vlen);
147 sigmoid_injector_->load_table_addr();
148 sigmoid_injector_->compute_vector(G0.getIdx());
149 // if training we write back the gates
150 if (is_training)
151 to_src(wg_addr(0), G0, src_data_t, current_vlen);
152
153 // Compute gate 1
154 load(G1, sg_addr(1), scratch_data_t, current_vlen);
155 to_float(tmp1_vmm, B_addr(1), rnn_.bias_dt, current_vlen);
156 compute_vaddps(G1, G1, tmp1_vmm, current_vlen);
157 load(tmp1_vmm, sc_addr(1), scratch_data_t, current_vlen);
158 compute_vaddps(G1, G1, tmp1_vmm, current_vlen);
159 sigmoid_injector_->load_table_addr();
160 sigmoid_injector_->compute_vector(G1.getIdx());
161 // if training we write back the gates
162 if (is_training)
163 to_src(wg_addr(1), G1, src_data_t, current_vlen);
164
165 // compute last gate
166 const auto wh_b_addr = sc_addr(2);
167 const auto ws_h_addr = ptr[addr_ws_h_reg];
168 load(tmp1_vmm, wh_b_addr, scratch_data_t, current_vlen);
169 to_float(tmp2_vmm, B_addr(3), rnn_.bias_dt, current_vlen);
170 compute_vaddps(tmp1_vmm, tmp1_vmm, tmp2_vmm, current_vlen);
171 if (is_training)
172 to_src(ws_h_addr, tmp1_vmm, src_data_t, current_vlen);
173 load(G2, sg_addr(2), scratch_data_t, current_vlen);
174 to_float(tmp2_vmm, B_addr(2), rnn_.bias_dt, current_vlen);
175 compute_vaddps(G2, G2, tmp2_vmm, current_vlen);
176 compute_vfmadd231ps(G2, G1, tmp1_vmm, current_vlen);
177 tanh_injector_->load_table_addr();
178 tanh_injector_->compute_vector(G2.getIdx());
179 // if training we write back the gates
180 if (is_training)
181 to_src(wg_addr(2), G2, src_data_t, current_vlen);
182
183 if (is_augru) {
184 load(tmp1_vmm, one_addr, scratch_data_t, current_vlen);
185 // for augru there is additional step G01 = (1 - a) * G0
186 // states_t_l = states_tm1_l * G01 + (1 - G01) * G2
187 const Xmm tmp2s_vmm(tmp2_vmm.getIdx());
188 to_float(tmp2s_vmm, ptr[addr_attn_reg], src_data_t,
189 scratch_dt_size);
190 uni_vbroadcastss(tmp2_vmm, tmp2s_vmm);
191 // G01 = (1 - a) * G0
192 compute_vsubps(tmp2_vmm, tmp1_vmm, tmp2_vmm, tmp3_vmm,
193 current_vlen);
194 compute_vmulps(G0, G0, tmp2_vmm, current_vlen);
195 // tmp1 = 1 - G01
196 compute_vsubps(tmp1_vmm, tmp1_vmm, G0, current_vlen);
197 // tmp1 = G2 * tmp1
198 compute_vmulps(
199 tmp1_vmm, G2, tmp1_vmm, tmp3_vmm, current_vlen);
200 // states_t_l = G01 * states_tm1_l + tmp2
201 to_float(tmp2_vmm, ptr[addr_states_tm1_l_reg], src_data_t,
202 current_vlen);
203 compute_vfmadd213ps(G0, tmp2_vmm, tmp1_vmm, current_vlen);
204 } else {
205 // states_t_l = states_tm1_l * G0 + (1 - G0) * G2
206 load(tmp1_vmm, one_addr, scratch_data_t, current_vlen);
207 compute_vsubps(tmp1_vmm, tmp1_vmm, G0, current_vlen);
208 to_float(tmp2_vmm, ptr[addr_states_tm1_l_reg], src_data_t,
209 current_vlen);
210 compute_vmulps(G0, G0, tmp2_vmm, current_vlen);
211 compute_vfmadd231ps(G0, tmp1_vmm, G2, current_vlen);
212 }
213
214 // write back the result
215 to_src(ptr[addr_states_t_l_reg], G0, src_data_t, current_vlen);
216 // if states_t_l_copy is a non null ptr, we write the output to
217 // both tensors
218 cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size);
219 jle(loop_inc_regs_or_finish);
220 // As to_src is called with write_only=true it's important for
221 // bf16 src_dt to execute just after to_src method with
222 // write_only=false for the same Vmm
223 to_src(ptr[addr_states_t_l_copy_reg], G0, src_data_t,
224 current_vlen, true);
225 // increment address pointers
226 L(loop_inc_regs_or_finish);
227 if (current_vlen != loop_tail_bytes) {
228 const auto current_gate_size
229 = current_vlen == vlen ? vlen_dst : gate_dt_size;
230 const auto current_states_size
231 = current_vlen == vlen ? vlen_dst : hstate_dt_size;
232 add(addr_scratch_gates_reg, current_vlen);
233 add(addr_ws_h_reg, current_gate_size);
234 add(addr_bias_reg,
235 current_vlen == vlen ? vlen_bias_ : bias_dt_size_);
236 add(addr_states_t_l_reg, current_states_size);
237 add(addr_states_t_l_copy_reg, current_states_size);
238 add(addr_states_tm1_l_reg, current_states_size);
239 add(addr_scratch_cell_reg, current_vlen);
240 if (is_training) add(addr_ws_gates_reg, current_gate_size);
241
242 // increment loop counter
243 sub(loop_cnt, current_vlen);
244 cmp(loop_cnt, current_vlen);
245 jge(loop_start_label);
246 }
247 }
248 };
249
250 // initialize registers with addresses and constants
251 mov(table_reg, table_label);
252 init_regs(vlen, loop_tail_bytes / scratch_dt_size);
253 mov(loop_cnt, loop_len_bytes);
254 if (loop_tail_bytes > 0) {
255 cmp(loop_cnt, vlen);
256 jl(tail_processing_or_exit_label, T_NEAR);
257 }
258
259 compute_loop(vlen);
260
261 L(tail_processing_or_exit_label);
262 if (loop_tail_bytes > 0) {
263 Label exit_label;
264 cmp(loop_cnt, 0);
265 jle(exit_label, T_NEAR);
266 compute_loop(is_avx512 ? loop_tail_bytes : scratch_dt_size);
267 L(exit_label);
268 }
269
270 postamble();
271
272 sigmoid_injector_->prepare_table(true);
273 tanh_injector_->prepare_table(true);
274 init_table(vlen);
275
276 L(table_label);
277 {
278 for (size_t i = 0; i < vlen / sizeof(float); i++)
279 dd(float2int(1.0f));
280 }
281 }
282}; // namespace cpu
283
284} // namespace x64
285} // namespace cpu
286} // namespace impl
287} // namespace dnnl
288
289#endif
290