1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file hoist_expression.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/stmt_functor.h>
28#include <tvm/tir/transform.h>
29
30#include <queue>
31#include <unordered_map>
32#include <unordered_set>
33#include <utility>
34
35#include "../../arith/interval_set.h"
36#include "../../arith/ir_mutator_with_analyzer.h"
37#include "../../runtime/thread_storage_scope.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43enum class HoistedConditionals : int {
44 kNone = 0,
45 kIfElseStmt = (1 << 0),
46 kIfElseExpr = (1 << 1),
47 kBooleanExpression = (1 << 2),
48 kUsingBlockVar = (1 << 3),
49};
50
51enum class HoistedLetBindings : int {
52 kNone = 0,
53 kRequiredByCondition = (1 << 0),
54 kLetStmt = (1 << 1),
55 kLetExpr = (1 << 2),
56};
57
58struct HoistExpressionConfigNode : public tvm::AttrsNode<HoistExpressionConfigNode> {
59 int hoisted_conditionals;
60 int hoisted_let_bindings;
61
62 TVM_DECLARE_ATTRS(HoistExpressionConfigNode, "tir.transform.HoistExpressionConfig") {
63 TVM_ATTR_FIELD(hoisted_conditionals)
64 .describe("Bitflags for the types of boolean expressions to hoist")
65 .set_default(static_cast<int>(HoistedConditionals::kIfElseStmt) |
66 static_cast<int>(HoistedConditionals::kIfElseExpr) |
67 static_cast<int>(HoistedConditionals::kBooleanExpression));
68 TVM_ATTR_FIELD(hoisted_let_bindings)
69 .describe("Bitflags for the types of let bindings to hoist")
70 .set_default(static_cast<int>(HoistedLetBindings::kRequiredByCondition) |
71 static_cast<int>(HoistedLetBindings::kLetStmt) |
72 static_cast<int>(HoistedLetBindings::kLetExpr));
73 }
74
75 bool FlagSet(HoistedConditionals flag) const {
76 return static_cast<int>(flag) & hoisted_conditionals;
77 }
78 bool FlagSet(HoistedLetBindings flag) const {
79 return static_cast<int>(flag) & hoisted_let_bindings;
80 }
81};
82
83class HoistExpressionConfig : public Attrs {
84 public:
85 HoistExpressionConfig(int hoisted_conditionals, int hoisted_let_bindings) {
86 auto node = make_object<HoistExpressionConfigNode>();
87 node->hoisted_conditionals = hoisted_conditionals;
88 node->hoisted_let_bindings = hoisted_let_bindings;
89 data_ = std::move(node);
90 }
91 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistExpressionConfig, Attrs,
92 HoistExpressionConfigNode);
93};
94
95TVM_REGISTER_NODE_TYPE(HoistExpressionConfigNode);
96TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression", HoistExpressionConfig);
97
98struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> {
99 // Would like to replace the typo here from "hosting" to "hoisting",
100 // but that may impact user configurations.
101 bool support_block_scope_hosting;
102
103 TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig") {
104 TVM_ATTR_FIELD(support_block_scope_hosting)
105 .describe("Hoist if cond with block scope variables")
106 .set_default(false);
107 }
108};
109
110class HoistIfThenElseConfig : public Attrs {
111 public:
112 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs,
113 HoistIfThenElseConfigNode);
114};
115
116TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode);
117TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse", HoistIfThenElseConfig);
118
119class HoistInfoCollector : public StmtExprVisitor {
120 public:
121 struct ConditionInfo {
122 ConditionInfo(PrimExpr condition, HoistedConditionals hoist_from, bool uses_block_var,
123 std::unordered_set<const VarNode*> required_let_bindings, bool generate_else_case)
124 : condition(condition),
125 hoist_from(hoist_from),
126 uses_block_var(uses_block_var),
127 required_let_bindings(required_let_bindings),
128 generate_else_case(generate_else_case) {}
129 PrimExpr condition;
130 HoistedConditionals hoist_from;
131 bool uses_block_var;
132 std::unordered_set<const VarNode*> required_let_bindings;
133 bool generate_else_case;
134
135 bool IsEnabled(const HoistExpressionConfig& config) const {
136 bool valid_source = config->FlagSet(hoist_from);
137
138 bool all_required_bindings_are_hoisted =
139 required_let_bindings.empty() ||
140 config->FlagSet(HoistedLetBindings::kRequiredByCondition) ||
141 config->FlagSet(HoistedLetBindings::kLetStmt);
142
143 bool valid_block_var_usage =
144 config->FlagSet(HoistedConditionals::kUsingBlockVar) || !uses_block_var;
145 return valid_source && all_required_bindings_are_hoisted && valid_block_var_usage;
146 }
147 };
148
149 struct LetBindingInfo {
150 LetBindingInfo(Var var, PrimExpr value, HoistedLetBindings hoist_from)
151 : var(var), value(value), hoist_from(hoist_from) {}
152 Var var;
153 PrimExpr value;
154 HoistedLetBindings hoist_from;
155
156 bool IsEnabled(const HoistExpressionConfig& config) const {
157 return config->FlagSet(hoist_from);
158 }
159 };
160
161 struct HoistInfo {
162 // The loop variable
163 Var loop_var;
164
165 // The For or AttrStmt that defines the loop var.
166 Stmt loop_def;
167
168 // Bindings defined in LetStmt inside the for-loop whose value
169 // does not depend on the loop variable. These can be hoisted
170 // outside this for-loop.
171 std::vector<LetBindingInfo> let_bindings;
172
173 // Conditions evaluated inside the for-loop whose value does not
174 // depend on the loop variable. These can be hoisted outside this
175 // for loop. These may depend on the let_bindings.
176 std::vector<ConditionInfo> conditions;
177
178 // Only conditions that impact the entire body of the loop
179 // hoisted. Conditionals may not be hoisted from inside a
180 // sequential node to outside.
181 bool reached_sequential_node{false};
182
183 // True if the loop variable representing a block variable
184 // (e.g. blockIdx.x, threadIdx.x), false otherwise.
185 bool IsBlockVariable() const { return !loop_def.as<ForNode>(); }
186 };
187
188 static std::vector<HoistInfo> Collect(Stmt stmt, HoistExpressionConfig config) {
189 HoistInfoCollector collector(config);
190 collector(stmt);
191 return collector.completed_loops;
192 }
193
194 private:
195 using Parent = StmtExprVisitor;
196 using Parent::VisitExpr_;
197 using Parent::VisitStmt_;
198
199 explicit HoistInfoCollector(HoistExpressionConfig config) : config(config) {}
200
201 void AttemptHoistConditional(PrimExpr cond, HoistedConditionals hoist_from,
202 bool generate_else_block = true) {
203 if (SideEffect(cond) > CallEffectKind::kPure) {
204 return;
205 }
206 if (auto info = FindHoistDestination(cond)) {
207 if (!info->reached_sequential_node) {
208 // Record whether this conditional uses any block variables.
209 bool uses_block_var = active_block_vars.size() && UsesVar(cond, [&](const VarNode* var) {
210 return active_block_vars.count(var);
211 });
212
213 std::unordered_set<const VarNode*> let_bindings_used;
214
215 for (Var var : UndefinedVars(cond)) {
216 auto it = let_var_to_let_vars.find(var.get());
217 if (it != let_var_to_let_vars.end()) {
218 let_bindings_used.insert(it->first);
219 for (auto used : it->second) {
220 let_bindings_used.insert(used);
221 }
222 }
223 }
224 info->conditions.push_back(ConditionInfo(cond, hoist_from, uses_block_var,
225 let_bindings_used, generate_else_block));
226 }
227 }
228 }
229
230 void VisitExpr_(const AndNode* op) final {
231 AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression);
232 AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression);
233 Parent::VisitExpr_(op);
234 }
235
236 void VisitExpr_(const OrNode* op) final {
237 AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression);
238 AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression);
239 Parent::VisitExpr_(op);
240 }
241
242 void VisitStmt_(const ForNode* op) final {
243 active_loops.push_back({op->loop_var, GetRef<Stmt>(op)});
244 active_loop_vars.insert(op->loop_var.get());
245
246 Parent::VisitStmt_(op);
247 completed_loops.push_back(active_loops.back());
248
249 active_loop_vars.erase(op->loop_var.get());
250 active_loops.pop_back();
251 }
252
253 void VisitStmt_(const AttrStmtNode* op) final {
254 Var var;
255 if (const auto* node_iter_var = op->node.as<IterVarNode>()) {
256 var = node_iter_var->var;
257 } else if (const auto* node_var = op->node.as<VarNode>()) {
258 var = GetRef<Var>(node_var);
259 } else {
260 return Parent::VisitStmt_(op);
261 }
262
263 active_block_vars.insert(var.get());
264 active_loop_vars.insert(var.get());
265 active_loops.push_back({var, GetRef<Stmt>(op)});
266
267 Parent::VisitStmt_(op);
268
269 completed_loops.push_back(active_loops.back());
270 active_loops.pop_back();
271
272 active_loop_vars.erase(var.get());
273 active_block_vars.erase(var.get());
274 }
275
276 void VisitBinding(Var var, PrimExpr value, HoistedLetBindings hoist_from) {
277 ICHECK_EQ(let_var_to_loop_vars.count(var.get()), 0)
278 << "Multiple nested definitions of variable " << var;
279 ICHECK_EQ(let_var_to_let_vars.count(var.get()), 0)
280 << "Multiple nested definitions of variable " << var;
281
282 if (auto info = FindHoistDestination(value)) {
283 if (!info->reached_sequential_node) {
284 info->let_bindings.push_back(LetBindingInfo(var, value, hoist_from));
285 }
286 }
287
288 // Walk through the loop binding
289 std::unordered_set<const VarNode*> loop_vars_used;
290 std::unordered_set<const VarNode*> let_bindings_used;
291 for (Var var : UndefinedVars(value)) {
292 if (active_loop_vars.count(var.get())) {
293 loop_vars_used.insert(var.get());
294 } else {
295 auto it = let_var_to_loop_vars.find(var.get());
296 if (it != let_var_to_loop_vars.end()) {
297 for (const VarNode* used : it->second) {
298 loop_vars_used.insert(used);
299 }
300 }
301 }
302
303 auto it = let_var_to_let_vars.find(var.get());
304 if (it != let_var_to_let_vars.end()) {
305 let_bindings_used.insert(it->first);
306 for (const VarNode* used : it->second) {
307 let_bindings_used.insert(used);
308 }
309 }
310 }
311
312 let_var_to_loop_vars[var.get()] = std::move(loop_vars_used);
313 let_var_to_let_vars[var.get()] = std::move(let_bindings_used);
314 }
315
316 void VisitStmt_(const LetStmtNode* op) final {
317 VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt);
318
319 Parent::VisitStmt_(op);
320
321 let_var_to_loop_vars.erase(op->var.get());
322 let_var_to_let_vars.erase(op->var.get());
323 }
324
325 void VisitExpr_(const LetNode* op) final {
326 VisitBinding(op->var, op->value, HoistedLetBindings::kLetExpr);
327
328 Parent::VisitExpr_(op);
329
330 let_var_to_loop_vars.erase(op->var.get());
331 let_var_to_let_vars.erase(op->var.get());
332 }
333
334 void VisitStmt_(const IfThenElseNode* op) final {
335 AttemptHoistConditional(op->condition, HoistedConditionals::kIfElseStmt,
336 op->else_case.defined());
337 Parent::VisitStmt_(op);
338 }
339
340 void VisitExpr_(const CallNode* op) final {
341 if (op->op.same_as(builtin::if_then_else())) {
342 PrimExpr cond = op->args[0];
343 AttemptHoistConditional(cond, HoistedConditionals::kIfElseExpr);
344 }
345 Parent::VisitExpr_(op);
346 }
347
348 void VisitStmt_(const SeqStmtNode* op) final {
349 if (active_loops.size()) {
350 active_loops.back().reached_sequential_node = true;
351 }
352 Parent::VisitStmt_(op);
353 }
354
355 // Find the loop above which this expression could be hoisted. If
356 // nullptr, the expression cannot be hoisted.
357 HoistInfo* FindHoistDestination(PrimExpr expr) {
358 // Cannot hoist above a loop if we aren't already in a loop.
359 if (active_loops.empty()) {
360 return nullptr;
361 }
362
363 for (auto it = active_loops.rbegin(); it != active_loops.rend(); it++) {
364 Var loop_var = it->loop_var;
365 bool uses_loop_var = UsesVar(expr, [&](const VarNode* var) -> bool {
366 if (var == loop_var.get()) {
367 return true;
368 }
369
370 auto it = let_var_to_loop_vars.find(var);
371 if (it == let_var_to_loop_vars.end()) {
372 return false;
373 }
374
375 return it->second.count(loop_var.get());
376 });
377
378 bool is_disabled_hoist_across_block_var =
379 !config->FlagSet(HoistedConditionals::kUsingBlockVar) && it->IsBlockVariable();
380
381 if (it->reached_sequential_node || uses_loop_var || is_disabled_hoist_across_block_var) {
382 if (it == active_loops.rbegin()) {
383 // Cannot hoist beyond the innermost loop iterator.
384 return nullptr;
385 } else {
386 // Hoist to just below the loop iterator that is required.
387 it--;
388 return &(*it);
389 }
390 }
391 }
392
393 // If no loop variables are used, can hoist above the outermost
394 // loop.
395 return &active_loops.front();
396 }
397
398 // The user-provided config describing which expressions should be
399 // hoisted.
400 HoistExpressionConfig config;
401
402 // Current thread_extent bindings of block variables.
403 std::unordered_set<const VarNode*> active_block_vars;
404
405 // An ordered list of loops that are currently being visited.
406 std::vector<HoistInfo> active_loops;
407
408 // Loops that have already been visited
409 std::vector<HoistInfo> completed_loops;
410
411 // Map from a bound variable to the loop variables it depends on.
412 // Includes indirect usage.
413 std::unordered_map<const VarNode*, std::unordered_set<const VarNode*>> let_var_to_loop_vars;
414
415 // Map from a bound variable to the other let bindings it depends on.
416 // Includes indirect usage.
417 std::unordered_map<const VarNode*, std::unordered_set<const VarNode*>> let_var_to_let_vars;
418
419 // Lookup table for the currently active loops.
420 std::unordered_set<const VarNode*> active_loop_vars;
421};
422
423class ExpressionHoister : public arith::IRMutatorWithAnalyzer {
424 public:
425 static Stmt Hoist(Stmt stmt, HoistExpressionConfig config) {
426 auto loop_info = HoistInfoCollector::Collect(stmt, config);
427
428 arith::Analyzer analyzer;
429 ExpressionHoister hoister(std::move(loop_info), config, &analyzer);
430 stmt = hoister(std::move(stmt));
431 stmt = ConvertSSA(std::move(stmt));
432 return stmt;
433 }
434
435 private:
436 using Parent = arith::IRMutatorWithAnalyzer;
437 using Parent::VisitExpr_;
438 using Parent::VisitStmt_;
439
440 explicit ExpressionHoister(std::vector<HoistInfoCollector::HoistInfo> loop_info,
441 HoistExpressionConfig config, arith::Analyzer* analyzer)
442 : Parent(analyzer), config_(config) {
443 for (auto& info : loop_info) {
444 // Mark let bindings to use if they are enabled on their own.
445 for (const auto& binding : info.let_bindings) {
446 if (binding.IsEnabled(config)) {
447 hoisted_let_bindings.insert(binding.var.get());
448 }
449 }
450
451 // Or if they are required by a conditional
452 if (config->FlagSet(HoistedLetBindings::kRequiredByCondition)) {
453 for (const auto& conditional : info.conditions) {
454 if (conditional.IsEnabled(config)) {
455 for (const auto& var : conditional.required_let_bindings) {
456 hoisted_let_bindings.insert(var);
457 }
458 }
459 }
460 }
461
462 loop_info_lookup[info.loop_def.get()] = std::move(info);
463 }
464 }
465
466 Stmt WrapHoistedStatements(Stmt stmt, const HoistInfoCollector::HoistInfo& info) {
467 for (auto cond_it = info.conditions.rbegin(); cond_it != info.conditions.rend(); cond_it++) {
468 if (cond_it->IsEnabled(config_)) {
469 if (cond_it->generate_else_case) {
470 stmt = IfThenElse(cond_it->condition, stmt, stmt);
471 } else {
472 stmt = IfThenElse(cond_it->condition, stmt);
473 }
474 }
475 }
476 for (auto let_it = info.let_bindings.rbegin(); let_it != info.let_bindings.rend(); let_it++) {
477 if (hoisted_let_bindings.count(let_it->var.get())) {
478 stmt = LetStmt(let_it->var, let_it->value, stmt);
479 }
480 }
481
482 return stmt;
483 }
484
485 Stmt VisitStmt_(const ForNode* op) final {
486 Stmt stmt = Parent::VisitStmt_(op);
487
488 auto it = loop_info_lookup.find(op);
489 ICHECK(it != loop_info_lookup.end())
490 << "Could not find pre-pass information for loop over " << op->loop_var;
491 return WrapHoistedStatements(stmt, it->second);
492 }
493
494 Stmt VisitStmt_(const AttrStmtNode* op) final {
495 Stmt stmt = Parent::VisitStmt_(op);
496
497 auto it = loop_info_lookup.find(op);
498 if (it == loop_info_lookup.end()) {
499 return stmt;
500 } else {
501 return WrapHoistedStatements(stmt, it->second);
502 }
503 }
504
505 Stmt VisitStmt_(const LetStmtNode* op) final {
506 if (hoisted_let_bindings.count(op->var.get())) {
507 return this->VisitStmt(op->body);
508 } else {
509 return Parent::VisitStmt_(op);
510 }
511 }
512
513 PrimExpr VisitExpr_(const LetNode* op) final {
514 if (hoisted_let_bindings.count(op->var.get())) {
515 return this->VisitExpr(op->body);
516 } else {
517 return Parent::VisitExpr_(op);
518 }
519 }
520
521 HoistExpressionConfig config_;
522
523 std::unordered_map<const StmtNode*, HoistInfoCollector::HoistInfo> loop_info_lookup;
524 std::unordered_set<const VarNode*> hoisted_let_bindings;
525};
526
527Stmt HoistExpression(Stmt stmt, HoistExpressionConfig config) {
528 return ExpressionHoister::Hoist(stmt, config);
529}
530
531namespace transform {
532
533Pass HoistExpression() {
534 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
535 auto* n = f.CopyOnWrite();
536 auto cfg = ctx->GetConfig<HoistExpressionConfig>("tir.HoistExpression");
537
538 if (!cfg.defined()) {
539 cfg = AttrsWithDefaultValues<HoistExpressionConfig>();
540 }
541 n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value());
542 return f;
543 };
544 auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistedExpression", {});
545
546 return Sequential(
547 {
548 insertion_pass,
549 Simplify(),
550 RemoveNoOp(),
551 },
552 "tir.HoistExpression");
553}
554
555TVM_REGISTER_GLOBAL("tir.transform.HoistExpression").set_body_typed(HoistExpression);
556
557Pass HoistIfThenElse() {
558 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
559 auto* n = f.CopyOnWrite();
560 auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
561
562 if (!cfg.defined()) {
563 cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
564 }
565 int block_var = static_cast<int>(cfg.value()->support_block_scope_hosting
566 ? HoistedConditionals::kUsingBlockVar
567 : HoistedConditionals::kNone);
568 HoistExpressionConfig config(block_var | static_cast<int>(HoistedConditionals::kIfElseStmt),
569 static_cast<int>(HoistedLetBindings::kNone));
570 n->body = ExpressionHoister::Hoist(std::move(n->body), config);
571 return f;
572 };
573 auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElse", {});
574 return Sequential(
575 {
576 insertion_pass,
577 Simplify(),
578 RemoveNoOp(),
579 },
580 "tir.HoistIfThenElse");
581}
582
583TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse").set_body_typed(HoistIfThenElse);
584
585Pass HoistIfThenElseBasic() {
586 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
587 auto* n = f.CopyOnWrite();
588 HoistExpressionConfig config(static_cast<int>(HoistedConditionals::kIfElseStmt),
589 static_cast<int>(HoistedLetBindings::kNone));
590 n->body = ExpressionHoister::Hoist(std::move(n->body), config);
591 return f;
592 };
593 auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElseBasic", {});
594 return Sequential(
595 {
596 insertion_pass,
597 Simplify(),
598 RemoveNoOp(),
599 },
600 "tir.HoistIfThenElseBasic");
601}
602
603TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic").set_body_typed(HoistIfThenElseBasic);
604
605} // namespace transform
606
607} // namespace tir
608} // namespace tvm
609