1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./concrete_schedule.h" |
20 | |
21 | #include <random> |
22 | |
23 | namespace tvm { |
24 | namespace tir { |
25 | |
26 | Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
27 | int debug_mask, ScheduleErrorRenderLevel error_render_level) { |
28 | ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>(); |
29 | n->state_ = ScheduleState(mod, debug_mask); |
30 | n->error_render_level_ = error_render_level; |
31 | n->symbol_table_ = {}; |
32 | n->analyzer_ = std::make_unique<arith::Analyzer>(); |
33 | n->Seed(seed); |
34 | GlobalVar gv = NullValue<GlobalVar>(); |
35 | if (FindEntryFunc(mod, &gv) != nullptr) { |
36 | n->func_working_on_ = gv; |
37 | } else { |
38 | n->func_working_on_ = NullOpt; |
39 | } |
40 | return Schedule(std::move(n)); |
41 | } |
42 | |
43 | /******** Copy ********/ |
44 | |
45 | /*! \brief Helper class to perform a deep copy of the sref tree */ |
46 | class ScheduleCopier { |
47 | using TSymbolTable = ConcreteScheduleNode::TSymbolTable; |
48 | template <class K, class V> |
49 | using UMap = std::unordered_map<K, V>; |
50 | template <class K, class V> |
51 | using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>; |
52 | |
53 | public: |
54 | static void Copy(const ConcreteScheduleNode* self, ScheduleState* new_state, |
55 | TSymbolTable* new_symbol_table) { |
56 | const ScheduleState& src_state = self->state_; |
57 | ScheduleCopier copier(src_state); |
58 | ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>(); |
59 | n->mod = src_state->mod; |
60 | n->block_info = copier.Copy(src_state->block_info); |
61 | n->stmt2ref = copier.Copy(src_state->stmt2ref); |
62 | n->debug_mask = src_state->debug_mask; |
63 | *new_state = ScheduleState(std::move(n)); |
64 | *new_symbol_table = copier.Copy(self->symbol_table_); |
65 | } |
66 | |
67 | private: |
68 | /*! \brief Create the copier and properly set up the `old2new_` table */ |
69 | explicit ScheduleCopier(const ScheduleState& state) { |
70 | // Create SRef tree without parents |
71 | for (const auto& kv : state->stmt2ref) { |
72 | const StmtSRefNode* sref = kv.second.operator->(); |
73 | old2new_.emplace(sref, // the old StmtSRef |
74 | StmtSRef(/*stmt=*/sref->stmt, // the new StmtSRef |
75 | /*parent=*/nullptr, // parent is not set yet |
76 | /*seq_index=*/sref->seq_index)); |
77 | } |
78 | // Fill in the parent field |
79 | // Find out the root along the way |
80 | for (auto& kv : old2new_) { |
81 | const StmtSRefNode* parent = kv.first->parent; |
82 | StmtSRef& sref = kv.second; |
83 | sref->parent = parent ? old2new_.at(parent).get() : nullptr; |
84 | } |
85 | } |
86 | |
87 | /*! \brief Copy StmtSRef */ |
88 | StmtSRef Copy(const StmtSRef& sref) { return old2new_.at(sref.operator->()); } |
89 | |
90 | /*! \brief Copy StmtSRefNode */ |
91 | StmtSRef Copy(const StmtSRefNode* sref) { |
92 | if (old2new_.count(sref)) { |
93 | return old2new_.at(sref); |
94 | } |
95 | // Handle expired sref |
96 | return old2new_[sref] = StmtSRef(nullptr, nullptr, -1); |
97 | } |
98 | |
99 | /*! \brief Copy Array<StmtSRef> */ |
100 | Array<StmtSRef> Copy(const Array<StmtSRef>& list) { |
101 | Array<StmtSRef> result; |
102 | result.reserve(list.size()); |
103 | for (const StmtSRef& elem : list) { |
104 | result.push_back(Copy(elem)); |
105 | } |
106 | return result; |
107 | } |
108 | |
109 | /*! \brief Copy Array<Dependency> */ |
110 | Array<Dependency> Copy(const Array<Dependency>& list) { |
111 | Array<Dependency> result; |
112 | result.reserve(list.size()); |
113 | for (const Dependency& elem : list) { |
114 | result.push_back(Dependency(Copy(elem->src), Copy(elem->dst), elem->kind)); |
115 | } |
116 | return result; |
117 | } |
118 | |
119 | /*! \brief Copy SMap<StmtSRef, Array<Dependency>> */ |
120 | SMap<StmtSRef, Array<Dependency>> Copy(const SMap<StmtSRef, Array<Dependency>>& map) { |
121 | SMap<StmtSRef, Array<Dependency>> result; |
122 | result.reserve(map.size()); |
123 | for (const auto& kv : map) { |
124 | result[Copy(kv.first)] = Copy(kv.second); |
125 | } |
126 | return result; |
127 | } |
128 | |
129 | /*! \brief Copy SMap<Buffer, Array<StmtSRef>> */ |
130 | SMap<Buffer, Array<StmtSRef>> Copy(const SMap<Buffer, Array<StmtSRef>>& map) { |
131 | SMap<Buffer, Array<StmtSRef>> result; |
132 | result.reserve(map.size()); |
133 | for (const auto& kv : map) { |
134 | result[kv.first] = Copy(kv.second); |
135 | } |
136 | return result; |
137 | } |
138 | |
139 | /*! \brief Copy SMap<StmtSRef, Scope> */ |
140 | SMap<StmtSRef, BlockInfo> Copy(const SMap<StmtSRef, BlockInfo>& scopes) { |
141 | SMap<StmtSRef, BlockInfo> result; |
142 | for (const auto& kv : scopes) { |
143 | const StmtSRef& old_sref = kv.first; |
144 | const BlockInfo& old_info = kv.second; |
145 | BlockInfo new_info = old_info; |
146 | ObjectPtr<BlockScopeNode> scope = make_object<BlockScopeNode>(); |
147 | scope->src2deps = Copy(old_info.scope->src2deps); |
148 | scope->dst2deps = Copy(old_info.scope->dst2deps); |
149 | scope->buffer_writers = Copy(old_info.scope->buffer_writers); |
150 | scope->stage_pipeline = old_info.scope->stage_pipeline; |
151 | new_info.scope = BlockScope(std::move(scope)); |
152 | result[Copy(old_sref)] = std::move(new_info); |
153 | } |
154 | return result; |
155 | } |
156 | |
157 | /*! \brief Copy the stmt2ref */ |
158 | UMap<const StmtNode*, StmtSRef> Copy(const UMap<const StmtNode*, StmtSRef>& stmt2ref) { |
159 | UMap<const StmtNode*, StmtSRef> result; |
160 | result.reserve(stmt2ref.size()); |
161 | for (const auto& kv : stmt2ref) { |
162 | const StmtNode* stmt = kv.first; |
163 | const StmtSRef& sref = kv.second; |
164 | result.emplace(stmt, Copy(sref)); |
165 | } |
166 | return result; |
167 | } |
168 | |
169 | /*! \brief Copy the symbol table */ |
170 | TSymbolTable Copy(const TSymbolTable& tab) { |
171 | TSymbolTable result; |
172 | for (const auto& kv : tab) { |
173 | ObjectRef entry = kv.second; |
174 | if (const auto* sref = entry.as<StmtSRefNode>()) { |
175 | entry = Copy(sref); |
176 | } |
177 | result.Set(kv.first, entry); |
178 | } |
179 | return result; |
180 | } |
181 | |
182 | private: |
183 | std::unordered_map<const StmtSRefNode*, StmtSRef> old2new_; |
184 | }; |
185 | |
186 | void ConcreteScheduleNode::WorkOn(const String& func_name) { |
187 | this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name); |
188 | } |
189 | |
190 | void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { |
191 | ScheduleCopier::Copy(this, new_state, new_symbol_table); |
192 | new_state->get()->DebugVerify(); |
193 | } |
194 | |
195 | Schedule ConcreteScheduleNode::Copy() { |
196 | ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>(); |
197 | n->func_working_on_ = this->func_working_on_; |
198 | n->error_render_level_ = this->error_render_level_; |
199 | ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); |
200 | n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful |
201 | n->rand_state_ = ForkSeed(); |
202 | return Schedule(std::move(n)); |
203 | } |
204 | |
205 | /*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */ |
206 | #define TVM_TIR_SCHEDULE_BEGIN() try { |
207 | /*! |
208 | * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error |
209 | * message rendering |
210 | * \param level An ScheduleErrorRenderLevel enum, level of error rendering |
211 | * \sa ScheduleErrorRenderLevel |
212 | */ |
213 | #define TVM_TIR_SCHEDULE_END(primitive, level) \ |
214 | } \ |
215 | catch (const ScheduleError& error) { \ |
216 | if ((level) == ScheduleErrorRenderLevel::kDetail) { \ |
217 | throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \ |
218 | } else if ((level) == ScheduleErrorRenderLevel::kFast) { \ |
219 | throw tvm::runtime::Error(error.FastErrorString()); \ |
220 | } else if ((level) == ScheduleErrorRenderLevel::kNone) { \ |
221 | throw tvm::runtime::Error("ScheduleError: (not rendered)"); \ |
222 | } \ |
223 | } |
224 | |
225 | /******** Schedule: Schedule: Sampling ********/ |
226 | |
227 | void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { |
228 | this->rand_state_ = support::LinearCongruentialEngine::NormalizeSeed(seed); |
229 | } |
230 | |
231 | support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { |
232 | return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); |
233 | } |
234 | |
235 | ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates, |
236 | const Array<FloatImm>& probs, |
237 | Optional<Integer> decision) { |
238 | TVM_TIR_SCHEDULE_BEGIN(); |
239 | return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); |
240 | TVM_TIR_SCHEDULE_END("sample-categorical" , this->error_render_level_); |
241 | throw; |
242 | } |
243 | |
244 | Array<ExprRV> ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, |
245 | int max_innermost_factor, |
246 | Optional<Array<Integer>> decision) { |
247 | TVM_TIR_SCHEDULE_BEGIN(); |
248 | return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n, |
249 | max_innermost_factor, &decision)); |
250 | TVM_TIR_SCHEDULE_END("sample-perfect-tile" , this->error_render_level_); |
251 | throw; |
252 | } |
253 | |
254 | LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, |
255 | Optional<Integer> decision) { |
256 | TVM_TIR_SCHEDULE_BEGIN(); |
257 | return CreateRV<LoopRV>( |
258 | tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); |
259 | TVM_TIR_SCHEDULE_END("sample-compute-location" , this->error_render_level_); |
260 | throw; |
261 | } |
262 | |
263 | /******** Schedule: Get blocks & loops ********/ |
264 | |
265 | BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional<String>& func_name) { |
266 | class NotSingleResult : public ScheduleError { |
267 | public: |
268 | explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks) |
269 | : name_(name), mod_(mod), blocks_{} { |
270 | blocks_.reserve(blocks.size()); |
271 | for (const StmtSRef& block_sref : blocks) { |
272 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
273 | blocks_.push_back(GetRef<Block>(block)); |
274 | } |
275 | } |
276 | |
277 | IRModule mod() const final { return mod_; } |
278 | Array<ObjectRef> LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; } |
279 | |
280 | String DetailRenderTemplate() const final { |
281 | if (blocks_.empty()) { |
282 | return "Cannot find a block with the name: " + name_; |
283 | } else { |
284 | return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_; |
285 | } |
286 | } |
287 | |
288 | String FastErrorString() const final { |
289 | if (blocks_.empty()) { |
290 | return "ScheduleError: Cannot find a block with the specified name" ; |
291 | } else { |
292 | return "ScheduleError: Found multiple blocks with the specified name" ; |
293 | } |
294 | } |
295 | |
296 | String name_; |
297 | IRModule mod_; |
298 | Array<Block> blocks_; |
299 | }; |
300 | GlobalVar gv = NullValue<GlobalVar>(); |
301 | if (func_name.defined()) { |
302 | gv = state_->mod->GetGlobalVar(func_name.value()); |
303 | } else if (func_working_on_.defined()) { |
304 | gv = this->func_working_on_.value(); |
305 | } else { |
306 | LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " |
307 | "specify the function name explicitly, or call `work_on` to specify the function " |
308 | "before using `get_block`." ; |
309 | } |
310 | Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv); |
311 | if (blocks.size() != 1) { |
312 | TVM_TIR_SCHEDULE_BEGIN(); |
313 | throw NotSingleResult(name, this->state_->mod, blocks); |
314 | TVM_TIR_SCHEDULE_END("get-block" , this->error_render_level_); |
315 | } |
316 | return CreateRV<BlockRV>(blocks[0]); |
317 | } |
318 | |
319 | Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { |
320 | return CreateRV<LoopRV>(tir::GetLoops(this->GetSRef(block_rv))); |
321 | } |
322 | |
323 | Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { |
324 | Array<BlockRV> result; |
325 | TVM_TIR_SCHEDULE_BEGIN(); |
326 | result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(block_rv))); |
327 | TVM_TIR_SCHEDULE_END("get-child-blocks" , this->error_render_level_); |
328 | this->state_->DebugVerify(); |
329 | return result; |
330 | } |
331 | |
332 | Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { |
333 | Array<BlockRV> result; |
334 | TVM_TIR_SCHEDULE_BEGIN(); |
335 | result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(loop_rv))); |
336 | TVM_TIR_SCHEDULE_END("get-child-blocks" , this->error_render_level_); |
337 | this->state_->DebugVerify(); |
338 | return result; |
339 | } |
340 | |
341 | Array<BlockRV> ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) { |
342 | TVM_TIR_SCHEDULE_BEGIN(); |
343 | return CreateRV<BlockRV>(tir::GetProducers(state_, this->GetSRef(block_rv))); |
344 | TVM_TIR_SCHEDULE_END("get-producers" , this->error_render_level_); |
345 | throw; |
346 | } |
347 | |
348 | Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { |
349 | TVM_TIR_SCHEDULE_BEGIN(); |
350 | return CreateRV<BlockRV>(tir::GetConsumers(state_, this->GetSRef(block_rv))); |
351 | TVM_TIR_SCHEDULE_END("get-consumers" , this->error_render_level_); |
352 | throw; |
353 | } |
354 | |
355 | /******** Schedule: Transform loops ********/ |
356 | |
357 | LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) { |
358 | CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)" ; |
359 | Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs); |
360 | StmtSRef result{nullptr}; |
361 | TVM_TIR_SCHEDULE_BEGIN(); |
362 | result = tir::Fuse(state_, loop_srefs, preserve_unit_iters); |
363 | TVM_TIR_SCHEDULE_END("fuse" , this->error_render_level_); |
364 | this->state_->DebugVerify(); |
365 | return CreateRV<LoopRV>(result); |
366 | } |
367 | |
368 | Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv, |
369 | const Array<Optional<ExprRV>>& factor_rvs, |
370 | bool preserve_unit_iters) { |
371 | class NotSingleInferFactorError : public ScheduleError { |
372 | public: |
373 | explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} |
374 | |
375 | String FastErrorString() const final { |
376 | return "ScheduleError: only one factor can be specified as -1 or none" ; |
377 | } |
378 | |
379 | String DetailRenderTemplate() const final { |
380 | return "Only one factor can be specified as -1 or none" ; |
381 | } |
382 | |
383 | IRModule mod() const final { return mod_; } |
384 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
385 | |
386 | IRModule mod_; |
387 | }; |
388 | |
389 | class WrongFactorProductError : public ScheduleError { |
390 | public: |
391 | explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} |
392 | |
393 | String FastErrorString() const final { |
394 | return "ScheduleError: The product of factors is not larger than or equal to the extent of " |
395 | "loop" ; |
396 | } |
397 | |
398 | String DetailRenderTemplate() const final { |
399 | return "The product of factors is not larger than or equal to the extent of loop {0}" ; |
400 | } |
401 | |
402 | IRModule mod() const final { return mod_; } |
403 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
404 | |
405 | IRModule mod_; |
406 | For loop_; |
407 | }; |
408 | |
409 | class NonPositiveFactorError : public ScheduleError { |
410 | public: |
411 | explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx) |
412 | : mod_(std::move(mod)), factor_(factor), idx_(idx) {} |
413 | |
414 | String FastErrorString() const final { |
415 | return "ScheduleError: All the constant factors are required to be positive. However, some " |
416 | "constant input factor is zero or negative." ; |
417 | } |
418 | String DetailRenderTemplate() const final { |
419 | std::ostringstream os; |
420 | os << "All the constant factors are required to be positive. However, the factor at position " |
421 | << idx_ << " is " << factor_; |
422 | return os.str(); |
423 | } |
424 | IRModule mod() const final { return mod_; } |
425 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
426 | |
427 | private: |
428 | IRModule mod_; |
429 | int64_t factor_; |
430 | size_t idx_; |
431 | }; |
432 | |
433 | // Prepare for the splitting |
434 | StmtSRef loop_sref = this->GetSRef(loop_rv); |
435 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
436 | Array<PrimExpr> factors; |
437 | factors.reserve(factor_rvs.size()); |
438 | int infer_index = -1; |
439 | PrimExpr tot_length = 1; |
440 | Array<StmtSRef> results; |
441 | TVM_TIR_SCHEDULE_BEGIN(); |
442 | // infer factor if needed and check validity of factors |
443 | for (size_t i = 0; i < factor_rvs.size(); i++) { |
444 | if (!factor_rvs[i].defined()) { |
445 | factors.push_back(Integer(-1)); |
446 | if (infer_index != -1) { |
447 | throw NotSingleInferFactorError(state_->mod); |
448 | } |
449 | infer_index = i; |
450 | } else { |
451 | PrimExpr factor = this->Get(factor_rvs[i].value()); |
452 | if (is_const_int(factor) && !is_positive_const(factor)) { |
453 | throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i); |
454 | } |
455 | if (factor.dtype().bits() > loop->extent.dtype().bits()) { |
456 | factor = cast(loop->extent.dtype(), factor); |
457 | } |
458 | factors.push_back(factor); |
459 | tot_length *= factor; |
460 | } |
461 | } |
462 | if (infer_index != -1) { |
463 | factors.Set(infer_index, |
464 | this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); |
465 | } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { |
466 | throw WrongFactorProductError(state_->mod, GetRef<For>(loop)); |
467 | } |
468 | results = tir::Split(state_, loop_sref, factors, preserve_unit_iters); |
469 | TVM_TIR_SCHEDULE_END("split" , this->error_render_level_); |
470 | this->state_->DebugVerify(); |
471 | return CreateRV<LoopRV>(results); |
472 | } |
473 | |
474 | void ConcreteScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) { |
475 | TVM_TIR_SCHEDULE_BEGIN(); |
476 | tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); |
477 | TVM_TIR_SCHEDULE_END("reorder" , this->error_render_level_); |
478 | this->state_->DebugVerify(); |
479 | } |
480 | |
481 | LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { |
482 | LoopRV result{nullptr}; |
483 | TVM_TIR_SCHEDULE_BEGIN(); |
484 | result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv))); |
485 | TVM_TIR_SCHEDULE_END("add-unit-loop" , this->error_render_level_); |
486 | this->state_->DebugVerify(); |
487 | return result; |
488 | } |
489 | |
490 | LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { |
491 | LoopRV result{nullptr}; |
492 | TVM_TIR_SCHEDULE_BEGIN(); |
493 | result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv))); |
494 | TVM_TIR_SCHEDULE_END("add-unit-loop" , this->error_render_level_); |
495 | this->state_->DebugVerify(); |
496 | return result; |
497 | } |
498 | |
499 | /******** Schedule: Manipulate ForKind ********/ |
500 | |
501 | void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { |
502 | TVM_TIR_SCHEDULE_BEGIN(); |
503 | tir::Parallel(state_, this->GetSRef(loop_rv)); |
504 | this->state_->DebugVerify(); |
505 | TVM_TIR_SCHEDULE_END("parallel" , this->error_render_level_); |
506 | } |
507 | |
508 | void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) { |
509 | TVM_TIR_SCHEDULE_BEGIN(); |
510 | tir::Vectorize(state_, this->GetSRef(loop_rv)); |
511 | this->state_->DebugVerify(); |
512 | TVM_TIR_SCHEDULE_END("vectorize" , this->error_render_level_); |
513 | } |
514 | |
515 | void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { |
516 | if (thread_axis == "vthread" ) { |
517 | LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use " |
518 | "`vthread.x`, `vthread.y` and `vthread.z` instead" ; |
519 | } |
520 | TVM_TIR_SCHEDULE_BEGIN(); |
521 | tir::Bind(state_, this->GetSRef(loop_rv), |
522 | IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex, |
523 | /*thread_tag=*/thread_axis)); |
524 | this->state_->DebugVerify(); |
525 | TVM_TIR_SCHEDULE_END("bind" , this->error_render_level_); |
526 | } |
527 | |
528 | void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) { |
529 | TVM_TIR_SCHEDULE_BEGIN(); |
530 | tir::Unroll(state_, this->GetSRef(loop_rv)); |
531 | this->state_->DebugVerify(); |
532 | TVM_TIR_SCHEDULE_END("unroll" , this->error_render_level_); |
533 | } |
534 | |
535 | /******** Schedule: Insert cache stages ********/ |
536 | |
537 | BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, |
538 | const String& storage_scope, |
539 | const Array<BlockRV> consumer_blocks) { |
540 | StmtSRef result{nullptr}; |
541 | // Create a new array of SRefs from the consumer block list. |
542 | Array<StmtSRef> consumer_block_refs = {}; |
543 | for (BlockRV block : consumer_blocks) { |
544 | consumer_block_refs.push_back(this->GetSRef(block)); |
545 | } |
546 | TVM_TIR_SCHEDULE_BEGIN(); |
547 | result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, |
548 | consumer_block_refs); |
549 | TVM_TIR_SCHEDULE_END("cache-read" , this->error_render_level_); |
550 | this->state_->DebugVerify(); |
551 | return CreateRV<BlockRV>(result); |
552 | } |
553 | |
554 | BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, |
555 | const String& storage_scope, |
556 | const Array<BlockRV> consumer_blocks) { |
557 | StmtSRef result{nullptr}; |
558 | // Create a new array of SRefs from the consumer block list. |
559 | Array<StmtSRef> consumer_block_refs = {}; |
560 | for (BlockRV block : consumer_blocks) { |
561 | consumer_block_refs.push_back(this->GetSRef(block)); |
562 | } |
563 | TVM_TIR_SCHEDULE_BEGIN(); |
564 | result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope, |
565 | consumer_block_refs); |
566 | TVM_TIR_SCHEDULE_END("cache-write" , this->error_render_level_); |
567 | this->state_->DebugVerify(); |
568 | return CreateRV<BlockRV>(result); |
569 | } |
570 | |
571 | Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, |
572 | const String& storage_scope) { |
573 | Array<StmtSRef> results; |
574 | TVM_TIR_SCHEDULE_BEGIN(); |
575 | results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); |
576 | TVM_TIR_SCHEDULE_END("cache-buffer" , this->error_render_level_); |
577 | this->state_->DebugVerify(); |
578 | Array<BlockRV> return_blocks; |
579 | return_blocks.push_back(CreateRV<BlockRV>(results[0])); |
580 | return_blocks.push_back(CreateRV<BlockRV>(results[1])); |
581 | return return_blocks; |
582 | } |
583 | |
584 | Array<BlockRV> ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, |
585 | const String& storage_scope, int cse_thresh) { |
586 | Array<StmtSRef> result; |
587 | TVM_TIR_SCHEDULE_BEGIN(); |
588 | result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); |
589 | TVM_TIR_SCHEDULE_END("cache-index" , this->error_render_level_); |
590 | this->state_->DebugVerify(); |
591 | Array<BlockRV> return_blocks; |
592 | for (const StmtSRef& blockrv : result) { |
593 | return_blocks.push_back(CreateRV<BlockRV>(blockrv)); |
594 | } |
595 | return return_blocks; |
596 | } |
597 | |
598 | BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, |
599 | BufferIndexType buffer_index_type) { |
600 | StmtSRef result{nullptr}; |
601 | TVM_TIR_SCHEDULE_BEGIN(); |
602 | result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type); |
603 | TVM_TIR_SCHEDULE_END("reindex" , this->error_render_level_); |
604 | this->state_->DebugVerify(); |
605 | return CreateRV<BlockRV>(result); |
606 | } |
607 | |
608 | /******** Schedule: Compute location ********/ |
609 | |
610 | void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, |
611 | bool preserve_unit_loops, int index) { |
612 | static StmtSRef inline_mark = StmtSRef::InlineMark(); |
613 | static StmtSRef root_mark = StmtSRef::RootMark(); |
614 | StmtSRef loop_sref = this->GetSRef(loop_rv); |
615 | if (loop_sref.same_as(root_mark)) { |
616 | // do nothing |
617 | } else if (loop_sref.same_as(inline_mark)) { |
618 | TVM_TIR_SCHEDULE_BEGIN(); |
619 | tir::ComputeInline(state_, this->GetSRef(block_rv)); |
620 | TVM_TIR_SCHEDULE_END("compute-at" , this->error_render_level_); |
621 | } else { |
622 | TVM_TIR_SCHEDULE_BEGIN(); |
623 | tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); |
624 | TVM_TIR_SCHEDULE_END("compute-at" , this->error_render_level_); |
625 | } |
626 | this->state_->DebugVerify(); |
627 | } |
628 | |
629 | void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, |
630 | bool preserve_unit_loops, int index) { |
631 | static StmtSRef inline_mark = StmtSRef::InlineMark(); |
632 | static StmtSRef root_mark = StmtSRef::RootMark(); |
633 | StmtSRef loop_sref = this->GetSRef(loop_rv); |
634 | if (loop_sref.same_as(root_mark)) { |
635 | // do nothing |
636 | } else if (loop_sref.same_as(inline_mark)) { |
637 | TVM_TIR_SCHEDULE_BEGIN(); |
638 | tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); |
639 | TVM_TIR_SCHEDULE_END("reverse-compute-at" , this->error_render_level_); |
640 | } else { |
641 | TVM_TIR_SCHEDULE_BEGIN(); |
642 | tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index); |
643 | TVM_TIR_SCHEDULE_END("reverse-compute-at" , this->error_render_level_); |
644 | } |
645 | this->state_->DebugVerify(); |
646 | } |
647 | |
648 | void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { |
649 | TVM_TIR_SCHEDULE_BEGIN(); |
650 | tir::ComputeInline(state_, this->GetSRef(block_rv)); |
651 | TVM_TIR_SCHEDULE_END("compute-inline" , this->error_render_level_); |
652 | this->state_->DebugVerify(); |
653 | } |
654 | |
655 | void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { |
656 | TVM_TIR_SCHEDULE_BEGIN(); |
657 | tir::ReverseComputeInline(state_, this->GetSRef(block_rv)); |
658 | TVM_TIR_SCHEDULE_END("reverse-compute-inline" , this->error_render_level_); |
659 | this->state_->DebugVerify(); |
660 | } |
661 | |
662 | /******** Schedule: Block Annotation ********/ |
663 | |
664 | void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, |
665 | int factor, int offset) { |
666 | TVM_TIR_SCHEDULE_BEGIN(); |
667 | tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); |
668 | TVM_TIR_SCHEDULE_END("storage-align" , this->error_render_level_); |
669 | this->state_->DebugVerify(); |
670 | } |
671 | |
672 | void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, |
673 | const String& storage_scope) { |
674 | TVM_TIR_SCHEDULE_BEGIN(); |
675 | tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); |
676 | TVM_TIR_SCHEDULE_END("set-scope" , this->error_render_level_); |
677 | this->state_->DebugVerify(); |
678 | } |
679 | |
680 | /******** Schedule: Reduction ********/ |
681 | |
682 | BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { |
683 | StmtSRef result{nullptr}; |
684 | TVM_TIR_SCHEDULE_BEGIN(); |
685 | result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); |
686 | TVM_TIR_SCHEDULE_END("decompose-reduction" , this->error_render_level_); |
687 | this->state_->DebugVerify(); |
688 | return CreateRV<BlockRV>(result); |
689 | } |
690 | |
691 | BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { |
692 | StmtSRef result{nullptr}; |
693 | TVM_TIR_SCHEDULE_BEGIN(); |
694 | result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); |
695 | TVM_TIR_SCHEDULE_END("rfactor" , this->error_render_level_); |
696 | this->state_->DebugVerify(); |
697 | return CreateRV<BlockRV>(result); |
698 | } |
699 | |
700 | /******** Schedule: Blockize & Tensorize ********/ |
701 | BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { |
702 | StmtSRef result{nullptr}; |
703 | TVM_TIR_SCHEDULE_BEGIN(); |
704 | result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters); |
705 | this->state_->DebugVerify(); |
706 | TVM_TIR_SCHEDULE_END("blockize" , this->error_render_level_); |
707 | return CreateRV<BlockRV>(result); |
708 | } |
709 | |
710 | void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, |
711 | bool preserve_unit_iters) { |
712 | TVM_TIR_SCHEDULE_BEGIN(); |
713 | tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), |
714 | preserve_unit_iters); |
715 | this->state_->DebugVerify(); |
716 | TVM_TIR_SCHEDULE_END("tensorize" , this->error_render_level_); |
717 | } |
718 | |
719 | void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, |
720 | bool preserve_unit_iters) { |
721 | TVM_TIR_SCHEDULE_BEGIN(); |
722 | tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), |
723 | preserve_unit_iters); |
724 | this->state_->DebugVerify(); |
725 | TVM_TIR_SCHEDULE_END("tensorize" , this->error_render_level_); |
726 | } |
727 | |
728 | /******** Schedule: Annotation ********/ |
729 | |
730 | ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { |
731 | if (ann_val.as<StringObj>()) { |
732 | return ann_val; |
733 | } |
734 | if (const auto* expr = ann_val.as<PrimExprNode>()) { |
735 | ICHECK(!ann_val->IsInstance<StringImmNode>()) |
736 | << "TypeError: runtime::String is expected, but gets StringImm" ; |
737 | return this->Get(GetRef<PrimExpr>(expr)); |
738 | } |
739 | if (const auto* arr = ann_val.as<ArrayNode>()) { |
740 | Array<ObjectRef> result; |
741 | result.reserve(arr->size()); |
742 | for (size_t i = 0; i < arr->size(); i++) { |
743 | result.push_back(CheckAndGetAnnotationValue(arr->at(i))); |
744 | } |
745 | return std::move(result); |
746 | } |
747 | if (const auto* dict = ann_val.as<MapNode>()) { |
748 | Map<String, ObjectRef> result; |
749 | for (auto it = dict->begin(); it != dict->end(); ++it) { |
750 | const auto& key = it->first; |
751 | auto value = CheckAndGetAnnotationValue(it->second); |
752 | if (const StringImmNode* imm = key.as<StringImmNode>()) { |
753 | result.Set(imm->value, value); |
754 | } else if (key->IsInstance<StringObj>()) { |
755 | result.Set(Downcast<String>(key), value); |
756 | } else { |
757 | LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm" ; |
758 | } |
759 | } |
760 | return std::move(result); |
761 | } |
762 | LOG(FATAL) |
763 | << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " |
764 | << "gets: " << ann_val->GetTypeKey(); |
765 | throw; |
766 | } |
767 | |
768 | void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, |
769 | const ObjectRef& ann_val) { |
770 | TVM_TIR_SCHEDULE_BEGIN(); |
771 | tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); |
772 | this->state_->DebugVerify(); |
773 | TVM_TIR_SCHEDULE_END("annotate" , this->error_render_level_); |
774 | } |
775 | |
776 | void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { |
777 | TVM_TIR_SCHEDULE_BEGIN(); |
778 | tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key); |
779 | this->state_->DebugVerify(); |
780 | TVM_TIR_SCHEDULE_END("unannotate" , this->error_render_level_); |
781 | } |
782 | |
783 | void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, |
784 | const ObjectRef& ann_val) { |
785 | TVM_TIR_SCHEDULE_BEGIN(); |
786 | tir::Annotate(state_, this->GetSRef(block_rv), ann_key, |
787 | this->CheckAndGetAnnotationValue(ann_val)); |
788 | this->state_->DebugVerify(); |
789 | TVM_TIR_SCHEDULE_END("annotate" , this->error_render_level_); |
790 | } |
791 | |
792 | void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { |
793 | TVM_TIR_SCHEDULE_BEGIN(); |
794 | tir::Unannotate(state_, this->GetSRef(block_rv), ann_key); |
795 | this->state_->DebugVerify(); |
796 | TVM_TIR_SCHEDULE_END("unannotate" , this->error_render_level_); |
797 | } |
798 | |
799 | /******** Schedule: Layout transformation ********/ |
800 | void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, |
801 | BufferIndexType buffer_index_type, |
802 | const IndexMap& index_map, |
803 | const Optional<IndexMap>& pad_value) { |
804 | TVM_TIR_SCHEDULE_BEGIN(); |
805 | auto f_subst = [&](const Var& var) -> Optional<PrimExpr> { |
806 | return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var)); |
807 | }; |
808 | auto new_index_map = Substitute(index_map, f_subst); |
809 | tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, |
810 | new_index_map, pad_value); |
811 | this->state_->DebugVerify(); |
812 | TVM_TIR_SCHEDULE_END("transform_layout" , this->error_render_level_); |
813 | } |
814 | |
815 | void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv, |
816 | const IndexMap& index_map) { |
817 | TVM_TIR_SCHEDULE_BEGIN(); |
818 | tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map); |
819 | this->state_->DebugVerify(); |
820 | TVM_TIR_SCHEDULE_END("transform_block_layout" , this->error_render_level_); |
821 | } |
822 | |
823 | void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, |
824 | BufferIndexType buffer_index_type, |
825 | const Array<IntImm>& axis_separators) { |
826 | TVM_TIR_SCHEDULE_BEGIN(); |
827 | tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, |
828 | axis_separators); |
829 | TVM_TIR_SCHEDULE_END("set-axis-separator" , this->error_render_level_); |
830 | this->state_->DebugVerify(); |
831 | } |
832 | |
833 | /******** Schedule: Padding ********/ |
834 | |
835 | BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { |
836 | StmtSRef result{nullptr}; |
837 | TVM_TIR_SCHEDULE_BEGIN(); |
838 | result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); |
839 | TVM_TIR_SCHEDULE_END("decompose-padding" , this->error_render_level_); |
840 | this->state_->DebugVerify(); |
841 | return CreateRV<BlockRV>(result); |
842 | } |
843 | |
844 | void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) { |
845 | TVM_TIR_SCHEDULE_BEGIN(); |
846 | tir::PadEinsum(state_, this->GetSRef(block_rv), padding); |
847 | TVM_TIR_SCHEDULE_END("pad-einsum" , this->error_render_level_); |
848 | this->state_->DebugVerify(); |
849 | } |
850 | |
851 | /******** Schedule: Buffer Transformation ********/ |
852 | |
853 | void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { |
854 | TVM_TIR_SCHEDULE_BEGIN(); |
855 | tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); |
856 | TVM_TIR_SCHEDULE_END("rolling-buffer" , this->error_render_level_); |
857 | this->state_->DebugVerify(); |
858 | } |
859 | |
860 | /******** Schedule: Misc ********/ |
861 | |
862 | } // namespace tir |
863 | } // namespace tvm |
864 | |