1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file convert_block_to_opaque.cc |
22 | * \brief Convert the blocks to opaque blocks which do not have block vars. |
23 | */ |
24 | |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include "ir_utils.h" |
29 | |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | /*! |
34 | * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque |
35 | * blocks. |
36 | */ |
37 | class OpaqueBlockConverter : public StmtExprMutator { |
38 | public: |
39 | static Stmt Substitute(const PrimFunc& f) { |
40 | OpaqueBlockConverter substituter; |
41 | return substituter.VisitStmt(f->body); |
42 | } |
43 | |
44 | private: |
45 | OpaqueBlockConverter() = default; |
46 | |
47 | PrimExpr VisitExpr_(const VarNode* var) final { |
48 | CHECK(!forbidden_iter_vars_.count(var)) |
49 | << "Variable " << var->name_hint << " occurs in the predicate or iter_values of a block, " |
50 | << "but isn't defined until the body of the block" ; |
51 | |
52 | auto it = var_substitutes_.find(var); |
53 | if (it != var_substitutes_.end()) { |
54 | return it->second; |
55 | } |
56 | return GetRef<Var>(var); |
57 | } |
58 | |
59 | Stmt VisitStmt_(const BlockNode* block) final { |
60 | ICHECK(!block->init.defined()) |
61 | << "Block Init part is not allowed in pass ConvertBlocksToOpaque" ; |
62 | Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block)); |
63 | if (!new_block->iter_vars.empty()) { |
64 | new_block.CopyOnWrite()->iter_vars.clear(); |
65 | } |
66 | return std::move(new_block); |
67 | } |
68 | |
69 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
70 | const auto* block_op = realize->block.get(); |
71 | ICHECK(!block_op->init.defined()); |
72 | |
73 | // Step 1. Visit the predicate and iter_values, without any variable bindings |
74 | for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.insert(iter->var.get()); |
75 | PrimExpr predicate = VisitExpr(realize->predicate); |
76 | Array<PrimExpr> iter_values = realize->iter_values; |
77 | iter_values.MutateByApply([this](PrimExpr expr) { return VisitExpr(std::move(expr)); }); |
78 | for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.erase(iter->var.get()); |
79 | |
80 | // Step 2. Update "block vars => binding values" for substitution. |
81 | ICHECK_EQ(block_op->iter_vars.size(), iter_values.size()); |
82 | for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) { |
83 | IterVar block_var = block_op->iter_vars[i]; |
84 | PrimExpr v = this->VisitExpr(iter_values[i]); |
85 | var_substitutes_.emplace(block_var->var.get(), v); |
86 | } |
87 | // Step 3. Visit recursively. |
88 | Block new_block = Downcast<Block>(VisitStmt(realize->block)); |
89 | |
90 | // Step 4. Clear the variable bindings |
91 | for (const auto& block_var : block_op->iter_vars) { |
92 | var_substitutes_.erase(block_var->var.get()); |
93 | } |
94 | |
95 | // Step 5. Return |
96 | if (predicate.same_as(realize->predicate) && iter_values.same_as(realize->iter_values) && |
97 | new_block.same_as(realize->block) && realize->iter_values.size() == 0) { |
98 | return GetRef<BlockRealize>(realize); |
99 | } else { |
100 | return BlockRealize({}, predicate, new_block); |
101 | } |
102 | } |
103 | |
104 | /*! \brief The map from block vars to their binding values. */ |
105 | std::unordered_map<const VarNode*, PrimExpr> var_substitutes_; |
106 | /* \brief Variables that may not occur in the current context */ |
107 | std::unordered_set<const VarNode*> forbidden_iter_vars_; |
108 | }; |
109 | |
110 | PrimFunc ConvertBlocksToOpaque(PrimFunc f) { |
111 | // Only apply this pass to TIR that is not from TE schedules |
112 | if (!IsFromLegacyTESchedule(f)) { |
113 | PrimFuncNode* fptr = f.CopyOnWrite(); |
114 | fptr->body = OpaqueBlockConverter::Substitute(f); |
115 | return f; |
116 | } else { |
117 | return f; |
118 | } |
119 | } |
120 | |
121 | namespace transform { |
122 | |
123 | Pass ConvertBlocksToOpaque() { |
124 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
125 | return ConvertBlocksToOpaque(std::move(f)); |
126 | }; |
127 | return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque" , {}); |
128 | } |
129 | |
130 | TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque" ).set_body_typed(ConvertBlocksToOpaque); |
131 | } // namespace transform |
132 | |
133 | } // namespace tir |
134 | } // namespace tvm |
135 | |