1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_IR_HPP
18#define GPU_JIT_IR_IR_HPP
19
20#include <algorithm>
21#include <mutex>
22#include <thread>
23#include <vector>
24
25#include "common/optional.hpp"
26#include "gpu/jit/ir/core.hpp"
27#include "gpu/jit/ir/hw_config.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace jit {
33
34class constraint_set_t;
35
36class ir_context_t {
37public:
38 ir_context_t(const exec_config_t &exec_cfg, constraint_set_t &cset)
39 : exec_cfg_(exec_cfg), cset_(cset) {}
40
41 const exec_config_t &exec_cfg() const { return exec_cfg_; }
42
43 const hw_config_t &hw_cfg() const { return exec_cfg().hw_cfg(); }
44
45 ngen::HW hw() const { return hw_cfg().hw(); }
46
47 int grf_size() const { return hw_cfg().grf_size(); }
48
49 const constraint_set_t &cset() { return cset_; }
50
51 void add_constraint(const expr_t &e);
52
53 expr_t create_tmp_var(
54 const type_t &type, const std::string &prefix = "tmp") {
55 int &id = prefix_ids_[prefix];
56 auto name = prefix + "_" + std::to_string(id);
57 id++;
58 return var_t::make(type, name);
59 }
60
61private:
62 exec_config_t exec_cfg_;
63 constraint_set_t &cset_;
64 std::unordered_map<std::string, int> prefix_ids_;
65};
66
67class alloc_updater_t : public ir_mutator_t {
68public:
69 void resize(const expr_t &buf, int new_size) {
70 auto ret = resizes_.insert({buf, new_size});
71 ir_assert(ret.second) << buf;
72 MAYBE_UNUSED(ret);
73 }
74
75 void add_attr(const expr_t &buf, const alloc_attr_t &attr) {
76 auto ret = attrs_.insert({buf, attr});
77 ir_assert(ret.second) << buf;
78 MAYBE_UNUSED(ret);
79 }
80
81 void remove(const expr_t &buf) {
82 auto ret = removes_.insert(buf);
83 ir_assert(ret.second) << buf;
84 MAYBE_UNUSED(ret);
85 }
86
87 stmt_t update(const stmt_t &stmt) { return mutate(stmt); }
88
89 object_t _mutate(const alloc_t &obj) override {
90 auto new_obj = ir_mutator_t::_mutate(obj);
91
92 // If removal succeeds, stop any further updates.
93 if (try_remove(new_obj)) return new_obj;
94
95 // Otherwise try to apply other modifications one by one.
96 try_resize(new_obj);
97 try_add_attr(new_obj);
98
99 return new_obj;
100 }
101
102private:
103 bool try_remove(object_t &obj) {
104 auto &alloc = obj.as<alloc_t>();
105 auto it = removes_.find(alloc.buf);
106 if (it == removes_.end()) return false;
107
108 obj = alloc.body;
109 removes_.erase(it);
110 return true;
111 }
112
113 bool try_resize(object_t &obj) {
114 auto &alloc = obj.as<alloc_t>();
115 auto it = resizes_.find(alloc.buf);
116 if (it == resizes_.end()) return false;
117
118 obj = alloc_t::make(
119 alloc.buf, it->second, alloc.kind, alloc.attrs, alloc.body);
120 resizes_.erase(it);
121 return true;
122 }
123
124 bool try_add_attr(object_t &obj) {
125 auto &alloc = obj.as<alloc_t>();
126 auto it = attrs_.find(alloc.buf);
127 if (it == attrs_.end()) return false;
128
129 auto new_attrs = alloc.attrs;
130 new_attrs.push_back(it->second);
131
132 obj = alloc_t::make(
133 alloc.buf, alloc.size, alloc.kind, new_attrs, alloc.body);
134 attrs_.erase(it);
135 return true;
136 }
137
138 object_set_t<expr_t> removes_;
139 object_map_t<expr_t, int> resizes_;
140 object_map_t<expr_t, alloc_attr_t> attrs_;
141};
142
143// Returns a new statement with injected buffer allocations from `allocs`.
144// - If put_innermost is false, then `stmt` is nested to all allocations
145// - If put_innermost is true, then every allocation is injected as innermost
146// as possible
147stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
148 bool put_innermost = false);
149
150// Returns a new statement with injected let statements, `stmt` is nested to
151// all let statements.
152stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets);
153
154template <typename T>
155struct expr_cast_helper_t {
156 static T call(const expr_t &e) { return to_cpp<T>(e); }
157
158 static std::vector<T> call(const std::vector<expr_t> &exprs) {
159 std::vector<T> ret;
160 for (auto &e : exprs)
161 ret.push_back(to_cpp<T>(e));
162 return ret;
163 }
164};
165
166template <>
167struct expr_cast_helper_t<expr_t> {
168 static expr_t call(const expr_t &e) { return e; }
169
170 static std::vector<expr_t> call(const std::vector<expr_t> &exprs) {
171 return exprs;
172 }
173
174 template <typename U,
175 typename
176 = typename std::enable_if<std::is_arithmetic<U>::value>::type>
177 static std::vector<expr_t> call(const std::vector<U> &vec) {
178 std::vector<expr_t> ret;
179 for (auto &v : vec)
180 ret.push_back(to_expr(v));
181 return ret;
182 }
183};
184
185template <typename DstT, typename SrcT>
186DstT expr_cast(const SrcT &src) {
187 return expr_cast_helper_t<DstT>::call(src);
188}
189
190template <typename DstT, typename SrcT>
191std::vector<DstT> expr_cast(const std::vector<SrcT> &src) {
192 return expr_cast_helper_t<DstT>::call(src);
193}
194
195// Performs constant folding recursively to an IR tree.
196object_t const_fold(const object_t &obj);
197
198// Performs constant folding non-recursively to an expression.
199expr_t const_fold_non_recursive(const expr_t &e);
200
201template <typename T>
202std::vector<object_t> find_objects(const object_t &root);
203
204class alloc_manager_t {
205public:
206 alloc_manager_t(const stmt_t &root) {
207 auto allocs = find_objects<alloc_t>(root);
208 for (auto &_a : allocs) {
209 auto &a = _a.as<alloc_t>();
210 auto ret = buf2alloc_.insert({a.buf, _a});
211 buffers_.push_back(a.buf);
212 ir_assert(ret.second) << "Buffer already exists: " << a.buf;
213 MAYBE_UNUSED(ret);
214 }
215
216 // Sort buffers by name.
217 std::sort(buffers_.begin(), buffers_.end(),
218 [](const expr_t &a, const expr_t &b) {
219 return a.as<var_t>().name < b.as<var_t>().name;
220 });
221 }
222
223 const std::vector<expr_t> &buffers() const { return buffers_; }
224
225 expr_t find_buffer(
226 const std::string &name, bool allow_empty = false) const {
227 for (auto &b : buffers())
228 if (b.as<var_t>().name == name) return b;
229
230 if (!allow_empty) ir_error_not_expected() << name;
231 return expr_t();
232 }
233
234 std::vector<expr_t> find_buffers(alloc_kind_t kind) const {
235 std::vector<expr_t> ret;
236 for (auto &b : buffers())
237 if (alloc_kind(b) == kind) ret.push_back(b);
238 return ret;
239 }
240
241 int alloc_size(const expr_t &buf) const {
242 auto *a = find_alloc(buf);
243 ir_assert(a) << buf;
244 return a->size;
245 }
246
247 alloc_kind_t alloc_kind(const expr_t &buf) const {
248 auto *a = find_alloc(buf);
249 ir_assert(a) << buf;
250 return a->kind;
251 }
252
253 int total_size(alloc_kind_t kind) const {
254 int ret = 0;
255 for (auto &kv : buf2alloc_) {
256 auto &a = kv.second.as<alloc_t>();
257 if (a.kind == kind) ret += a.size;
258 }
259 return ret;
260 }
261
262private:
263 const alloc_t *find_alloc(const expr_t &buf) const {
264 auto it = buf2alloc_.find(buf);
265 if (it == buf2alloc_.end()) return nullptr;
266 return it->second.as_ptr<alloc_t>();
267 }
268
269 object_map_t<expr_t, stmt_t> buf2alloc_;
270 std::vector<expr_t> buffers_;
271 object_map_t<expr_t, stmt_t> alloc_updates_;
272};
273
274// IR utility functions.
275expr_t abs(const expr_t &e);
276
277expr_t cast(const expr_t &e, const type_t &type, bool saturate = false);
278
279bool is_zero(const expr_t &e);
280
281bool is_one(const expr_t &e);
282
283bool is_minus_one(const expr_t &e);
284
285bool is_const_broadcast(const expr_t &e);
286
287bool is_const_broadcast(const expr_t &e, const expr_t &value);
288
289expr_t make_buffer(const std::string &name);
290
291// Utility functions for nary_op_t.
292expr_t nary_op_back_transform(const expr_t &e);
293expr_t nary_op_canonicalize(const expr_t &_e);
294expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args);
295std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e);
296
297// Substitutes all occurrences of `from` to `to` in `root`.
298object_t substitute(const object_t &root, const object_t &from,
299 const object_t &to,
300 int max_substitutions = std::numeric_limits<int>::max());
301
302// Substitutes all occurrences of `from` to `to` in `root` and propagates any
303// required type changes.
304object_t substitute_with_different_type(const object_t &root,
305 const object_t &from, const object_t &to,
306 int max_substitutions = std::numeric_limits<int>::max());
307
308// Returns leaf statements of `root`. Uses inorder traversal.
309std::vector<stmt_t> flatten_statements(const stmt_t &root);
310
311template <typename T, bool find_unique = false, bool save_objects = true>
312class object_finder_t : public ir_visitor_t {
313public:
314 void _visit(const T &obj) override {
315 ir_visitor_t::_visit(obj);
316 occurrences++;
317 if (!save_objects) return;
318 if (find_unique) {
319 found_unique.insert(obj);
320 } else {
321 found.push_back(obj);
322 }
323 }
324
325 std::vector<object_t> found;
326 object_set_t<object_t> found_unique;
327 int occurrences = 0;
328};
329
330// Returns all IR objects of type `T` found in `root`.
331template <typename T>
332std::vector<object_t> find_objects(const object_t &root) {
333 object_finder_t<T, /*find_unique=*/false> finder;
334 finder.visit(root);
335 return finder.found;
336}
337
338template <typename T>
339int count_objects(const object_t &root) {
340 object_finder_t<T, /*find_unique=*/false, /*save_objects=*/false> finder;
341 finder.visit(root);
342 return finder.occurrences;
343}
344
345// Returns unique IR objects of type `T` found in `root`.
346template <typename T>
347object_set_t<object_t> find_unique_objects(const object_t &root) {
348 object_finder_t<T, /*find_unique=*/true> finder;
349 finder.visit(root);
350 return finder.found_unique;
351}
352
353// Returns number of occurrences of `obj` in `root` (based on identity
354// comparison).
355int count_object(const object_t &root, const object_t &obj);
356
357// Returns number of occurrences of `obj` in vector of root objects (based on
358// identity comparison).
359template <typename T>
360int count_object(const std::vector<T> &roots, const object_t &obj) {
361 int ret = 0;
362 for (auto &root : roots)
363 ret += count_object(root, obj);
364 return ret;
365}
366
367// Checks if `root` contains `obj`.
368bool contains_object(const object_t &root, const object_t &obj);
369
370// Returns all statement groups matching the label.
371std::vector<stmt_t> find_stmt_groups(
372 const object_t &root, const stmt_label_t &label);
373
374// Returns a statement group matching the label. `root` must have exactly one
375// occurrence.
376utils::optional_t<stmt_t> find_stmt_group(
377 const object_t &root, const stmt_label_t &label);
378
379// Removes all statement groups matching the label.
380object_t remove_stmt_group(const object_t &root, stmt_label_t label);
381
382class scope_visitor_t : public ir_visitor_t {
383public:
384 bool is_expr_defined(const expr_t &e) const {
385 auto vars = find_unique_objects<var_t>(e);
386 for (auto &v : vars) {
387 if (def_vars_.count(v) == 0) return false;
388 }
389 return true;
390 }
391
392#define CASE(type, var_field, is_pre) \
393 if (obj.is<type>()) { \
394 visit_scope((const type &)obj, ((const type &)obj).var_field, is_pre); \
395 return; \
396 }
397
398 void pre_visit(const object_impl_t &obj) override {
399 CASE(alloc_t, buf, true);
400 CASE(let_t, var, true);
401 CASE(for_t, var, true);
402 }
403
404 void post_visit(const object_impl_t &obj) override {
405 CASE(alloc_t, buf, false);
406 CASE(let_t, var, false);
407 CASE(for_t, var, false);
408 }
409
410#undef CASE
411
412private:
413 template <typename T>
414 void visit_scope(const T &obj, const expr_t &var, bool is_pre_visit) {
415 if (is_pre_visit) {
416 def_vars_.insert(var);
417 return;
418 }
419 def_vars_.erase(var);
420 }
421
422 object_set_t<expr_t> def_vars_;
423};
424
425class ir_path_t {
426public:
427 void push(const object_impl_t *obj) { path_.push_back(obj); }
428
429 void pop() { path_.pop_back(); }
430
431 const object_impl_t *back() const {
432 ir_assert(!is_empty());
433 return path_.back();
434 }
435
436 bool is_empty() const { return path_.empty(); }
437
438 void merge(const ir_path_t &other) {
439 size_t idx;
440 size_t min_size = std::min(path_.size(), other.path_.size());
441 for (idx = 0; idx < min_size; idx++) {
442 if (path_[idx] != other.path_[idx]) break;
443 }
444 path_.resize(idx);
445 }
446
447private:
448 std::vector<const object_impl_t *> path_;
449};
450
451// Only for statements that create scope.
452stmt_t get_stmt_body(const stmt_t &stmt);
453
454stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body);
455
456int get_peak_grf_usage(const stmt_t &stmt, int grf_size, int external_usage = 0,
457 bool skip_let = false);
458
459bool has_send_atomics(const stmt_t &s);
460
461struct mem_usage_guard_t {
462 mem_usage_guard_t(int *usage, int *peak_usage, int size)
463 : usage(usage), peak_usage(peak_usage), size(size) {
464 if (usage) *usage += size;
465 if (usage && peak_usage) *peak_usage = std::max(*peak_usage, *usage);
466 }
467
468 mem_usage_guard_t(int *usage, int size)
469 : mem_usage_guard_t(usage, nullptr, size) {}
470
471 mem_usage_guard_t() : mem_usage_guard_t(nullptr, nullptr, 0) {}
472
473 mem_usage_guard_t(mem_usage_guard_t &&other)
474 : usage(other.usage), peak_usage(other.peak_usage), size(other.size) {
475 other.usage = nullptr;
476 other.peak_usage = nullptr;
477 other.size = 0;
478 }
479
480 mem_usage_guard_t &operator=(mem_usage_guard_t &&other) {
481 usage = other.usage;
482 peak_usage = other.peak_usage;
483 size = other.size;
484 other.usage = nullptr;
485 other.peak_usage = nullptr;
486 other.size = 0;
487 return *this;
488 }
489
490 mem_usage_guard_t(const mem_usage_guard_t &) = delete;
491 mem_usage_guard_t &operator=(const mem_usage_guard_t &) = delete;
492
493 ~mem_usage_guard_t() {
494 if (usage) *usage -= size;
495 }
496
497 int *usage {nullptr};
498 int *peak_usage {nullptr};
499 int size {0};
500};
501
502// Describes the linear transformation F(x) for variable x: F(x) = (a * x + b),
503// where a and b are integer constants.
504struct linear_transform_t {
505 expr_t x;
506 int a;
507 int b;
508
509 bool is_identity() const { return a == 1 && b == 0; }
510};
511
512// Relation: (lhs op rhs), where:
513// - lhs is a variable
514// - rhs is an integer constant
515// - op is a comparison operation
516class relation_t {
517public:
518 relation_t(const expr_t &expr) : expr_(normalize(expr)) {}
519
520 const expr_t &expr() const { return expr_; }
521
522 const expr_t &var() const { return expr_.as<binary_op_t>().a; }
523
524 const expr_t &rhs() const { return expr_.as<binary_op_t>().b; }
525
526 op_kind_t op_kind() const { return expr_.as<binary_op_t>().op_kind; }
527
528 bool implies(const relation_t &other) const;
529
530 // Applies linear transformation to left and right hand sides of the relation.
531 relation_t transform(const linear_transform_t &t, const expr_t &new_var);
532
533 std::string str() const {
534 std::ostringstream oss;
535 oss << expr_;
536 return oss.str();
537 }
538
539 static bool is_relation_constraint(const expr_t &e) {
540 auto *binary_op = e.as_ptr<binary_op_t>();
541 if (!binary_op) return false;
542 if (!is_var(binary_op->a)) return false;
543 if (!is_const(binary_op->b)) return false;
544 if (!is_cmp_op(binary_op->op_kind)) return false;
545 return true;
546 }
547
548private:
549 static expr_t normalize(const expr_t &e);
550
551 expr_t expr_;
552};
553
554inline std::ostream &operator<<(std::ostream &out, const relation_t &rel) {
555 out << rel.str();
556 return out;
557}
558
559// Equality for modulus: (var % mod) == 0, where:
560// - var is a variable
561// - mod is an integer constant
562class modulus_info_t {
563public:
564 modulus_info_t(const expr_t &expr) : expr_(expr) {}
565
566 const expr_t &expr() const { return expr_; }
567
568 const expr_t &var() const {
569 auto &mod_expr = expr_.as<binary_op_t>().a;
570 return mod_expr.as<binary_op_t>().a;
571 }
572
573 const expr_t &mod() const {
574 auto &mod_expr = expr_.as<binary_op_t>().a;
575 return mod_expr.as<binary_op_t>().b;
576 }
577
578 bool implies(const modulus_info_t &other) const {
579 ir_assert(var().is_same(other.var()));
580
581 int64_t this_mod = to_cpp<int64_t>(mod());
582 int64_t other_mod = to_cpp<int64_t>(other.mod());
583
584 return this_mod % other_mod == 0;
585 }
586
587 std::string str() const {
588 std::ostringstream oss;
589 oss << expr_;
590 return oss.str();
591 }
592
593 // Try to match (var % mod) == 0.
594 static bool is_modulus_constraint(const expr_t &e);
595
596private:
597 expr_t expr_;
598};
599
600inline std::ostream &operator<<(std::ostream &out, const modulus_info_t &mod) {
601 out << mod.str();
602 return out;
603}
604
605// Helper class to find constant bounds of integer expressions based on known
606// relations.
607class bound_finder_base_t {
608public:
609 int64_t find_low_bound(const expr_t &e) const {
610 return find_bound_impl(e, /*is_low=*/true);
611 }
612
613 int64_t find_high_bound(const expr_t &e) const {
614 return find_bound_impl(e, /*is_low=*/false);
615 }
616
617 virtual int64_t get_var_bound(const expr_t &e, bool is_low) const = 0;
618
619 static int64_t unlimited_bound(bool is_low) {
620 if (is_low) return std::numeric_limits<int64_t>::min();
621 return std::numeric_limits<int64_t>::max();
622 }
623
624 static bool is_good_bound(int64_t bound) {
625 if (bound == unlimited_bound(true)) return false;
626 if (bound == unlimited_bound(false)) return false;
627 return true;
628 }
629
630protected:
631 // If is_low is true, searches for proven low bound, and high bound
632 // otherwise.
633 virtual int64_t find_bound_impl(const expr_t &e, bool is_low) const;
634};
635
636class bound_finder_t : public bound_finder_base_t {
637public:
638 bound_finder_t(
639 const object_map_t<expr_t, std::vector<relation_t>> &relations)
640 : relations_(relations) {}
641
642 int64_t get_var_bound(const expr_t &e, bool is_low) const override {
643 ir_assert(is_var(e));
644 int64_t def_bound = unlimited_bound(is_low);
645 auto it = relations_.find(e);
646 if (it == relations_.end()) return def_bound;
647
648 int64_t ret = def_bound;
649 for (auto &rel : it->second) {
650 bool is_ge = (rel.op_kind() == op_kind_t::_ge);
651 if (is_ge != is_low) continue;
652 if (is_ge) {
653 ret = std::max(to_cpp<int64_t>(rel.rhs()), ret);
654 } else {
655 ret = std::min(to_cpp<int64_t>(rel.rhs()), ret);
656 }
657 }
658 return ret;
659 }
660
661private:
662 object_map_t<expr_t, std::vector<relation_t>> relations_;
663};
664
665// TODO: Add integers check (only integers can be constrained).
666class constraint_set_t {
667public:
668 const object_map_t<expr_t, std::vector<relation_t>> &relations() const {
669 return relations_;
670 }
671
672 void add_constraint(const expr_t &e);
673
674 bool can_prove(const expr_t &e, bool try_simplify = true) const {
675 auto ret = can_prove_impl(e, /*do_simplify=*/false);
676 if (ret || !try_simplify) return ret;
677
678 return can_prove_impl(e, /*do_simplify=*/true);
679 }
680
681 bool is_single_value(const expr_t &e, expr_t &value) const;
682
683 int max_proven_gcd(const expr_t &var) const;
684
685private:
686 bool can_prove_modulus(const expr_t &e) const {
687 modulus_info_t unknown(e);
688 auto it = modulus_infos_.find(unknown.var());
689 if (it == modulus_infos_.end()) return false;
690
691 for (auto &known : it->second) {
692 if (known.implies(unknown)) return true;
693 }
694
695 return false;
696 }
697
698 bool can_prove_relation(const expr_t &e) const {
699 relation_t unknown(e);
700 auto it = relations_.find(unknown.var());
701 if (it == relations_.end()) return false;
702
703 for (auto &known : it->second) {
704 if (known.implies(unknown)) return true;
705 }
706
707 return false;
708 }
709
710 bool try_prove_compound_relation(const expr_t &e) const {
711 auto *binary = e.as_ptr<binary_op_t>();
712 if (!binary) return false;
713
714 auto op_kind = binary->op_kind;
715 auto &a = binary->a;
716 auto &_b = binary->b;
717
718 if (!is_const(_b)) return false;
719
720 auto b = to_cpp<int64_t>(_b);
721
722 // Normalize operation kind.
723 switch (op_kind) {
724 case op_kind_t::_ge:
725 case op_kind_t::_le: break;
726 case op_kind_t::_gt:
727 op_kind = op_kind_t::_ge;
728 ir_assert(b < std::numeric_limits<int64_t>::max());
729 b += 1;
730 break;
731 case op_kind_t::_lt:
732 op_kind = op_kind_t::_le;
733 ir_assert(b > std::numeric_limits<int64_t>::min());
734 b -= 1;
735 break;
736 default: return false;
737 }
738
739 bound_finder_t finder(relations_);
740 if (op_kind == op_kind_t::_ge) {
741 auto lo = finder.find_low_bound(a);
742 if (!bound_finder_t::is_good_bound(lo)) return false;
743 return lo >= b;
744 }
745
746 if (op_kind == op_kind_t::_le) {
747 auto hi = finder.find_high_bound(a);
748 if (!bound_finder_t::is_good_bound(hi)) return false;
749 return hi <= b;
750 }
751
752 return false;
753 }
754
755 bool can_prove_impl(const expr_t &_e, bool do_simplify) const;
756
757 object_map_t<expr_t, std::vector<relation_t>> relations_;
758 object_map_t<expr_t, std::vector<modulus_info_t>> modulus_infos_;
759};
760
761// Pre-defined functions.
762namespace funcs {
763
764inline func_t barrier_func() {
765 static thread_local auto f = builtin_t::make("barrier");
766 return f;
767}
768
769inline stmt_t barrier() {
770 return barrier_func().call();
771}
772
773inline func_t slm_fence_func() {
774 static thread_local auto f = builtin_t::make("slm_fence");
775 return f;
776}
777
778inline stmt_t slm_fence() {
779 return slm_fence_func().call();
780}
781
782inline func_t signal_func() {
783 static thread_local auto f = builtin_t::make("signal");
784 return f;
785}
786
787inline stmt_t signal() {
788 return signal_func().call();
789}
790
791inline func_t barrier_wait_func() {
792 static thread_local auto f = builtin_t::make("barrier_wait");
793 return f;
794}
795
796inline stmt_t barrier_wait() {
797 return barrier_wait_func().call();
798}
799
800inline func_t continue_func() {
801 static thread_local auto f = builtin_t::make("continue");
802 return f;
803}
804
805inline stmt_t _continue() {
806 return continue_func().call();
807}
808
809} // namespace funcs
810
811} // namespace jit
812} // namespace gpu
813} // namespace impl
814} // namespace dnnl
815
816#endif
817