1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_gru_cell_postgemm_part2_bwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_bwd) |
31 | |
32 | jit_uni_gru_cell_postgemm_part2_bwd( |
33 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
34 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
35 | |
36 | ~jit_uni_gru_cell_postgemm_part2_bwd() {} |
37 | |
38 | status_t init(data_type_t sdt) override { |
39 | jit_uni_rnn_postgemm::init(src_data_t); |
40 | return create_kernel(); |
41 | } |
42 | |
43 | protected: |
44 | // register size in bytes |
45 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
46 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
47 | static constexpr size_t hstate_dt_size = sizeof(float); |
48 | const size_t vlen_scratch |
49 | = vlen / (sizeof(float) / types::data_type_size(scratch_data_t)); |
50 | const size_t gate_dt_size = types::data_type_size(scratch_data_t); |
51 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
52 | |
53 | void generate() override { |
54 | using namespace Xbyak; |
55 | |
56 | // Labels declaration |
57 | Label vector_loop_start_label, vector_loop_inc_regs, |
58 | vector_loop_end_label; |
59 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
60 | |
61 | // Register map |
62 | const Reg64 loop_cnt(rbx); // loop counter |
63 | |
64 | // We skip vmm0 as it can be used by the injector for masks on sse4.1 |
65 | enum { |
66 | dG1_idx = 1, |
67 | dhG1_idx = 2, |
68 | hG1_idx = 3, |
69 | G1_idx = 4, |
70 | dH_idx = 5, |
71 | tmp1_idx = 6, |
72 | h_idx = 7 |
73 | }; |
74 | |
75 | // We start code generations here |
76 | preamble(); |
77 | |
78 | // extract addresses passed as parameter |
79 | const auto addr_ws_gates_reg = abi_param1; |
80 | const auto addr_scratch_gates_reg = abi_param2; |
81 | // auto addr_diff_states_t_lp1_reg = abi_param3; // not needed |
82 | // auto addr_diff_states_tp1_l_reg = abi_param4; // not needed |
83 | #ifdef _WIN32 |
84 | const auto addr_diff_states_t_l_reg = r10; |
85 | const auto addr_states_tm1_l_reg = r11; |
86 | const auto addr_scratch_cell_reg = r12; |
87 | // auto addr_ws_grid_reg = rsi; // not needed |
88 | const auto addr_dhG1_reg = rsi; |
89 | const auto base_args = get_stack_params_address(); |
90 | mov(addr_diff_states_t_l_reg, ptr[base_args]); |
91 | mov(addr_states_tm1_l_reg, ptr[base_args + 8]); |
92 | mov(addr_scratch_cell_reg, ptr[base_args + 16]); |
93 | // mov(addr_ws_grid_reg, ptr[base_args + 24]); |
94 | mov(addr_dhG1_reg, ptr[base_args + 32]); |
95 | #else |
96 | const auto addr_diff_states_t_l_reg = abi_param5; |
97 | const auto addr_states_tm1_l_reg = abi_param6; |
98 | const auto addr_scratch_cell_reg = r10; |
99 | // auto addr_ws_grid_reg = r11; // not needed |
100 | const auto addr_dhG1_reg = r11; |
101 | const auto base_args = get_stack_params_address(); |
102 | mov(addr_scratch_cell_reg, ptr[base_args]); |
103 | // mov(addr_ws_grid_reg, ptr[base_args + 8]); |
104 | mov(addr_dhG1_reg, ptr[base_args + 16]); |
105 | #endif |
106 | |
107 | // helper lambda to address the gates and biases |
108 | const auto sg_addr = [&](int i) { |
109 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; |
110 | }; |
111 | const auto wg_addr = [&](int i) { |
112 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; |
113 | }; |
114 | |
115 | // initialize registers with addresses and constants |
116 | init_regs(vlen); |
117 | |
118 | mov(loop_cnt, rnn_.dhc * scratch_dt_size); |
119 | cmp(loop_cnt, vlen_scratch); |
120 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
121 | |
122 | L(vector_loop_start_label); |
123 | { |
124 | const Vmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx), |
125 | dH(dH_idx), tmp1(tmp1_idx), h(h_idx); |
126 | |
127 | to_float(G1, wg_addr(1), src_data_t, vlen); |
128 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen); |
129 | |
130 | // compute dG1 |
131 | uni_vmovups(dG1, G1); |
132 | uni_vmovups(tmp1, G1); |
133 | uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2) |
134 | uni_vmulps(dG1, dG1, h); |
135 | uni_vmovups(dhG1, ptr[addr_dhG1_reg]); |
136 | uni_vmulps(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt |
137 | |
138 | // compute hG1 |
139 | uni_vmovups(hG1, G1); |
140 | uni_vmulps(hG1, hG1, h); |
141 | |
142 | // compute diff_states_t_l = diff_states_t_l + dhG1 * G1 |
143 | uni_vmovups(dH, ptr[addr_diff_states_t_l_reg]); |
144 | uni_vfmadd231ps(dH, dhG1, G1); |
145 | |
146 | // downconvert and write data |
147 | to_src(sg_addr(1), dG1, scratch_data_t, vlen); |
148 | to_src(ptr[addr_scratch_cell_reg], hG1, scratch_data_t, vlen); |
149 | uni_vmovups(ptr[addr_diff_states_t_l_reg], dH); |
150 | |
151 | // increment address pointers |
152 | add(addr_ws_gates_reg, vlen_scratch); |
153 | add(addr_scratch_gates_reg, vlen_scratch); |
154 | add(addr_dhG1_reg, vlen); |
155 | add(addr_diff_states_t_l_reg, vlen); |
156 | add(addr_states_tm1_l_reg, vlen_scratch); |
157 | add(addr_scratch_cell_reg, vlen_scratch); |
158 | inc_regs(vlen); |
159 | |
160 | // increment loop counter |
161 | sub(loop_cnt, vlen_scratch); |
162 | cmp(loop_cnt, vlen_scratch); |
163 | jge(vector_loop_start_label); |
164 | } |
165 | L(vector_loop_end_label); |
166 | |
167 | cmp(loop_cnt, 0); |
168 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
169 | // Same code as above, we just use movuss for accessing inputs |
170 | // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar |
171 | L(rem_loop_start_label); |
172 | { |
173 | const Xmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx), |
174 | dH(dH_idx), tmp1(tmp1_idx), h(h_idx); |
175 | |
176 | to_float(G1, wg_addr(1), src_data_t, hstate_dt_size); |
177 | to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size); |
178 | |
179 | // compute dG1 |
180 | uni_vmovss(dG1, G1); |
181 | uni_vmovss(tmp1, G1); |
182 | uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2) |
183 | uni_vmulss(dG1, dG1, h); |
184 | uni_vmovss(dhG1, ptr[addr_dhG1_reg]); |
185 | uni_vmulss(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt |
186 | |
187 | // compute hG1 |
188 | uni_vmovss(hG1, G1); |
189 | uni_vmulss(hG1, hG1, h); |
190 | |
191 | // compute diff_states_t_l = diff_states_t_l + dhG1 * G1 |
192 | uni_vmovss(dH, ptr[addr_diff_states_t_l_reg]); |
193 | uni_vfmadd231ps(dH, dhG1, G1); |
194 | |
195 | // downconvert and write data |
196 | to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size); |
197 | to_src(ptr[addr_scratch_cell_reg], hG1, scratch_data_t, |
198 | hstate_dt_size); |
199 | uni_vmovss(ptr[addr_diff_states_t_l_reg], dH); |
200 | |
201 | // increment address pointers |
202 | add(addr_ws_gates_reg, scratch_dt_size); |
203 | add(addr_scratch_gates_reg, scratch_dt_size); |
204 | add(addr_dhG1_reg, hstate_dt_size); |
205 | add(addr_diff_states_t_l_reg, hstate_dt_size); |
206 | add(addr_states_tm1_l_reg, scratch_dt_size); |
207 | add(addr_scratch_cell_reg, scratch_dt_size); |
208 | inc_regs(hstate_dt_size); |
209 | |
210 | // increment loop counter |
211 | sub(loop_cnt, scratch_dt_size); |
212 | cmp(loop_cnt, 0); |
213 | jg(rem_loop_start_label); |
214 | } |
215 | L(rem_loop_end_label); |
216 | |
217 | postamble(); |
218 | |
219 | init_table(vlen); |
220 | } |
221 | }; |
222 | |
223 | } // namespace x64 |
224 | } // namespace cpu |
225 | } // namespace impl |
226 | } // namespace dnnl |
227 | |
228 | #endif |
229 | |