1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/hoist.hpp" |
18 | |
19 | #include "gpu/jit/codegen/register_allocator.hpp" |
20 | #include "gpu/jit/ir/message.hpp" |
21 | #include "gpu/jit/utils/trace.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | class sum_expr_t { |
29 | public: |
30 | sum_expr_t(const expr_t &e) |
31 | : type_(e.type()), args_(split_by_add(e, e.type().elems())) {} |
32 | |
33 | std::vector<expr_t> args() const { return args_; } |
34 | |
35 | bool is_trivial() const { return args_.size() <= 1; } |
36 | |
37 | expr_t expr() const { return make_add(args_, type_); } |
38 | |
39 | static expr_t make_add( |
40 | const std::vector<expr_t> &args, const type_t &type) { |
41 | auto maybe_bcast = [&](const expr_t &e) { |
42 | if (e.type().elems() == type.elems()) return e; |
43 | ir_assert(e.type().is_scalar()); |
44 | return shuffle_t::make_broadcast(e, type.elems()); |
45 | }; |
46 | if (args.empty()) return cast(0, type); |
47 | auto ret = maybe_bcast(args[0]); |
48 | for (int i = 1; i < (int)args.size(); i++) |
49 | ret += maybe_bcast(args[i]); |
50 | return ret; |
51 | } |
52 | |
53 | private: |
54 | static std::vector<expr_t> split_by_add(const expr_t &e, int elems) { |
55 | auto *shuffle = e.as_ptr<shuffle_t>(); |
56 | if (shuffle && shuffle->is_broadcast() && shuffle->elems() == elems) { |
57 | return split_by_add(shuffle->vec[0], elems); |
58 | } |
59 | auto *op = e.as_ptr<binary_op_t>(); |
60 | if (!op || op->op_kind != op_kind_t::_add) return {e}; |
61 | auto a_args = split_by_add(op->a, elems); |
62 | auto b_args = split_by_add(op->b, elems); |
63 | std::vector<expr_t> args; |
64 | args.insert(args.end(), a_args.begin(), a_args.end()); |
65 | args.insert(args.end(), b_args.begin(), b_args.end()); |
66 | return args; |
67 | } |
68 | |
69 | type_t type_; |
70 | std::vector<expr_t> args_; |
71 | }; |
72 | |
73 | class hoist_exprs_mutator_t : public ir_mutator_t { |
74 | public: |
75 | hoist_exprs_mutator_t(ir_context_t &ir_ctx, |
76 | int max_hoist_size = std::numeric_limits<int>::max()) |
77 | : ir_ctx_(ir_ctx), max_hoist_size_(max_hoist_size) {} |
78 | |
79 | ~hoist_exprs_mutator_t() override { ir_assert(let_vars_.empty()); } |
80 | |
81 | object_t _mutate(const func_call_t &obj) override { |
82 | if (!obj.func.is<send_t>()) return ir_mutator_t::_mutate(obj); |
83 | |
84 | std::vector<expr_t> new_args; |
85 | for (auto &e : obj.args) { |
86 | new_args.push_back(hoist_expr(e)); |
87 | } |
88 | |
89 | if (ir_utils::is_equal(new_args, obj.args)) return obj; |
90 | |
91 | return func_call_t::make(obj.func, new_args, obj.attr); |
92 | } |
93 | |
94 | object_t _mutate(const stmt_group_t &obj) override { |
95 | if (obj.body.is<for_t>()) { |
96 | loops_.emplace_back(obj.body.as<for_t>().var); |
97 | const for_t *for_obj = obj.body.as_ptr<for_t>(); |
98 | auto body = for_obj ? ir_mutator_t::_mutate(*for_obj) : for_obj; |
99 | if (body.is_same(obj.body)) return obj; |
100 | auto new_obj = stmt_group_t::make(obj.label, body); |
101 | return injects_lets_and_pop_loop(new_obj); |
102 | } |
103 | return ir_mutator_t::_mutate(obj); |
104 | } |
105 | |
106 | object_t _mutate(const store_t &obj) override { |
107 | auto value = hoist_expr(obj.value); |
108 | if (value.is_equal(obj.value)) return obj; |
109 | return store_t::make(obj.buf, obj.off, value, obj.stride); |
110 | } |
111 | |
112 | object_t _mutate(const for_t &obj) override { |
113 | loops_.emplace_back(obj.var); |
114 | auto new_obj = ir_mutator_t::_mutate(obj); |
115 | return injects_lets_and_pop_loop(new_obj); |
116 | } |
117 | |
118 | object_t _mutate(const let_t &obj) override { |
119 | bool fully_hoisted = false; |
120 | expr_t new_value; |
121 | bool is_const_let = is_const(obj.value) || is_shuffle_const(obj.value); |
122 | if (is_const_let && loops_.size() > 0 && can_hoist(obj.var)) { |
123 | fully_hoisted = true; |
124 | register_let(obj.var, obj.value); |
125 | add_hoist_let(loops_[0], obj.var, obj.value); |
126 | } else { |
127 | new_value = hoist_expr(obj.value, obj.var, &fully_hoisted); |
128 | } |
129 | if (fully_hoisted) return mutate(obj.body); |
130 | register_let(obj.var, new_value); |
131 | auto new_obj = let_t::make( |
132 | obj.var, new_value, ir_mutator_t::mutate(obj.body)); |
133 | unregister_let(obj.var); |
134 | return std::move(new_obj); |
135 | } |
136 | |
137 | private: |
138 | struct loop_info_t { |
139 | loop_info_t(const expr_t &var) : var(var) {} |
140 | |
141 | expr_t var; |
142 | int var_count = 0; |
143 | std::vector<stmt_t> lets; |
144 | }; |
145 | |
146 | bool can_hoist(const expr_t &expr) { |
147 | return expr.type().size() <= max_hoist_size_ - current_hoist_size_; |
148 | } |
149 | |
150 | void add_hoist_let( |
151 | loop_info_t &loop, const expr_t &var, const expr_t &value) { |
152 | loop.lets.emplace_back(let_t::make(var, value)); |
153 | current_hoist_size_ += utils::rnd_up( |
154 | var.type().size(), reg_allocator_t::granularity); |
155 | } |
156 | |
157 | expr_t hoist_expr(const expr_t &expr, const expr_t &expr_var = {}, |
158 | bool *fully_hoisted = nullptr) { |
159 | if (expr.is_empty()) return expr; |
160 | if (expr.type().is_ptr()) return expr; |
161 | if (expr.type().is_bool()) return expr; |
162 | if (is_const(expr) || is_shuffle_const(expr) || is_var(expr)) |
163 | return expr; |
164 | if (!can_hoist(expr)) return expr; |
165 | |
166 | auto hoisted_expr = hoist_expr_with_add(expr, expr_var, fully_hoisted); |
167 | if (!hoisted_expr.is_equal(expr)) return hoisted_expr; |
168 | |
169 | // hoist_expr_with_add() doesn't handle cast so try to hoist it manually. |
170 | auto *cast = expr.as_ptr<cast_t>(); |
171 | if (!cast) return hoisted_expr; |
172 | |
173 | auto hoisted_cast_expr = hoist_expr(cast->expr); |
174 | if (!hoisted_cast_expr.is_equal(cast->expr)) { |
175 | hoisted_expr = cast_t::make( |
176 | cast->type, hoisted_cast_expr, cast->saturate); |
177 | } |
178 | return hoisted_expr; |
179 | } |
180 | |
181 | expr_t hoist_expr_with_add(const expr_t &expr, const expr_t &expr_var = {}, |
182 | bool *fully_hoisted = nullptr) { |
183 | const type_t &type = expr.type(); |
184 | sum_expr_t cur_expr(expr); |
185 | |
186 | for (size_t i = 0; i < loops_.size(); i++) { |
187 | std::vector<expr_t> invariant_args; |
188 | std::vector<expr_t> other_args; |
189 | std::vector<expr_t> nary_args; |
190 | if (!cur_expr.is_trivial()) { |
191 | nary_args = cur_expr.args(); |
192 | } else { |
193 | nary_args.push_back(cur_expr.expr()); |
194 | } |
195 | for (auto &a : nary_args) { |
196 | bool is_inv_arg = true; |
197 | for (size_t j = i; j < loops_.size(); j++) { |
198 | if (!is_invariant(a, loops_[j].var)) is_inv_arg = false; |
199 | } |
200 | if (is_inv_arg) { |
201 | invariant_args.push_back(a); |
202 | } else { |
203 | other_args.push_back(a); |
204 | } |
205 | } |
206 | // Nothing to hoist for this loop, continue. |
207 | if (invariant_args.empty()) continue; |
208 | if (invariant_args.size() == 1 && is_var(invariant_args[0]) |
209 | && !other_args.empty()) |
210 | continue; |
211 | if (invariant_args.size() == 1 |
212 | && (is_const(invariant_args[0]) |
213 | || is_const_broadcast(invariant_args[0]))) |
214 | continue; |
215 | |
216 | // Introduce new variable for the invariant sub-expression. |
217 | auto inv_expr = sum_expr_t::make_add(invariant_args, type); |
218 | expr_t inv_var; |
219 | if (!expr_var.is_empty() && other_args.empty()) { |
220 | // If nothing to hoist further, reuse the old variable and |
221 | // return. |
222 | inv_var = expr_var; |
223 | } else { |
224 | inv_var = ir_ctx_.create_tmp_var(inv_expr.type()); |
225 | } |
226 | register_let(inv_var, inv_expr); |
227 | add_hoist_let(loops_[i], inv_var, inv_expr); |
228 | |
229 | if (other_args.empty()) { |
230 | if (fully_hoisted) *fully_hoisted = true; |
231 | return inv_var; |
232 | } |
233 | |
234 | other_args.push_back(inv_var); |
235 | cur_expr = sum_expr_t::make_add(other_args, type); |
236 | } |
237 | return cur_expr.expr(); |
238 | } |
239 | |
240 | stmt_t injects_lets_and_pop_loop(const stmt_t &_s) { |
241 | stmt_t s = _s; |
242 | // Inject let statements if any. |
243 | auto &lets = loops_.back().lets; |
244 | for (auto it = lets.rbegin(); it != lets.rend(); ++it) { |
245 | auto &let = it->as<let_t>(); |
246 | s = let_t::make(let.var, let.value, s); |
247 | unregister_let(let.var); |
248 | } |
249 | loops_.pop_back(); |
250 | return s; |
251 | } |
252 | |
253 | void register_let(const expr_t &var, const expr_t &value) { |
254 | let_vars_.insert({var, value}); |
255 | } |
256 | |
257 | void unregister_let(const expr_t &var) { let_vars_.erase(var); } |
258 | |
259 | bool is_invariant(const expr_t &e, const expr_t &var) const { |
260 | if (contains_object(e, var)) return false; |
261 | if (!find_objects<load_t>(e).empty()) return false; |
262 | |
263 | // Check value if this is a let variable. |
264 | auto it = let_vars_.find(e); |
265 | if (it != let_vars_.end()) return is_invariant(it->second, var); |
266 | |
267 | if (is_var(e)) return true; |
268 | |
269 | // Check transitive dependencies. |
270 | auto vars = find_unique_objects<var_t>(e); |
271 | for (auto &v : vars) { |
272 | if (!is_invariant(v, var)) return false; |
273 | } |
274 | return true; |
275 | } |
276 | |
277 | ir_context_t &ir_ctx_; |
278 | std::vector<loop_info_t> loops_; |
279 | int max_hoist_size_; |
280 | int current_hoist_size_ = 0; |
281 | |
282 | object_map_t<expr_t, expr_t> let_vars_; |
283 | }; |
284 | stmt_t hoist_exprs_impl( |
285 | const stmt_t &s, ir_context_t &ir_ctx, int reserved_regs) { |
286 | |
287 | int grf_size = ir_ctx.hw_cfg().grf_size(); |
288 | int available_regs = ir_ctx.exec_cfg().regs() - reserved_regs; |
289 | int memory_usage_limit = available_regs * grf_size; |
290 | |
291 | auto stmt = hoist_exprs_mutator_t(ir_ctx).mutate(s); |
292 | |
293 | int memory_usage = get_peak_grf_usage(stmt, grf_size) * grf_size; |
294 | if (memory_usage >= memory_usage_limit) { |
295 | // Pessimistically hoist expressions. Does not identify and account for |
296 | // hoists which do not change memory usage. |
297 | int memory_usage_original = get_peak_grf_usage(s, grf_size) * grf_size; |
298 | stmt = hoist_exprs_mutator_t( |
299 | ir_ctx, memory_usage_limit - memory_usage_original) |
300 | .mutate(s); |
301 | } |
302 | return stmt; |
303 | } |
304 | |
305 | stmt_t hoist_exprs(const stmt_t &s, ir_context_t &ir_ctx, int reserved_regs) { |
306 | trace_start(); |
307 | auto ret = hoist_exprs_impl(s, ir_ctx, reserved_regs); |
308 | trace_pass("hoist_exprs" , ret, ir_ctx); |
309 | return ret; |
310 | } |
311 | |
312 | class hoist_send_masks_mutator_t : public ir_mutator_t { |
313 | public: |
314 | hoist_send_masks_mutator_t( |
315 | ir_context_t &ir_ctx, const stmt_label_t &label, bool split_by_and) |
316 | : ir_ctx_(ir_ctx), label_(label), split_by_and_(split_by_and) {} |
317 | |
318 | object_t _mutate(const for_t &obj) override { |
319 | loop_deps_.insert(obj.var); |
320 | return ir_mutator_t::_mutate(obj); |
321 | } |
322 | |
323 | object_t _mutate(const func_call_t &obj) override { |
324 | if (!in_stmt_group || !is_func_call<send_t>(obj)) |
325 | return ir_mutator_t::_mutate(obj); |
326 | |
327 | auto &mask = send_t::arg_mask(obj); |
328 | if (mask.is_empty()) return ir_mutator_t::_mutate(obj); |
329 | |
330 | auto new_args = obj.args; |
331 | auto hoisted_mask = hoist_mask(mask); |
332 | if (hoisted_mask.is_same(mask)) return ir_mutator_t::_mutate(obj); |
333 | |
334 | ir_assert(hoisted_mask.type().is_u16()) << hoisted_mask; |
335 | |
336 | send_t::arg_mask(new_args) = cast(hoisted_mask, mask.type()); |
337 | return func_call_t::make(obj.func, new_args, obj.attr); |
338 | } |
339 | |
340 | object_t _mutate(const let_t &obj) override { |
341 | auto value_vars = find_objects<var_t>(obj.value); |
342 | for (auto &v : value_vars) { |
343 | if (is_loop_dependency(v)) { |
344 | loop_deps_.insert(obj.var); |
345 | break; |
346 | } |
347 | } |
348 | |
349 | if (in_stmt_group) { |
350 | ir_assert(!obj.value.is_empty()); |
351 | let_values_.emplace(obj.var, expand(obj.value, value_vars)); |
352 | } |
353 | |
354 | return ir_mutator_t::_mutate(obj); |
355 | } |
356 | |
357 | object_t _mutate(const stmt_group_t &obj) override { |
358 | bool is_stmt_group = (obj.label == label_); |
359 | if (is_stmt_group) in_stmt_group = true; |
360 | auto new_obj = ir_mutator_t::_mutate(obj); |
361 | if (is_stmt_group) { |
362 | in_stmt_group = false; |
363 | return create_mask_stmt(new_obj); |
364 | } |
365 | return new_obj; |
366 | } |
367 | |
368 | private: |
369 | bool is_loop_dependency(const expr_t &v) const { |
370 | ir_assert(is_var(v)) << v; |
371 | return loop_deps_.count(v) != 0; |
372 | } |
373 | |
374 | expr_t hoist_mask(const expr_t &e) { |
375 | ir_assert(e.type().is_bool()) << e; |
376 | |
377 | if (e.type().elems() > 16) return e; |
378 | if (is_const(e) || is_shuffle_const(e)) return e; |
379 | |
380 | // Can't hoist a mask containing loop vars. |
381 | auto vars = find_objects<var_t>(e); |
382 | for (auto &v : vars) { |
383 | if (is_loop_dependency(v)) return e; |
384 | } |
385 | |
386 | auto e_expanded = expand(e, vars); |
387 | |
388 | // Can't hoist a mask containing loads. |
389 | if (!find_objects<load_t>(e_expanded).empty()) return e; |
390 | |
391 | auto it = hoisted_masks_.find(e_expanded); |
392 | if (it != hoisted_masks_.end()) return it->second; |
393 | |
394 | auto var = ir_ctx_.create_tmp_var(type_t::u16()); |
395 | hoisted_masks_.emplace(e_expanded, var); |
396 | |
397 | return var; |
398 | } |
399 | |
400 | expr_t expand(const expr_t &_e, const std::vector<object_t> &e_vars) const { |
401 | auto e = _e; |
402 | for (auto &v : e_vars) { |
403 | auto it = let_values_.find(v); |
404 | if (it == let_values_.end()) continue; |
405 | e = substitute(e, v, it->second); |
406 | } |
407 | return e; |
408 | } |
409 | |
410 | stmt_t create_mask_stmt(const stmt_t &body) { |
411 | stmt_t s = body; |
412 | |
413 | object_eq_map_t<expr_t, expr_t> and_ops; |
414 | object_eq_map_t<expr_t, expr_t> mask_exprs; |
415 | for (auto &kv : hoisted_masks_) { |
416 | if (split_by_and_) { |
417 | auto e = split_by_and_ops(kv.first, and_ops); |
418 | mask_exprs.emplace(e, kv.second); |
419 | } |
420 | } |
421 | if (and_ops.size() < mask_exprs.size()) { |
422 | for (auto &kv : mask_exprs) { |
423 | s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s); |
424 | } |
425 | for (auto &kv : and_ops) { |
426 | s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s); |
427 | } |
428 | } else { |
429 | for (auto &kv : hoisted_masks_) |
430 | s = let_t::make(kv.second, cast(kv.first, kv.second.type()), s); |
431 | } |
432 | |
433 | return s; |
434 | } |
435 | |
436 | expr_t split_by_and_ops( |
437 | const expr_t &e, object_eq_map_t<expr_t, expr_t> &ops) { |
438 | auto *binary_op = e.as_ptr<binary_op_t>(); |
439 | if (!binary_op || binary_op->op_kind != op_kind_t::_and) { |
440 | auto it = ops.find(e); |
441 | if (it != ops.end()) return it->second; |
442 | |
443 | auto var = ir_ctx_.create_tmp_var(type_t::u16()); |
444 | ops.emplace(e, var); |
445 | return var; |
446 | } |
447 | |
448 | auto a = split_by_and_ops(binary_op->a, ops); |
449 | auto b = split_by_and_ops(binary_op->b, ops); |
450 | return binary_op_t::make(op_kind_t::_and, a, b); |
451 | } |
452 | |
453 | bool in_stmt_group = false; |
454 | object_set_t<expr_t> loop_deps_; |
455 | object_eq_map_t<expr_t, expr_t> hoisted_masks_; |
456 | object_map_t<expr_t, expr_t> let_values_; |
457 | |
458 | ir_context_t &ir_ctx_; |
459 | stmt_label_t label_; |
460 | bool split_by_and_; |
461 | }; |
462 | |
463 | stmt_t hoist_send_masks(const stmt_t &s, ir_context_t &ir_ctx, |
464 | const stmt_label_t &label, bool split_by_and) { |
465 | trace_start(); |
466 | hoist_send_masks_mutator_t mutator(ir_ctx, label, split_by_and); |
467 | |
468 | auto ret = mutator.mutate(s); |
469 | trace_pass("hoist_send_masks" , ret, ir_ctx); |
470 | return ret; |
471 | } |
472 | |
473 | } // namespace jit |
474 | } // namespace gpu |
475 | } // namespace impl |
476 | } // namespace dnnl |
477 | |