1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | class RandomComputeLocationNode : public ScheduleRuleNode { |
25 | public: |
26 | // Inherited from ScheduleRuleNode |
27 | void InitializeWithTuneContext(const TuneContext& context) final {} |
28 | |
29 | // Inherited from ScheduleRuleNode |
30 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { |
31 | if (!CheckConditions(sch, block_rv)) { |
32 | return {sch}; |
33 | } |
34 | |
35 | // Step 1. If the producer of the input block needs a random compute-at location (specified by |
36 | // the annotation), we collect the producer first, and transform the producer block later. |
37 | // - The reason we collect the producer before transforming the input block is that, if the |
38 | // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer |
39 | // access the input block. Hence we collect its producer ahead of time. |
40 | // - Note that only single producer is allowed in this case. |
41 | Array<tir::BlockRV> producers{nullptr}; |
42 | if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, |
43 | true)) { |
44 | producers = sch->GetProducers(block_rv); |
45 | sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); |
46 | ICHECK_EQ(producers.size(), 1); |
47 | } |
48 | |
49 | // Step 2. Transform the input block. |
50 | tir::Schedule res = RandomlyComputeAt(sch, block_rv); |
51 | |
52 | // Step 3. Transform the producer block if compute-location sampling is needed. |
53 | if (producers.defined()) { |
54 | res = RandomlyComputeAt(res, producers[0]); |
55 | } |
56 | |
57 | return {res}; |
58 | } |
59 | |
60 | // Inherited from ScheduleRuleNode |
61 | ScheduleRule Clone() const final { |
62 | ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>(*this); |
63 | return ScheduleRule(n); |
64 | } |
65 | |
66 | private: |
67 | bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { |
68 | tir::StmtSRef block_sref = sch->GetSRef(block_rv); |
69 | TVM_SREF_TO_BLOCK(block_sref); |
70 | |
71 | // Cond 1. The block is not the root block. |
72 | if (block_sref->parent == nullptr) { |
73 | return false; |
74 | } |
75 | // Cond 2. The block should be the direct child block of the root block. |
76 | if (GetScopeRoot(sch->state(), block_sref, |
77 | /*require_stage_pipeline=*/false) |
78 | ->parent != nullptr) { |
79 | return false; |
80 | } |
81 | // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child |
82 | // block. |
83 | Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref); |
84 | if (loop_srefs.empty()) { |
85 | return false; |
86 | } |
87 | if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { |
88 | return false; |
89 | } |
90 | // Cond 5. The block is not tiled. We check this condition by examine the block's annotation. |
91 | if (tir::HasBeenMultiLevelTiled(block_sref)) { |
92 | return false; |
93 | } |
94 | // Cond 6. The block has at lease one consumer. |
95 | if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { |
96 | return false; |
97 | } |
98 | return true; |
99 | } |
100 | |
101 | /*! |
102 | * \brief Keep sampling a compute-at location for the input block until success. |
103 | * \param sch The TIR schedule |
104 | * \param block_rv The block whose compute-at location is to be sampled |
105 | * \return The TIR schedule after transformation |
106 | */ |
107 | tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
108 | tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); |
109 | sch->ComputeAt(block_rv, compute_at_loc, true); |
110 | return sch; |
111 | } |
112 | |
113 | public: |
114 | void VisitAttrs(tvm::AttrVisitor* v) {} |
115 | |
116 | static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation" ; |
117 | TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); |
118 | }; |
119 | |
120 | ScheduleRule ScheduleRule::RandomComputeLocation() { |
121 | return ScheduleRule(make_object<RandomComputeLocationNode>()); |
122 | } |
123 | |
124 | TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); |
125 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation" ) |
126 | .set_body_typed(ScheduleRule::RandomComputeLocation); |
127 | } // namespace meta_schedule |
128 | } // namespace tvm |
129 | |