1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_RNN_REF_RNN_HPP
18#define GPU_OCL_RNN_REF_RNN_HPP
19
20#include <assert.h>
21#include <stdio.h>
22
23#include "common/c_types_map.hpp"
24#include "common/primitive.hpp"
25#include "common/primitive_desc_iterator.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28#include "gpu/compute/compute.hpp"
29#include "gpu/gemm/gpu_gemm.hpp"
30#include "gpu/gpu_primitive.hpp"
31#include "gpu/gpu_resource.hpp"
32#include "gpu/gpu_rnn_pd.hpp"
33#include "gpu/ocl/ocl_memory_storage.hpp"
34#include "gpu/ocl/ocl_stream.hpp"
35#include "gpu/ocl/ocl_utils.hpp"
36#include "gpu/ocl/rnn/rnn_utils.hpp"
37#include "gpu/primitive_conf.hpp"
38
39// TODO just to debug
40#define WS_NAN_FILLING 0
41
42#define DEBUGPRINT 0
43
44namespace dnnl {
45namespace impl {
46namespace gpu {
47namespace ocl {
48
49enum gemm_kind_t {
50 gemm_iter_fwd,
51 gemm_iter_fwd_2,
52 gemm_layer_fwd,
53 gemm_iter_bwd,
54 gemm_iter_bwd_2,
55 gemm_layer_bwd,
56 gemm_diff_wei_iter,
57 gemm_diff_wei_iter_2,
58 gemm_diff_wei_layer
59};
60
61template <prop_kind_t aprop>
62struct _ref_rnn_common_t : public gpu_primitive_t {
63 using gpu_primitive_t::gpu_primitive_t;
64
65 using class_name = _ref_rnn_common_t<aprop>;
66
67 typedef elemwise_sig((class_name::*elemwise_f));
68 typedef elemwise_sig_gru((class_name::*elemwise_gru_f));
69 typedef elemwise_sig_gru_lbr((class_name::*elemwise_gru_lbr_f));
70 typedef cell_execution_sig((class_name::*cell_execution_f));
71 typedef grid_execution_sig((class_name::*grid_execution_f));
72 typedef gemm_sig((class_name::*gemm_t));
73 typedef weights_assign_sig((class_name::*weights_assign_t));
74
75 using base_pd_t =
76 typename utils::conditional<false || aprop == prop_kind::forward,
77 gpu_rnn_fwd_pd_t, gpu_rnn_bwd_pd_t>::type;
78 enum {
79 key_gemm_iter_fwd = memory_tracking::names::key_nested_multiple,
80 key_gemm_iter_fwd_2,
81 key_gemm_layer_fwd,
82 key_gemm_iter_bwd,
83 key_gemm_iter_bwd_2,
84 key_gemm_layer_bwd,
85 key_gemm_diff_wei_layer,
86 key_gemm_diff_wei_iter,
87 key_gemm_diff_wei_iter_2,
88 };
89
90 struct pd_t : public base_pd_t {
91
92 using base_pd_t::base_pd_t;
93
94 pd_t(const pd_t &other) = default;
95
96 DECLARE_COMMON_PD_T("ref:any", class_name);
97
98 status_t init(engine_t *engine);
99
100 status_t set_default_params();
101
102 rnn_conf_t conf;
103 rnn_offsets_t off;
104 rnn_utils::conf_t rnn_conf;
105 data_type_t acc_data_t;
106 data_type_t src_type;
107 data_type_t weights_type;
108 bool is_xe_hpc;
109 int subgroup_size;
110 int max_eus_per_wg;
111 bool use_subgroup_reduction;
112
113 std::shared_ptr<primitive_desc_t> gemm_iter_fwd_pd_;
114 std::shared_ptr<primitive_desc_t> gemm_iter_fwd_2_pd_;
115 std::shared_ptr<primitive_desc_t> gemm_layer_fwd_pd_;
116 std::shared_ptr<primitive_desc_t> gemm_iter_bwd_pd_;
117 std::shared_ptr<primitive_desc_t> gemm_iter_bwd_2_pd_;
118 std::shared_ptr<primitive_desc_t> gemm_layer_bwd_pd_;
119 std::shared_ptr<primitive_desc_t> gemm_diff_wei_layer_pd_;
120 std::shared_ptr<primitive_desc_t> gemm_diff_wei_iter_pd_;
121 std::shared_ptr<primitive_desc_t> gemm_diff_wei_iter_2_pd_;
122
123 private:
124 void init_scratchpad(size_t scratchpad_sz) {
125 using namespace memory_tracking::names;
126 auto scratchpad = this->scratchpad_registry().registrar();
127 scratchpad.book(key_rnn_space, scratchpad_sz, 1,
128 OCL_BUFFER_ALIGNMENT, 4096);
129 scratchpad.book(key_rnn_gates, rnn_conf.scratch_gates_size, 1,
130 OCL_BUFFER_ALIGNMENT, 4096);
131 scratchpad.book(key_rnn_cell, rnn_conf.scratch_cell_size, 1,
132 OCL_BUFFER_ALIGNMENT, 4096);
133 scratchpad.book(key_rnn_diff_states,
134 rnn_conf.scratch_diff_states_size, 1, OCL_BUFFER_ALIGNMENT,
135 4096);
136 scratchpad.book(key_rnn_diff_ht, rnn_conf.scratch_dhG1_size, 1,
137 OCL_BUFFER_ALIGNMENT, 4096);
138 // book scratchpad for nested primitives
139 switch (aprop) {
140 case prop_kind::forward:
141 scratchpad.book(key_gemm_iter_fwd,
142 gemm_iter_fwd_pd_->scratchpad_registry());
143 scratchpad.book(key_gemm_layer_fwd,
144 gemm_layer_fwd_pd_->scratchpad_registry());
145 if (conf.is_vanilla_gru)
146 scratchpad.book(key_gemm_iter_fwd_2,
147 gemm_iter_fwd_2_pd_->scratchpad_registry());
148 break;
149 case prop_kind::backward:
150 scratchpad.book(key_gemm_iter_bwd,
151 gemm_iter_bwd_pd_->scratchpad_registry());
152 scratchpad.book(key_gemm_layer_bwd,
153 gemm_layer_bwd_pd_->scratchpad_registry());
154 scratchpad.book(key_gemm_diff_wei_layer,
155 gemm_diff_wei_layer_pd_->scratchpad_registry());
156 scratchpad.book(key_gemm_diff_wei_iter,
157 gemm_diff_wei_iter_pd_->scratchpad_registry());
158 if (conf.is_vanilla_gru) {
159 scratchpad.book(key_gemm_iter_bwd_2,
160 gemm_iter_bwd_2_pd_->scratchpad_registry());
161 scratchpad.book(key_gemm_diff_wei_iter_2,
162 gemm_diff_wei_iter_2_pd_
163 ->scratchpad_registry());
164 }
165 break;
166 default: assert(!"unknown prop_kind");
167 }
168 }
169 }; // struct pd_t : public base_pd_t
170
171 status_t init(engine_t *engine) override;
172
173 ~_ref_rnn_common_t() {
174 free(wei_layer_offset_ptr);
175 free(wei_iter_offset_ptr);
176 }
177
178 status_t execute(const exec_ctx_t &ctx) const override {
179 return execute_(ctx);
180 }
181
182protected:
183 status_t init_res_storage(
184 engine_t *engine, gpu_resource_t *r) const override;
185
186private:
187 status_t execute_(const exec_ctx_t &ctx) const;
188 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
189
190 compute::nd_range_t get_nd_range(std::vector<int> gws) const {
191 // Try to schedule one local thread per eu
192 int subgroup_size = pd()->subgroup_size;
193 int lws_max = pd()->max_eus_per_wg * subgroup_size;
194 std::vector<int> lws;
195 lws.reserve(gws.size());
196 for (int i = 0; i < (int)gws.size(); i++) {
197 int l_dim = 2 * gws[i] < lws_max ? utils::rnd_up_pow2(gws[i])
198 : lws_max;
199 if (i == 0 && l_dim < subgroup_size) l_dim = subgroup_size;
200 lws.emplace_back(l_dim);
201 gws[i] = utils::rnd_up(gws[i], l_dim);
202 lws_max = lws_max / l_dim;
203 }
204
205 return compute::nd_range_t(gws, lws);
206 }
207
208 // set the class names
209 grid_execution_sig(linear_execution);
210
211 cell_execution_sig(cell_execution);
212 cell_execution_sig(cell_execution_gru);
213 cell_execution_sig(cell_execution_gru_lbr);
214
215 elemwise_sig(rnn_elemwise);
216 elemwise_sig(lstm_elemwise);
217 elemwise_sig(lstm_elemwise_u8s8);
218 elemwise_sig_gru(gru_elemwise);
219 elemwise_sig_gru_lbr(gru_lbr_elemwise);
220
221 gemm_sig(gemm_primitive);
222
223 weights_assign_sig(assign_weights);
224
225 float (*activation_func)(float dd, float s, float alpha, float cliping);
226 void bias_prepare(const exec_ctx_t &ctx,
227 compute::compute_stream_t *compute_stream, int n_layer, int n_dir,
228 int n_bias, int n_gates, int dhc, const memory_storage_t &ws,
229 const memory_storage_t &scales, const memory_storage_t &wei_layer,
230 const memory_storage_t &wei_iter,
231 const memory_storage_t &bias) const;
232 void copy_init_layer(const exec_ctx_t &ctx,
233 compute::compute_stream_t *compute_stream, bool lr, bool rl,
234 int n_iter, int batch, int slc, const memory_storage_t &ws,
235 const memory_storage_t &scratch_diff_states,
236 const memory_storage_t &input,
237 const memory_storage_t &diff_dst_layer) const;
238 void copy_init_iter(const exec_ctx_t &ctx,
239 compute::compute_stream_t *compute_stream, int n_layer, int n_dir,
240 int batch, int sic, int dhc, const memory_storage_t &ws,
241 const memory_storage_t &scratch_diff_states,
242 const memory_storage_t &firstit_states,
243 const memory_storage_t &firstit_c_states,
244 const memory_storage_t &diff_dst_iter,
245 const memory_storage_t &diff_dst_iter_c, const float shift,
246 const float scale, const bool quantize) const;
247 void copy_res_layer(const exec_ctx_t &ctx,
248 compute::compute_stream_t *compute_stream, bool lr, bool rl,
249 int n_iter, int batch, int slc, int dlc,
250 const memory_storage_t &scratch_diff_states,
251 const memory_storage_t &dst_last_layer,
252 const memory_storage_t &diff_src_layer, const memory_storage_t &ws,
253 const float shift, const float scale, const bool dequantize) const;
254 void copy_res_iter(const exec_ctx_t &ctx,
255 compute::compute_stream_t *compute_stream, int n_layer, int n_dir,
256 int batch, int sic, int dhc,
257 const memory_storage_t &scratch_diff_states,
258 const memory_storage_t &dst_last_iter,
259 const memory_storage_t &dst_last_iter_c,
260 const memory_storage_t &diff_src_iter,
261 const memory_storage_t &diff_src_iter_c, const memory_storage_t &ws,
262 const float shift, const float scale, const bool dequantize) const;
263 void gates_reduction(const exec_ctx_t &ctx, int dir, int lay, int iter,
264 int n_gates, int dhc, int batch, const memory_storage_t &gates,
265 const memory_storage_t &cell,
266 const memory_storage_t &diff_bias) const;
267 void ws_set(const exec_ctx_t &ctx,
268 compute::compute_stream_t *compute_stream,
269 const memory_storage_t &workspace, const cl_ulong ws_offset,
270 const int ws_part, const float val, const size_t size) const;
271#if DEBUGPRINT
272 void ws_print(const exec_ctx_t &ctx, compute::compute_stream_t *s,
273 const memory_storage_t &workspace) const;
274 compute::kernel_t ws_print_kernel_;
275#endif
276
277 compute::kernel_t bias_prepare_kernel_;
278 compute::kernel_t copy_init_layer_kernel_;
279 compute::kernel_t copy_init_iter_kernel_;
280 compute::kernel_t copy_res_layer_kernel_;
281 compute::kernel_t copy_res_iter_kernel_;
282
283 compute::kernel_t ws_set_kernel_;
284 compute::kernel_t elemwise_fwd_kernel_;
285 compute::kernel_t elemwise_bwd_kernel_;
286 compute::kernel_t gates_reduction_kernel_;
287
288 // ptrs to GEMM primitives
289 std::shared_ptr<primitive_t> gemm_layer_fwd_;
290 std::shared_ptr<primitive_t> gemm_iter_fwd_;
291 std::shared_ptr<primitive_t> gemm_iter_fwd_2_;
292 std::shared_ptr<primitive_t> gemm_layer_bwd_;
293 std::shared_ptr<primitive_t> gemm_iter_bwd_;
294 std::shared_ptr<primitive_t> gemm_iter_bwd_2_;
295 std::shared_ptr<primitive_t> gemm_diff_wei_layer_;
296 std::shared_ptr<primitive_t> gemm_diff_wei_iter_;
297 std::shared_ptr<primitive_t> gemm_diff_wei_iter_2_;
298
299 // offset variables set in workspace and used in offset calculations for
300 // grid & cell execution and fwd & bwd kernel macros
301 cl_ulong ws_gates_offset_;
302 cl_ulong ws_states_offset_;
303 cl_ulong ws_c_states_offset_;
304 cl_ulong ws_grid_comp_offset_;
305 cl_ulong ws_bias_offset_;
306 cl_ulong scratch_dhG1_offset_;
307 cl_ulong scratch_cell_offset_;
308 cl_ulong scratch_gates_offset_;
309 cl_ulong scratch_diff_states_offset_;
310
311 // ptrs for storing weight offsets which are pre-calculated in
312 // in grid execution as weights_*_assing_func
313 size_t *wei_layer_offset_ptr;
314 size_t *wei_iter_offset_ptr;
315
316 grid_execution_f grid_computation;
317 cell_execution_f cell_func;
318
319 weights_assign_t weights_layer_assign_func;
320 weights_assign_t weights_iter_assign_func;
321
322 gemm_t gemm_iter_func;
323 gemm_t gemm_layer_func;
324 elemwise_f elemwise_common;
325 elemwise_gru_f elemwise_gru;
326 elemwise_gru_lbr_f elemwise_gru_lbr;
327
328 enum { SCALES_ = 0, TM_SCALES_ = 1 };
329};
330using ref_rnn_fwd_t = _ref_rnn_common_t<prop_kind::forward>;
331using ref_rnn_bwd_t = _ref_rnn_common_t<prop_kind::backward>;
332} // namespace ocl
333} // namespace gpu
334} // namespace impl
335} // namespace dnnl
336#endif
337
338// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
339