1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_BWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_BWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_gru_lbr_cell_postgemm_bwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_lbr_cell_postgemm_bwd) |
31 | |
32 | jit_uni_gru_lbr_cell_postgemm_bwd( |
33 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
34 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
35 | |
36 | ~jit_uni_gru_lbr_cell_postgemm_bwd() {} |
37 | |
38 | status_t init(data_type_t sdt) override { |
39 | jit_uni_rnn_postgemm::init(src_data_t); |
40 | return create_kernel(); |
41 | } |
42 | |
43 | protected: |
44 | // register size in bytes |
45 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
46 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
47 | static constexpr size_t hstate_dt_size = sizeof(float); |
48 | const size_t vlen_scratch |
49 | = vlen / (sizeof(float) / types::data_type_size(scratch_data_t)); |
50 | const size_t gate_dt_size = types::data_type_size(scratch_data_t); |
51 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
52 | |
53 | void generate() override { |
54 | using namespace Xbyak; |
55 | |
56 | const bool is_augru = pd_->cell_kind() == alg_kind::lbr_augru; |
57 | |
58 | // Labels declaration |
59 | Label vector_loop_start_label, vector_loop_inc_regs, |
60 | vector_loop_end_label; |
61 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
62 | Label table_label; |
63 | |
64 | // Register map |
65 | const Reg64 table_reg(rbx); // used to load ones before the loop |
66 | const Reg64 loop_cnt( |
67 | rbx); // loop counter, can be aliased with table_reg |
68 | |
69 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
70 | const int dG0_idx = 1, dG1_idx = 2, dG2_idx = 3, G0_idx = 4, G1_idx = 5, |
71 | G2_idx = 6, h_idx = 7, dHt_idx = 8, one_idx = 9, |
72 | tmp1_idx = 10, tmp2_idx = 11, dattn_acc_idx = 12, |
73 | attn_idx = 13; |
74 | const Vmm one_vmm(one_idx); |
75 | const Xmm one_xmm(one_idx); |
76 | |
77 | // constant table map |
78 | const Address one_addr = ptr[table_reg]; |
79 | |
80 | // We start code generations here |
81 | preamble(); |
82 | |
83 | // extract addresses passed as parameter |
84 | const auto addr_ws_gates_reg = abi_param1; |
85 | const auto addr_scratch_gates_reg = abi_param2; |
86 | const auto addr_diff_states_t_lp1_reg = abi_param3; |
87 | const auto addr_diff_states_tp1_l_reg = abi_param4; |
88 | const auto addr_attn_reg = r14; |
89 | #ifdef _WIN32 |
90 | const auto addr_diff_states_t_l_reg = r10; |
91 | const auto addr_states_tm1_l_reg = r11; |
92 | const auto addr_scratch_cell_reg = r12; |
93 | const auto addr_ws_grid_reg = rsi; |
94 | const auto base_args = get_stack_params_address(); |
95 | mov(addr_diff_states_t_l_reg, ptr[base_args]); |
96 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
97 | mov(addr_scratch_cell_reg, ptr[base_args + 16]); |
98 | mov(addr_ws_grid_reg, ptr[base_args + 24]); |
99 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]); |
100 | #else |
101 | const auto addr_diff_states_t_l_reg = abi_param5; |
102 | const auto addr_states_tm1_l_reg = abi_param6; |
103 | const auto addr_scratch_cell_reg = r10; |
104 | const auto addr_ws_grid_reg = r11; |
105 | const auto base_args = get_stack_params_address(); |
106 | mov(addr_scratch_cell_reg, ptr[base_args]); |
107 | mov(addr_ws_grid_reg, ptr[base_args + 8]); |
108 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]); |
109 | #endif |
110 | |
111 | // helper lambda to address the gates and biases |
112 | const auto sg_addr = [&](int i) { |
113 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; |
114 | }; |
115 | const auto wg_addr = [&](int i) { |
116 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; |
117 | }; |
118 | const auto sc_addr = [&](int i) { |
119 | return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size]; |
120 | }; |
121 | |
122 | // initialize registers with addresses and constants |
123 | mov(table_reg, table_label); |
124 | init_regs(vlen); |
125 | uni_vmovups(one_vmm, one_addr); |
126 | |
127 | if (is_augru) { |
128 | uni_vpxor( |
129 | Vmm(dattn_acc_idx), Vmm(dattn_acc_idx), Vmm(dattn_acc_idx)); |
130 | const Xmm attn1s(attn_idx); |
131 | to_float(attn1s, ptr[addr_attn_reg], src_data_t, hstate_dt_size); |
132 | } |
133 | |
134 | mov(loop_cnt, rnn_.dhc * scratch_dt_size); |
135 | cmp(loop_cnt, vlen_scratch); |
136 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
137 | |
138 | if (is_augru) { |
139 | const Xmm attn1s(attn_idx); |
140 | const Vmm attn(attn_idx); |
141 | uni_vbroadcastss(attn, attn1s); |
142 | } |
143 | |
144 | L(vector_loop_start_label); |
145 | { |
146 | const Vmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), G0(G0_idx), |
147 | G1(G1_idx), G2(G2_idx), dHt(dHt_idx), tmp1(tmp1_idx), |
148 | tmp2(tmp2_idx), h(h_idx), diff_attn_acc(dattn_acc_idx), |
149 | attn(attn_idx); |
150 | |
151 | to_float(G0, wg_addr(0), src_data_t, vlen); |
152 | to_float(G1, wg_addr(1), src_data_t, vlen); |
153 | to_float(G2, wg_addr(2), src_data_t, vlen); |
154 | |
155 | // compute dHt |
156 | uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]); |
157 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
158 | uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
159 | uni_vaddps(dHt, dHt, tmp1); |
160 | |
161 | // compute dG0 |
162 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen); |
163 | uni_vmovups(dG0, G0); |
164 | uni_vmovups(tmp1, G0); |
165 | uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2) |
166 | uni_vsubps(h, h, G2); // (h - G2) |
167 | uni_vmulps(dG0, dG0, h); |
168 | uni_vmulps(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt |
169 | |
170 | if (is_augru) { |
171 | // Compute diff_attention |
172 | // 1. compute dAttention = -dG0 * G |
173 | uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2); |
174 | // 2. Compute dG0 *= 1 - Attention |
175 | uni_vsubps(tmp1, one_vmm, attn, tmp2); |
176 | uni_vmulps(dG0, dG0, tmp1); |
177 | } |
178 | // compute dG2 |
179 | uni_vmovups(tmp1, one_vmm); |
180 | uni_vsubps(tmp1, tmp1, G0); // (1 - G0) |
181 | uni_vmovups(dG2, one_vmm); |
182 | uni_vmovups(tmp2, G2); |
183 | uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2) |
184 | uni_vmulps(dG2, dG2, tmp1); |
185 | uni_vmulps(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt |
186 | |
187 | // compute dG1 |
188 | to_float(tmp1, ptr[addr_ws_grid_reg], src_data_t, vlen); |
189 | uni_vmovups(dG1, G1); |
190 | uni_vmovups(tmp2, G1); |
191 | uni_vfnmadd231ps(dG1, tmp2, tmp2); // (G1 - G1^2) |
192 | uni_vmulps(dG1, dG1, dG2); |
193 | uni_vmulps(dG1, dG1, tmp1); // (G1 - G1^2) * dG2 * ws_grid |
194 | |
195 | // compute diff_state_t_l |
196 | uni_vmulps(dHt, dHt, G0); |
197 | uni_vmovups(ptr[addr_diff_states_t_l_reg], dHt); |
198 | |
199 | // compute scratch_cell |
200 | uni_vmovups(tmp1, dG2); |
201 | uni_vmulps(tmp1, tmp1, G1); |
202 | |
203 | // downconvert and write data |
204 | to_src(sc_addr(0), dG0, scratch_data_t, vlen); |
205 | // As to_src is called with write_only=true it's important for bf16 |
206 | // src_dt to execute just after to_src method with write_only=false |
207 | // for the same Vmm |
208 | to_src(sg_addr(0), dG0, scratch_data_t, vlen, true); |
209 | |
210 | to_src(sc_addr(1), dG1, scratch_data_t, vlen); |
211 | // As to_src is called with write_only=true it's important for bf16 |
212 | // src_dt to execute just after to_src method with write_only=false |
213 | // for the same Vmm |
214 | to_src(sg_addr(1), dG1, scratch_data_t, vlen, true); |
215 | |
216 | to_src(sc_addr(2), tmp1, scratch_data_t, vlen); |
217 | to_src(sg_addr(2), dG2, scratch_data_t, vlen); |
218 | |
219 | // increment address pointers |
220 | add(addr_ws_gates_reg, vlen_scratch); |
221 | add(addr_scratch_gates_reg, vlen_scratch); |
222 | add(addr_diff_states_t_lp1_reg, vlen); |
223 | add(addr_diff_states_tp1_l_reg, vlen); |
224 | add(addr_diff_states_t_l_reg, vlen); |
225 | add(addr_states_tm1_l_reg, vlen_scratch); |
226 | add(addr_scratch_cell_reg, vlen_scratch); |
227 | add(addr_ws_grid_reg, vlen_scratch); |
228 | inc_regs(vlen); |
229 | |
230 | // increment loop counter |
231 | sub(loop_cnt, vlen_scratch); |
232 | cmp(loop_cnt, vlen_scratch); |
233 | jge(vector_loop_start_label); |
234 | } |
235 | L(vector_loop_end_label); |
236 | |
237 | // Reduce diff attention into XMM size. Otherwise accumulation |
238 | // using XMM will zero high part of YMM/ZMM. |
239 | if (vlen >= cpu_isa_traits<avx512_core>::vlen) { |
240 | Zmm diff_attn_acc(dattn_acc_idx); |
241 | Ymm diff_attn_acc_high(tmp1_idx); |
242 | Ymm diff_attn_acc_low(dattn_acc_idx); |
243 | vextractf32x8(diff_attn_acc_high, diff_attn_acc, 1); |
244 | vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high); |
245 | } |
246 | if (vlen >= cpu_isa_traits<avx2>::vlen) { |
247 | Ymm diff_attn_acc(dattn_acc_idx); |
248 | Xmm diff_attn_acc_high(tmp1_idx); |
249 | Xmm diff_attn_acc_low(dattn_acc_idx); |
250 | vextractf128(diff_attn_acc_high, diff_attn_acc, 1); |
251 | vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high); |
252 | } |
253 | |
254 | cmp(loop_cnt, 0); |
255 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
256 | |
257 | // Same code as above, we just use movuss for accessing inputs |
258 | // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar |
259 | L(rem_loop_start_label); |
260 | { |
261 | const Xmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), G0(G0_idx), |
262 | G1(G1_idx), G2(G2_idx), dHt(dHt_idx), tmp1(tmp1_idx), |
263 | tmp2(tmp2_idx), h(h_idx), diff_attn_acc(dattn_acc_idx), |
264 | attn(attn_idx); |
265 | |
266 | to_float(G0, wg_addr(0), src_data_t, hstate_dt_size); |
267 | to_float(G1, wg_addr(1), src_data_t, hstate_dt_size); |
268 | to_float(G2, wg_addr(2), src_data_t, hstate_dt_size); |
269 | |
270 | // compute dHt |
271 | uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]); |
272 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
273 | uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
274 | uni_vaddss(dHt, dHt, tmp1); |
275 | |
276 | // compute dG0 |
277 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size); |
278 | uni_vmovss(dG0, G0); |
279 | uni_vmovss(tmp1, dG0); |
280 | uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2) |
281 | uni_vsubss(h, h, G2); // (h - G2) |
282 | uni_vmulss(dG0, dG0, h); |
283 | uni_vmulss(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt |
284 | |
285 | if (is_augru) { |
286 | // compute diff_attention |
287 | // 1. compute tmp2 = dG0 * G |
288 | uni_vmovss(tmp2, dG0); |
289 | uni_vmulss(tmp2, tmp2, G0); |
290 | // 2. Store dAttention |
291 | uni_vsubss(diff_attn_acc, diff_attn_acc, tmp2); |
292 | // 3. Compute dG0 *= 1 - attention |
293 | uni_vmovss(tmp1, one_xmm); |
294 | uni_vsubss(tmp1, tmp1, attn); |
295 | uni_vmulss(dG0, dG0, tmp1); |
296 | } |
297 | |
298 | // compute dG2 |
299 | uni_vmovss(tmp1, one_xmm); |
300 | uni_vsubss(tmp1, tmp1, G0); // (1 - G0) |
301 | |
302 | uni_vmovss(dG2, one_xmm); |
303 | uni_vmovss(tmp2, G2); |
304 | uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2) |
305 | uni_vmulss(dG2, dG2, tmp1); |
306 | uni_vmulss(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt |
307 | |
308 | // compute dG1 |
309 | to_float(tmp1, ptr[addr_ws_grid_reg], src_data_t, hstate_dt_size); |
310 | uni_vmovss(dG1, G1); |
311 | uni_vmovss(tmp2, G1); |
312 | uni_vfnmadd231ps(dG1, tmp2, tmp2); // (G1 - G1^2) |
313 | uni_vmulss(dG1, dG1, dG2); |
314 | uni_vmulss(dG1, dG1, tmp1); // (G1 - G1^2) * dG2 * ws_grid |
315 | |
316 | // compute diff_state_t_l |
317 | uni_vmulss(dHt, dHt, G0); |
318 | uni_vmovss(ptr[addr_diff_states_t_l_reg], dHt); |
319 | |
320 | // compute scratch_cell |
321 | uni_vmovss(tmp1, dG2); |
322 | uni_vmulss(tmp1, tmp1, G1); |
323 | |
324 | // downconvert and write data |
325 | to_src(sc_addr(0), dG0, scratch_data_t, hstate_dt_size); |
326 | // As to_src is called with write_only=true it's important for bf16 |
327 | // src_dt to execute just after to_src method with write_only=false |
328 | // for the same Vmm |
329 | to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size, true); |
330 | |
331 | to_src(sc_addr(1), dG1, scratch_data_t, hstate_dt_size); |
332 | // As to_src is called with write_only=true it's important for bf16 |
333 | // src_dt to execute just after to_src method with write_only=false |
334 | // for the same Vmm |
335 | to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size, true); |
336 | |
337 | to_src(sc_addr(2), tmp1, scratch_data_t, hstate_dt_size); |
338 | to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size); |
339 | |
340 | // increment address pointers |
341 | add(addr_ws_gates_reg, scratch_dt_size); |
342 | add(addr_scratch_gates_reg, scratch_dt_size); |
343 | add(addr_diff_states_t_lp1_reg, hstate_dt_size); |
344 | add(addr_diff_states_tp1_l_reg, hstate_dt_size); |
345 | add(addr_diff_states_t_l_reg, hstate_dt_size); |
346 | add(addr_states_tm1_l_reg, scratch_dt_size); |
347 | add(addr_scratch_cell_reg, scratch_dt_size); |
348 | add(addr_ws_grid_reg, scratch_dt_size); |
349 | inc_regs(hstate_dt_size); |
350 | |
351 | // increment loop counter |
352 | sub(loop_cnt, scratch_dt_size); |
353 | cmp(loop_cnt, 0); |
354 | jg(rem_loop_start_label); |
355 | } |
356 | L(rem_loop_end_label); |
357 | |
358 | if (is_augru) { |
359 | // Complete diff attention reduction |
360 | Xmm diff_attn_acc(dattn_acc_idx); |
361 | uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc); |
362 | uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc); |
363 | const auto base_args = get_stack_params_address(); |
364 | #ifdef _WIN32 |
365 | mov(addr_attn_reg, ptr[base_args + 56]); |
366 | #else |
367 | mov(addr_attn_reg, ptr[base_args + 40]); |
368 | #endif |
369 | uni_vmovss(ptr[addr_attn_reg], diff_attn_acc); |
370 | } |
371 | |
372 | postamble(); |
373 | |
374 | init_table(vlen); |
375 | L(table_label); |
376 | { |
377 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
378 | dd(float2int(1.0f)); |
379 | } |
380 | } |
381 | }; // namespace cpu |
382 | |
383 | } // namespace x64 |
384 | } // namespace cpu |
385 | } // namespace impl |
386 | } // namespace dnnl |
387 | |
388 | #endif |
389 | |