1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/strength_reduce.hpp"
18
19#include "gpu/jit/pass/simplify.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class loop_strength_reducer_t : public ir_mutator_t {
28public:
29 loop_strength_reducer_t() {
30 // Create top-level dummy loop.
31 loops_.emplace_back();
32 }
33
34 ~loop_strength_reducer_t() override {
35 // Sanity check, all stores must be applied.
36 ir_assert(post_inc_stores.empty());
37 }
38
39 object_t _mutate(const for_t &obj) override {
40 loops_.emplace_back(obj);
41 auto new_obj = ir_mutator_t::_mutate(obj);
42 return inject_stores_and_pop_loop(new_obj);
43 }
44
45 object_t _mutate(const let_t &obj) override {
46 int loop_level = int(loops_.size()) - 1;
47 auto ret = lets_.insert(
48 {obj.var, let_info_t(obj.var, obj.value, loop_level)});
49 ir_assert(ret.second);
50 MAYBE_UNUSED(ret);
51 auto new_obj = ir_mutator_t::_mutate(obj);
52 lets_.erase(obj.var);
53 return new_obj;
54 }
55
56 object_t _mutate(const stmt_group_t &obj) override {
57 if (obj.body.is<for_t>()) {
58 loops_.emplace_back(obj.body);
59 const for_t *for_obj = obj.body.as_ptr<for_t>();
60 auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
61 if (body.is_same(obj.body)) return obj;
62 auto new_obj = stmt_group_t::make(obj.label, body);
63 return inject_stores_and_pop_loop(new_obj);
64 }
65 return ir_mutator_t::_mutate(obj);
66 }
67
68 // Pattern to handle:
69 // for (...) {
70 // store(buf_ptr, ...) <- Write (producer).
71 // // ...
72 // stmt_t(..., buf_ptr, ...) <- Read (consumer).
73 // }
74 object_t _mutate(const store_t &obj) override {
75 if (loops_.size() == 1) return ir_mutator_t::_mutate(obj);
76
77 // Try to reduce strength, moving the store up.
78 int init_store_level = -1;
79 stmt_t init_store_stmt = obj;
80 post_inc_store_info_t post_inc_store(obj);
81 for (int level = int(loops_.size()) - 1; level >= 1; level--) {
82 auto &loop_info = loops_[level];
83 int refs = count_object(loop_info.loop, obj.buf);
84 // Producer and consumer - must be 2 references.
85 if (refs != 2) break;
86
87 // Try to insert the store before level-th loop.
88 auto &store = init_store_stmt.as<store_t>();
89 auto &store_value = store.value;
90 auto &loop_var = loop_info.loop_var();
91
92 auto cur_value = substitute_let(store_value, level);
93 auto next_value = substitute(cur_value, loop_var, loop_var + 1);
94 auto inc = simplify(next_value - cur_value);
95
96 // Cannot eliminate loop variable, break.
97 if (contains_object(inc, loop_var)) break;
98
99 // Not scalar, break.
100 if (!store_value.type().is_scalar()) break;
101
102 // Success, replace store by post-increment store.
103 init_store_level = level;
104
105 auto new_store_value
106 = substitute(cur_value, loop_var, loop_info.loop_init());
107 init_store_stmt = store_t::make(store.buf, store.off,
108 simplify(new_store_value), store.stride);
109
110 post_inc_store.update(loop_info, inc);
111 }
112
113 // Can't do anything, return as is.
114 if (init_store_level == -1) return ir_mutator_t::_mutate(obj);
115
116 // Move this store up, remove from here.
117 loops_[init_store_level].init_stores.push_back(init_store_stmt);
118 if (!post_inc_store.is_empty()) {
119 auto ret = post_inc_stores.insert({obj.buf, post_inc_store});
120 ir_assert(ret.second);
121 MAYBE_UNUSED(ret);
122 }
123 return stmt_t();
124 }
125
126 object_t _mutate(const func_call_t &obj) override {
127 for (auto &kv : post_inc_stores) {
128 int refs = count_object(obj, kv.first);
129 if (refs == 1) {
130 auto ret = stmt_seq_t::make(obj, kv.second.stmt());
131 post_inc_stores.erase(kv.first);
132 return std::move(ret);
133 }
134 }
135 return ir_mutator_t::_mutate(obj);
136 }
137
138private:
139 struct loop_info_t {
140 loop_info_t(const stmt_t &loop = {}) : loop(loop) {}
141
142 const expr_t &loop_var() const { return loop.as<for_t>().var; }
143
144 const expr_t &loop_init() const { return loop.as<for_t>().init; }
145
146 const expr_t &loop_bound() const { return loop.as<for_t>().bound; }
147
148 expr_t loop_extent() const { return loop_bound() - loop_init(); }
149
150 // Loop being analyzed.
151 stmt_t loop;
152 // Stores to insert before the loop.
153 std::vector<stmt_t> init_stores;
154
155 std::vector<stmt_t> lets;
156 };
157
158 struct let_info_t {
159 let_info_t(const expr_t &var, const expr_t &value, int loop_level)
160 : var(var), value(value), loop_level(loop_level) {}
161
162 expr_t var;
163 expr_t value;
164 int loop_level;
165 };
166
167 struct post_inc_store_info_t {
168 post_inc_store_info_t(const store_t &obj)
169 : store(&obj), inc(0), last_iter_cond(true), compensation(0) {}
170
171 stmt_t stmt() const {
172 auto load
173 = load_t::make(store->value.type(), store->buf, store->off);
174 return store_t::make(store->buf, store->off, load + inc);
175 }
176
177 bool is_empty() const { return is_zero(inc); }
178
179 void update(const loop_info_t &loop, const expr_t &loop_inc) {
180 inc = simplify(iif_t::make(
181 last_iter_cond, inc - compensation + loop_inc, inc));
182 if (last_iter_cond.is_equal(expr_t(true))) {
183 last_iter_cond = (loop.loop_var() == loop.loop_bound() - 1);
184 } else {
185 last_iter_cond = last_iter_cond
186 & (loop.loop_var() == loop.loop_bound() - 1);
187 }
188 compensation = simplify(loop.loop_extent() * loop_inc);
189 }
190
191 const store_t *store;
192 expr_t inc;
193
194 expr_t last_iter_cond;
195 expr_t compensation;
196 };
197
198 // Recursively substitutes all variable from let statements located under
199 // the given loop level.
200 expr_t substitute_let(const expr_t &_e, int loop_level) const {
201 auto e = _e;
202 for (;;) {
203 bool found = false;
204 auto vars = find_unique_objects<var_t>(e);
205 for (auto &v : vars) {
206 auto it = lets_.find(v);
207 if (it == lets_.end()) continue;
208 auto &let_info = it->second;
209 // Do not substitute top-level let variables.
210 if (let_info.loop_level < loop_level) continue;
211 found = true;
212 e = substitute(e, v, let_info.value);
213 }
214 if (!found) break;
215 }
216 return e;
217 }
218
219 // Injects initial store statements if any.
220 object_t inject_stores_and_pop_loop(const stmt_t &_s) {
221 stmt_t s = _s;
222 auto &stores = loops_.back().init_stores;
223 for (auto it = stores.rbegin(); it != stores.rend(); ++it) {
224 s = stmt_seq_t::make(*it, s);
225 }
226 loops_.pop_back();
227 // The top-level dummy loop shouldn't be removed.
228 ir_assert(loops_.size() >= 1);
229 return std::move(s);
230 }
231
232 // Loops, ordered from outermost to innermost. The first loop is dummy, to
233 // represent let statements in the top-level scope.
234 std::vector<loop_info_t> loops_;
235
236 // Buffers whose references are to be updated.
237 object_map_t<expr_t, post_inc_store_info_t> post_inc_stores;
238
239 // Let statements available at the current IR node.
240 object_map_t<expr_t, let_info_t> lets_;
241};
242
243stmt_t loop_strength_reduce(const stmt_t &s, ir_context_t &ir_ctx) {
244 trace_start();
245 auto ret = loop_strength_reducer_t().mutate(s);
246 trace_pass("loop_strength_reduce", ret, ir_ctx);
247 return ret;
248}
249
250} // namespace jit
251} // namespace gpu
252} // namespace impl
253} // namespace dnnl
254