1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_RNN_RNN_UTILS_HPP
18#define GPU_OCL_RNN_RNN_UTILS_HPP
19
20#include "oneapi/dnnl/dnnl_types.h"
21
22#include "common/c_types_map.hpp"
23#include "common/memory_desc_wrapper.hpp"
24
25#define OFF6(i0, d0, i1, d1, i2, d2, i3, d3, i4, d4, i5, d5) \
26 ((((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) * (d4) + (i4)) \
27 * (d5) \
28 + (i5))
29#define OFF5(i0, d0, i1, d1, i2, d2, i3, d3, i4, d4) \
30 (((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3)) * (d4) + (i4))
31#define OFF4(i0, d0, i1, d1, i2, d2, i3, d3) \
32 ((((i0) * (d1) + (i1)) * (d2) + (i2)) * (d3) + (i3))
33#define OFF3(i0, d0, i1, d1, i2, d2) (((i0) * (d1) + (i1)) * (d2) + (i2))
34#define OFF2(i0, d0, i1, d1) ((i0) * (d1) + (i1))
35
36#define elemwise_sig(f) \
37 void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \
38 int batch, const memory_storage_t &workspace, \
39 const memory_storage_t &scratch_gates, \
40 const memory_storage_t &scratch_diff_states, \
41 const memory_storage_t *scales, const memory_storage_t &bias, \
42 const memory_storage_t *tm_scales) const
43
44#define elemwise_sig_gru_lbr(f) \
45 void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \
46 int batch, const memory_storage_t &workspace, \
47 const memory_storage_t &scratch_gates, \
48 const memory_storage_t &scratch_cell, \
49 const memory_storage_t &scratch_diff_states, \
50 const memory_storage_t &bias, const memory_storage_t *tm_scales) \
51 const
52
53#define elemwise_sig_gru(f) \
54 void f(const exec_ctx_t &ctx, int dir, int lay, int iter, int dhc, \
55 int batch, const memory_storage_t &workspace, \
56 const memory_storage_t &scratch_gates, \
57 const memory_storage_t &scratch_cell, \
58 const memory_storage_t &scratch_diff_states, \
59 const memory_storage_t &scratch_dhG1, \
60 const memory_storage_t &bias, const memory_storage_t *tm_scales, \
61 int part) const
62
63#define cell_execution_sig(f) \
64 void f(engine_t *engine, const exec_ctx_t &ctx, int dir, int lay, \
65 int iter, size_t *wei_layer_offset, size_t *wei_iter_offset, \
66 const memory_storage_t &bias, const memory_storage_t &workspace, \
67 const memory_storage_t &scratch_gates, \
68 const memory_storage_t &scratch_cell, \
69 const memory_storage_t &scratch_diff_states, \
70 const memory_storage_t &scratch_dhG1, \
71 const memory_storage_t &wei_layer, \
72 const memory_storage_t &wei_iter, \
73 const memory_storage_t &diff_weights_layer, \
74 const memory_storage_t &diff_weights_iter, \
75 const memory_storage_t &diff_bias, const memory_storage_t *scales, \
76 const memory_storage_t *tm_scales) const
77
78#define grid_execution_sig(f) \
79 void f(engine_t *engine, const exec_ctx_t &ctx, \
80 const memory_storage_t &bias, const memory_storage_t &workspace, \
81 const memory_storage_t &scratch_gates, \
82 const memory_storage_t &scratch_cell, \
83 const memory_storage_t &scratch_diff_states, \
84 const memory_storage_t &scratch_dhG1, \
85 const memory_storage_t &wei_layer, \
86 const memory_storage_t &wei_iter, \
87 const memory_storage_t &diff_weights_layer, \
88 const memory_storage_t &diff_weights_iter, \
89 const memory_storage_t &diff_bias, const memory_storage_t *scales, \
90 const memory_storage_t *tm_scales) const
91
92#define gemm_sig(f) \
93 void f(engine_t *engine, const exec_ctx_t &ctx, const memory_storage_t &a, \
94 size_t off_a, const memory_storage_t &b, size_t off_b, \
95 const memory_storage_t &c, size_t off_c, gemm_kind_t gemm_kind) \
96 const
97
98#define weights_assign_sig(f) \
99 void f(const rnn_utils::conf_t &rnn, const memory_desc_t *md, \
100 size_t *weights_, int n_parts, const int *gates_per_part, \
101 const memory_storage_t &w_, int ld, int nld, data_type_t wei_t) \
102 const
103
104namespace dnnl {
105namespace impl {
106namespace gpu {
107namespace ocl {
108
109namespace rnn_utils {
110
111enum execution_direction_t {
112 l2r,
113 r2l,
114 bi_concat,
115 bi_sum,
116};
117
118enum data_type_conf_t {
119 all_f32,
120 all_f16,
121 all_bf16,
122 u8u8u8f32,
123 f32u8f32f32,
124 u8u8u8u8,
125 f32u8f32u8
126};
127
128enum ws_part_t {
129 gates,
130 states,
131 c_states,
132 diff_states,
133 dhG1_gru,
134 cell,
135 grid,
136 bias
137};
138
139struct conf_t {
140 execution_direction_t exec_dir;
141 data_type_conf_t dt_conf;
142 int n_layer, n_iter, n_dir, n_gates, n_states;
143 int mb;
144 int slc, sic, dhc, dlc;
145
146 int gates_ld, gates_nld, gates_ws_ld, arch_ld;
147
148 int n_parts_weights_layer, parts_weights_layer[DNNL_RNN_MAX_N_PARTS];
149 int n_parts_weights_iter, parts_weights_iter[DNNL_RNN_MAX_N_PARTS];
150 int n_bias, n_parts_bias, parts_bias[DNNL_RNN_MAX_N_PARTS];
151
152 size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS],
153 part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS];
154
155 // Size of packed data in bytes
156 size_t weights_layer_comp_offset, weights_layer_pack_size,
157 weights_iter_comp_offset, weights_iter_pack_size;
158
159 bool copy_bias;
160 int weights_layer_ld, weights_layer_nld;
161 int diff_weights_layer_ld, diff_weights_layer_nld;
162 int weights_iter_ld, weights_iter_nld;
163 int diff_weights_iter_ld, diff_weights_iter_nld;
164 int states_nld, states_ws_ld, scratch_diff_states_ld;
165 int weights_iter_compensation_size, weights_layer_compensation_size;
166 bool is_fwd, is_training, is_lbr, is_int8, is_testmode, is_vanilla_gru;
167 bool use_workspace;
168
169 // for test mode (--skip_nonliner=true of benchdnn)
170 float tm_cscale;
171 int tm_ngates;
172
173 // Size of workspace for each tensor in bytes
174 size_t ws_gates_size, ws_states_size, ws_c_states_size,
175 scratch_diff_states_size, scratch_cell_size, scratch_dhG1_size,
176 ws_grid_comp_size, ws_per_cell, ws_bias_size;
177
178 bool merge_gemm_iter, merge_gemm_layer, use_gemm, use_layer_packed_gemm,
179 use_iter_packed_gemm;
180
181 // Element size of each workspace part in bytes
182 int ws_gates_elsz, ws_states_elsz, ws_grid_comp_elsz, ws_bias_elsz;
183
184 size_t scratch_gates_size;
185 int n_iter_scratch_gates;
186 int scratch_gates_elsz, scratch_gates_ld;
187
188 data_type_t acc_data_type;
189 int acc_data_type_elsz;
190 data_type_t aux_data_type;
191 data_type_t input_data_type;
192 data_type_t output_data_type;
193 data_type_t dst_data_type;
194 data_type_t diff_data_type;
195};
196bool is_ldigo(const memory_desc_wrapper &md);
197bool is_ldgoi(const memory_desc_wrapper &md);
198
199int get_good_ld(int arch_ld, int dim, int sizeof_dt);
200void init_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
201 const memory_desc_wrapper &src_layer_d,
202 const memory_desc_wrapper &src_iter_d,
203 const memory_desc_wrapper &weights_layer_d,
204 const memory_desc_wrapper &weights_iter_d,
205 const memory_desc_wrapper &dst_layer_d, bool is_xe_hpc);
206void init_test_mode(conf_t &rnn, const primitive_attr_t &attr);
207void set_rnn_conf(conf_t &rnn, const rnn_desc_t &rd,
208 const memory_desc_wrapper &weights_layer_d,
209 const memory_desc_wrapper &weights_iter_d,
210 const memory_desc_wrapper &diff_weights_layer_d,
211 const memory_desc_wrapper &diff_weights_iter_d);
212void set_offsets(const conf_t &rnn, size_t &ws_gates_offset,
213 size_t &ws_h_state_offset, size_t &ws_c_state_offset,
214 size_t &ws_grid_comp_offset, size_t &ws_bias_offset,
215 size_t &scratch_diff_states_offset, size_t &scratch_cell_offset,
216 size_t &scratch_dhG1_offset, size_t &scratch_gates_offset,
217 size_t &scratchpad_size, size_t &workspace_size);
218void set_gru_offsets_part2(const conf_t &rnn, int iter, int dir, int lay,
219 data_type_t src_t, size_t *wei_iter_off_ptr,
220 const size_t &ws_states_offset_, size_t &cell_wei_iter_offset,
221 size_t &cell_scratch_offset, size_t &cell_ws_iter_offset);
222void set_offsets_fwd_gemm(const conf_t &rnn, int dir, int lay,
223 data_type_t src_t, size_t *wei_layer_off_ptr,
224 const size_t &ws_states_offset_, size_t &grid_ws_lay_offset,
225 size_t &grid_wei_lay_offset, size_t &grid_ws_iter_offset);
226void set_offsets_fwd_gemm(const conf_t &rnn, int iter, int dir, int lay,
227 data_type_t src_t, size_t *wei_iter_off_ptr,
228 const size_t &ws_states_offset_, size_t &cell_ws_iter_offset,
229 size_t &cell_ws_lay_offset, size_t &cell_scratch_offset,
230 size_t &cell_wei_iter_offset);
231void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay,
232 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
233 size_t &cell_scr_diff_lay_off, size_t &cell_scr_diff_iter_off);
234void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay,
235 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
236 size_t &cell_scr_diff_lay_off, size_t &cell_scr_diff_iter_off,
237 size_t &cell_diff_wei_iter_off2);
238void set_offsets_bwd_gemm(const conf_t &rnn, int iter, int dir, int lay,
239 size_t &cell_diff_wei_iter_off, size_t &cell_diff_wei_lay_off,
240 size_t &cell_scr_diff_lay_off);
241void get_scratchpad_and_workspace_sizes(
242 const conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size);
243status_t set_expected_desc(
244 conf_t &rnn, memory_desc_t &weights_md, bool is_iter);
245status_t set_good_strides(int ld_, memory_desc_t &weights_md, format_tag_t tag);
246} // namespace rnn_utils
247
248} // namespace ocl
249} // namespace gpu
250} // namespace impl
251} // namespace dnnl
252
253#endif
254