1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file hoist_expression.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <queue> |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | #include <utility> |
34 | |
35 | #include "../../arith/interval_set.h" |
36 | #include "../../arith/ir_mutator_with_analyzer.h" |
37 | #include "../../runtime/thread_storage_scope.h" |
38 | #include "ir_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | enum class HoistedConditionals : int { |
44 | kNone = 0, |
45 | kIfElseStmt = (1 << 0), |
46 | kIfElseExpr = (1 << 1), |
47 | kBooleanExpression = (1 << 2), |
48 | kUsingBlockVar = (1 << 3), |
49 | }; |
50 | |
51 | enum class HoistedLetBindings : int { |
52 | kNone = 0, |
53 | kRequiredByCondition = (1 << 0), |
54 | kLetStmt = (1 << 1), |
55 | kLetExpr = (1 << 2), |
56 | }; |
57 | |
58 | struct HoistExpressionConfigNode : public tvm::AttrsNode<HoistExpressionConfigNode> { |
59 | int hoisted_conditionals; |
60 | int hoisted_let_bindings; |
61 | |
62 | TVM_DECLARE_ATTRS(HoistExpressionConfigNode, "tir.transform.HoistExpressionConfig" ) { |
63 | TVM_ATTR_FIELD(hoisted_conditionals) |
64 | .describe("Bitflags for the types of boolean expressions to hoist" ) |
65 | .set_default(static_cast<int>(HoistedConditionals::kIfElseStmt) | |
66 | static_cast<int>(HoistedConditionals::kIfElseExpr) | |
67 | static_cast<int>(HoistedConditionals::kBooleanExpression)); |
68 | TVM_ATTR_FIELD(hoisted_let_bindings) |
69 | .describe("Bitflags for the types of let bindings to hoist" ) |
70 | .set_default(static_cast<int>(HoistedLetBindings::kRequiredByCondition) | |
71 | static_cast<int>(HoistedLetBindings::kLetStmt) | |
72 | static_cast<int>(HoistedLetBindings::kLetExpr)); |
73 | } |
74 | |
75 | bool FlagSet(HoistedConditionals flag) const { |
76 | return static_cast<int>(flag) & hoisted_conditionals; |
77 | } |
78 | bool FlagSet(HoistedLetBindings flag) const { |
79 | return static_cast<int>(flag) & hoisted_let_bindings; |
80 | } |
81 | }; |
82 | |
83 | class HoistExpressionConfig : public Attrs { |
84 | public: |
85 | HoistExpressionConfig(int hoisted_conditionals, int hoisted_let_bindings) { |
86 | auto node = make_object<HoistExpressionConfigNode>(); |
87 | node->hoisted_conditionals = hoisted_conditionals; |
88 | node->hoisted_let_bindings = hoisted_let_bindings; |
89 | data_ = std::move(node); |
90 | } |
91 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistExpressionConfig, Attrs, |
92 | HoistExpressionConfigNode); |
93 | }; |
94 | |
95 | TVM_REGISTER_NODE_TYPE(HoistExpressionConfigNode); |
96 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistExpression" , HoistExpressionConfig); |
97 | |
98 | struct HoistIfThenElseConfigNode : public tvm::AttrsNode<HoistIfThenElseConfigNode> { |
99 | // Would like to replace the typo here from "hosting" to "hoisting", |
100 | // but that may impact user configurations. |
101 | bool support_block_scope_hosting; |
102 | |
103 | TVM_DECLARE_ATTRS(HoistIfThenElseConfigNode, "tir.transform.HoistIfThenElseConfig" ) { |
104 | TVM_ATTR_FIELD(support_block_scope_hosting) |
105 | .describe("Hoist if cond with block scope variables" ) |
106 | .set_default(false); |
107 | } |
108 | }; |
109 | |
110 | class HoistIfThenElseConfig : public Attrs { |
111 | public: |
112 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(HoistIfThenElseConfig, Attrs, |
113 | HoistIfThenElseConfigNode); |
114 | }; |
115 | |
116 | TVM_REGISTER_NODE_TYPE(HoistIfThenElseConfigNode); |
117 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.HoistIfThenElse" , HoistIfThenElseConfig); |
118 | |
119 | class HoistInfoCollector : public StmtExprVisitor { |
120 | public: |
121 | struct ConditionInfo { |
122 | ConditionInfo(PrimExpr condition, HoistedConditionals hoist_from, bool uses_block_var, |
123 | std::unordered_set<const VarNode*> required_let_bindings, bool generate_else_case) |
124 | : condition(condition), |
125 | hoist_from(hoist_from), |
126 | uses_block_var(uses_block_var), |
127 | required_let_bindings(required_let_bindings), |
128 | generate_else_case(generate_else_case) {} |
129 | PrimExpr condition; |
130 | HoistedConditionals hoist_from; |
131 | bool uses_block_var; |
132 | std::unordered_set<const VarNode*> required_let_bindings; |
133 | bool generate_else_case; |
134 | |
135 | bool IsEnabled(const HoistExpressionConfig& config) const { |
136 | bool valid_source = config->FlagSet(hoist_from); |
137 | |
138 | bool all_required_bindings_are_hoisted = |
139 | required_let_bindings.empty() || |
140 | config->FlagSet(HoistedLetBindings::kRequiredByCondition) || |
141 | config->FlagSet(HoistedLetBindings::kLetStmt); |
142 | |
143 | bool valid_block_var_usage = |
144 | config->FlagSet(HoistedConditionals::kUsingBlockVar) || !uses_block_var; |
145 | return valid_source && all_required_bindings_are_hoisted && valid_block_var_usage; |
146 | } |
147 | }; |
148 | |
149 | struct LetBindingInfo { |
150 | LetBindingInfo(Var var, PrimExpr value, HoistedLetBindings hoist_from) |
151 | : var(var), value(value), hoist_from(hoist_from) {} |
152 | Var var; |
153 | PrimExpr value; |
154 | HoistedLetBindings hoist_from; |
155 | |
156 | bool IsEnabled(const HoistExpressionConfig& config) const { |
157 | return config->FlagSet(hoist_from); |
158 | } |
159 | }; |
160 | |
161 | struct HoistInfo { |
162 | // The loop variable |
163 | Var loop_var; |
164 | |
165 | // The For or AttrStmt that defines the loop var. |
166 | Stmt loop_def; |
167 | |
168 | // Bindings defined in LetStmt inside the for-loop whose value |
169 | // does not depend on the loop variable. These can be hoisted |
170 | // outside this for-loop. |
171 | std::vector<LetBindingInfo> let_bindings; |
172 | |
173 | // Conditions evaluated inside the for-loop whose value does not |
174 | // depend on the loop variable. These can be hoisted outside this |
175 | // for loop. These may depend on the let_bindings. |
176 | std::vector<ConditionInfo> conditions; |
177 | |
178 | // Only conditions that impact the entire body of the loop |
179 | // hoisted. Conditionals may not be hoisted from inside a |
180 | // sequential node to outside. |
181 | bool reached_sequential_node{false}; |
182 | |
183 | // True if the loop variable representing a block variable |
184 | // (e.g. blockIdx.x, threadIdx.x), false otherwise. |
185 | bool IsBlockVariable() const { return !loop_def.as<ForNode>(); } |
186 | }; |
187 | |
188 | static std::vector<HoistInfo> Collect(Stmt stmt, HoistExpressionConfig config) { |
189 | HoistInfoCollector collector(config); |
190 | collector(stmt); |
191 | return collector.completed_loops; |
192 | } |
193 | |
194 | private: |
195 | using Parent = StmtExprVisitor; |
196 | using Parent::VisitExpr_; |
197 | using Parent::VisitStmt_; |
198 | |
199 | explicit HoistInfoCollector(HoistExpressionConfig config) : config(config) {} |
200 | |
201 | void AttemptHoistConditional(PrimExpr cond, HoistedConditionals hoist_from, |
202 | bool generate_else_block = true) { |
203 | if (SideEffect(cond) > CallEffectKind::kPure) { |
204 | return; |
205 | } |
206 | if (auto info = FindHoistDestination(cond)) { |
207 | if (!info->reached_sequential_node) { |
208 | // Record whether this conditional uses any block variables. |
209 | bool uses_block_var = active_block_vars.size() && UsesVar(cond, [&](const VarNode* var) { |
210 | return active_block_vars.count(var); |
211 | }); |
212 | |
213 | std::unordered_set<const VarNode*> let_bindings_used; |
214 | |
215 | for (Var var : UndefinedVars(cond)) { |
216 | auto it = let_var_to_let_vars.find(var.get()); |
217 | if (it != let_var_to_let_vars.end()) { |
218 | let_bindings_used.insert(it->first); |
219 | for (auto used : it->second) { |
220 | let_bindings_used.insert(used); |
221 | } |
222 | } |
223 | } |
224 | info->conditions.push_back(ConditionInfo(cond, hoist_from, uses_block_var, |
225 | let_bindings_used, generate_else_block)); |
226 | } |
227 | } |
228 | } |
229 | |
230 | void VisitExpr_(const AndNode* op) final { |
231 | AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression); |
232 | AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression); |
233 | Parent::VisitExpr_(op); |
234 | } |
235 | |
236 | void VisitExpr_(const OrNode* op) final { |
237 | AttemptHoistConditional(op->a, HoistedConditionals::kBooleanExpression); |
238 | AttemptHoistConditional(op->b, HoistedConditionals::kBooleanExpression); |
239 | Parent::VisitExpr_(op); |
240 | } |
241 | |
242 | void VisitStmt_(const ForNode* op) final { |
243 | active_loops.push_back({op->loop_var, GetRef<Stmt>(op)}); |
244 | active_loop_vars.insert(op->loop_var.get()); |
245 | |
246 | Parent::VisitStmt_(op); |
247 | completed_loops.push_back(active_loops.back()); |
248 | |
249 | active_loop_vars.erase(op->loop_var.get()); |
250 | active_loops.pop_back(); |
251 | } |
252 | |
253 | void VisitStmt_(const AttrStmtNode* op) final { |
254 | Var var; |
255 | if (const auto* node_iter_var = op->node.as<IterVarNode>()) { |
256 | var = node_iter_var->var; |
257 | } else if (const auto* node_var = op->node.as<VarNode>()) { |
258 | var = GetRef<Var>(node_var); |
259 | } else { |
260 | return Parent::VisitStmt_(op); |
261 | } |
262 | |
263 | active_block_vars.insert(var.get()); |
264 | active_loop_vars.insert(var.get()); |
265 | active_loops.push_back({var, GetRef<Stmt>(op)}); |
266 | |
267 | Parent::VisitStmt_(op); |
268 | |
269 | completed_loops.push_back(active_loops.back()); |
270 | active_loops.pop_back(); |
271 | |
272 | active_loop_vars.erase(var.get()); |
273 | active_block_vars.erase(var.get()); |
274 | } |
275 | |
276 | void VisitBinding(Var var, PrimExpr value, HoistedLetBindings hoist_from) { |
277 | ICHECK_EQ(let_var_to_loop_vars.count(var.get()), 0) |
278 | << "Multiple nested definitions of variable " << var; |
279 | ICHECK_EQ(let_var_to_let_vars.count(var.get()), 0) |
280 | << "Multiple nested definitions of variable " << var; |
281 | |
282 | if (auto info = FindHoistDestination(value)) { |
283 | if (!info->reached_sequential_node) { |
284 | info->let_bindings.push_back(LetBindingInfo(var, value, hoist_from)); |
285 | } |
286 | } |
287 | |
288 | // Walk through the loop binding |
289 | std::unordered_set<const VarNode*> loop_vars_used; |
290 | std::unordered_set<const VarNode*> let_bindings_used; |
291 | for (Var var : UndefinedVars(value)) { |
292 | if (active_loop_vars.count(var.get())) { |
293 | loop_vars_used.insert(var.get()); |
294 | } else { |
295 | auto it = let_var_to_loop_vars.find(var.get()); |
296 | if (it != let_var_to_loop_vars.end()) { |
297 | for (const VarNode* used : it->second) { |
298 | loop_vars_used.insert(used); |
299 | } |
300 | } |
301 | } |
302 | |
303 | auto it = let_var_to_let_vars.find(var.get()); |
304 | if (it != let_var_to_let_vars.end()) { |
305 | let_bindings_used.insert(it->first); |
306 | for (const VarNode* used : it->second) { |
307 | let_bindings_used.insert(used); |
308 | } |
309 | } |
310 | } |
311 | |
312 | let_var_to_loop_vars[var.get()] = std::move(loop_vars_used); |
313 | let_var_to_let_vars[var.get()] = std::move(let_bindings_used); |
314 | } |
315 | |
316 | void VisitStmt_(const LetStmtNode* op) final { |
317 | VisitBinding(op->var, op->value, HoistedLetBindings::kLetStmt); |
318 | |
319 | Parent::VisitStmt_(op); |
320 | |
321 | let_var_to_loop_vars.erase(op->var.get()); |
322 | let_var_to_let_vars.erase(op->var.get()); |
323 | } |
324 | |
325 | void VisitExpr_(const LetNode* op) final { |
326 | VisitBinding(op->var, op->value, HoistedLetBindings::kLetExpr); |
327 | |
328 | Parent::VisitExpr_(op); |
329 | |
330 | let_var_to_loop_vars.erase(op->var.get()); |
331 | let_var_to_let_vars.erase(op->var.get()); |
332 | } |
333 | |
334 | void VisitStmt_(const IfThenElseNode* op) final { |
335 | AttemptHoistConditional(op->condition, HoistedConditionals::kIfElseStmt, |
336 | op->else_case.defined()); |
337 | Parent::VisitStmt_(op); |
338 | } |
339 | |
340 | void VisitExpr_(const CallNode* op) final { |
341 | if (op->op.same_as(builtin::if_then_else())) { |
342 | PrimExpr cond = op->args[0]; |
343 | AttemptHoistConditional(cond, HoistedConditionals::kIfElseExpr); |
344 | } |
345 | Parent::VisitExpr_(op); |
346 | } |
347 | |
348 | void VisitStmt_(const SeqStmtNode* op) final { |
349 | if (active_loops.size()) { |
350 | active_loops.back().reached_sequential_node = true; |
351 | } |
352 | Parent::VisitStmt_(op); |
353 | } |
354 | |
355 | // Find the loop above which this expression could be hoisted. If |
356 | // nullptr, the expression cannot be hoisted. |
357 | HoistInfo* FindHoistDestination(PrimExpr expr) { |
358 | // Cannot hoist above a loop if we aren't already in a loop. |
359 | if (active_loops.empty()) { |
360 | return nullptr; |
361 | } |
362 | |
363 | for (auto it = active_loops.rbegin(); it != active_loops.rend(); it++) { |
364 | Var loop_var = it->loop_var; |
365 | bool uses_loop_var = UsesVar(expr, [&](const VarNode* var) -> bool { |
366 | if (var == loop_var.get()) { |
367 | return true; |
368 | } |
369 | |
370 | auto it = let_var_to_loop_vars.find(var); |
371 | if (it == let_var_to_loop_vars.end()) { |
372 | return false; |
373 | } |
374 | |
375 | return it->second.count(loop_var.get()); |
376 | }); |
377 | |
378 | bool is_disabled_hoist_across_block_var = |
379 | !config->FlagSet(HoistedConditionals::kUsingBlockVar) && it->IsBlockVariable(); |
380 | |
381 | if (it->reached_sequential_node || uses_loop_var || is_disabled_hoist_across_block_var) { |
382 | if (it == active_loops.rbegin()) { |
383 | // Cannot hoist beyond the innermost loop iterator. |
384 | return nullptr; |
385 | } else { |
386 | // Hoist to just below the loop iterator that is required. |
387 | it--; |
388 | return &(*it); |
389 | } |
390 | } |
391 | } |
392 | |
393 | // If no loop variables are used, can hoist above the outermost |
394 | // loop. |
395 | return &active_loops.front(); |
396 | } |
397 | |
398 | // The user-provided config describing which expressions should be |
399 | // hoisted. |
400 | HoistExpressionConfig config; |
401 | |
402 | // Current thread_extent bindings of block variables. |
403 | std::unordered_set<const VarNode*> active_block_vars; |
404 | |
405 | // An ordered list of loops that are currently being visited. |
406 | std::vector<HoistInfo> active_loops; |
407 | |
408 | // Loops that have already been visited |
409 | std::vector<HoistInfo> completed_loops; |
410 | |
411 | // Map from a bound variable to the loop variables it depends on. |
412 | // Includes indirect usage. |
413 | std::unordered_map<const VarNode*, std::unordered_set<const VarNode*>> let_var_to_loop_vars; |
414 | |
415 | // Map from a bound variable to the other let bindings it depends on. |
416 | // Includes indirect usage. |
417 | std::unordered_map<const VarNode*, std::unordered_set<const VarNode*>> let_var_to_let_vars; |
418 | |
419 | // Lookup table for the currently active loops. |
420 | std::unordered_set<const VarNode*> active_loop_vars; |
421 | }; |
422 | |
423 | class ExpressionHoister : public arith::IRMutatorWithAnalyzer { |
424 | public: |
425 | static Stmt Hoist(Stmt stmt, HoistExpressionConfig config) { |
426 | auto loop_info = HoistInfoCollector::Collect(stmt, config); |
427 | |
428 | arith::Analyzer analyzer; |
429 | ExpressionHoister hoister(std::move(loop_info), config, &analyzer); |
430 | stmt = hoister(std::move(stmt)); |
431 | stmt = ConvertSSA(std::move(stmt)); |
432 | return stmt; |
433 | } |
434 | |
435 | private: |
436 | using Parent = arith::IRMutatorWithAnalyzer; |
437 | using Parent::VisitExpr_; |
438 | using Parent::VisitStmt_; |
439 | |
440 | explicit ExpressionHoister(std::vector<HoistInfoCollector::HoistInfo> loop_info, |
441 | HoistExpressionConfig config, arith::Analyzer* analyzer) |
442 | : Parent(analyzer), config_(config) { |
443 | for (auto& info : loop_info) { |
444 | // Mark let bindings to use if they are enabled on their own. |
445 | for (const auto& binding : info.let_bindings) { |
446 | if (binding.IsEnabled(config)) { |
447 | hoisted_let_bindings.insert(binding.var.get()); |
448 | } |
449 | } |
450 | |
451 | // Or if they are required by a conditional |
452 | if (config->FlagSet(HoistedLetBindings::kRequiredByCondition)) { |
453 | for (const auto& conditional : info.conditions) { |
454 | if (conditional.IsEnabled(config)) { |
455 | for (const auto& var : conditional.required_let_bindings) { |
456 | hoisted_let_bindings.insert(var); |
457 | } |
458 | } |
459 | } |
460 | } |
461 | |
462 | loop_info_lookup[info.loop_def.get()] = std::move(info); |
463 | } |
464 | } |
465 | |
466 | Stmt WrapHoistedStatements(Stmt stmt, const HoistInfoCollector::HoistInfo& info) { |
467 | for (auto cond_it = info.conditions.rbegin(); cond_it != info.conditions.rend(); cond_it++) { |
468 | if (cond_it->IsEnabled(config_)) { |
469 | if (cond_it->generate_else_case) { |
470 | stmt = IfThenElse(cond_it->condition, stmt, stmt); |
471 | } else { |
472 | stmt = IfThenElse(cond_it->condition, stmt); |
473 | } |
474 | } |
475 | } |
476 | for (auto let_it = info.let_bindings.rbegin(); let_it != info.let_bindings.rend(); let_it++) { |
477 | if (hoisted_let_bindings.count(let_it->var.get())) { |
478 | stmt = LetStmt(let_it->var, let_it->value, stmt); |
479 | } |
480 | } |
481 | |
482 | return stmt; |
483 | } |
484 | |
485 | Stmt VisitStmt_(const ForNode* op) final { |
486 | Stmt stmt = Parent::VisitStmt_(op); |
487 | |
488 | auto it = loop_info_lookup.find(op); |
489 | ICHECK(it != loop_info_lookup.end()) |
490 | << "Could not find pre-pass information for loop over " << op->loop_var; |
491 | return WrapHoistedStatements(stmt, it->second); |
492 | } |
493 | |
494 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
495 | Stmt stmt = Parent::VisitStmt_(op); |
496 | |
497 | auto it = loop_info_lookup.find(op); |
498 | if (it == loop_info_lookup.end()) { |
499 | return stmt; |
500 | } else { |
501 | return WrapHoistedStatements(stmt, it->second); |
502 | } |
503 | } |
504 | |
505 | Stmt VisitStmt_(const LetStmtNode* op) final { |
506 | if (hoisted_let_bindings.count(op->var.get())) { |
507 | return this->VisitStmt(op->body); |
508 | } else { |
509 | return Parent::VisitStmt_(op); |
510 | } |
511 | } |
512 | |
513 | PrimExpr VisitExpr_(const LetNode* op) final { |
514 | if (hoisted_let_bindings.count(op->var.get())) { |
515 | return this->VisitExpr(op->body); |
516 | } else { |
517 | return Parent::VisitExpr_(op); |
518 | } |
519 | } |
520 | |
521 | HoistExpressionConfig config_; |
522 | |
523 | std::unordered_map<const StmtNode*, HoistInfoCollector::HoistInfo> loop_info_lookup; |
524 | std::unordered_set<const VarNode*> hoisted_let_bindings; |
525 | }; |
526 | |
527 | Stmt HoistExpression(Stmt stmt, HoistExpressionConfig config) { |
528 | return ExpressionHoister::Hoist(stmt, config); |
529 | } |
530 | |
531 | namespace transform { |
532 | |
533 | Pass HoistExpression() { |
534 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
535 | auto* n = f.CopyOnWrite(); |
536 | auto cfg = ctx->GetConfig<HoistExpressionConfig>("tir.HoistExpression" ); |
537 | |
538 | if (!cfg.defined()) { |
539 | cfg = AttrsWithDefaultValues<HoistExpressionConfig>(); |
540 | } |
541 | n->body = ExpressionHoister::Hoist(std::move(n->body), cfg.value()); |
542 | return f; |
543 | }; |
544 | auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistedExpression" , {}); |
545 | |
546 | return Sequential( |
547 | { |
548 | insertion_pass, |
549 | Simplify(), |
550 | RemoveNoOp(), |
551 | }, |
552 | "tir.HoistExpression" ); |
553 | } |
554 | |
555 | TVM_REGISTER_GLOBAL("tir.transform.HoistExpression" ).set_body_typed(HoistExpression); |
556 | |
557 | Pass HoistIfThenElse() { |
558 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
559 | auto* n = f.CopyOnWrite(); |
560 | auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse" ); |
561 | |
562 | if (!cfg.defined()) { |
563 | cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>(); |
564 | } |
565 | int block_var = static_cast<int>(cfg.value()->support_block_scope_hosting |
566 | ? HoistedConditionals::kUsingBlockVar |
567 | : HoistedConditionals::kNone); |
568 | HoistExpressionConfig config(block_var | static_cast<int>(HoistedConditionals::kIfElseStmt), |
569 | static_cast<int>(HoistedLetBindings::kNone)); |
570 | n->body = ExpressionHoister::Hoist(std::move(n->body), config); |
571 | return f; |
572 | }; |
573 | auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElse" , {}); |
574 | return Sequential( |
575 | { |
576 | insertion_pass, |
577 | Simplify(), |
578 | RemoveNoOp(), |
579 | }, |
580 | "tir.HoistIfThenElse" ); |
581 | } |
582 | |
583 | TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElse" ).set_body_typed(HoistIfThenElse); |
584 | |
585 | Pass HoistIfThenElseBasic() { |
586 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
587 | auto* n = f.CopyOnWrite(); |
588 | HoistExpressionConfig config(static_cast<int>(HoistedConditionals::kIfElseStmt), |
589 | static_cast<int>(HoistedLetBindings::kNone)); |
590 | n->body = ExpressionHoister::Hoist(std::move(n->body), config); |
591 | return f; |
592 | }; |
593 | auto insertion_pass = CreatePrimFuncPass(pass_func, 0, "tir.InsertHoistIfThenElseBasic" , {}); |
594 | return Sequential( |
595 | { |
596 | insertion_pass, |
597 | Simplify(), |
598 | RemoveNoOp(), |
599 | }, |
600 | "tir.HoistIfThenElseBasic" ); |
601 | } |
602 | |
603 | TVM_REGISTER_GLOBAL("tir.transform.HoistIfThenElseBasic" ).set_body_typed(HoistIfThenElseBasic); |
604 | |
605 | } // namespace transform |
606 | |
607 | } // namespace tir |
608 | } // namespace tvm |
609 | |