1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./concrete_schedule.h"
20
21#include <random>
22
23namespace tvm {
24namespace tir {
25
26Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
27 int debug_mask, ScheduleErrorRenderLevel error_render_level) {
28 ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
29 n->state_ = ScheduleState(mod, debug_mask);
30 n->error_render_level_ = error_render_level;
31 n->symbol_table_ = {};
32 n->analyzer_ = std::make_unique<arith::Analyzer>();
33 n->Seed(seed);
34 GlobalVar gv = NullValue<GlobalVar>();
35 if (FindEntryFunc(mod, &gv) != nullptr) {
36 n->func_working_on_ = gv;
37 } else {
38 n->func_working_on_ = NullOpt;
39 }
40 return Schedule(std::move(n));
41}
42
43/******** Copy ********/
44
45/*! \brief Helper class to perform a deep copy of the sref tree */
46class ScheduleCopier {
47 using TSymbolTable = ConcreteScheduleNode::TSymbolTable;
48 template <class K, class V>
49 using UMap = std::unordered_map<K, V>;
50 template <class K, class V>
51 using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>;
52
53 public:
54 static void Copy(const ConcreteScheduleNode* self, ScheduleState* new_state,
55 TSymbolTable* new_symbol_table) {
56 const ScheduleState& src_state = self->state_;
57 ScheduleCopier copier(src_state);
58 ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
59 n->mod = src_state->mod;
60 n->block_info = copier.Copy(src_state->block_info);
61 n->stmt2ref = copier.Copy(src_state->stmt2ref);
62 n->debug_mask = src_state->debug_mask;
63 *new_state = ScheduleState(std::move(n));
64 *new_symbol_table = copier.Copy(self->symbol_table_);
65 }
66
67 private:
68 /*! \brief Create the copier and properly set up the `old2new_` table */
69 explicit ScheduleCopier(const ScheduleState& state) {
70 // Create SRef tree without parents
71 for (const auto& kv : state->stmt2ref) {
72 const StmtSRefNode* sref = kv.second.operator->();
73 old2new_.emplace(sref, // the old StmtSRef
74 StmtSRef(/*stmt=*/sref->stmt, // the new StmtSRef
75 /*parent=*/nullptr, // parent is not set yet
76 /*seq_index=*/sref->seq_index));
77 }
78 // Fill in the parent field
79 // Find out the root along the way
80 for (auto& kv : old2new_) {
81 const StmtSRefNode* parent = kv.first->parent;
82 StmtSRef& sref = kv.second;
83 sref->parent = parent ? old2new_.at(parent).get() : nullptr;
84 }
85 }
86
87 /*! \brief Copy StmtSRef */
88 StmtSRef Copy(const StmtSRef& sref) { return old2new_.at(sref.operator->()); }
89
90 /*! \brief Copy StmtSRefNode */
91 StmtSRef Copy(const StmtSRefNode* sref) {
92 if (old2new_.count(sref)) {
93 return old2new_.at(sref);
94 }
95 // Handle expired sref
96 return old2new_[sref] = StmtSRef(nullptr, nullptr, -1);
97 }
98
99 /*! \brief Copy Array<StmtSRef> */
100 Array<StmtSRef> Copy(const Array<StmtSRef>& list) {
101 Array<StmtSRef> result;
102 result.reserve(list.size());
103 for (const StmtSRef& elem : list) {
104 result.push_back(Copy(elem));
105 }
106 return result;
107 }
108
109 /*! \brief Copy Array<Dependency> */
110 Array<Dependency> Copy(const Array<Dependency>& list) {
111 Array<Dependency> result;
112 result.reserve(list.size());
113 for (const Dependency& elem : list) {
114 result.push_back(Dependency(Copy(elem->src), Copy(elem->dst), elem->kind));
115 }
116 return result;
117 }
118
119 /*! \brief Copy SMap<StmtSRef, Array<Dependency>> */
120 SMap<StmtSRef, Array<Dependency>> Copy(const SMap<StmtSRef, Array<Dependency>>& map) {
121 SMap<StmtSRef, Array<Dependency>> result;
122 result.reserve(map.size());
123 for (const auto& kv : map) {
124 result[Copy(kv.first)] = Copy(kv.second);
125 }
126 return result;
127 }
128
129 /*! \brief Copy SMap<Buffer, Array<StmtSRef>> */
130 SMap<Buffer, Array<StmtSRef>> Copy(const SMap<Buffer, Array<StmtSRef>>& map) {
131 SMap<Buffer, Array<StmtSRef>> result;
132 result.reserve(map.size());
133 for (const auto& kv : map) {
134 result[kv.first] = Copy(kv.second);
135 }
136 return result;
137 }
138
139 /*! \brief Copy SMap<StmtSRef, Scope> */
140 SMap<StmtSRef, BlockInfo> Copy(const SMap<StmtSRef, BlockInfo>& scopes) {
141 SMap<StmtSRef, BlockInfo> result;
142 for (const auto& kv : scopes) {
143 const StmtSRef& old_sref = kv.first;
144 const BlockInfo& old_info = kv.second;
145 BlockInfo new_info = old_info;
146 ObjectPtr<BlockScopeNode> scope = make_object<BlockScopeNode>();
147 scope->src2deps = Copy(old_info.scope->src2deps);
148 scope->dst2deps = Copy(old_info.scope->dst2deps);
149 scope->buffer_writers = Copy(old_info.scope->buffer_writers);
150 scope->stage_pipeline = old_info.scope->stage_pipeline;
151 new_info.scope = BlockScope(std::move(scope));
152 result[Copy(old_sref)] = std::move(new_info);
153 }
154 return result;
155 }
156
157 /*! \brief Copy the stmt2ref */
158 UMap<const StmtNode*, StmtSRef> Copy(const UMap<const StmtNode*, StmtSRef>& stmt2ref) {
159 UMap<const StmtNode*, StmtSRef> result;
160 result.reserve(stmt2ref.size());
161 for (const auto& kv : stmt2ref) {
162 const StmtNode* stmt = kv.first;
163 const StmtSRef& sref = kv.second;
164 result.emplace(stmt, Copy(sref));
165 }
166 return result;
167 }
168
169 /*! \brief Copy the symbol table */
170 TSymbolTable Copy(const TSymbolTable& tab) {
171 TSymbolTable result;
172 for (const auto& kv : tab) {
173 ObjectRef entry = kv.second;
174 if (const auto* sref = entry.as<StmtSRefNode>()) {
175 entry = Copy(sref);
176 }
177 result.Set(kv.first, entry);
178 }
179 return result;
180 }
181
182 private:
183 std::unordered_map<const StmtSRefNode*, StmtSRef> old2new_;
184};
185
186void ConcreteScheduleNode::WorkOn(const String& func_name) {
187 this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name);
188}
189
190void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const {
191 ScheduleCopier::Copy(this, new_state, new_symbol_table);
192 new_state->get()->DebugVerify();
193}
194
195Schedule ConcreteScheduleNode::Copy() {
196 ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
197 n->func_working_on_ = this->func_working_on_;
198 n->error_render_level_ = this->error_render_level_;
199 ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
200 n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
201 n->rand_state_ = ForkSeed();
202 return Schedule(std::move(n));
203}
204
205/*! \brief Macro that guards the beginning of each invocation of TensorIR schedule primitive */
206#define TVM_TIR_SCHEDULE_BEGIN() try {
207/*!
208 * \brief Macro that pairs with `TVM_TIR_SCHEDULE_BEGIN`, handling potential errors and error
209 * message rendering
210 * \param level An ScheduleErrorRenderLevel enum, level of error rendering
211 * \sa ScheduleErrorRenderLevel
212 */
213#define TVM_TIR_SCHEDULE_END(primitive, level) \
214 } \
215 catch (const ScheduleError& error) { \
216 if ((level) == ScheduleErrorRenderLevel::kDetail) { \
217 throw tvm::runtime::Error(error.RenderReport(primitive) + "\n" + runtime::Backtrace()); \
218 } else if ((level) == ScheduleErrorRenderLevel::kFast) { \
219 throw tvm::runtime::Error(error.FastErrorString()); \
220 } else if ((level) == ScheduleErrorRenderLevel::kNone) { \
221 throw tvm::runtime::Error("ScheduleError: (not rendered)"); \
222 } \
223 }
224
225/******** Schedule: Schedule: Sampling ********/
226
227void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) {
228 this->rand_state_ = support::LinearCongruentialEngine::NormalizeSeed(seed);
229}
230
231support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
232 return support::LinearCongruentialEngine(&rand_state_).ForkSeed();
233}
234
235ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates,
236 const Array<FloatImm>& probs,
237 Optional<Integer> decision) {
238 TVM_TIR_SCHEDULE_BEGIN();
239 return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision));
240 TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_);
241 throw;
242}
243
244Array<ExprRV> ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n,
245 int max_innermost_factor,
246 Optional<Array<Integer>> decision) {
247 TVM_TIR_SCHEDULE_BEGIN();
248 return CreateRV(tir::SamplePerfectTile(&this->rand_state_, this->GetSRef(loop_rv), n,
249 max_innermost_factor, &decision));
250 TVM_TIR_SCHEDULE_END("sample-perfect-tile", this->error_render_level_);
251 throw;
252}
253
254LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv,
255 Optional<Integer> decision) {
256 TVM_TIR_SCHEDULE_BEGIN();
257 return CreateRV<LoopRV>(
258 tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision));
259 TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_);
260 throw;
261}
262
263/******** Schedule: Get blocks & loops ********/
264
265BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional<String>& func_name) {
266 class NotSingleResult : public ScheduleError {
267 public:
268 explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks)
269 : name_(name), mod_(mod), blocks_{} {
270 blocks_.reserve(blocks.size());
271 for (const StmtSRef& block_sref : blocks) {
272 const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
273 blocks_.push_back(GetRef<Block>(block));
274 }
275 }
276
277 IRModule mod() const final { return mod_; }
278 Array<ObjectRef> LocationsOfInterest() const final { return {blocks_.begin(), blocks_.end()}; }
279
280 String DetailRenderTemplate() const final {
281 if (blocks_.empty()) {
282 return "Cannot find a block with the name: " + name_;
283 } else {
284 return "Found " + std::to_string(blocks_.size()) + " blocks with the name: " + name_;
285 }
286 }
287
288 String FastErrorString() const final {
289 if (blocks_.empty()) {
290 return "ScheduleError: Cannot find a block with the specified name";
291 } else {
292 return "ScheduleError: Found multiple blocks with the specified name";
293 }
294 }
295
296 String name_;
297 IRModule mod_;
298 Array<Block> blocks_;
299 };
300 GlobalVar gv = NullValue<GlobalVar>();
301 if (func_name.defined()) {
302 gv = state_->mod->GetGlobalVar(func_name.value());
303 } else if (func_working_on_.defined()) {
304 gv = this->func_working_on_.value();
305 } else {
306 LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please "
307 "specify the function name explicitly, or call `work_on` to specify the function "
308 "before using `get_block`.";
309 }
310 Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv);
311 if (blocks.size() != 1) {
312 TVM_TIR_SCHEDULE_BEGIN();
313 throw NotSingleResult(name, this->state_->mod, blocks);
314 TVM_TIR_SCHEDULE_END("get-block", this->error_render_level_);
315 }
316 return CreateRV<BlockRV>(blocks[0]);
317}
318
319Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
320 return CreateRV<LoopRV>(tir::GetLoops(this->GetSRef(block_rv)));
321}
322
323Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
324 Array<BlockRV> result;
325 TVM_TIR_SCHEDULE_BEGIN();
326 result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(block_rv)));
327 TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
328 this->state_->DebugVerify();
329 return result;
330}
331
332Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
333 Array<BlockRV> result;
334 TVM_TIR_SCHEDULE_BEGIN();
335 result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(loop_rv)));
336 TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
337 this->state_->DebugVerify();
338 return result;
339}
340
341Array<BlockRV> ConcreteScheduleNode::GetProducers(const BlockRV& block_rv) {
342 TVM_TIR_SCHEDULE_BEGIN();
343 return CreateRV<BlockRV>(tir::GetProducers(state_, this->GetSRef(block_rv)));
344 TVM_TIR_SCHEDULE_END("get-producers", this->error_render_level_);
345 throw;
346}
347
348Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
349 TVM_TIR_SCHEDULE_BEGIN();
350 return CreateRV<BlockRV>(tir::GetConsumers(state_, this->GetSRef(block_rv)));
351 TVM_TIR_SCHEDULE_END("get-consumers", this->error_render_level_);
352 throw;
353}
354
355/******** Schedule: Transform loops ********/
356
357LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
358 CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
359 Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
360 StmtSRef result{nullptr};
361 TVM_TIR_SCHEDULE_BEGIN();
362 result = tir::Fuse(state_, loop_srefs, preserve_unit_iters);
363 TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
364 this->state_->DebugVerify();
365 return CreateRV<LoopRV>(result);
366}
367
368Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
369 const Array<Optional<ExprRV>>& factor_rvs,
370 bool preserve_unit_iters) {
371 class NotSingleInferFactorError : public ScheduleError {
372 public:
373 explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
374
375 String FastErrorString() const final {
376 return "ScheduleError: only one factor can be specified as -1 or none";
377 }
378
379 String DetailRenderTemplate() const final {
380 return "Only one factor can be specified as -1 or none";
381 }
382
383 IRModule mod() const final { return mod_; }
384 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
385
386 IRModule mod_;
387 };
388
389 class WrongFactorProductError : public ScheduleError {
390 public:
391 explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}
392
393 String FastErrorString() const final {
394 return "ScheduleError: The product of factors is not larger than or equal to the extent of "
395 "loop";
396 }
397
398 String DetailRenderTemplate() const final {
399 return "The product of factors is not larger than or equal to the extent of loop {0}";
400 }
401
402 IRModule mod() const final { return mod_; }
403 Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
404
405 IRModule mod_;
406 For loop_;
407 };
408
409 class NonPositiveFactorError : public ScheduleError {
410 public:
411 explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx)
412 : mod_(std::move(mod)), factor_(factor), idx_(idx) {}
413
414 String FastErrorString() const final {
415 return "ScheduleError: All the constant factors are required to be positive. However, some "
416 "constant input factor is zero or negative.";
417 }
418 String DetailRenderTemplate() const final {
419 std::ostringstream os;
420 os << "All the constant factors are required to be positive. However, the factor at position "
421 << idx_ << " is " << factor_;
422 return os.str();
423 }
424 IRModule mod() const final { return mod_; }
425 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
426
427 private:
428 IRModule mod_;
429 int64_t factor_;
430 size_t idx_;
431 };
432
433 // Prepare for the splitting
434 StmtSRef loop_sref = this->GetSRef(loop_rv);
435 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
436 Array<PrimExpr> factors;
437 factors.reserve(factor_rvs.size());
438 int infer_index = -1;
439 PrimExpr tot_length = 1;
440 Array<StmtSRef> results;
441 TVM_TIR_SCHEDULE_BEGIN();
442 // infer factor if needed and check validity of factors
443 for (size_t i = 0; i < factor_rvs.size(); i++) {
444 if (!factor_rvs[i].defined()) {
445 factors.push_back(Integer(-1));
446 if (infer_index != -1) {
447 throw NotSingleInferFactorError(state_->mod);
448 }
449 infer_index = i;
450 } else {
451 PrimExpr factor = this->Get(factor_rvs[i].value());
452 if (is_const_int(factor) && !is_positive_const(factor)) {
453 throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
454 }
455 if (factor.dtype().bits() > loop->extent.dtype().bits()) {
456 factor = cast(loop->extent.dtype(), factor);
457 }
458 factors.push_back(factor);
459 tot_length *= factor;
460 }
461 }
462 if (infer_index != -1) {
463 factors.Set(infer_index,
464 this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length)));
465 } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
466 throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
467 }
468 results = tir::Split(state_, loop_sref, factors, preserve_unit_iters);
469 TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
470 this->state_->DebugVerify();
471 return CreateRV<LoopRV>(results);
472}
473
474void ConcreteScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
475 TVM_TIR_SCHEDULE_BEGIN();
476 tir::Reorder(state_, GetSRefs(ordered_loop_rvs));
477 TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_);
478 this->state_->DebugVerify();
479}
480
481LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
482 LoopRV result{nullptr};
483 TVM_TIR_SCHEDULE_BEGIN();
484 result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv)));
485 TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
486 this->state_->DebugVerify();
487 return result;
488}
489
490LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
491 LoopRV result{nullptr};
492 TVM_TIR_SCHEDULE_BEGIN();
493 result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv)));
494 TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
495 this->state_->DebugVerify();
496 return result;
497}
498
499/******** Schedule: Manipulate ForKind ********/
500
501void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) {
502 TVM_TIR_SCHEDULE_BEGIN();
503 tir::Parallel(state_, this->GetSRef(loop_rv));
504 this->state_->DebugVerify();
505 TVM_TIR_SCHEDULE_END("parallel", this->error_render_level_);
506}
507
508void ConcreteScheduleNode::Vectorize(const LoopRV& loop_rv) {
509 TVM_TIR_SCHEDULE_BEGIN();
510 tir::Vectorize(state_, this->GetSRef(loop_rv));
511 this->state_->DebugVerify();
512 TVM_TIR_SCHEDULE_END("vectorize", this->error_render_level_);
513}
514
515void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) {
516 if (thread_axis == "vthread") {
517 LOG(WARNING) << "`vthread` is legacy behavior and is going to be deprecated. Please use "
518 "`vthread.x`, `vthread.y` and `vthread.z` instead";
519 }
520 TVM_TIR_SCHEDULE_BEGIN();
521 tir::Bind(state_, this->GetSRef(loop_rv),
522 IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex,
523 /*thread_tag=*/thread_axis));
524 this->state_->DebugVerify();
525 TVM_TIR_SCHEDULE_END("bind", this->error_render_level_);
526}
527
528void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
529 TVM_TIR_SCHEDULE_BEGIN();
530 tir::Unroll(state_, this->GetSRef(loop_rv));
531 this->state_->DebugVerify();
532 TVM_TIR_SCHEDULE_END("unroll", this->error_render_level_);
533}
534
535/******** Schedule: Insert cache stages ********/
536
537BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
538 const String& storage_scope,
539 const Array<BlockRV> consumer_blocks) {
540 StmtSRef result{nullptr};
541 // Create a new array of SRefs from the consumer block list.
542 Array<StmtSRef> consumer_block_refs = {};
543 for (BlockRV block : consumer_blocks) {
544 consumer_block_refs.push_back(this->GetSRef(block));
545 }
546 TVM_TIR_SCHEDULE_BEGIN();
547 result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope,
548 consumer_block_refs);
549 TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
550 this->state_->DebugVerify();
551 return CreateRV<BlockRV>(result);
552}
553
554BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
555 const String& storage_scope,
556 const Array<BlockRV> consumer_blocks) {
557 StmtSRef result{nullptr};
558 // Create a new array of SRefs from the consumer block list.
559 Array<StmtSRef> consumer_block_refs = {};
560 for (BlockRV block : consumer_blocks) {
561 consumer_block_refs.push_back(this->GetSRef(block));
562 }
563 TVM_TIR_SCHEDULE_BEGIN();
564 result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope,
565 consumer_block_refs);
566 TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
567 this->state_->DebugVerify();
568 return CreateRV<BlockRV>(result);
569}
570
571Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index,
572 const String& storage_scope) {
573 Array<StmtSRef> results;
574 TVM_TIR_SCHEDULE_BEGIN();
575 results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
576 TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_);
577 this->state_->DebugVerify();
578 Array<BlockRV> return_blocks;
579 return_blocks.push_back(CreateRV<BlockRV>(results[0]));
580 return_blocks.push_back(CreateRV<BlockRV>(results[1]));
581 return return_blocks;
582}
583
584Array<BlockRV> ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv,
585 const String& storage_scope, int cse_thresh) {
586 Array<StmtSRef> result;
587 TVM_TIR_SCHEDULE_BEGIN();
588 result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh);
589 TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_);
590 this->state_->DebugVerify();
591 Array<BlockRV> return_blocks;
592 for (const StmtSRef& blockrv : result) {
593 return_blocks.push_back(CreateRV<BlockRV>(blockrv));
594 }
595 return return_blocks;
596}
597
598BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
599 BufferIndexType buffer_index_type) {
600 StmtSRef result{nullptr};
601 TVM_TIR_SCHEDULE_BEGIN();
602 result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type);
603 TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
604 this->state_->DebugVerify();
605 return CreateRV<BlockRV>(result);
606}
607
608/******** Schedule: Compute location ********/
609
610void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
611 bool preserve_unit_loops, int index) {
612 static StmtSRef inline_mark = StmtSRef::InlineMark();
613 static StmtSRef root_mark = StmtSRef::RootMark();
614 StmtSRef loop_sref = this->GetSRef(loop_rv);
615 if (loop_sref.same_as(root_mark)) {
616 // do nothing
617 } else if (loop_sref.same_as(inline_mark)) {
618 TVM_TIR_SCHEDULE_BEGIN();
619 tir::ComputeInline(state_, this->GetSRef(block_rv));
620 TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
621 } else {
622 TVM_TIR_SCHEDULE_BEGIN();
623 tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
624 TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
625 }
626 this->state_->DebugVerify();
627}
628
629void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
630 bool preserve_unit_loops, int index) {
631 static StmtSRef inline_mark = StmtSRef::InlineMark();
632 static StmtSRef root_mark = StmtSRef::RootMark();
633 StmtSRef loop_sref = this->GetSRef(loop_rv);
634 if (loop_sref.same_as(root_mark)) {
635 // do nothing
636 } else if (loop_sref.same_as(inline_mark)) {
637 TVM_TIR_SCHEDULE_BEGIN();
638 tir::ReverseComputeInline(state_, this->GetSRef(block_rv));
639 TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
640 } else {
641 TVM_TIR_SCHEDULE_BEGIN();
642 tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
643 TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
644 }
645 this->state_->DebugVerify();
646}
647
648void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
649 TVM_TIR_SCHEDULE_BEGIN();
650 tir::ComputeInline(state_, this->GetSRef(block_rv));
651 TVM_TIR_SCHEDULE_END("compute-inline", this->error_render_level_);
652 this->state_->DebugVerify();
653}
654
655void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
656 TVM_TIR_SCHEDULE_BEGIN();
657 tir::ReverseComputeInline(state_, this->GetSRef(block_rv));
658 TVM_TIR_SCHEDULE_END("reverse-compute-inline", this->error_render_level_);
659 this->state_->DebugVerify();
660}
661
662/******** Schedule: Block Annotation ********/
663
664void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
665 int factor, int offset) {
666 TVM_TIR_SCHEDULE_BEGIN();
667 tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset);
668 TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
669 this->state_->DebugVerify();
670}
671
672void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
673 const String& storage_scope) {
674 TVM_TIR_SCHEDULE_BEGIN();
675 tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope);
676 TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_);
677 this->state_->DebugVerify();
678}
679
680/******** Schedule: Reduction ********/
681
682BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
683 StmtSRef result{nullptr};
684 TVM_TIR_SCHEDULE_BEGIN();
685 result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
686 TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_);
687 this->state_->DebugVerify();
688 return CreateRV<BlockRV>(result);
689}
690
691BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
692 StmtSRef result{nullptr};
693 TVM_TIR_SCHEDULE_BEGIN();
694 result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis);
695 TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_);
696 this->state_->DebugVerify();
697 return CreateRV<BlockRV>(result);
698}
699
700/******** Schedule: Blockize & Tensorize ********/
701BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
702 StmtSRef result{nullptr};
703 TVM_TIR_SCHEDULE_BEGIN();
704 result = tir::Blockize(state_, this->GetSRef(loop_rv), preserve_unit_iters);
705 this->state_->DebugVerify();
706 TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
707 return CreateRV<BlockRV>(result);
708}
709
710void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
711 bool preserve_unit_iters) {
712 TVM_TIR_SCHEDULE_BEGIN();
713 tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(),
714 preserve_unit_iters);
715 this->state_->DebugVerify();
716 TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
717}
718
719void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin,
720 bool preserve_unit_iters) {
721 TVM_TIR_SCHEDULE_BEGIN();
722 tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(),
723 preserve_unit_iters);
724 this->state_->DebugVerify();
725 TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
726}
727
728/******** Schedule: Annotation ********/
729
730ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) {
731 if (ann_val.as<StringObj>()) {
732 return ann_val;
733 }
734 if (const auto* expr = ann_val.as<PrimExprNode>()) {
735 ICHECK(!ann_val->IsInstance<StringImmNode>())
736 << "TypeError: runtime::String is expected, but gets StringImm";
737 return this->Get(GetRef<PrimExpr>(expr));
738 }
739 if (const auto* arr = ann_val.as<ArrayNode>()) {
740 Array<ObjectRef> result;
741 result.reserve(arr->size());
742 for (size_t i = 0; i < arr->size(); i++) {
743 result.push_back(CheckAndGetAnnotationValue(arr->at(i)));
744 }
745 return std::move(result);
746 }
747 if (const auto* dict = ann_val.as<MapNode>()) {
748 Map<String, ObjectRef> result;
749 for (auto it = dict->begin(); it != dict->end(); ++it) {
750 const auto& key = it->first;
751 auto value = CheckAndGetAnnotationValue(it->second);
752 if (const StringImmNode* imm = key.as<StringImmNode>()) {
753 result.Set(imm->value, value);
754 } else if (key->IsInstance<StringObj>()) {
755 result.Set(Downcast<String>(key), value);
756 } else {
757 LOG(FATAL) << "TypeError: annotation dict key expect to be String or StringImm";
758 }
759 }
760 return std::move(result);
761 }
762 LOG(FATAL)
763 << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
764 << "gets: " << ann_val->GetTypeKey();
765 throw;
766}
767
768void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
769 const ObjectRef& ann_val) {
770 TVM_TIR_SCHEDULE_BEGIN();
771 tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val));
772 this->state_->DebugVerify();
773 TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
774}
775
776void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) {
777 TVM_TIR_SCHEDULE_BEGIN();
778 tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
779 this->state_->DebugVerify();
780 TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
781}
782
783void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key,
784 const ObjectRef& ann_val) {
785 TVM_TIR_SCHEDULE_BEGIN();
786 tir::Annotate(state_, this->GetSRef(block_rv), ann_key,
787 this->CheckAndGetAnnotationValue(ann_val));
788 this->state_->DebugVerify();
789 TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
790}
791
792void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) {
793 TVM_TIR_SCHEDULE_BEGIN();
794 tir::Unannotate(state_, this->GetSRef(block_rv), ann_key);
795 this->state_->DebugVerify();
796 TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
797}
798
799/******** Schedule: Layout transformation ********/
800void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
801 BufferIndexType buffer_index_type,
802 const IndexMap& index_map,
803 const Optional<IndexMap>& pad_value) {
804 TVM_TIR_SCHEDULE_BEGIN();
805 auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
806 return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
807 };
808 auto new_index_map = Substitute(index_map, f_subst);
809 tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
810 new_index_map, pad_value);
811 this->state_->DebugVerify();
812 TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
813}
814
815void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
816 const IndexMap& index_map) {
817 TVM_TIR_SCHEDULE_BEGIN();
818 tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
819 this->state_->DebugVerify();
820 TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
821}
822
823void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
824 BufferIndexType buffer_index_type,
825 const Array<IntImm>& axis_separators) {
826 TVM_TIR_SCHEDULE_BEGIN();
827 tir::SetAxisSeparator(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
828 axis_separators);
829 TVM_TIR_SCHEDULE_END("set-axis-separator", this->error_render_level_);
830 this->state_->DebugVerify();
831}
832
833/******** Schedule: Padding ********/
834
835BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
836 StmtSRef result{nullptr};
837 TVM_TIR_SCHEDULE_BEGIN();
838 result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
839 TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_);
840 this->state_->DebugVerify();
841 return CreateRV<BlockRV>(result);
842}
843
844void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
845 TVM_TIR_SCHEDULE_BEGIN();
846 tir::PadEinsum(state_, this->GetSRef(block_rv), padding);
847 TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
848 this->state_->DebugVerify();
849}
850
851/******** Schedule: Buffer Transformation ********/
852
853void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) {
854 TVM_TIR_SCHEDULE_BEGIN();
855 tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index);
856 TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_);
857 this->state_->DebugVerify();
858}
859
860/******** Schedule: Misc ********/
861
862} // namespace tir
863} // namespace tvm
864