1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/alloc.hpp"
18
19#include "gpu/jit/ir/message.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class alloc_lifter_t : public ir_mutator_t {
28public:
29 alloc_lifter_t(const stmt_t &root, bool reuse_headers)
30 : reuse_headers_(reuse_headers) {
31 if (!reuse_headers_) return;
32 auto calls = find_objects<func_call_t>(root);
33 for (auto &c : calls) {
34 if (!is_func_call<send_t>(c)) continue;
35 auto header_buf = send_t::arg_mem_off(c);
36 ir_assert(is_var(header_buf)) << header_buf;
37 header_bufs_.insert(header_buf);
38 }
39 }
40
41 object_t _mutate(const alloc_t &obj) override {
42 if (!do_lift(obj)) return ir_mutator_t::_mutate(obj);
43 // Remove alloc and insert it before the compute loop.
44 allocs_.push_back(&obj);
45 return obj.body;
46 }
47
48 object_t _mutate(const stmt_group_t &obj) override {
49 bool is_compute_loop = (obj.label == stmt_label_t::compute_loop());
50 if (is_compute_loop) in_compute_loop_ = true;
51 auto new_obj = ir_mutator_t::_mutate(obj);
52 if (is_compute_loop) {
53 in_compute_loop_ = false;
54 // Outermost loop.
55 for (auto it = allocs_.rbegin(); it != allocs_.rend(); ++it) {
56 auto &a = it->as<alloc_t>();
57 new_obj = alloc_t::make(
58 a.buf, a.size, a.kind, a.attrs, new_obj);
59 }
60 allocs_.resize(0);
61 }
62 return new_obj;
63 }
64
65private:
66 bool do_lift(const alloc_t &obj) const {
67 if (!in_compute_loop_) return false;
68 if (reuse_headers_) {
69 bool is_header_alloc = (header_bufs_.count(obj.buf) != 0);
70 return !is_header_alloc;
71 }
72 return true;
73 }
74
75 bool reuse_headers_;
76 object_set_t<expr_t> header_bufs_;
77
78 bool in_compute_loop_ = false;
79 std::vector<stmt_t> allocs_;
80};
81
82stmt_t lift_alloc(const stmt_t &s, ir_context_t &ir_ctx, bool reuse_headers) {
83 trace_start();
84 auto ret = alloc_lifter_t(s, reuse_headers).mutate(s);
85 trace_pass("lift_alloc", ret, ir_ctx);
86 return ret;
87}
88
89class alloc_let_optimizer_t : public ir_mutator_t {
90public:
91 // Also track alloc_t and for_t to validate all variable usages.
92 object_t _mutate(const alloc_t &obj) override {
93 return mutate_scope(obj, obj.buf);
94 }
95
96 object_t _mutate(const for_t &obj) override {
97 level_++;
98 auto new_obj = mutate_scope(obj, obj.var);
99 level_--;
100 return new_obj;
101 }
102
103 object_t _mutate(const let_t &obj) override {
104 return mutate_scope(obj, obj.var);
105 }
106
107 object_t _mutate(const store_t &obj) override {
108 auto &base = (obj.buf.is<var_t>() ? obj.buf : obj.buf.as<ptr_t>().base);
109 // Do not count store references. If there are only stores to a buffer
110 // and no other usages, the buffer can be safely removed.
111 skip_var_ = base;
112 auto new_obj = ir_mutator_t::_mutate(obj);
113 skip_var_ = expr_t();
114 return new_obj;
115 }
116
117 object_t _mutate(const var_t &obj) override {
118 ir_assert(refs_.count(obj) == 1)
119 << "Variable is not defined: " << expr_t(&obj);
120 if (!skip_var_.is_same(obj)) refs_[&obj].update(increment_, level_);
121 return ir_mutator_t::_mutate(obj);
122 }
123
124private:
125 struct ref_info_t {
126 ref_info_t(int level = 0)
127 : refs(0), min_level(level), max_level(level) {}
128
129 void update(int increment, int level) {
130 refs += increment;
131 max_level = std::max(max_level, level);
132 }
133
134 bool is_same_level() const { return min_level == max_level; }
135
136 int refs;
137 int min_level;
138 int max_level;
139 };
140
141 template <typename T>
142 object_t mutate_scope(const T &obj, const expr_t &var) {
143 auto ret = refs_.insert({var, ref_info_t(level_)});
144 ir_assert(ret.second) << stmt_t(obj);
145 MAYBE_UNUSED(ret);
146
147 auto new_obj = ir_mutator_t::_mutate(obj);
148 auto &ref_info = refs_[var];
149
150 if (std::is_same<T, let_t>()) {
151 new_obj = mutate_let(new_obj.template as<let_t>(), ref_info);
152 } else if (std::is_same<T, alloc_t>()) {
153 new_obj = mutate_alloc(new_obj.template as<alloc_t>(), ref_info);
154 }
155
156 refs_.erase(var);
157 return new_obj;
158 }
159
160 object_t mutate_let(const let_t &obj, const ref_info_t &ref_info) {
161 ir_assert(ref_info.refs >= 1);
162 if (ref_info.refs == 1) {
163 // Variable is not used.
164 remove_refs(obj);
165 return obj.body;
166 }
167 // Check following conditions to substitute let value:
168 // - 2 references: one from producer, one from consumer - means single usage
169 // - Consumer and producer are on the same level (same loop)
170 // - Variable is not external
171 if (ref_info.refs == 2 && ref_info.is_same_level()
172 && !obj.value.is_empty()) {
173 return substitute(obj.body, obj.var, obj.value);
174 }
175 return obj;
176 }
177
178 object_t mutate_alloc(const alloc_t &obj, const ref_info_t &ref_info) {
179 ir_assert(ref_info.refs >= 1);
180 // Buffer is not used, single reference from alloc_t itself. Remove
181 // stores to the buffer if any.
182 if (ref_info.refs == 1) return remove_stores(obj.body, obj.buf);
183 return obj;
184 }
185
186 void remove_refs(const let_t &obj) {
187 increment_ = -1;
188 mutate(obj.value);
189 increment_ = 1;
190 }
191
192 // Removes all nested stores to the buffer.
193 stmt_t remove_stores(const stmt_t &stmt, const expr_t &buf) {
194 auto ret = stmt;
195 auto stores = find_objects<store_t>(stmt);
196 for (auto &_s : stores) {
197 auto &s = _s.as<store_t>();
198 auto &base = (s.buf.is<var_t>() ? s.buf : s.buf.as<ptr_t>().base);
199 if (base.is_same(buf)) ret = substitute(ret, _s, stmt_t());
200 }
201 return ret;
202 }
203
204 int increment_ = 1;
205 int level_ = 0;
206
207 expr_t skip_var_;
208 object_map_t<expr_t, ref_info_t> refs_;
209};
210
211stmt_t optimize_alloc_let(const stmt_t &s, ir_context_t &ir_ctx) {
212 trace_start();
213 auto ret = alloc_let_optimizer_t().mutate(s);
214 trace_pass("optimize_alloc_let", ret, ir_ctx);
215 return ret;
216}
217
218} // namespace jit
219} // namespace gpu
220} // namespace impl
221} // namespace dnnl
222