1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file convert_block_to_opaque.cc
22 * \brief Convert the blocks to opaque blocks which do not have block vars.
23 */
24
25#include <tvm/tir/stmt_functor.h>
26#include <tvm/tir/transform.h>
27
28#include "ir_utils.h"
29
30namespace tvm {
31namespace tir {
32
33/*!
34 * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque
35 * blocks.
36 */
37class OpaqueBlockConverter : public StmtExprMutator {
38 public:
39 static Stmt Substitute(const PrimFunc& f) {
40 OpaqueBlockConverter substituter;
41 return substituter.VisitStmt(f->body);
42 }
43
44 private:
45 OpaqueBlockConverter() = default;
46
47 PrimExpr VisitExpr_(const VarNode* var) final {
48 CHECK(!forbidden_iter_vars_.count(var))
49 << "Variable " << var->name_hint << " occurs in the predicate or iter_values of a block, "
50 << "but isn't defined until the body of the block";
51
52 auto it = var_substitutes_.find(var);
53 if (it != var_substitutes_.end()) {
54 return it->second;
55 }
56 return GetRef<Var>(var);
57 }
58
59 Stmt VisitStmt_(const BlockNode* block) final {
60 ICHECK(!block->init.defined())
61 << "Block Init part is not allowed in pass ConvertBlocksToOpaque";
62 Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
63 if (!new_block->iter_vars.empty()) {
64 new_block.CopyOnWrite()->iter_vars.clear();
65 }
66 return std::move(new_block);
67 }
68
69 Stmt VisitStmt_(const BlockRealizeNode* realize) final {
70 const auto* block_op = realize->block.get();
71 ICHECK(!block_op->init.defined());
72
73 // Step 1. Visit the predicate and iter_values, without any variable bindings
74 for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.insert(iter->var.get());
75 PrimExpr predicate = VisitExpr(realize->predicate);
76 Array<PrimExpr> iter_values = realize->iter_values;
77 iter_values.MutateByApply([this](PrimExpr expr) { return VisitExpr(std::move(expr)); });
78 for (const auto& iter : block_op->iter_vars) forbidden_iter_vars_.erase(iter->var.get());
79
80 // Step 2. Update "block vars => binding values" for substitution.
81 ICHECK_EQ(block_op->iter_vars.size(), iter_values.size());
82 for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) {
83 IterVar block_var = block_op->iter_vars[i];
84 PrimExpr v = this->VisitExpr(iter_values[i]);
85 var_substitutes_.emplace(block_var->var.get(), v);
86 }
87 // Step 3. Visit recursively.
88 Block new_block = Downcast<Block>(VisitStmt(realize->block));
89
90 // Step 4. Clear the variable bindings
91 for (const auto& block_var : block_op->iter_vars) {
92 var_substitutes_.erase(block_var->var.get());
93 }
94
95 // Step 5. Return
96 if (predicate.same_as(realize->predicate) && iter_values.same_as(realize->iter_values) &&
97 new_block.same_as(realize->block) && realize->iter_values.size() == 0) {
98 return GetRef<BlockRealize>(realize);
99 } else {
100 return BlockRealize({}, predicate, new_block);
101 }
102 }
103
104 /*! \brief The map from block vars to their binding values. */
105 std::unordered_map<const VarNode*, PrimExpr> var_substitutes_;
106 /* \brief Variables that may not occur in the current context */
107 std::unordered_set<const VarNode*> forbidden_iter_vars_;
108};
109
110PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
111 // Only apply this pass to TIR that is not from TE schedules
112 if (!IsFromLegacyTESchedule(f)) {
113 PrimFuncNode* fptr = f.CopyOnWrite();
114 fptr->body = OpaqueBlockConverter::Substitute(f);
115 return f;
116 } else {
117 return f;
118 }
119}
120
121namespace transform {
122
123Pass ConvertBlocksToOpaque() {
124 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
125 return ConvertBlocksToOpaque(std::move(f));
126 };
127 return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {});
128}
129
130TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque);
131} // namespace transform
132
133} // namespace tir
134} // namespace tvm
135