1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/alloc.hpp" |
18 | |
19 | #include "gpu/jit/ir/message.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | class alloc_lifter_t : public ir_mutator_t { |
28 | public: |
29 | alloc_lifter_t(const stmt_t &root, bool ) |
30 | : reuse_headers_(reuse_headers) { |
31 | if (!reuse_headers_) return; |
32 | auto calls = find_objects<func_call_t>(root); |
33 | for (auto &c : calls) { |
34 | if (!is_func_call<send_t>(c)) continue; |
35 | auto header_buf = send_t::arg_mem_off(c); |
36 | ir_assert(is_var(header_buf)) << header_buf; |
37 | header_bufs_.insert(header_buf); |
38 | } |
39 | } |
40 | |
41 | object_t _mutate(const alloc_t &obj) override { |
42 | if (!do_lift(obj)) return ir_mutator_t::_mutate(obj); |
43 | // Remove alloc and insert it before the compute loop. |
44 | allocs_.push_back(&obj); |
45 | return obj.body; |
46 | } |
47 | |
48 | object_t _mutate(const stmt_group_t &obj) override { |
49 | bool is_compute_loop = (obj.label == stmt_label_t::compute_loop()); |
50 | if (is_compute_loop) in_compute_loop_ = true; |
51 | auto new_obj = ir_mutator_t::_mutate(obj); |
52 | if (is_compute_loop) { |
53 | in_compute_loop_ = false; |
54 | // Outermost loop. |
55 | for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) { |
56 | auto &a = it->as<alloc_t>(); |
57 | new_obj = alloc_t::make( |
58 | a.buf, a.size, a.kind, a.attrs, new_obj); |
59 | } |
60 | allocs_.resize(0); |
61 | } |
62 | return new_obj; |
63 | } |
64 | |
65 | private: |
66 | bool do_lift(const alloc_t &obj) const { |
67 | if (!in_compute_loop_) return false; |
68 | if (reuse_headers_) { |
69 | bool = (header_bufs_.count(obj.buf) != 0); |
70 | return !is_header_alloc; |
71 | } |
72 | return true; |
73 | } |
74 | |
75 | bool ; |
76 | object_set_t<expr_t> ; |
77 | |
78 | bool in_compute_loop_ = false; |
79 | std::vector<stmt_t> allocs_; |
80 | }; |
81 | |
82 | stmt_t lift_alloc(const stmt_t &s, ir_context_t &ir_ctx, bool ) { |
83 | trace_start(); |
84 | auto ret = alloc_lifter_t(s, reuse_headers).mutate(s); |
85 | trace_pass("lift_alloc" , ret, ir_ctx); |
86 | return ret; |
87 | } |
88 | |
89 | class alloc_let_optimizer_t : public ir_mutator_t { |
90 | public: |
91 | // Also track alloc_t and for_t to validate all variable usages. |
92 | object_t _mutate(const alloc_t &obj) override { |
93 | return mutate_scope(obj, obj.buf); |
94 | } |
95 | |
96 | object_t _mutate(const for_t &obj) override { |
97 | level_++; |
98 | auto new_obj = mutate_scope(obj, obj.var); |
99 | level_--; |
100 | return new_obj; |
101 | } |
102 | |
103 | object_t _mutate(const let_t &obj) override { |
104 | return mutate_scope(obj, obj.var); |
105 | } |
106 | |
107 | object_t _mutate(const store_t &obj) override { |
108 | auto &base = (obj.buf.is<var_t>() ? obj.buf : obj.buf.as<ptr_t>().base); |
109 | // Do not count store references. If there are only stores to a buffer |
110 | // and no other usages, the buffer can be safely removed. |
111 | skip_var_ = base; |
112 | auto new_obj = ir_mutator_t::_mutate(obj); |
113 | skip_var_ = expr_t(); |
114 | return new_obj; |
115 | } |
116 | |
117 | object_t _mutate(const var_t &obj) override { |
118 | ir_assert(refs_.count(obj) == 1) |
119 | << "Variable is not defined: " << expr_t(&obj); |
120 | if (!skip_var_.is_same(obj)) refs_[&obj].update(increment_, level_); |
121 | return ir_mutator_t::_mutate(obj); |
122 | } |
123 | |
124 | private: |
125 | struct ref_info_t { |
126 | ref_info_t(int level = 0) |
127 | : refs(0), min_level(level), max_level(level) {} |
128 | |
129 | void update(int increment, int level) { |
130 | refs += increment; |
131 | max_level = std::max(max_level, level); |
132 | } |
133 | |
134 | bool is_same_level() const { return min_level == max_level; } |
135 | |
136 | int refs; |
137 | int min_level; |
138 | int max_level; |
139 | }; |
140 | |
141 | template <typename T> |
142 | object_t mutate_scope(const T &obj, const expr_t &var) { |
143 | auto ret = refs_.insert({var, ref_info_t(level_)}); |
144 | ir_assert(ret.second) << stmt_t(obj); |
145 | MAYBE_UNUSED(ret); |
146 | |
147 | auto new_obj = ir_mutator_t::_mutate(obj); |
148 | auto &ref_info = refs_[var]; |
149 | |
150 | if (std::is_same<T, let_t>()) { |
151 | new_obj = mutate_let(new_obj.template as<let_t>(), ref_info); |
152 | } else if (std::is_same<T, alloc_t>()) { |
153 | new_obj = mutate_alloc(new_obj.template as<alloc_t>(), ref_info); |
154 | } |
155 | |
156 | refs_.erase(var); |
157 | return new_obj; |
158 | } |
159 | |
160 | object_t mutate_let(const let_t &obj, const ref_info_t &ref_info) { |
161 | ir_assert(ref_info.refs >= 1); |
162 | if (ref_info.refs == 1) { |
163 | // Variable is not used. |
164 | remove_refs(obj); |
165 | return obj.body; |
166 | } |
167 | // Check following conditions to substitute let value: |
168 | // - 2 references: one from producer, one from consumer - means single usage |
169 | // - Consumer and producer are on the same level (same loop) |
170 | // - Variable is not external |
171 | if (ref_info.refs == 2 && ref_info.is_same_level() |
172 | && !obj.value.is_empty()) { |
173 | return substitute(obj.body, obj.var, obj.value); |
174 | } |
175 | return obj; |
176 | } |
177 | |
178 | object_t mutate_alloc(const alloc_t &obj, const ref_info_t &ref_info) { |
179 | ir_assert(ref_info.refs >= 1); |
180 | // Buffer is not used, single reference from alloc_t itself. Remove |
181 | // stores to the buffer if any. |
182 | if (ref_info.refs == 1) return remove_stores(obj.body, obj.buf); |
183 | return obj; |
184 | } |
185 | |
186 | void remove_refs(const let_t &obj) { |
187 | increment_ = -1; |
188 | mutate(obj.value); |
189 | increment_ = 1; |
190 | } |
191 | |
192 | // Removes all nested stores to the buffer. |
193 | stmt_t remove_stores(const stmt_t &stmt, const expr_t &buf) { |
194 | auto ret = stmt; |
195 | auto stores = find_objects<store_t>(stmt); |
196 | for (auto &_s : stores) { |
197 | auto &s = _s.as<store_t>(); |
198 | auto &base = (s.buf.is<var_t>() ? s.buf : s.buf.as<ptr_t>().base); |
199 | if (base.is_same(buf)) ret = substitute(ret, _s, stmt_t()); |
200 | } |
201 | return ret; |
202 | } |
203 | |
204 | int increment_ = 1; |
205 | int level_ = 0; |
206 | |
207 | expr_t skip_var_; |
208 | object_map_t<expr_t, ref_info_t> refs_; |
209 | }; |
210 | |
211 | stmt_t optimize_alloc_let(const stmt_t &s, ir_context_t &ir_ctx) { |
212 | trace_start(); |
213 | auto ret = alloc_let_optimizer_t().mutate(s); |
214 | trace_pass("optimize_alloc_let" , ret, ir_ctx); |
215 | return ret; |
216 | } |
217 | |
218 | } // namespace jit |
219 | } // namespace gpu |
220 | } // namespace impl |
221 | } // namespace dnnl |
222 | |