1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_KERNEL_INFO_HPP
18#define GPU_JIT_IR_KERNEL_INFO_HPP
19
20#include <limits>
21#include <memory>
22#include <string>
23#include <vector>
24
25#include "common/c_types_map.hpp"
26#include "common/primitive_exec_types.hpp"
27#include "gpu/compute/compute.hpp"
28#include "gpu/gpu_primitive.hpp"
29#include "gpu/jit/ir/ir.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace jit {
35
36class memory_storage_ptr_t {
37public:
38 memory_storage_ptr_t(std::unique_ptr<memory_storage_t> &&ptr)
39 : unique_ptr_(std::move(ptr)) {}
40 memory_storage_ptr_t(const memory_storage_t *ptr) : raw_ptr_(ptr) {}
41 memory_storage_ptr_t(const memory_storage_ptr_t &) = delete;
42
43 const memory_storage_t *get() const {
44 if (unique_ptr_) return unique_ptr_.get();
45 return raw_ptr_;
46 }
47
48private:
49 std::unique_ptr<memory_storage_t> unique_ptr_; // Owning pointer.
50 const memory_storage_t *raw_ptr_ = nullptr; // Non-owning pointer.
51};
52
53class memory_storage_wrapper_t {
54public:
55 memory_storage_wrapper_t() = default;
56 memory_storage_wrapper_t(std::unique_ptr<memory_storage_t> &&ptr)
57 : ptr_(new memory_storage_ptr_t(std::move(ptr))) {}
58 memory_storage_wrapper_t(const memory_storage_t *ptr)
59 : ptr_(new memory_storage_ptr_t(ptr)) {}
60 memory_storage_wrapper_t(const memory_storage_t &ref)
61 : memory_storage_wrapper_t(&ref) {}
62
63 const memory_storage_t *get() const {
64 if (!ptr_) return nullptr;
65 return ptr_.get()->get();
66 }
67
68private:
69 std::shared_ptr<memory_storage_ptr_t> ptr_;
70};
71
72enum class kernel_id_t {
73 convolution,
74 pre_reorder,
75 post_reorder,
76 zero_out,
77};
78
79// Kernel information, includes:
80// - Kernel identifier
81// - Kernel arguments
82// - ND-range for submission (optional)
83// Kernel arguments can be of the following kinds:
84// - Internal arguments: only scalar
85// - Examples: common output scales (contain a single value)
86// - Resource arguments: stored to a resource storage during primitive creation
87// - Examples: output scales or zero points
88// - User arguments: passed by the user at run time
89// - Examples: source, weights, destination
90class kernel_info_t {
91public:
92 void set_id(kernel_id_t id) { id_ = id; }
93
94 kernel_id_t id() const { return id_; }
95
96 // Returns stage ID, kernels with smaller stage IDs are executed first.
97 int stage_id() const {
98 switch (id()) {
99 case kernel_id_t::convolution: return 1;
100 case kernel_id_t::pre_reorder: return 0;
101 case kernel_id_t::post_reorder: return 2;
102 case kernel_id_t::zero_out: return 0;
103 default: ir_error_not_expected();
104 }
105 return -1;
106 }
107
108 void set_nd_range(const compute::nd_range_t &nd_range) {
109 nd_range_ = nd_range;
110 }
111
112 const compute::nd_range_t &nd_range() const { return nd_range_; }
113
114 void register_internal_arg(const expr_t &var, const expr_t &value) {
115 register_arg(var, arg_kind_t::internal, -1, /*is_input=*/true, value);
116 }
117
118 void register_resource_arg(const expr_t &var) {
119 // TODO: Check key uniqueness.
120 register_arg(var, arg_kind_t::resource, nargs(), /*is_input=*/true);
121 }
122
123 void register_user_arg(const expr_t &var, int dnnl_arg, bool is_input) {
124 register_arg(var, arg_kind_t::user, dnnl_arg, is_input);
125 }
126
127 void register_scratchpad_arg(
128 const expr_t &var, int key, bool is_input, size_t size) {
129 register_arg(
130 var, arg_kind_t::scratchpad, key, is_input, expr_t(), size);
131 }
132
133 const std::string &arg_name(int idx) const {
134 ir_assert(idx >= 0 && idx < nargs());
135 return args_[idx].var.as<var_t>().name;
136 }
137
138 const expr_t &arg_var(int idx) const {
139 ir_assert(idx >= 0 && idx < nargs());
140 return args_[idx].var;
141 }
142
143 const type_t &arg_type(int idx) const { return arg_var(idx).type(); }
144
145 expr_t find_arg(const std::string &name, bool allow_empty = false) const {
146 for (int i = 0; i < nargs(); i++) {
147 if (arg_name(i) == name) return args_[i].var;
148 }
149 if (!allow_empty)
150 ir_error_not_expected() << "Argument not found: " << name;
151 return expr_t();
152 }
153
154 int key(int idx) const {
155 ir_assert(idx >= 0 && idx < nargs());
156 return args_[idx].key;
157 }
158
159 int key(const std::string &name) const {
160 for (int i = 0; i < nargs(); i++) {
161 if (arg_name(i) == name) return key(i);
162 }
163 ir_error_not_expected() << "Argument not found: " << name;
164 return -1;
165 }
166
167 int nargs() const { return int(args_.size()); }
168
169 bool is_resource(int idx) const {
170 ir_assert(idx >= 0 && idx < nargs());
171 return args_[idx].kind == arg_kind_t::resource;
172 }
173
174 bool is_scratchpad(int idx) const {
175 ir_assert(idx >= 0 && idx < nargs());
176 return args_[idx].kind == arg_kind_t::scratchpad;
177 }
178
179 bool is_user(int idx) const {
180 ir_assert(idx >= 0 && idx < nargs());
181 return args_[idx].kind == arg_kind_t::user;
182 }
183
184 bool is_input(int idx) const {
185 ir_assert(idx >= 0 && idx < nargs());
186 return args_[idx].is_input;
187 }
188
189 bool is_output(int idx) const { return !is_input(idx); }
190
191 memory_storage_wrapper_t arg_storage(int idx, const exec_ctx_t &ctx,
192 const gpu_primitive_t *primitive) const {
193 ir_assert(idx >= 0 && idx < nargs());
194 bool is_input = args_[idx].is_input;
195 int key = args_[idx].key;
196 switch (args_[idx].kind) {
197 case arg_kind_t::resource:
198 return *(primitive->cached_mapper()
199 ->template get<gpu_resource_t>(primitive)
200 ->get_memory_storage(key));
201 case arg_kind_t::scratchpad:
202 return ctx.get_scratchpad_grantor().get_memory_storage(key);
203 case arg_kind_t::user: {
204 if (is_input)
205 return ctx.input(args_[idx].key)->memory_storage();
206 return ctx.output(args_[idx].key)->memory_storage();
207 }
208 // No storage for internal arguments.
209 case arg_kind_t::internal: return memory_storage_wrapper_t();
210 default: ir_error_not_expected();
211 }
212 return memory_storage_wrapper_t();
213 }
214
215 size_t arg_size(int idx, const gpu_primitive_t *primitive) const {
216 switch (args_[idx].kind) {
217 case arg_kind_t::user: {
218 auto *md = primitive->pd()->arg_md(key(idx));
219 return memory_desc_wrapper(md).size();
220 }
221 case arg_kind_t::scratchpad: return args_[idx].scratchpad_size;
222 default: ir_error_not_expected();
223 }
224 return std::numeric_limits<size_t>::max();
225 }
226
227 void init_memory_storage_list(std::vector<memory_storage_wrapper_t> &list,
228 const exec_ctx_t &ctx, const gpu_primitive_t *primitive) const {
229 list = std::vector<memory_storage_wrapper_t>(nargs());
230 for (int i = 0; i < nargs(); i++) {
231 list[i] = arg_storage(i, ctx, primitive);
232 }
233 }
234
235 void set_args(compute::kernel_arg_list_t &arg_list,
236 const std::vector<memory_storage_wrapper_t> &storage_list) const {
237 for (int i = 0; i < nargs(); i++) {
238 switch (args_[i].kind) {
239 case arg_kind_t::internal: {
240 auto &value = args_[i].value;
241 auto &type = args_[i].var.type();
242
243 do {
244#define CASE(ir_type, cpp_type) \
245 if (type == type_t::ir_type()) { \
246 arg_list.set(i, to_cpp<cpp_type>(value)); \
247 break; \
248 }
249
250 CASE(f32, float)
251 CASE(s16, int16_t)
252 CASE(s32, int32_t)
253 CASE(s64, int64_t)
254 CASE(u16, uint16_t)
255 CASE(u32, uint32_t)
256 CASE(u64, uint64_t)
257#undef CASE
258
259 ir_error_not_expected() << type;
260 } while (false);
261 break;
262 }
263 case arg_kind_t::resource:
264 case arg_kind_t::scratchpad:
265 case arg_kind_t::user: {
266 arg_list.set(i, *storage_list[i].get());
267 break;
268 }
269 default: ir_error_not_expected();
270 }
271 }
272 }
273
274private:
275 enum class arg_kind_t { internal, resource, scratchpad, user };
276
277 struct arg_t {
278 arg_t(const expr_t &var, arg_kind_t kind, int key, bool is_input,
279 const expr_t &value, size_t scratchpad_size)
280 : var(var)
281 , kind(kind)
282 , key(key)
283 , is_input(is_input)
284 , value(value)
285 , scratchpad_size(scratchpad_size) {}
286
287 expr_t var;
288 arg_kind_t kind;
289 int key; // Unique key across arguments with the same kind.
290 bool is_input;
291 expr_t value; // For internal arguments, must be a constant.
292 size_t scratchpad_size; // For scratchpad arguments only.
293 };
294
295 void register_arg(const expr_t &var, arg_kind_t kind, int key,
296 bool is_input, const expr_t &value = expr_t(),
297 size_t scratchpad_size = 0) {
298 ir_assert(is_var(var)) << "Expected var, got: " << var;
299 args_.emplace_back(var, kind, key, is_input, value, scratchpad_size);
300 }
301
302 kernel_id_t id_;
303 compute::nd_range_t nd_range_;
304
305 std::vector<arg_t> args_;
306};
307
308} // namespace jit
309} // namespace gpu
310} // namespace impl
311} // namespace dnnl
312
313#endif
314