1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_BWD_HPP |
18 | #define CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_BWD_HPP |
19 | |
20 | #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | |
27 | template <cpu_isa_t isa, impl::data_type_t src_data_t, |
28 | impl::data_type_t scratch_data_t> |
29 | struct jit_uni_rnn_cell_postgemm_bwd : public jit_uni_rnn_postgemm { |
30 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_rnn_cell_postgemm_bwd) |
31 | |
32 | jit_uni_rnn_cell_postgemm_bwd( |
33 | const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) |
34 | : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {} |
35 | |
36 | ~jit_uni_rnn_cell_postgemm_bwd() {} |
37 | |
38 | status_t init(data_type_t sdt) override { |
39 | jit_uni_rnn_postgemm::init(src_data_t); |
40 | return create_kernel(); |
41 | } |
42 | |
43 | protected: |
44 | // register size in bytes |
45 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
46 | static constexpr size_t vlen = cpu_isa_traits<isa>::vlen; |
47 | static constexpr size_t hstate_dt_size = sizeof(float); |
48 | const size_t vlen_scratch |
49 | = vlen / (sizeof(float) / types::data_type_size(scratch_data_t)); |
50 | const size_t gate_dt_size = types::data_type_size(scratch_data_t); |
51 | const size_t scratch_dt_size = types::data_type_size(scratch_data_t); |
52 | |
53 | void generate() override { |
54 | using namespace Xbyak; |
55 | |
56 | // Labels declaration |
57 | Label vector_loop_start_label, vector_loop_inc_regs, |
58 | vector_loop_end_label; |
59 | Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; |
60 | Label table_one_label, table_alpha_label; |
61 | |
62 | // Register map |
63 | // aliasing with table_reg and loop_cnt since they are not used at the same time |
64 | const Reg64 table_reg(r11); |
65 | const Reg64 loop_cnt(r11); |
66 | |
67 | // Here we do no unrolling, loop overhead should not be that dramatic |
68 | // Note: G has to be indexed at 0 as it is used as a mask in blend for bwd relu |
69 | enum { |
70 | G_idx = 0, |
71 | dG_idx, |
72 | dHt_idx, |
73 | tmp1_idx, |
74 | one_idx, |
75 | zero_idx, |
76 | alpha_idx |
77 | }; |
78 | const Xbyak::Opmask kmask(1); |
79 | |
80 | // We start code generations here |
81 | preamble(); |
82 | |
83 | // extract addresses passed as parameter |
84 | const auto addr_ws_gates_reg = abi_param1; |
85 | const auto addr_scratch_gates_reg = abi_param2; |
86 | const auto addr_diff_states_t_lp1_reg = abi_param3; |
87 | const auto addr_diff_states_tp1_l_reg = abi_param4; |
88 | |
89 | // helper lambda to address the gates and biases |
90 | const auto sg_addr = [&](int i) { |
91 | return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; |
92 | }; |
93 | const auto wg_addr = [&](int i) { |
94 | return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; |
95 | }; |
96 | // auto sc_addr = [&](int i) { |
97 | // return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size]; |
98 | // }; |
99 | |
100 | // initialize registers with addresses and constants |
101 | init_regs(vlen); |
102 | |
103 | mov(table_reg, table_one_label); |
104 | uni_vmovups(Vmm(one_idx), ptr[table_reg]); |
105 | |
106 | if (pd_->activation_kind() == alg_kind::eltwise_relu) { |
107 | mov(table_reg, table_alpha_label); |
108 | uni_vmovups(Vmm(alpha_idx), ptr[table_reg]); |
109 | } |
110 | |
111 | uni_vxorps(Vmm(zero_idx), Vmm(zero_idx), Vmm(zero_idx)); |
112 | |
113 | mov(loop_cnt, rnn_.dhc * scratch_dt_size); |
114 | cmp(loop_cnt, vlen_scratch); |
115 | jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
116 | |
117 | L(vector_loop_start_label); |
118 | { |
119 | const Vmm G(G_idx), dG(dG_idx), dHt(dHt_idx), tmp1(tmp1_idx), |
120 | one(one_idx), zero(zero_idx), alpha(alpha_idx); |
121 | |
122 | to_float(G, wg_addr(0), src_data_t, vlen); |
123 | |
124 | // compute dHt |
125 | uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]); |
126 | uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
127 | uni_vaddps(dHt, dHt, tmp1); |
128 | |
129 | // compute dG |
130 | switch (pd_->activation_kind()) { |
131 | case alg_kind::eltwise_relu: |
132 | // G > 0 ? alpha : 1 |
133 | if (G.isZMM()) { |
134 | vcmpps(kmask, G, zero, _cmp_nle_us); |
135 | vblendmps(dG | kmask, alpha, one); |
136 | } else { |
137 | // NOTE: here G is assumed to be xmm0 for sse4.1 blendvps to work |
138 | uni_vcmpps(G, G, zero, _cmp_nle_us); |
139 | uni_vmovups(dG, alpha); |
140 | uni_vblendvps(dG, dG, one, G); |
141 | } |
142 | break; |
143 | case alg_kind::eltwise_tanh: |
144 | // 1 - G^2 |
145 | uni_vmovups(dG, one); |
146 | uni_vfnmadd231ps(dG, G, G); // (1 - G2^2) |
147 | break; |
148 | case alg_kind::eltwise_logistic: |
149 | uni_vmovups(dG, G); |
150 | uni_vfnmadd231ps(dG, G, G); // (G - G^2) |
151 | break; |
152 | default: assert(!"unsupported" ); |
153 | } |
154 | |
155 | // dG = dG * dHt |
156 | uni_vmulps(dG, dG, dHt); |
157 | |
158 | // downconvert and write data |
159 | to_src(sg_addr(0), dG, scratch_data_t, vlen); |
160 | |
161 | // increment address pointers |
162 | add(addr_ws_gates_reg, vlen_scratch); |
163 | add(addr_scratch_gates_reg, vlen_scratch); |
164 | add(addr_diff_states_t_lp1_reg, vlen); |
165 | add(addr_diff_states_tp1_l_reg, vlen); |
166 | inc_regs(vlen); |
167 | |
168 | // increment loop counter |
169 | sub(loop_cnt, vlen_scratch); |
170 | cmp(loop_cnt, vlen_scratch); |
171 | jge(vector_loop_start_label); |
172 | } |
173 | L(vector_loop_end_label); |
174 | |
175 | cmp(loop_cnt, 0); |
176 | je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); |
177 | // Same code as above, we just use movuss for accessing inputs |
178 | // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar |
179 | L(rem_loop_start_label); |
180 | { |
181 | const Xmm G(G_idx), dG(dG_idx), dHt(dHt_idx), tmp1(tmp1_idx), |
182 | one(one_idx), zero(zero_idx), alpha(alpha_idx); |
183 | |
184 | to_float(G, wg_addr(0), src_data_t, hstate_dt_size); |
185 | |
186 | // compute dHt |
187 | uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]); |
188 | uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]); |
189 | uni_vaddss(dHt, dHt, tmp1); |
190 | |
191 | // compute dG |
192 | switch (pd_->activation_kind()) { |
193 | case alg_kind::eltwise_relu: |
194 | // G > 0 ? alpha : 1 |
195 | // NOTE: here G is assumed to be xmm0 for sse4.1 blendvps to work |
196 | uni_vcmpps(G, G, zero, _cmp_nle_us); |
197 | uni_vmovups(dG, alpha); |
198 | uni_vblendvps(dG, dG, one, G); |
199 | break; |
200 | case alg_kind::eltwise_tanh: |
201 | // 1 - G^2 |
202 | uni_vmovss(dG, one); |
203 | uni_vfnmadd231ps(dG, G, G); // (1 - G2^2) |
204 | break; |
205 | case alg_kind::eltwise_logistic: |
206 | uni_vmovss(dG, G); |
207 | uni_vfnmadd231ps(dG, G, G); // (G - G^2) |
208 | break; |
209 | default: assert(!"unsupported" ); |
210 | } |
211 | |
212 | // dG = dG * dHt |
213 | uni_vmulps(dG, dG, dHt); |
214 | |
215 | // downconvert and write data |
216 | to_src(sg_addr(0), dG, scratch_data_t, hstate_dt_size); |
217 | |
218 | // increment address pointers |
219 | add(addr_ws_gates_reg, scratch_dt_size); |
220 | add(addr_scratch_gates_reg, scratch_dt_size); |
221 | add(addr_diff_states_t_lp1_reg, hstate_dt_size); |
222 | add(addr_diff_states_tp1_l_reg, hstate_dt_size); |
223 | inc_regs(hstate_dt_size); |
224 | |
225 | // increment loop counter |
226 | sub(loop_cnt, scratch_dt_size); |
227 | cmp(loop_cnt, 0); |
228 | jg(rem_loop_start_label); |
229 | } |
230 | L(rem_loop_end_label); |
231 | |
232 | postamble(); |
233 | |
234 | // inject the constant table for the activation |
235 | init_table(vlen); |
236 | L(table_one_label); |
237 | { |
238 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
239 | dd(float2int(1.0f)); |
240 | } |
241 | L(table_alpha_label); |
242 | { |
243 | if (pd_->activation_kind() == alg_kind::eltwise_relu) { |
244 | for (size_t i = 0; i < vlen / sizeof(float); i++) |
245 | dd(float2int(pd_->desc()->alpha)); |
246 | } |
247 | } |
248 | } |
249 | }; |
250 | |
251 | } // namespace x64 |
252 | } // namespace cpu |
253 | } // namespace impl |
254 | } // namespace dnnl |
255 | |
256 | #endif |
257 | |