1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! \brief The visitor that finds all the reduction block to be decomposed */ |
25 | struct ReductionBlockFinder : private StmtVisitor { |
26 | public: |
27 | /*! \brief Find all the reduction blocks that should be decomposed */ |
28 | static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) { |
29 | std::vector<std::pair<StmtSRef, String>> results; |
30 | for (const auto& kv : self->mod->functions) { |
31 | GlobalVar g_var = kv.first; |
32 | BaseFunc base_func = kv.second; |
33 | if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
34 | ReductionBlockFinder finder; |
35 | finder(prim_func->body); |
36 | for (const BlockNode* block : finder.results_) { |
37 | results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); |
38 | } |
39 | } |
40 | } |
41 | return results; |
42 | } |
43 | |
44 | private: |
45 | void VisitStmt_(const ForNode* loop) final { |
46 | runtime::ThreadScope thread_scope = GetThreadScope(loop); |
47 | if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { |
48 | thread_bound_loop_vars_.insert(loop->loop_var.get()); |
49 | } |
50 | StmtVisitor::VisitStmt_(loop); |
51 | } |
52 | |
53 | void VisitStmt_(const BlockRealizeNode* realize) final { |
54 | if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { |
55 | results_.push_back(realize->block.get()); |
56 | } |
57 | StmtVisitor::VisitStmt_(realize); |
58 | } |
59 | |
60 | bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { |
61 | if (thread_bound_loop_vars_.empty()) { |
62 | return true; |
63 | } |
64 | auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; |
65 | const BlockNode* block = realize->block.get(); |
66 | ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); |
67 | int n = block->iter_vars.size(); |
68 | for (int i = 0; i < n; ++i) { |
69 | IterVar iter_var = block->iter_vars[i]; |
70 | PrimExpr binding = realize->iter_values[i]; |
71 | if (iter_var->iter_type == tir::kCommReduce) { |
72 | if (UsesVar(binding, f_find)) { |
73 | return false; |
74 | } |
75 | } |
76 | } |
77 | return true; |
78 | } |
79 | |
80 | /*! \brief The results of the collection */ |
81 | std::vector<const BlockNode*> results_; |
82 | /*! \brief Loop variables that are bound to threads */ |
83 | std::unordered_set<const VarNode*> thread_bound_loop_vars_; |
84 | }; |
85 | |
86 | /*! |
87 | * \brief Find the innermost loop that the `init` of the input block could be decomposed to |
88 | * \param block_sref The StmtSRef of the block to be decomposed |
89 | * \return The index of the innermost loop where the `init` of the input block could be decomposed, |
90 | * or -1 if the `init` does not need to be decomposed. |
91 | */ |
92 | int FindDecomposePoint(const StmtSRef& block_sref) { |
93 | Array<StmtSRef> loop_srefs = GetLoops(block_sref); |
94 | int n = loop_srefs.size(); |
95 | for (int i = 0; i < n; ++i) { |
96 | if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { |
97 | return i; |
98 | } |
99 | } |
100 | return -1; |
101 | } |
102 | |
103 | } // namespace tir |
104 | } // namespace tvm |
105 | |
106 | namespace tvm { |
107 | namespace meta_schedule { |
108 | |
109 | /*! \brief Rewrite reduction block by moving the init block out */ |
110 | class RewriteReductionBlockNode : public PostprocNode { |
111 | public: |
112 | // Inherited from PostprocNode |
113 | void InitializeWithTuneContext(const TuneContext& context) final {} |
114 | // Inherited from PostprocNode |
115 | bool Apply(const tir::Schedule& sch) final; |
116 | |
117 | Postproc Clone() const { |
118 | ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(*this); |
119 | return Postproc(n); |
120 | } |
121 | |
122 | void VisitAttrs(tvm::AttrVisitor* v) {} |
123 | |
124 | static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock" ; |
125 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); |
126 | }; |
127 | |
128 | bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { |
129 | for (;;) { |
130 | std::vector<std::pair<tir::StmtSRef, String>> results = |
131 | tir::ReductionBlockFinder::Find(sch->state()); |
132 | int rewritten = 0; |
133 | for (const auto& kv : results) { |
134 | const tir::StmtSRef& block_sref = kv.first; |
135 | const String& global_var_name = kv.second; |
136 | int decompose_point = tir::FindDecomposePoint(block_sref); |
137 | if (decompose_point == -1) { |
138 | continue; |
139 | } |
140 | tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); |
141 | Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv); |
142 | tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); |
143 | |
144 | // Rewrite auto tensorization related annotations |
145 | if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) { |
146 | // Remove tensorization annotation as it shouldn't be propagated to the init block. |
147 | sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); |
148 | Optional<String> tensorize_init = |
149 | tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init); |
150 | // The annotation of tensorization of the init statement should be moved to the init block |
151 | // after 'DecomposeReduction'. |
152 | // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt. |
153 | sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize, |
154 | tensorize_init.value_or("" )); |
155 | if (tensorize_init.defined()) { |
156 | sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init); |
157 | sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init); |
158 | } |
159 | } |
160 | ++rewritten; |
161 | } |
162 | if (rewritten == 0) { |
163 | break; |
164 | } |
165 | } |
166 | return true; |
167 | } |
168 | |
169 | Postproc Postproc::RewriteReductionBlock() { |
170 | ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(); |
171 | return Postproc(n); |
172 | } |
173 | |
174 | TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); |
175 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock" ) |
176 | .set_body_typed(Postproc::RewriteReductionBlock); |
177 | |
178 | } // namespace meta_schedule |
179 | } // namespace tvm |
180 | |