1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*! \brief The visitor that finds all the reduction block to be decomposed */
25struct ReductionBlockFinder : private StmtVisitor {
26 public:
27 /*! \brief Find all the reduction blocks that should be decomposed */
28 static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
29 std::vector<std::pair<StmtSRef, String>> results;
30 for (const auto& kv : self->mod->functions) {
31 GlobalVar g_var = kv.first;
32 BaseFunc base_func = kv.second;
33 if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
34 ReductionBlockFinder finder;
35 finder(prim_func->body);
36 for (const BlockNode* block : finder.results_) {
37 results.emplace_back(self->stmt2ref.at(block), g_var->name_hint);
38 }
39 }
40 }
41 return results;
42 }
43
44 private:
45 void VisitStmt_(const ForNode* loop) final {
46 runtime::ThreadScope thread_scope = GetThreadScope(loop);
47 if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) {
48 thread_bound_loop_vars_.insert(loop->loop_var.get());
49 }
50 StmtVisitor::VisitStmt_(loop);
51 }
52
53 void VisitStmt_(const BlockRealizeNode* realize) final {
54 if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) {
55 results_.push_back(realize->block.get());
56 }
57 StmtVisitor::VisitStmt_(realize);
58 }
59
60 bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const {
61 if (thread_bound_loop_vars_.empty()) {
62 return true;
63 }
64 auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); };
65 const BlockNode* block = realize->block.get();
66 ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
67 int n = block->iter_vars.size();
68 for (int i = 0; i < n; ++i) {
69 IterVar iter_var = block->iter_vars[i];
70 PrimExpr binding = realize->iter_values[i];
71 if (iter_var->iter_type == tir::kCommReduce) {
72 if (UsesVar(binding, f_find)) {
73 return false;
74 }
75 }
76 }
77 return true;
78 }
79
80 /*! \brief The results of the collection */
81 std::vector<const BlockNode*> results_;
82 /*! \brief Loop variables that are bound to threads */
83 std::unordered_set<const VarNode*> thread_bound_loop_vars_;
84};
85
86/*!
87 * \brief Find the innermost loop that the `init` of the input block could be decomposed to
88 * \param block_sref The StmtSRef of the block to be decomposed
89 * \return The index of the innermost loop where the `init` of the input block could be decomposed,
90 * or -1 if the `init` does not need to be decomposed.
91 */
92int FindDecomposePoint(const StmtSRef& block_sref) {
93 Array<StmtSRef> loop_srefs = GetLoops(block_sref);
94 int n = loop_srefs.size();
95 for (int i = 0; i < n; ++i) {
96 if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) {
97 return i;
98 }
99 }
100 return -1;
101}
102
103} // namespace tir
104} // namespace tvm
105
106namespace tvm {
107namespace meta_schedule {
108
109/*! \brief Rewrite reduction block by moving the init block out */
110class RewriteReductionBlockNode : public PostprocNode {
111 public:
112 // Inherited from PostprocNode
113 void InitializeWithTuneContext(const TuneContext& context) final {}
114 // Inherited from PostprocNode
115 bool Apply(const tir::Schedule& sch) final;
116
117 Postproc Clone() const {
118 ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>(*this);
119 return Postproc(n);
120 }
121
122 void VisitAttrs(tvm::AttrVisitor* v) {}
123
124 static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
125 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode);
126};
127
128bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
129 for (;;) {
130 std::vector<std::pair<tir::StmtSRef, String>> results =
131 tir::ReductionBlockFinder::Find(sch->state());
132 int rewritten = 0;
133 for (const auto& kv : results) {
134 const tir::StmtSRef& block_sref = kv.first;
135 const String& global_var_name = kv.second;
136 int decompose_point = tir::FindDecomposePoint(block_sref);
137 if (decompose_point == -1) {
138 continue;
139 }
140 tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
141 Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
142 tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);
143
144 // Rewrite auto tensorization related annotations
145 if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
146 // Remove tensorization annotation as it shouldn't be propagated to the init block.
147 sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
148 Optional<String> tensorize_init =
149 tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init);
150 // The annotation of tensorization of the init statement should be moved to the init block
151 // after 'DecomposeReduction'.
152 // Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt.
153 sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
154 tensorize_init.value_or(""));
155 if (tensorize_init.defined()) {
156 sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
157 sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
158 }
159 }
160 ++rewritten;
161 }
162 if (rewritten == 0) {
163 break;
164 }
165 }
166 return true;
167}
168
169Postproc Postproc::RewriteReductionBlock() {
170 ObjectPtr<RewriteReductionBlockNode> n = make_object<RewriteReductionBlockNode>();
171 return Postproc(n);
172}
173
174TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode);
175TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock")
176 .set_body_typed(Postproc::RewriteReductionBlock);
177
178} // namespace meta_schedule
179} // namespace tvm
180