1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_RNN_REF_RNN_HPP |
18 | #define GPU_OCL_RNN_REF_RNN_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <stdio.h> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/primitive.hpp" |
25 | #include "common/primitive_desc_iterator.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | #include "gpu/compute/compute.hpp" |
29 | #include "gpu/gemm/gpu_gemm.hpp" |
30 | #include "gpu/gpu_primitive.hpp" |
31 | #include "gpu/gpu_resource.hpp" |
32 | #include "gpu/gpu_rnn_pd.hpp" |
33 | #include "gpu/ocl/ocl_memory_storage.hpp" |
34 | #include "gpu/ocl/ocl_stream.hpp" |
35 | #include "gpu/ocl/ocl_utils.hpp" |
36 | #include "gpu/ocl/rnn/rnn_utils.hpp" |
37 | #include "gpu/primitive_conf.hpp" |
38 | |
39 | // TODO just to debug |
40 | #define WS_NAN_FILLING 0 |
41 | |
42 | #define DEBUGPRINT 0 |
43 | |
44 | namespace dnnl { |
45 | namespace impl { |
46 | namespace gpu { |
47 | namespace ocl { |
48 | |
49 | enum gemm_kind_t { |
50 | gemm_iter_fwd, |
51 | gemm_iter_fwd_2, |
52 | gemm_layer_fwd, |
53 | gemm_iter_bwd, |
54 | gemm_iter_bwd_2, |
55 | gemm_layer_bwd, |
56 | gemm_diff_wei_iter, |
57 | gemm_diff_wei_iter_2, |
58 | gemm_diff_wei_layer |
59 | }; |
60 | |
61 | template <prop_kind_t aprop> |
62 | struct _ref_rnn_common_t : public gpu_primitive_t { |
63 | using gpu_primitive_t::gpu_primitive_t; |
64 | |
65 | using class_name = _ref_rnn_common_t<aprop>; |
66 | |
67 | typedef elemwise_sig((class_name::*elemwise_f)); |
68 | typedef elemwise_sig_gru((class_name::*elemwise_gru_f)); |
69 | typedef elemwise_sig_gru_lbr((class_name::*elemwise_gru_lbr_f)); |
70 | typedef cell_execution_sig((class_name::*cell_execution_f)); |
71 | typedef grid_execution_sig((class_name::*grid_execution_f)); |
72 | typedef gemm_sig((class_name::*gemm_t)); |
73 | typedef weights_assign_sig((class_name::*weights_assign_t)); |
74 | |
75 | using base_pd_t = |
76 | typename utils::conditional<false || aprop == prop_kind::forward, |
77 | gpu_rnn_fwd_pd_t, gpu_rnn_bwd_pd_t>::type; |
78 | enum { |
79 | key_gemm_iter_fwd = memory_tracking::names::key_nested_multiple, |
80 | key_gemm_iter_fwd_2, |
81 | key_gemm_layer_fwd, |
82 | key_gemm_iter_bwd, |
83 | key_gemm_iter_bwd_2, |
84 | key_gemm_layer_bwd, |
85 | key_gemm_diff_wei_layer, |
86 | key_gemm_diff_wei_iter, |
87 | key_gemm_diff_wei_iter_2, |
88 | }; |
89 | |
90 | struct pd_t : public base_pd_t { |
91 | |
92 | using base_pd_t::base_pd_t; |
93 | |
94 | pd_t(const pd_t &other) = default; |
95 | |
96 | DECLARE_COMMON_PD_T("ref:any" , class_name); |
97 | |
98 | status_t init(engine_t *engine); |
99 | |
100 | status_t set_default_params(); |
101 | |
102 | rnn_conf_t conf; |
103 | rnn_offsets_t off; |
104 | rnn_utils::conf_t rnn_conf; |
105 | data_type_t acc_data_t; |
106 | data_type_t src_type; |
107 | data_type_t weights_type; |
108 | bool is_xe_hpc; |
109 | int subgroup_size; |
110 | int max_eus_per_wg; |
111 | bool use_subgroup_reduction; |
112 | |
113 | std::shared_ptr<primitive_desc_t> gemm_iter_fwd_pd_; |
114 | std::shared_ptr<primitive_desc_t> gemm_iter_fwd_2_pd_; |
115 | std::shared_ptr<primitive_desc_t> gemm_layer_fwd_pd_; |
116 | std::shared_ptr<primitive_desc_t> gemm_iter_bwd_pd_; |
117 | std::shared_ptr<primitive_desc_t> gemm_iter_bwd_2_pd_; |
118 | std::shared_ptr<primitive_desc_t> gemm_layer_bwd_pd_; |
119 | std::shared_ptr<primitive_desc_t> gemm_diff_wei_layer_pd_; |
120 | std::shared_ptr<primitive_desc_t> gemm_diff_wei_iter_pd_; |
121 | std::shared_ptr<primitive_desc_t> gemm_diff_wei_iter_2_pd_; |
122 | |
123 | private: |
124 | void init_scratchpad(size_t scratchpad_sz) { |
125 | using namespace memory_tracking::names; |
126 | auto scratchpad = this->scratchpad_registry().registrar(); |
127 | scratchpad.book(key_rnn_space, scratchpad_sz, 1, |
128 | OCL_BUFFER_ALIGNMENT, 4096); |
129 | scratchpad.book(key_rnn_gates, rnn_conf.scratch_gates_size, 1, |
130 | OCL_BUFFER_ALIGNMENT, 4096); |
131 | scratchpad.book(key_rnn_cell, rnn_conf.scratch_cell_size, 1, |
132 | OCL_BUFFER_ALIGNMENT, 4096); |
133 | scratchpad.book(key_rnn_diff_states, |
134 | rnn_conf.scratch_diff_states_size, 1, OCL_BUFFER_ALIGNMENT, |
135 | 4096); |
136 | scratchpad.book(key_rnn_diff_ht, rnn_conf.scratch_dhG1_size, 1, |
137 | OCL_BUFFER_ALIGNMENT, 4096); |
138 | // book scratchpad for nested primitives |
139 | switch (aprop) { |
140 | case prop_kind::forward: |
141 | scratchpad.book(key_gemm_iter_fwd, |
142 | gemm_iter_fwd_pd_->scratchpad_registry()); |
143 | scratchpad.book(key_gemm_layer_fwd, |
144 | gemm_layer_fwd_pd_->scratchpad_registry()); |
145 | if (conf.is_vanilla_gru) |
146 | scratchpad.book(key_gemm_iter_fwd_2, |
147 | gemm_iter_fwd_2_pd_->scratchpad_registry()); |
148 | break; |
149 | case prop_kind::backward: |
150 | scratchpad.book(key_gemm_iter_bwd, |
151 | gemm_iter_bwd_pd_->scratchpad_registry()); |
152 | scratchpad.book(key_gemm_layer_bwd, |
153 | gemm_layer_bwd_pd_->scratchpad_registry()); |
154 | scratchpad.book(key_gemm_diff_wei_layer, |
155 | gemm_diff_wei_layer_pd_->scratchpad_registry()); |
156 | scratchpad.book(key_gemm_diff_wei_iter, |
157 | gemm_diff_wei_iter_pd_->scratchpad_registry()); |
158 | if (conf.is_vanilla_gru) { |
159 | scratchpad.book(key_gemm_iter_bwd_2, |
160 | gemm_iter_bwd_2_pd_->scratchpad_registry()); |
161 | scratchpad.book(key_gemm_diff_wei_iter_2, |
162 | gemm_diff_wei_iter_2_pd_ |
163 | ->scratchpad_registry()); |
164 | } |
165 | break; |
166 | default: assert(!"unknown prop_kind" ); |
167 | } |
168 | } |
169 | }; // struct pd_t : public base_pd_t |
170 | |
171 | status_t init(engine_t *engine) override; |
172 | |
173 | ~_ref_rnn_common_t() { |
174 | free(wei_layer_offset_ptr); |
175 | free(wei_iter_offset_ptr); |
176 | } |
177 | |
178 | status_t execute(const exec_ctx_t &ctx) const override { |
179 | return execute_(ctx); |
180 | } |
181 | |
182 | protected: |
183 | status_t init_res_storage( |
184 | engine_t *engine, gpu_resource_t *r) const override; |
185 | |
186 | private: |
187 | status_t execute_(const exec_ctx_t &ctx) const; |
188 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
189 | |
190 | compute::nd_range_t get_nd_range(std::vector<int> gws) const { |
191 | // Try to schedule one local thread per eu |
192 | int subgroup_size = pd()->subgroup_size; |
193 | int lws_max = pd()->max_eus_per_wg * subgroup_size; |
194 | std::vector<int> lws; |
195 | lws.reserve(gws.size()); |
196 | for (int i = 0; i < (int)gws.size(); i++) { |
197 | int l_dim = 2 * gws[i] < lws_max ? utils::rnd_up_pow2(gws[i]) |
198 | : lws_max; |
199 | if (i == 0 && l_dim < subgroup_size) l_dim = subgroup_size; |
200 | lws.emplace_back(l_dim); |
201 | gws[i] = utils::rnd_up(gws[i], l_dim); |
202 | lws_max = lws_max / l_dim; |
203 | } |
204 | |
205 | return compute::nd_range_t(gws, lws); |
206 | } |
207 | |
208 | // set the class names |
209 | grid_execution_sig(linear_execution); |
210 | |
211 | cell_execution_sig(cell_execution); |
212 | cell_execution_sig(cell_execution_gru); |
213 | cell_execution_sig(cell_execution_gru_lbr); |
214 | |
215 | elemwise_sig(rnn_elemwise); |
216 | elemwise_sig(lstm_elemwise); |
217 | elemwise_sig(lstm_elemwise_u8s8); |
218 | elemwise_sig_gru(gru_elemwise); |
219 | elemwise_sig_gru_lbr(gru_lbr_elemwise); |
220 | |
221 | gemm_sig(gemm_primitive); |
222 | |
223 | weights_assign_sig(assign_weights); |
224 | |
225 | float (*activation_func)(float dd, float s, float alpha, float cliping); |
226 | void bias_prepare(const exec_ctx_t &ctx, |
227 | compute::compute_stream_t *compute_stream, int n_layer, int n_dir, |
228 | int n_bias, int n_gates, int dhc, const memory_storage_t &ws, |
229 | const memory_storage_t &scales, const memory_storage_t &wei_layer, |
230 | const memory_storage_t &wei_iter, |
231 | const memory_storage_t &bias) const; |
232 | void copy_init_layer(const exec_ctx_t &ctx, |
233 | compute::compute_stream_t *compute_stream, bool lr, bool rl, |
234 | int n_iter, int batch, int slc, const memory_storage_t &ws, |
235 | const memory_storage_t &scratch_diff_states, |
236 | const memory_storage_t &input, |
237 | const memory_storage_t &diff_dst_layer) const; |
238 | void copy_init_iter(const exec_ctx_t &ctx, |
239 | compute::compute_stream_t *compute_stream, int n_layer, int n_dir, |
240 | int batch, int sic, int dhc, const memory_storage_t &ws, |
241 | const memory_storage_t &scratch_diff_states, |
242 | const memory_storage_t &firstit_states, |
243 | const memory_storage_t &firstit_c_states, |
244 | const memory_storage_t &diff_dst_iter, |
245 | const memory_storage_t &diff_dst_iter_c, const float shift, |
246 | const float scale, const bool quantize) const; |
247 | void copy_res_layer(const exec_ctx_t &ctx, |
248 | compute::compute_stream_t *compute_stream, bool lr, bool rl, |
249 | int n_iter, int batch, int slc, int dlc, |
250 | const memory_storage_t &scratch_diff_states, |
251 | const memory_storage_t &dst_last_layer, |
252 | const memory_storage_t &diff_src_layer, const memory_storage_t &ws, |
253 | const float shift, const float scale, const bool dequantize) const; |
254 | void copy_res_iter(const exec_ctx_t &ctx, |
255 | compute::compute_stream_t *compute_stream, int n_layer, int n_dir, |
256 | int batch, int sic, int dhc, |
257 | const memory_storage_t &scratch_diff_states, |
258 | const memory_storage_t &dst_last_iter, |
259 | const memory_storage_t &dst_last_iter_c, |
260 | const memory_storage_t &diff_src_iter, |
261 | const memory_storage_t &diff_src_iter_c, const memory_storage_t &ws, |
262 | const float shift, const float scale, const bool dequantize) const; |
263 | void gates_reduction(const exec_ctx_t &ctx, int dir, int lay, int iter, |
264 | int n_gates, int dhc, int batch, const memory_storage_t &gates, |
265 | const memory_storage_t &cell, |
266 | const memory_storage_t &diff_bias) const; |
267 | void ws_set(const exec_ctx_t &ctx, |
268 | compute::compute_stream_t *compute_stream, |
269 | const memory_storage_t &workspace, const cl_ulong ws_offset, |
270 | const int ws_part, const float val, const size_t size) const; |
271 | #if DEBUGPRINT |
272 | void ws_print(const exec_ctx_t &ctx, compute::compute_stream_t *s, |
273 | const memory_storage_t &workspace) const; |
274 | compute::kernel_t ws_print_kernel_; |
275 | #endif |
276 | |
277 | compute::kernel_t bias_prepare_kernel_; |
278 | compute::kernel_t copy_init_layer_kernel_; |
279 | compute::kernel_t copy_init_iter_kernel_; |
280 | compute::kernel_t copy_res_layer_kernel_; |
281 | compute::kernel_t copy_res_iter_kernel_; |
282 | |
283 | compute::kernel_t ws_set_kernel_; |
284 | compute::kernel_t elemwise_fwd_kernel_; |
285 | compute::kernel_t elemwise_bwd_kernel_; |
286 | compute::kernel_t gates_reduction_kernel_; |
287 | |
288 | // ptrs to GEMM primitives |
289 | std::shared_ptr<primitive_t> gemm_layer_fwd_; |
290 | std::shared_ptr<primitive_t> gemm_iter_fwd_; |
291 | std::shared_ptr<primitive_t> gemm_iter_fwd_2_; |
292 | std::shared_ptr<primitive_t> gemm_layer_bwd_; |
293 | std::shared_ptr<primitive_t> gemm_iter_bwd_; |
294 | std::shared_ptr<primitive_t> gemm_iter_bwd_2_; |
295 | std::shared_ptr<primitive_t> gemm_diff_wei_layer_; |
296 | std::shared_ptr<primitive_t> gemm_diff_wei_iter_; |
297 | std::shared_ptr<primitive_t> gemm_diff_wei_iter_2_; |
298 | |
299 | // offset variables set in workspace and used in offset calculations for |
300 | // grid & cell execution and fwd & bwd kernel macros |
301 | cl_ulong ws_gates_offset_; |
302 | cl_ulong ws_states_offset_; |
303 | cl_ulong ws_c_states_offset_; |
304 | cl_ulong ws_grid_comp_offset_; |
305 | cl_ulong ws_bias_offset_; |
306 | cl_ulong scratch_dhG1_offset_; |
307 | cl_ulong scratch_cell_offset_; |
308 | cl_ulong scratch_gates_offset_; |
309 | cl_ulong scratch_diff_states_offset_; |
310 | |
311 | // ptrs for storing weight offsets which are pre-calculated in |
312 | // in grid execution as weights_*_assing_func |
313 | size_t *wei_layer_offset_ptr; |
314 | size_t *wei_iter_offset_ptr; |
315 | |
316 | grid_execution_f grid_computation; |
317 | cell_execution_f cell_func; |
318 | |
319 | weights_assign_t weights_layer_assign_func; |
320 | weights_assign_t weights_iter_assign_func; |
321 | |
322 | gemm_t gemm_iter_func; |
323 | gemm_t gemm_layer_func; |
324 | elemwise_f elemwise_common; |
325 | elemwise_gru_f elemwise_gru; |
326 | elemwise_gru_lbr_f elemwise_gru_lbr; |
327 | |
328 | enum { SCALES_ = 0, TM_SCALES_ = 1 }; |
329 | }; |
330 | using ref_rnn_fwd_t = _ref_rnn_common_t<prop_kind::forward>; |
331 | using ref_rnn_bwd_t = _ref_rnn_common_t<prop_kind::backward>; |
332 | } // namespace ocl |
333 | } // namespace gpu |
334 | } // namespace impl |
335 | } // namespace dnnl |
336 | #endif |
337 | |
338 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
339 | |