1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | /*! \brief Collecting all the blocks */ |
25 | class BlockCollector : public tir::StmtVisitor { |
26 | public: |
27 | static Array<tir::BlockRV> Collect(const tir::Schedule& sch, |
28 | const runtime::PackedFunc f_block_filter = nullptr) { // |
29 | return BlockCollector(sch, f_block_filter).Run(); |
30 | } |
31 | |
32 | private: |
33 | /*! \brief Entry point */ |
34 | Array<tir::BlockRV> Run() { |
35 | std::vector<tir::BlockRV> results; |
36 | for (const auto& kv : sch_->mod()->functions) { |
37 | const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function |
38 | const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function |
39 | if (const auto* func = base_func.as<tir::PrimFuncNode>()) { |
40 | func_name_ = gv->name_hint; |
41 | block_names_.clear(); |
42 | blocks_to_collect_.clear(); |
43 | VisitStmt(func->body); |
44 | for (const String& name : blocks_to_collect_) { |
45 | results.push_back(sch_->GetBlock(name, func_name_)); |
46 | } |
47 | } |
48 | } |
49 | return results; |
50 | } |
51 | /*! \brief Constructor */ |
52 | explicit BlockCollector(const tir::Schedule& sch, |
53 | const runtime::PackedFunc f_block_filter = nullptr) |
54 | : sch_(sch), f_block_filter_(f_block_filter) {} |
55 | /*! \brief Override the Stmt visiting behaviour */ |
56 | void VisitStmt_(const tir::BlockNode* block) override { |
57 | tir::StmtVisitor::VisitStmt_(block); |
58 | CHECK(block_names_.count(block->name_hint) == 0) |
59 | << "Duplicated block name " << block->name_hint << " in function " << func_name_ |
60 | << " not supported!" ; |
61 | block_names_.insert(block->name_hint); |
62 | |
63 | // If filter function is provided, use it to selectively collect blocks. |
64 | // Otherwise collect all blocks. |
65 | Bool collect_block = Bool(true); |
66 | if (f_block_filter_ != nullptr) { |
67 | collect_block = f_block_filter_(GetRef<tir::Block>(block)); |
68 | } |
69 | if (collect_block) { |
70 | blocks_to_collect_.push_back(block->name_hint); |
71 | } |
72 | } |
73 | |
74 | /*! \brief The schedule to be collected */ |
75 | const tir::Schedule& sch_; |
76 | /*! \brief An optional packed func that allows only certain blocks to be collected. */ |
77 | const runtime::PackedFunc f_block_filter_; |
78 | /*! \brief The set of func name and block name pair */ |
79 | std::unordered_set<String> block_names_; |
80 | /* \brief The list of blocks to collect in order */ |
81 | Array<String> blocks_to_collect_; |
82 | /*! \brief Name of the current PrimFunc */ |
83 | String func_name_; |
84 | }; |
85 | |
86 | /*! |
87 | * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks |
88 | * in post-DFS order. |
89 | * */ |
90 | class PostOrderApplyNode : public SpaceGeneratorNode { |
91 | public: |
92 | /*! |
93 | * \brief Optional block names to target. If not specified all blocks will have spaces generated. |
94 | */ |
95 | runtime::PackedFunc f_block_filter_ = nullptr; |
96 | /*! \brief The random state. -1 means using random number. */ |
97 | TRandState rand_state_ = -1; |
98 | |
99 | void VisitAttrs(tvm::AttrVisitor* v) { |
100 | SpaceGeneratorNode::VisitAttrs(v); |
101 | // `rand_state_` is not visited |
102 | // `sch_rules_` is not visited |
103 | } |
104 | |
105 | void InitializeWithTuneContext(const TuneContext& context) final { |
106 | SpaceGeneratorNode::InitializeWithTuneContext(context); |
107 | this->rand_state_ = ForkSeed(&context->rand_state); |
108 | } |
109 | |
110 | Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final { |
111 | using ScheduleAndUnvisitedBlocks = std::pair<tir::Schedule, Array<tir::BlockRV>>; |
112 | CHECK(sch_rules.defined()) << "ValueError: `sch_rules` is not set in PostOrderApply" ; |
113 | tir::Schedule sch = tir::Schedule::Traced( |
114 | /*mod=*/mod, |
115 | /*rand_state=*/ForkSeed(&this->rand_state_), |
116 | /*debug_mode=*/0, |
117 | /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); |
118 | |
119 | std::vector<ScheduleAndUnvisitedBlocks> stack; |
120 | Array<tir::Schedule> result{sch}; |
121 | Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_); |
122 | |
123 | for (ScheduleRule sch_rule : sch_rules.value()) { |
124 | for (const tir::Schedule& sch : result) { |
125 | stack.emplace_back(sch, all_blocks); |
126 | } |
127 | result.clear(); |
128 | while (!stack.empty()) { |
129 | // get the stack.top() |
130 | auto [sch, blocks] = stack.back(); |
131 | stack.pop_back(); |
132 | // if all blocks are visited |
133 | if (blocks.empty()) { |
134 | result.push_back(sch); |
135 | continue; |
136 | } |
137 | // otherwise, get the last block that is not visited |
138 | tir::BlockRV block_rv = blocks.back(); |
139 | blocks.pop_back(); |
140 | if (!sch->HasBlock(block_rv)) { |
141 | stack.emplace_back(sch, blocks); |
142 | continue; |
143 | } |
144 | if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { |
145 | if (tir::GetAnn<String>(sch->GetSRef(block_rv), "schedule_rule" ).defined()) { |
146 | stack.emplace_back(sch, blocks); |
147 | continue; |
148 | } |
149 | } |
150 | Array<tir::Schedule> applied = sch_rule->Apply(sch, /*block=*/block_rv); |
151 | for (const tir::Schedule& sch : applied) { |
152 | stack.emplace_back(sch, blocks); |
153 | } |
154 | } |
155 | } |
156 | return result; |
157 | } |
158 | |
159 | SpaceGenerator Clone() const final { |
160 | ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>(*this); |
161 | CloneRules(this, n.get()); |
162 | return SpaceGenerator(n); |
163 | } |
164 | static constexpr const char* _type_key = "meta_schedule.PostOrderApply" ; |
165 | TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); |
166 | }; |
167 | |
168 | SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter, |
169 | Optional<Array<ScheduleRule>> sch_rules, |
170 | Optional<Array<Postproc>> postprocs, |
171 | Optional<Map<Mutator, FloatImm>> mutator_probs) { |
172 | ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>(); |
173 | n->sch_rules = std::move(sch_rules); |
174 | n->postprocs = std::move(postprocs); |
175 | n->mutator_probs = std::move(mutator_probs); |
176 | n->f_block_filter_ = std::move(f_block_filter); |
177 | return SpaceGenerator(n); |
178 | } |
179 | |
180 | TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); |
181 | TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply" ) |
182 | .set_body_typed(SpaceGenerator::PostOrderApply); |
183 | |
184 | } // namespace meta_schedule |
185 | } // namespace tvm |
186 | |