1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/transform_step.cc
22 * \brief Transformation steps. These steps are used to manipulate the LoopState.
23 * They are similar to the schedule primitives in te::Stage.
24 */
25
26#include <tvm/auto_scheduler/compute_dag.h>
27#include <tvm/auto_scheduler/loop_state.h>
28#include <tvm/auto_scheduler/transform_step.h>
29#include <tvm/runtime/logging.h>
30#include <tvm/runtime/registry.h>
31#include <tvm/te/operation.h>
32
33#include <string>
34#include <utility>
35#include <vector>
36
37#include "utils.h"
38
39namespace dmlc {
40namespace json {
41
42template <>
43struct Handler<::tvm::Array<::tvm::Integer>> {
44 inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::Integer>& array) {
45 writer->BeginArray(false);
46 for (const auto& i : array) {
47 ICHECK(i.defined());
48 writer->WriteArrayItem(i->value);
49 }
50 writer->EndArray();
51 }
52 inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::Integer>* array) {
53 array->clear();
54 reader->BeginArray();
55 while (reader->NextArrayItem()) {
56 int value;
57 Handler<int>::Read(reader, &value);
58 array->push_back(value);
59 }
60 }
61};
62
63template <>
64struct Handler<::tvm::Array<::tvm::Optional<::tvm::Integer>>> {
65 inline static void Write(dmlc::JSONWriter* writer,
66 const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& array) {
67 writer->BeginArray(false);
68 for (const auto& i : array) {
69 ICHECK(i);
70 writer->WriteArrayItem(i.value()->value);
71 }
72 writer->EndArray();
73 }
74 inline static void Read(dmlc::JSONReader* reader,
75 ::tvm::Array<::tvm::Optional<::tvm::Integer>>* array) {
76 array->clear();
77 reader->BeginArray();
78 while (reader->NextArrayItem()) {
79 int value;
80 Handler<int>::Read(reader, &value);
81 array->push_back(::tvm::Integer(value));
82 }
83 }
84};
85
86} // namespace json
87} // namespace dmlc
88
89namespace tvm {
90namespace auto_scheduler {
91
92// Update the te::stage to tir::IterVar axis mapping
93void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
94 if (auto pop = stage->op.as<te::ComputeOpNode>()) {
95 Array<IterVar> axes;
96 for (const auto& axis : pop->axis) {
97 axes.push_back(axis);
98 }
99 for (const auto& axis : pop->reduce_axis) {
100 axes.push_back(axis);
101 }
102 stage_to_axes->Set(stage, std::move(axes));
103 } else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
104 {} // do nothing on Placeholder
105 } else {
106 LOG(FATAL) << "Invalid op " << stage->op;
107 }
108}
109
110const char* IteratorAnnotationString[] = {
111 "for", // kNone = 0
112 "unroll", // kUnroll = 1
113 "vectorize", // kVectorize = 2
114 "parallel", // kParallel = 3
115 "vthread", // kVThread = 4
116 "blockIdx.x", // kBlockX = 5
117 "threadIdx.x", // kThreadX = 6
118 "blockIdx.y", // kBlockY = 7
119 "threadIdx.y", // kThreadY = 8
120 "blockIdx.z", // kBlockZ = 9
121 "threadIdx.z", // kThreadZ = 10
122 "tensorize" // kTensorized = 11
123};
124
125StepNode* Step::CopyOnWrite() {
126 CHECK(data_ != nullptr);
127 if (!data_.unique()) {
128 if (const auto& ps = as<AnnotationStepNode>()) {
129 auto n = make_object<AnnotationStepNode>(*ps);
130 ObjectPtr<Object>(std::move(n)).swap(data_);
131 } else if (const auto& ps = as<FuseStepNode>()) {
132 auto n = make_object<FuseStepNode>(*ps);
133 ObjectPtr<Object>(std::move(n)).swap(data_);
134 } else if (const auto& ps = as<PragmaStepNode>()) {
135 auto n = make_object<PragmaStepNode>(*ps);
136 ObjectPtr<Object>(std::move(n)).swap(data_);
137 } else if (const auto& ps = as<ReorderStepNode>()) {
138 auto n = make_object<ReorderStepNode>(*ps);
139 ObjectPtr<Object>(std::move(n)).swap(data_);
140 } else if (const auto& ps = as<SplitStepNode>()) {
141 auto n = make_object<SplitStepNode>(*ps);
142 ObjectPtr<Object>(std::move(n)).swap(data_);
143 } else if (const auto& ps = as<FollowSplitStepNode>()) {
144 auto n = make_object<FollowSplitStepNode>(*ps);
145 ObjectPtr<Object>(std::move(n)).swap(data_);
146 } else if (const auto& ps = as<FollowFusedSplitStepNode>()) {
147 auto n = make_object<FollowFusedSplitStepNode>(*ps);
148 ObjectPtr<Object>(std::move(n)).swap(data_);
149 } else if (const auto& ps = as<StorageAlignStepNode>()) {
150 auto n = make_object<StorageAlignStepNode>(*ps);
151 ObjectPtr<Object>(std::move(n)).swap(data_);
152 } else if (const auto& ps = as<ComputeAtStepNode>()) {
153 auto n = make_object<ComputeAtStepNode>(*ps);
154 ObjectPtr<Object>(std::move(n)).swap(data_);
155 } else if (const auto& ps = as<ComputeInlineStepNode>()) {
156 auto n = make_object<ComputeInlineStepNode>(*ps);
157 ObjectPtr<Object>(std::move(n)).swap(data_);
158 } else if (const auto& ps = as<ComputeRootStepNode>()) {
159 auto n = make_object<ComputeRootStepNode>(*ps);
160 ObjectPtr<Object>(std::move(n)).swap(data_);
161 } else if (const auto& ps = as<CacheReadStepNode>()) {
162 auto n = make_object<CacheReadStepNode>(*ps);
163 ObjectPtr<Object>(std::move(n)).swap(data_);
164 } else if (const auto& ps = as<CacheWriteStepNode>()) {
165 auto n = make_object<CacheWriteStepNode>(*ps);
166 ObjectPtr<Object>(std::move(n)).swap(data_);
167 } else if (const auto& ps = as<RfactorStepNode>()) {
168 auto n = make_object<RfactorStepNode>(*ps);
169 ObjectPtr<Object>(std::move(n)).swap(data_);
170 } else {
171 LOG(FATAL) << "Invalid step: " << (*this);
172 }
173 }
174 return static_cast<StepNode*>(data_.get());
175}
176
177Step StepReadFromRecord(dmlc::JSONReader* reader) {
178 std::string name;
179 bool s;
180 s = reader->NextArrayItem();
181 ICHECK(s);
182 reader->Read(&name);
183 if (name == AnnotationStepNode::record_prefix_str) {
184 return AnnotationStep(reader);
185 } else if (name == FuseStepNode::record_prefix_str) {
186 return FuseStep(reader);
187 } else if (name == PragmaStepNode::record_prefix_str) {
188 return PragmaStep(reader);
189 } else if (name == ReorderStepNode::record_prefix_str) {
190 return ReorderStep(reader);
191 } else if (name == SplitStepNode::record_prefix_str) {
192 return SplitStep(reader);
193 } else if (name == FollowSplitStepNode::record_prefix_str) {
194 return FollowSplitStep(reader);
195 } else if (name == FollowFusedSplitStepNode::record_prefix_str) {
196 return FollowFusedSplitStep(reader);
197 } else if (name == StorageAlignStepNode::record_prefix_str) {
198 return StorageAlignStep(reader);
199 } else if (name == ComputeAtStepNode::record_prefix_str) {
200 return ComputeAtStep(reader);
201 } else if (name == ComputeInlineStepNode::record_prefix_str) {
202 return ComputeInlineStep(reader);
203 } else if (name == ComputeRootStepNode::record_prefix_str) {
204 return ComputeRootStep(reader);
205 } else if (name == CacheReadStepNode::record_prefix_str) {
206 return CacheReadStep(reader);
207 } else if (name == CacheWriteStepNode::record_prefix_str) {
208 return CacheWriteStep(reader);
209 } else if (name == RfactorStepNode::record_prefix_str) {
210 return RfactorStep(reader);
211 } else {
212 LOG(FATAL) << "Invalid step format: " << name;
213 }
214 return Step();
215}
216
217void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
218 // We need this runtime dispatcher because different steps have different function signatures
219 if (auto ps = step.as<AnnotationStepNode>()) {
220 ps->ApplyToState(state);
221 } else if (auto ps = step.as<FuseStepNode>()) {
222 ps->ApplyToState(state);
223 } else if (auto ps = step.as<PragmaStepNode>()) {
224 ps->ApplyToState(state);
225 } else if (auto ps = step.as<ReorderStepNode>()) {
226 ps->ApplyToState(state);
227 } else if (auto ps = step.as<SplitStepNode>()) {
228 ps->ApplyToState(state);
229 } else if (auto ps = step.as<FollowSplitStepNode>()) {
230 ps->ApplyToState(state);
231 } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
232 ps->ApplyToState(state);
233 } else if (auto ps = step.as<StorageAlignStepNode>()) {
234 ps->ApplyToState(state);
235 } else if (auto ps = step.as<ComputeAtStepNode>()) {
236 ps->ApplyToState(state);
237 } else if (auto ps = step.as<ComputeInlineStepNode>()) {
238 ps->ApplyToState(state);
239 } else if (auto ps = step.as<ComputeRootStepNode>()) {
240 ps->ApplyToState(state);
241 } else if (auto ps = step.as<CacheReadStepNode>()) {
242 ps->ApplyToState(state, dag);
243 } else if (auto ps = step.as<CacheWriteStepNode>()) {
244 ps->ApplyToState(state, dag);
245 } else if (auto ps = step.as<RfactorStepNode>()) {
246 ps->ApplyToState(state, dag);
247 } else {
248 LOG(FATAL) << "Invalid step: " << step;
249 }
250}
251
252void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
253 te::Schedule* schedule, const Array<Step>& transform_steps) {
254 if (auto ps = step.as<AnnotationStepNode>()) {
255 ps->ApplyToSchedule(stages, stage_to_axes);
256 } else if (auto ps = step.as<FuseStepNode>()) {
257 ps->ApplyToSchedule(stages, stage_to_axes);
258 } else if (auto ps = step.as<PragmaStepNode>()) {
259 ps->ApplyToSchedule(stages, stage_to_axes);
260 } else if (auto ps = step.as<ReorderStepNode>()) {
261 ps->ApplyToSchedule(stages, stage_to_axes);
262 } else if (auto ps = step.as<SplitStepNode>()) {
263 ps->ApplyToSchedule(stages, stage_to_axes);
264 } else if (auto ps = step.as<FollowSplitStepNode>()) {
265 ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
266 } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
267 ps->ApplyToSchedule(stages, stage_to_axes, transform_steps);
268 } else if (auto ps = step.as<StorageAlignStepNode>()) {
269 ps->ApplyToSchedule(stages, stage_to_axes);
270 } else if (auto ps = step.as<ComputeAtStepNode>()) {
271 ps->ApplyToSchedule(stages, stage_to_axes);
272 } else if (auto ps = step.as<ComputeInlineStepNode>()) {
273 ps->ApplyToSchedule(stages, stage_to_axes);
274 } else if (auto ps = step.as<ComputeRootStepNode>()) {
275 ps->ApplyToSchedule(stages, stage_to_axes);
276 } else if (auto ps = step.as<CacheReadStepNode>()) {
277 ps->ApplyToSchedule(stages, stage_to_axes, schedule);
278 } else if (auto ps = step.as<CacheWriteStepNode>()) {
279 ps->ApplyToSchedule(stages, stage_to_axes, schedule);
280 } else if (auto ps = step.as<RfactorStepNode>()) {
281 ps->ApplyToSchedule(stages, stage_to_axes, schedule);
282 } else {
283 LOG(FATAL) << "Invalid Step: " << step;
284 }
285}
286
287String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
288 StageToAxesMap* stage_to_axes, te::Schedule* schedule,
289 const Array<Step>& transform_steps) {
290 if (auto ps = step.as<AnnotationStepNode>()) {
291 return ps->PrintAsPythonAPI(stages, stage_to_axes);
292 } else if (auto ps = step.as<FuseStepNode>()) {
293 return ps->PrintAsPythonAPI(stages, stage_to_axes);
294 } else if (auto ps = step.as<PragmaStepNode>()) {
295 return ps->PrintAsPythonAPI(stages, stage_to_axes);
296 } else if (auto ps = step.as<ReorderStepNode>()) {
297 return ps->PrintAsPythonAPI(stages, stage_to_axes);
298 } else if (auto ps = step.as<SplitStepNode>()) {
299 return ps->PrintAsPythonAPI(stages, stage_to_axes);
300 } else if (auto ps = step.as<FollowSplitStepNode>()) {
301 return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
302 } else if (auto ps = step.as<FollowFusedSplitStepNode>()) {
303 return ps->PrintAsPythonAPI(stages, stage_to_axes, transform_steps);
304 } else if (auto ps = step.as<StorageAlignStepNode>()) {
305 return ps->PrintAsPythonAPI(stages, stage_to_axes);
306 } else if (auto ps = step.as<ComputeAtStepNode>()) {
307 return ps->PrintAsPythonAPI(stages, stage_to_axes);
308 } else if (auto ps = step.as<ComputeInlineStepNode>()) {
309 return ps->PrintAsPythonAPI(stages, stage_to_axes);
310 } else if (auto ps = step.as<ComputeRootStepNode>()) {
311 return ps->PrintAsPythonAPI(stages, stage_to_axes);
312 } else if (auto ps = step.as<CacheReadStepNode>()) {
313 return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
314 } else if (auto ps = step.as<CacheWriteStepNode>()) {
315 return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
316 } else if (auto ps = step.as<RfactorStepNode>()) {
317 return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
318 } else {
319 LOG(FATAL) << "Invalid Step: " << step;
320 }
321 return "";
322}
323
324/********** Steps working on single stage **********/
325
326/********** Annotation **********/
327AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) {
328 auto node = make_object<AnnotationStepNode>();
329 node->stage_id = stage_id;
330 node->iter_id = iter_id;
331 node->annotation = ann;
332 data_ = std::move(node);
333}
334
335AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) {
336 auto node = make_object<AnnotationStepNode>();
337 bool s;
338 s = reader->NextArrayItem();
339 ICHECK(s);
340 reader->Read(&node->stage_id);
341 s = reader->NextArrayItem();
342 ICHECK(s);
343 reader->Read(&node->iter_id);
344 s = reader->NextArrayItem();
345 ICHECK(s);
346 int int_val;
347 reader->Read(&int_val);
348 node->annotation = IteratorAnnotation(int_val);
349 data_ = std::move(node);
350}
351
352void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
353 writer->WriteArraySeperator();
354 writer->WriteString(record_prefix_str);
355 writer->WriteArrayItem(stage_id);
356 writer->WriteArrayItem(iter_id);
357 writer->WriteArrayItem(static_cast<int>(annotation));
358}
359
360Iterator AnnotationStepNode::ApplyToState(State* state) const {
361 const Stage& stage = (*state)->stages[stage_id];
362 Iterator it = stage->iters[iter_id];
363
364 ICHECK(it->annotation == IteratorAnnotation::kNone);
365 Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation, &it->orig_iters);
366 Stage new_stage = stage;
367 new_stage.CopyOnWrite()->iters.Set(iter_id, new_it);
368 state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage));
369 return new_it;
370}
371
372void AnnotationStepNode::ApplyToSchedule(Array<te::Stage>* stages,
373 StageToAxesMap* stage_to_axes) const {
374 te::Stage stage = (*stages)[stage_id];
375 const Array<IterVar>& axes = (*stage_to_axes)[stage];
376
377 switch (annotation) {
378 case IteratorAnnotation::kUnroll:
379 stage.unroll(axes[iter_id]);
380 break;
381 case IteratorAnnotation::kVectorize:
382 stage.vectorize(axes[iter_id]);
383 break;
384 case IteratorAnnotation::kParallel:
385 stage.parallel(axes[iter_id]);
386 break;
387 case IteratorAnnotation::kVThread:
388 case IteratorAnnotation::kBlockX:
389 case IteratorAnnotation::kBlockY:
390 case IteratorAnnotation::kBlockZ:
391 case IteratorAnnotation::kThreadX:
392 case IteratorAnnotation::kThreadY:
393 case IteratorAnnotation::kThreadZ:
394 stage.bind(axes[iter_id],
395 te::thread_axis(Range(), IteratorAnnotationString[static_cast<int>(annotation)]));
396 break;
397 case IteratorAnnotation::kNone:
398 break;
399 default:
400 LOG(FATAL) << "Invalid Annotation " << static_cast<int>(annotation);
401 break;
402 }
403
404 stages->Set(stage_id, std::move(stage));
405}
406
407String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
408 StageToAxesMap* stage_to_axes) const {
409 std::stringstream ss;
410 const auto& stage = (*stages)[stage_id];
411 const auto& iter = (*stage_to_axes)[stage][iter_id];
412 const auto& op_name = CleanName(stage->op->name);
413
414 ss << "s[" << op_name << "].";
415 switch (annotation) {
416 case IteratorAnnotation::kUnroll:
417 ss << "unroll(";
418 break;
419 case IteratorAnnotation::kVectorize:
420 ss << "vectorize(";
421 break;
422 case IteratorAnnotation::kParallel:
423 ss << "parallel(";
424 break;
425 case IteratorAnnotation::kVThread:
426 case IteratorAnnotation::kBlockX:
427 case IteratorAnnotation::kBlockY:
428 case IteratorAnnotation::kBlockZ:
429 case IteratorAnnotation::kThreadX:
430 case IteratorAnnotation::kThreadY:
431 case IteratorAnnotation::kThreadZ:
432 ss << "bind(";
433 break;
434 case IteratorAnnotation::kNone:
435 break;
436 default:
437 LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
438 break;
439 }
440 ss << CleanName(iter->var->name_hint, op_name);
441 switch (annotation) {
442 case IteratorAnnotation::kVThread:
443 case IteratorAnnotation::kBlockX:
444 case IteratorAnnotation::kBlockY:
445 case IteratorAnnotation::kBlockZ:
446 case IteratorAnnotation::kThreadX:
447 case IteratorAnnotation::kThreadY:
448 case IteratorAnnotation::kThreadZ:
449 ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
450 << "\")";
451 break;
452 default:
453 break;
454 }
455 ss << ")\n";
456
457 ApplyToSchedule(stages, stage_to_axes);
458 return ss.str();
459}
460
461/********** Fuse **********/
462FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
463 auto node = make_object<FuseStepNode>();
464 node->stage_id = stage_id;
465 for (const auto& x : fused_ids) {
466 ICHECK(x->IsInstance<IntImmNode>());
467 }
468 node->fused_ids = fused_ids;
469 data_ = std::move(node);
470}
471
472FuseStep::FuseStep(dmlc::JSONReader* reader) {
473 auto node = make_object<FuseStepNode>();
474 bool s;
475 s = reader->NextArrayItem();
476 ICHECK(s);
477 reader->Read(&node->stage_id);
478 s = reader->NextArrayItem();
479 ICHECK(s);
480 reader->Read(&node->fused_ids);
481 data_ = std::move(node);
482}
483
484void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
485 writer->WriteArraySeperator();
486 writer->WriteString(record_prefix_str);
487 writer->WriteArrayItem(stage_id);
488 writer->WriteArrayItem(fused_ids);
489}
490
491Iterator FuseStepNode::ApplyToState(State* state) const {
492 const Stage& stage = (*state)->stages[stage_id];
493 size_t old_iter_size = static_cast<int>(stage->iters.size());
494
495 String new_name;
496 PrimExpr new_extent = 1;
497 IteratorKind new_iter_kind = IteratorKind::kSpecial;
498 std::vector<Iterator> orig_iters;
499
500 for (size_t i = 0; i < fused_ids.size(); ++i) {
501 if (i > 0) {
502 ICHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1);
503 }
504 if (i != fused_ids.size() - 1) {
505 const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages;
506 if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i].IntValue())) !=
507 iter_to_attached_stage.end()) {
508 LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some "
509 << "stages. State before fusion:\n"
510 << (*state);
511 }
512 }
513
514 const Iterator& it = stage->iters[fused_ids[i].IntValue()];
515 orig_iters.push_back(it);
516 new_name = new_name + it->name + "@";
517
518 if (it->range.defined() && new_extent.defined()) {
519 new_extent = new_extent * it->range->extent;
520 } else {
521 new_extent = PrimExpr();
522 }
523
524 if (i == 0) {
525 new_iter_kind = it->iter_kind;
526 } else {
527 if (new_iter_kind != it->iter_kind) {
528 new_iter_kind = IteratorKind::kMixed;
529 }
530 }
531 }
532
533 Range range;
534 if (new_extent.defined()) {
535 range = Range::FromMinExtent(0, new_extent);
536 }
537 Iterator new_it =
538 Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone, &orig_iters);
539 Array<Iterator> new_iters;
540
541 if (fused_ids.empty()) {
542 new_iters.push_back(new_it);
543 } else {
544 new_iters.insert(new_iters.end(), stage->iters.begin(),
545 stage->iters.begin() + fused_ids.front().IntValue());
546 new_iters.push_back(new_it);
547 new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back().IntValue() + 1,
548 stage->iters.end());
549 }
550
551 StateNode* pstate = state->CopyOnWrite();
552 pstate->stages.Set(stage_id,
553 Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
554
555 if (fused_ids.empty()) {
556 return new_it;
557 }
558
559 // Two vectors are used to represent the iterator relation before and after fuse
560 // The original iterators in AttachMap will be updated with the new iterators
561 std::vector<IterKey> from_iters;
562 std::vector<IterKey> to_iters;
563 const size_t begin_id = fused_ids.front().IntValue(), end_id = fused_ids.back().IntValue();
564 for (size_t i = 0; i < old_iter_size; ++i) {
565 if (i <= begin_id) {
566 continue;
567 } else if (i > end_id) {
568 // move forward
569 from_iters.emplace_back(stage_id, i);
570 to_iters.emplace_back(stage_id, i - end_id + begin_id);
571 } else {
572 // move to the fused id
573 from_iters.emplace_back(stage_id, i);
574 to_iters.emplace_back(stage_id, begin_id);
575 }
576 }
577 pstate->attach_map.UpdateIters(from_iters, to_iters);
578
579 return new_it;
580}
581
582IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
583 StageToAxesMap* stage_to_axes) const {
584 auto stage = (*stages)[stage_id];
585 const Array<IterVar>& axes = stage_to_axes->at(stage);
586
587 Array<IterVar> to_fuse;
588 for (const auto& i : fused_ids) {
589 to_fuse.push_back(axes[i.IntValue()]);
590 }
591 IterVar fused_axis;
592 stage.fuse(to_fuse, &fused_axis);
593
594 Array<IterVar> new_axes;
595 if (fused_ids.empty()) {
596 new_axes.push_back(fused_axis);
597 } else {
598 new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front().IntValue());
599 new_axes.push_back(fused_axis);
600 new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back().IntValue() + 1, axes.end());
601 }
602
603 stage_to_axes->Set(stage, std::move(new_axes));
604 stages->Set(stage_id, std::move(stage));
605 return fused_axis;
606}
607
608String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
609 StageToAxesMap* stage_to_axes) const {
610 const auto& stage = (*stages)[stage_id];
611 const auto& op_name = CleanName(stage->op->name);
612 std::stringstream to_fuse;
613
614 for (size_t i = 0; i < fused_ids.size(); ++i) {
615 to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i].IntValue()]->var->name_hint,
616 op_name);
617 if (i != fused_ids.size() - 1) {
618 to_fuse << ", ";
619 }
620 }
621
622 std::stringstream ss;
623 const auto& fused = ApplyToSchedule(stages, stage_to_axes);
624
625 ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse("
626 << to_fuse.str() << ")\n";
627
628 return ss.str();
629}
630
631/********** Pragma **********/
632PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) {
633 auto node = make_object<PragmaStepNode>();
634 node->stage_id = stage_id;
635 node->iter_id = iter_id;
636 node->pragma_type = std::move(pragma_type);
637 data_ = std::move(node);
638}
639
640PragmaStep::PragmaStep(dmlc::JSONReader* reader) {
641 auto node = make_object<PragmaStepNode>();
642 bool s;
643 s = reader->NextArrayItem();
644 ICHECK(s);
645 reader->Read(&node->stage_id);
646 s = reader->NextArrayItem();
647 ICHECK(s);
648 reader->Read(&node->iter_id);
649 s = reader->NextArrayItem();
650 ICHECK(s);
651 std::string string_value;
652 reader->Read(&string_value);
653 node->pragma_type = std::move(string_value);
654 data_ = std::move(node);
655}
656
657void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
658 writer->WriteArraySeperator();
659 writer->WriteString(record_prefix_str);
660 writer->WriteArrayItem(stage_id);
661 writer->WriteArrayItem(iter_id);
662 writer->WriteArraySeperator();
663 writer->WriteString(pragma_type);
664}
665
666void PragmaStepNode::ApplyToState(State* state) const {
667 if (pragma_type == "debug_skip_region") {
668 StateNode* pstate = state->CopyOnWrite();
669 pstate->attach_map.DeleteStage(stage_id);
670 } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
671 StateNode* pstate = state->CopyOnWrite();
672 Stage stage = pstate->stages[stage_id];
673 size_t pos = 0;
674 for (; pos < pragma_type.size(); ++pos) {
675 if ((*(pragma_type.c_str() + pos)) == '$') {
676 break;
677 }
678 }
679 ICHECK_LT(pos, pragma_type.size()) << "max step value not found.";
680 stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1);
681 pstate->stages.Set(stage_id, std::move(stage));
682 } else {
683 LOG(FATAL) << "Unsupported pragma: " << pragma_type;
684 }
685}
686
687void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages,
688 StageToAxesMap* stage_to_axes) const {
689 te::Stage stage = (*stages)[stage_id];
690 const Array<IterVar>& axes = (*stage_to_axes)[stage];
691 if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
692 size_t pos = 0;
693 for (; pos < pragma_type.size(); ++pos) {
694 if ((*(pragma_type.c_str() + pos)) == '$') {
695 break;
696 }
697 }
698 ICHECK_LT(pos, pragma_type.size()) << "max step value not found.";
699 int value = atoi(pragma_type.c_str() + pos + 1);
700 if (iter_id < static_cast<int>(axes.size())) {
701 stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
702 stage.pragma(axes[iter_id], "unroll_explicit", true);
703 }
704 } else {
705 ICHECK_LT(iter_id, axes.size());
706 stage.pragma(axes[iter_id], pragma_type);
707 }
708 stages->Set(stage_id, std::move(stage));
709}
710
711String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
712 StageToAxesMap* stage_to_axes) const {
713 std::stringstream ss;
714 const auto& stage = (*stages)[stage_id];
715 const auto& op_name = CleanName(stage->op->name);
716
717 if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
718 size_t pos = 0;
719 for (; pos < pragma_type.size(); ++pos) {
720 if ((*(pragma_type.c_str() + pos)) == '$') {
721 break;
722 }
723 }
724 ICHECK_LT(pos, pragma_type.size()) << "max step value not found.";
725 int value = atoi(pragma_type.c_str() + pos + 1);
726 ss << "s[" << op_name << "].pragma("
727 << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
728 << ", \"auto_unroll_max_step\", " << value << ")\n";
729 ss << "s[" << op_name << "].pragma("
730 << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
731 << ", \"unroll_explicit\", True)\n";
732 } else {
733 ss << "s[" << op_name << "].pragma("
734 << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \""
735 << pragma_type << "\")\n";
736 }
737
738 ApplyToSchedule(stages, stage_to_axes);
739 return ss.str();
740}
741
742/********** Reorder **********/
743ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
744 auto node = make_object<ReorderStepNode>();
745 node->stage_id = stage_id;
746 for (const auto& x : after_ids) {
747 ICHECK(x->IsInstance<IntImmNode>());
748 }
749 node->after_ids = after_ids;
750 data_ = std::move(node);
751}
752
753ReorderStep::ReorderStep(dmlc::JSONReader* reader) {
754 auto node = make_object<ReorderStepNode>();
755 bool s;
756 s = reader->NextArrayItem();
757 ICHECK(s);
758 reader->Read(&node->stage_id);
759 s = reader->NextArrayItem();
760 ICHECK(s);
761 reader->Read(&node->after_ids);
762 data_ = std::move(node);
763}
764
765void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
766 writer->WriteArraySeperator();
767 writer->WriteString(record_prefix_str);
768 writer->WriteArrayItem(stage_id);
769 writer->WriteArrayItem(after_ids);
770}
771
772void ReorderStepNode::ApplyToState(State* state) const {
773 const Stage& stage = (*state)->stages[stage_id];
774 Array<Iterator> iters;
775 for (auto x : after_ids) {
776 iters.push_back(stage->iters[x.IntValue()]);
777 }
778 state->CopyOnWrite()->stages.Set(
779 stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
780}
781
782void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
783 StageToAxesMap* stage_to_axes) const {
784 auto stage = (*stages)[stage_id];
785 const Array<IterVar>& axes = stage_to_axes->at(stage);
786 ICHECK_EQ(after_ids.size(), axes.size());
787
788 Array<IterVar> new_axes;
789 new_axes.reserve(axes.size());
790 for (auto i : after_ids) {
791 new_axes.push_back(axes[i.IntValue()]);
792 }
793 stage.reorder(new_axes);
794
795 stage_to_axes->Set(stage, std::move(new_axes));
796 stages->Set(stage_id, std::move(stage));
797}
798
799String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
800 StageToAxesMap* stage_to_axes) const {
801 const auto& stage = (*stages)[stage_id];
802 const auto& op_name = CleanName(stage->op->name);
803 std::stringstream ss;
804
805 ss << "s[" << op_name << "].reorder(";
806 for (size_t i = 0; i < after_ids.size(); ++i) {
807 ss << CleanName((*stage_to_axes)[stage][after_ids[i].IntValue()]->var->name_hint, op_name);
808 if (i != after_ids.size() - 1) {
809 ss << ", ";
810 }
811 }
812 ss << ")\n";
813
814 ApplyToSchedule(stages, stage_to_axes);
815 return ss.str();
816}
817
818/********** Split **********/
819// common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep
820Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
821 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
822 const Stage& stage = (*state)->stages[stage_id];
823 const Iterator& it = stage->iters[iter_id];
824 size_t old_iter_size = stage->iters.size();
825 bool concrete = true;
826
827 Optional<PrimExpr> tosplit_min, tosplit_extent;
828 if (it->range.defined()) {
829 tosplit_min = it->range->min;
830 tosplit_extent = it->range->extent;
831 } else {
832 tosplit_min = NullOpt;
833 tosplit_extent = NullOpt;
834 }
835
836 Array<Iterator> outs;
837 for (size_t i = 0; i < lengths.size(); ++i) {
838 Optional<Integer> l;
839 String name;
840 if (inner_to_outer) {
841 l = lengths[lengths.size() - i - 1];
842 name = it->name + "." + std::to_string(lengths.size() - i);
843 } else {
844 l = lengths[i];
845 name = it->name + "." + std::to_string(i);
846 }
847 Iterator res;
848 if (l && tosplit_min && tosplit_extent) {
849 res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
850 IteratorAnnotation::kNone);
851 tosplit_min = Integer(0);
852 tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
853 } else {
854 res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
855 tosplit_min = NullOpt;
856 tosplit_extent = NullOpt;
857 if (!l.defined()) {
858 concrete = false;
859 }
860 }
861 outs.push_back(std::move(res));
862 }
863
864 Range range;
865 if (tosplit_min && tosplit_extent) {
866 range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
867 }
868 if (inner_to_outer) {
869 outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
870 // Reverse the Iterator array
871 Array<Iterator> temp(outs.rbegin(), outs.rend());
872 outs = std::move(temp);
873 } else {
874 outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
875 IteratorAnnotation::kNone));
876 }
877
878 Array<Iterator> new_iters;
879 new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
880 new_iters.insert(new_iters.end(), outs.begin(), outs.end());
881 new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
882
883 StateNode* pstate = state->CopyOnWrite();
884 pstate->stages.Set(stage_id,
885 Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
886 pstate->concrete &= concrete;
887
888 // Two vectors are used to represent the iterator relation before and after split
889 // The original iterators in AttachMap will be updated with the new iterators
890 std::vector<IterKey> from_iters;
891 std::vector<IterKey> to_iters;
892 for (size_t i = iter_id; i < old_iter_size; ++i) {
893 from_iters.emplace_back(stage_id, i);
894 to_iters.emplace_back(stage_id, i + lengths.size());
895 }
896 pstate->attach_map.UpdateIters(from_iters, to_iters);
897
898 return outs;
899}
900
901Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
902 int stage_id, int iter_id,
903 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
904 auto stage = (*stages)[stage_id];
905 const Array<IterVar>& axes = stage_to_axes->at(stage);
906
907 Array<IterVar> outs;
908 if (inner_to_outer) {
909 IterVar outer = axes[iter_id], inner;
910 for (int i = static_cast<int>(lengths.size()) - 1; i >= 0; i--) {
911 IterVar to_split = outer;
912 stage.split(to_split, lengths[i].value(), &outer, &inner);
913 outs.push_back(inner);
914 }
915 outs.push_back(outer);
916 } else {
917 IterVar outer, inner = axes[iter_id];
918 for (size_t i = 0; i < lengths.size(); i++) {
919 IterVar to_split = inner;
920 stage.split_by_nparts(to_split, lengths[i].value(), &outer, &inner);
921 outs.push_back(outer);
922 }
923 outs.push_back(inner);
924 }
925
926 Array<IterVar> new_axes;
927 new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id);
928 if (inner_to_outer) {
929 for (auto x = outs.rbegin(); x != outs.rend(); ++x) {
930 new_axes.push_back((*x));
931 }
932 } else {
933 for (const auto& x : outs) {
934 new_axes.push_back(x);
935 }
936 }
937 new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end());
938
939 stage_to_axes->Set(stage, std::move(new_axes));
940 stages->Set(stage_id, std::move(stage));
941 return outs;
942}
943
944String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, int stage_id,
945 int iter_id, const Array<Optional<Integer>>& lengths,
946 bool inner_to_outer) {
947 const auto& stage = (*stages)[stage_id];
948 auto to_split = stage_to_axes->at(stage)[iter_id];
949 const auto& func_name = CleanName(stage->op->name);
950 const auto& outs =
951 ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
952 ICHECK_EQ(outs.size(), lengths.size() + 1);
953
954 std::stringstream ss;
955 int size = static_cast<int>(lengths.size());
956 if (inner_to_outer) {
957 for (int i = size - 1; i >= 0; i--) {
958 ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", "
959 << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name
960 << "].split(" << CleanName(to_split->var->name_hint, func_name)
961 << ", factor=" << lengths[i] << ")\n";
962 to_split = outs[size - i];
963 }
964 } else {
965 for (int i = 0; i < size; i++) {
966 ss << CleanName(outs[i]->var->name_hint, func_name) << ", "
967 << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split("
968 << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n";
969 to_split = outs[i + 1];
970 }
971 }
972
973 return ss.str();
974}
975
976SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
977 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
978 auto node = make_object<SplitStepNode>();
979 node->stage_id = stage_id;
980 // Extent can be a irreducible expression in some special cases
981 if (extent && extent.value()->IsInstance<IntImmNode>()) {
982 node->extent = tvm::Downcast<Integer>(extent.value());
983 }
984 node->iter_id = iter_id;
985 node->lengths = lengths;
986 node->inner_to_outer = inner_to_outer;
987 data_ = std::move(node);
988}
989
990SplitStep::SplitStep(dmlc::JSONReader* reader) {
991 auto node = make_object<SplitStepNode>();
992 bool s;
993 s = reader->NextArrayItem();
994 ICHECK(s);
995 reader->Read(&node->stage_id);
996 s = reader->NextArrayItem();
997 ICHECK(s);
998 reader->Read(&node->iter_id);
999 int int_val;
1000 s = reader->NextArrayItem();
1001 ICHECK(s);
1002 reader->Read(&int_val);
1003 if (int_val) {
1004 node->extent = Integer(int_val);
1005 }
1006 s = reader->NextArrayItem();
1007 ICHECK(s);
1008 reader->Read(&node->lengths);
1009 s = reader->NextArrayItem();
1010 ICHECK(s);
1011 reader->Read(&node->inner_to_outer);
1012 data_ = std::move(node);
1013}
1014
1015void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1016 writer->WriteArraySeperator();
1017 writer->WriteString(record_prefix_str);
1018 writer->WriteArrayItem(stage_id);
1019 writer->WriteArrayItem(iter_id);
1020 writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
1021 writer->WriteArrayItem(lengths);
1022 writer->WriteArrayItem(static_cast<int>(inner_to_outer));
1023}
1024
1025Array<Iterator> SplitStepNode::ApplyToState(State* state) const {
1026 return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer);
1027}
1028
1029Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1030 StageToAxesMap* stage_to_axes) const {
1031 return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
1032}
1033
1034String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1035 StageToAxesMap* stage_to_axes) const {
1036 return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
1037}
1038
1039/********** Follow Split **********/
1040FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split) {
1041 auto node = make_object<FollowSplitStepNode>();
1042 node->stage_id = stage_id;
1043 node->iter_id = iter_id;
1044 node->src_step_id = src_step_id;
1045 node->n_split = n_split;
1046 data_ = std::move(node);
1047}
1048
1049void FollowSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1050 writer->WriteArraySeperator();
1051 writer->WriteString(record_prefix_str);
1052 writer->WriteArrayItem(stage_id);
1053 writer->WriteArrayItem(iter_id);
1054 writer->WriteArrayItem(src_step_id);
1055 writer->WriteArrayItem(n_split);
1056}
1057
1058Array<Optional<Integer>> FollowSplitStepNode::ExtractSplitLengths(
1059 const Array<Step>& transform_steps) const {
1060 // Make sure src_step_id is within the range of transform_steps.
1061 ICHECK_LT(src_step_id, transform_steps.size());
1062 auto ps = transform_steps[src_step_id].as<SplitStepNode>();
1063 ICHECK(ps != nullptr);
1064
1065 // Make sure the size of ps->lengths is not smaller than n_split-1.
1066 // Note that the number of actual splitting factors of src_step is ps->lengths.size()+1.
1067 ICHECK_LE(n_split, ps->lengths.size() + 1);
1068 ICHECK(ps != nullptr);
1069
1070 Array<Optional<Integer>> lengths;
1071 lengths.reserve(n_split);
1072 int j = 0;
1073 // Get the first (n_split-1) split factors of followed src_step.
1074 for (; j < n_split - 1; ++j) {
1075 lengths.push_back(ps->lengths[j]);
1076 }
1077
1078 // Get the last split factor of src_step for splitting level if n_split is smaller than
1079 // ps->lengths.size()+1.
1080 PrimExpr last_factor = 1;
1081 for (; j < static_cast<int>(ps->lengths.size()); ++j) {
1082 if (ps->lengths[j]) {
1083 last_factor *= ps->lengths[j].value();
1084 } else {
1085 last_factor = PrimExpr();
1086 break;
1087 }
1088 }
1089 if (last_factor.defined()) {
1090 lengths.push_back(Downcast<Integer>(last_factor));
1091 } else {
1092 lengths.push_back(NullOpt);
1093 }
1094
1095 return lengths;
1096}
1097
1098FollowSplitStep::FollowSplitStep(dmlc::JSONReader* reader) {
1099 auto node = make_object<FollowSplitStepNode>();
1100 bool s;
1101 s = reader->NextArrayItem();
1102 ICHECK(s);
1103 reader->Read(&node->stage_id);
1104 s = reader->NextArrayItem();
1105 ICHECK(s);
1106 reader->Read(&node->iter_id);
1107 s = reader->NextArrayItem();
1108 ICHECK(s);
1109 reader->Read(&node->src_step_id);
1110 s = reader->NextArrayItem();
1111 ICHECK(s);
1112 reader->Read(&node->n_split);
1113 data_ = std::move(node);
1114}
1115
1116Array<Iterator> FollowSplitStepNode::ApplyToState(State* state) const {
1117 return ApplySplitToState(state, stage_id, iter_id, ExtractSplitLengths((*state)->transform_steps),
1118 true);
1119}
1120
1121Array<IterVar> FollowSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1122 StageToAxesMap* stage_to_axes,
1123 const Array<Step>& transform_steps) const {
1124 return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
1125 ExtractSplitLengths(transform_steps), true);
1126}
1127
1128String FollowSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1129 StageToAxesMap* stage_to_axes,
1130 const Array<Step>& transform_steps) const {
1131 return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
1132 ExtractSplitLengths(transform_steps), true);
1133}
1134
1135/********** Follow Fused Split **********/
1136FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id,
1137 const Array<Integer>& src_step_ids, int level,
1138 bool factor_or_nparts) {
1139 auto node = make_object<FollowFusedSplitStepNode>();
1140 node->stage_id = stage_id;
1141 node->iter_id = iter_id;
1142 node->src_step_ids = src_step_ids;
1143 node->level = level;
1144 node->factor_or_nparts = factor_or_nparts;
1145 data_ = std::move(node);
1146}
1147
1148FollowFusedSplitStep::FollowFusedSplitStep(dmlc::JSONReader* reader) {
1149 auto node = make_object<FollowFusedSplitStepNode>();
1150 bool s;
1151 s = reader->NextArrayItem();
1152 ICHECK(s);
1153 reader->Read(&node->stage_id);
1154 s = reader->NextArrayItem();
1155 ICHECK(s);
1156 reader->Read(&node->iter_id);
1157 s = reader->NextArrayItem();
1158 ICHECK(s);
1159 reader->Read(&node->src_step_ids);
1160 s = reader->NextArrayItem();
1161 ICHECK(s);
1162 reader->Read(&node->level);
1163 s = reader->NextArrayItem();
1164 ICHECK(s);
1165 reader->Read(&node->factor_or_nparts);
1166 data_ = std::move(node);
1167}
1168
1169void FollowFusedSplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1170 writer->WriteArraySeperator();
1171 writer->WriteString(record_prefix_str);
1172 writer->WriteArrayItem(stage_id);
1173 writer->WriteArrayItem(iter_id);
1174 writer->WriteArrayItem(src_step_ids);
1175 writer->WriteArrayItem(level);
1176 writer->WriteArrayItem(static_cast<int>(factor_or_nparts));
1177}
1178
1179Optional<Integer> FollowFusedSplitStepNode::ExtractSplitLength(
1180 const Array<Step>& transform_steps) const {
1181 PrimExpr ret(1);
1182
1183 for (auto src_step_id : src_step_ids) {
1184 // Make sure the src_step_id is within the range of transform_steps.
1185 ICHECK_LT(src_step_id.IntValue(), transform_steps.size());
1186 auto ps = transform_steps[src_step_id.IntValue()].as<SplitStepNode>();
1187 ICHECK(ps != nullptr);
1188 // Multiple the splitting factor on corresponding splitting level of src_steps.
1189 if (ps->lengths[level] && ret.defined()) {
1190 ret *= ps->lengths[level].value();
1191 } else {
1192 return NullOpt;
1193 }
1194 }
1195 return Downcast<Integer>(ret);
1196}
1197
1198Array<Iterator> FollowFusedSplitStepNode::ApplyToState(State* state) const {
1199 return ApplySplitToState(state, stage_id, iter_id,
1200 {ExtractSplitLength((*state)->transform_steps)}, factor_or_nparts);
1201}
1202
1203Array<IterVar> FollowFusedSplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1204 StageToAxesMap* stage_to_axes,
1205 const Array<Step>& transform_steps) const {
1206 return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id,
1207 {ExtractSplitLength(transform_steps)}, factor_or_nparts);
1208}
1209
1210String FollowFusedSplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1211 StageToAxesMap* stage_to_axes,
1212 const Array<Step>& transform_steps) const {
1213 return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id,
1214 {ExtractSplitLength(transform_steps)}, factor_or_nparts);
1215}
1216
1217/********** Storage Align **********/
1218StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) {
1219 auto node = make_object<StorageAlignStepNode>();
1220 node->stage_id = stage_id;
1221 node->iter_id = iter_id;
1222 node->factor = factor;
1223 node->offset = offset;
1224 data_ = std::move(node);
1225}
1226
1227StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) {
1228 auto node = make_object<StorageAlignStepNode>();
1229 bool s;
1230 s = reader->NextArrayItem();
1231 ICHECK(s);
1232 reader->Read(&node->stage_id);
1233 s = reader->NextArrayItem();
1234 ICHECK(s);
1235 reader->Read(&node->iter_id);
1236 s = reader->NextArrayItem();
1237 ICHECK(s);
1238 reader->Read(&node->factor);
1239 s = reader->NextArrayItem();
1240 ICHECK(s);
1241 reader->Read(&node->offset);
1242 data_ = std::move(node);
1243}
1244
1245void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1246 writer->WriteArraySeperator();
1247 writer->WriteString(record_prefix_str);
1248 writer->WriteArrayItem(stage_id);
1249 writer->WriteArrayItem(iter_id);
1250 writer->WriteArrayItem(factor);
1251 writer->WriteArrayItem(offset);
1252}
1253
1254void StorageAlignStepNode::ApplyToState(State* state) const {
1255 StateNode* pstate = state->CopyOnWrite();
1256 Stage stage = pstate->stages[stage_id];
1257 stage.CopyOnWrite()->attrs.storage_offset = offset;
1258 pstate->stages.Set(stage_id, std::move(stage));
1259}
1260
1261void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1262 StageToAxesMap* stage_to_axes) const {
1263 te::Stage stage = (*stages)[stage_id];
1264 const Array<IterVar>& axes = (*stage_to_axes)[stage];
1265 stage.storage_align(axes[iter_id], factor, offset);
1266 stages->Set(stage_id, std::move(stage));
1267}
1268
1269String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1270 StageToAxesMap* stage_to_axes) const {
1271 std::stringstream ss;
1272 const auto& stage = (*stages)[stage_id];
1273 const auto& op_name = CleanName(stage->op->name);
1274 ss << "s[" << op_name << "].storage_align("
1275 << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor
1276 << ", " << offset << ")\n";
1277
1278 ApplyToSchedule(stages, stage_to_axes);
1279 return ss.str();
1280}
1281
1282/********** Steps working on multiple stages **********/
1283
1284/********** Compute At **********/
1285ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) {
1286 auto node = make_object<ComputeAtStepNode>();
1287 node->stage_id = stage_id;
1288 node->target_stage_id = target_stage_id;
1289 node->target_iter_id = target_iter_id;
1290 data_ = std::move(node);
1291}
1292
1293ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) {
1294 auto node = make_object<ComputeAtStepNode>();
1295 bool s;
1296 s = reader->NextArrayItem();
1297 ICHECK(s);
1298 reader->Read(&node->stage_id);
1299 s = reader->NextArrayItem();
1300 ICHECK(s);
1301 reader->Read(&node->target_stage_id);
1302 s = reader->NextArrayItem();
1303 ICHECK(s);
1304 reader->Read(&node->target_iter_id);
1305 data_ = std::move(node);
1306}
1307
1308void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1309 writer->WriteArraySeperator();
1310 writer->WriteString(record_prefix_str);
1311 writer->WriteArrayItem(stage_id);
1312 writer->WriteArrayItem(target_stage_id);
1313 writer->WriteArrayItem(target_iter_id);
1314}
1315void ComputeAtStepNode::ApplyToState(State* state) const {
1316 const Stage& stage = (*state)->stages[stage_id];
1317
1318 // Remove the bound information of each iterator since they may not be accurate after
1319 // compute at
1320 Array<Iterator> new_iters;
1321 for (const Iterator& it : stage->iters) {
1322 new_iters.push_back(
1323 Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->orig_iters));
1324 }
1325
1326 StateNode* pstate = state->CopyOnWrite();
1327 pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
1328 ComputeAtKind::kIter, stage->attrs));
1329 // Update attach map
1330 pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id);
1331}
1332
1333void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1334 StageToAxesMap* stage_to_axes) const {
1335 te::Stage stage = (*stages)[stage_id];
1336 const auto& target_stage = (*stages)[target_stage_id];
1337 const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id];
1338 stage.compute_at(target_stage, target_axis);
1339
1340 stages->Set(stage_id, std::move(stage));
1341}
1342
1343String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1344 StageToAxesMap* stage_to_axes) const {
1345 std::stringstream ss;
1346 const auto& stage = (*stages)[stage_id];
1347 const auto& target_stage = (*stages)[target_stage_id];
1348 const auto& op_name = CleanName(stage->op->name);
1349 const auto& target_op_name = CleanName(target_stage->op->name);
1350 ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], "
1351 << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name)
1352 << ")\n";
1353 ApplyToSchedule(stages, stage_to_axes);
1354 return ss.str();
1355}
1356
1357/********** Compute Inline **********/
1358ComputeInlineStep::ComputeInlineStep(int stage_id) {
1359 auto node = make_object<ComputeInlineStepNode>();
1360 node->stage_id = stage_id;
1361 data_ = std::move(node);
1362}
1363
1364ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) {
1365 auto node = make_object<ComputeInlineStepNode>();
1366 bool s;
1367 s = reader->NextArrayItem();
1368 ICHECK(s);
1369 reader->Read(&node->stage_id);
1370 data_ = std::move(node);
1371}
1372
1373void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1374 writer->WriteArraySeperator();
1375 writer->WriteString(record_prefix_str);
1376 writer->WriteArrayItem(stage_id);
1377}
1378
1379void ComputeInlineStepNode::ApplyToState(State* state) const {
1380 const Stage& stage = (*state)->stages[stage_id];
1381
1382 // Check the validity of compute_inline
1383 for (size_t i = 0; i < stage->iters.size(); ++i) {
1384 ICHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0)
1385 << "Invalid compute_inline: There are some other stages that are attached to the "
1386 << "target stage";
1387 }
1388
1389 StateNode* pstate = state->CopyOnWrite();
1390 auto new_stage = pstate->stages[stage_id];
1391 new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined;
1392 pstate->stages.Set(stage_id, std::move(new_stage));
1393 // Update attach map
1394 pstate->attach_map.DeleteStage(stage_id);
1395}
1396
1397void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1398 StageToAxesMap* stage_to_axes) const {
1399 auto stage = (*stages)[stage_id];
1400 stage.compute_inline();
1401 stages->Set(stage_id, std::move(stage));
1402}
1403
1404String ComputeInlineStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1405 StageToAxesMap* stage_to_axes) const {
1406 std::stringstream ss;
1407 const auto& stage = (*stages)[stage_id];
1408 ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n";
1409 ApplyToSchedule(stages, stage_to_axes);
1410 return ss.str();
1411}
1412
1413/********** Compute Root **********/
1414ComputeRootStep::ComputeRootStep(int stage_id) {
1415 auto node = make_object<ComputeRootStepNode>();
1416 node->stage_id = stage_id;
1417 data_ = std::move(node);
1418}
1419
1420ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) {
1421 auto node = make_object<ComputeRootStepNode>();
1422 bool s;
1423 s = reader->NextArrayItem();
1424 ICHECK(s);
1425 reader->Read(&node->stage_id);
1426 data_ = std::move(node);
1427}
1428
1429void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1430 writer->WriteArraySeperator();
1431 writer->WriteString(record_prefix_str);
1432 writer->WriteArrayItem(stage_id);
1433}
1434
1435void ComputeRootStepNode::ApplyToState(State* state) const {
1436 const Stage& stage = (*state)->stages[stage_id];
1437
1438 // Remove the bound information of each iterator since they may not be accurate after
1439 // compute root
1440 Array<Iterator> new_iters;
1441 for (const Iterator& it : stage->iters) {
1442 new_iters.push_back(
1443 Iterator(it->name, Range(), it->iter_kind, it->annotation, &it->orig_iters));
1444 }
1445
1446 StateNode* pstate = state->CopyOnWrite();
1447 pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
1448 ComputeAtKind::kRoot, stage->attrs));
1449 // Update attach map
1450 pstate->attach_map.DeleteStage(stage_id);
1451}
1452
1453void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1454 StageToAxesMap* stage_to_axes) const {
1455 auto stage = (*stages)[stage_id];
1456 stage.compute_root();
1457 stages->Set(stage_id, std::move(stage));
1458}
1459
1460String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
1461 StageToAxesMap* stage_to_axes) const {
1462 std::stringstream ss;
1463 const auto& stage = (*stages)[stage_id];
1464 ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n";
1465 ApplyToSchedule(stages, stage_to_axes);
1466 return ss.str();
1467}
1468
1469/********** Steps adding new stages **********/
1470
1471/*!
1472 * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
1473 * RfactorStep). This will return all steps that can change the number of stages in a ComputeDAG,
1474 * and stop by the current step.
1475 */
1476Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
1477 Array<Step> ret_steps;
1478 for (size_t i = 0; i < transform_steps.size(); ++i) {
1479 const Step& step = transform_steps[i];
1480 if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
1481 ret_steps.push_back(step);
1482 } else if (step->IsInstance<RfactorStepNode>()) {
1483 // add FuseStepNode required by rfactor
1484 if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
1485 const Step& fuse_step = transform_steps[i - 2];
1486 if (fuse_step->stage_id == step->stage_id) {
1487 ret_steps.push_back(fuse_step);
1488 }
1489 }
1490 // add SplitStepNode required by rfactor
1491 ICHECK_GE(i, 1);
1492 ICHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
1493 const Step& split_step = transform_steps[i - 1];
1494 ICHECK_EQ(split_step->stage_id, step->stage_id);
1495 ret_steps.push_back(split_step);
1496 // add RfactorStepNode
1497 ret_steps.push_back(step);
1498 }
1499 // A state may have multiple stage modifiable steps, stop by the current step to avoid
1500 // replaying excess steps
1501 if (step.same_as(current_step)) {
1502 break;
1503 }
1504 }
1505 return ret_steps;
1506}
1507
1508/********** Cache Read **********/
1509CacheReadStep::CacheReadStep(int stage_id, String scope_name,
1510 const Array<Integer>& reader_stage_ids) {
1511 auto node = make_object<CacheReadStepNode>();
1512 node->stage_id = stage_id;
1513 node->scope_name = std::move(scope_name);
1514 node->reader_stage_ids = reader_stage_ids;
1515 data_ = std::move(node);
1516}
1517
1518CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
1519 auto node = make_object<CacheReadStepNode>();
1520 bool s;
1521 s = reader->NextArrayItem();
1522 ICHECK(s);
1523 reader->Read(&node->stage_id);
1524 s = reader->NextArrayItem();
1525 ICHECK(s);
1526 std::string string_value;
1527 reader->Read(&string_value);
1528 node->scope_name = std::move(string_value);
1529 s = reader->NextArrayItem();
1530 ICHECK(s);
1531 reader->Read(&node->reader_stage_ids);
1532 data_ = std::move(node);
1533}
1534
1535void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1536 writer->WriteArraySeperator();
1537 writer->WriteString(record_prefix_str);
1538 writer->WriteArrayItem(stage_id);
1539 writer->WriteArraySeperator();
1540 writer->WriteString(scope_name);
1541 writer->WriteArrayItem(reader_stage_ids);
1542}
1543
1544int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
1545 StateNode* pstate = state->CopyOnWrite();
1546 const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
1547 GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
1548
1549 // target_stage -> target_stage + target_store
1550 // Update the op of the target stage, insert a new cache read stage behind, update the op of
1551 // later stages, then update the stage_id mapping in AttachMap
1552 int added_stage_id = stage_id + 1;
1553 Stage tmp_stage = pstate->stages[stage_id];
1554 tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
1555 pstate->stages.Set(stage_id, std::move(tmp_stage));
1556 pstate->stages.insert(pstate->stages.begin() + added_stage_id,
1557 Stage(current_compute_dag->ops[added_stage_id]));
1558 for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
1559 tmp_stage = pstate->stages[i];
1560 tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
1561 pstate->stages.Set(i, std::move(tmp_stage));
1562 }
1563 pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
1564 pstate->current_compute_dag = std::move(current_compute_dag);
1565
1566 return added_stage_id;
1567}
1568
1569te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1570 StageToAxesMap* stage_to_axes,
1571 te::Schedule* schedule) const {
1572 const te::Stage& stage = (*stages)[stage_id];
1573 Array<te::Operation> readers;
1574 for (const auto& i : reader_stage_ids) {
1575 readers.push_back((*stages)[i.IntValue()]->origin_op);
1576 }
1577 auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers);
1578
1579 const auto& new_stage = (*schedule)[out->op];
1580 UpdateStageToAxesMap(new_stage, stage_to_axes);
1581 stages->insert(stages->begin() + stage_id + 1, new_stage);
1582
1583 return out;
1584}
1585
1586String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1587 te::Schedule* schedule) const {
1588 std::stringstream ss;
1589 // Since the original stage will be changed after schedule apply, keep a copy here
1590 // These information will be used to print Python API string later
1591 auto stage = (*stages)[stage_id];
1592 Array<te::Stage> reader_stages;
1593 for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
1594 reader_stages.push_back((*stages)[reader_stage_ids[i].IntValue()]);
1595 }
1596 auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
1597
1598 const auto& op_name = CleanName(out->op->name);
1599 ss << op_name << " = "
1600 << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
1601 << CleanName(reader_stages[0]->op->name);
1602 for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
1603 ss << ", " << CleanName(reader_stages[i]->op->name);
1604 }
1605 ss << "])\n";
1606
1607 // Print the iterators of the new added stage
1608 const auto& iters = out->op->root_iter_vars();
1609 for (size_t i = 0; i < iters.size(); ++i) {
1610 ss << CleanName(iters[i]->var->name_hint, op_name);
1611 if (i != iters.size() - 1) {
1612 ss << ", ";
1613 }
1614 }
1615 ss << " = "
1616 << "tuple(" << op_name << ".op.axis)\n";
1617
1618 return ss.str();
1619}
1620
1621/********** Cache Write **********/
1622CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
1623 auto node = make_object<CacheWriteStepNode>();
1624 node->stage_id = stage_id;
1625 node->scope_name = std::move(scope_name);
1626 data_ = std::move(node);
1627}
1628
1629CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
1630 auto node = make_object<CacheWriteStepNode>();
1631 bool s;
1632 s = reader->NextArrayItem();
1633 ICHECK(s);
1634 reader->Read(&node->stage_id);
1635 s = reader->NextArrayItem();
1636 ICHECK(s);
1637 std::string string_value;
1638 reader->Read(&string_value);
1639 node->scope_name = std::move(string_value);
1640 data_ = std::move(node);
1641}
1642
1643void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1644 writer->WriteArraySeperator();
1645 writer->WriteString(record_prefix_str);
1646 writer->WriteArrayItem(stage_id);
1647 writer->WriteArraySeperator();
1648 writer->WriteString(scope_name);
1649}
1650
1651int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
1652 StateNode* pstate = state->CopyOnWrite();
1653 int last_dag_op_size = pstate->current_compute_dag
1654 ? pstate->current_compute_dag.value().as<ComputeDAGNode>()->ops.size()
1655 : dag->ops.size();
1656 const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
1657 GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
1658 int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
1659 // TODO(jcf94): Update this check to equal after fixing the cache write bug in TVM
1660 ICHECK_GE(added_ops, 1);
1661
1662 // target_stage -> cache_write_stage + target_stage
1663 // Assume no step has been applied to the target stage before cache write.
1664 // Insert a new cache write stage ahead, update the op of the target stage and later stages, then
1665 // update the stage_id mapping in AttachMap
1666 pstate->stages.insert(pstate->stages.begin() + stage_id,
1667 Stage(current_compute_dag->ops[stage_id]));
1668 pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 1]));
1669 int next_stage_id = stage_id + 2;
1670 // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 support.
1671 // TVM's cache_write has a bug with multi outputs. See
1672 // `tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write` test
1673 // for more details
1674 if (added_ops == 2) {
1675 pstate->stages.insert(pstate->stages.begin() + next_stage_id,
1676 Stage(current_compute_dag->ops[next_stage_id]));
1677 next_stage_id++;
1678 } else if (added_ops > 2) {
1679 LOG(ERROR) << "Unexpected behavior of CacheWrite.";
1680 }
1681 for (size_t i = next_stage_id; i < current_compute_dag->ops.size(); ++i) {
1682 Stage tmp_stage = pstate->stages[i];
1683 tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
1684 pstate->stages.Set(i, std::move(tmp_stage));
1685 }
1686 pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id, added_ops);
1687 pstate->current_compute_dag = std::move(current_compute_dag);
1688
1689 return stage_id;
1690}
1691
1692Array<te::Tensor> CacheWriteStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1693 StageToAxesMap* stage_to_axes,
1694 te::Schedule* schedule) const {
1695 const te::Stage& stage = (*stages)[stage_id];
1696 Array<te::Tensor> tensor_array;
1697 // If the target stage has multi outputs, TVM requires to cache_write
1698 // all of them or schedule.cache_write will raise an error
1699 for (auto i = 0; i < stage->op->num_outputs(); ++i) {
1700 tensor_array.push_back(stage->origin_op.output(i));
1701 }
1702 auto outs = schedule->cache_write(tensor_array, scope_name);
1703
1704 UpdateStageToAxesMap(stage, stage_to_axes);
1705 // Even if there is multi outputs, TVM schedule only generate one
1706 // new stage
1707 const auto& new_stage = (*schedule)[outs[0]->op];
1708 UpdateStageToAxesMap(new_stage, stage_to_axes);
1709 stages->insert(stages->begin() + stage_id, new_stage);
1710
1711 return outs;
1712}
1713
1714String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1715 te::Schedule* schedule) const {
1716 std::stringstream ss;
1717 // Since the original stage will be changed after schedule apply, keep a copy here
1718 // These information will be used to print Python API string later
1719 te::Stage stage = (*stages)[stage_id];
1720 auto outs = ApplyToSchedule(stages, stage_to_axes, schedule);
1721
1722 for (size_t i = 0; i < outs.size(); ++i) {
1723 ss << CleanName(outs[i]->op->name) << ", ";
1724 }
1725 ss << "= "
1726 << "s.cache_write([" << CleanName(stage->op.output(0)->op->name);
1727 for (auto i = 1; i < stage->op->num_outputs(); ++i) {
1728 ss << ", " << CleanName(stage->op.output(i)->op->name);
1729 }
1730 ss << "], \"" << scope_name << "\")\n";
1731
1732 // Print the iterators of the new added stage
1733 for (const auto& out : outs) {
1734 const auto& iters = out->op->root_iter_vars();
1735 const auto& op_name = CleanName(out->op->name);
1736 for (size_t i = 0; i < iters.size(); ++i) {
1737 ss << CleanName(iters[i]->var->name_hint, op_name);
1738 if (i != iters.size() - 1) {
1739 ss << ", ";
1740 }
1741 }
1742 ss << " = "
1743 << "tuple(" << op_name << ".op.axis)"
1744 << " + "
1745 << "tuple(" << op_name << ".op.reduce_axis)\n";
1746 }
1747
1748 return ss.str();
1749}
1750
1751/********** Rfactor **********/
1752RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) {
1753 auto node = make_object<RfactorStepNode>();
1754 node->stage_id = stage_id;
1755 node->iter_id = iter_id;
1756 node->factor_iter_id = factor_iter_id;
1757 data_ = std::move(node);
1758}
1759
1760RfactorStep::RfactorStep(dmlc::JSONReader* reader) {
1761 auto node = make_object<RfactorStepNode>();
1762 bool s;
1763 s = reader->NextArrayItem();
1764 ICHECK(s);
1765 reader->Read(&node->stage_id);
1766 s = reader->NextArrayItem();
1767 ICHECK(s);
1768 reader->Read(&node->iter_id);
1769 s = reader->NextArrayItem();
1770 ICHECK(s);
1771 reader->Read(&node->factor_iter_id);
1772 data_ = std::move(node);
1773}
1774
1775void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
1776 writer->WriteArraySeperator();
1777 writer->WriteString(record_prefix_str);
1778 writer->WriteArrayItem(stage_id);
1779 writer->WriteArrayItem(iter_id);
1780 writer->WriteArrayItem(factor_iter_id);
1781}
1782
1783int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
1784 StateNode* pstate = state->CopyOnWrite();
1785 const auto& compute_at_type = pstate->stages[stage_id]->compute_at;
1786 const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(
1787 GetFormerStageModifiableSteps(GetRef<Step>(this), (*state)->transform_steps));
1788
1789 // target_stage -> rfactor_compute + target_stage
1790 // Insert a new compute stage, update the target stage and later stage, then update the stage_id
1791 // mapping in AttachMap
1792 pstate->stages.insert(pstate->stages.begin() + stage_id,
1793 Stage(current_compute_dag->ops[stage_id]));
1794 // Maintain the compute_at type of the target stage
1795 Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]);
1796 target_stage.CopyOnWrite()->compute_at = compute_at_type;
1797 pstate->stages.Set(stage_id + 1, std::move(target_stage));
1798 for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) {
1799 Stage stage = pstate->stages[i];
1800 stage.CopyOnWrite()->op = current_compute_dag->ops[i];
1801 pstate->stages.Set(i, std::move(stage));
1802 }
1803 pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(stage_id);
1804 pstate->current_compute_dag = std::move(current_compute_dag);
1805
1806 return stage_id;
1807}
1808
1809Array<te::Tensor> RfactorStepNode::ApplyToSchedule(Array<te::Stage>* stages,
1810 StageToAxesMap* stage_to_axes,
1811 te::Schedule* schedule) const {
1812 const auto& stage = (*stages)[stage_id];
1813 const Array<IterVar>& axes = (*stage_to_axes)[stage];
1814
1815 const te::Tensor& tensor = stage->origin_op.output(0);
1816 const IterVar& axis = axes[iter_id];
1817 auto outs = schedule->rfactor(tensor, axis, factor_iter_id);
1818
1819 UpdateStageToAxesMap(stage, stage_to_axes);
1820 const auto& new_stage = (*schedule)[outs[0]->op];
1821 UpdateStageToAxesMap(new_stage, stage_to_axes);
1822 stages->insert(stages->begin() + stage_id, new_stage);
1823
1824 return outs;
1825}
1826
1827String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1828 te::Schedule* schedule) const {
1829 std::stringstream ss;
1830 const auto& stage = (*stages)[stage_id];
1831
1832 const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name);
1833 const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint);
1834
1835 const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule);
1836
1837 for (size_t i = 0; i < outs.size(); ++i) {
1838 ss << CleanName(outs[i]->op->name);
1839 if (i != outs.size() - 1) {
1840 ss << ", ";
1841 }
1842 }
1843 ss << " = "
1844 << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n";
1845
1846 for (const auto& out : outs) {
1847 const auto& iters = out->op->root_iter_vars();
1848 const auto& op_name = CleanName(out->op->name);
1849 for (size_t i = 0; i < iters.size(); ++i) {
1850 ss << CleanName(iters[i]->var->name_hint, op_name);
1851 if (i != iters.size() - 1) {
1852 ss << ", ";
1853 }
1854 }
1855 ss << " = "
1856 << "tuple(" << op_name << ".op.axis)"
1857 << " + "
1858 << "tuple(" << op_name << ".op.reduce_axis)\n";
1859 }
1860
1861 const auto& output = (*stages)[stage_id + 1]->op.output(0);
1862 const auto& iters = output->op->root_iter_vars();
1863 const auto& op_name = CleanName(output->op->name);
1864 for (size_t i = 0; i < iters.size(); ++i) {
1865 ss << CleanName(iters[i]->var->name_hint, op_name);
1866 if (i != iters.size() - 1) {
1867 ss << ", ";
1868 }
1869 }
1870 ss << " = "
1871 << "tuple(s[" << op_name << "].op.axis)"
1872 << " + "
1873 << "tuple(s[" << op_name << "].op.reduce_axis)\n";
1874
1875 return ss.str();
1876}
1877
1878} // namespace auto_scheduler
1879} // namespace tvm
1880