1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_BWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_BWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_gru_cell_postgemm_part1_bwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part1_bwd) |
31 | |
32 | jit_uni_gru_cell_postgemm_part1_bwd( |
33 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
34 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
35 | |
36 | ~jit_uni_gru_cell_postgemm_part1_bwd() {} |
37 | |
38 | status_t init(data_type_t sdt) override { |
39 | jit_uni_rnn_postgemm::init(src_data_t); |
40 | return create_kernel(); |
41 | } |
42 | |
43 | protected: |
44 | // register size in bytes |
45 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
46 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
47 | const size_t vlen_scratch |
48 | = vlen / (sizeof(float) / types::data_type_size(scratch_data_t)); |
49 | static constexpr size_t hstate_dt_size = sizeof(float); |
50 | const size_t gate_dt_size = types::data_type_size(scratch_data_t); |
51 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
52 | |
53 | void generate() override { |
54 | using namespace Xbyak; |
55 | |
56 | const bool is_augru = pd_->cell_kind() == alg_kind::vanilla_augru; |
57 | |
58 | // Labels declaration |
59 | Label vector_loop_start_label, vector_loop_inc_regs, |
60 | vector_loop_end_label; |
61 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
62 | Label table_label; |
63 | |
64 | // Register map |
65 | const Reg64 table_reg(rbx); // used to load ones before the loop |
66 | const Reg64 loop_cnt( |
67 | rbx); // loop counter, can be aliased with table_reg |
68 | |
69 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
70 | const int dG0_idx = 1, dG2_idx = 3, G0_idx = 4, G2_idx = 6, h_idx = 7, |
71 | dHt_idx = 8, one_idx = 9, tmp1_idx = 10, tmp2_idx = 11, |
72 | dattn_acc_idx = 12, attn_idx = 13; |
73 | const Vmm one_vmm(one_idx); |
74 | const Xmm one_xmm(one_idx); |
75 | |
76 | // constant table map |
77 | const Address one_addr = ptr[table_reg]; |
78 | |
79 | // We start code generations here |
80 | preamble(); |
81 | |
82 | // extract addresses passed as parameter |
83 | const auto addr_ws_gates_reg = abi_param1; |
84 | const auto addr_scratch_gates_reg = abi_param2; |
85 | const auto addr_diff_states_t_lp1_reg = abi_param3; |
86 | const auto addr_diff_states_tp1_l_reg = abi_param4; |
87 | const auto addr_attn_reg = r15; |
88 | #ifdef _WIN32 |
89 | const auto addr_diff_states_t_l_reg = r10; |
90 | const auto addr_states_tm1_l_reg = r11; |
91 | const auto base_args = get_stack_params_address(); |
92 | mov(addr_diff_states_t_l_reg, ptr[base_args]); |
93 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
94 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]); |
95 | #else |
96 | const auto addr_diff_states_t_l_reg = abi_param5; |
97 | const auto addr_states_tm1_l_reg = abi_param6; |
98 | const auto base_args = get_stack_params_address(); |
99 | if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]); |
100 | #endif |
101 | |
102 | // helper lambda to address the gates and biases |
103 | const auto sg_addr = [&](int i) { |
104 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; |
105 | }; |
106 | const auto wg_addr = [&](int i) { |
107 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; |
108 | }; |
109 | |
110 | // initialize registers with addresses and constants |
111 | mov(table_reg, table_label); |
112 | init_regs(vlen); |
113 | uni_vmovups(one_vmm, one_addr); |
114 | |
115 | if (is_augru) { |
116 | uni_vpxor( |
117 | Vmm(dattn_acc_idx), Vmm(dattn_acc_idx), Vmm(dattn_acc_idx)); |
118 | const Xmm attn1s(attn_idx); |
119 | to_float(attn1s, ptr[addr_attn_reg], src_data_t, hstate_dt_size); |
120 | } |
121 | |
122 | mov(loop_cnt, rnn_.dhc * scratch_dt_size); |
123 | cmp(loop_cnt, vlen_scratch); |
124 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
125 | |
126 | if (is_augru) { |
127 | const Xmm attn1s(attn_idx); |
128 | const Vmm attn(attn_idx); |
129 | uni_vbroadcastss(attn, attn1s); |
130 | } |
131 | |
132 | L(vector_loop_start_label); |
133 | { |
134 | const Vmm dG0(dG0_idx), dG2(dG2_idx), G0(G0_idx), G2(G2_idx), |
135 | dHt(dHt_idx), tmp1(tmp1_idx), tmp2(tmp2_idx), h(h_idx), |
136 | diff_attn_acc(dattn_acc_idx), attn(attn_idx); |
137 | |
138 | to_float(G0, wg_addr(0), src_data_t, vlen); |
139 | to_float(G2, wg_addr(2), src_data_t, vlen); |
140 | |
141 | // compute dHt |
142 | uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]); |
143 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
144 | uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
145 | uni_vaddps(dHt, dHt, tmp1); |
146 | |
147 | // compute dG0 |
148 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen); |
149 | uni_vmovups(dG0, G0); |
150 | uni_vmovups(tmp1, G0); |
151 | uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2) |
152 | uni_vsubps(h, h, G2); // (h - G2) |
153 | uni_vmulps(dG0, dG0, h); |
154 | uni_vmulps(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt |
155 | |
156 | // compute dG2 |
157 | uni_vmovups(tmp1, one_vmm); |
158 | uni_vsubps(tmp1, tmp1, G0); // (1 - G0) |
159 | uni_vmovups(dG2, one_vmm); |
160 | uni_vmovups(tmp2, G2); |
161 | uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2) |
162 | uni_vmulps(dG2, dG2, tmp1); |
163 | uni_vmulps(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt |
164 | |
165 | if (is_augru) { |
166 | // Compute diff_attention |
167 | // 1. compute dAttention -= dG0 * G |
168 | uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2); |
169 | // 2. Compute dG0 *= 1 - Attention |
170 | uni_vsubps(tmp1, one_vmm, attn, tmp2); |
171 | uni_vmulps(dG0, dG0, tmp1); |
172 | } |
173 | |
174 | // compute diff_state_t_l |
175 | uni_vmulps(dHt, dHt, G0); |
176 | uni_vmovups(ptr[addr_diff_states_t_l_reg], dHt); |
177 | |
178 | // downconvert and write data |
179 | to_src(sg_addr(0), dG0, scratch_data_t, vlen); |
180 | to_src(sg_addr(2), dG2, scratch_data_t, vlen); |
181 | |
182 | // increment address pointers |
183 | add(addr_ws_gates_reg, vlen_scratch); |
184 | add(addr_scratch_gates_reg, vlen_scratch); |
185 | add(addr_diff_states_t_lp1_reg, vlen); |
186 | add(addr_diff_states_tp1_l_reg, vlen); |
187 | add(addr_diff_states_t_l_reg, vlen); |
188 | add(addr_states_tm1_l_reg, vlen_scratch); |
189 | inc_regs(vlen); |
190 | |
191 | // increment loop counter |
192 | sub(loop_cnt, vlen_scratch); |
193 | cmp(loop_cnt, vlen_scratch); |
194 | jge(vector_loop_start_label); |
195 | } |
196 | L(vector_loop_end_label); |
197 | |
198 | // Reduce diff attention into XMM size. Otherwise accumulation |
199 | // using XMM will zero high part of YMM/ZMM. |
200 | if (vlen >= cpu_isa_traits<avx512_core>::vlen) { |
201 | Zmm diff_attn_acc(dattn_acc_idx); |
202 | Ymm diff_attn_acc_high(tmp1_idx); |
203 | Ymm diff_attn_acc_low(dattn_acc_idx); |
204 | vextractf32x8(diff_attn_acc_high, diff_attn_acc, 1); |
205 | vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high); |
206 | } |
207 | if (vlen >= cpu_isa_traits<avx2>::vlen) { |
208 | Ymm diff_attn_acc(dattn_acc_idx); |
209 | Xmm diff_attn_acc_high(tmp1_idx); |
210 | Xmm diff_attn_acc_low(dattn_acc_idx); |
211 | vextractf128(diff_attn_acc_high, diff_attn_acc, 1); |
212 | vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high); |
213 | } |
214 | |
215 | cmp(loop_cnt, 0); |
216 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
217 | |
218 | // Same code as above, we just use movuss for accessing inputs |
219 | // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar |
220 | L(rem_loop_start_label); |
221 | { |
222 | const Xmm dG0(dG0_idx), dG2(dG2_idx), G0(G0_idx), G2(G2_idx), |
223 | dHt(dHt_idx), tmp1(tmp1_idx), tmp2(tmp2_idx), h(h_idx), |
224 | diff_attn_acc(dattn_acc_idx), attn(attn_idx); |
225 | |
226 | to_float(G0, wg_addr(0), src_data_t, hstate_dt_size); |
227 | to_float(G2, wg_addr(2), src_data_t, hstate_dt_size); |
228 | |
229 | // compute dHt |
230 | uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]); |
231 | // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states |
232 | uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
233 | uni_vaddss(dHt, dHt, tmp1); |
234 | |
235 | // compute dG0 |
236 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size); |
237 | uni_vmovss(dG0, G0); |
238 | uni_vmovss(tmp1, G0); |
239 | uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2) |
240 | uni_vsubss(h, h, G2); // (h - G2) |
241 | uni_vmulss(dG0, dG0, h); |
242 | uni_vmulss(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt |
243 | |
244 | // compute dG2 |
245 | uni_vmovss(tmp1, one_xmm); |
246 | uni_vsubss(tmp1, tmp1, G0); // (1 - G0) |
247 | uni_vmovss(dG2, one_xmm); |
248 | uni_vmovss(tmp2, G2); |
249 | uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2) |
250 | uni_vmulss(dG2, dG2, tmp1); |
251 | uni_vmulss(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt |
252 | |
253 | if (is_augru) { |
254 | // compute diff_attention |
255 | // 1. compute tmp2 = -dG0 * G |
256 | uni_vmovss(tmp2, dG0); |
257 | uni_vmulss(tmp2, tmp2, G0); |
258 | // 2. Store dAttention |
259 | uni_vsubss(diff_attn_acc, diff_attn_acc, tmp2); |
260 | // 3. Compute dG0 *= 1 - attention |
261 | uni_vmovss(tmp1, one_xmm); |
262 | uni_vsubss(tmp1, tmp1, attn); |
263 | uni_vmulss(dG0, dG0, tmp1); |
264 | } |
265 | |
266 | // compute diff_state_t_l |
267 | uni_vmulss(dHt, dHt, G0); |
268 | uni_vmovss(ptr[addr_diff_states_t_l_reg], dHt); |
269 | |
270 | // downconvert and write data |
271 | to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size); |
272 | to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size); |
273 | |
274 | // increment address pointers |
275 | add(addr_ws_gates_reg, scratch_dt_size); |
276 | add(addr_scratch_gates_reg, scratch_dt_size); |
277 | add(addr_diff_states_t_lp1_reg, hstate_dt_size); |
278 | add(addr_diff_states_tp1_l_reg, hstate_dt_size); |
279 | add(addr_diff_states_t_l_reg, hstate_dt_size); |
280 | add(addr_states_tm1_l_reg, scratch_dt_size); |
281 | inc_regs(hstate_dt_size); |
282 | |
283 | // increment loop counter |
284 | sub(loop_cnt, scratch_dt_size); |
285 | jnz(rem_loop_start_label); |
286 | } |
287 | L(rem_loop_end_label); |
288 | |
289 | if (is_augru) { |
290 | // Complete diff attention reduction |
291 | Xmm diff_attn_acc(dattn_acc_idx); |
292 | uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc); |
293 | uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc); |
294 | const auto base_args = get_stack_params_address(); |
295 | #ifdef _WIN32 |
296 | mov(addr_attn_reg, ptr[base_args + 56]); |
297 | #else |
298 | mov(addr_attn_reg, ptr[base_args + 40]); |
299 | #endif |
300 | uni_vmovss(ptr[addr_attn_reg], diff_attn_acc); |
301 | } |
302 | |
303 | postamble(); |
304 | |
305 | init_table(vlen); |
306 | L(table_label); |
307 | { |
308 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
309 | dd(float2int(1.0f)); |
310 | } |
311 | } |
312 | }; |
313 | |
314 | } // namespace x64 |
315 | } // namespace cpu |
316 | } // namespace impl |
317 | } // namespace dnnl |
318 | |
319 | #endif |
320 | |