1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_FWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_FWD_HPP
19
20#include "cpu/x64/injectors/injector_utils.hpp"
21#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28template <cpu_isa_t isa, impl::data_type_t src_data_t,
29 impl::data_type_t scratch_data_t>
30struct jit_uni_gru_cell_postgemm_part1_fwd : public jit_uni_rnn_postgemm {
31 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part1_fwd)
32
33 using injector_t = typename utils::conditional<isa == avx512_core,
34 jit_uni_eltwise_injector_f32<avx512_core>,
35 jit_uni_eltwise_injector_f32<isa>>::type;
36
37 jit_uni_gru_cell_postgemm_part1_fwd(
38 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
39 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
40
41 status_t init(data_type_t sdt) override {
42 jit_uni_rnn_postgemm::init(src_data_t);
43 // no need to save state of registers
44 // (unless emulating bf16 support)
45 const bool save_state
46 = src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16);
47 // we use rax for both constant tables as they use the same table
48 CHECK(safe_ptr_assign(sigmoid_injector_,
49 new injector_t(this, alg_kind::eltwise_logistic, 0.0f, 0.0f,
50 1.0f, save_state, rax)));
51 return create_kernel();
52 }
53
54protected:
55 std::unique_ptr<injector_t> sigmoid_injector_;
56
57 // register size in bytes
58 using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
59 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
60 static constexpr size_t qscale_dt_size = sizeof(float);
61 const size_t vlen_dst
62 = vlen / (sizeof(float) / types::data_type_size(src_data_t));
63 const size_t vlen_bias_ = vlen / (sizeof(float) / bias_dt_size_);
64 const size_t hstate_dt_size = types::data_type_size(src_data_t);
65 const size_t gate_dt_size = types::data_type_size(src_data_t);
66 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
67 const size_t vlen_qscale = vlen / qscale_dt_size;
68 const size_t vlen_elems = vlen / scratch_dt_size;
69
70 const int loop_ur_max = 4;
71 // We skip vmm0 as it can be used by the injector for masks on sse4.1
72 int G0_idx(int i) {
73 const int idx = 1 + i;
74 assert(idx < loop_ur_max + 1);
75 return idx;
76 }
77 int G1_idx(int i) {
78 const int idx = loop_ur_max + 1 + i;
79 assert(idx < 2 * loop_ur_max + 1);
80 return idx;
81 }
82 const Vmm tmp1_vmm = Vmm(9);
83 const Vmm tmp2_vmm = Vmm(10);
84
85 void generate() override {
86 using namespace Xbyak;
87 const auto is_training
88 = pd_->desc()->prop_kind == prop_kind::forward_training;
89
90 const int mask = pd_->attr()->rnn_weights_qparams_.mask_;
91 float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
92
93 // Register map
94 const Reg64 loop_cnt(rbx); // loop counter
95
96 // We start code generations here
97 preamble();
98
99 // extract addresses passed as parameter
100 const auto addr_ws_gates_reg = abi_param1;
101 const auto addr_scratch_gates_reg = abi_param2;
102 const auto addr_bias_reg = abi_param3;
103 const auto addr_states_t_l_reg = abi_param4;
104#ifdef _WIN32
105 const auto addr_states_t_l_copy_reg = r10;
106 const auto addr_states_tm1_l_reg = r11;
107 // Here we cannot use rbp to have initial stack pointer so we
108 // use rsp and offset it with the size of pushed registers in
109 // preamble
110 const auto base_args = get_stack_params_address();
111 mov(addr_states_t_l_copy_reg, ptr[base_args]);
112 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
113#else
114 const auto addr_states_t_l_copy_reg = abi_param5;
115 const auto addr_states_tm1_l_reg = abi_param6;
116#endif
117 // helper lambda to address the gates and biases
118 const auto sg_addr = [&](int i, int j) {
119 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size
120 + j * vlen];
121 };
122 const auto wg_addr = [&](int i, int j) {
123 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size
124 + j * vlen_dst];
125 };
126 const auto B_addr = [&](int i, int j) {
127 return ptr[addr_bias_reg + i * rnn_.dhc * bias_dt_size_ + j * vlen];
128 };
129
130 const size_t loop_len = rnn_.dhc;
131 const size_t loop_tail = loop_len % vlen_elems;
132 // initialize registers with addresses and constants
133 init_regs(weights_scales, vlen, loop_tail);
134
135 // both sigmoid and tanh use the same table so load address just once in rax
136 sigmoid_injector_->load_table_addr();
137
138 const size_t nb_loop_len = loop_len / vlen_elems;
139 size_t loop_ur_val = 1;
140 const bool is_brgemm = rnn_.is_brgemm && !rnn_.unfused_post_gemm;
141 if (is_brgemm) {
142#ifdef _WIN32
143 mov(loop_cnt, ptr[base_args + 40]);
144#else
145 // Here we cannot use rbp to have initial stack pointer so we
146 // use rsp and offset it with the size of pushed registers in
147 // preamble
148 const auto base_args = get_stack_params_address();
149 mov(loop_cnt, ptr[base_args + 24]);
150#endif
151 } else {
152 for (loop_ur_val = loop_ur_max; loop_ur_val > 1; --loop_ur_val)
153 if (nb_loop_len % loop_ur_val == 0) break;
154
155 mov(loop_cnt, loop_len);
156 }
157 const size_t loop_ur = loop_ur_val;
158
159 auto compute_loop = [=](size_t current_vlen_elem,
160 size_t current_loop_unroll) {
161 const auto current_vlen = current_vlen_elem * scratch_dt_size;
162 Label loop_start_label;
163 L(loop_start_label);
164 {
165 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
166 ++loop_ur_idx) {
167 const Vmm G0(G0_idx(loop_ur_idx));
168 const Vmm G1(G1_idx(loop_ur_idx));
169 // batch these operations in order to combine calls to injector:
170 // Compute gate 0: G0 = sigmoid(G0 + b0)
171 // Compute gate 1: G1 = sigmoid(G1 + b1)
172
173 // load gates from scratchpad
174 load(G0, sg_addr(0, loop_ur_idx), scratch_data_t,
175 current_vlen);
176 load(G1, sg_addr(1, loop_ur_idx), scratch_data_t,
177 current_vlen);
178
179 // dequantize gates from s32 to f32 if needed
180 deq_w(src_data_t, G0, tmp1_vmm, tmp2_vmm,
181 0 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask,
182 current_vlen);
183 deq_w(src_data_t, G1, tmp1_vmm, tmp2_vmm,
184 1 * rnn_.dhc + loop_ur_idx * vlen_qscale, mask,
185 current_vlen);
186
187 // apply bias
188 to_float(tmp1_vmm, B_addr(0, loop_ur_idx), rnn_.bias_dt,
189 current_vlen);
190 compute_vaddps(G0, G0, tmp1_vmm, current_vlen);
191 to_float(tmp2_vmm, B_addr(1, loop_ur_idx), rnn_.bias_dt,
192 current_vlen);
193 compute_vaddps(G1, G1, tmp2_vmm, current_vlen);
194 }
195
196 // Compute sigmoid of unrolled G0 and G1 regs together
197 // (this allows to not save any registers during eltwise)
198 injector_utils::vmm_index_set_t vmm_idxs;
199 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
200 ++loop_ur_idx) {
201 vmm_idxs.emplace(G0_idx(loop_ur_idx));
202 vmm_idxs.emplace(G1_idx(loop_ur_idx));
203 }
204 sigmoid_injector_->compute_vector_range(vmm_idxs);
205
206 for (size_t loop_ur_idx = 0; loop_ur_idx < current_loop_unroll;
207 ++loop_ur_idx) {
208 const Vmm G0(G0_idx(loop_ur_idx));
209 const Vmm G1(G1_idx(loop_ur_idx));
210 // store G0 for use in postgemm_part2
211 store(sg_addr(0, loop_ur_idx), G0, scratch_data_t,
212 current_vlen);
213
214 // if training we write back the gates
215 if (is_training) {
216 to_src(wg_addr(1, loop_ur_idx), G1, src_data_t,
217 current_vlen);
218 to_src(wg_addr(0, loop_ur_idx), G0, src_data_t,
219 current_vlen);
220 }
221
222 // states_t_l = states_tm1_l * G1
223 to_float(tmp1_vmm,
224 ptr[addr_states_tm1_l_reg + loop_ur_idx * vlen_dst],
225 src_data_t, current_vlen);
226 compute_vmulps(G1, G1, tmp1_vmm, current_vlen);
227 to_src(ptr[addr_states_t_l_reg + loop_ur_idx * vlen_dst],
228 G1, src_data_t, current_vlen);
229 // if states_t_l_copy is a non null ptr, we write the output
230 // to both tensors
231 Label loop_inc_regs;
232 cmp(addr_states_t_l_copy_reg, rnn_.dhc * hstate_dt_size);
233 jle(loop_inc_regs);
234 // As to_src is called with write_only=true it's important
235 // for bf16 src_dt to execute just after to_src method with
236 // write_only=false for the same Vmm
237 to_src(ptr[addr_states_t_l_copy_reg
238 + loop_ur_idx * vlen_dst],
239 G1, src_data_t, current_vlen, true);
240 L(loop_inc_regs);
241 }
242
243 if (current_vlen_elem != loop_tail) {
244 // increment address pointers
245 const auto current_gate_size = current_vlen == vlen
246 ? vlen_dst * current_loop_unroll
247 : gate_dt_size;
248 const auto current_states_size = current_vlen == vlen
249 ? vlen_dst * current_loop_unroll
250 : hstate_dt_size;
251
252 add(addr_scratch_gates_reg,
253 current_vlen * current_loop_unroll);
254 add(addr_bias_reg,
255 current_vlen == vlen
256 ? vlen_bias_ * current_loop_unroll
257 : bias_dt_size_);
258 add(addr_states_t_l_reg, current_states_size);
259 add(addr_states_t_l_copy_reg, current_states_size);
260 add(addr_states_tm1_l_reg, current_states_size);
261 if (is_training) add(addr_ws_gates_reg, current_gate_size);
262 inc_regs(mask,
263 current_vlen == vlen
264 ? current_vlen * current_loop_unroll
265 : qscale_dt_size);
266
267 // increment loop counter
268 sub(loop_cnt, current_vlen_elem * current_loop_unroll);
269 cmp(loop_cnt, current_vlen_elem * current_loop_unroll);
270 jge(loop_start_label);
271 }
272 }
273 };
274
275 // vector processing
276 if (loop_len >= vlen_elems) {
277 Label tail_processing_or_exit_label;
278 if (is_brgemm) {
279 cmp(loop_cnt, vlen_elems * loop_ur);
280 jl(tail_processing_or_exit_label, T_NEAR);
281 }
282 compute_loop(vlen_elems, loop_ur);
283 L(tail_processing_or_exit_label);
284 }
285
286 // tail processing
287 if (loop_tail > 0) {
288 Label exit_label;
289 if (is_brgemm) {
290 cmp(loop_cnt, 0);
291 jle(exit_label, T_NEAR);
292 }
293
294 compute_loop(is_avx512 ? loop_tail : 1, 1);
295 L(exit_label);
296 }
297
298 postamble();
299
300 // Again, only one table is needed and shared between sigmoid and tanh
301 sigmoid_injector_->prepare_table(true);
302 init_table(vlen);
303 }
304};
305
306} // namespace x64
307} // namespace cpu
308} // namespace impl
309} // namespace dnnl
310
311#endif
312