1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_BWD_HPP
18#define CPU_X64_RNN_JIT_UNI_RNN_CELL_POSTGEMM_BWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_rnn_cell_postgemm_bwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_rnn_cell_postgemm_bwd)
31
32 jit_uni_rnn_cell_postgemm_bwd(
33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
35
36 ~jit_uni_rnn_cell_postgemm_bwd() {}
37
38 status_t init(data_type_t sdt) override {
39 jit_uni_rnn_postgemm::init(src_data_t);
40 return create_kernel();
41 }
42
43protected:
44 // register size in bytes
45 using Vmm = typename cpu_isa_traits<isa>::Vmm;
46 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
47 static constexpr size_t hstate_dt_size = sizeof(float);
48 const size_t vlen_scratch
49 = vlen / (sizeof(float) / types::data_type_size(scratch_data_t));
50 const size_t gate_dt_size = types::data_type_size(scratch_data_t);
51 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
52
53 void generate() override {
54 using namespace Xbyak;
55
56 // Labels declaration
57 Label vector_loop_start_label, vector_loop_inc_regs,
58 vector_loop_end_label;
59 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
60 Label table_one_label, table_alpha_label;
61
62 // Register map
63 // aliasing with table_reg and loop_cnt since they are not used at the same time
64 const Reg64 table_reg(r11);
65 const Reg64 loop_cnt(r11);
66
67 // Here we do no unrolling, loop overhead should not be that dramatic
68 // Note: G has to be indexed at 0 as it is used as a mask in blend for bwd relu
69 enum {
70 G_idx = 0,
71 dG_idx,
72 dHt_idx,
73 tmp1_idx,
74 one_idx,
75 zero_idx,
76 alpha_idx
77 };
78 const Xbyak::Opmask kmask(1);
79
80 // We start code generations here
81 preamble();
82
83 // extract addresses passed as parameter
84 const auto addr_ws_gates_reg = abi_param1;
85 const auto addr_scratch_gates_reg = abi_param2;
86 const auto addr_diff_states_t_lp1_reg = abi_param3;
87 const auto addr_diff_states_tp1_l_reg = abi_param4;
88
89 // helper lambda to address the gates and biases
90 const auto sg_addr = [&](int i) {
91 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
92 };
93 const auto wg_addr = [&](int i) {
94 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
95 };
96 // auto sc_addr = [&](int i) {
97 // return ptr[addr_scratch_cell_reg + i * rnn_.dhc * scratch_dt_size];
98 // };
99
100 // initialize registers with addresses and constants
101 init_regs(vlen);
102
103 mov(table_reg, table_one_label);
104 uni_vmovups(Vmm(one_idx), ptr[table_reg]);
105
106 if (pd_->activation_kind() == alg_kind::eltwise_relu) {
107 mov(table_reg, table_alpha_label);
108 uni_vmovups(Vmm(alpha_idx), ptr[table_reg]);
109 }
110
111 uni_vxorps(Vmm(zero_idx), Vmm(zero_idx), Vmm(zero_idx));
112
113 mov(loop_cnt, rnn_.dhc * scratch_dt_size);
114 cmp(loop_cnt, vlen_scratch);
115 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
116
117 L(vector_loop_start_label);
118 {
119 const Vmm G(G_idx), dG(dG_idx), dHt(dHt_idx), tmp1(tmp1_idx),
120 one(one_idx), zero(zero_idx), alpha(alpha_idx);
121
122 to_float(G, wg_addr(0), src_data_t, vlen);
123
124 // compute dHt
125 uni_vmovups(dHt, ptr[addr_diff_states_tp1_l_reg]);
126 uni_vmovups(tmp1, ptr[addr_diff_states_t_lp1_reg]);
127 uni_vaddps(dHt, dHt, tmp1);
128
129 // compute dG
130 switch (pd_->activation_kind()) {
131 case alg_kind::eltwise_relu:
132 // G > 0 ? alpha : 1
133 if (G.isZMM()) {
134 vcmpps(kmask, G, zero, _cmp_nle_us);
135 vblendmps(dG | kmask, alpha, one);
136 } else {
137 // NOTE: here G is assumed to be xmm0 for sse4.1 blendvps to work
138 uni_vcmpps(G, G, zero, _cmp_nle_us);
139 uni_vmovups(dG, alpha);
140 uni_vblendvps(dG, dG, one, G);
141 }
142 break;
143 case alg_kind::eltwise_tanh:
144 // 1 - G^2
145 uni_vmovups(dG, one);
146 uni_vfnmadd231ps(dG, G, G); // (1 - G2^2)
147 break;
148 case alg_kind::eltwise_logistic:
149 uni_vmovups(dG, G);
150 uni_vfnmadd231ps(dG, G, G); // (G - G^2)
151 break;
152 default: assert(!"unsupported");
153 }
154
155 // dG = dG * dHt
156 uni_vmulps(dG, dG, dHt);
157
158 // downconvert and write data
159 to_src(sg_addr(0), dG, scratch_data_t, vlen);
160
161 // increment address pointers
162 add(addr_ws_gates_reg, vlen_scratch);
163 add(addr_scratch_gates_reg, vlen_scratch);
164 add(addr_diff_states_t_lp1_reg, vlen);
165 add(addr_diff_states_tp1_l_reg, vlen);
166 inc_regs(vlen);
167
168 // increment loop counter
169 sub(loop_cnt, vlen_scratch);
170 cmp(loop_cnt, vlen_scratch);
171 jge(vector_loop_start_label);
172 }
173 L(vector_loop_end_label);
174
175 cmp(loop_cnt, 0);
176 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
177 // Same code as above, we just use movuss for accessing inputs
178 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
179 L(rem_loop_start_label);
180 {
181 const Xmm G(G_idx), dG(dG_idx), dHt(dHt_idx), tmp1(tmp1_idx),
182 one(one_idx), zero(zero_idx), alpha(alpha_idx);
183
184 to_float(G, wg_addr(0), src_data_t, hstate_dt_size);
185
186 // compute dHt
187 uni_vmovss(dHt, ptr[addr_diff_states_tp1_l_reg]);
188 uni_vmovss(tmp1, ptr[addr_diff_states_t_lp1_reg]);
189 uni_vaddss(dHt, dHt, tmp1);
190
191 // compute dG
192 switch (pd_->activation_kind()) {
193 case alg_kind::eltwise_relu:
194 // G > 0 ? alpha : 1
195 // NOTE: here G is assumed to be xmm0 for sse4.1 blendvps to work
196 uni_vcmpps(G, G, zero, _cmp_nle_us);
197 uni_vmovups(dG, alpha);
198 uni_vblendvps(dG, dG, one, G);
199 break;
200 case alg_kind::eltwise_tanh:
201 // 1 - G^2
202 uni_vmovss(dG, one);
203 uni_vfnmadd231ps(dG, G, G); // (1 - G2^2)
204 break;
205 case alg_kind::eltwise_logistic:
206 uni_vmovss(dG, G);
207 uni_vfnmadd231ps(dG, G, G); // (G - G^2)
208 break;
209 default: assert(!"unsupported");
210 }
211
212 // dG = dG * dHt
213 uni_vmulps(dG, dG, dHt);
214
215 // downconvert and write data
216 to_src(sg_addr(0), dG, scratch_data_t, hstate_dt_size);
217
218 // increment address pointers
219 add(addr_ws_gates_reg, scratch_dt_size);
220 add(addr_scratch_gates_reg, scratch_dt_size);
221 add(addr_diff_states_t_lp1_reg, hstate_dt_size);
222 add(addr_diff_states_tp1_l_reg, hstate_dt_size);
223 inc_regs(hstate_dt_size);
224
225 // increment loop counter
226 sub(loop_cnt, scratch_dt_size);
227 cmp(loop_cnt, 0);
228 jg(rem_loop_start_label);
229 }
230 L(rem_loop_end_label);
231
232 postamble();
233
234 // inject the constant table for the activation
235 init_table(vlen);
236 L(table_one_label);
237 {
238 for (size_t i = 0; i < vlen / sizeof(float); i++)
239 dd(float2int(1.0f));
240 }
241 L(table_alpha_label);
242 {
243 if (pd_->activation_kind() == alg_kind::eltwise_relu) {
244 for (size_t i = 0; i < vlen / sizeof(float); i++)
245 dd(float2int(pd_->desc()->alpha));
246 }
247 }
248 }
249};
250
251} // namespace x64
252} // namespace cpu
253} // namespace impl
254} // namespace dnnl
255
256#endif
257