1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP
18#define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP
19
20#include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26
27template <cpu_isa_t isa, impl::data_type_t src_data_t,
28 impl::data_type_t scratch_data_t>
29struct jit_uni_gru_cell_postgemm_part2_bwd : public jit_uni_rnn_postgemm {
30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_bwd)
31
32 jit_uni_gru_cell_postgemm_part2_bwd(
33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34 : jit_uni_rnn_postgemm(rnn, pd, jit_name()) {}
35
36 ~jit_uni_gru_cell_postgemm_part2_bwd() {}
37
38 status_t init(data_type_t sdt) override {
39 jit_uni_rnn_postgemm::init(src_data_t);
40 return create_kernel();
41 }
42
43protected:
44 // register size in bytes
45 using Vmm = typename cpu_isa_traits<isa>::Vmm;
46 static constexpr size_t vlen = cpu_isa_traits<isa>::vlen;
47 static constexpr size_t hstate_dt_size = sizeof(float);
48 const size_t vlen_scratch
49 = vlen / (sizeof(float) / types::data_type_size(scratch_data_t));
50 const size_t gate_dt_size = types::data_type_size(scratch_data_t);
51 const size_t scratch_dt_size = types::data_type_size(scratch_data_t);
52
53 void generate() override {
54 using namespace Xbyak;
55
56 // Labels declaration
57 Label vector_loop_start_label, vector_loop_inc_regs,
58 vector_loop_end_label;
59 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
60
61 // Register map
62 const Reg64 loop_cnt(rbx); // loop counter
63
64 // We skip vmm0 as it can be used by the injector for masks on sse4.1
65 enum {
66 dG1_idx = 1,
67 dhG1_idx = 2,
68 hG1_idx = 3,
69 G1_idx = 4,
70 dH_idx = 5,
71 tmp1_idx = 6,
72 h_idx = 7
73 };
74
75 // We start code generations here
76 preamble();
77
78 // extract addresses passed as parameter
79 const auto addr_ws_gates_reg = abi_param1;
80 const auto addr_scratch_gates_reg = abi_param2;
81 // auto addr_diff_states_t_lp1_reg = abi_param3; // not needed
82 // auto addr_diff_states_tp1_l_reg = abi_param4; // not needed
83#ifdef _WIN32
84 const auto addr_diff_states_t_l_reg = r10;
85 const auto addr_states_tm1_l_reg = r11;
86 const auto addr_scratch_cell_reg = r12;
87 // auto addr_ws_grid_reg = rsi; // not needed
88 const auto addr_dhG1_reg = rsi;
89 const auto base_args = get_stack_params_address();
90 mov(addr_diff_states_t_l_reg, ptr[base_args]);
91 mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
92 mov(addr_scratch_cell_reg, ptr[base_args + 16]);
93 // mov(addr_ws_grid_reg, ptr[base_args + 24]);
94 mov(addr_dhG1_reg, ptr[base_args + 32]);
95#else
96 const auto addr_diff_states_t_l_reg = abi_param5;
97 const auto addr_states_tm1_l_reg = abi_param6;
98 const auto addr_scratch_cell_reg = r10;
99 // auto addr_ws_grid_reg = r11; // not needed
100 const auto addr_dhG1_reg = r11;
101 const auto base_args = get_stack_params_address();
102 mov(addr_scratch_cell_reg, ptr[base_args]);
103 // mov(addr_ws_grid_reg, ptr[base_args + 8]);
104 mov(addr_dhG1_reg, ptr[base_args + 16]);
105#endif
106
107 // helper lambda to address the gates and biases
108 const auto sg_addr = [&](int i) {
109 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
110 };
111 const auto wg_addr = [&](int i) {
112 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
113 };
114
115 // initialize registers with addresses and constants
116 init_regs(vlen);
117
118 mov(loop_cnt, rnn_.dhc * scratch_dt_size);
119 cmp(loop_cnt, vlen_scratch);
120 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
121
122 L(vector_loop_start_label);
123 {
124 const Vmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx),
125 dH(dH_idx), tmp1(tmp1_idx), h(h_idx);
126
127 to_float(G1, wg_addr(1), src_data_t, vlen);
128 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, vlen);
129
130 // compute dG1
131 uni_vmovups(dG1, G1);
132 uni_vmovups(tmp1, G1);
133 uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2)
134 uni_vmulps(dG1, dG1, h);
135 uni_vmovups(dhG1, ptr[addr_dhG1_reg]);
136 uni_vmulps(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt
137
138 // compute hG1
139 uni_vmovups(hG1, G1);
140 uni_vmulps(hG1, hG1, h);
141
142 // compute diff_states_t_l = diff_states_t_l + dhG1 * G1
143 uni_vmovups(dH, ptr[addr_diff_states_t_l_reg]);
144 uni_vfmadd231ps(dH, dhG1, G1);
145
146 // downconvert and write data
147 to_src(sg_addr(1), dG1, scratch_data_t, vlen);
148 to_src(ptr[addr_scratch_cell_reg], hG1, scratch_data_t, vlen);
149 uni_vmovups(ptr[addr_diff_states_t_l_reg], dH);
150
151 // increment address pointers
152 add(addr_ws_gates_reg, vlen_scratch);
153 add(addr_scratch_gates_reg, vlen_scratch);
154 add(addr_dhG1_reg, vlen);
155 add(addr_diff_states_t_l_reg, vlen);
156 add(addr_states_tm1_l_reg, vlen_scratch);
157 add(addr_scratch_cell_reg, vlen_scratch);
158 inc_regs(vlen);
159
160 // increment loop counter
161 sub(loop_cnt, vlen_scratch);
162 cmp(loop_cnt, vlen_scratch);
163 jge(vector_loop_start_label);
164 }
165 L(vector_loop_end_label);
166
167 cmp(loop_cnt, 0);
168 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
169 // Same code as above, we just use movuss for accessing inputs
170 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
171 L(rem_loop_start_label);
172 {
173 const Xmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx),
174 dH(dH_idx), tmp1(tmp1_idx), h(h_idx);
175
176 to_float(G1, wg_addr(1), src_data_t, hstate_dt_size);
177 to_float(h, ptr[addr_states_tm1_l_reg], src_data_t, hstate_dt_size);
178
179 // compute dG1
180 uni_vmovss(dG1, G1);
181 uni_vmovss(tmp1, G1);
182 uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2)
183 uni_vmulss(dG1, dG1, h);
184 uni_vmovss(dhG1, ptr[addr_dhG1_reg]);
185 uni_vmulss(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt
186
187 // compute hG1
188 uni_vmovss(hG1, G1);
189 uni_vmulss(hG1, hG1, h);
190
191 // compute diff_states_t_l = diff_states_t_l + dhG1 * G1
192 uni_vmovss(dH, ptr[addr_diff_states_t_l_reg]);
193 uni_vfmadd231ps(dH, dhG1, G1);
194
195 // downconvert and write data
196 to_src(sg_addr(1), dG1, scratch_data_t, hstate_dt_size);
197 to_src(ptr[addr_scratch_cell_reg], hG1, scratch_data_t,
198 hstate_dt_size);
199 uni_vmovss(ptr[addr_diff_states_t_l_reg], dH);
200
201 // increment address pointers
202 add(addr_ws_gates_reg, scratch_dt_size);
203 add(addr_scratch_gates_reg, scratch_dt_size);
204 add(addr_dhG1_reg, hstate_dt_size);
205 add(addr_diff_states_t_l_reg, hstate_dt_size);
206 add(addr_states_tm1_l_reg, scratch_dt_size);
207 add(addr_scratch_cell_reg, scratch_dt_size);
208 inc_regs(hstate_dt_size);
209
210 // increment loop counter
211 sub(loop_cnt, scratch_dt_size);
212 cmp(loop_cnt, 0);
213 jg(rem_loop_start_label);
214 }
215 L(rem_loop_end_label);
216
217 postamble();
218
219 init_table(vlen);
220 }
221};
222
223} // namespace x64
224} // namespace cpu
225} // namespace impl
226} // namespace dnnl
227
228#endif
229