1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24/*! \brief The type of inline to be performed on a specific block */
25enum class InlineType : int32_t {
26 /*! \brief No inline opportunity */
27 kNoInline = 0,
28 /*! \brief Inline the block into its consumer */
29 kInlineIntoConsumer = 1,
30 /*! \brief Inline the block into its producer */
31 kInlineIntoProducer = 2,
32};
33
34bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sref) {
35 using namespace tvm::tir;
36 const StmtSRefNode* sref = block_sref.get();
37 for (; sref->parent != nullptr; sref = sref->parent) {
38 }
39 ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance<BlockNode>());
40 return IsSpatialPrimFunc(GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr)));
41}
42
43/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */
44class AutoInlineNode : public ScheduleRuleNode {
45 public:
46 /*! \brief Checks if the specific block should be inlined */
47 inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv);
48
49 // Inherited from ScheduleRuleNode
50 void InitializeWithTuneContext(const TuneContext& context) final {}
51
52 // Inherited from ScheduleRuleNode
53 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
54 InlineType inline_type = CheckInline(sch, block_rv);
55 if (inline_type == InlineType::kInlineIntoConsumer) {
56 sch->ComputeInline(block_rv);
57 } else if (inline_type == InlineType::kInlineIntoProducer) {
58 sch->ReverseComputeInline(block_rv);
59 }
60 return {sch};
61 }
62
63 // Inherited from ScheduleRuleNode
64 ScheduleRule Clone() const final {
65 ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>(*this);
66 return ScheduleRule(n);
67 }
68
69 public:
70 /*! \brief If allows to inline a block into its producer */
71 bool into_producer;
72 /*! \brief If allows to inline a block into its consumer */
73 bool into_consumer;
74 /*! \brief Always inline constant tensors */
75 bool inline_const_tensor;
76 /*! \brief Always disallow if-then-else-like constructs */
77 bool disallow_if_then_else;
78 /*! \brief Always require the read-to-write mapping to be injective to do auto inline */
79 bool require_injective;
80 /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */
81 bool require_ordered;
82 /*! \brief The operators that are disallowed in auto inline */
83 Array<Op> disallow_op;
84
85 void VisitAttrs(tvm::AttrVisitor* v) {
86 v->Visit("into_producer", &into_producer);
87 v->Visit("into_consumer", &into_consumer);
88 v->Visit("inline_const_tensor", &inline_const_tensor);
89 v->Visit("disallow_if_then_else", &disallow_if_then_else);
90 v->Visit("require_injective", &require_injective);
91 v->Visit("require_ordered", &require_ordered);
92 v->Visit("disallow_op", &disallow_op);
93 }
94
95 static constexpr const char* _type_key = "meta_schedule.AutoInline";
96 TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode);
97};
98
99inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
100 const tir::BlockRV& block_rv) {
101 using namespace tvm::tir;
102 StmtSRef block_sref = sch->GetSRef(block_rv);
103 bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
104 ScheduleState state = sch->state();
105 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
106 BlockRealize realize = GetBlockRealize(state, block_sref);
107 // Cond 1. The block has only one write buffer
108 if (block->writes.size() != 1) {
109 return InlineType::kNoInline;
110 }
111 // Cond 2. For a block that generates a constant tensor, ignore all other conditions
112 if (inline_const_tensor && block->reads.empty()) {
113 Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
114 if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
115 return InlineType::kInlineIntoConsumer;
116 }
117 }
118 // Cond 3. The block doesn't contain any disallowed operators
119 if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
120 return InlineType::kNoInline;
121 }
122 // Cond 4. The block doesn't have any if-then-else-like constructs
123 if (!is_pure_sptial && disallow_if_then_else && HasIfThenElse(realize)) {
124 return InlineType::kNoInline;
125 }
126 // Cond 5. The mapping from read indices to write indices are injective and ordered
127 if (!is_pure_sptial && (require_injective || require_ordered)) {
128 const BufferRegion& write_region = block->writes[0];
129 for (const BufferRegion& read_region : block->reads) {
130 bool injective, ordered;
131 auto _ = std::ignore;
132 std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
133 /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
134 if (require_injective && injective == false) {
135 return InlineType::kNoInline;
136 }
137 if (require_ordered && ordered == false) {
138 return InlineType::kNoInline;
139 }
140 }
141 }
142 // Cond 6. The block is disallowed for auto inline
143 if (Optional<String> ann =
144 tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_inline_rule)) {
145 if (ann.value() == "disable") return InlineType::kNoInline;
146 }
147 // Last cond: Check inline into the consumers or the spatial producer
148 tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
149 /*require_stage_pipeline=*/false);
150 if (into_consumer) {
151 Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
152 if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
153 return InlineType::kInlineIntoConsumer;
154 }
155 }
156 if (into_producer) {
157 Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
158 if (producer_srefs.size() == 1 &&
159 tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
160 CanReverseComputeInline(state, block_sref) &&
161 !GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
162 return InlineType::kInlineIntoProducer;
163 }
164 }
165 return InlineType::kNoInline;
166}
167
168ScheduleRule ScheduleRule::AutoInline(bool into_producer, //
169 bool into_consumer, //
170 bool inline_const_tensor, //
171 bool disallow_if_then_else, //
172 bool require_injective, //
173 bool require_ordered, //
174 Optional<Array<String>> disallow_op) {
175 ObjectPtr<AutoInlineNode> n = make_object<AutoInlineNode>();
176 n->into_producer = into_producer;
177 n->into_consumer = into_consumer;
178 n->inline_const_tensor = inline_const_tensor;
179 n->disallow_if_then_else = disallow_if_then_else;
180 n->require_injective = require_injective;
181 n->require_ordered = require_ordered;
182 n->disallow_op.clear();
183 if (disallow_op.defined()) {
184 Array<String> op_names = disallow_op.value();
185 n->disallow_op.reserve(op_names.size());
186 for (const String& op_name : op_names) {
187 n->disallow_op.push_back(Op::Get(op_name));
188 }
189 }
190 return ScheduleRule(n);
191}
192
193TVM_REGISTER_NODE_TYPE(AutoInlineNode);
194TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
195 .set_body_typed(ScheduleRule::AutoInline);
196
197/*! \brief Inline blocks that produce a constant scalar. */
198class InlineConstantScalarsNode : public ScheduleRuleNode {
199 public:
200 void InitializeWithTuneContext(const TuneContext& context) final {}
201
202 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
203 // Look for a block of the form
204 // block compile_engine_const(iter_var(vi, range(min=0, ext=1))) {
205 // reads([])
206 // writes([compile_engine_const[]])
207 // compile_engine_const[] = 59
208 // }
209 auto block = sch->Get(block_rv);
210 if (block->reads.size() == 0 && block->writes.size() == 1 &&
211 block->writes[0]->buffer->shape.size() == 0) {
212 sch->ComputeInline(block_rv);
213 }
214 return {sch};
215 }
216
217 ScheduleRule Clone() const final {
218 ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>(*this);
219 return ScheduleRule(n);
220 }
221
222 static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars";
223 TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode);
224};
225
226ScheduleRule ScheduleRule::InlineConstantScalars() {
227 ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>();
228 return ScheduleRule(n);
229}
230
231TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode);
232TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars")
233 .set_body_typed(ScheduleRule::InlineConstantScalars);
234} // namespace meta_schedule
235} // namespace tvm
236