1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief Collecting all the blocks */
25class BlockCollector : public tir::StmtVisitor {
26 public:
27 static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
28 const runtime::PackedFunc f_block_filter = nullptr) { //
29 return BlockCollector(sch, f_block_filter).Run();
30 }
31
32 private:
33 /*! \brief Entry point */
34 Array<tir::BlockRV> Run() {
35 std::vector<tir::BlockRV> results;
36 for (const auto& kv : sch_->mod()->functions) {
37 const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function
38 const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function
39 if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
40 func_name_ = gv->name_hint;
41 block_names_.clear();
42 blocks_to_collect_.clear();
43 VisitStmt(func->body);
44 for (const String& name : blocks_to_collect_) {
45 results.push_back(sch_->GetBlock(name, func_name_));
46 }
47 }
48 }
49 return results;
50 }
51 /*! \brief Constructor */
52 explicit BlockCollector(const tir::Schedule& sch,
53 const runtime::PackedFunc f_block_filter = nullptr)
54 : sch_(sch), f_block_filter_(f_block_filter) {}
55 /*! \brief Override the Stmt visiting behaviour */
56 void VisitStmt_(const tir::BlockNode* block) override {
57 tir::StmtVisitor::VisitStmt_(block);
58 CHECK(block_names_.count(block->name_hint) == 0)
59 << "Duplicated block name " << block->name_hint << " in function " << func_name_
60 << " not supported!";
61 block_names_.insert(block->name_hint);
62
63 // If filter function is provided, use it to selectively collect blocks.
64 // Otherwise collect all blocks.
65 Bool collect_block = Bool(true);
66 if (f_block_filter_ != nullptr) {
67 collect_block = f_block_filter_(GetRef<tir::Block>(block));
68 }
69 if (collect_block) {
70 blocks_to_collect_.push_back(block->name_hint);
71 }
72 }
73
74 /*! \brief The schedule to be collected */
75 const tir::Schedule& sch_;
76 /*! \brief An optional packed func that allows only certain blocks to be collected. */
77 const runtime::PackedFunc f_block_filter_;
78 /*! \brief The set of func name and block name pair */
79 std::unordered_set<String> block_names_;
80 /* \brief The list of blocks to collect in order */
81 Array<String> blocks_to_collect_;
82 /*! \brief Name of the current PrimFunc */
83 String func_name_;
84};
85
86/*!
87 * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks
88 * in post-DFS order.
89 * */
90class PostOrderApplyNode : public SpaceGeneratorNode {
91 public:
92 /*!
93 * \brief Optional block names to target. If not specified all blocks will have spaces generated.
94 */
95 runtime::PackedFunc f_block_filter_ = nullptr;
96 /*! \brief The random state. -1 means using random number. */
97 TRandState rand_state_ = -1;
98
99 void VisitAttrs(tvm::AttrVisitor* v) {
100 SpaceGeneratorNode::VisitAttrs(v);
101 // `rand_state_` is not visited
102 // `sch_rules_` is not visited
103 }
104
105 void InitializeWithTuneContext(const TuneContext& context) final {
106 SpaceGeneratorNode::InitializeWithTuneContext(context);
107 this->rand_state_ = ForkSeed(&context->rand_state);
108 }
109
110 Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
111 using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>;
112 CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply";
113 tir::Schedule sch = tir::Schedule::Traced(
114 /*mod=*/mod,
115 /*rand_state=*/ForkSeed(&this->rand_state_),
116 /*debug_mode=*/0,
117 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail);
118
119 std::vector<ScheduleAndUnvisitedBlocks> stack;
120 Array<tir::Schedule> result{sch};
121 Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_);
122
123 for (ScheduleRule sch_rule : sch_rules.value()) {
124 for (const tir::Schedule& sch : result) {
125 stack.emplace_back(sch, all_blocks);
126 }
127 result.clear();
128 while (!stack.empty()) {
129 // get the stack.top()
130 auto [sch, blocks] = stack.back();
131 stack.pop_back();
132 // if all blocks are visited
133 if (blocks.empty()) {
134 result.push_back(sch);
135 continue;
136 }
137 // otherwise, get the last block that is not visited
138 tir::BlockRV block_rv = blocks.back();
139 blocks.pop_back();
140 if (!sch->HasBlock(block_rv)) {
141 stack.emplace_back(sch, blocks);
142 continue;
143 }
144 if (!ScheduleRule::IsApplyCustomRule(sch_rule)) {
145 if (tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule").defined()) {
146 stack.emplace_back(sch, blocks);
147 continue;
148 }
149 }
150 Array<tir::Schedule> applied = sch_rule->Apply(sch, /*block=*/block_rv);
151 for (const tir::Schedule& sch : applied) {
152 stack.emplace_back(sch, blocks);
153 }
154 }
155 }
156 return result;
157 }
158
159 SpaceGenerator Clone() const final {
160 ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>(*this);
161 CloneRules(this, n.get());
162 return SpaceGenerator(n);
163 }
164 static constexpr const char* _type_key = "meta_schedule.PostOrderApply";
165 TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
166};
167
168SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter,
169 Optional<Array<ScheduleRule>> sch_rules,
170 Optional<Array<Postproc>> postprocs,
171 Optional<Map<Mutator, FloatImm>> mutator_probs) {
172 ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>();
173 n->sch_rules = std::move(sch_rules);
174 n->postprocs = std::move(postprocs);
175 n->mutator_probs = std::move(mutator_probs);
176 n->f_block_filter_ = std::move(f_block_filter);
177 return SpaceGenerator(n);
178}
179
180TVM_REGISTER_NODE_TYPE(PostOrderApplyNode);
181TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply")
182 .set_body_typed(SpaceGenerator::PostOrderApply);
183
184} // namespace meta_schedule
185} // namespace tvm
186