1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "./traced_schedule.h"
20
21namespace tvm {
22namespace tir {
23
24Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
25 int debug_mask, ScheduleErrorRenderLevel error_render_level) {
26 ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
27 n->state_ = ScheduleState(mod, debug_mask);
28 n->error_render_level_ = error_render_level;
29 n->symbol_table_ = {};
30 n->analyzer_ = std::make_unique<arith::Analyzer>();
31 n->trace_ = Trace();
32 n->Seed(seed);
33 GlobalVar gv = NullValue<GlobalVar>();
34 if (FindEntryFunc(mod, &gv) != nullptr) {
35 n->func_working_on_ = gv;
36 } else {
37 n->func_working_on_ = NullOpt;
38 }
39 return Schedule(std::move(n));
40}
41
42Schedule TracedScheduleNode::Copy() {
43 ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
44 n->error_render_level_ = this->error_render_level_;
45 ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
46 n->func_working_on_ = this->func_working_on_;
47 n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
48 n->rand_state_ = ForkSeed();
49 n->trace_ = Trace(this->trace_->insts, this->trace_->decisions);
50 return Schedule(std::move(n));
51}
52
53/******** Schedule: Sampling ********/
54
55ExprRV TracedScheduleNode::SampleCategorical(const Array<Integer>& candidates,
56 const Array<FloatImm>& probs,
57 Optional<Integer> decision) {
58 ExprRV result =
59 CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision));
60 static const InstructionKind& kind = InstructionKind::Get("SampleCategorical");
61 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
62 /*inputs=*/{},
63 /*attrs=*/{candidates, probs},
64 /*outputs=*/{result}),
65 /*decision=*/decision);
66 return result;
67}
68
69Array<ExprRV> TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n,
70 int max_innermost_factor,
71 Optional<Array<Integer>> decision) {
72 Array<ExprRV> results = CreateRV(tir::SamplePerfectTile(
73 &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision));
74
75 static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile");
76 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
77 /*inputs=*/{loop_rv},
78 /*attrs=*/{Integer(n), Integer(max_innermost_factor)},
79 /*outputs=*/{results.begin(), results.end()}),
80 /*decision=*/decision);
81 return results;
82}
83
84LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv,
85 Optional<Integer> decision) {
86 LoopRV result = CreateRV<LoopRV>(tir::SampleComputeLocation(this->state_, &this->rand_state_,
87 this->GetSRef(block_rv), &decision));
88
89 static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation");
90 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
91 /*inputs=*/{block_rv},
92 /*attrs=*/{},
93 /*outputs=*/{result}),
94 /*decision=*/decision);
95 return result;
96}
97
98/******** Schedule: Get blocks & loops ********/
99
100BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional<String>& func_name) {
101 GlobalVar gv = NullValue<GlobalVar>();
102 if (func_name.defined()) {
103 gv = state_->mod->GetGlobalVar(func_name.value());
104 } else if (func_working_on_.defined()) {
105 gv = this->func_working_on_.value();
106 } else {
107 LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please "
108 "specify the function name explicitly, or call `work_on` to specify the function "
109 "before using `get_block`.";
110 }
111 BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);
112
113 static const InstructionKind& kind = InstructionKind::Get("GetBlock");
114 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
115 /*inputs=*/{},
116 /*attrs=*/{name, gv->name_hint},
117 /*outputs=*/{result}));
118 return result;
119}
120
121Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV& block_rv) {
122 Array<LoopRV> results = ConcreteScheduleNode::GetLoops(block_rv);
123
124 static const InstructionKind& kind = InstructionKind::Get("GetLoops");
125 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
126 /*inputs=*/{block_rv},
127 /*attrs=*/{},
128 /*outputs=*/{results.begin(), results.end()}));
129 return results;
130}
131
132Array<BlockRV> TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
133 Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(block_rv);
134
135 static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
136 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
137 /*inputs=*/{block_rv},
138 /*attrs=*/{},
139 /*outputs=*/{results.begin(), results.end()}));
140 return results;
141}
142
143Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
144 Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(loop_rv);
145
146 static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
147 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
148 /*inputs=*/{loop_rv},
149 /*attrs=*/{},
150 /*outputs=*/{results.begin(), results.end()}));
151 return results;
152}
153
154Array<BlockRV> TracedScheduleNode::GetProducers(const BlockRV& block_rv) {
155 Array<BlockRV> results = ConcreteScheduleNode::GetProducers(block_rv);
156
157 static const InstructionKind& kind = InstructionKind::Get("GetProducers");
158 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
159 /*inputs=*/{block_rv},
160 /*attrs=*/{},
161 /*outputs=*/{results.begin(), results.end()}));
162 return results;
163}
164
165Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
166 Array<BlockRV> results = ConcreteScheduleNode::GetConsumers(block_rv);
167
168 static const InstructionKind& kind = InstructionKind::Get("GetConsumers");
169 trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
170 /*inputs=*/{block_rv},
171 /*attrs=*/{},
172 /*outputs=*/{results.begin(), results.end()}));
173 return results;
174}
175
176/******** Schedule: Transform loops ********/
177
178LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
179 LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);
180
181 static const InstructionKind& kind = InstructionKind::Get("Fuse");
182 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
183 /*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
184 /*attrs=*/{Integer(preserve_unit_loops)},
185 /*outputs=*/{result}));
186 return result;
187}
188
189Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv,
190 const Array<Optional<ExprRV>>& factor_rvs,
191 bool preserve_unit_iters) {
192 Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters);
193
194 std::vector<ObjectRef> inputs;
195 inputs.reserve(1 + factor_rvs.size());
196 inputs.push_back(loop_rv);
197 for (const ObjectRef& obj : factor_rvs) {
198 inputs.push_back(obj);
199 }
200
201 static const InstructionKind& kind = InstructionKind::Get("Split");
202 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
203 /*inputs=*/inputs,
204 /*attrs=*/{Integer(preserve_unit_iters)},
205 /*outputs=*/{results.begin(), results.end()}));
206 return results;
207}
208
209void TracedScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) {
210 ConcreteScheduleNode::Reorder(ordered_loop_rvs);
211
212 static const InstructionKind& kind = InstructionKind::Get("Reorder");
213 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
214 /*inputs=*/{ordered_loop_rvs.begin(), ordered_loop_rvs.end()},
215 /*attrs=*/{},
216 /*outputs=*/{}));
217}
218
219LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
220 LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv);
221
222 static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
223 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
224 /*inputs=*/{block_rv},
225 /*attrs=*/{},
226 /*outputs=*/{result}));
227 return result;
228}
229
230LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
231 LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv);
232
233 static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop");
234 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
235 /*inputs=*/{loop_rv},
236 /*attrs=*/{},
237 /*outputs=*/{result}));
238 return result;
239}
240
241/******** Schedule: Manipulate ForKind ********/
242
243void TracedScheduleNode::Parallel(const LoopRV& loop_rv) {
244 ConcreteScheduleNode::Parallel(loop_rv);
245
246 static const InstructionKind& kind = InstructionKind::Get("Parallel");
247 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
248 /*inputs=*/{loop_rv},
249 /*attrs=*/{},
250 /*outputs=*/{}));
251}
252
253void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) {
254 ConcreteScheduleNode::Vectorize(loop_rv);
255
256 static const InstructionKind& kind = InstructionKind::Get("Vectorize");
257 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
258 /*inputs=*/{loop_rv},
259 /*attrs=*/{},
260 /*outputs=*/{}));
261}
262
263void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) {
264 ConcreteScheduleNode::Bind(loop_rv, thread_axis);
265
266 static const InstructionKind& kind = InstructionKind::Get("Bind");
267 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
268 /*inputs=*/{loop_rv},
269 /*attrs=*/{thread_axis},
270 /*outputs=*/{}));
271}
272
273void TracedScheduleNode::Unroll(const LoopRV& loop_rv) {
274 ConcreteScheduleNode::Unroll(loop_rv);
275
276 static const InstructionKind& kind = InstructionKind::Get("Unroll");
277 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
278 /*inputs=*/{loop_rv},
279 /*attrs=*/{},
280 /*outputs=*/{}));
281}
282
283/******** Schedule: Insert cache stages ********/
284BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
285 const String& storage_scope,
286 const Array<BlockRV> consumer_blocks) {
287 BlockRV result =
288 ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks);
289
290 static const InstructionKind& kind = InstructionKind::Get("CacheRead");
291 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
292 /*inputs=*/{block_rv, consumer_blocks},
293 /*attrs=*/{Integer(read_buffer_index), storage_scope},
294 /*outputs=*/{result}));
295 return result;
296}
297
298BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
299 const String& storage_scope,
300 const Array<BlockRV> consumer_blocks) {
301 BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope,
302 consumer_blocks);
303
304 static const InstructionKind& kind = InstructionKind::Get("CacheWrite");
305 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
306 /*inputs=*/{block_rv, consumer_blocks},
307 /*attrs=*/{Integer(write_buffer_index), storage_scope},
308 /*outputs=*/{result}));
309 return result;
310}
311
312Array<BlockRV> TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index,
313 const String& storage_scope) {
314 Array<BlockRV> result =
315 ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope);
316 Array<ObjectRef> results;
317 for (const BlockRV& r : result) {
318 results.push_back(r);
319 }
320 static const InstructionKind& kind = InstructionKind::Get("CacheInplace");
321 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
322 /*inputs=*/{block_rv},
323 /*attrs=*/{Integer(read_buffer_index), storage_scope},
324 /*outputs=*/results));
325 return result;
326}
327
328Array<BlockRV> TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const String& storage_scope,
329 int cse_thresh) {
330 Array<BlockRV> result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh);
331 Array<ObjectRef> outputs;
332 for (const BlockRV& r : result) {
333 outputs.push_back(r);
334 }
335 static const InstructionKind& kind = InstructionKind::Get("CacheIndex");
336 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
337 /*inputs=*/{block_rv},
338 /*attrs=*/{storage_scope, Integer(cse_thresh)},
339 /*outputs=*/outputs));
340 return result;
341}
342
343BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
344 BufferIndexType buffer_index_type) {
345 BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type);
346
347 static const InstructionKind& kind = InstructionKind::Get("ReIndex");
348 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
349 /*inputs=*/{block_rv},
350 /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)},
351 /*outputs=*/{result}));
352 return result;
353}
354
355/******** Schedule: Compute location ********/
356
357void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
358 bool preserve_unit_loops, int index) {
359 ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
360
361 static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
362 trace_->Append(
363 /*inst=*/Instruction(/*kind=*/kind,
364 /*inputs=*/{block_rv, loop_rv},
365 /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
366 /*outputs=*/{}));
367}
368
369void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
370 bool preserve_unit_loops, int index) {
371 ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index);
372
373 static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt");
374 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
375 /*inputs=*/{block_rv, loop_rv},
376 /*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
377 /*outputs=*/{}));
378}
379
380void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) {
381 ConcreteScheduleNode::ComputeInline(block_rv);
382
383 static const InstructionKind& kind = InstructionKind::Get("ComputeInline");
384 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
385 /*inputs=*/{block_rv},
386 /*attrs=*/{},
387 /*outputs=*/{}));
388}
389
390void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
391 ConcreteScheduleNode::ReverseComputeInline(block_rv);
392
393 static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline");
394 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
395 /*inputs=*/{block_rv},
396 /*attrs=*/{},
397 /*outputs=*/{}));
398}
399
400/******** Schedule: Reduction ********/
401
402BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
403 BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv);
404 static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction");
405 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
406 /*inputs=*/{block_rv, loop_rv},
407 /*attrs=*/{},
408 /*outputs=*/{result}));
409 return result;
410}
411
412BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
413 BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis);
414 static const InstructionKind& kind = InstructionKind::Get("RFactor");
415 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
416 /*inputs=*/{loop_rv},
417 /*attrs=*/{Integer(factor_axis)},
418 /*outputs=*/{result}));
419 return result;
420}
421
422/******** Schedule: Block annotation ********/
423
424void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
425 int factor, int offset) {
426 ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset);
427 static const InstructionKind& kind = InstructionKind::Get("StorageAlign");
428 trace_->Append(/*inst=*/Instruction(
429 /*kind=*/kind,
430 /*inputs=*/{block_rv},
431 /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)},
432 /*outputs=*/{}));
433}
434
435void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
436 const String& storage_scope) {
437 ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope);
438 static const InstructionKind& kind = InstructionKind::Get("SetScope");
439 trace_->Append(/*inst=*/Instruction(
440 /*kind=*/kind,
441 /*inputs=*/{block_rv},
442 /*attrs=*/{Integer(buffer_index), storage_scope},
443 /*outputs=*/{}));
444}
445
446/******** Schedule: Blockize & Tensorize ********/
447
448BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
449 BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters);
450 static const InstructionKind& kind = InstructionKind::Get("Blockize");
451 trace_->Append(/*inst=*/Instruction(
452 /*kind=*/kind,
453 /*inputs=*/{loop_rv},
454 /*attrs=*/{Bool(preserve_unit_iters)},
455 /*outputs=*/{new_block}));
456 return new_block;
457}
458
459void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
460 bool preserve_unit_iters) {
461 ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters);
462 static const InstructionKind& kind = InstructionKind::Get("Tensorize");
463 trace_->Append(/*inst=*/Instruction(
464 /*kind=*/kind,
465 /*inputs=*/{loop_rv},
466 /*attrs=*/{intrin, Bool(preserve_unit_iters)},
467 /*outputs=*/{}));
468}
469
470void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin,
471 bool preserve_unit_iters) {
472 ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters);
473 static const InstructionKind& kind = InstructionKind::Get("Tensorize");
474 trace_->Append(/*inst=*/Instruction(
475 /*kind=*/kind,
476 /*inputs=*/{block_rv},
477 /*attrs=*/{intrin, Bool(preserve_unit_iters)},
478 /*outputs=*/{}));
479}
480
481/******** Schedule: Annotation ********/
482
483void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
484 const ObjectRef& ann_val) {
485 ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val);
486 static const InstructionKind& kind = InstructionKind::Get("Annotate");
487 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
488 /*inputs=*/{loop_rv, ann_val},
489 /*attrs=*/{ann_key},
490 /*outputs=*/{}));
491}
492
493void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key,
494 const ObjectRef& ann_val) {
495 ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val);
496 static const InstructionKind& kind = InstructionKind::Get("Annotate");
497 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
498 /*inputs=*/{block_rv, ann_val},
499 /*attrs=*/{ann_key},
500 /*outputs=*/{}));
501}
502
503void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) {
504 ConcreteScheduleNode::Unannotate(loop_rv, ann_key);
505 static const InstructionKind& kind = InstructionKind::Get("Unannotate");
506 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
507 /*inputs=*/{loop_rv},
508 /*attrs=*/{ann_key},
509 /*outputs=*/{}));
510}
511
512void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) {
513 ConcreteScheduleNode::Unannotate(block_rv, ann_key);
514 static const InstructionKind& kind = InstructionKind::Get("Unannotate");
515 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
516 /*inputs=*/{block_rv},
517 /*attrs=*/{ann_key},
518 /*outputs=*/{}));
519}
520
521/******** Schedule: Layout transformation ********/
522
523void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
524 BufferIndexType buffer_index_type,
525 const IndexMap& index_map,
526 const Optional<IndexMap>& pad_value) {
527 ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map,
528 pad_value);
529 static const InstructionKind& kind = InstructionKind::Get("TransformLayout");
530 trace_->Append(
531 /*inst=*/Instruction(
532 /*kind=*/kind,
533 /*inputs=*/{block_rv, index_map},
534 /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), pad_value},
535 /*outputs=*/{}));
536}
537
538void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) {
539 ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map);
540 static const InstructionKind& kind = InstructionKind::Get("TransformBlockLayout");
541 trace_->Append(
542 /*inst=*/Instruction(/*kind=*/kind,
543 /*inputs=*/{block_rv},
544 /*attrs=*/{index_map},
545 /*outputs=*/{}));
546}
547
548void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
549 BufferIndexType buffer_index_type,
550 const Array<IntImm>& axis_separators) {
551 ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type,
552 axis_separators);
553 static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator");
554 trace_->Append(/*inst=*/Instruction(
555 /*kind=*/kind,
556 /*inputs=*/{block_rv},
557 /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), axis_separators},
558 /*outputs=*/{}));
559}
560
561/******** Schedule: Padding ********/
562BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
563 BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv);
564 static const InstructionKind& kind = InstructionKind::Get("DecomposePadding");
565 trace_->Append(/*inst=*/Instruction(
566 /*kind=*/kind,
567 /*inputs=*/{block_rv, loop_rv},
568 /*attrs=*/{},
569 /*outputs=*/{new_block}));
570 return new_block;
571}
572
573void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) {
574 ConcreteScheduleNode::PadEinsum(block_rv, padding);
575 static const InstructionKind& kind = InstructionKind::Get("PadEinsum");
576 trace_->Append(/*inst=*/Instruction(
577 /*kind=*/kind,
578 /*inputs=*/{block_rv},
579 /*attrs=*/{padding},
580 /*outputs=*/{}));
581}
582
583/******** Schedule: Buffer transformation ********/
584
585void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) {
586 ConcreteScheduleNode::RollingBuffer(block_rv, write_buffer_index);
587 static const InstructionKind& kind = InstructionKind::Get("RollingBuffer");
588 trace_->Append(/*inst=*/Instruction(
589 /*kind=*/kind,
590 /*inputs=*/{block_rv},
591 /*attrs=*/{Integer(write_buffer_index)},
592 /*outputs=*/{}));
593}
594
595/******** Schedule: Misc ********/
596
597void TracedScheduleNode::EnterPostproc() {
598 ConcreteScheduleNode::EnterPostproc();
599 static const InstructionKind& kind = InstructionKind::Get("EnterPostproc");
600 trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
601 /*inputs=*/{},
602 /*attrs=*/{},
603 /*outputs=*/{}));
604}
605
606} // namespace tir
607} // namespace tvm
608