1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "./traced_schedule.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, |
25 | int debug_mask, ScheduleErrorRenderLevel error_render_level) { |
26 | ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>(); |
27 | n->state_ = ScheduleState(mod, debug_mask); |
28 | n->error_render_level_ = error_render_level; |
29 | n->symbol_table_ = {}; |
30 | n->analyzer_ = std::make_unique<arith::Analyzer>(); |
31 | n->trace_ = Trace(); |
32 | n->Seed(seed); |
33 | GlobalVar gv = NullValue<GlobalVar>(); |
34 | if (FindEntryFunc(mod, &gv) != nullptr) { |
35 | n->func_working_on_ = gv; |
36 | } else { |
37 | n->func_working_on_ = NullOpt; |
38 | } |
39 | return Schedule(std::move(n)); |
40 | } |
41 | |
42 | Schedule TracedScheduleNode::Copy() { |
43 | ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>(); |
44 | n->error_render_level_ = this->error_render_level_; |
45 | ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); |
46 | n->func_working_on_ = this->func_working_on_; |
47 | n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful |
48 | n->rand_state_ = ForkSeed(); |
49 | n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); |
50 | return Schedule(std::move(n)); |
51 | } |
52 | |
53 | /******** Schedule: Sampling ********/ |
54 | |
55 | ExprRV TracedScheduleNode::SampleCategorical(const Array<Integer>& candidates, |
56 | const Array<FloatImm>& probs, |
57 | Optional<Integer> decision) { |
58 | ExprRV result = |
59 | CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); |
60 | static const InstructionKind& kind = InstructionKind::Get("SampleCategorical" ); |
61 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
62 | /*inputs=*/{}, |
63 | /*attrs=*/{candidates, probs}, |
64 | /*outputs=*/{result}), |
65 | /*decision=*/decision); |
66 | return result; |
67 | } |
68 | |
69 | Array<ExprRV> TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n, |
70 | int max_innermost_factor, |
71 | Optional<Array<Integer>> decision) { |
72 | Array<ExprRV> results = CreateRV(tir::SamplePerfectTile( |
73 | &this->rand_state_, this->GetSRef(loop_rv), n, max_innermost_factor, &decision)); |
74 | |
75 | static const InstructionKind& kind = InstructionKind::Get("SamplePerfectTile" ); |
76 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
77 | /*inputs=*/{loop_rv}, |
78 | /*attrs=*/{Integer(n), Integer(max_innermost_factor)}, |
79 | /*outputs=*/{results.begin(), results.end()}), |
80 | /*decision=*/decision); |
81 | return results; |
82 | } |
83 | |
84 | LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, |
85 | Optional<Integer> decision) { |
86 | LoopRV result = CreateRV<LoopRV>(tir::SampleComputeLocation(this->state_, &this->rand_state_, |
87 | this->GetSRef(block_rv), &decision)); |
88 | |
89 | static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation" ); |
90 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
91 | /*inputs=*/{block_rv}, |
92 | /*attrs=*/{}, |
93 | /*outputs=*/{result}), |
94 | /*decision=*/decision); |
95 | return result; |
96 | } |
97 | |
98 | /******** Schedule: Get blocks & loops ********/ |
99 | |
100 | BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional<String>& func_name) { |
101 | GlobalVar gv = NullValue<GlobalVar>(); |
102 | if (func_name.defined()) { |
103 | gv = state_->mod->GetGlobalVar(func_name.value()); |
104 | } else if (func_working_on_.defined()) { |
105 | gv = this->func_working_on_.value(); |
106 | } else { |
107 | LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " |
108 | "specify the function name explicitly, or call `work_on` to specify the function " |
109 | "before using `get_block`." ; |
110 | } |
111 | BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); |
112 | |
113 | static const InstructionKind& kind = InstructionKind::Get("GetBlock" ); |
114 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
115 | /*inputs=*/{}, |
116 | /*attrs=*/{name, gv->name_hint}, |
117 | /*outputs=*/{result})); |
118 | return result; |
119 | } |
120 | |
121 | Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV& block_rv) { |
122 | Array<LoopRV> results = ConcreteScheduleNode::GetLoops(block_rv); |
123 | |
124 | static const InstructionKind& kind = InstructionKind::Get("GetLoops" ); |
125 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
126 | /*inputs=*/{block_rv}, |
127 | /*attrs=*/{}, |
128 | /*outputs=*/{results.begin(), results.end()})); |
129 | return results; |
130 | } |
131 | |
132 | Array<BlockRV> TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { |
133 | Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(block_rv); |
134 | |
135 | static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks" ); |
136 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
137 | /*inputs=*/{block_rv}, |
138 | /*attrs=*/{}, |
139 | /*outputs=*/{results.begin(), results.end()})); |
140 | return results; |
141 | } |
142 | |
143 | Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { |
144 | Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(loop_rv); |
145 | |
146 | static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks" ); |
147 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
148 | /*inputs=*/{loop_rv}, |
149 | /*attrs=*/{}, |
150 | /*outputs=*/{results.begin(), results.end()})); |
151 | return results; |
152 | } |
153 | |
154 | Array<BlockRV> TracedScheduleNode::GetProducers(const BlockRV& block_rv) { |
155 | Array<BlockRV> results = ConcreteScheduleNode::GetProducers(block_rv); |
156 | |
157 | static const InstructionKind& kind = InstructionKind::Get("GetProducers" ); |
158 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
159 | /*inputs=*/{block_rv}, |
160 | /*attrs=*/{}, |
161 | /*outputs=*/{results.begin(), results.end()})); |
162 | return results; |
163 | } |
164 | |
165 | Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { |
166 | Array<BlockRV> results = ConcreteScheduleNode::GetConsumers(block_rv); |
167 | |
168 | static const InstructionKind& kind = InstructionKind::Get("GetConsumers" ); |
169 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // |
170 | /*inputs=*/{block_rv}, |
171 | /*attrs=*/{}, |
172 | /*outputs=*/{results.begin(), results.end()})); |
173 | return results; |
174 | } |
175 | |
176 | /******** Schedule: Transform loops ********/ |
177 | |
178 | LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) { |
179 | LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops); |
180 | |
181 | static const InstructionKind& kind = InstructionKind::Get("Fuse" ); |
182 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
183 | /*inputs=*/{loop_rvs.begin(), loop_rvs.end()}, |
184 | /*attrs=*/{Integer(preserve_unit_loops)}, |
185 | /*outputs=*/{result})); |
186 | return result; |
187 | } |
188 | |
189 | Array<LoopRV> TracedScheduleNode::Split(const LoopRV& loop_rv, |
190 | const Array<Optional<ExprRV>>& factor_rvs, |
191 | bool preserve_unit_iters) { |
192 | Array<LoopRV> results = ConcreteScheduleNode::Split(loop_rv, factor_rvs, preserve_unit_iters); |
193 | |
194 | std::vector<ObjectRef> inputs; |
195 | inputs.reserve(1 + factor_rvs.size()); |
196 | inputs.push_back(loop_rv); |
197 | for (const ObjectRef& obj : factor_rvs) { |
198 | inputs.push_back(obj); |
199 | } |
200 | |
201 | static const InstructionKind& kind = InstructionKind::Get("Split" ); |
202 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
203 | /*inputs=*/inputs, |
204 | /*attrs=*/{Integer(preserve_unit_iters)}, |
205 | /*outputs=*/{results.begin(), results.end()})); |
206 | return results; |
207 | } |
208 | |
209 | void TracedScheduleNode::Reorder(const Array<LoopRV>& ordered_loop_rvs) { |
210 | ConcreteScheduleNode::Reorder(ordered_loop_rvs); |
211 | |
212 | static const InstructionKind& kind = InstructionKind::Get("Reorder" ); |
213 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
214 | /*inputs=*/{ordered_loop_rvs.begin(), ordered_loop_rvs.end()}, |
215 | /*attrs=*/{}, |
216 | /*outputs=*/{})); |
217 | } |
218 | |
219 | LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { |
220 | LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); |
221 | |
222 | static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop" ); |
223 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
224 | /*inputs=*/{block_rv}, |
225 | /*attrs=*/{}, |
226 | /*outputs=*/{result})); |
227 | return result; |
228 | } |
229 | |
230 | LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { |
231 | LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv); |
232 | |
233 | static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop" ); |
234 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
235 | /*inputs=*/{loop_rv}, |
236 | /*attrs=*/{}, |
237 | /*outputs=*/{result})); |
238 | return result; |
239 | } |
240 | |
241 | /******** Schedule: Manipulate ForKind ********/ |
242 | |
243 | void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { |
244 | ConcreteScheduleNode::Parallel(loop_rv); |
245 | |
246 | static const InstructionKind& kind = InstructionKind::Get("Parallel" ); |
247 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
248 | /*inputs=*/{loop_rv}, |
249 | /*attrs=*/{}, |
250 | /*outputs=*/{})); |
251 | } |
252 | |
253 | void TracedScheduleNode::Vectorize(const LoopRV& loop_rv) { |
254 | ConcreteScheduleNode::Vectorize(loop_rv); |
255 | |
256 | static const InstructionKind& kind = InstructionKind::Get("Vectorize" ); |
257 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
258 | /*inputs=*/{loop_rv}, |
259 | /*attrs=*/{}, |
260 | /*outputs=*/{})); |
261 | } |
262 | |
263 | void TracedScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis) { |
264 | ConcreteScheduleNode::Bind(loop_rv, thread_axis); |
265 | |
266 | static const InstructionKind& kind = InstructionKind::Get("Bind" ); |
267 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
268 | /*inputs=*/{loop_rv}, |
269 | /*attrs=*/{thread_axis}, |
270 | /*outputs=*/{})); |
271 | } |
272 | |
273 | void TracedScheduleNode::Unroll(const LoopRV& loop_rv) { |
274 | ConcreteScheduleNode::Unroll(loop_rv); |
275 | |
276 | static const InstructionKind& kind = InstructionKind::Get("Unroll" ); |
277 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
278 | /*inputs=*/{loop_rv}, |
279 | /*attrs=*/{}, |
280 | /*outputs=*/{})); |
281 | } |
282 | |
283 | /******** Schedule: Insert cache stages ********/ |
284 | BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index, |
285 | const String& storage_scope, |
286 | const Array<BlockRV> consumer_blocks) { |
287 | BlockRV result = |
288 | ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks); |
289 | |
290 | static const InstructionKind& kind = InstructionKind::Get("CacheRead" ); |
291 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
292 | /*inputs=*/{block_rv, consumer_blocks}, |
293 | /*attrs=*/{Integer(read_buffer_index), storage_scope}, |
294 | /*outputs=*/{result})); |
295 | return result; |
296 | } |
297 | |
298 | BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index, |
299 | const String& storage_scope, |
300 | const Array<BlockRV> consumer_blocks) { |
301 | BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope, |
302 | consumer_blocks); |
303 | |
304 | static const InstructionKind& kind = InstructionKind::Get("CacheWrite" ); |
305 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
306 | /*inputs=*/{block_rv, consumer_blocks}, |
307 | /*attrs=*/{Integer(write_buffer_index), storage_scope}, |
308 | /*outputs=*/{result})); |
309 | return result; |
310 | } |
311 | |
312 | Array<BlockRV> TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, |
313 | const String& storage_scope) { |
314 | Array<BlockRV> result = |
315 | ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope); |
316 | Array<ObjectRef> results; |
317 | for (const BlockRV& r : result) { |
318 | results.push_back(r); |
319 | } |
320 | static const InstructionKind& kind = InstructionKind::Get("CacheInplace" ); |
321 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
322 | /*inputs=*/{block_rv}, |
323 | /*attrs=*/{Integer(read_buffer_index), storage_scope}, |
324 | /*outputs=*/results)); |
325 | return result; |
326 | } |
327 | |
328 | Array<BlockRV> TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const String& storage_scope, |
329 | int cse_thresh) { |
330 | Array<BlockRV> result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); |
331 | Array<ObjectRef> outputs; |
332 | for (const BlockRV& r : result) { |
333 | outputs.push_back(r); |
334 | } |
335 | static const InstructionKind& kind = InstructionKind::Get("CacheIndex" ); |
336 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
337 | /*inputs=*/{block_rv}, |
338 | /*attrs=*/{storage_scope, Integer(cse_thresh)}, |
339 | /*outputs=*/outputs)); |
340 | return result; |
341 | } |
342 | |
343 | BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, |
344 | BufferIndexType buffer_index_type) { |
345 | BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); |
346 | |
347 | static const InstructionKind& kind = InstructionKind::Get("ReIndex" ); |
348 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
349 | /*inputs=*/{block_rv}, |
350 | /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type)}, |
351 | /*outputs=*/{result})); |
352 | return result; |
353 | } |
354 | |
355 | /******** Schedule: Compute location ********/ |
356 | |
357 | void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, |
358 | bool preserve_unit_loops, int index) { |
359 | ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index); |
360 | |
361 | static const InstructionKind& kind = InstructionKind::Get("ComputeAt" ); |
362 | trace_->Append( |
363 | /*inst=*/Instruction(/*kind=*/kind, |
364 | /*inputs=*/{block_rv, loop_rv}, |
365 | /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, |
366 | /*outputs=*/{})); |
367 | } |
368 | |
369 | void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, |
370 | bool preserve_unit_loops, int index) { |
371 | ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index); |
372 | |
373 | static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt" ); |
374 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
375 | /*inputs=*/{block_rv, loop_rv}, |
376 | /*attrs=*/{Integer(preserve_unit_loops), Integer(index)}, |
377 | /*outputs=*/{})); |
378 | } |
379 | |
380 | void TracedScheduleNode::ComputeInline(const BlockRV& block_rv) { |
381 | ConcreteScheduleNode::ComputeInline(block_rv); |
382 | |
383 | static const InstructionKind& kind = InstructionKind::Get("ComputeInline" ); |
384 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
385 | /*inputs=*/{block_rv}, |
386 | /*attrs=*/{}, |
387 | /*outputs=*/{})); |
388 | } |
389 | |
390 | void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { |
391 | ConcreteScheduleNode::ReverseComputeInline(block_rv); |
392 | |
393 | static const InstructionKind& kind = InstructionKind::Get("ReverseComputeInline" ); |
394 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
395 | /*inputs=*/{block_rv}, |
396 | /*attrs=*/{}, |
397 | /*outputs=*/{})); |
398 | } |
399 | |
400 | /******** Schedule: Reduction ********/ |
401 | |
402 | BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { |
403 | BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv); |
404 | static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction" ); |
405 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
406 | /*inputs=*/{block_rv, loop_rv}, |
407 | /*attrs=*/{}, |
408 | /*outputs=*/{result})); |
409 | return result; |
410 | } |
411 | |
412 | BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { |
413 | BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); |
414 | static const InstructionKind& kind = InstructionKind::Get("RFactor" ); |
415 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
416 | /*inputs=*/{loop_rv}, |
417 | /*attrs=*/{Integer(factor_axis)}, |
418 | /*outputs=*/{result})); |
419 | return result; |
420 | } |
421 | |
422 | /******** Schedule: Block annotation ********/ |
423 | |
424 | void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, |
425 | int factor, int offset) { |
426 | ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset); |
427 | static const InstructionKind& kind = InstructionKind::Get("StorageAlign" ); |
428 | trace_->Append(/*inst=*/Instruction( |
429 | /*kind=*/kind, |
430 | /*inputs=*/{block_rv}, |
431 | /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)}, |
432 | /*outputs=*/{})); |
433 | } |
434 | |
435 | void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, |
436 | const String& storage_scope) { |
437 | ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope); |
438 | static const InstructionKind& kind = InstructionKind::Get("SetScope" ); |
439 | trace_->Append(/*inst=*/Instruction( |
440 | /*kind=*/kind, |
441 | /*inputs=*/{block_rv}, |
442 | /*attrs=*/{Integer(buffer_index), storage_scope}, |
443 | /*outputs=*/{})); |
444 | } |
445 | |
446 | /******** Schedule: Blockize & Tensorize ********/ |
447 | |
448 | BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) { |
449 | BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv, preserve_unit_iters); |
450 | static const InstructionKind& kind = InstructionKind::Get("Blockize" ); |
451 | trace_->Append(/*inst=*/Instruction( |
452 | /*kind=*/kind, |
453 | /*inputs=*/{loop_rv}, |
454 | /*attrs=*/{Bool(preserve_unit_iters)}, |
455 | /*outputs=*/{new_block})); |
456 | return new_block; |
457 | } |
458 | |
459 | void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin, |
460 | bool preserve_unit_iters) { |
461 | ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters); |
462 | static const InstructionKind& kind = InstructionKind::Get("Tensorize" ); |
463 | trace_->Append(/*inst=*/Instruction( |
464 | /*kind=*/kind, |
465 | /*inputs=*/{loop_rv}, |
466 | /*attrs=*/{intrin, Bool(preserve_unit_iters)}, |
467 | /*outputs=*/{})); |
468 | } |
469 | |
470 | void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String& intrin, |
471 | bool preserve_unit_iters) { |
472 | ConcreteScheduleNode::Tensorize(block_rv, intrin, preserve_unit_iters); |
473 | static const InstructionKind& kind = InstructionKind::Get("Tensorize" ); |
474 | trace_->Append(/*inst=*/Instruction( |
475 | /*kind=*/kind, |
476 | /*inputs=*/{block_rv}, |
477 | /*attrs=*/{intrin, Bool(preserve_unit_iters)}, |
478 | /*outputs=*/{})); |
479 | } |
480 | |
481 | /******** Schedule: Annotation ********/ |
482 | |
483 | void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, |
484 | const ObjectRef& ann_val) { |
485 | ConcreteScheduleNode::Annotate(loop_rv, ann_key, ann_val); |
486 | static const InstructionKind& kind = InstructionKind::Get("Annotate" ); |
487 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
488 | /*inputs=*/{loop_rv, ann_val}, |
489 | /*attrs=*/{ann_key}, |
490 | /*outputs=*/{})); |
491 | } |
492 | |
493 | void TracedScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, |
494 | const ObjectRef& ann_val) { |
495 | ConcreteScheduleNode::Annotate(block_rv, ann_key, ann_val); |
496 | static const InstructionKind& kind = InstructionKind::Get("Annotate" ); |
497 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
498 | /*inputs=*/{block_rv, ann_val}, |
499 | /*attrs=*/{ann_key}, |
500 | /*outputs=*/{})); |
501 | } |
502 | |
503 | void TracedScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) { |
504 | ConcreteScheduleNode::Unannotate(loop_rv, ann_key); |
505 | static const InstructionKind& kind = InstructionKind::Get("Unannotate" ); |
506 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
507 | /*inputs=*/{loop_rv}, |
508 | /*attrs=*/{ann_key}, |
509 | /*outputs=*/{})); |
510 | } |
511 | |
512 | void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_key) { |
513 | ConcreteScheduleNode::Unannotate(block_rv, ann_key); |
514 | static const InstructionKind& kind = InstructionKind::Get("Unannotate" ); |
515 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
516 | /*inputs=*/{block_rv}, |
517 | /*attrs=*/{ann_key}, |
518 | /*outputs=*/{})); |
519 | } |
520 | |
521 | /******** Schedule: Layout transformation ********/ |
522 | |
523 | void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, |
524 | BufferIndexType buffer_index_type, |
525 | const IndexMap& index_map, |
526 | const Optional<IndexMap>& pad_value) { |
527 | ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, |
528 | pad_value); |
529 | static const InstructionKind& kind = InstructionKind::Get("TransformLayout" ); |
530 | trace_->Append( |
531 | /*inst=*/Instruction( |
532 | /*kind=*/kind, |
533 | /*inputs=*/{block_rv, index_map}, |
534 | /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), pad_value}, |
535 | /*outputs=*/{})); |
536 | } |
537 | |
538 | void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) { |
539 | ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map); |
540 | static const InstructionKind& kind = InstructionKind::Get("TransformBlockLayout" ); |
541 | trace_->Append( |
542 | /*inst=*/Instruction(/*kind=*/kind, |
543 | /*inputs=*/{block_rv}, |
544 | /*attrs=*/{index_map}, |
545 | /*outputs=*/{})); |
546 | } |
547 | |
548 | void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index, |
549 | BufferIndexType buffer_index_type, |
550 | const Array<IntImm>& axis_separators) { |
551 | ConcreteScheduleNode::SetAxisSeparator(block_rv, buffer_index, buffer_index_type, |
552 | axis_separators); |
553 | static const InstructionKind& kind = InstructionKind::Get("SetAxisSeparator" ); |
554 | trace_->Append(/*inst=*/Instruction( |
555 | /*kind=*/kind, |
556 | /*inputs=*/{block_rv}, |
557 | /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), axis_separators}, |
558 | /*outputs=*/{})); |
559 | } |
560 | |
561 | /******** Schedule: Padding ********/ |
562 | BlockRV TracedScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { |
563 | BlockRV new_block = ConcreteScheduleNode::DecomposePadding(block_rv, loop_rv); |
564 | static const InstructionKind& kind = InstructionKind::Get("DecomposePadding" ); |
565 | trace_->Append(/*inst=*/Instruction( |
566 | /*kind=*/kind, |
567 | /*inputs=*/{block_rv, loop_rv}, |
568 | /*attrs=*/{}, |
569 | /*outputs=*/{new_block})); |
570 | return new_block; |
571 | } |
572 | |
573 | void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) { |
574 | ConcreteScheduleNode::PadEinsum(block_rv, padding); |
575 | static const InstructionKind& kind = InstructionKind::Get("PadEinsum" ); |
576 | trace_->Append(/*inst=*/Instruction( |
577 | /*kind=*/kind, |
578 | /*inputs=*/{block_rv}, |
579 | /*attrs=*/{padding}, |
580 | /*outputs=*/{})); |
581 | } |
582 | |
583 | /******** Schedule: Buffer transformation ********/ |
584 | |
585 | void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { |
586 | ConcreteScheduleNode::RollingBuffer(block_rv, write_buffer_index); |
587 | static const InstructionKind& kind = InstructionKind::Get("RollingBuffer" ); |
588 | trace_->Append(/*inst=*/Instruction( |
589 | /*kind=*/kind, |
590 | /*inputs=*/{block_rv}, |
591 | /*attrs=*/{Integer(write_buffer_index)}, |
592 | /*outputs=*/{})); |
593 | } |
594 | |
595 | /******** Schedule: Misc ********/ |
596 | |
597 | void TracedScheduleNode::EnterPostproc() { |
598 | ConcreteScheduleNode::EnterPostproc(); |
599 | static const InstructionKind& kind = InstructionKind::Get("EnterPostproc" ); |
600 | trace_->Append(/*inst=*/Instruction(/*kind=*/kind, |
601 | /*inputs=*/{}, |
602 | /*attrs=*/{}, |
603 | /*outputs=*/{})); |
604 | } |
605 | |
606 | } // namespace tir |
607 | } // namespace tvm |
608 | |