1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_IR_IR_HPP |
18 | #define GPU_JIT_IR_IR_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <mutex> |
22 | #include <thread> |
23 | #include <vector> |
24 | |
25 | #include "common/optional.hpp" |
26 | #include "gpu/jit/ir/core.hpp" |
27 | #include "gpu/jit/ir/hw_config.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace jit { |
33 | |
34 | class constraint_set_t; |
35 | |
36 | class ir_context_t { |
37 | public: |
38 | ir_context_t(const exec_config_t &exec_cfg, constraint_set_t &cset) |
39 | : exec_cfg_(exec_cfg), cset_(cset) {} |
40 | |
41 | const exec_config_t &exec_cfg() const { return exec_cfg_; } |
42 | |
43 | const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); } |
44 | |
45 | ngen::HW hw() const { return hw_cfg().hw(); } |
46 | |
47 | int grf_size() const { return hw_cfg().grf_size(); } |
48 | |
49 | const constraint_set_t &cset() { return cset_; } |
50 | |
51 | void add_constraint(const expr_t &e); |
52 | |
53 | expr_t create_tmp_var( |
54 | const type_t &type, const std::string &prefix = "tmp" ) { |
55 | int &id = prefix_ids_[prefix]; |
56 | auto name = prefix + "_" + std::to_string(id); |
57 | id++; |
58 | return var_t::make(type, name); |
59 | } |
60 | |
61 | private: |
62 | exec_config_t exec_cfg_; |
63 | constraint_set_t &cset_; |
64 | std::unordered_map<std::string, int> prefix_ids_; |
65 | }; |
66 | |
67 | class alloc_updater_t : public ir_mutator_t { |
68 | public: |
69 | void resize(const expr_t &buf, int new_size) { |
70 | auto ret = resizes_.insert({buf, new_size}); |
71 | ir_assert(ret.second) << buf; |
72 | MAYBE_UNUSED(ret); |
73 | } |
74 | |
75 | void add_attr(const expr_t &buf, const alloc_attr_t &attr) { |
76 | auto ret = attrs_.insert({buf, attr}); |
77 | ir_assert(ret.second) << buf; |
78 | MAYBE_UNUSED(ret); |
79 | } |
80 | |
81 | void remove(const expr_t &buf) { |
82 | auto ret = removes_.insert(buf); |
83 | ir_assert(ret.second) << buf; |
84 | MAYBE_UNUSED(ret); |
85 | } |
86 | |
87 | stmt_t update(const stmt_t &stmt) { return mutate(stmt); } |
88 | |
89 | object_t _mutate(const alloc_t &obj) override { |
90 | auto new_obj = ir_mutator_t::_mutate(obj); |
91 | |
92 | // If removal succeeds, stop any further updates. |
93 | if (try_remove(new_obj)) return new_obj; |
94 | |
95 | // Otherwise try to apply other modifications one by one. |
96 | try_resize(new_obj); |
97 | try_add_attr(new_obj); |
98 | |
99 | return new_obj; |
100 | } |
101 | |
102 | private: |
103 | bool try_remove(object_t &obj) { |
104 | auto &alloc = obj.as<alloc_t>(); |
105 | auto it = removes_.find(alloc.buf); |
106 | if (it == removes_.end()) return false; |
107 | |
108 | obj = alloc.body; |
109 | removes_.erase(it); |
110 | return true; |
111 | } |
112 | |
113 | bool try_resize(object_t &obj) { |
114 | auto &alloc = obj.as<alloc_t>(); |
115 | auto it = resizes_.find(alloc.buf); |
116 | if (it == resizes_.end()) return false; |
117 | |
118 | obj = alloc_t::make( |
119 | alloc.buf, it->second, alloc.kind, alloc.attrs, alloc.body); |
120 | resizes_.erase(it); |
121 | return true; |
122 | } |
123 | |
124 | bool try_add_attr(object_t &obj) { |
125 | auto &alloc = obj.as<alloc_t>(); |
126 | auto it = attrs_.find(alloc.buf); |
127 | if (it == attrs_.end()) return false; |
128 | |
129 | auto new_attrs = alloc.attrs; |
130 | new_attrs.push_back(it->second); |
131 | |
132 | obj = alloc_t::make( |
133 | alloc.buf, alloc.size, alloc.kind, new_attrs, alloc.body); |
134 | attrs_.erase(it); |
135 | return true; |
136 | } |
137 | |
138 | object_set_t<expr_t> removes_; |
139 | object_map_t<expr_t, int> resizes_; |
140 | object_map_t<expr_t, alloc_attr_t> attrs_; |
141 | }; |
142 | |
143 | // Returns a new statement with injected buffer allocations from `allocs`. |
144 | // - If put_innermost is false, then `stmt` is nested to all allocations |
145 | // - If put_innermost is true, then every allocation is injected as innermost |
146 | // as possible |
147 | stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs, |
148 | bool put_innermost = false); |
149 | |
150 | // Returns a new statement with injected let statements, `stmt` is nested to |
151 | // all let statements. |
152 | stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets); |
153 | |
154 | template <typename T> |
155 | struct expr_cast_helper_t { |
156 | static T call(const expr_t &e) { return to_cpp<T>(e); } |
157 | |
158 | static std::vector<T> call(const std::vector<expr_t> &exprs) { |
159 | std::vector<T> ret; |
160 | for (auto &e : exprs) |
161 | ret.push_back(to_cpp<T>(e)); |
162 | return ret; |
163 | } |
164 | }; |
165 | |
166 | template <> |
167 | struct expr_cast_helper_t<expr_t> { |
168 | static expr_t call(const expr_t &e) { return e; } |
169 | |
170 | static std::vector<expr_t> call(const std::vector<expr_t> &exprs) { |
171 | return exprs; |
172 | } |
173 | |
174 | template <typename U, |
175 | typename |
176 | = typename std::enable_if<std::is_arithmetic<U>::value>::type> |
177 | static std::vector<expr_t> call(const std::vector<U> &vec) { |
178 | std::vector<expr_t> ret; |
179 | for (auto &v : vec) |
180 | ret.push_back(to_expr(v)); |
181 | return ret; |
182 | } |
183 | }; |
184 | |
185 | template <typename DstT, typename SrcT> |
186 | DstT expr_cast(const SrcT &src) { |
187 | return expr_cast_helper_t<DstT>::call(src); |
188 | } |
189 | |
190 | template <typename DstT, typename SrcT> |
191 | std::vector<DstT> expr_cast(const std::vector<SrcT> &src) { |
192 | return expr_cast_helper_t<DstT>::call(src); |
193 | } |
194 | |
195 | // Performs constant folding recursively to an IR tree. |
196 | object_t const_fold(const object_t &obj); |
197 | |
198 | // Performs constant folding non-recursively to an expression. |
199 | expr_t const_fold_non_recursive(const expr_t &e); |
200 | |
201 | template <typename T> |
202 | std::vector<object_t> find_objects(const object_t &root); |
203 | |
204 | class alloc_manager_t { |
205 | public: |
206 | alloc_manager_t(const stmt_t &root) { |
207 | auto allocs = find_objects<alloc_t>(root); |
208 | for (auto &_a : allocs) { |
209 | auto &a = _a.as<alloc_t>(); |
210 | auto ret = buf2alloc_.insert({a.buf, _a}); |
211 | buffers_.push_back(a.buf); |
212 | ir_assert(ret.second) << "Buffer already exists: " << a.buf; |
213 | MAYBE_UNUSED(ret); |
214 | } |
215 | |
216 | // Sort buffers by name. |
217 | std::sort(buffers_.begin(), buffers_.end(), |
218 | [](const expr_t &a, const expr_t &b) { |
219 | return a.as<var_t>().name < b.as<var_t>().name; |
220 | }); |
221 | } |
222 | |
223 | const std::vector<expr_t> &buffers() const { return buffers_; } |
224 | |
225 | expr_t find_buffer( |
226 | const std::string &name, bool allow_empty = false) const { |
227 | for (auto &b : buffers()) |
228 | if (b.as<var_t>().name == name) return b; |
229 | |
230 | if (!allow_empty) ir_error_not_expected() << name; |
231 | return expr_t(); |
232 | } |
233 | |
234 | std::vector<expr_t> find_buffers(alloc_kind_t kind) const { |
235 | std::vector<expr_t> ret; |
236 | for (auto &b : buffers()) |
237 | if (alloc_kind(b) == kind) ret.push_back(b); |
238 | return ret; |
239 | } |
240 | |
241 | int alloc_size(const expr_t &buf) const { |
242 | auto *a = find_alloc(buf); |
243 | ir_assert(a) << buf; |
244 | return a->size; |
245 | } |
246 | |
247 | alloc_kind_t alloc_kind(const expr_t &buf) const { |
248 | auto *a = find_alloc(buf); |
249 | ir_assert(a) << buf; |
250 | return a->kind; |
251 | } |
252 | |
253 | int total_size(alloc_kind_t kind) const { |
254 | int ret = 0; |
255 | for (auto &kv : buf2alloc_) { |
256 | auto &a = kv.second.as<alloc_t>(); |
257 | if (a.kind == kind) ret += a.size; |
258 | } |
259 | return ret; |
260 | } |
261 | |
262 | private: |
263 | const alloc_t *find_alloc(const expr_t &buf) const { |
264 | auto it = buf2alloc_.find(buf); |
265 | if (it == buf2alloc_.end()) return nullptr; |
266 | return it->second.as_ptr<alloc_t>(); |
267 | } |
268 | |
269 | object_map_t<expr_t, stmt_t> buf2alloc_; |
270 | std::vector<expr_t> buffers_; |
271 | object_map_t<expr_t, stmt_t> alloc_updates_; |
272 | }; |
273 | |
274 | // IR utility functions. |
275 | expr_t abs(const expr_t &e); |
276 | |
277 | expr_t cast(const expr_t &e, const type_t &type, bool saturate = false); |
278 | |
279 | bool is_zero(const expr_t &e); |
280 | |
281 | bool is_one(const expr_t &e); |
282 | |
283 | bool is_minus_one(const expr_t &e); |
284 | |
285 | bool is_const_broadcast(const expr_t &e); |
286 | |
287 | bool is_const_broadcast(const expr_t &e, const expr_t &value); |
288 | |
289 | expr_t make_buffer(const std::string &name); |
290 | |
291 | // Utility functions for nary_op_t. |
292 | expr_t nary_op_back_transform(const expr_t &e); |
293 | expr_t nary_op_canonicalize(const expr_t &_e); |
294 | expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args); |
295 | std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e); |
296 | |
297 | // Substitutes all occurrences of `from` to `to` in `root`. |
298 | object_t substitute(const object_t &root, const object_t &from, |
299 | const object_t &to, |
300 | int max_substitutions = std::numeric_limits<int>::max()); |
301 | |
302 | // Substitutes all occurrences of `from` to `to` in `root` and propagates any |
303 | // required type changes. |
304 | object_t substitute_with_different_type(const object_t &root, |
305 | const object_t &from, const object_t &to, |
306 | int max_substitutions = std::numeric_limits<int>::max()); |
307 | |
308 | // Returns leaf statements of `root`. Uses inorder traversal. |
309 | std::vector<stmt_t> flatten_statements(const stmt_t &root); |
310 | |
311 | template <typename T, bool find_unique = false, bool save_objects = true> |
312 | class object_finder_t : public ir_visitor_t { |
313 | public: |
314 | void _visit(const T &obj) override { |
315 | ir_visitor_t::_visit(obj); |
316 | occurrences++; |
317 | if (!save_objects) return; |
318 | if (find_unique) { |
319 | found_unique.insert(obj); |
320 | } else { |
321 | found.push_back(obj); |
322 | } |
323 | } |
324 | |
325 | std::vector<object_t> found; |
326 | object_set_t<object_t> found_unique; |
327 | int occurrences = 0; |
328 | }; |
329 | |
330 | // Returns all IR objects of type `T` found in `root`. |
331 | template <typename T> |
332 | std::vector<object_t> find_objects(const object_t &root) { |
333 | object_finder_t<T, /*find_unique=*/false> finder; |
334 | finder.visit(root); |
335 | return finder.found; |
336 | } |
337 | |
338 | template <typename T> |
339 | int count_objects(const object_t &root) { |
340 | object_finder_t<T, /*find_unique=*/false, /*save_objects=*/false> finder; |
341 | finder.visit(root); |
342 | return finder.occurrences; |
343 | } |
344 | |
345 | // Returns unique IR objects of type `T` found in `root`. |
346 | template <typename T> |
347 | object_set_t<object_t> find_unique_objects(const object_t &root) { |
348 | object_finder_t<T, /*find_unique=*/true> finder; |
349 | finder.visit(root); |
350 | return finder.found_unique; |
351 | } |
352 | |
353 | // Returns number of occurrences of `obj` in `root` (based on identity |
354 | // comparison). |
355 | int count_object(const object_t &root, const object_t &obj); |
356 | |
357 | // Returns number of occurrences of `obj` in vector of root objects (based on |
358 | // identity comparison). |
359 | template <typename T> |
360 | int count_object(const std::vector<T> &roots, const object_t &obj) { |
361 | int ret = 0; |
362 | for (auto &root : roots) |
363 | ret += count_object(root, obj); |
364 | return ret; |
365 | } |
366 | |
367 | // Checks if `root` contains `obj`. |
368 | bool contains_object(const object_t &root, const object_t &obj); |
369 | |
370 | // Returns all statement groups matching the label. |
371 | std::vector<stmt_t> find_stmt_groups( |
372 | const object_t &root, const stmt_label_t &label); |
373 | |
374 | // Returns a statement group matching the label. `root` must have exactly one |
375 | // occurrence. |
376 | utils::optional_t<stmt_t> find_stmt_group( |
377 | const object_t &root, const stmt_label_t &label); |
378 | |
379 | // Removes all statement groups matching the label. |
380 | object_t remove_stmt_group(const object_t &root, stmt_label_t label); |
381 | |
382 | class scope_visitor_t : public ir_visitor_t { |
383 | public: |
384 | bool is_expr_defined(const expr_t &e) const { |
385 | auto vars = find_unique_objects<var_t>(e); |
386 | for (auto &v : vars) { |
387 | if (def_vars_.count(v) == 0) return false; |
388 | } |
389 | return true; |
390 | } |
391 | |
392 | #define CASE(type, var_field, is_pre) \ |
393 | if (obj.is<type>()) { \ |
394 | visit_scope((const type &)obj, ((const type &)obj).var_field, is_pre); \ |
395 | return; \ |
396 | } |
397 | |
398 | void pre_visit(const object_impl_t &obj) override { |
399 | CASE(alloc_t, buf, true); |
400 | CASE(let_t, var, true); |
401 | CASE(for_t, var, true); |
402 | } |
403 | |
404 | void post_visit(const object_impl_t &obj) override { |
405 | CASE(alloc_t, buf, false); |
406 | CASE(let_t, var, false); |
407 | CASE(for_t, var, false); |
408 | } |
409 | |
410 | #undef CASE |
411 | |
412 | private: |
413 | template <typename T> |
414 | void visit_scope(const T &obj, const expr_t &var, bool is_pre_visit) { |
415 | if (is_pre_visit) { |
416 | def_vars_.insert(var); |
417 | return; |
418 | } |
419 | def_vars_.erase(var); |
420 | } |
421 | |
422 | object_set_t<expr_t> def_vars_; |
423 | }; |
424 | |
425 | class ir_path_t { |
426 | public: |
427 | void push(const object_impl_t *obj) { path_.push_back(obj); } |
428 | |
429 | void pop() { path_.pop_back(); } |
430 | |
431 | const object_impl_t *back() const { |
432 | ir_assert(!is_empty()); |
433 | return path_.back(); |
434 | } |
435 | |
436 | bool is_empty() const { return path_.empty(); } |
437 | |
438 | void merge(const ir_path_t &other) { |
439 | size_t idx; |
440 | size_t min_size = std::min(path_.size(), other.path_.size()); |
441 | for (idx = 0; idx < min_size; idx++) { |
442 | if (path_[idx] != other.path_[idx]) break; |
443 | } |
444 | path_.resize(idx); |
445 | } |
446 | |
447 | private: |
448 | std::vector<const object_impl_t *> path_; |
449 | }; |
450 | |
451 | // Only for statements that create scope. |
452 | stmt_t get_stmt_body(const stmt_t &stmt); |
453 | |
454 | stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body); |
455 | |
456 | int get_peak_grf_usage(const stmt_t &stmt, int grf_size, int external_usage = 0, |
457 | bool skip_let = false); |
458 | |
459 | bool has_send_atomics(const stmt_t &s); |
460 | |
461 | struct mem_usage_guard_t { |
462 | mem_usage_guard_t(int *usage, int *peak_usage, int size) |
463 | : usage(usage), peak_usage(peak_usage), size(size) { |
464 | if (usage) *usage += size; |
465 | if (usage && peak_usage) *peak_usage = std::max(*peak_usage, *usage); |
466 | } |
467 | |
468 | mem_usage_guard_t(int *usage, int size) |
469 | : mem_usage_guard_t(usage, nullptr, size) {} |
470 | |
471 | mem_usage_guard_t() : mem_usage_guard_t(nullptr, nullptr, 0) {} |
472 | |
473 | mem_usage_guard_t(mem_usage_guard_t &&other) |
474 | : usage(other.usage), peak_usage(other.peak_usage), size(other.size) { |
475 | other.usage = nullptr; |
476 | other.peak_usage = nullptr; |
477 | other.size = 0; |
478 | } |
479 | |
480 | mem_usage_guard_t &operator=(mem_usage_guard_t &&other) { |
481 | usage = other.usage; |
482 | peak_usage = other.peak_usage; |
483 | size = other.size; |
484 | other.usage = nullptr; |
485 | other.peak_usage = nullptr; |
486 | other.size = 0; |
487 | return *this; |
488 | } |
489 | |
490 | mem_usage_guard_t(const mem_usage_guard_t &) = delete; |
491 | mem_usage_guard_t &operator=(const mem_usage_guard_t &) = delete; |
492 | |
493 | ~mem_usage_guard_t() { |
494 | if (usage) *usage -= size; |
495 | } |
496 | |
497 | int *usage {nullptr}; |
498 | int *peak_usage {nullptr}; |
499 | int size {0}; |
500 | }; |
501 | |
502 | // Describes the linear transformation F(x) for variable x: F(x) = (a * x + b), |
503 | // where a and b are integer constants. |
504 | struct linear_transform_t { |
505 | expr_t x; |
506 | int a; |
507 | int b; |
508 | |
509 | bool is_identity() const { return a == 1 && b == 0; } |
510 | }; |
511 | |
512 | // Relation: (lhs op rhs), where: |
513 | // - lhs is a variable |
514 | // - rhs is an integer constant |
515 | // - op is a comparison operation |
516 | class relation_t { |
517 | public: |
518 | relation_t(const expr_t &expr) : expr_(normalize(expr)) {} |
519 | |
520 | const expr_t &expr() const { return expr_; } |
521 | |
522 | const expr_t &var() const { return expr_.as<binary_op_t>().a; } |
523 | |
524 | const expr_t &rhs() const { return expr_.as<binary_op_t>().b; } |
525 | |
526 | op_kind_t op_kind() const { return expr_.as<binary_op_t>().op_kind; } |
527 | |
528 | bool implies(const relation_t &other) const; |
529 | |
530 | // Applies linear transformation to left and right hand sides of the relation. |
531 | relation_t transform(const linear_transform_t &t, const expr_t &new_var); |
532 | |
533 | std::string str() const { |
534 | std::ostringstream oss; |
535 | oss << expr_; |
536 | return oss.str(); |
537 | } |
538 | |
539 | static bool is_relation_constraint(const expr_t &e) { |
540 | auto *binary_op = e.as_ptr<binary_op_t>(); |
541 | if (!binary_op) return false; |
542 | if (!is_var(binary_op->a)) return false; |
543 | if (!is_const(binary_op->b)) return false; |
544 | if (!is_cmp_op(binary_op->op_kind)) return false; |
545 | return true; |
546 | } |
547 | |
548 | private: |
549 | static expr_t normalize(const expr_t &e); |
550 | |
551 | expr_t expr_; |
552 | }; |
553 | |
554 | inline std::ostream &operator<<(std::ostream &out, const relation_t &rel) { |
555 | out << rel.str(); |
556 | return out; |
557 | } |
558 | |
559 | // Equality for modulus: (var % mod) == 0, where: |
560 | // - var is a variable |
561 | // - mod is an integer constant |
562 | class modulus_info_t { |
563 | public: |
564 | modulus_info_t(const expr_t &expr) : expr_(expr) {} |
565 | |
566 | const expr_t &expr() const { return expr_; } |
567 | |
568 | const expr_t &var() const { |
569 | auto &mod_expr = expr_.as<binary_op_t>().a; |
570 | return mod_expr.as<binary_op_t>().a; |
571 | } |
572 | |
573 | const expr_t &mod() const { |
574 | auto &mod_expr = expr_.as<binary_op_t>().a; |
575 | return mod_expr.as<binary_op_t>().b; |
576 | } |
577 | |
578 | bool implies(const modulus_info_t &other) const { |
579 | ir_assert(var().is_same(other.var())); |
580 | |
581 | int64_t this_mod = to_cpp<int64_t>(mod()); |
582 | int64_t other_mod = to_cpp<int64_t>(other.mod()); |
583 | |
584 | return this_mod % other_mod == 0; |
585 | } |
586 | |
587 | std::string str() const { |
588 | std::ostringstream oss; |
589 | oss << expr_; |
590 | return oss.str(); |
591 | } |
592 | |
593 | // Try to match (var % mod) == 0. |
594 | static bool is_modulus_constraint(const expr_t &e); |
595 | |
596 | private: |
597 | expr_t expr_; |
598 | }; |
599 | |
600 | inline std::ostream &operator<<(std::ostream &out, const modulus_info_t &mod) { |
601 | out << mod.str(); |
602 | return out; |
603 | } |
604 | |
605 | // Helper class to find constant bounds of integer expressions based on known |
606 | // relations. |
607 | class bound_finder_base_t { |
608 | public: |
609 | int64_t find_low_bound(const expr_t &e) const { |
610 | return find_bound_impl(e, /*is_low=*/true); |
611 | } |
612 | |
613 | int64_t find_high_bound(const expr_t &e) const { |
614 | return find_bound_impl(e, /*is_low=*/false); |
615 | } |
616 | |
617 | virtual int64_t get_var_bound(const expr_t &e, bool is_low) const = 0; |
618 | |
619 | static int64_t unlimited_bound(bool is_low) { |
620 | if (is_low) return std::numeric_limits<int64_t>::min(); |
621 | return std::numeric_limits<int64_t>::max(); |
622 | } |
623 | |
624 | static bool is_good_bound(int64_t bound) { |
625 | if (bound == unlimited_bound(true)) return false; |
626 | if (bound == unlimited_bound(false)) return false; |
627 | return true; |
628 | } |
629 | |
630 | protected: |
631 | // If is_low is true, searches for proven low bound, and high bound |
632 | // otherwise. |
633 | virtual int64_t find_bound_impl(const expr_t &e, bool is_low) const; |
634 | }; |
635 | |
636 | class bound_finder_t : public bound_finder_base_t { |
637 | public: |
638 | bound_finder_t( |
639 | const object_map_t<expr_t, std::vector<relation_t>> &relations) |
640 | : relations_(relations) {} |
641 | |
642 | int64_t get_var_bound(const expr_t &e, bool is_low) const override { |
643 | ir_assert(is_var(e)); |
644 | int64_t def_bound = unlimited_bound(is_low); |
645 | auto it = relations_.find(e); |
646 | if (it == relations_.end()) return def_bound; |
647 | |
648 | int64_t ret = def_bound; |
649 | for (auto &rel : it->second) { |
650 | bool is_ge = (rel.op_kind() == op_kind_t::_ge); |
651 | if (is_ge != is_low) continue; |
652 | if (is_ge) { |
653 | ret = std::max(to_cpp<int64_t>(rel.rhs()), ret); |
654 | } else { |
655 | ret = std::min(to_cpp<int64_t>(rel.rhs()), ret); |
656 | } |
657 | } |
658 | return ret; |
659 | } |
660 | |
661 | private: |
662 | object_map_t<expr_t, std::vector<relation_t>> relations_; |
663 | }; |
664 | |
665 | // TODO: Add integers check (only integers can be constrained). |
666 | class constraint_set_t { |
667 | public: |
668 | const object_map_t<expr_t, std::vector<relation_t>> &relations() const { |
669 | return relations_; |
670 | } |
671 | |
672 | void add_constraint(const expr_t &e); |
673 | |
674 | bool can_prove(const expr_t &e, bool try_simplify = true) const { |
675 | auto ret = can_prove_impl(e, /*do_simplify=*/false); |
676 | if (ret || !try_simplify) return ret; |
677 | |
678 | return can_prove_impl(e, /*do_simplify=*/true); |
679 | } |
680 | |
681 | bool is_single_value(const expr_t &e, expr_t &value) const; |
682 | |
683 | int max_proven_gcd(const expr_t &var) const; |
684 | |
685 | private: |
686 | bool can_prove_modulus(const expr_t &e) const { |
687 | modulus_info_t unknown(e); |
688 | auto it = modulus_infos_.find(unknown.var()); |
689 | if (it == modulus_infos_.end()) return false; |
690 | |
691 | for (auto &known : it->second) { |
692 | if (known.implies(unknown)) return true; |
693 | } |
694 | |
695 | return false; |
696 | } |
697 | |
698 | bool can_prove_relation(const expr_t &e) const { |
699 | relation_t unknown(e); |
700 | auto it = relations_.find(unknown.var()); |
701 | if (it == relations_.end()) return false; |
702 | |
703 | for (auto &known : it->second) { |
704 | if (known.implies(unknown)) return true; |
705 | } |
706 | |
707 | return false; |
708 | } |
709 | |
710 | bool try_prove_compound_relation(const expr_t &e) const { |
711 | auto *binary = e.as_ptr<binary_op_t>(); |
712 | if (!binary) return false; |
713 | |
714 | auto op_kind = binary->op_kind; |
715 | auto &a = binary->a; |
716 | auto &_b = binary->b; |
717 | |
718 | if (!is_const(_b)) return false; |
719 | |
720 | auto b = to_cpp<int64_t>(_b); |
721 | |
722 | // Normalize operation kind. |
723 | switch (op_kind) { |
724 | case op_kind_t::_ge: |
725 | case op_kind_t::_le: break; |
726 | case op_kind_t::_gt: |
727 | op_kind = op_kind_t::_ge; |
728 | ir_assert(b < std::numeric_limits<int64_t>::max()); |
729 | b += 1; |
730 | break; |
731 | case op_kind_t::_lt: |
732 | op_kind = op_kind_t::_le; |
733 | ir_assert(b > std::numeric_limits<int64_t>::min()); |
734 | b -= 1; |
735 | break; |
736 | default: return false; |
737 | } |
738 | |
739 | bound_finder_t finder(relations_); |
740 | if (op_kind == op_kind_t::_ge) { |
741 | auto lo = finder.find_low_bound(a); |
742 | if (!bound_finder_t::is_good_bound(lo)) return false; |
743 | return lo >= b; |
744 | } |
745 | |
746 | if (op_kind == op_kind_t::_le) { |
747 | auto hi = finder.find_high_bound(a); |
748 | if (!bound_finder_t::is_good_bound(hi)) return false; |
749 | return hi <= b; |
750 | } |
751 | |
752 | return false; |
753 | } |
754 | |
755 | bool can_prove_impl(const expr_t &_e, bool do_simplify) const; |
756 | |
757 | object_map_t<expr_t, std::vector<relation_t>> relations_; |
758 | object_map_t<expr_t, std::vector<modulus_info_t>> modulus_infos_; |
759 | }; |
760 | |
761 | // Pre-defined functions. |
762 | namespace funcs { |
763 | |
764 | inline func_t barrier_func() { |
765 | static thread_local auto f = builtin_t::make("barrier" ); |
766 | return f; |
767 | } |
768 | |
769 | inline stmt_t barrier() { |
770 | return barrier_func().call(); |
771 | } |
772 | |
773 | inline func_t slm_fence_func() { |
774 | static thread_local auto f = builtin_t::make("slm_fence" ); |
775 | return f; |
776 | } |
777 | |
778 | inline stmt_t slm_fence() { |
779 | return slm_fence_func().call(); |
780 | } |
781 | |
782 | inline func_t signal_func() { |
783 | static thread_local auto f = builtin_t::make("signal" ); |
784 | return f; |
785 | } |
786 | |
787 | inline stmt_t signal() { |
788 | return signal_func().call(); |
789 | } |
790 | |
791 | inline func_t barrier_wait_func() { |
792 | static thread_local auto f = builtin_t::make("barrier_wait" ); |
793 | return f; |
794 | } |
795 | |
796 | inline stmt_t barrier_wait() { |
797 | return barrier_wait_func().call(); |
798 | } |
799 | |
800 | inline func_t continue_func() { |
801 | static thread_local auto f = builtin_t::make("continue" ); |
802 | return f; |
803 | } |
804 | |
805 | inline stmt_t _continue() { |
806 | return continue_func().call(); |
807 | } |
808 | |
809 | } // namespace funcs |
810 | |
811 | } // namespace jit |
812 | } // namespace gpu |
813 | } // namespace impl |
814 | } // namespace dnnl |
815 | |
816 | #endif |
817 | |