1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_RNN_RNN_UTILS_HPP |
18 | #define GPU_OCL_RNN_RNN_UTILS_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl_types.h" |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/memory_desc_wrapper.hpp" |
24 | |
25 | #define OFF6(i0, d0, i1, d1, i2, d2, i3, d3, i4, d4, i5, d5) \ |
26 | ((((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) * (d4) + (i4)) \ |
27 | * (d5) \ |
28 | + (i5)) |
29 | #define OFF5(i0, d0, i1, d1, i2, d2, i3, d3, i4, d4) \ |
30 | (((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) * (d4) + (i4)) |
31 | #define OFF4(i0, d0, i1, d1, i2, d2, i3, d3) \ |
32 | ((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) |
33 | #define OFF3(i0, d0, i1, d1, i2, d2) (((i0) * (d1) + (i1)) * (d2) + (i2)) |
34 | #define OFF2(i0, d0, i1, d1) ((i0) * (d1) + (i1)) |
35 | |
36 | #define elemwise_sig(f) \ |
37 | void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \ |
38 | int batch, const memory_storage_t &workspace, \ |
39 | const memory_storage_t &scratch_gates, \ |
40 | const memory_storage_t &scratch_diff_states, \ |
41 | const memory_storage_t *scales, const memory_storage_t &bias, \ |
42 | const memory_storage_t *tm_scales) const |
43 | |
44 | #define elemwise_sig_gru_lbr(f) \ |
45 | void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \ |
46 | int batch, const memory_storage_t &workspace, \ |
47 | const memory_storage_t &scratch_gates, \ |
48 | const memory_storage_t &scratch_cell, \ |
49 | const memory_storage_t &scratch_diff_states, \ |
50 | const memory_storage_t &bias, const memory_storage_t *tm_scales) \ |
51 | const |
52 | |
53 | #define elemwise_sig_gru(f) \ |
54 | void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \ |
55 | int batch, const memory_storage_t &workspace, \ |
56 | const memory_storage_t &scratch_gates, \ |
57 | const memory_storage_t &scratch_cell, \ |
58 | const memory_storage_t &scratch_diff_states, \ |
59 | const memory_storage_t &scratch_dhG1, \ |
60 | const memory_storage_t &bias, const memory_storage_t *tm_scales, \ |
61 | int part) const |
62 | |
63 | #define cell_execution_sig(f) \ |
64 | void f(engine_t *engine, const exec_ctx_t &ctx, int dir, int lay, \ |
65 | int iter, size_t *wei_layer_offset, size_t *wei_iter_offset, \ |
66 | const memory_storage_t &bias, const memory_storage_t &workspace, \ |
67 | const memory_storage_t &scratch_gates, \ |
68 | const memory_storage_t &scratch_cell, \ |
69 | const memory_storage_t &scratch_diff_states, \ |
70 | const memory_storage_t &scratch_dhG1, \ |
71 | const memory_storage_t &wei_layer, \ |
72 | const memory_storage_t &wei_iter, \ |
73 | const memory_storage_t &diff_weights_layer, \ |
74 | const memory_storage_t &diff_weights_iter, \ |
75 | const memory_storage_t &diff_bias, const memory_storage_t *scales, \ |
76 | const memory_storage_t *tm_scales) const |
77 | |
78 | #define grid_execution_sig(f) \ |
79 | void f(engine_t *engine, const exec_ctx_t &ctx, \ |
80 | const memory_storage_t &bias, const memory_storage_t &workspace, \ |
81 | const memory_storage_t &scratch_gates, \ |
82 | const memory_storage_t &scratch_cell, \ |
83 | const memory_storage_t &scratch_diff_states, \ |
84 | const memory_storage_t &scratch_dhG1, \ |
85 | const memory_storage_t &wei_layer, \ |
86 | const memory_storage_t &wei_iter, \ |
87 | const memory_storage_t &diff_weights_layer, \ |
88 | const memory_storage_t &diff_weights_iter, \ |
89 | const memory_storage_t &diff_bias, const memory_storage_t *scales, \ |
90 | const memory_storage_t *tm_scales) const |
91 | |
92 | #define gemm_sig(f) \ |
93 | void f(engine_t *engine, const exec_ctx_t &ctx, const memory_storage_t &a, \ |
94 | size_t off_a, const memory_storage_t &b, size_t off_b, \ |
95 | const memory_storage_t &c, size_t off_c, gemm_kind_t gemm_kind) \ |
96 | const |
97 | |
98 | #define weights_assign_sig(f) \ |
99 | void f(const rnn_utils::conf_t &rnn, const memory_desc_t *md, \ |
100 | size_t *weights_, int n_parts, const int *gates_per_part, \ |
101 | const memory_storage_t &w_, int ld, int nld, data_type_t wei_t) \ |
102 | const |
103 | |
104 | namespace dnnl { |
105 | namespace impl { |
106 | namespace gpu { |
107 | namespace ocl { |
108 | |
109 | namespace rnn_utils { |
110 | |
111 | enum execution_direction_t { |
112 | l2r, |
113 | r2l, |
114 | bi_concat, |
115 | bi_sum, |
116 | }; |
117 | |
118 | enum data_type_conf_t { |
119 | all_f32, |
120 | all_f16, |
121 | all_bf16, |
122 | u8u8u8f32, |
123 | f32u8f32f32, |
124 | u8u8u8u8, |
125 | f32u8f32u8 |
126 | }; |
127 | |
128 | enum ws_part_t { |
129 | gates, |
130 | states, |
131 | c_states, |
132 | diff_states, |
133 | dhG1_gru, |
134 | cell, |
135 | grid, |
136 | bias |
137 | }; |
138 | |
139 | struct conf_t { |
140 | execution_direction_t exec_dir; |
141 | data_type_conf_t dt_conf; |
142 | int n_layer, n_iter, n_dir, n_gates, n_states; |
143 | int mb; |
144 | int slc, sic, dhc, dlc; |
145 | |
146 | int gates_ld, gates_nld, gates_ws_ld, arch_ld; |
147 | |
148 | int n_parts_weights_layer, parts_weights_layer[DNNL_RNN_MAX_N_PARTS]; |
149 | int n_parts_weights_iter, parts_weights_iter[DNNL_RNN_MAX_N_PARTS]; |
150 | int n_bias, n_parts_bias, parts_bias[DNNL_RNN_MAX_N_PARTS]; |
151 | |
152 | size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS], |
153 | part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS]; |
154 | |
155 | // Size of packed data in bytes |
156 | size_t weights_layer_comp_offset, weights_layer_pack_size, |
157 | weights_iter_comp_offset, weights_iter_pack_size; |
158 | |
159 | bool copy_bias; |
160 | int weights_layer_ld, weights_layer_nld; |
161 | int diff_weights_layer_ld, diff_weights_layer_nld; |
162 | int weights_iter_ld, weights_iter_nld; |
163 | int diff_weights_iter_ld, diff_weights_iter_nld; |
164 | int states_nld, states_ws_ld, scratch_diff_states_ld; |
165 | int weights_iter_compensation_size, weights_layer_compensation_size; |
166 | bool is_fwd, is_training, is_lbr, is_int8, is_testmode, is_vanilla_gru; |
167 | bool use_workspace; |
168 | |
169 | // for test mode (--skip_nonliner=true of benchdnn) |
170 | float tm_cscale; |
171 | int tm_ngates; |
172 | |
173 | // Size of workspace for each tensor in bytes |
174 | size_t ws_gates_size, ws_states_size, ws_c_states_size, |
175 | scratch_diff_states_size, scratch_cell_size, scratch_dhG1_size, |
176 | ws_grid_comp_size, ws_per_cell, ws_bias_size; |
177 | |
178 | bool merge_gemm_iter, merge_gemm_layer, use_gemm, use_layer_packed_gemm, |
179 | use_iter_packed_gemm; |
180 | |
181 | // Element size of each workspace part in bytes |
182 | int ws_gates_elsz, ws_states_elsz, ws_grid_comp_elsz, ws_bias_elsz; |
183 | |
184 | size_t scratch_gates_size; |
185 | int n_iter_scratch_gates; |
186 | int scratch_gates_elsz, scratch_gates_ld; |
187 | |
188 | data_type_t acc_data_type; |
189 | int acc_data_type_elsz; |
190 | data_type_t aux_data_type; |
191 | data_type_t input_data_type; |
192 | data_type_t output_data_type; |
193 | data_type_t dst_data_type; |
194 | data_type_t diff_data_type; |
195 | }; |
196 | bool is_ldigo(const memory_desc_wrapper &md); |
197 | bool is_ldgoi(const memory_desc_wrapper &md); |
198 | |
199 | int get_good_ld(int arch_ld, int dim, int sizeof_dt); |
200 | void init_rnn_conf(conf_t &rnn, const rnn_desc_t &rd, |
201 | const memory_desc_wrapper &src_layer_d, |
202 | const memory_desc_wrapper &src_iter_d, |
203 | const memory_desc_wrapper &weights_layer_d, |
204 | const memory_desc_wrapper &weights_iter_d, |
205 | const memory_desc_wrapper &dst_layer_d, bool is_xe_hpc); |
206 | void init_test_mode(conf_t &rnn, const primitive_attr_t &attr); |
207 | void set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd, |
208 | const memory_desc_wrapper &weights_layer_d, |
209 | const memory_desc_wrapper &weights_iter_d, |
210 | const memory_desc_wrapper &diff_weights_layer_d, |
211 | const memory_desc_wrapper &diff_weights_iter_d); |
212 | void set_offsets(const conf_t &rnn, size_t &ws_gates_offset, |
213 | size_t &ws_h_state_offset, size_t &ws_c_state_offset, |
214 | size_t &ws_grid_comp_offset, size_t &ws_bias_offset, |
215 | size_t &scratch_diff_states_offset, size_t &scratch_cell_offset, |
216 | size_t &scratch_dhG1_offset, size_t &scratch_gates_offset, |
217 | size_t &scratchpad_size, size_t &workspace_size); |
218 | void set_gru_offsets_part2(const conf_t &rnn, int iter, int dir, int lay, |
219 | data_type_t src_t, size_t *wei_iter_off_ptr, |
220 | const size_t &ws_states_offset_, size_t &cell_wei_iter_offset, |
221 | size_t &cell_scratch_offset, size_t &cell_ws_iter_offset); |
222 | void set_offsets_fwd_gemm(const conf_t &rnn, int dir, int lay, |
223 | data_type_t src_t, size_t *wei_layer_off_ptr, |
224 | const size_t &ws_states_offset_, size_t &grid_ws_lay_offset, |
225 | size_t &grid_wei_lay_offset, size_t &grid_ws_iter_offset); |
226 | void set_offsets_fwd_gemm(const conf_t &rnn, int iter, int dir, int lay, |
227 | data_type_t src_t, size_t *wei_iter_off_ptr, |
228 | const size_t &ws_states_offset_, size_t &cell_ws_iter_offset, |
229 | size_t &cell_ws_lay_offset, size_t &cell_scratch_offset, |
230 | size_t &cell_wei_iter_offset); |
231 | void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay, |
232 | size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off, |
233 | size_t &cell_scr_diff_lay_off, size_t &cell_scr_diff_iter_off); |
234 | void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay, |
235 | size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off, |
236 | size_t &cell_scr_diff_lay_off, size_t &cell_scr_diff_iter_off, |
237 | size_t &cell_diff_wei_iter_off2); |
238 | void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay, |
239 | size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off, |
240 | size_t &cell_scr_diff_lay_off); |
241 | void get_scratchpad_and_workspace_sizes( |
242 | const conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size); |
243 | status_t set_expected_desc( |
244 | conf_t &rnn, memory_desc_t &weights_md, bool is_iter); |
245 | status_t set_good_strides(int ld_, memory_desc_t &weights_md, format_tag_t tag); |
246 | } // namespace rnn_utils |
247 | |
248 | } // namespace ocl |
249 | } // namespace gpu |
250 | } // namespace impl |
251 | } // namespace dnnl |
252 | |
253 | #endif |
254 | |