1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_BWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_LSTM_CELL_POSTGEMM_BWD_HPP |
19 | |
20 | #include <memory> |
21 | #include "common/utils.hpp" |
22 | #include "cpu/x64/rnn/jit_uni_lstm_cell_postgemm.hpp" |
23 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
31 | impl::data_type_t scratch_data_t> |
32 | struct jit_uni_lstm_cell_postgemm_bwd |
33 | : public jit_uni_rnn_postgemm, |
34 | public jit_uni_lstm_cell_postgemm_t<isa> { |
35 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_cell_postgemm_bwd) |
36 | |
37 | jit_uni_lstm_cell_postgemm_bwd( |
38 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
39 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) |
40 | , jit_uni_lstm_cell_postgemm_t<isa>(this, 11 /*tmp_id_begin*/, |
41 | // usage of jit_uni_rnn_postgemm::bf16_emu_ to identify bf16 |
42 | // emulation case is illegal here, it's created in |
43 | // jit_uni_rnn_postgemm::init(), not in constructor, so |
44 | // jit_uni_rnn_postgemm::bf16_emu_ = nullptr always on this |
45 | // stage |
46 | src_data_t == data_type::bf16 && !mayiuse(avx512_core_bf16)) { |
47 | } |
48 | ~jit_uni_lstm_cell_postgemm_bwd() = default; |
49 | |
50 | status_t init(data_type_t sdt) override { |
51 | jit_uni_rnn_postgemm::init(src_data_t); |
52 | // we use rax for both constant tables as they use the same table |
53 | tanh_injector_ = utils::make_unique<injector_t>( |
54 | this, alg_kind::eltwise_tanh, 0.0f, 0.0f, 1.0f, true, rax); |
55 | return create_kernel(); |
56 | } |
57 | |
58 | protected: |
59 | using injector_t = typename jit_uni_lstm_cell_postgemm_t<isa>::injector_t; |
60 | using Vmm = typename jit_uni_lstm_cell_postgemm_t<isa>::Vmm; |
61 | |
62 | std::unique_ptr<injector_t> tanh_injector_; |
63 | |
64 | // register size in bytes |
65 | static constexpr size_t vlen_ = cpu_isa_traits<isa>::vlen; |
66 | const size_t vlen_c_states_ = vlen_ / (sizeof(float) / cstate_dt_size_); |
67 | |
68 | static constexpr size_t diff_cstate_dt_size_ = sizeof(float); |
69 | static constexpr size_t hstate_dt_size_ = sizeof(float); |
70 | static constexpr size_t weights_peephole_dt_size_ = sizeof(float); |
71 | |
72 | const size_t vlen_scratch_ |
73 | = vlen_ / (sizeof(float) / types::data_type_size(scratch_data_t)); |
74 | const size_t gate_dt_size_ = types::data_type_size(scratch_data_t); |
75 | const size_t scratch_dt_size_ = types::data_type_size(scratch_data_t); |
76 | |
77 | void generate() override { |
78 | using namespace Xbyak; |
79 | |
80 | // Labels declaration |
81 | Label vector_loop_start_label, vector_loop_end_label; |
82 | Label rem_loop_start_label, rem_loop_end_label; |
83 | Label table_label; |
84 | |
85 | // Register map |
86 | const Reg64 table_reg(rbx); // used to load ones before the loop |
87 | const Reg64 loop_cnt( |
88 | rbx); // loop counter, can be aliased with table_reg |
89 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
90 | const int dG0_idx = 1, dG1_idx = 2, dG2_idx = 3, dG3_idx = 4, |
91 | tanhCt_idx = 5, dHt_idx = 6, dCt_idx = 7, G0_idx = 8, |
92 | G1_idx = 9, one_idx = 10; |
93 | const Vmm one_vmm(one_idx); |
94 | const Xmm one_xmm(one_idx); |
95 | |
96 | // Adress maping |
97 | const Address one_addr = ptr[table_reg]; |
98 | // We start code generations here |
99 | preamble(); |
100 | |
101 | // extract addresses passed as parameter |
102 | const auto addr_ws_gates_reg = abi_param1; |
103 | const auto addr_scratch_gates_reg = abi_param2; |
104 | const auto addr_diff_states_t_lp1_reg = abi_param3; |
105 | const auto addr_diff_states_tp1_l_reg = abi_param4; |
106 | const auto addr_weights_peephole_reg = r12; |
107 | #ifdef _WIN32 |
108 | const auto addr_diff_c_states_t_l_reg = r10; |
109 | const auto addr_diff_c_states_tp1_l_reg = r11; |
110 | const auto addr_c_states_tm1_l_reg = rdi; |
111 | const auto addr_c_states_t_l_reg = rsi; |
112 | const auto base_args = get_stack_params_address(); |
113 | mov(addr_diff_c_states_t_l_reg, ptr[base_args]); |
114 | mov(addr_diff_c_states_tp1_l_reg, ptr[base_args + 8]); |
115 | mov(addr_c_states_tm1_l_reg, ptr[base_args + 16]); |
116 | mov(addr_c_states_t_l_reg, ptr[base_args + 24]); |
117 | mov(addr_weights_peephole_reg, ptr[base_args + 32]); |
118 | #else |
119 | const auto addr_diff_c_states_t_l_reg = abi_param5; |
120 | const auto addr_diff_c_states_tp1_l_reg = abi_param6; |
121 | const auto addr_c_states_tm1_l_reg = r10; |
122 | const auto addr_c_states_t_l_reg = r11; |
123 | const auto base_args = get_stack_params_address(); |
124 | mov(addr_c_states_tm1_l_reg, ptr[base_args]); |
125 | mov(addr_c_states_t_l_reg, ptr[base_args + 8]); |
126 | mov(addr_weights_peephole_reg, ptr[base_args + 16]); |
127 | #endif |
128 | |
129 | // helper lambda to address the gates and biases |
130 | const auto sg_addr = [&](int i) { |
131 | return ptr[addr_scratch_gates_reg |
132 | + i * rnn_.dhc * scratch_dt_size_]; |
133 | }; |
134 | const auto weights_peephole_addr = [&](int i) { |
135 | return ptr[addr_weights_peephole_reg |
136 | + i * rnn_.dhc * weights_peephole_dt_size_]; |
137 | }; |
138 | const auto wg_addr = [&](int i) { |
139 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size_]; |
140 | }; |
141 | |
142 | // initialize registers with addresses and constants |
143 | mov(table_reg, table_label); |
144 | init_regs(vlen_); |
145 | uni_vmovups(one_vmm, one_addr); |
146 | tanh_injector_->load_table_addr(); |
147 | |
148 | mov(loop_cnt, rnn_.dhc * scratch_dt_size_); |
149 | cmp(loop_cnt, vlen_scratch_); |
150 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
151 | |
152 | L(vector_loop_start_label); |
153 | { |
154 | const Vmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), dG3(dG3_idx), |
155 | tanhCt(tanhCt_idx), dHt(dHt_idx), dCt(dCt_idx), G0(G0_idx), |
156 | G1(G1_idx); |
157 | |
158 | // TODO: if w_gates are bfloat, we have to convert them to float |
159 | // datatypes summary: |
160 | // - c states are all float |
161 | // - h states are all src_data_t |
162 | // - diff_* are all float |
163 | // - scratch is src_data_t |
164 | // - ws_gates is src_data_t |
165 | |
166 | // compute tanhCt |
167 | to_float(tanhCt, ptr[addr_c_states_t_l_reg], rnn_.src_iter_c_dt, |
168 | vlen_); |
169 | tanh_injector_->compute_vector(tanhCt.getIdx()); |
170 | |
171 | // compute dHt |
172 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
173 | uni_vmovups(dHt, ptr[addr_diff_states_t_lp1_reg]); |
174 | if (!rnn_.is_lstm_projection) { |
175 | this->vaddps_rhs_op_mem( |
176 | dHt, dHt, ptr[addr_diff_states_tp1_l_reg]); |
177 | } |
178 | |
179 | // compute dCt |
180 | const auto tmp_dCt1 = this->get_next_tmp_vmm(); |
181 | const auto tmp_dCt2 = this->get_next_tmp_vmm(); |
182 | |
183 | uni_vmovups(tmp_dCt1, one_vmm); |
184 | uni_vmovups(tmp_dCt2, tanhCt); |
185 | uni_vfnmadd231ps(tmp_dCt1, tmp_dCt2, tmp_dCt2); |
186 | uni_vmulps(tmp_dCt1, tmp_dCt1, dHt); |
187 | to_float(dG3, wg_addr(3), src_data_t, vlen_); |
188 | uni_vmulps(tmp_dCt1, tmp_dCt1, dG3); |
189 | uni_vmovups(dCt, ptr[addr_diff_c_states_tp1_l_reg]); |
190 | uni_vaddps(dCt, dCt, tmp_dCt1); |
191 | |
192 | // compute dG3 |
193 | const auto tmp_dG3 = this->get_next_tmp_vmm(); |
194 | uni_vmovups(tmp_dG3, dG3); |
195 | uni_vfnmadd231ps(dG3, tmp_dG3, tmp_dG3); |
196 | uni_vmulps(dG3, dG3, dHt); |
197 | uni_vmulps(dG3, dG3, tanhCt); |
198 | |
199 | // update dCt if lstm_peephole |
200 | if (rnn_.is_lstm_peephole) |
201 | this->vfmadd231ps_rhs_op_mem( |
202 | dCt, dG3, weights_peephole_addr(2)); |
203 | |
204 | // compute dG0 |
205 | // we will reuse G0 and G2 later for dG2 |
206 | to_float(G0, wg_addr(0), src_data_t, vlen_); |
207 | to_float(dG2, wg_addr(2), src_data_t, vlen_); |
208 | uni_vmovups(dG0, G0); |
209 | const auto tmp_g0 = this->vmm_backup(G0); |
210 | uni_vfnmadd231ps(dG0, tmp_g0, tmp_g0); |
211 | uni_vmulps(dG0, dG0, dCt); |
212 | uni_vmulps(dG0, dG0, dG2); |
213 | |
214 | // compute dG1 |
215 | to_float(G1, wg_addr(1), src_data_t, vlen_); |
216 | uni_vmovups(dG1, G1); |
217 | const auto tmp_g1 = this->vmm_backup(G1); |
218 | uni_vfnmadd231ps(dG1, tmp_g1, tmp_g1); |
219 | uni_vmulps(dG1, dG1, dCt); |
220 | |
221 | const auto tmp_c_states_tm1 = this->get_next_tmp_vmm(); |
222 | to_float(tmp_c_states_tm1, ptr[addr_c_states_tm1_l_reg], |
223 | rnn_.src_iter_c_dt, vlen_); |
224 | this->uni_vmulps(dG1, dG1, tmp_c_states_tm1); |
225 | |
226 | // compute dG2 |
227 | const auto tmp_dg2 = this->get_next_tmp_vmm(); |
228 | uni_vmovups(tmp_dg2, one_vmm); |
229 | const auto tmp_g2 = this->vmm_backup(dG2); |
230 | |
231 | uni_vfnmadd231ps(tmp_dg2, tmp_g2, tmp_g2); |
232 | uni_vmulps(G0, G0, dCt); |
233 | uni_vmulps(tmp_dg2, tmp_dg2, G0); |
234 | uni_vmovups(dG2, tmp_dg2); |
235 | |
236 | // compute diff_state_t_l |
237 | uni_vmulps(dCt, dCt, G1); |
238 | if (rnn_.is_lstm_peephole) { |
239 | this->vfmadd231ps_rhs_op_mem( |
240 | dCt, dG0, weights_peephole_addr(0)); |
241 | this->vfmadd231ps_rhs_op_mem( |
242 | dCt, dG1, weights_peephole_addr(1)); |
243 | } |
244 | uni_vmovups(ptr[addr_diff_c_states_t_l_reg], dCt); |
245 | |
246 | to_src(sg_addr(0), dG0, scratch_data_t, vlen_); |
247 | to_src(sg_addr(1), dG1, scratch_data_t, vlen_); |
248 | to_src(sg_addr(2), dG2, scratch_data_t, vlen_); |
249 | to_src(sg_addr(3), dG3, scratch_data_t, vlen_); |
250 | |
251 | // increment address pointers |
252 | add(addr_ws_gates_reg, vlen_scratch_); |
253 | add(addr_scratch_gates_reg, vlen_scratch_); |
254 | add(addr_diff_states_t_lp1_reg, vlen_); |
255 | add(addr_diff_states_tp1_l_reg, vlen_); |
256 | add(addr_diff_c_states_t_l_reg, vlen_); |
257 | add(addr_diff_c_states_tp1_l_reg, vlen_); |
258 | add(addr_c_states_tm1_l_reg, vlen_c_states_); |
259 | add(addr_c_states_t_l_reg, vlen_c_states_); |
260 | if (rnn_.is_lstm_peephole) add(addr_weights_peephole_reg, vlen_); |
261 | inc_regs(vlen_); |
262 | |
263 | // increment loop counter |
264 | sub(loop_cnt, vlen_scratch_); |
265 | cmp(loop_cnt, vlen_scratch_); |
266 | jge(vector_loop_start_label); |
267 | } |
268 | L(vector_loop_end_label); |
269 | |
270 | cmp(loop_cnt, 0); |
271 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
272 | // Same code as above, we just use vmovss for accessing inputs |
273 | this->reset_vmm_cnt(); |
274 | L(rem_loop_start_label); |
275 | { |
276 | const Xmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), dG3(dG3_idx), |
277 | tanhCt(tanhCt_idx), dHt(dHt_idx), dCt(dCt_idx), G0(G0_idx), |
278 | G1(G1_idx); |
279 | |
280 | // compute tanhCt |
281 | to_float(tanhCt, ptr[addr_c_states_t_l_reg], rnn_.src_iter_c_dt, |
282 | sizeof(float)); |
283 | tanh_injector_->compute_vector(tanhCt.getIdx()); |
284 | |
285 | // compute dHt |
286 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
287 | uni_vmovss(dHt, ptr[addr_diff_states_t_lp1_reg]); |
288 | if (!rnn_.is_lstm_projection) |
289 | this->vaddss_rhs_op_mem( |
290 | dHt, dHt, ptr[addr_diff_states_tp1_l_reg]); |
291 | |
292 | // compute dCt |
293 | const auto tmp_dCt1 = this->get_next_tmp_xmm(); |
294 | const auto tmp_dCt2 = this->get_next_tmp_xmm(); |
295 | |
296 | uni_vmovss(tmp_dCt1, one_xmm); |
297 | // This overrides tanhCt when using Xmm |
298 | uni_vmovss(tmp_dCt2, tanhCt); |
299 | uni_vfnmadd231ss(tmp_dCt1, tmp_dCt2, tmp_dCt2); |
300 | uni_vmulss(tmp_dCt1, tmp_dCt1, dHt); |
301 | to_float(dG3, wg_addr(3), src_data_t, hstate_dt_size_); |
302 | uni_vmulss(tmp_dCt1, tmp_dCt1, dG3); |
303 | uni_vmovss(dCt, ptr[addr_diff_c_states_tp1_l_reg]); |
304 | uni_vaddss(dCt, dCt, tmp_dCt1); |
305 | |
306 | // compute dG3 |
307 | const auto tmp_dG3 = this->get_next_tmp_xmm(); |
308 | uni_vmovss(tmp_dG3, dG3); |
309 | uni_vfnmadd231ss(dG3, tmp_dG3, tmp_dG3); |
310 | uni_vmulss(dG3, dG3, dHt); |
311 | uni_vmulss(dG3, dG3, tanhCt); |
312 | |
313 | // update dCt if lstm_peephole |
314 | if (rnn_.is_lstm_peephole) { |
315 | this->vfmadd231ss_rhs_op_mem( |
316 | dCt, dG3, weights_peephole_addr(2)); |
317 | } |
318 | |
319 | // compute dG0 |
320 | // we will reuse G0 and G2 later for dG2 |
321 | to_float(G0, wg_addr(0), src_data_t, hstate_dt_size_); |
322 | to_float(dG2, wg_addr(2), src_data_t, hstate_dt_size_); |
323 | |
324 | uni_vmovss(dG0, G0); |
325 | const auto tmp_g0 = this->xmm_backup(G0); |
326 | uni_vfnmadd231ss(dG0, tmp_g0, tmp_g0); |
327 | uni_vmulss(dG0, dG0, dCt); |
328 | uni_vmulss(dG0, dG0, dG2); |
329 | |
330 | // compute dG1 |
331 | to_float(G1, wg_addr(1), src_data_t, hstate_dt_size_); |
332 | const auto tmp_g1 = this->xmm_backup(G1); |
333 | uni_vmovss(dG1, G1); |
334 | uni_vfnmadd231ss(dG1, tmp_g1, tmp_g1); |
335 | uni_vmulss(dG1, dG1, dCt); |
336 | |
337 | const auto tmp_c_states_tm1 = this->get_next_tmp_xmm(); |
338 | to_float(tmp_c_states_tm1, ptr[addr_c_states_tm1_l_reg], |
339 | rnn_.src_iter_c_dt, sizeof(float)); |
340 | this->uni_vmulss(dG1, dG1, tmp_c_states_tm1); |
341 | |
342 | // compute dG2 |
343 | const auto tmp_dG2 = this->get_next_tmp_xmm(); |
344 | uni_vmovss(tmp_dG2, one_xmm); |
345 | const auto tmp_g2 = this->xmm_backup(dG2); |
346 | |
347 | uni_vfnmadd231ss(tmp_dG2, tmp_g2, tmp_g2); |
348 | uni_vmulss(G0, G0, dCt); |
349 | uni_vmulss(tmp_dG2, tmp_dG2, G0); |
350 | uni_vmovss(dG2, tmp_dG2); |
351 | |
352 | // compute diff_state_t_l |
353 | uni_vmulss(dCt, dCt, G1); |
354 | if (rnn_.is_lstm_peephole) { |
355 | this->vfmadd231ss_rhs_op_mem( |
356 | dCt, dG1, weights_peephole_addr(1)); |
357 | this->vfmadd231ss_rhs_op_mem( |
358 | dCt, dG0, weights_peephole_addr(0)); |
359 | } |
360 | uni_vmovss(ptr[addr_diff_c_states_t_l_reg], dCt); |
361 | |
362 | to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size_); |
363 | to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size_); |
364 | to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size_); |
365 | to_src(sg_addr(3), dG3, scratch_data_t, hstate_dt_size_); |
366 | |
367 | // increment address pointers |
368 | add(addr_ws_gates_reg, scratch_dt_size_); |
369 | add(addr_scratch_gates_reg, scratch_dt_size_); |
370 | add(addr_diff_states_t_lp1_reg, hstate_dt_size_); |
371 | add(addr_diff_states_tp1_l_reg, hstate_dt_size_); |
372 | add(addr_diff_c_states_t_l_reg, diff_cstate_dt_size_); |
373 | add(addr_diff_c_states_tp1_l_reg, diff_cstate_dt_size_); |
374 | add(addr_c_states_tm1_l_reg, cstate_dt_size_); |
375 | add(addr_c_states_t_l_reg, cstate_dt_size_); |
376 | if (rnn_.is_lstm_peephole) |
377 | add(addr_weights_peephole_reg, weights_peephole_dt_size_); |
378 | inc_regs(hstate_dt_size_); |
379 | |
380 | // increment loop counter |
381 | sub(loop_cnt, scratch_dt_size_); |
382 | cmp(loop_cnt, 0); |
383 | jg(rem_loop_start_label); |
384 | } |
385 | L(rem_loop_end_label); |
386 | |
387 | postamble(); |
388 | |
389 | tanh_injector_->prepare_table(); |
390 | init_table(vlen_); |
391 | L(table_label); |
392 | { |
393 | for (size_t i = 0; i < vlen_ / sizeof(float); ++i) |
394 | dd(float2int(1.0f)); |
395 | } |
396 | } |
397 | }; |
398 | |
399 | } // namespace x64 |
400 | } // namespace cpu |
401 | } // namespace impl |
402 | } // namespace dnnl |
403 | |
404 | #endif |
405 | |