1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class RandomComputeLocationNode : public ScheduleRuleNode {
25 public:
26 // Inherited from ScheduleRuleNode
27 void InitializeWithTuneContext(const TuneContext& context) final {}
28
29 // Inherited from ScheduleRuleNode
30 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
31 if (!CheckConditions(sch, block_rv)) {
32 return {sch};
33 }
34
35 // Step 1. If the producer of the input block needs a random compute-at location (specified by
36 // the annotation), we collect the producer first, and transform the producer block later.
37 // - The reason we collect the producer before transforming the input block is that, if the
38 // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer
39 // access the input block. Hence we collect its producer ahead of time.
40 // - Note that only single producer is allowed in this case.
41 Array<tir::BlockRV> producers{nullptr};
42 if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer,
43 true)) {
44 producers = sch->GetProducers(block_rv);
45 sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer);
46 ICHECK_EQ(producers.size(), 1);
47 }
48
49 // Step 2. Transform the input block.
50 tir::Schedule res = RandomlyComputeAt(sch, block_rv);
51
52 // Step 3. Transform the producer block if compute-location sampling is needed.
53 if (producers.defined()) {
54 res = RandomlyComputeAt(res, producers[0]);
55 }
56
57 return {res};
58 }
59
60 // Inherited from ScheduleRuleNode
61 ScheduleRule Clone() const final {
62 ObjectPtr<RandomComputeLocationNode> n = make_object<RandomComputeLocationNode>(*this);
63 return ScheduleRule(n);
64 }
65
66 private:
67 bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
68 tir::StmtSRef block_sref = sch->GetSRef(block_rv);
69 TVM_SREF_TO_BLOCK(block_sref);
70
71 // Cond 1. The block is not the root block.
72 if (block_sref->parent == nullptr) {
73 return false;
74 }
75 // Cond 2. The block should be the direct child block of the root block.
76 if (GetScopeRoot(sch->state(), block_sref,
77 /*require_stage_pipeline=*/false)
78 ->parent != nullptr) {
79 return false;
80 }
81 // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child
82 // block.
83 Array<tir::StmtSRef> loop_srefs = tir::GetLoops(block_sref);
84 if (loop_srefs.empty()) {
85 return false;
86 }
87 if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) {
88 return false;
89 }
90 // Cond 5. The block is not tiled. We check this condition by examine the block's annotation.
91 if (tir::HasBeenMultiLevelTiled(block_sref)) {
92 return false;
93 }
94 // Cond 6. The block has at lease one consumer.
95 if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) {
96 return false;
97 }
98 return true;
99 }
100
101 /*!
102 * \brief Keep sampling a compute-at location for the input block until success.
103 * \param sch The TIR schedule
104 * \param block_rv The block whose compute-at location is to be sampled
105 * \return The TIR schedule after transformation
106 */
107 tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
108 tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv);
109 sch->ComputeAt(block_rv, compute_at_loc, true);
110 return sch;
111 }
112
113 public:
114 void VisitAttrs(tvm::AttrVisitor* v) {}
115
116 static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation";
117 TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode);
118};
119
120ScheduleRule ScheduleRule::RandomComputeLocation() {
121 return ScheduleRule(make_object<RandomComputeLocationNode>());
122}
123
124TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode);
125TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation")
126 .set_body_typed(ScheduleRule::RandomComputeLocation);
127} // namespace meta_schedule
128} // namespace tvm
129