1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/unroll.hpp" |
18 | |
19 | #include "gpu/jit/utils/trace.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace jit { |
25 | |
26 | class unrolling_updater_t : public ir_mutator_t { |
27 | public: |
28 | object_t _mutate(const let_t &obj) override { |
29 | if (level_ == 0) { |
30 | // Skip top-level let statements. |
31 | return ir_mutator_t::_mutate(obj); |
32 | } |
33 | lets_.push_back(&obj); |
34 | auto new_body = mutate(obj.body); |
35 | if (!lets_.back()) { |
36 | // Let was moved to the innermost loop. |
37 | lets_.pop_back(); |
38 | return new_body; |
39 | } |
40 | lets_.pop_back(); |
41 | if (new_body.is_same(obj.body)) return obj; |
42 | return let_t::make(obj.var, obj.value, new_body); |
43 | } |
44 | |
45 | object_t _mutate(const for_t &obj) override { |
46 | if (in_compute_loop_) level_++; |
47 | found_loop_ = false; |
48 | auto new_obj = ir_mutator_t::_mutate(obj); |
49 | if (in_compute_loop_) level_--; |
50 | if (!found_loop_) { |
51 | // Innermost loop, inject let statements. |
52 | auto body = get_stmt_body(new_obj); |
53 | for (auto it = lets_.rbegin(); it != lets_.rend(); ++it) { |
54 | body = let_t::make((*it)->var, (*it)->value, body); |
55 | *it = nullptr; |
56 | } |
57 | new_obj = replace_stmt_body(new_obj, body); |
58 | } |
59 | found_loop_ = true; |
60 | return new_obj; |
61 | } |
62 | |
63 | object_t _mutate(const stmt_group_t &obj) override { |
64 | if (obj.label == stmt_label_t::compute_loop()) { |
65 | in_compute_loop_ = true; |
66 | } |
67 | auto new_obj = ir_mutator_t::_mutate(obj); |
68 | if (obj.label == stmt_label_t::compute_loop()) { |
69 | in_compute_loop_ = false; |
70 | } |
71 | return new_obj; |
72 | } |
73 | |
74 | private: |
75 | bool found_loop_ = false; |
76 | bool in_compute_loop_ = false; |
77 | int level_ = 0; |
78 | std::vector<const let_t *> lets_; |
79 | }; |
80 | |
81 | stmt_t update_loops_for_unrolling(const stmt_t &s, ir_context_t &ir_ctx) { |
82 | trace_start(); |
83 | auto ret = unrolling_updater_t().mutate(s); |
84 | trace_pass("update_loops_for_unrolling" , ret, ir_ctx); |
85 | return ret; |
86 | } |
87 | |
88 | class loop_unroller_t : public ir_mutator_t { |
89 | public: |
90 | loop_unroller_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {} |
91 | |
92 | object_t _mutate(const for_t &obj) override { |
93 | auto new_obj = ir_mutator_t::_mutate(obj); |
94 | auto &_for = new_obj.as<for_t>(); |
95 | // No unrolling. |
96 | if (_for.unroll == 1) return new_obj; |
97 | |
98 | ir_assert(is_const(_for.init)) |
99 | << "Can't unroll loop with non-const bound: " << _for.init; |
100 | ir_assert(is_const(_for.bound)) |
101 | << "Can't unroll loop with non-const bound: " << _for.bound; |
102 | |
103 | auto init = to_cpp<int>(_for.init); |
104 | auto bound = to_cpp<int>(_for.bound); |
105 | |
106 | ir_assert(_for.unroll == (bound - init)) |
107 | << "Only full loop unroll is supported." ; |
108 | |
109 | stmt_t ret; |
110 | for (int i = init; i < bound; i++) { |
111 | auto iter_stmt = substitute( |
112 | _for.body, _for.var, to_expr(i, _for.var.type())); |
113 | iter_stmt = rename_let_alloc(iter_stmt, i - init); |
114 | ret = ret.append(iter_stmt); |
115 | } |
116 | return std::move(ret); |
117 | } |
118 | |
119 | private: |
120 | stmt_t rename_let_alloc(const stmt_t &s, int idx) { |
121 | auto lets = find_objects<let_t>(s); |
122 | auto ret = s; |
123 | for (auto &_let : lets) { |
124 | auto &let = _let.as<let_t>(); |
125 | auto &var = let.var.as<var_t>(); |
126 | auto new_var = ir_ctx_.create_tmp_var(var.type, var.name); |
127 | ret = substitute(ret, let.var, new_var); |
128 | } |
129 | auto allocs = find_objects<alloc_t>(s); |
130 | for (auto &_alloc : allocs) { |
131 | auto &alloc = _alloc.as<alloc_t>(); |
132 | auto &buf = alloc.buf.as<var_t>(); |
133 | auto new_buf = ir_ctx_.create_tmp_var(buf.type, buf.name); |
134 | ret = substitute(ret, alloc.buf, new_buf); |
135 | } |
136 | return ret; |
137 | } |
138 | |
139 | ir_context_t &ir_ctx_; |
140 | }; |
141 | |
142 | stmt_t unroll_loops(const stmt_t &s, ir_context_t &ir_ctx) { |
143 | trace_start(); |
144 | auto ret = loop_unroller_t(ir_ctx).mutate(s); |
145 | trace_pass("unroll_loops" , ret, ir_ctx); |
146 | return ret; |
147 | } |
148 | |
149 | } // namespace jit |
150 | } // namespace gpu |
151 | } // namespace impl |
152 | } // namespace dnnl |
153 | |