1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_BWD_HPP
18#define CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_BWD_HPP
19
20#include <memory>
21#include "common/utils.hpp"
22#include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm.hpp"
23#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30template <cpu_isa_t isa, impl::data_type_t src_data_t,
31 impl::data_type_t scratch_data_t>
32struct jit_uni_lstm_cell_postgemm_bwd
33 : public jit_uni_rnn_postgemm,
34 public jit_uni_lstm_cell_postgemm_t<isa> {
35 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_cell_postgemm_bwd)
36
37 jit_uni_lstm_cell_postgemm_bwd(
38 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
39 : jit_uni_rnn_postgemm(rnn, pd, jit_name())
40 , jit_uni_lstm_cell_postgemm_t<isa>(this, 11 /*tmp_id_begin*/,
41 // usage of jit_uni_rnn_postgemm::bf16_emu_ to identify bf16
42 // emulation case is illegal here, it's created in
43 // jit_uni_rnn_postgemm::init(), not in constructor, so
44 // jit_uni_rnn_postgemm::bf16_emu_ = nullptr always on this
45 // stage
46 src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16)) {
47 }
48 ~jit_uni_lstm_cell_postgemm_bwd() = default;
49
50 status_t init(data_type_t sdt) override {
51 jit_uni_rnn_postgemm::init(src_data_t);
52 // we use rax for both constant tables as they use the same table
53 tanh_injector_ = utils::make_unique<injector_t>(
54 this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, true, rax);
55 return create_kernel();
56 }
57
58protected:
59 using injector_t = typename jit_uni_lstm_cell_postgemm_t<isa>::injector_t;
60 using Vmm = typename jit_uni_lstm_cell_postgemm_t<isa>::Vmm;
61
62 std::unique_ptr<injector_t> tanh_injector_;
63
64 // register size in bytes
65 static constexpr size_t vlen_ = cpu_isa_traits<isa>::vlen;
66 const size_t vlen_c_states_ = vlen_ / (sizeof(float) / cstate_dt_size_);
67
68 static constexpr size_t diff_cstate_dt_size_ = sizeof(float);
69 static constexpr size_t hstate_dt_size_ = sizeof(float);
70 static constexpr size_t weights_peephole_dt_size_ = sizeof(float);
71
72 const size_t vlen_scratch_
73 = vlen_ / (sizeof(float) / types::data_type_size(scratch_data_t));
74 const size_t gate_dt_size_ = types::data_type_size(scratch_data_t);
75 const size_t scratch_dt_size_ = types::data_type_size(scratch_data_t);
76
77 void generate() override {
78 using namespace Xbyak;
79
80 // Labels declaration
81 Label vector_loop_start_label, vector_loop_end_label;
82 Label rem_loop_start_label, rem_loop_end_label;
83 Label table_label;
84
85 // Register map
86 const Reg64 table_reg(rbx); // used to load ones before the loop
87 const Reg64 loop_cnt(
88 rbx); // loop counter, can be aliased with table_reg
89 // We skip vmm0 as it can be used by the injector for masks on sse4.1
90 const int dG0_idx = 1, dG1_idx = 2, dG2_idx = 3, dG3_idx = 4,
91 tanhCt_idx = 5, dHt_idx = 6, dCt_idx = 7, G0_idx = 8,
92 G1_idx = 9, one_idx = 10;
93 const Vmm one_vmm(one_idx);
94 const Xmm one_xmm(one_idx);
95
96 // Adress maping
97 const Address one_addr = ptr[table_reg];
98 // We start code generations here
99 preamble();
100
101 // extract addresses passed as parameter
102 const auto addr_ws_gates_reg = abi_param1;
103 const auto addr_scratch_gates_reg = abi_param2;
104 const auto addr_diff_states_t_lp1_reg = abi_param3;
105 const auto addr_diff_states_tp1_l_reg = abi_param4;
106 const auto addr_weights_peephole_reg = r12;
107#ifdef _WIN32
108 const auto addr_diff_c_states_t_l_reg = r10;
109 const auto addr_diff_c_states_tp1_l_reg = r11;
110 const auto addr_c_states_tm1_l_reg = rdi;
111 const auto addr_c_states_t_l_reg = rsi;
112 const auto base_args = get_stack_params_address();
113 mov(addr_diff_c_states_t_l_reg, ptr[base_args]);
114 mov(addr_diff_c_states_tp1_l_reg, ptr[base_args + 8]);
115 mov(addr_c_states_tm1_l_reg, ptr[base_args + 16]);
116 mov(addr_c_states_t_l_reg, ptr[base_args + 24]);
117 mov(addr_weights_peephole_reg, ptr[base_args + 32]);
118#else
119 const auto addr_diff_c_states_t_l_reg = abi_param5;
120 const auto addr_diff_c_states_tp1_l_reg = abi_param6;
121 const auto addr_c_states_tm1_l_reg = r10;
122 const auto addr_c_states_t_l_reg = r11;
123 const auto base_args = get_stack_params_address();
124 mov(addr_c_states_tm1_l_reg, ptr[base_args]);
125 mov(addr_c_states_t_l_reg, ptr[base_args + 8]);
126 mov(addr_weights_peephole_reg, ptr[base_args + 16]);
127#endif
128
129 // helper lambda to address the gates and biases
130 const auto sg_addr = [&](int i) {
131 return ptr[addr_scratch_gates_reg
132 + i * rnn_.dhc * scratch_dt_size_];
133 };
134 const auto weights_peephole_addr = [&](int i) {
135 return ptr[addr_weights_peephole_reg
136 + i * rnn_.dhc * weights_peephole_dt_size_];
137 };
138 const auto wg_addr = [&](int i) {
139 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size_];
140 };
141
142 // initialize registers with addresses and constants
143 mov(table_reg, table_label);
144 init_regs(vlen_);
145 uni_vmovups(one_vmm, one_addr);
146 tanh_injector_->load_table_addr();
147
148 mov(loop_cnt, rnn_.dhc * scratch_dt_size_);
149 cmp(loop_cnt, vlen_scratch_);
150 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
151
152 L(vector_loop_start_label);
153 {
154 const Vmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), dG3(dG3_idx),
155 tanhCt(tanhCt_idx), dHt(dHt_idx), dCt(dCt_idx), G0(G0_idx),
156 G1(G1_idx);
157
158 // TODO: if w_gates are bfloat, we have to convert them to float
159 // datatypes summary:
160 // - c states are all float
161 // - h states are all src_data_t
162 // - diff_* are all float
163 // - scratch is src_data_t
164 // - ws_gates is src_data_t
165
166 // compute tanhCt
167 to_float(tanhCt, ptr[addr_c_states_t_l_reg], rnn_.src_iter_c_dt,
168 vlen_);
169 tanh_injector_->compute_vector(tanhCt.getIdx());
170
171 // compute dHt
172 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
173 uni_vmovups(dHt, ptr[addr_diff_states_t_lp1_reg]);
174 if (!rnn_.is_lstm_projection) {
175 this->vaddps_rhs_op_mem(
176 dHt, dHt, ptr[addr_diff_states_tp1_l_reg]);
177 }
178
179 // compute dCt
180 const auto tmp_dCt1 = this->get_next_tmp_vmm();
181 const auto tmp_dCt2 = this->get_next_tmp_vmm();
182
183 uni_vmovups(tmp_dCt1, one_vmm);
184 uni_vmovups(tmp_dCt2, tanhCt);
185 uni_vfnmadd231ps(tmp_dCt1, tmp_dCt2, tmp_dCt2);
186 uni_vmulps(tmp_dCt1, tmp_dCt1, dHt);
187 to_float(dG3, wg_addr(3), src_data_t, vlen_);
188 uni_vmulps(tmp_dCt1, tmp_dCt1, dG3);
189 uni_vmovups(dCt, ptr[addr_diff_c_states_tp1_l_reg]);
190 uni_vaddps(dCt, dCt, tmp_dCt1);
191
192 // compute dG3
193 const auto tmp_dG3 = this->get_next_tmp_vmm();
194 uni_vmovups(tmp_dG3, dG3);
195 uni_vfnmadd231ps(dG3, tmp_dG3, tmp_dG3);
196 uni_vmulps(dG3, dG3, dHt);
197 uni_vmulps(dG3, dG3, tanhCt);
198
199 // update dCt if lstm_peephole
200 if (rnn_.is_lstm_peephole)
201 this->vfmadd231ps_rhs_op_mem(
202 dCt, dG3, weights_peephole_addr(2));
203
204 // compute dG0
205 // we will reuse G0 and G2 later for dG2
206 to_float(G0, wg_addr(0), src_data_t, vlen_);
207 to_float(dG2, wg_addr(2), src_data_t, vlen_);
208 uni_vmovups(dG0, G0);
209 const auto tmp_g0 = this->vmm_backup(G0);
210 uni_vfnmadd231ps(dG0, tmp_g0, tmp_g0);
211 uni_vmulps(dG0, dG0, dCt);
212 uni_vmulps(dG0, dG0, dG2);
213
214 // compute dG1
215 to_float(G1, wg_addr(1), src_data_t, vlen_);
216 uni_vmovups(dG1, G1);
217 const auto tmp_g1 = this->vmm_backup(G1);
218 uni_vfnmadd231ps(dG1, tmp_g1, tmp_g1);
219 uni_vmulps(dG1, dG1, dCt);
220
221 const auto tmp_c_states_tm1 = this->get_next_tmp_vmm();
222 to_float(tmp_c_states_tm1, ptr[addr_c_states_tm1_l_reg],
223 rnn_.src_iter_c_dt, vlen_);
224 this->uni_vmulps(dG1, dG1, tmp_c_states_tm1);
225
226 // compute dG2
227 const auto tmp_dg2 = this->get_next_tmp_vmm();
228 uni_vmovups(tmp_dg2, one_vmm);
229 const auto tmp_g2 = this->vmm_backup(dG2);
230
231 uni_vfnmadd231ps(tmp_dg2, tmp_g2, tmp_g2);
232 uni_vmulps(G0, G0, dCt);
233 uni_vmulps(tmp_dg2, tmp_dg2, G0);
234 uni_vmovups(dG2, tmp_dg2);
235
236 // compute diff_state_t_l
237 uni_vmulps(dCt, dCt, G1);
238 if (rnn_.is_lstm_peephole) {
239 this->vfmadd231ps_rhs_op_mem(
240 dCt, dG0, weights_peephole_addr(0));
241 this->vfmadd231ps_rhs_op_mem(
242 dCt, dG1, weights_peephole_addr(1));
243 }
244 uni_vmovups(ptr[addr_diff_c_states_t_l_reg], dCt);
245
246 to_src(sg_addr(0), dG0, scratch_data_t, vlen_);
247 to_src(sg_addr(1), dG1, scratch_data_t, vlen_);
248 to_src(sg_addr(2), dG2, scratch_data_t, vlen_);
249 to_src(sg_addr(3), dG3, scratch_data_t, vlen_);
250
251 // increment address pointers
252 add(addr_ws_gates_reg, vlen_scratch_);
253 add(addr_scratch_gates_reg, vlen_scratch_);
254 add(addr_diff_states_t_lp1_reg, vlen_);
255 add(addr_diff_states_tp1_l_reg, vlen_);
256 add(addr_diff_c_states_t_l_reg, vlen_);
257 add(addr_diff_c_states_tp1_l_reg, vlen_);
258 add(addr_c_states_tm1_l_reg, vlen_c_states_);
259 add(addr_c_states_t_l_reg, vlen_c_states_);
260 if (rnn_.is_lstm_peephole) add(addr_weights_peephole_reg, vlen_);
261 inc_regs(vlen_);
262
263 // increment loop counter
264 sub(loop_cnt, vlen_scratch_);
265 cmp(loop_cnt, vlen_scratch_);
266 jge(vector_loop_start_label);
267 }
268 L(vector_loop_end_label);
269
270 cmp(loop_cnt, 0);
271 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
272 // Same code as above, we just use vmovss for accessing inputs
273 this->reset_vmm_cnt();
274 L(rem_loop_start_label);
275 {
276 const Xmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), dG3(dG3_idx),
277 tanhCt(tanhCt_idx), dHt(dHt_idx), dCt(dCt_idx), G0(G0_idx),
278 G1(G1_idx);
279
280 // compute tanhCt
281 to_float(tanhCt, ptr[addr_c_states_t_l_reg], rnn_.src_iter_c_dt,
282 sizeof(float));
283 tanh_injector_->compute_vector(tanhCt.getIdx());
284
285 // compute dHt
286 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
287 uni_vmovss(dHt, ptr[addr_diff_states_t_lp1_reg]);
288 if (!rnn_.is_lstm_projection)
289 this->vaddss_rhs_op_mem(
290 dHt, dHt, ptr[addr_diff_states_tp1_l_reg]);
291
292 // compute dCt
293 const auto tmp_dCt1 = this->get_next_tmp_xmm();
294 const auto tmp_dCt2 = this->get_next_tmp_xmm();
295
296 uni_vmovss(tmp_dCt1, one_xmm);
297 // This overrides tanhCt when using Xmm
298 uni_vmovss(tmp_dCt2, tanhCt);
299 uni_vfnmadd231ss(tmp_dCt1, tmp_dCt2, tmp_dCt2);
300 uni_vmulss(tmp_dCt1, tmp_dCt1, dHt);
301 to_float(dG3, wg_addr(3), src_data_t, hstate_dt_size_);
302 uni_vmulss(tmp_dCt1, tmp_dCt1, dG3);
303 uni_vmovss(dCt, ptr[addr_diff_c_states_tp1_l_reg]);
304 uni_vaddss(dCt, dCt, tmp_dCt1);
305
306 // compute dG3
307 const auto tmp_dG3 = this->get_next_tmp_xmm();
308 uni_vmovss(tmp_dG3, dG3);
309 uni_vfnmadd231ss(dG3, tmp_dG3, tmp_dG3);
310 uni_vmulss(dG3, dG3, dHt);
311 uni_vmulss(dG3, dG3, tanhCt);
312
313 // update dCt if lstm_peephole
314 if (rnn_.is_lstm_peephole) {
315 this->vfmadd231ss_rhs_op_mem(
316 dCt, dG3, weights_peephole_addr(2));
317 }
318
319 // compute dG0
320 // we will reuse G0 and G2 later for dG2
321 to_float(G0, wg_addr(0), src_data_t, hstate_dt_size_);
322 to_float(dG2, wg_addr(2), src_data_t, hstate_dt_size_);
323
324 uni_vmovss(dG0, G0);
325 const auto tmp_g0 = this->xmm_backup(G0);
326 uni_vfnmadd231ss(dG0, tmp_g0, tmp_g0);
327 uni_vmulss(dG0, dG0, dCt);
328 uni_vmulss(dG0, dG0, dG2);
329
330 // compute dG1
331 to_float(G1, wg_addr(1), src_data_t, hstate_dt_size_);
332 const auto tmp_g1 = this->xmm_backup(G1);
333 uni_vmovss(dG1, G1);
334 uni_vfnmadd231ss(dG1, tmp_g1, tmp_g1);
335 uni_vmulss(dG1, dG1, dCt);
336
337 const auto tmp_c_states_tm1 = this->get_next_tmp_xmm();
338 to_float(tmp_c_states_tm1, ptr[addr_c_states_tm1_l_reg],
339 rnn_.src_iter_c_dt, sizeof(float));
340 this->uni_vmulss(dG1, dG1, tmp_c_states_tm1);
341
342 // compute dG2
343 const auto tmp_dG2 = this->get_next_tmp_xmm();
344 uni_vmovss(tmp_dG2, one_xmm);
345 const auto tmp_g2 = this->xmm_backup(dG2);
346
347 uni_vfnmadd231ss(tmp_dG2, tmp_g2, tmp_g2);
348 uni_vmulss(G0, G0, dCt);
349 uni_vmulss(tmp_dG2, tmp_dG2, G0);
350 uni_vmovss(dG2, tmp_dG2);
351
352 // compute diff_state_t_l
353 uni_vmulss(dCt, dCt, G1);
354 if (rnn_.is_lstm_peephole) {
355 this->vfmadd231ss_rhs_op_mem(
356 dCt, dG1, weights_peephole_addr(1));
357 this->vfmadd231ss_rhs_op_mem(
358 dCt, dG0, weights_peephole_addr(0));
359 }
360 uni_vmovss(ptr[addr_diff_c_states_t_l_reg], dCt);
361
362 to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size_);
363 to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size_);
364 to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size_);
365 to_src(sg_addr(3), dG3, scratch_data_t, hstate_dt_size_);
366
367 // increment address pointers
368 add(addr_ws_gates_reg, scratch_dt_size_);
369 add(addr_scratch_gates_reg, scratch_dt_size_);
370 add(addr_diff_states_t_lp1_reg, hstate_dt_size_);
371 add(addr_diff_states_tp1_l_reg, hstate_dt_size_);
372 add(addr_diff_c_states_t_l_reg, diff_cstate_dt_size_);
373 add(addr_diff_c_states_tp1_l_reg, diff_cstate_dt_size_);
374 add(addr_c_states_tm1_l_reg, cstate_dt_size_);
375 add(addr_c_states_t_l_reg, cstate_dt_size_);
376 if (rnn_.is_lstm_peephole)
377 add(addr_weights_peephole_reg, weights_peephole_dt_size_);
378 inc_regs(hstate_dt_size_);
379
380 // increment loop counter
381 sub(loop_cnt, scratch_dt_size_);
382 cmp(loop_cnt, 0);
383 jg(rem_loop_start_label);
384 }
385 L(rem_loop_end_label);
386
387 postamble();
388
389 tanh_injector_->prepare_table();
390 init_table(vlen_);
391 L(table_label);
392 {
393 for (size_t i = 0; i < vlen_ / sizeof(float); ++i)
394 dd(float2int(1.0f));
395 }
396 }
397};
398
399} // namespace x64
400} // namespace cpu
401} // namespace impl
402} // namespace dnnl
403
404#endif
405