1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./utils.h"
20
21namespace tvm {
22namespace tir {
23
24/**************** Constructor ****************/
25
26BlockRV::BlockRV() { this->data_ = make_object<BlockRVNode>(); }
27
28LoopRV::LoopRV() { this->data_ = make_object<LoopRVNode>(); }
29
30/**************** GetSRef ****************/
31
32StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const {
33 ScheduleState state = this->state();
34 auto it = state->stmt2ref.find(stmt);
35 if (it == state->stmt2ref.end()) {
36 LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR";
37 }
38 return it->second;
39}
40
41/**************** FFI ****************/
42
43TVM_REGISTER_NODE_TYPE(BlockRVNode);
44TVM_REGISTER_NODE_TYPE(LoopRVNode);
45TVM_REGISTER_OBJECT_TYPE(ScheduleNode);
46
47TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") //
48 .set_body_method<Schedule>(&ScheduleNode::mod);
49TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
50 .set_body_method<Schedule>(&ScheduleNode::state);
51TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
52 .set_body_method<Schedule>(&ScheduleNode::trace);
53TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
54 .set_body_method<Schedule>(&ScheduleNode::Copy);
55TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
56 .set_body_method<Schedule>(&ScheduleNode::Seed);
57TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") //
58 .set_body_method<Schedule>(&ScheduleNode::ForkSeed);
59TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") //
60 .set_body_method<Schedule>(&ScheduleNode::WorkOn);
61
62/**************** (FFI) Constructor ****************/
63
64TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); });
65TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); });
66TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
67 .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
68 int debug_mask, int error_render_level) -> Schedule {
69 return Schedule::Concrete(mod, debug_mask, seed,
70 static_cast<ScheduleErrorRenderLevel>(error_render_level));
71 });
72TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
73 .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
74 int debug_mask, int error_render_level) -> Schedule {
75 return Schedule::Traced(mod, seed, debug_mask,
76 static_cast<ScheduleErrorRenderLevel>(error_render_level));
77 });
78
79/******** (FFI) Lookup random variables ********/
80
81TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet")
82 .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef {
83 if (const auto* loop_rv = obj.as<LoopRVNode>()) {
84 return self->Get(GetRef<LoopRV>(loop_rv));
85 }
86 if (const auto* block_rv = obj.as<BlockRVNode>()) {
87 return self->Get(GetRef<BlockRV>(block_rv));
88 }
89 if (const auto* expr_rv = obj.as<ExprRVNode>()) {
90 return self->Get(GetRef<ExprRV>(expr_rv));
91 }
92 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey()
93 << ". Its value is: " << obj;
94 throw;
95 });
96TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef")
97 .set_body_typed([](Schedule self, ObjectRef obj) -> Optional<ObjectRef> {
98 if (const auto* loop_rv = obj.as<LoopRVNode>()) {
99 return self->GetSRef(GetRef<LoopRV>(loop_rv));
100 }
101 if (const auto* block_rv = obj.as<BlockRVNode>()) {
102 return self->GetSRef(GetRef<BlockRV>(block_rv));
103 }
104 if (const auto* stmt = obj.as<StmtNode>()) {
105 return self->GetSRef(GetRef<Stmt>(stmt));
106 }
107 LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey();
108 throw;
109 });
110TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV")
111 .set_body_typed([](Schedule self, ObjectRef obj) -> void {
112 if (const auto* loop_rv = obj.as<LoopRVNode>()) {
113 return self->RemoveRV(GetRef<LoopRV>(loop_rv));
114 }
115 if (const auto* block_rv = obj.as<BlockRVNode>()) {
116 return self->RemoveRV(GetRef<BlockRV>(block_rv));
117 }
118 if (const auto* expr_rv = obj.as<ExprRVNode>()) {
119 return self->RemoveRV(GetRef<ExprRV>(expr_rv));
120 }
121 LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey();
122 throw;
123 });
124
125/******** (FFI) Sampling ********/
126TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical")
127 .set_body_method<Schedule>(&ScheduleNode::SampleCategorical);
128TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile")
129 .set_body_method<Schedule>(&ScheduleNode::SamplePerfectTile);
130TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation")
131 .set_body_method<Schedule>(&ScheduleNode::SampleComputeLocation);
132/******** (FFI) Get blocks & loops ********/
133TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock")
134 .set_body_method<Schedule>(&ScheduleNode::GetBlock);
135TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops")
136 .set_body_method<Schedule>(&ScheduleNode::GetLoops);
137TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
138 .set_body_typed([](Schedule self, ObjectRef rv) {
139 if (const auto* block_rv = rv.as<BlockRVNode>()) {
140 return self->GetChildBlocks(GetRef<BlockRV>(block_rv));
141 }
142 if (const auto* loop_rv = rv.as<LoopRVNode>()) {
143 return self->GetChildBlocks(GetRef<LoopRV>(loop_rv));
144 }
145 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
146 << ". Its value is: " << rv;
147 throw;
148 });
149TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
150 .set_body_method<Schedule>(&ScheduleNode::GetProducers);
151TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
152 .set_body_method<Schedule>(&ScheduleNode::GetConsumers);
153/******** (FFI) Transform loops ********/
154TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
155TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
156TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")
157 .set_body_method<Schedule>(&ScheduleNode::Reorder);
158TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop")
159 .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV {
160 if (const auto* loop_rv = rv.as<LoopRVNode>()) {
161 return self->AddUnitLoop(GetRef<LoopRV>(loop_rv));
162 } else if (const auto* block_rv = rv.as<BlockRVNode>()) {
163 return self->AddUnitLoop(GetRef<BlockRV>(block_rv));
164 } else {
165 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
166 << ". Its value is: " << rv;
167 throw;
168 }
169 });
170/******** (FFI) Manipulate ForKind ********/
171TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel")
172 .set_body_method<Schedule>(&ScheduleNode::Parallel);
173TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize")
174 .set_body_method<Schedule>(&ScheduleNode::Vectorize);
175TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method<Schedule>(&ScheduleNode::Bind);
176TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method<Schedule>(&ScheduleNode::Unroll);
177/******** (FFI) Insert cache stages ********/
178TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead")
179 .set_body_method<Schedule>(&ScheduleNode::CacheRead);
180TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite")
181 .set_body_method<Schedule>(&ScheduleNode::CacheWrite);
182TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace")
183 .set_body_method<Schedule>(&ScheduleNode::CacheInplace);
184TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex")
185 .set_body_method<Schedule>(&ScheduleNode::CacheIndex);
186TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex")
187 .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
188 int buffer_index_type) {
189 return self->ReIndex(block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type));
190 });
191/******** (FFI) Compute location ********/
192TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt")
193 .set_body_method<Schedule>(&ScheduleNode::ComputeAt);
194TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt")
195 .set_body_method<Schedule>(&ScheduleNode::ReverseComputeAt);
196TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
197 .set_body_method<Schedule>(&ScheduleNode::ComputeInline);
198TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
199 .set_body_method<Schedule>(&ScheduleNode::ReverseComputeInline);
200/******** (FFI) Reduction ********/
201TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction")
202 .set_body_method<Schedule>(&ScheduleNode::DecomposeReduction);
203TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
204 .set_body_method<Schedule>(&ScheduleNode::RFactor);
205/******** (FFI) Block annotation ********/
206TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
207 .set_body_method<Schedule>(&ScheduleNode::StorageAlign);
208TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
209 .set_body_method<Schedule>(&ScheduleNode::SetScope);
210/******** (FFI) Blockize & Tensorize ********/
211TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
212 .set_body_method<Schedule>(&ScheduleNode::Blockize);
213TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
214 .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) {
215 if (const auto* block_rv = rv.as<BlockRVNode>()) {
216 self->Tensorize(GetRef<BlockRV>(block_rv), intrin, preserve_unit_iters);
217 } else if (const auto* loop_rv = rv.as<LoopRVNode>()) {
218 self->Tensorize(GetRef<LoopRV>(loop_rv), intrin, preserve_unit_iters);
219 } else {
220 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
221 << ". Its value is: " << rv;
222 }
223 });
224
225/******** (FFI) Annotation ********/
226TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate")
227 .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key,
228 const ObjectRef& ann_val) {
229 if (const auto* block_rv = rv.as<BlockRVNode>()) {
230 return self->Annotate(GetRef<BlockRV>(block_rv), ann_key, ann_val);
231 }
232 if (const auto* loop_rv = rv.as<LoopRVNode>()) {
233 return self->Annotate(GetRef<LoopRV>(loop_rv), ann_key, ann_val);
234 }
235 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
236 << ". Its value is: " << rv;
237 throw;
238 });
239TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate")
240 .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) {
241 if (const auto* block_rv = rv.as<BlockRVNode>()) {
242 return self->Unannotate(GetRef<BlockRV>(block_rv), ann_key);
243 }
244 if (const auto* loop_rv = rv.as<LoopRVNode>()) {
245 return self->Unannotate(GetRef<LoopRV>(loop_rv), ann_key);
246 }
247 LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
248 << ". Its value is: " << rv;
249 throw;
250 });
251
252/******** (FFI) Layout transformation ********/
253TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
254 .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
255 int buffer_index_type, const IndexMap& index_map,
256 const Optional<IndexMap>& pad_value) {
257 return self->TransformLayout(block_rv, buffer_index,
258 static_cast<BufferIndexType>(buffer_index_type), index_map,
259 pad_value);
260 });
261TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
262 .set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);
263TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
264 .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
265 int buffer_index_type, const Array<IntImm>& axis_separators) {
266 return self->SetAxisSeparator(
267 block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type), axis_separators);
268 });
269
270/******** (FFI) Padding decomposition ********/
271TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding")
272 .set_body_method<Schedule>(&ScheduleNode::DecomposePadding);
273TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum")
274 .set_body_method<Schedule>(&ScheduleNode::PadEinsum);
275/******** (FFI) Buffer transformation ********/
276TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer")
277 .set_body_method<Schedule>(&ScheduleNode::RollingBuffer);
278/******** (FFI) Misc ********/
279TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
280 .set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
281
282} // namespace tir
283} // namespace tvm
284