1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_BWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_LBR_CELL_POSTGEMM_BWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_gru_lbr_cell_postgemm_bwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_lbr_cell_postgemm_bwd)
31
32 jit_uni_gru_lbr_cell_postgemm_bwd(
33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
35
36 ~jit_uni_gru_lbr_cell_postgemm_bwd() {}
37
38 status_t init(data_type_t sdt) override {
39 jit_uni_rnn_postgemm::init(src_data_t);
40 return create_kernel();
41 }
42
43protected:
44 // register size in bytes
45 using Vmm = typename cpu_isa_traits<isa>::Vmm;
46 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
47 static constexpr size_t hstate_dt_size = sizeof(float);
48 const size_t vlen_scratch
49 = vlen / (sizeof(float) / types::data_type_size(scratch_data_t));
50 const size_t gate_dt_size = types::data_type_size(scratch_data_t);
51 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
52
53 void generate() override {
54 using namespace Xbyak;
55
56 const bool is_augru = pd_->cell_kind() == alg_kind::lbr_augru;
57
58 // Labels declaration
59 Label vector_loop_start_label, vector_loop_inc_regs,
60 vector_loop_end_label;
61 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
62 Label table_label;
63
64 // Register map
65 const Reg64 table_reg(rbx); // used to load ones before the loop
66 const Reg64 loop_cnt(
67 rbx); // loop counter, can be aliased with table_reg
68
69 // We skip vmm0 as it can be used by the injector for masks on sse4.1
70 const int dG0_idx = 1, dG1_idx = 2, dG2_idx = 3, G0_idx = 4, G1_idx = 5,
71 G2_idx = 6, h_idx = 7, dHt_idx = 8, one_idx = 9,
72 tmp1_idx = 10, tmp2_idx = 11, dattn_acc_idx = 12,
73 attn_idx = 13;
74 const Vmm one_vmm(one_idx);
75 const Xmm one_xmm(one_idx);
76
77 // constant table map
78 const Address one_addr = ptr[table_reg];
79
80 // We start code generations here
81 preamble();
82
83 // extract addresses passed as parameter
84 const auto addr_ws_gates_reg = abi_param1;
85 const auto addr_scratch_gates_reg = abi_param2;
86 const auto addr_diff_states_t_lp1_reg = abi_param3;
87 const auto addr_diff_states_tp1_l_reg = abi_param4;
88 const auto addr_attn_reg = r14;
89#ifdef _WIN32
90 const auto addr_diff_states_t_l_reg = r10;
91 const auto addr_states_tm1_l_reg = r11;
92 const auto addr_scratch_cell_reg = r12;
93 const auto addr_ws_grid_reg = rsi;
94 const auto base_args = get_stack_params_address();
95 mov(addr_diff_states_t_l_reg, ptr[base_args]);
96 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
97 mov(addr_scratch_cell_reg, ptr[base_args + 16]);
98 mov(addr_ws_grid_reg, ptr[base_args + 24]);
99 if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]);
100#else
101 const auto addr_diff_states_t_l_reg = abi_param5;
102 const auto addr_states_tm1_l_reg = abi_param6;
103 const auto addr_scratch_cell_reg = r10;
104 const auto addr_ws_grid_reg = r11;
105 const auto base_args = get_stack_params_address();
106 mov(addr_scratch_cell_reg, ptr[base_args]);
107 mov(addr_ws_grid_reg, ptr[base_args + 8]);
108 if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]);
109#endif
110
111 // helper lambda to address the gates and biases
112 const auto sg_addr = [&](int i) {
113 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
114 };
115 const auto wg_addr = [&](int i) {
116 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
117 };
118 const auto sc_addr = [&](int i) {
119 return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size];
120 };
121
122 // initialize registers with addresses and constants
123 mov(table_reg, table_label);
124 init_regs(vlen);
125 uni_vmovups(one_vmm, one_addr);
126
127 if (is_augru) {
128 uni_vpxor(
129 Vmm(dattn_acc_idx), Vmm(dattn_acc_idx), Vmm(dattn_acc_idx));
130 const Xmm attn1s(attn_idx);
131 to_float(attn1s, ptr[addr_attn_reg], src_data_t, hstate_dt_size);
132 }
133
134 mov(loop_cnt, rnn_.dhc * scratch_dt_size);
135 cmp(loop_cnt, vlen_scratch);
136 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
137
138 if (is_augru) {
139 const Xmm attn1s(attn_idx);
140 const Vmm attn(attn_idx);
141 uni_vbroadcastss(attn, attn1s);
142 }
143
144 L(vector_loop_start_label);
145 {
146 const Vmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), G0(G0_idx),
147 G1(G1_idx), G2(G2_idx), dHt(dHt_idx), tmp1(tmp1_idx),
148 tmp2(tmp2_idx), h(h_idx), diff_attn_acc(dattn_acc_idx),
149 attn(attn_idx);
150
151 to_float(G0, wg_addr(0), src_data_t, vlen);
152 to_float(G1, wg_addr(1), src_data_t, vlen);
153 to_float(G2, wg_addr(2), src_data_t, vlen);
154
155 // compute dHt
156 uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]);
157 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
158 uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]);
159 uni_vaddps(dHt, dHt, tmp1);
160
161 // compute dG0
162 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen);
163 uni_vmovups(dG0, G0);
164 uni_vmovups(tmp1, G0);
165 uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2)
166 uni_vsubps(h, h, G2); // (h - G2)
167 uni_vmulps(dG0, dG0, h);
168 uni_vmulps(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt
169
170 if (is_augru) {
171 // Compute diff_attention
172 // 1. compute dAttention = -dG0 * G
173 uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2);
174 // 2. Compute dG0 *= 1 - Attention
175 uni_vsubps(tmp1, one_vmm, attn, tmp2);
176 uni_vmulps(dG0, dG0, tmp1);
177 }
178 // compute dG2
179 uni_vmovups(tmp1, one_vmm);
180 uni_vsubps(tmp1, tmp1, G0); // (1 - G0)
181 uni_vmovups(dG2, one_vmm);
182 uni_vmovups(tmp2, G2);
183 uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2)
184 uni_vmulps(dG2, dG2, tmp1);
185 uni_vmulps(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt
186
187 // compute dG1
188 to_float(tmp1, ptr[addr_ws_grid_reg], src_data_t, vlen);
189 uni_vmovups(dG1, G1);
190 uni_vmovups(tmp2, G1);
191 uni_vfnmadd231ps(dG1, tmp2, tmp2); // (G1 - G1^2)
192 uni_vmulps(dG1, dG1, dG2);
193 uni_vmulps(dG1, dG1, tmp1); // (G1 - G1^2) * dG2 * ws_grid
194
195 // compute diff_state_t_l
196 uni_vmulps(dHt, dHt, G0);
197 uni_vmovups(ptr[addr_diff_states_t_l_reg], dHt);
198
199 // compute scratch_cell
200 uni_vmovups(tmp1, dG2);
201 uni_vmulps(tmp1, tmp1, G1);
202
203 // downconvert and write data
204 to_src(sc_addr(0), dG0, scratch_data_t, vlen);
205 // As to_src is called with write_only=true it's important for bf16
206 // src_dt to execute just after to_src method with write_only=false
207 // for the same Vmm
208 to_src(sg_addr(0), dG0, scratch_data_t, vlen, true);
209
210 to_src(sc_addr(1), dG1, scratch_data_t, vlen);
211 // As to_src is called with write_only=true it's important for bf16
212 // src_dt to execute just after to_src method with write_only=false
213 // for the same Vmm
214 to_src(sg_addr(1), dG1, scratch_data_t, vlen, true);
215
216 to_src(sc_addr(2), tmp1, scratch_data_t, vlen);
217 to_src(sg_addr(2), dG2, scratch_data_t, vlen);
218
219 // increment address pointers
220 add(addr_ws_gates_reg, vlen_scratch);
221 add(addr_scratch_gates_reg, vlen_scratch);
222 add(addr_diff_states_t_lp1_reg, vlen);
223 add(addr_diff_states_tp1_l_reg, vlen);
224 add(addr_diff_states_t_l_reg, vlen);
225 add(addr_states_tm1_l_reg, vlen_scratch);
226 add(addr_scratch_cell_reg, vlen_scratch);
227 add(addr_ws_grid_reg, vlen_scratch);
228 inc_regs(vlen);
229
230 // increment loop counter
231 sub(loop_cnt, vlen_scratch);
232 cmp(loop_cnt, vlen_scratch);
233 jge(vector_loop_start_label);
234 }
235 L(vector_loop_end_label);
236
237 // Reduce diff attention into XMM size. Otherwise accumulation
238 // using XMM will zero high part of YMM/ZMM.
239 if (vlen >= cpu_isa_traits<avx512_core>::vlen) {
240 Zmm diff_attn_acc(dattn_acc_idx);
241 Ymm diff_attn_acc_high(tmp1_idx);
242 Ymm diff_attn_acc_low(dattn_acc_idx);
243 vextractf32x8(diff_attn_acc_high, diff_attn_acc, 1);
244 vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high);
245 }
246 if (vlen >= cpu_isa_traits<avx2>::vlen) {
247 Ymm diff_attn_acc(dattn_acc_idx);
248 Xmm diff_attn_acc_high(tmp1_idx);
249 Xmm diff_attn_acc_low(dattn_acc_idx);
250 vextractf128(diff_attn_acc_high, diff_attn_acc, 1);
251 vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high);
252 }
253
254 cmp(loop_cnt, 0);
255 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
256
257 // Same code as above, we just use movuss for accessing inputs
258 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
259 L(rem_loop_start_label);
260 {
261 const Xmm dG0(dG0_idx), dG1(dG1_idx), dG2(dG2_idx), G0(G0_idx),
262 G1(G1_idx), G2(G2_idx), dHt(dHt_idx), tmp1(tmp1_idx),
263 tmp2(tmp2_idx), h(h_idx), diff_attn_acc(dattn_acc_idx),
264 attn(attn_idx);
265
266 to_float(G0, wg_addr(0), src_data_t, hstate_dt_size);
267 to_float(G1, wg_addr(1), src_data_t, hstate_dt_size);
268 to_float(G2, wg_addr(2), src_data_t, hstate_dt_size);
269
270 // compute dHt
271 uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]);
272 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
273 uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]);
274 uni_vaddss(dHt, dHt, tmp1);
275
276 // compute dG0
277 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size);
278 uni_vmovss(dG0, G0);
279 uni_vmovss(tmp1, dG0);
280 uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2)
281 uni_vsubss(h, h, G2); // (h - G2)
282 uni_vmulss(dG0, dG0, h);
283 uni_vmulss(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt
284
285 if (is_augru) {
286 // compute diff_attention
287 // 1. compute tmp2 = dG0 * G
288 uni_vmovss(tmp2, dG0);
289 uni_vmulss(tmp2, tmp2, G0);
290 // 2. Store dAttention
291 uni_vsubss(diff_attn_acc, diff_attn_acc, tmp2);
292 // 3. Compute dG0 *= 1 - attention
293 uni_vmovss(tmp1, one_xmm);
294 uni_vsubss(tmp1, tmp1, attn);
295 uni_vmulss(dG0, dG0, tmp1);
296 }
297
298 // compute dG2
299 uni_vmovss(tmp1, one_xmm);
300 uni_vsubss(tmp1, tmp1, G0); // (1 - G0)
301
302 uni_vmovss(dG2, one_xmm);
303 uni_vmovss(tmp2, G2);
304 uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2)
305 uni_vmulss(dG2, dG2, tmp1);
306 uni_vmulss(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt
307
308 // compute dG1
309 to_float(tmp1, ptr[addr_ws_grid_reg], src_data_t, hstate_dt_size);
310 uni_vmovss(dG1, G1);
311 uni_vmovss(tmp2, G1);
312 uni_vfnmadd231ps(dG1, tmp2, tmp2); // (G1 - G1^2)
313 uni_vmulss(dG1, dG1, dG2);
314 uni_vmulss(dG1, dG1, tmp1); // (G1 - G1^2) * dG2 * ws_grid
315
316 // compute diff_state_t_l
317 uni_vmulss(dHt, dHt, G0);
318 uni_vmovss(ptr[addr_diff_states_t_l_reg], dHt);
319
320 // compute scratch_cell
321 uni_vmovss(tmp1, dG2);
322 uni_vmulss(tmp1, tmp1, G1);
323
324 // downconvert and write data
325 to_src(sc_addr(0), dG0, scratch_data_t, hstate_dt_size);
326 // As to_src is called with write_only=true it's important for bf16
327 // src_dt to execute just after to_src method with write_only=false
328 // for the same Vmm
329 to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size, true);
330
331 to_src(sc_addr(1), dG1, scratch_data_t, hstate_dt_size);
332 // As to_src is called with write_only=true it's important for bf16
333 // src_dt to execute just after to_src method with write_only=false
334 // for the same Vmm
335 to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size, true);
336
337 to_src(sc_addr(2), tmp1, scratch_data_t, hstate_dt_size);
338 to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size);
339
340 // increment address pointers
341 add(addr_ws_gates_reg, scratch_dt_size);
342 add(addr_scratch_gates_reg, scratch_dt_size);
343 add(addr_diff_states_t_lp1_reg, hstate_dt_size);
344 add(addr_diff_states_tp1_l_reg, hstate_dt_size);
345 add(addr_diff_states_t_l_reg, hstate_dt_size);
346 add(addr_states_tm1_l_reg, scratch_dt_size);
347 add(addr_scratch_cell_reg, scratch_dt_size);
348 add(addr_ws_grid_reg, scratch_dt_size);
349 inc_regs(hstate_dt_size);
350
351 // increment loop counter
352 sub(loop_cnt, scratch_dt_size);
353 cmp(loop_cnt, 0);
354 jg(rem_loop_start_label);
355 }
356 L(rem_loop_end_label);
357
358 if (is_augru) {
359 // Complete diff attention reduction
360 Xmm diff_attn_acc(dattn_acc_idx);
361 uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc);
362 uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc);
363 const auto base_args = get_stack_params_address();
364#ifdef _WIN32
365 mov(addr_attn_reg, ptr[base_args + 56]);
366#else
367 mov(addr_attn_reg, ptr[base_args + 40]);
368#endif
369 uni_vmovss(ptr[addr_attn_reg], diff_attn_acc);
370 }
371
372 postamble();
373
374 init_table(vlen);
375 L(table_label);
376 {
377 for (size_t i = 0; i < vlen / sizeof(float); i++)
378 dd(float2int(1.0f));
379 }
380 }
381}; // namespace cpu
382
383} // namespace x64
384} // namespace cpu
385} // namespace impl
386} // namespace dnnl
387
388#endif
389