1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_BWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_1_BWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_gru_cell_postgemm_part1_bwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part1_bwd)
31
32 jit_uni_gru_cell_postgemm_part1_bwd(
33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
35
36 ~jit_uni_gru_cell_postgemm_part1_bwd() {}
37
38 status_t init(data_type_t sdt) override {
39 jit_uni_rnn_postgemm::init(src_data_t);
40 return create_kernel();
41 }
42
43protected:
44 // register size in bytes
45 using Vmm = typename cpu_isa_traits<isa>::Vmm;
46 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
47 const size_t vlen_scratch
48 = vlen / (sizeof(float) / types::data_type_size(scratch_data_t));
49 static constexpr size_t hstate_dt_size = sizeof(float);
50 const size_t gate_dt_size = types::data_type_size(scratch_data_t);
51 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
52
53 void generate() override {
54 using namespace Xbyak;
55
56 const bool is_augru = pd_->cell_kind() == alg_kind::vanilla_augru;
57
58 // Labels declaration
59 Label vector_loop_start_label, vector_loop_inc_regs,
60 vector_loop_end_label;
61 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
62 Label table_label;
63
64 // Register map
65 const Reg64 table_reg(rbx); // used to load ones before the loop
66 const Reg64 loop_cnt(
67 rbx); // loop counter, can be aliased with table_reg
68
69 // We skip vmm0 as it can be used by the injector for masks on sse4.1
70 const int dG0_idx = 1, dG2_idx = 3, G0_idx = 4, G2_idx = 6, h_idx = 7,
71 dHt_idx = 8, one_idx = 9, tmp1_idx = 10, tmp2_idx = 11,
72 dattn_acc_idx = 12, attn_idx = 13;
73 const Vmm one_vmm(one_idx);
74 const Xmm one_xmm(one_idx);
75
76 // constant table map
77 const Address one_addr = ptr[table_reg];
78
79 // We start code generations here
80 preamble();
81
82 // extract addresses passed as parameter
83 const auto addr_ws_gates_reg = abi_param1;
84 const auto addr_scratch_gates_reg = abi_param2;
85 const auto addr_diff_states_t_lp1_reg = abi_param3;
86 const auto addr_diff_states_tp1_l_reg = abi_param4;
87 const auto addr_attn_reg = r15;
88#ifdef _WIN32
89 const auto addr_diff_states_t_l_reg = r10;
90 const auto addr_states_tm1_l_reg = r11;
91 const auto base_args = get_stack_params_address();
92 mov(addr_diff_states_t_l_reg, ptr[base_args]);
93 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
94 if (is_augru) mov(addr_attn_reg, ptr[base_args + 48]);
95#else
96 const auto addr_diff_states_t_l_reg = abi_param5;
97 const auto addr_states_tm1_l_reg = abi_param6;
98 const auto base_args = get_stack_params_address();
99 if (is_augru) mov(addr_attn_reg, ptr[base_args + 32]);
100#endif
101
102 // helper lambda to address the gates and biases
103 const auto sg_addr = [&](int i) {
104 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
105 };
106 const auto wg_addr = [&](int i) {
107 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
108 };
109
110 // initialize registers with addresses and constants
111 mov(table_reg, table_label);
112 init_regs(vlen);
113 uni_vmovups(one_vmm, one_addr);
114
115 if (is_augru) {
116 uni_vpxor(
117 Vmm(dattn_acc_idx), Vmm(dattn_acc_idx), Vmm(dattn_acc_idx));
118 const Xmm attn1s(attn_idx);
119 to_float(attn1s, ptr[addr_attn_reg], src_data_t, hstate_dt_size);
120 }
121
122 mov(loop_cnt, rnn_.dhc * scratch_dt_size);
123 cmp(loop_cnt, vlen_scratch);
124 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
125
126 if (is_augru) {
127 const Xmm attn1s(attn_idx);
128 const Vmm attn(attn_idx);
129 uni_vbroadcastss(attn, attn1s);
130 }
131
132 L(vector_loop_start_label);
133 {
134 const Vmm dG0(dG0_idx), dG2(dG2_idx), G0(G0_idx), G2(G2_idx),
135 dHt(dHt_idx), tmp1(tmp1_idx), tmp2(tmp2_idx), h(h_idx),
136 diff_attn_acc(dattn_acc_idx), attn(attn_idx);
137
138 to_float(G0, wg_addr(0), src_data_t, vlen);
139 to_float(G2, wg_addr(2), src_data_t, vlen);
140
141 // compute dHt
142 uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]);
143 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
144 uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]);
145 uni_vaddps(dHt, dHt, tmp1);
146
147 // compute dG0
148 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen);
149 uni_vmovups(dG0, G0);
150 uni_vmovups(tmp1, G0);
151 uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2)
152 uni_vsubps(h, h, G2); // (h - G2)
153 uni_vmulps(dG0, dG0, h);
154 uni_vmulps(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt
155
156 // compute dG2
157 uni_vmovups(tmp1, one_vmm);
158 uni_vsubps(tmp1, tmp1, G0); // (1 - G0)
159 uni_vmovups(dG2, one_vmm);
160 uni_vmovups(tmp2, G2);
161 uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2)
162 uni_vmulps(dG2, dG2, tmp1);
163 uni_vmulps(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt
164
165 if (is_augru) {
166 // Compute diff_attention
167 // 1. compute dAttention -= dG0 * G
168 uni_vfnmadd231ps(diff_attn_acc, dG0, G0, tmp2);
169 // 2. Compute dG0 *= 1 - Attention
170 uni_vsubps(tmp1, one_vmm, attn, tmp2);
171 uni_vmulps(dG0, dG0, tmp1);
172 }
173
174 // compute diff_state_t_l
175 uni_vmulps(dHt, dHt, G0);
176 uni_vmovups(ptr[addr_diff_states_t_l_reg], dHt);
177
178 // downconvert and write data
179 to_src(sg_addr(0), dG0, scratch_data_t, vlen);
180 to_src(sg_addr(2), dG2, scratch_data_t, vlen);
181
182 // increment address pointers
183 add(addr_ws_gates_reg, vlen_scratch);
184 add(addr_scratch_gates_reg, vlen_scratch);
185 add(addr_diff_states_t_lp1_reg, vlen);
186 add(addr_diff_states_tp1_l_reg, vlen);
187 add(addr_diff_states_t_l_reg, vlen);
188 add(addr_states_tm1_l_reg, vlen_scratch);
189 inc_regs(vlen);
190
191 // increment loop counter
192 sub(loop_cnt, vlen_scratch);
193 cmp(loop_cnt, vlen_scratch);
194 jge(vector_loop_start_label);
195 }
196 L(vector_loop_end_label);
197
198 // Reduce diff attention into XMM size. Otherwise accumulation
199 // using XMM will zero high part of YMM/ZMM.
200 if (vlen >= cpu_isa_traits<avx512_core>::vlen) {
201 Zmm diff_attn_acc(dattn_acc_idx);
202 Ymm diff_attn_acc_high(tmp1_idx);
203 Ymm diff_attn_acc_low(dattn_acc_idx);
204 vextractf32x8(diff_attn_acc_high, diff_attn_acc, 1);
205 vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high);
206 }
207 if (vlen >= cpu_isa_traits<avx2>::vlen) {
208 Ymm diff_attn_acc(dattn_acc_idx);
209 Xmm diff_attn_acc_high(tmp1_idx);
210 Xmm diff_attn_acc_low(dattn_acc_idx);
211 vextractf128(diff_attn_acc_high, diff_attn_acc, 1);
212 vaddps(diff_attn_acc_low, diff_attn_acc_low, diff_attn_acc_high);
213 }
214
215 cmp(loop_cnt, 0);
216 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
217
218 // Same code as above, we just use movuss for accessing inputs
219 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
220 L(rem_loop_start_label);
221 {
222 const Xmm dG0(dG0_idx), dG2(dG2_idx), G0(G0_idx), G2(G2_idx),
223 dHt(dHt_idx), tmp1(tmp1_idx), tmp2(tmp2_idx), h(h_idx),
224 diff_attn_acc(dattn_acc_idx), attn(attn_idx);
225
226 to_float(G0, wg_addr(0), src_data_t, hstate_dt_size);
227 to_float(G2, wg_addr(2), src_data_t, hstate_dt_size);
228
229 // compute dHt
230 uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]);
231 // assumption: the diff_states_t_lp1 address is already offset by rnn.n_states
232 uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]);
233 uni_vaddss(dHt, dHt, tmp1);
234
235 // compute dG0
236 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size);
237 uni_vmovss(dG0, G0);
238 uni_vmovss(tmp1, G0);
239 uni_vfnmadd231ps(dG0, tmp1, tmp1); // (G0 - G0^2)
240 uni_vsubss(h, h, G2); // (h - G2)
241 uni_vmulss(dG0, dG0, h);
242 uni_vmulss(dG0, dG0, dHt); // (h - G2) * (G0 - G0^2) * dHt
243
244 // compute dG2
245 uni_vmovss(tmp1, one_xmm);
246 uni_vsubss(tmp1, tmp1, G0); // (1 - G0)
247 uni_vmovss(dG2, one_xmm);
248 uni_vmovss(tmp2, G2);
249 uni_vfnmadd231ps(dG2, tmp2, tmp2); // (1 - G2^2)
250 uni_vmulss(dG2, dG2, tmp1);
251 uni_vmulss(dG2, dG2, dHt); //(1 - G0) * (1 - G2^2) * dHt
252
253 if (is_augru) {
254 // compute diff_attention
255 // 1. compute tmp2 = -dG0 * G
256 uni_vmovss(tmp2, dG0);
257 uni_vmulss(tmp2, tmp2, G0);
258 // 2. Store dAttention
259 uni_vsubss(diff_attn_acc, diff_attn_acc, tmp2);
260 // 3. Compute dG0 *= 1 - attention
261 uni_vmovss(tmp1, one_xmm);
262 uni_vsubss(tmp1, tmp1, attn);
263 uni_vmulss(dG0, dG0, tmp1);
264 }
265
266 // compute diff_state_t_l
267 uni_vmulss(dHt, dHt, G0);
268 uni_vmovss(ptr[addr_diff_states_t_l_reg], dHt);
269
270 // downconvert and write data
271 to_src(sg_addr(0), dG0, scratch_data_t, hstate_dt_size);
272 to_src(sg_addr(2), dG2, scratch_data_t, hstate_dt_size);
273
274 // increment address pointers
275 add(addr_ws_gates_reg, scratch_dt_size);
276 add(addr_scratch_gates_reg, scratch_dt_size);
277 add(addr_diff_states_t_lp1_reg, hstate_dt_size);
278 add(addr_diff_states_tp1_l_reg, hstate_dt_size);
279 add(addr_diff_states_t_l_reg, hstate_dt_size);
280 add(addr_states_tm1_l_reg, scratch_dt_size);
281 inc_regs(hstate_dt_size);
282
283 // increment loop counter
284 sub(loop_cnt, scratch_dt_size);
285 jnz(rem_loop_start_label);
286 }
287 L(rem_loop_end_label);
288
289 if (is_augru) {
290 // Complete diff attention reduction
291 Xmm diff_attn_acc(dattn_acc_idx);
292 uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc);
293 uni_vhaddps(diff_attn_acc, diff_attn_acc, diff_attn_acc);
294 const auto base_args = get_stack_params_address();
295#ifdef _WIN32
296 mov(addr_attn_reg, ptr[base_args + 56]);
297#else
298 mov(addr_attn_reg, ptr[base_args + 40]);
299#endif
300 uni_vmovss(ptr[addr_attn_reg], diff_attn_acc);
301 }
302
303 postamble();
304
305 init_table(vlen);
306 L(table_label);
307 {
308 for (size_t i = 0; i < vlen / sizeof(float); i++)
309 dd(float2int(1.0f));
310 }
311 }
312};
313
314} // namespace x64
315} // namespace cpu
316} // namespace impl
317} // namespace dnnl
318
319#endif
320