1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/hoist.hpp"
18
19#include "gpu/jit/codegen/register_allocator.hpp"
20#include "gpu/jit/ir/message.hpp"
21#include "gpu/jit/utils/trace.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28class sum_expr_t {
29public:
30 sum_expr_t(const expr_t &e)
31 : type_(e.type()), args_(split_by_add(e, e.type().elems())) {}
32
33 std::vector<expr_t> args() const { return args_; }
34
35 bool is_trivial() const { return args_.size() <= 1; }
36
37 expr_t expr() const { return make_add(args_, type_); }
38
39 static expr_t make_add(
40 const std::vector<expr_t> &args, const type_t &type) {
41 auto maybe_bcast = [&](const expr_t &e) {
42 if (e.type().elems() == type.elems()) return e;
43 ir_assert(e.type().is_scalar());
44 return shuffle_t::make_broadcast(e, type.elems());
45 };
46 if (args.empty()) return cast(0, type);
47 auto ret = maybe_bcast(args[0]);
48 for (int i = 1; i < (int)args.size(); i++)
49 ret += maybe_bcast(args[i]);
50 return ret;
51 }
52
53private:
54 static std::vector<expr_t> split_by_add(const expr_t &e, int elems) {
55 auto *shuffle = e.as_ptr<shuffle_t>();
56 if (shuffle && shuffle->is_broadcast() && shuffle->elems() == elems) {
57 return split_by_add(shuffle->vec[0], elems);
58 }
59 auto *op = e.as_ptr<binary_op_t>();
60 if (!op || op->op_kind != op_kind_t::_add) return {e};
61 auto a_args = split_by_add(op->a, elems);
62 auto b_args = split_by_add(op->b, elems);
63 std::vector<expr_t> args;
64 args.insert(args.end(), a_args.begin(), a_args.end());
65 args.insert(args.end(), b_args.begin(), b_args.end());
66 return args;
67 }
68
69 type_t type_;
70 std::vector<expr_t> args_;
71};
72
73class hoist_exprs_mutator_t : public ir_mutator_t {
74public:
75 hoist_exprs_mutator_t(ir_context_t &ir_ctx,
76 int max_hoist_size = std::numeric_limits<int>::max())
77 : ir_ctx_(ir_ctx), max_hoist_size_(max_hoist_size) {}
78
79 ~hoist_exprs_mutator_t() override { ir_assert(let_vars_.empty()); }
80
81 object_t _mutate(const func_call_t &obj) override {
82 if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj);
83
84 std::vector<expr_t> new_args;
85 for (auto &e : obj.args) {
86 new_args.push_back(hoist_expr(e));
87 }
88
89 if (ir_utils::is_equal(new_args, obj.args)) return obj;
90
91 return func_call_t::make(obj.func, new_args, obj.attr);
92 }
93
94 object_t _mutate(const stmt_group_t &obj) override {
95 if (obj.body.is<for_t>()) {
96 loops_.emplace_back(obj.body.as<for_t>().var);
97 const for_t *for_obj = obj.body.as_ptr<for_t>();
98 auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj;
99 if (body.is_same(obj.body)) return obj;
100 auto new_obj = stmt_group_t::make(obj.label, body);
101 return injects_lets_and_pop_loop(new_obj);
102 }
103 return ir_mutator_t::_mutate(obj);
104 }
105
106 object_t _mutate(const store_t &obj) override {
107 auto value = hoist_expr(obj.value);
108 if (value.is_equal(obj.value)) return obj;
109 return store_t::make(obj.buf, obj.off, value, obj.stride);
110 }
111
112 object_t _mutate(const for_t &obj) override {
113 loops_.emplace_back(obj.var);
114 auto new_obj = ir_mutator_t::_mutate(obj);
115 return injects_lets_and_pop_loop(new_obj);
116 }
117
118 object_t _mutate(const let_t &obj) override {
119 bool fully_hoisted = false;
120 expr_t new_value;
121 bool is_const_let = is_const(obj.value) || is_shuffle_const(obj.value);
122 if (is_const_let && loops_.size() > 0 && can_hoist(obj.var)) {
123 fully_hoisted = true;
124 register_let(obj.var, obj.value);
125 add_hoist_let(loops_[0], obj.var, obj.value);
126 } else {
127 new_value = hoist_expr(obj.value, obj.var, &fully_hoisted);
128 }
129 if (fully_hoisted) return mutate(obj.body);
130 register_let(obj.var, new_value);
131 auto new_obj = let_t::make(
132 obj.var, new_value, ir_mutator_t::mutate(obj.body));
133 unregister_let(obj.var);
134 return std::move(new_obj);
135 }
136
137private:
138 struct loop_info_t {
139 loop_info_t(const expr_t &var) : var(var) {}
140
141 expr_t var;
142 int var_count = 0;
143 std::vector<stmt_t> lets;
144 };
145
146 bool can_hoist(const expr_t &expr) {
147 return expr.type().size() <= max_hoist_size_ - current_hoist_size_;
148 }
149
150 void add_hoist_let(
151 loop_info_t &loop, const expr_t &var, const expr_t &value) {
152 loop.lets.emplace_back(let_t::make(var, value));
153 current_hoist_size_ += utils::rnd_up(
154 var.type().size(), reg_allocator_t::granularity);
155 }
156
157 expr_t hoist_expr(const expr_t &expr, const expr_t &expr_var = {},
158 bool *fully_hoisted = nullptr) {
159 if (expr.is_empty()) return expr;
160 if (expr.type().is_ptr()) return expr;
161 if (expr.type().is_bool()) return expr;
162 if (is_const(expr) || is_shuffle_const(expr) || is_var(expr))
163 return expr;
164 if (!can_hoist(expr)) return expr;
165
166 auto hoisted_expr = hoist_expr_with_add(expr, expr_var, fully_hoisted);
167 if (!hoisted_expr.is_equal(expr)) return hoisted_expr;
168
169 // hoist_expr_with_add() doesn't handle cast so try to hoist it manually.
170 auto *cast = expr.as_ptr<cast_t>();
171 if (!cast) return hoisted_expr;
172
173 auto hoisted_cast_expr = hoist_expr(cast->expr);
174 if (!hoisted_cast_expr.is_equal(cast->expr)) {
175 hoisted_expr = cast_t::make(
176 cast->type, hoisted_cast_expr, cast->saturate);
177 }
178 return hoisted_expr;
179 }
180
181 expr_t hoist_expr_with_add(const expr_t &expr, const expr_t &expr_var = {},
182 bool *fully_hoisted = nullptr) {
183 const type_t &type = expr.type();
184 sum_expr_t cur_expr(expr);
185
186 for (size_t i = 0; i < loops_.size(); i++) {
187 std::vector<expr_t> invariant_args;
188 std::vector<expr_t> other_args;
189 std::vector<expr_t> nary_args;
190 if (!cur_expr.is_trivial()) {
191 nary_args = cur_expr.args();
192 } else {
193 nary_args.push_back(cur_expr.expr());
194 }
195 for (auto &a : nary_args) {
196 bool is_inv_arg = true;
197 for (size_t j = i; j < loops_.size(); j++) {
198 if (!is_invariant(a, loops_[j].var)) is_inv_arg = false;
199 }
200 if (is_inv_arg) {
201 invariant_args.push_back(a);
202 } else {
203 other_args.push_back(a);
204 }
205 }
206 // Nothing to hoist for this loop, continue.
207 if (invariant_args.empty()) continue;
208 if (invariant_args.size() == 1 && is_var(invariant_args[0])
209 && !other_args.empty())
210 continue;
211 if (invariant_args.size() == 1
212 && (is_const(invariant_args[0])
213 || is_const_broadcast(invariant_args[0])))
214 continue;
215
216 // Introduce new variable for the invariant sub-expression.
217 auto inv_expr = sum_expr_t::make_add(invariant_args, type);
218 expr_t inv_var;
219 if (!expr_var.is_empty() && other_args.empty()) {
220 // If nothing to hoist further, reuse the old variable and
221 // return.
222 inv_var = expr_var;
223 } else {
224 inv_var = ir_ctx_.create_tmp_var(inv_expr.type());
225 }
226 register_let(inv_var, inv_expr);
227 add_hoist_let(loops_[i], inv_var, inv_expr);
228
229 if (other_args.empty()) {
230 if (fully_hoisted) *fully_hoisted = true;
231 return inv_var;
232 }
233
234 other_args.push_back(inv_var);
235 cur_expr = sum_expr_t::make_add(other_args, type);
236 }
237 return cur_expr.expr();
238 }
239
240 stmt_t injects_lets_and_pop_loop(const stmt_t &_s) {
241 stmt_t s = _s;
242 // Inject let statements if any.
243 auto &lets = loops_.back().lets;
244 for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
245 auto &let = it->as<let_t>();
246 s = let_t::make(let.var, let.value, s);
247 unregister_let(let.var);
248 }
249 loops_.pop_back();
250 return s;
251 }
252
253 void register_let(const expr_t &var, const expr_t &value) {
254 let_vars_.insert({var, value});
255 }
256
257 void unregister_let(const expr_t &var) { let_vars_.erase(var); }
258
259 bool is_invariant(const expr_t &e, const expr_t &var) const {
260 if (contains_object(e, var)) return false;
261 if (!find_objects<load_t>(e).empty()) return false;
262
263 // Check value if this is a let variable.
264 auto it = let_vars_.find(e);
265 if (it != let_vars_.end()) return is_invariant(it->second, var);
266
267 if (is_var(e)) return true;
268
269 // Check transitive dependencies.
270 auto vars = find_unique_objects<var_t>(e);
271 for (auto &v : vars) {
272 if (!is_invariant(v, var)) return false;
273 }
274 return true;
275 }
276
277 ir_context_t &ir_ctx_;
278 std::vector<loop_info_t> loops_;
279 int max_hoist_size_;
280 int current_hoist_size_ = 0;
281
282 object_map_t<expr_t, expr_t> let_vars_;
283};
284stmt_t hoist_exprs_impl(
285 const stmt_t &s, ir_context_t &ir_ctx, int reserved_regs) {
286
287 int grf_size = ir_ctx.hw_cfg().grf_size();
288 int available_regs = ir_ctx.exec_cfg().regs() - reserved_regs;
289 int memory_usage_limit = available_regs * grf_size;
290
291 auto stmt = hoist_exprs_mutator_t(ir_ctx).mutate(s);
292
293 int memory_usage = get_peak_grf_usage(stmt, grf_size) * grf_size;
294 if (memory_usage >= memory_usage_limit) {
295 // Pessimistically hoist expressions. Does not identify and account for
296 // hoists which do not change memory usage.
297 int memory_usage_original = get_peak_grf_usage(s, grf_size) * grf_size;
298 stmt = hoist_exprs_mutator_t(
299 ir_ctx, memory_usage_limit - memory_usage_original)
300 .mutate(s);
301 }
302 return stmt;
303}
304
305stmt_t hoist_exprs(const stmt_t &s, ir_context_t &ir_ctx, int reserved_regs) {
306 trace_start();
307 auto ret = hoist_exprs_impl(s, ir_ctx, reserved_regs);
308 trace_pass("hoist_exprs", ret, ir_ctx);
309 return ret;
310}
311
312class hoist_send_masks_mutator_t : public ir_mutator_t {
313public:
314 hoist_send_masks_mutator_t(
315 ir_context_t &ir_ctx, const stmt_label_t &label, bool split_by_and)
316 : ir_ctx_(ir_ctx), label_(label), split_by_and_(split_by_and) {}
317
318 object_t _mutate(const for_t &obj) override {
319 loop_deps_.insert(obj.var);
320 return ir_mutator_t::_mutate(obj);
321 }
322
323 object_t _mutate(const func_call_t &obj) override {
324 if (!in_stmt_group || !is_func_call<send_t>(obj))
325 return ir_mutator_t::_mutate(obj);
326
327 auto &mask = send_t::arg_mask(obj);
328 if (mask.is_empty()) return ir_mutator_t::_mutate(obj);
329
330 auto new_args = obj.args;
331 auto hoisted_mask = hoist_mask(mask);
332 if (hoisted_mask.is_same(mask)) return ir_mutator_t::_mutate(obj);
333
334 ir_assert(hoisted_mask.type().is_u16()) << hoisted_mask;
335
336 send_t::arg_mask(new_args) = cast(hoisted_mask, mask.type());
337 return func_call_t::make(obj.func, new_args, obj.attr);
338 }
339
340 object_t _mutate(const let_t &obj) override {
341 auto value_vars = find_objects<var_t>(obj.value);
342 for (auto &v : value_vars) {
343 if (is_loop_dependency(v)) {
344 loop_deps_.insert(obj.var);
345 break;
346 }
347 }
348
349 if (in_stmt_group) {
350 ir_assert(!obj.value.is_empty());
351 let_values_.emplace(obj.var, expand(obj.value, value_vars));
352 }
353
354 return ir_mutator_t::_mutate(obj);
355 }
356
357 object_t _mutate(const stmt_group_t &obj) override {
358 bool is_stmt_group = (obj.label == label_);
359 if (is_stmt_group) in_stmt_group = true;
360 auto new_obj = ir_mutator_t::_mutate(obj);
361 if (is_stmt_group) {
362 in_stmt_group = false;
363 return create_mask_stmt(new_obj);
364 }
365 return new_obj;
366 }
367
368private:
369 bool is_loop_dependency(const expr_t &v) const {
370 ir_assert(is_var(v)) << v;
371 return loop_deps_.count(v) != 0;
372 }
373
374 expr_t hoist_mask(const expr_t &e) {
375 ir_assert(e.type().is_bool()) << e;
376
377 if (e.type().elems() > 16) return e;
378 if (is_const(e) || is_shuffle_const(e)) return e;
379
380 // Can't hoist a mask containing loop vars.
381 auto vars = find_objects<var_t>(e);
382 for (auto &v : vars) {
383 if (is_loop_dependency(v)) return e;
384 }
385
386 auto e_expanded = expand(e, vars);
387
388 // Can't hoist a mask containing loads.
389 if (!find_objects<load_t>(e_expanded).empty()) return e;
390
391 auto it = hoisted_masks_.find(e_expanded);
392 if (it != hoisted_masks_.end()) return it->second;
393
394 auto var = ir_ctx_.create_tmp_var(type_t::u16());
395 hoisted_masks_.emplace(e_expanded, var);
396
397 return var;
398 }
399
400 expr_t expand(const expr_t &_e, const std::vector<object_t> &e_vars) const {
401 auto e = _e;
402 for (auto &v : e_vars) {
403 auto it = let_values_.find(v);
404 if (it == let_values_.end()) continue;
405 e = substitute(e, v, it->second);
406 }
407 return e;
408 }
409
410 stmt_t create_mask_stmt(const stmt_t &body) {
411 stmt_t s = body;
412
413 object_eq_map_t<expr_t, expr_t> and_ops;
414 object_eq_map_t<expr_t, expr_t> mask_exprs;
415 for (auto &kv : hoisted_masks_) {
416 if (split_by_and_) {
417 auto e = split_by_and_ops(kv.first, and_ops);
418 mask_exprs.emplace(e, kv.second);
419 }
420 }
421 if (and_ops.size() < mask_exprs.size()) {
422 for (auto &kv : mask_exprs) {
423 s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s);
424 }
425 for (auto &kv : and_ops) {
426 s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s);
427 }
428 } else {
429 for (auto &kv : hoisted_masks_)
430 s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s);
431 }
432
433 return s;
434 }
435
436 expr_t split_by_and_ops(
437 const expr_t &e, object_eq_map_t<expr_t, expr_t> &ops) {
438 auto *binary_op = e.as_ptr<binary_op_t>();
439 if (!binary_op || binary_op->op_kind != op_kind_t::_and) {
440 auto it = ops.find(e);
441 if (it != ops.end()) return it->second;
442
443 auto var = ir_ctx_.create_tmp_var(type_t::u16());
444 ops.emplace(e, var);
445 return var;
446 }
447
448 auto a = split_by_and_ops(binary_op->a, ops);
449 auto b = split_by_and_ops(binary_op->b, ops);
450 return binary_op_t::make(op_kind_t::_and, a, b);
451 }
452
453 bool in_stmt_group = false;
454 object_set_t<expr_t> loop_deps_;
455 object_eq_map_t<expr_t, expr_t> hoisted_masks_;
456 object_map_t<expr_t, expr_t> let_values_;
457
458 ir_context_t &ir_ctx_;
459 stmt_label_t label_;
460 bool split_by_and_;
461};
462
463stmt_t hoist_send_masks(const stmt_t &s, ir_context_t &ir_ctx,
464 const stmt_label_t &label, bool split_by_and) {
465 trace_start();
466 hoist_send_masks_mutator_t mutator(ir_ctx, label, split_by_and);
467
468 auto ret = mutator.mutate(s);
469 trace_pass("hoist_send_masks", ret, ir_ctx);
470 return ret;
471}
472
473} // namespace jit
474} // namespace gpu
475} // namespace impl
476} // namespace dnnl
477