1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/strength_reduce.hpp" |
18 | |
19 | #include "gpu/jit/pass/simplify.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | class loop_strength_reducer_t : public ir_mutator_t { |
28 | public: |
29 | loop_strength_reducer_t() { |
30 | // Create top-level dummy loop. |
31 | loops_.emplace_back(); |
32 | } |
33 | |
34 | ~loop_strength_reducer_t() override { |
35 | // Sanity check, all stores must be applied. |
36 | ir_assert(post_inc_stores.empty()); |
37 | } |
38 | |
39 | object_t _mutate(const for_t &obj) override { |
40 | loops_.emplace_back(obj); |
41 | auto new_obj = ir_mutator_t::_mutate(obj); |
42 | return inject_stores_and_pop_loop(new_obj); |
43 | } |
44 | |
45 | object_t _mutate(const let_t &obj) override { |
46 | int loop_level = int(loops_.size()) - 1; |
47 | auto ret = lets_.insert( |
48 | {obj.var, let_info_t(obj.var, obj.value, loop_level)}); |
49 | ir_assert(ret.second); |
50 | MAYBE_UNUSED(ret); |
51 | auto new_obj = ir_mutator_t::_mutate(obj); |
52 | lets_.erase(obj.var); |
53 | return new_obj; |
54 | } |
55 | |
56 | object_t _mutate(const stmt_group_t &obj) override { |
57 | if (obj.body.is<for_t>()) { |
58 | loops_.emplace_back(obj.body); |
59 | const for_t *for_obj = obj.body.as_ptr<for_t>(); |
60 | auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj; |
61 | if (body.is_same(obj.body)) return obj; |
62 | auto new_obj = stmt_group_t::make(obj.label, body); |
63 | return inject_stores_and_pop_loop(new_obj); |
64 | } |
65 | return ir_mutator_t::_mutate(obj); |
66 | } |
67 | |
68 | // Pattern to handle: |
69 | // for (...) { |
70 | // store(buf_ptr, ...) <- Write (producer). |
71 | // // ... |
72 | // stmt_t(..., buf_ptr, ...) <- Read (consumer). |
73 | // } |
74 | object_t _mutate(const store_t &obj) override { |
75 | if (loops_.size() == 1) return ir_mutator_t::_mutate(obj); |
76 | |
77 | // Try to reduce strength, moving the store up. |
78 | int init_store_level = -1; |
79 | stmt_t init_store_stmt = obj; |
80 | post_inc_store_info_t post_inc_store(obj); |
81 | for (int level = int(loops_.size()) - 1; level >= 1; level--) { |
82 | auto &loop_info = loops_[level]; |
83 | int refs = count_object(loop_info.loop, obj.buf); |
84 | // Producer and consumer - must be 2 references. |
85 | if (refs != 2) break; |
86 | |
87 | // Try to insert the store before level-th loop. |
88 | auto &store = init_store_stmt.as<store_t>(); |
89 | auto &store_value = store.value; |
90 | auto &loop_var = loop_info.loop_var(); |
91 | |
92 | auto cur_value = substitute_let(store_value, level); |
93 | auto next_value = substitute(cur_value, loop_var, loop_var + 1); |
94 | auto inc = simplify(next_value - cur_value); |
95 | |
96 | // Cannot eliminate loop variable, break. |
97 | if (contains_object(inc, loop_var)) break; |
98 | |
99 | // Not scalar, break. |
100 | if (!store_value.type().is_scalar()) break; |
101 | |
102 | // Success, replace store by post-increment store. |
103 | init_store_level = level; |
104 | |
105 | auto new_store_value |
106 | = substitute(cur_value, loop_var, loop_info.loop_init()); |
107 | init_store_stmt = store_t::make(store.buf, store.off, |
108 | simplify(new_store_value), store.stride); |
109 | |
110 | post_inc_store.update(loop_info, inc); |
111 | } |
112 | |
113 | // Can't do anything, return as is. |
114 | if (init_store_level == -1) return ir_mutator_t::_mutate(obj); |
115 | |
116 | // Move this store up, remove from here. |
117 | loops_[init_store_level].init_stores.push_back(init_store_stmt); |
118 | if (!post_inc_store.is_empty()) { |
119 | auto ret = post_inc_stores.insert({obj.buf, post_inc_store}); |
120 | ir_assert(ret.second); |
121 | MAYBE_UNUSED(ret); |
122 | } |
123 | return stmt_t(); |
124 | } |
125 | |
126 | object_t _mutate(const func_call_t &obj) override { |
127 | for (auto &kv : post_inc_stores) { |
128 | int refs = count_object(obj, kv.first); |
129 | if (refs == 1) { |
130 | auto ret = stmt_seq_t::make(obj, kv.second.stmt()); |
131 | post_inc_stores.erase(kv.first); |
132 | return std::move(ret); |
133 | } |
134 | } |
135 | return ir_mutator_t::_mutate(obj); |
136 | } |
137 | |
138 | private: |
139 | struct loop_info_t { |
140 | loop_info_t(const stmt_t &loop = {}) : loop(loop) {} |
141 | |
142 | const expr_t &loop_var() const { return loop.as<for_t>().var; } |
143 | |
144 | const expr_t &loop_init() const { return loop.as<for_t>().init; } |
145 | |
146 | const expr_t &loop_bound() const { return loop.as<for_t>().bound; } |
147 | |
148 | expr_t loop_extent() const { return loop_bound() - loop_init(); } |
149 | |
150 | // Loop being analyzed. |
151 | stmt_t loop; |
152 | // Stores to insert before the loop. |
153 | std::vector<stmt_t> init_stores; |
154 | |
155 | std::vector<stmt_t> lets; |
156 | }; |
157 | |
158 | struct let_info_t { |
159 | let_info_t(const expr_t &var, const expr_t &value, int loop_level) |
160 | : var(var), value(value), loop_level(loop_level) {} |
161 | |
162 | expr_t var; |
163 | expr_t value; |
164 | int loop_level; |
165 | }; |
166 | |
167 | struct post_inc_store_info_t { |
168 | post_inc_store_info_t(const store_t &obj) |
169 | : store(&obj), inc(0), last_iter_cond(true), compensation(0) {} |
170 | |
171 | stmt_t stmt() const { |
172 | auto load |
173 | = load_t::make(store->value.type(), store->buf, store->off); |
174 | return store_t::make(store->buf, store->off, load + inc); |
175 | } |
176 | |
177 | bool is_empty() const { return is_zero(inc); } |
178 | |
179 | void update(const loop_info_t &loop, const expr_t &loop_inc) { |
180 | inc = simplify(iif_t::make( |
181 | last_iter_cond, inc - compensation + loop_inc, inc)); |
182 | if (last_iter_cond.is_equal(expr_t(true))) { |
183 | last_iter_cond = (loop.loop_var() == loop.loop_bound() - 1); |
184 | } else { |
185 | last_iter_cond = last_iter_cond |
186 | & (loop.loop_var() == loop.loop_bound() - 1); |
187 | } |
188 | compensation = simplify(loop.loop_extent() * loop_inc); |
189 | } |
190 | |
191 | const store_t *store; |
192 | expr_t inc; |
193 | |
194 | expr_t last_iter_cond; |
195 | expr_t compensation; |
196 | }; |
197 | |
198 | // Recursively substitutes all variable from let statements located under |
199 | // the given loop level. |
200 | expr_t substitute_let(const expr_t &_e, int loop_level) const { |
201 | auto e = _e; |
202 | for (;;) { |
203 | bool found = false; |
204 | auto vars = find_unique_objects<var_t>(e); |
205 | for (auto &v : vars) { |
206 | auto it = lets_.find(v); |
207 | if (it == lets_.end()) continue; |
208 | auto &let_info = it->second; |
209 | // Do not substitute top-level let variables. |
210 | if (let_info.loop_level < loop_level) continue; |
211 | found = true; |
212 | e = substitute(e, v, let_info.value); |
213 | } |
214 | if (!found) break; |
215 | } |
216 | return e; |
217 | } |
218 | |
219 | // Injects initial store statements if any. |
220 | object_t inject_stores_and_pop_loop(const stmt_t &_s) { |
221 | stmt_t s = _s; |
222 | auto &stores = loops_.back().init_stores; |
223 | for (auto it = stores.rbegin(); it != stores.rend(); ++it) { |
224 | s = stmt_seq_t::make(*it, s); |
225 | } |
226 | loops_.pop_back(); |
227 | // The top-level dummy loop shouldn't be removed. |
228 | ir_assert(loops_.size() >= 1); |
229 | return std::move(s); |
230 | } |
231 | |
232 | // Loops, ordered from outermost to innermost. The first loop is dummy, to |
233 | // represent let statements in the top-level scope. |
234 | std::vector<loop_info_t> loops_; |
235 | |
236 | // Buffers whose references are to be updated. |
237 | object_map_t<expr_t, post_inc_store_info_t> post_inc_stores; |
238 | |
239 | // Let statements available at the current IR node. |
240 | object_map_t<expr_t, let_info_t> lets_; |
241 | }; |
242 | |
243 | stmt_t loop_strength_reduce(const stmt_t &s, ir_context_t &ir_ctx) { |
244 | trace_start(); |
245 | auto ret = loop_strength_reducer_t().mutate(s); |
246 | trace_pass("loop_strength_reduce" , ret, ir_ctx); |
247 | return ret; |
248 | } |
249 | |
250 | } // namespace jit |
251 | } // namespace gpu |
252 | } // namespace impl |
253 | } // namespace dnnl |
254 | |