1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_KERNEL_INFO_HPP |
18 | #define GPU_JIT_IR_KERNEL_INFO_HPP |
19 | |
20 | #include <limits> |
21 | #include <memory> |
22 | #include <string> |
23 | #include <vector> |
24 | |
25 | #include "common/c_types_map.hpp" |
26 | #include "common/primitive_exec_types.hpp" |
27 | #include "gpu/compute/compute.hpp" |
28 | #include "gpu/gpu_primitive.hpp" |
29 | #include "gpu/jit/ir/ir.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | class memory_storage_ptr_t { |
37 | public: |
38 | memory_storage_ptr_t(std::unique_ptr<memory_storage_t> &&ptr) |
39 | : unique_ptr_(std::move(ptr)) {} |
40 | memory_storage_ptr_t(const memory_storage_t *ptr) : raw_ptr_(ptr) {} |
41 | memory_storage_ptr_t(const memory_storage_ptr_t &) = delete; |
42 | |
43 | const memory_storage_t *get() const { |
44 | if (unique_ptr_) return unique_ptr_.get(); |
45 | return raw_ptr_; |
46 | } |
47 | |
48 | private: |
49 | std::unique_ptr<memory_storage_t> unique_ptr_; // Owning pointer. |
50 | const memory_storage_t *raw_ptr_ = nullptr; // Non-owning pointer. |
51 | }; |
52 | |
53 | class memory_storage_wrapper_t { |
54 | public: |
55 | memory_storage_wrapper_t() = default; |
56 | memory_storage_wrapper_t(std::unique_ptr<memory_storage_t> &&ptr) |
57 | : ptr_(new memory_storage_ptr_t(std::move(ptr))) {} |
58 | memory_storage_wrapper_t(const memory_storage_t *ptr) |
59 | : ptr_(new memory_storage_ptr_t(ptr)) {} |
60 | memory_storage_wrapper_t(const memory_storage_t &ref) |
61 | : memory_storage_wrapper_t(&ref) {} |
62 | |
63 | const memory_storage_t *get() const { |
64 | if (!ptr_) return nullptr; |
65 | return ptr_.get()->get(); |
66 | } |
67 | |
68 | private: |
69 | std::shared_ptr<memory_storage_ptr_t> ptr_; |
70 | }; |
71 | |
72 | enum class kernel_id_t { |
73 | convolution, |
74 | pre_reorder, |
75 | post_reorder, |
76 | zero_out, |
77 | }; |
78 | |
79 | // Kernel information, includes: |
80 | // - Kernel identifier |
81 | // - Kernel arguments |
82 | // - ND-range for submission (optional) |
83 | // Kernel arguments can be of the following kinds: |
84 | // - Internal arguments: only scalar |
85 | // - Examples: common output scales (contain a single value) |
86 | // - Resource arguments: stored to a resource storage during primitive creation |
87 | // - Examples: output scales or zero points |
88 | // - User arguments: passed by the user at run time |
89 | // - Examples: source, weights, destination |
90 | class kernel_info_t { |
91 | public: |
92 | void set_id(kernel_id_t id) { id_ = id; } |
93 | |
94 | kernel_id_t id() const { return id_; } |
95 | |
96 | // Returns stage ID, kernels with smaller stage IDs are executed first. |
97 | int stage_id() const { |
98 | switch (id()) { |
99 | case kernel_id_t::convolution: return 1; |
100 | case kernel_id_t::pre_reorder: return 0; |
101 | case kernel_id_t::post_reorder: return 2; |
102 | case kernel_id_t::zero_out: return 0; |
103 | default: ir_error_not_expected(); |
104 | } |
105 | return -1; |
106 | } |
107 | |
108 | void set_nd_range(const compute::nd_range_t &nd_range) { |
109 | nd_range_ = nd_range; |
110 | } |
111 | |
112 | const compute::nd_range_t &nd_range() const { return nd_range_; } |
113 | |
114 | void register_internal_arg(const expr_t &var, const expr_t &value) { |
115 | register_arg(var, arg_kind_t::internal, -1, /*is_input=*/true, value); |
116 | } |
117 | |
118 | void register_resource_arg(const expr_t &var) { |
119 | // TODO: Check key uniqueness. |
120 | register_arg(var, arg_kind_t::resource, nargs(), /*is_input=*/true); |
121 | } |
122 | |
123 | void register_user_arg(const expr_t &var, int dnnl_arg, bool is_input) { |
124 | register_arg(var, arg_kind_t::user, dnnl_arg, is_input); |
125 | } |
126 | |
127 | void register_scratchpad_arg( |
128 | const expr_t &var, int key, bool is_input, size_t size) { |
129 | register_arg( |
130 | var, arg_kind_t::scratchpad, key, is_input, expr_t(), size); |
131 | } |
132 | |
133 | const std::string &arg_name(int idx) const { |
134 | ir_assert(idx >= 0 && idx < nargs()); |
135 | return args_[idx].var.as<var_t>().name; |
136 | } |
137 | |
138 | const expr_t &arg_var(int idx) const { |
139 | ir_assert(idx >= 0 && idx < nargs()); |
140 | return args_[idx].var; |
141 | } |
142 | |
143 | const type_t &arg_type(int idx) const { return arg_var(idx).type(); } |
144 | |
145 | expr_t find_arg(const std::string &name, bool allow_empty = false) const { |
146 | for (int i = 0; i < nargs(); i++) { |
147 | if (arg_name(i) == name) return args_[i].var; |
148 | } |
149 | if (!allow_empty) |
150 | ir_error_not_expected() << "Argument not found: " << name; |
151 | return expr_t(); |
152 | } |
153 | |
154 | int key(int idx) const { |
155 | ir_assert(idx >= 0 && idx < nargs()); |
156 | return args_[idx].key; |
157 | } |
158 | |
159 | int key(const std::string &name) const { |
160 | for (int i = 0; i < nargs(); i++) { |
161 | if (arg_name(i) == name) return key(i); |
162 | } |
163 | ir_error_not_expected() << "Argument not found: " << name; |
164 | return -1; |
165 | } |
166 | |
167 | int nargs() const { return int(args_.size()); } |
168 | |
169 | bool is_resource(int idx) const { |
170 | ir_assert(idx >= 0 && idx < nargs()); |
171 | return args_[idx].kind == arg_kind_t::resource; |
172 | } |
173 | |
174 | bool is_scratchpad(int idx) const { |
175 | ir_assert(idx >= 0 && idx < nargs()); |
176 | return args_[idx].kind == arg_kind_t::scratchpad; |
177 | } |
178 | |
179 | bool is_user(int idx) const { |
180 | ir_assert(idx >= 0 && idx < nargs()); |
181 | return args_[idx].kind == arg_kind_t::user; |
182 | } |
183 | |
184 | bool is_input(int idx) const { |
185 | ir_assert(idx >= 0 && idx < nargs()); |
186 | return args_[idx].is_input; |
187 | } |
188 | |
189 | bool is_output(int idx) const { return !is_input(idx); } |
190 | |
191 | memory_storage_wrapper_t arg_storage(int idx, const exec_ctx_t &ctx, |
192 | const gpu_primitive_t *primitive) const { |
193 | ir_assert(idx >= 0 && idx < nargs()); |
194 | bool is_input = args_[idx].is_input; |
195 | int key = args_[idx].key; |
196 | switch (args_[idx].kind) { |
197 | case arg_kind_t::resource: |
198 | return *(primitive->cached_mapper() |
199 | ->template get<gpu_resource_t>(primitive) |
200 | ->get_memory_storage(key)); |
201 | case arg_kind_t::scratchpad: |
202 | return ctx.get_scratchpad_grantor().get_memory_storage(key); |
203 | case arg_kind_t::user: { |
204 | if (is_input) |
205 | return ctx.input(args_[idx].key)->memory_storage(); |
206 | return ctx.output(args_[idx].key)->memory_storage(); |
207 | } |
208 | // No storage for internal arguments. |
209 | case arg_kind_t::internal: return memory_storage_wrapper_t(); |
210 | default: ir_error_not_expected(); |
211 | } |
212 | return memory_storage_wrapper_t(); |
213 | } |
214 | |
215 | size_t arg_size(int idx, const gpu_primitive_t *primitive) const { |
216 | switch (args_[idx].kind) { |
217 | case arg_kind_t::user: { |
218 | auto *md = primitive->pd()->arg_md(key(idx)); |
219 | return memory_desc_wrapper(md).size(); |
220 | } |
221 | case arg_kind_t::scratchpad: return args_[idx].scratchpad_size; |
222 | default: ir_error_not_expected(); |
223 | } |
224 | return std::numeric_limits<size_t>::max(); |
225 | } |
226 | |
227 | void init_memory_storage_list(std::vector<memory_storage_wrapper_t> &list, |
228 | const exec_ctx_t &ctx, const gpu_primitive_t *primitive) const { |
229 | list = std::vector<memory_storage_wrapper_t>(nargs()); |
230 | for (int i = 0; i < nargs(); i++) { |
231 | list[i] = arg_storage(i, ctx, primitive); |
232 | } |
233 | } |
234 | |
235 | void set_args(compute::kernel_arg_list_t &arg_list, |
236 | const std::vector<memory_storage_wrapper_t> &storage_list) const { |
237 | for (int i = 0; i < nargs(); i++) { |
238 | switch (args_[i].kind) { |
239 | case arg_kind_t::internal: { |
240 | auto &value = args_[i].value; |
241 | auto &type = args_[i].var.type(); |
242 | |
243 | do { |
244 | #define CASE(ir_type, cpp_type) \ |
245 | if (type == type_t::ir_type()) { \ |
246 | arg_list.set(i, to_cpp<cpp_type>(value)); \ |
247 | break; \ |
248 | } |
249 | |
250 | CASE(f32, float) |
251 | CASE(s16, int16_t) |
252 | CASE(s32, int32_t) |
253 | CASE(s64, int64_t) |
254 | CASE(u16, uint16_t) |
255 | CASE(u32, uint32_t) |
256 | CASE(u64, uint64_t) |
257 | #undef CASE |
258 | |
259 | ir_error_not_expected() << type; |
260 | } while (false); |
261 | break; |
262 | } |
263 | case arg_kind_t::resource: |
264 | case arg_kind_t::scratchpad: |
265 | case arg_kind_t::user: { |
266 | arg_list.set(i, *storage_list[i].get()); |
267 | break; |
268 | } |
269 | default: ir_error_not_expected(); |
270 | } |
271 | } |
272 | } |
273 | |
274 | private: |
275 | enum class arg_kind_t { internal, resource, scratchpad, user }; |
276 | |
277 | struct arg_t { |
278 | arg_t(const expr_t &var, arg_kind_t kind, int key, bool is_input, |
279 | const expr_t &value, size_t scratchpad_size) |
280 | : var(var) |
281 | , kind(kind) |
282 | , key(key) |
283 | , is_input(is_input) |
284 | , value(value) |
285 | , scratchpad_size(scratchpad_size) {} |
286 | |
287 | expr_t var; |
288 | arg_kind_t kind; |
289 | int key; // Unique key across arguments with the same kind. |
290 | bool is_input; |
291 | expr_t value; // For internal arguments, must be a constant. |
292 | size_t scratchpad_size; // For scratchpad arguments only. |
293 | }; |
294 | |
295 | void register_arg(const expr_t &var, arg_kind_t kind, int key, |
296 | bool is_input, const expr_t &value = expr_t(), |
297 | size_t scratchpad_size = 0) { |
298 | ir_assert(is_var(var)) << "Expected var, got: " << var; |
299 | args_.emplace_back(var, kind, key, is_input, value, scratchpad_size); |
300 | } |
301 | |
302 | kernel_id_t id_; |
303 | compute::nd_range_t nd_range_; |
304 | |
305 | std::vector<arg_t> args_; |
306 | }; |
307 | |
308 | } // namespace jit |
309 | } // namespace gpu |
310 | } // namespace impl |
311 | } // namespace dnnl |
312 | |
313 | #endif |
314 | |