1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /**************** Constructor ****************/ |
25 | |
26 | BlockRV::BlockRV() { this->data_ = make_object<BlockRVNode>(); } |
27 | |
28 | LoopRV::LoopRV() { this->data_ = make_object<LoopRVNode>(); } |
29 | |
30 | /**************** GetSRef ****************/ |
31 | |
32 | StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { |
33 | ScheduleState state = this->state(); |
34 | auto it = state->stmt2ref.find(stmt); |
35 | if (it == state->stmt2ref.end()) { |
36 | LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR" ; |
37 | } |
38 | return it->second; |
39 | } |
40 | |
41 | /**************** FFI ****************/ |
42 | |
43 | TVM_REGISTER_NODE_TYPE(BlockRVNode); |
44 | TVM_REGISTER_NODE_TYPE(LoopRVNode); |
45 | TVM_REGISTER_OBJECT_TYPE(ScheduleNode); |
46 | |
47 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod" ) // |
48 | .set_body_method<Schedule>(&ScheduleNode::mod); |
49 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState" ) // |
50 | .set_body_method<Schedule>(&ScheduleNode::state); |
51 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace" ) // |
52 | .set_body_method<Schedule>(&ScheduleNode::trace); |
53 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy" ) // |
54 | .set_body_method<Schedule>(&ScheduleNode::Copy); |
55 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed" ) // |
56 | .set_body_method<Schedule>(&ScheduleNode::Seed); |
57 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed" ) // |
58 | .set_body_method<Schedule>(&ScheduleNode::ForkSeed); |
59 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn" ) // |
60 | .set_body_method<Schedule>(&ScheduleNode::WorkOn); |
61 | |
62 | /**************** (FFI) Constructor ****************/ |
63 | |
64 | TVM_REGISTER_GLOBAL("tir.schedule.BlockRV" ).set_body_typed([]() { return BlockRV(); }); |
65 | TVM_REGISTER_GLOBAL("tir.schedule.LoopRV" ).set_body_typed([]() { return LoopRV(); }); |
66 | TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule" ) |
67 | .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
68 | int debug_mask, int error_render_level) -> Schedule { |
69 | return Schedule::Concrete(mod, debug_mask, seed, |
70 | static_cast<ScheduleErrorRenderLevel>(error_render_level)); |
71 | }); |
72 | TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule" ) |
73 | .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
74 | int debug_mask, int error_render_level) -> Schedule { |
75 | return Schedule::Traced(mod, seed, debug_mask, |
76 | static_cast<ScheduleErrorRenderLevel>(error_render_level)); |
77 | }); |
78 | |
79 | /******** (FFI) Lookup random variables ********/ |
80 | |
81 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet" ) |
82 | .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { |
83 | if (const auto* loop_rv = obj.as<LoopRVNode>()) { |
84 | return self->Get(GetRef<LoopRV>(loop_rv)); |
85 | } |
86 | if (const auto* block_rv = obj.as<BlockRVNode>()) { |
87 | return self->Get(GetRef<BlockRV>(block_rv)); |
88 | } |
89 | if (const auto* expr_rv = obj.as<ExprRVNode>()) { |
90 | return self->Get(GetRef<ExprRV>(expr_rv)); |
91 | } |
92 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() |
93 | << ". Its value is: " << obj; |
94 | throw; |
95 | }); |
96 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef" ) |
97 | .set_body_typed([](Schedule self, ObjectRef obj) -> Optional<ObjectRef> { |
98 | if (const auto* loop_rv = obj.as<LoopRVNode>()) { |
99 | return self->GetSRef(GetRef<LoopRV>(loop_rv)); |
100 | } |
101 | if (const auto* block_rv = obj.as<BlockRVNode>()) { |
102 | return self->GetSRef(GetRef<BlockRV>(block_rv)); |
103 | } |
104 | if (const auto* stmt = obj.as<StmtNode>()) { |
105 | return self->GetSRef(GetRef<Stmt>(stmt)); |
106 | } |
107 | LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); |
108 | throw; |
109 | }); |
110 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV" ) |
111 | .set_body_typed([](Schedule self, ObjectRef obj) -> void { |
112 | if (const auto* loop_rv = obj.as<LoopRVNode>()) { |
113 | return self->RemoveRV(GetRef<LoopRV>(loop_rv)); |
114 | } |
115 | if (const auto* block_rv = obj.as<BlockRVNode>()) { |
116 | return self->RemoveRV(GetRef<BlockRV>(block_rv)); |
117 | } |
118 | if (const auto* expr_rv = obj.as<ExprRVNode>()) { |
119 | return self->RemoveRV(GetRef<ExprRV>(expr_rv)); |
120 | } |
121 | LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); |
122 | throw; |
123 | }); |
124 | |
125 | /******** (FFI) Sampling ********/ |
126 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical" ) |
127 | .set_body_method<Schedule>(&ScheduleNode::SampleCategorical); |
128 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile" ) |
129 | .set_body_method<Schedule>(&ScheduleNode::SamplePerfectTile); |
130 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation" ) |
131 | .set_body_method<Schedule>(&ScheduleNode::SampleComputeLocation); |
132 | /******** (FFI) Get blocks & loops ********/ |
133 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock" ) |
134 | .set_body_method<Schedule>(&ScheduleNode::GetBlock); |
135 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops" ) |
136 | .set_body_method<Schedule>(&ScheduleNode::GetLoops); |
137 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks" ) |
138 | .set_body_typed([](Schedule self, ObjectRef rv) { |
139 | if (const auto* block_rv = rv.as<BlockRVNode>()) { |
140 | return self->GetChildBlocks(GetRef<BlockRV>(block_rv)); |
141 | } |
142 | if (const auto* loop_rv = rv.as<LoopRVNode>()) { |
143 | return self->GetChildBlocks(GetRef<LoopRV>(loop_rv)); |
144 | } |
145 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() |
146 | << ". Its value is: " << rv; |
147 | throw; |
148 | }); |
149 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers" ) |
150 | .set_body_method<Schedule>(&ScheduleNode::GetProducers); |
151 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers" ) |
152 | .set_body_method<Schedule>(&ScheduleNode::GetConsumers); |
153 | /******** (FFI) Transform loops ********/ |
154 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse" ).set_body_method<Schedule>(&ScheduleNode::Fuse); |
155 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit" ).set_body_method<Schedule>(&ScheduleNode::Split); |
156 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder" ) |
157 | .set_body_method<Schedule>(&ScheduleNode::Reorder); |
158 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop" ) |
159 | .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { |
160 | if (const auto* loop_rv = rv.as<LoopRVNode>()) { |
161 | return self->AddUnitLoop(GetRef<LoopRV>(loop_rv)); |
162 | } else if (const auto* block_rv = rv.as<BlockRVNode>()) { |
163 | return self->AddUnitLoop(GetRef<BlockRV>(block_rv)); |
164 | } else { |
165 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() |
166 | << ". Its value is: " << rv; |
167 | throw; |
168 | } |
169 | }); |
170 | /******** (FFI) Manipulate ForKind ********/ |
171 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel" ) |
172 | .set_body_method<Schedule>(&ScheduleNode::Parallel); |
173 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize" ) |
174 | .set_body_method<Schedule>(&ScheduleNode::Vectorize); |
175 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind" ).set_body_method<Schedule>(&ScheduleNode::Bind); |
176 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll" ).set_body_method<Schedule>(&ScheduleNode::Unroll); |
177 | /******** (FFI) Insert cache stages ********/ |
178 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead" ) |
179 | .set_body_method<Schedule>(&ScheduleNode::CacheRead); |
180 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite" ) |
181 | .set_body_method<Schedule>(&ScheduleNode::CacheWrite); |
182 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace" ) |
183 | .set_body_method<Schedule>(&ScheduleNode::CacheInplace); |
184 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex" ) |
185 | .set_body_method<Schedule>(&ScheduleNode::CacheIndex); |
186 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex" ) |
187 | .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, |
188 | int buffer_index_type) { |
189 | return self->ReIndex(block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type)); |
190 | }); |
191 | /******** (FFI) Compute location ********/ |
192 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt" ) |
193 | .set_body_method<Schedule>(&ScheduleNode::ComputeAt); |
194 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt" ) |
195 | .set_body_method<Schedule>(&ScheduleNode::ReverseComputeAt); |
196 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline" ) |
197 | .set_body_method<Schedule>(&ScheduleNode::ComputeInline); |
198 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline" ) |
199 | .set_body_method<Schedule>(&ScheduleNode::ReverseComputeInline); |
200 | /******** (FFI) Reduction ********/ |
201 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction" ) |
202 | .set_body_method<Schedule>(&ScheduleNode::DecomposeReduction); |
203 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor" ) |
204 | .set_body_method<Schedule>(&ScheduleNode::RFactor); |
205 | /******** (FFI) Block annotation ********/ |
206 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign" ) |
207 | .set_body_method<Schedule>(&ScheduleNode::StorageAlign); |
208 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope" ) |
209 | .set_body_method<Schedule>(&ScheduleNode::SetScope); |
210 | /******** (FFI) Blockize & Tensorize ********/ |
211 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize" ) |
212 | .set_body_method<Schedule>(&ScheduleNode::Blockize); |
213 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize" ) |
214 | .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool preserve_unit_iters) { |
215 | if (const auto* block_rv = rv.as<BlockRVNode>()) { |
216 | self->Tensorize(GetRef<BlockRV>(block_rv), intrin, preserve_unit_iters); |
217 | } else if (const auto* loop_rv = rv.as<LoopRVNode>()) { |
218 | self->Tensorize(GetRef<LoopRV>(loop_rv), intrin, preserve_unit_iters); |
219 | } else { |
220 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() |
221 | << ". Its value is: " << rv; |
222 | } |
223 | }); |
224 | |
225 | /******** (FFI) Annotation ********/ |
226 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate" ) |
227 | .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, |
228 | const ObjectRef& ann_val) { |
229 | if (const auto* block_rv = rv.as<BlockRVNode>()) { |
230 | return self->Annotate(GetRef<BlockRV>(block_rv), ann_key, ann_val); |
231 | } |
232 | if (const auto* loop_rv = rv.as<LoopRVNode>()) { |
233 | return self->Annotate(GetRef<LoopRV>(loop_rv), ann_key, ann_val); |
234 | } |
235 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() |
236 | << ". Its value is: " << rv; |
237 | throw; |
238 | }); |
239 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate" ) |
240 | .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { |
241 | if (const auto* block_rv = rv.as<BlockRVNode>()) { |
242 | return self->Unannotate(GetRef<BlockRV>(block_rv), ann_key); |
243 | } |
244 | if (const auto* loop_rv = rv.as<LoopRVNode>()) { |
245 | return self->Unannotate(GetRef<LoopRV>(loop_rv), ann_key); |
246 | } |
247 | LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() |
248 | << ". Its value is: " << rv; |
249 | throw; |
250 | }); |
251 | |
252 | /******** (FFI) Layout transformation ********/ |
253 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout" ) |
254 | .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, |
255 | int buffer_index_type, const IndexMap& index_map, |
256 | const Optional<IndexMap>& pad_value) { |
257 | return self->TransformLayout(block_rv, buffer_index, |
258 | static_cast<BufferIndexType>(buffer_index_type), index_map, |
259 | pad_value); |
260 | }); |
261 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout" ) |
262 | .set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout); |
263 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator" ) |
264 | .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, |
265 | int buffer_index_type, const Array<IntImm>& axis_separators) { |
266 | return self->SetAxisSeparator( |
267 | block_rv, buffer_index, static_cast<BufferIndexType>(buffer_index_type), axis_separators); |
268 | }); |
269 | |
270 | /******** (FFI) Padding decomposition ********/ |
271 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding" ) |
272 | .set_body_method<Schedule>(&ScheduleNode::DecomposePadding); |
273 | TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum" ) |
274 | .set_body_method<Schedule>(&ScheduleNode::PadEinsum); |
275 | /******** (FFI) Buffer transformation ********/ |
276 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer" ) |
277 | .set_body_method<Schedule>(&ScheduleNode::RollingBuffer); |
278 | /******** (FFI) Misc ********/ |
279 | TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc" ) |
280 | .set_body_method<Schedule>(&ScheduleNode::EnterPostproc); |
281 | |
282 | } // namespace tir |
283 | } // namespace tvm |
284 | |