1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/ir.hpp" |
18 | |
19 | #include <sstream> |
20 | |
21 | #include "common/math_utils.hpp" |
22 | #include "common/optional.hpp" |
23 | #include "gpu/jit/codegen/register_allocator.hpp" |
24 | #include "gpu/jit/ir/core.hpp" |
25 | #include "gpu/jit/ir/message.hpp" |
26 | #include "gpu/jit/pass/simplify.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace jit { |
32 | |
33 | using namespace ir_utils; |
34 | |
35 | namespace { |
36 | |
37 | // Helper class to print IR objects. |
38 | class ir_printer_t : public ir_visitor_t { |
39 | public: |
40 | ir_printer_t(std::ostream &out) : out_(out) {} |
41 | |
42 | void _visit(const alloc_t &obj) override { |
43 | auto guard |
44 | = mem_usage_guard(obj.kind == alloc_kind_t::grf ? obj.size : 0); |
45 | print_indent(); |
46 | out_ << "alloc " << obj.buf.as<var_t>().name << "[" << obj.size |
47 | << "] (mem_usage: " << mem_usage_ << ")\n" ; |
48 | visit(obj.body); |
49 | } |
50 | |
51 | void _visit(const binary_op_t &obj) override { |
52 | if (utils::one_of(obj.op_kind, op_kind_t::_min, op_kind_t::_max)) { |
53 | out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b |
54 | << ")" ; |
55 | return; |
56 | } |
57 | out_ << "(" ; |
58 | visit(obj.a); |
59 | out_ << " " << to_string(obj.op_kind) << " " ; |
60 | visit(obj.b); |
61 | out_ << ")" ; |
62 | } |
63 | |
64 | void _visit(const bool_imm_t &obj) override { |
65 | out_ << (obj.value ? "true" : "false" ); |
66 | } |
67 | |
68 | void _visit(const cast_t &obj) override { |
69 | out_ << obj.type; |
70 | if (obj.saturate) out_ << ".sat" ; |
71 | out_ << "(" << obj.expr << ")" ; |
72 | } |
73 | |
74 | void _visit(const float_imm_t &obj) override { out_ << obj.value; } |
75 | |
76 | void _visit(const for_t &obj) override { |
77 | print_indent(); |
78 | out_ << "for (" << obj.var << " = " << obj.init << "; " << obj.var |
79 | << " < " << obj.bound << "; " << obj.var << "++) " ; |
80 | if (obj.unroll != 1) out_ << "[unroll: " << obj.unroll << "] " ; |
81 | out_ << "{\n" ; |
82 | add_indent(); |
83 | visit(obj.body); |
84 | remove_indent(); |
85 | print_indent(); |
86 | out_ << "}\n" ; |
87 | } |
88 | |
89 | void _visit(const func_call_t &obj) override { |
90 | print_indent(); |
91 | out_ << obj.func << "(" << make_seq_print_helper(obj.args) << ")" ; |
92 | if (!obj.attr.is_empty()) out_ << " " << obj.attr; |
93 | out_ << "\n" ; |
94 | } |
95 | |
96 | void _visit(const func_impl_t &obj) override { out_ << obj.str(); } |
97 | |
98 | void _visit(const if_t &obj) override { |
99 | print_indent(); |
100 | out_ << "if (" << strip_parens(obj.cond.str()) << ") {\n" ; |
101 | add_indent(); |
102 | visit(obj.body); |
103 | remove_indent(); |
104 | print_indent(); |
105 | if (obj.else_body.is_empty()) { |
106 | out_ << "}\n" ; |
107 | return; |
108 | } |
109 | out_ << "} else {\n" ; |
110 | add_indent(); |
111 | visit(obj.else_body); |
112 | remove_indent(); |
113 | print_indent(); |
114 | out_ << "}\n" ; |
115 | } |
116 | |
117 | void _visit(const iif_t &obj) override { |
118 | out_ << "(" << obj.cond << " ? " << obj.true_expr << " : " |
119 | << obj.false_expr << ")" ; |
120 | } |
121 | |
122 | void _visit(const int_imm_t &obj) override { |
123 | out_ << std::to_string(obj.value); |
124 | } |
125 | |
126 | void _visit(const let_t &obj) override { |
127 | // Empty objects are allocated in reserved space |
128 | // nGEN only claims subregisters at dword granularity |
129 | int size = obj.value.is_empty() ? 0 |
130 | : utils::rnd_up(obj.var.type().size(), |
131 | reg_allocator_t::granularity); |
132 | auto guard = mem_usage_guard(size); |
133 | print_indent(); |
134 | out_ << obj.var << "." << obj.var.type() << " = " << obj.value << "\n" ; |
135 | visit(obj.body); |
136 | } |
137 | |
138 | void _visit(const load_t &obj) override { |
139 | out_ << obj.buf; |
140 | if (obj.has_default_stride()) { |
141 | out_ << "." << obj.type << "(" << obj.off / obj.type.size() << ")" ; |
142 | } else { |
143 | out_ << "[" << obj.off << "]." << obj.type; |
144 | out_ << "<" << obj.stride << ">" ; |
145 | } |
146 | } |
147 | |
148 | void _visit(const ptr_t &obj) override { |
149 | out_ << obj.base << "[" << obj.off << "]" ; |
150 | } |
151 | |
152 | void _visit(const shuffle_t &obj) override { |
153 | if (obj.is_broadcast()) { |
154 | out_ << "bcast" << obj.elems() << "(" << obj.vec[0] << ")" ; |
155 | return; |
156 | } |
157 | std::vector<expr_t> vec_all; |
158 | for (auto &v : obj.vec) { |
159 | for (int i = 0; i < v.type().elems(); i++) |
160 | vec_all.push_back(v); |
161 | } |
162 | int elems = obj.type.elems(); |
163 | out_ << "(" ; |
164 | for (int i = 0; i < elems; i++) { |
165 | int idx = obj.idx[i]; |
166 | auto &v = vec_all[idx]; |
167 | int v_elems = v.type().elems(); |
168 | out_ << v; |
169 | if (v_elems != 1) out_ << "[" << idx << "]" ; |
170 | if (i != elems - 1) out_ << ", " ; |
171 | } |
172 | out_ << ")" ; |
173 | } |
174 | |
175 | void _visit(const stmt_group_t &obj) override { |
176 | print_indent(); |
177 | out_ << obj.label << " {\n" ; |
178 | add_indent(); |
179 | visit(obj.body); |
180 | remove_indent(); |
181 | print_indent(); |
182 | out_ << "}\n" ; |
183 | return; |
184 | } |
185 | |
186 | void _visit(const stmt_seq_t &obj) override { |
187 | visit(obj.head); |
188 | visit(obj.tail); |
189 | } |
190 | |
191 | void _visit(const store_t &obj) override { |
192 | print_indent(); |
193 | out_ << load_t::make(obj.value.type(), obj.buf, obj.off, obj.stride); |
194 | out_ << " = " << obj.value; |
195 | if (!obj.mask.is_empty()) { |
196 | out_ << ", mask = " << obj.mask.str(); |
197 | if (obj.fill_mask0) out_ << " [FILL]" ; |
198 | } |
199 | out_ << "\n" ; |
200 | } |
201 | |
202 | void _visit(const ternary_op_t &obj) override { |
203 | out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b << ", " |
204 | << obj.c << ")" ; |
205 | return; |
206 | } |
207 | |
208 | void _visit(const unary_op_t &obj) override { |
209 | out_ << to_string(obj.op_kind); |
210 | visit(obj.a); |
211 | } |
212 | |
213 | void _visit(const var_t &obj) override { out_ << obj.name; } |
214 | |
215 | private: |
216 | mem_usage_guard_t mem_usage_guard(int size) { |
217 | return mem_usage_guard_t(&mem_usage_, size); |
218 | } |
219 | |
220 | static std::string strip_parens(const std::string &s) { |
221 | if (s.size() < 2 || s[0] != '(' || s[s.size() - 1] != ')') return s; |
222 | auto ret = s; |
223 | ret.resize(s.size() - 1); |
224 | return ret.substr(1); |
225 | } |
226 | |
227 | void print_indent() { |
228 | for (int i = 0; i < indent_; i++) |
229 | out_ << prefix_; |
230 | } |
231 | |
232 | void add_indent() { indent_++; } |
233 | void remove_indent() { indent_--; } |
234 | |
235 | std::ostream &out_; |
236 | int indent_ = 0; |
237 | |
238 | std::string prefix_ = " " ; |
239 | |
240 | // Size required for all enclosed let/alloc statements. The value is |
241 | // updated during traversal. |
242 | int mem_usage_ = 0; |
243 | }; |
244 | |
245 | class substitute_mutator_t : public ir_mutator_t { |
246 | public: |
247 | substitute_mutator_t(const object_t &from, const object_t &to) |
248 | : from_(from), to_(to) {} |
249 | |
250 | int substitutions() const { return substitutions_; } |
251 | |
252 | #define HANDLE_IR_OBJECT(type) \ |
253 | object_t _mutate(const type &obj) override { \ |
254 | if (from_.impl() == (const object_impl_t *)&obj) { \ |
255 | substitutions_++; \ |
256 | return to_; \ |
257 | } \ |
258 | return ir_mutator_t::_mutate(obj); \ |
259 | }; |
260 | |
261 | HANDLE_TRAVERSE_TARGETS() |
262 | |
263 | #undef HANDLE_IR_OBJECT |
264 | |
265 | private: |
266 | object_t from_; |
267 | object_t to_; |
268 | |
269 | int substitutions_ = 0; |
270 | }; |
271 | |
272 | class substitute_and_type_mutator_t : public ir_mutator_t { |
273 | public: |
274 | substitute_and_type_mutator_t(const object_t &from, const object_t &to) { |
275 | substitutes_[from] = to; |
276 | } |
277 | |
278 | int substitutions() const { return substitutions_; } |
279 | |
280 | template <typename T> |
281 | object_t _mutate_after(const T &obj) { |
282 | return ir_mutator_t::_mutate(obj); |
283 | } |
284 | |
285 | object_t _mutate_after(const let_t &obj) { |
286 | auto var = mutate(obj.var); |
287 | auto value = mutate(obj.value); |
288 | |
289 | // Allow changing variable types when performing substitutions. Avoids |
290 | // the following invalid substitute transformation sequence: |
291 | // |
292 | // tmp0.s32 -> tmp0_0.u64 |
293 | // tmp1.s32 = tmp0.s32 -> tmp1.s32 = tmp0_0.u64 |
294 | if (!value.is_empty()) { |
295 | auto &value_type = expr_t(value).type(); |
296 | if (var.as<var_t>().type != value_type) { |
297 | auto var_old = var; |
298 | var = var_t::make(value_type, var.as<var_t>().name); |
299 | |
300 | substitutes_[var_old] = var; |
301 | } |
302 | } |
303 | |
304 | auto body = mutate(obj.body); |
305 | |
306 | if (var.is_same(obj.var) && value.is_same(obj.value) |
307 | && body.is_same(obj.body)) |
308 | return obj; |
309 | |
310 | return let_t::make(var, value, body); |
311 | } |
312 | |
313 | #define HANDLE_IR_OBJECT(type) \ |
314 | object_t _mutate(const type &obj) override { \ |
315 | auto it = substitutes_.find(obj); \ |
316 | if (it != substitutes_.end()) { \ |
317 | substitutions_++; \ |
318 | return it->second; \ |
319 | } \ |
320 | return _mutate_after(obj); \ |
321 | }; |
322 | |
323 | HANDLE_ALL_IR_OBJECTS() |
324 | |
325 | #undef HANDLE_IR_OBJECT |
326 | |
327 | private: |
328 | object_eq_map_t<object_t, object_t> substitutes_; |
329 | |
330 | int substitutions_ = 0; |
331 | }; |
332 | |
333 | class stmt_flattener_t : public ir_visitor_t { |
334 | public: |
335 | #define HANDLE_IR_OBJECT(type) \ |
336 | void _visit(const type &obj) { \ |
337 | size_t old_size = stmts.size(); \ |
338 | ir_visitor_t::_visit(obj); \ |
339 | if (stmts.size() > old_size) return; \ |
340 | if (obj.is_stmt()) stmts.push_back(obj); \ |
341 | } |
342 | |
343 | HANDLE_ALL_IR_OBJECTS() |
344 | |
345 | #undef HANDLE_IR_OBJECT |
346 | |
347 | std::vector<stmt_t> stmts; |
348 | }; |
349 | |
350 | class alloc_injector_t : public ir_mutator_t { |
351 | public: |
352 | alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs, |
353 | bool put_innermost) |
354 | : root_(root), put_innermost_(put_innermost), allocs_(allocs) { |
355 | for (auto &_a : allocs) { |
356 | auto &a = _a.as<alloc_t>(); |
357 | if (a.kind != alloc_kind_t::global) ir_assert(a.size > 0) << _a; |
358 | alloc_map_.insert({a.buf, _a}); |
359 | } |
360 | mutate(root_); |
361 | buf_total_refs_ = buf_cur_refs_; |
362 | for (auto &kv : buf_cur_refs_) |
363 | kv.second = 0; |
364 | in_ctor_ = false; |
365 | } |
366 | |
367 | #define HANDLE_IR_OBJECT(type) \ |
368 | object_t _mutate(const type &obj) override { return mutate_stmt(obj); } |
369 | |
370 | HANDLE_STMT_IR_OBJECTS() |
371 | |
372 | #undef HANDLE_IR_OBJECT |
373 | object_t _mutate(const var_t &obj) override { |
374 | if (alloc_map_.find(obj) != alloc_map_.end()) buf_cur_refs_[obj]++; |
375 | return obj; |
376 | } |
377 | |
378 | private: |
379 | template <typename T> |
380 | object_t mutate_stmt(const T &obj) { |
381 | if (in_ctor_) return ir_mutator_t::_mutate(obj); |
382 | object_t new_obj = obj; |
383 | object_set_t<expr_t> undef_bufs; |
384 | if (put_innermost_) { |
385 | for (auto &kv : buf_cur_refs_) |
386 | if (kv.second == 0) undef_bufs.insert(kv.first); |
387 | new_obj = ir_mutator_t::_mutate(obj); |
388 | } |
389 | for (auto &a : allocs_) { |
390 | auto it = alloc_map_.find(a.as<alloc_t>().buf); |
391 | auto &buf = it->first; |
392 | if (it->second.is_empty()) continue; // Already injected. |
393 | bool do_inject = false; |
394 | if (put_innermost_) { |
395 | int cur_refs = buf_cur_refs_[buf]; |
396 | int total_refs = buf_total_refs_[buf]; |
397 | bool was_undef = (undef_bufs.count(buf) != 0); |
398 | do_inject = was_undef && (cur_refs == total_refs); |
399 | } else { |
400 | do_inject = root_.is_same(obj); |
401 | } |
402 | if (do_inject) { |
403 | auto &a = it->second.as<alloc_t>(); |
404 | new_obj = alloc_t::make( |
405 | a.buf, a.size, a.kind, a.attrs, new_obj); |
406 | it->second = stmt_t(); |
407 | } |
408 | } |
409 | return new_obj; |
410 | } |
411 | |
412 | bool in_ctor_ = true; |
413 | const stmt_t &root_; |
414 | bool put_innermost_; |
415 | std::vector<stmt_t> allocs_; |
416 | object_map_t<expr_t, stmt_t> alloc_map_; |
417 | object_map_t<expr_t, int> buf_total_refs_; |
418 | object_map_t<expr_t, int> buf_cur_refs_; |
419 | }; |
420 | |
421 | } // namespace |
422 | |
423 | std::string object_impl_t::str() const { |
424 | std::ostringstream oss; |
425 | ir_printer_t printer(oss); |
426 | printer.visit(this); |
427 | return oss.str(); |
428 | } |
429 | |
430 | object_t substitute(const object_t &root, const object_t &from, |
431 | const object_t &to, int max_substitutions) { |
432 | if (to.is_same(from)) return root; |
433 | substitute_mutator_t sm(from, to); |
434 | auto ret = sm.mutate(root); |
435 | ir_assert(sm.substitutions() <= max_substitutions) |
436 | << "Unexpected number of substitutions." ; |
437 | return ret; |
438 | } |
439 | |
440 | object_t substitute_with_different_type(const object_t &root, |
441 | const object_t &from, const object_t &to, int max_substitutions) { |
442 | if (to.is_same(from)) return root; |
443 | substitute_and_type_mutator_t sm(from, to); |
444 | auto ret = sm.mutate(root); |
445 | ir_assert(sm.substitutions() <= max_substitutions) |
446 | << "Unexpected number of substitutions." ; |
447 | return ret; |
448 | } |
449 | |
450 | std::vector<stmt_t> flatten_statements(const stmt_t &root) { |
451 | stmt_flattener_t f; |
452 | f.visit(root); |
453 | return f.stmts; |
454 | } |
455 | |
456 | stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs, |
457 | bool put_innermost) { |
458 | alloc_injector_t injector(stmt, allocs, put_innermost); |
459 | return injector.mutate(stmt); |
460 | } |
461 | |
462 | stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets) { |
463 | stmt_t ret = stmt; |
464 | for (auto it = lets.rbegin(); it != lets.rend(); ++it) { |
465 | auto &let = it->as<let_t>(); |
466 | ret = let_t::make(let.var, let.value, ret); |
467 | } |
468 | return ret; |
469 | } |
470 | |
471 | expr_t abs(const expr_t &e) { |
472 | ir_assert(is_const(e)) << e; |
473 | if (to_cpp<bool>(e >= 0)) return e; |
474 | return -e; |
475 | } |
476 | |
477 | expr_t cast(const expr_t &e, const type_t &type, bool saturate) { |
478 | return const_fold(cast_t::make(type, e, saturate)); |
479 | } |
480 | |
481 | bool is_zero(const expr_t &e) { |
482 | if (e.is_empty()) return false; |
483 | if (!e.type().is_scalar() || e.type().is_ptr()) return false; |
484 | return e.is_equal(to_expr(0, e.type())); |
485 | } |
486 | |
487 | bool is_one(const expr_t &e) { |
488 | if (e.is_empty()) return false; |
489 | if (!e.type().is_scalar() || e.type().is_ptr()) return false; |
490 | return e.is_equal(to_expr(1, e.type())); |
491 | } |
492 | |
493 | bool is_minus_one(const expr_t &e) { |
494 | if (e.is_empty()) return false; |
495 | if (!e.type().is_scalar() || e.type().is_ptr()) return false; |
496 | return e.is_equal(to_expr(-1, e.type())); |
497 | } |
498 | |
499 | bool is_const_broadcast(const expr_t &e) { |
500 | auto *shuffle = e.as_ptr<shuffle_t>(); |
501 | if (!shuffle) return false; |
502 | if (!shuffle->is_broadcast()) return false; |
503 | return is_const(shuffle->vec[0]); |
504 | } |
505 | |
506 | bool is_const_broadcast(const expr_t &e, const expr_t &value) { |
507 | if (!is_const_broadcast(e)) return false; |
508 | return e.as<shuffle_t>().vec[0].is_equal(value); |
509 | } |
510 | |
511 | expr_t make_buffer(const std::string &name) { |
512 | return var_t::make(type_t::byte_ptr(), name); |
513 | } |
514 | |
515 | // Returns number of occurrences of `obj` in `root` (based on identity equality). |
516 | int count_object(const object_t &root, const object_t &obj) { |
517 | ir_assert(!obj.is_empty()); |
518 | |
519 | std::vector<object_t> found; |
520 | do { |
521 | #define HANDLE_IR_OBJECT(type) \ |
522 | if (obj.dispatch_type_id() == type::_dispatch_type_id()) { \ |
523 | found = find_objects<type>(root); \ |
524 | break; \ |
525 | } |
526 | |
527 | HANDLE_ALL_IR_OBJECTS() |
528 | |
529 | #undef HANDLE_IR_OBJECT |
530 | |
531 | ir_error_not_expected() << obj; |
532 | } while (false); |
533 | |
534 | int ret = 0; |
535 | for (auto &f : found) |
536 | if (f.is_equal(obj)) ret++; |
537 | return ret; |
538 | } |
539 | |
540 | bool contains_object(const object_t &root, const object_t &obj) { |
541 | ir_assert(is_var(obj)) << obj; |
542 | return count_object(root, obj) > 0; |
543 | } |
544 | |
545 | std::vector<stmt_t> find_stmt_groups( |
546 | const object_t &root, const stmt_label_t &label) { |
547 | auto groups = find_objects<stmt_group_t>(root); |
548 | std::vector<stmt_t> ret; |
549 | for (auto &g : groups) { |
550 | if (g.as<stmt_group_t>().label == label) ret.push_back(g); |
551 | } |
552 | return ret; |
553 | } |
554 | |
555 | utils::optional_t<stmt_t> find_stmt_group( |
556 | const object_t &root, const stmt_label_t &label) { |
557 | auto groups = find_stmt_groups(root, label); |
558 | if (groups.size() == 1) |
559 | return groups[0]; |
560 | else |
561 | return utils::nullopt; |
562 | } |
563 | |
564 | class stmt_group_remover_t : public ir_mutator_t { |
565 | public: |
566 | stmt_group_remover_t(stmt_label_t label) : label_(label) {} |
567 | object_t _mutate(const stmt_group_t &obj) override { |
568 | if (obj.label == label_) return stmt_t(); |
569 | return ir_mutator_t::_mutate(obj); |
570 | } |
571 | stmt_label_t label_; |
572 | }; |
573 | |
574 | object_t remove_stmt_group(const object_t &root, stmt_label_t label) { |
575 | stmt_group_remover_t remover(label); |
576 | return remover.mutate(root); |
577 | } |
578 | |
579 | stmt_t get_stmt_body(const stmt_t &stmt) { |
580 | auto *alloc = stmt.as_ptr<alloc_t>(); |
581 | if (alloc) return alloc->body; |
582 | |
583 | auto *_for = stmt.as_ptr<for_t>(); |
584 | if (_for) return _for->body; |
585 | |
586 | auto *let = stmt.as_ptr<let_t>(); |
587 | if (let) return let->body; |
588 | |
589 | auto *group = stmt.as_ptr<stmt_group_t>(); |
590 | if (group) return group->body; |
591 | |
592 | return stmt; |
593 | } |
594 | |
595 | stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body) { |
596 | auto *alloc = stmt.as_ptr<alloc_t>(); |
597 | if (alloc) { |
598 | return alloc_t::make( |
599 | alloc->buf, alloc->size, alloc->kind, alloc->attrs, new_body); |
600 | } |
601 | |
602 | auto *_for = stmt.as_ptr<for_t>(); |
603 | if (_for) { |
604 | return for_t::make( |
605 | _for->var, _for->init, _for->bound, new_body, _for->unroll); |
606 | } |
607 | |
608 | auto *let = stmt.as_ptr<let_t>(); |
609 | if (let) { return let_t::make(let->var, let->value, new_body); } |
610 | |
611 | auto *group = stmt.as_ptr<stmt_group_t>(); |
612 | if (group) { return stmt_group_t::make(group->label, new_body); } |
613 | |
614 | return new_body; |
615 | } |
616 | |
617 | class grf_usage_visitor_t : public ir_visitor_t { |
618 | public: |
619 | grf_usage_visitor_t(int grf_size, int external_usage, bool skip_let) |
620 | : grf_size_(grf_size) |
621 | , skip_let_(skip_let) |
622 | , grf_usage_(external_usage) {} |
623 | |
624 | void _visit(const alloc_t &obj) override { |
625 | int size = (obj.kind == alloc_kind_t::grf ? obj.size : 0); |
626 | size = utils::rnd_up(size, grf_size_); |
627 | auto guard = grf_usage_guard(size); |
628 | ir_visitor_t::_visit(obj); |
629 | } |
630 | |
631 | void _visit(const let_t &obj) override { |
632 | // Empty objects are allocated in reserved space |
633 | // nGEN only claims subregisters at dword granularity |
634 | int size = (skip_let_ || obj.value.is_empty()) |
635 | ? 0 |
636 | : utils::rnd_up( |
637 | obj.var.type().size(), reg_allocator_t::granularity); |
638 | auto guard = grf_usage_guard(size); |
639 | ir_visitor_t::_visit(obj); |
640 | } |
641 | |
642 | int peak_grf_usage() const { return peak_grf_usage_; } |
643 | |
644 | private: |
645 | mem_usage_guard_t grf_usage_guard(int size) { |
646 | auto ret = mem_usage_guard_t(&grf_usage_, size); |
647 | peak_grf_usage_ = std::max(peak_grf_usage_, grf_usage_); |
648 | return ret; |
649 | } |
650 | |
651 | int grf_size_ = 0; |
652 | bool skip_let_ = false; |
653 | int grf_usage_ = 0; |
654 | int peak_grf_usage_ = 0; |
655 | }; |
656 | |
657 | int get_peak_grf_usage( |
658 | const stmt_t &stmt, int grf_size, int external_usage, bool skip_let) { |
659 | grf_usage_visitor_t visitor(grf_size, external_usage, skip_let); |
660 | visitor.visit(stmt); |
661 | return utils::div_up(visitor.peak_grf_usage(), grf_size); |
662 | } |
663 | |
664 | class has_send_atomics_visitor_t : public ir_visitor_t { |
665 | public: |
666 | void _visit(const func_call_t &obj) override { |
667 | auto *send = obj.func.as_ptr<send_t>(); |
668 | if (send && send->is_atomic()) found = true; |
669 | } |
670 | |
671 | bool found = false; |
672 | }; |
673 | |
674 | bool has_send_atomics(const stmt_t &s) { |
675 | has_send_atomics_visitor_t visitor; |
676 | visitor.visit(s); |
677 | return visitor.found; |
678 | } |
679 | |
680 | bool relation_t::implies(const relation_t &other) const { |
681 | ir_assert(var().is_same(other.var())); |
682 | |
683 | if (op_kind() != other.op_kind()) return false; |
684 | |
685 | auto A = to_cpp<int64_t>(rhs()); |
686 | auto B = to_cpp<int64_t>(other.rhs()); |
687 | |
688 | switch (op_kind()) { |
689 | // (x > A) && (A >= B) => (x > B) |
690 | // (x >= A) && (A >= B) => (x >= B) |
691 | case op_kind_t::_gt: |
692 | case op_kind_t::_ge: return A >= B; |
693 | // (x < A) && (A <= B) => (x < B) |
694 | // (x <= A) && (A <= B) => (x <= B) |
695 | case op_kind_t::_lt: |
696 | case op_kind_t::_le: return A <= B; |
697 | default: ir_error_not_expected() << "Not implemented: " << expr_; |
698 | } |
699 | return false; |
700 | } |
701 | |
702 | relation_t relation_t::transform( |
703 | const linear_transform_t &t, const expr_t &new_var) { |
704 | ir_assert(t.a == 1) << "Not implemented." ; |
705 | return relation_t(binary_op_t::make(op_kind(), new_var, rhs() + t.b)); |
706 | } |
707 | |
708 | expr_t relation_t::normalize(const expr_t &e) { |
709 | ir_assert(is_relation_constraint(e)) << e; |
710 | auto &op = e.as<binary_op_t>(); |
711 | |
712 | auto op_kind = op.op_kind; |
713 | auto a = op.a; |
714 | auto b = op.b; |
715 | |
716 | switch (op_kind) { |
717 | case op_kind_t::_lt: |
718 | op_kind = op_kind_t::_le; |
719 | b -= 1; |
720 | break; |
721 | case op_kind_t::_gt: |
722 | op_kind = op_kind_t::_ge; |
723 | b += 1; |
724 | break; |
725 | default: return e; |
726 | } |
727 | return binary_op_t::make(op_kind, a, b); |
728 | } |
729 | |
730 | bool modulus_info_t::is_modulus_constraint(const expr_t &e) { |
731 | auto *binary_op = e.as_ptr<binary_op_t>(); |
732 | if (!binary_op) return false; |
733 | if (!is_zero(binary_op->b)) return false; |
734 | if (binary_op->op_kind != op_kind_t::_eq) return false; |
735 | |
736 | auto *mod_op = binary_op->a.as_ptr<binary_op_t>(); |
737 | if (!mod_op) return false; |
738 | if (mod_op->op_kind != op_kind_t::_mod) return false; |
739 | if (!is_var(mod_op->a)) return false; |
740 | if (!is_const(mod_op->b)) return false; |
741 | |
742 | return true; |
743 | } |
744 | |
745 | int64_t bound_finder_base_t::find_bound_impl( |
746 | const expr_t &e, bool is_low) const { |
747 | int64_t def_bound = unlimited_bound(is_low); |
748 | if (is_const(e)) return to_cpp<int64_t>(e); |
749 | if (is_var(e)) return get_var_bound(e, is_low); |
750 | |
751 | auto *unary = e.as_ptr<unary_op_t>(); |
752 | if (unary) { |
753 | ir_assert(unary->op_kind == op_kind_t::_minus) << e; |
754 | auto a = find_bound_impl(unary->a, !is_low); |
755 | if (!is_good_bound(a)) return def_bound; |
756 | return -a; |
757 | } |
758 | |
759 | auto *binary = e.as_ptr<binary_op_t>(); |
760 | if (binary) { |
761 | switch (binary->op_kind) { |
762 | case op_kind_t::_add: { |
763 | auto a = find_bound_impl(binary->a, is_low); |
764 | auto b = find_bound_impl(binary->b, is_low); |
765 | if (!is_good_bound(a) || !is_good_bound(b)) return def_bound; |
766 | return a + b; |
767 | } |
768 | case op_kind_t::_sub: { |
769 | auto a = find_bound_impl(binary->a, is_low); |
770 | auto b = find_bound_impl(binary->b, !is_low); |
771 | if (!is_good_bound(a) || !is_good_bound(b)) return def_bound; |
772 | return a - b; |
773 | } |
774 | case op_kind_t::_mul: { |
775 | auto a = binary->a; |
776 | auto b = binary->b; |
777 | if (!is_const(a) && is_const(b)) std::swap(a, b); |
778 | if (!is_const(a)) return def_bound; |
779 | |
780 | auto a_const = to_cpp<int64_t>(a); |
781 | if (a_const == 0) return 0; |
782 | |
783 | auto b_lo = find_low_bound(b); |
784 | auto b_hi = find_high_bound(b); |
785 | auto b_lo_ok = is_good_bound(b_lo); |
786 | auto b_hi_ok = is_good_bound(b_hi); |
787 | |
788 | if ((a_const > 0) == is_low && b_lo_ok) return a_const * b_lo; |
789 | if ((a_const > 0) != is_low && b_hi_ok) return a_const * b_hi; |
790 | |
791 | break; |
792 | } |
793 | case op_kind_t::_div: { |
794 | if (!is_const(binary->b)) return def_bound; |
795 | |
796 | auto b = to_cpp<int64_t>(binary->b); |
797 | ir_assert(b != 0); |
798 | |
799 | auto a = find_bound_impl(binary->a, b > 0 ? is_low : !is_low); |
800 | if (!is_good_bound(a)) return def_bound; |
801 | |
802 | bool is_neg = ((a > 0) && (b < 0)) || ((a < 0) && (b > 0)); |
803 | |
804 | int64_t div_bound; |
805 | if (is_low != is_neg) { |
806 | // Truncate away from zero. |
807 | div_bound = utils::div_up(std::abs(a), std::abs(b)); |
808 | } else { |
809 | // Truncate towards zero. |
810 | div_bound = std::abs(a) / std::abs(b); |
811 | } |
812 | if (is_neg) div_bound *= -1; |
813 | return div_bound; |
814 | } |
815 | case op_kind_t::_mod: { |
816 | if (is_low) return 0; |
817 | auto max_mod = find_bound_impl(binary->b, /*is_low=*/false); |
818 | if (!is_good_bound(max_mod)) return def_bound; |
819 | return max_mod - 1; |
820 | } |
821 | case op_kind_t::_and: { |
822 | if (e.type().is_u16()) { |
823 | return is_low ? e.type().min<int64_t>() |
824 | : e.type().max<int64_t>(); |
825 | } |
826 | break; |
827 | } |
828 | default: break; |
829 | } |
830 | } |
831 | |
832 | auto *cast = e.as_ptr<cast_t>(); |
833 | if (cast) { |
834 | // Saturate if needed, otherwise assume the same bounds. |
835 | if (!cast->is_bool_vec_u16() && !cast->saturate) |
836 | return find_bound_impl(cast->expr, is_low); |
837 | |
838 | if (is_low) { |
839 | auto type_lo = cast->type.min<int64_t>(); |
840 | auto lo = find_low_bound(cast->expr); |
841 | return std::max(type_lo, lo); |
842 | } |
843 | // Check u64 explicitly as its max doesn't fit into int64_t. |
844 | if (cast->type.is_u64()) return find_bound_impl(cast->expr, is_low); |
845 | auto type_hi = cast->type.max<int64_t>(); |
846 | auto hi = find_high_bound(cast->expr); |
847 | return std::min(type_hi, hi); |
848 | } |
849 | |
850 | if (e.type().is_bool()) return is_low ? 0 : 1; |
851 | |
852 | return def_bound; |
853 | } |
854 | |
855 | bool is_linear_var_transform(const expr_t &e, linear_transform_t &t) { |
856 | if (is_var(e)) { |
857 | t.x = e; |
858 | t.a = 1; |
859 | t.b = 0; |
860 | return true; |
861 | } |
862 | |
863 | auto *binary_op = e.as_ptr<binary_op_t>(); |
864 | if (!binary_op) return false; |
865 | |
866 | auto vars = find_objects<var_t>(e); |
867 | if (vars.size() != 1) return false; |
868 | |
869 | auto &var = vars[0]; |
870 | |
871 | // TODO: Extend to match multiplication: (a * var). |
872 | if (!utils::one_of(binary_op->op_kind, op_kind_t::_add, op_kind_t::_sub)) |
873 | return false; |
874 | |
875 | auto &a = binary_op->a; |
876 | auto &b = binary_op->b; |
877 | |
878 | bool is_sub = (binary_op->op_kind == op_kind_t::_sub); |
879 | |
880 | // var op b -> (t.a = 1, t.b = +/-b) |
881 | if (a.is_same(var) && is_const(b)) { |
882 | t.x = var; |
883 | t.a = 1; |
884 | t.b = (is_sub ? -1 : 1) * to_cpp<int>(b); |
885 | return true; |
886 | } |
887 | |
888 | // a op var -> (t.a = +/-1, t.b = a) |
889 | if (is_const(a) && b.is_same(var)) { |
890 | t.x = var; |
891 | t.a = (is_sub ? -1 : 1); |
892 | t.b = to_cpp<int>(a); |
893 | return true; |
894 | } |
895 | |
896 | return false; |
897 | } |
898 | |
899 | void ir_context_t::add_constraint(const expr_t &e) { |
900 | cset_.add_constraint(e); |
901 | } |
902 | |
903 | void constraint_set_t::add_constraint(const expr_t &e) { |
904 | auto *shuffle = e.as_ptr<shuffle_t>(); |
905 | if (shuffle) { |
906 | if (shuffle->is_broadcast()) add_constraint(shuffle->vec[0]); |
907 | return; |
908 | } |
909 | |
910 | if (modulus_info_t::is_modulus_constraint(e)) { |
911 | modulus_info_t mi(e); |
912 | modulus_infos_[mi.var()].push_back(mi); |
913 | return; |
914 | } |
915 | |
916 | if (relation_t::is_relation_constraint(e)) { |
917 | relation_t rel(e); |
918 | relations_[rel.var()].push_back(rel); |
919 | return; |
920 | } |
921 | |
922 | // Propagate constraints from y for (x == y) equalities. |
923 | auto *binary_op = e.as_ptr<binary_op_t>(); |
924 | if (binary_op && binary_op->op_kind == op_kind_t::_eq) { |
925 | auto &a = binary_op->a; |
926 | auto &b = binary_op->b; |
927 | linear_transform_t t; |
928 | if (is_var(a) && is_linear_var_transform(b, t)) { |
929 | // Relations. |
930 | auto r_it = relations_.find(t.x); |
931 | if (r_it != relations_.end()) { |
932 | for (auto &c : r_it->second) { |
933 | add_constraint(c.transform(t, a).expr()); |
934 | } |
935 | } |
936 | // Modulus. |
937 | if (t.is_identity()) { |
938 | auto m_it = modulus_infos_.find(t.x); |
939 | if (m_it != modulus_infos_.end()) { |
940 | for (auto &c : m_it->second) { |
941 | add_constraint(substitute(c.expr(), b, a)); |
942 | } |
943 | } |
944 | } |
945 | return; |
946 | } |
947 | } |
948 | } |
949 | |
950 | bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const { |
951 | ir_assert(is_var(e)) << e; |
952 | auto it = relations_.find(e); |
953 | if (it == relations_.end()) return false; |
954 | |
955 | expr_t lo; |
956 | expr_t hi; |
957 | for (auto &rel : it->second) { |
958 | ir_assert(is_const(rel.rhs())) << rel; |
959 | bool do_break = false; |
960 | switch (rel.op_kind()) { |
961 | case op_kind_t::_eq: |
962 | lo = hi = rel.rhs(); |
963 | do_break = true; |
964 | break; |
965 | case op_kind_t::_ge: |
966 | case op_kind_t::_gt: { |
967 | auto cur_lo = (rel.op_kind() == op_kind_t::_ge ? rel.rhs() |
968 | : rel.rhs() + 1); |
969 | if (lo.is_empty() || to_cpp<bool>(cur_lo > lo)) { lo = cur_lo; } |
970 | break; |
971 | } |
972 | case op_kind_t::_le: |
973 | case op_kind_t::_lt: { |
974 | auto cur_hi = (rel.op_kind() == op_kind_t::_le ? rel.rhs() |
975 | : rel.rhs() - 1); |
976 | if (hi.is_empty() || to_cpp<bool>(cur_hi < hi)) { hi = cur_hi; } |
977 | break; |
978 | } |
979 | default: ir_error_not_expected() << rel; |
980 | } |
981 | if (do_break) break; |
982 | } |
983 | bool ret = !lo.is_empty() && lo.is_equal(hi); |
984 | if (ret) value = lo; |
985 | return ret; |
986 | } |
987 | |
988 | bool constraint_set_t::can_prove_impl( |
989 | const expr_t &_e, bool do_simplify) const { |
990 | auto e = _e; |
991 | if (is_const(e)) { |
992 | ir_assert(e.type() == type_t::_bool()) << e; |
993 | return to_cpp<bool>(e); |
994 | } |
995 | |
996 | if (do_simplify) { |
997 | // These passes for comparison help to prove more inequalities. |
998 | e = simplify_cmp_move_const_to_rhs(e); |
999 | e = simplify_cmp_reduce_lhs_rhs(e); |
1000 | e = simplify(e); |
1001 | if (is_const(e)) { |
1002 | ir_assert(e.type() == type_t::_bool()) << e; |
1003 | return to_cpp<bool>(e); |
1004 | } |
1005 | } |
1006 | |
1007 | if (modulus_info_t::is_modulus_constraint(e)) return can_prove_modulus(e); |
1008 | if (relation_t::is_relation_constraint(e)) return can_prove_relation(e); |
1009 | |
1010 | // Try to estimate bounds for compound relation. |
1011 | if (try_prove_compound_relation(e)) return true; |
1012 | |
1013 | // Can't prove. |
1014 | return false; |
1015 | } |
1016 | |
1017 | int constraint_set_t::max_proven_gcd(const expr_t &var) const { |
1018 | auto it = modulus_infos_.find(var); |
1019 | if (it == modulus_infos_.end()) return 1; |
1020 | int ret = 1; |
1021 | for (auto &c : it->second) { |
1022 | ret = math::lcm(ret, to_cpp<int>(c.mod())); |
1023 | } |
1024 | return ret; |
1025 | } |
1026 | |
1027 | } // namespace jit |
1028 | } // namespace gpu |
1029 | } // namespace impl |
1030 | } // namespace dnnl |
1031 | |