1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/loop_state.cc
22 * \brief An lightweight IR (intermediate representation) for loop structures.
23 * see auto_scheduler/loop_state.h for more explanation.
24 */
25
26#include <tvm/auto_scheduler/compute_dag.h>
27#include <tvm/auto_scheduler/loop_state.h>
28#include <tvm/auto_scheduler/transform_step.h>
29#include <tvm/runtime/registry.h>
30#include <tvm/te/operation.h>
31
32#include <utility>
33
34#include "utils.h"
35
36namespace tvm {
37namespace auto_scheduler {
38
39TVM_REGISTER_OBJECT_TYPE(StepNode);
40TVM_REGISTER_NODE_TYPE(StageNode);
41TVM_REGISTER_NODE_TYPE(StateNode);
42TVM_REGISTER_NODE_TYPE(IteratorNode);
43
44/********** Iterator **********/
45Iterator::Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
46 const std::vector<Iterator>* orig_iters) {
47 auto node = make_object<IteratorNode>();
48 node->name = std::move(name);
49 node->range = std::move(range);
50 node->iter_kind = iter_kind;
51 node->annotation = annotation;
52 if (orig_iters != nullptr) {
53 node->orig_iters = *orig_iters;
54 }
55 data_ = std::move(node);
56}
57
58/********** Stage **********/
59Stage::Stage(te::Operation op) {
60 auto node = make_object<StageNode>();
61 if (op->IsInstance<te::ComputeOpNode>()) {
62 node->op_type = StageKind::kCompute;
63 auto* pop = op.as<te::ComputeOpNode>();
64 for (const auto& axis : pop->axis) {
65 node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
66 IteratorKind::kSpatial, IteratorAnnotation::kNone));
67 }
68 for (const auto& axis : pop->reduce_axis) {
69 node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom,
70 IteratorKind::kReduction, IteratorAnnotation::kNone));
71 }
72 } else if (op->IsInstance<te::PlaceholderOpNode>()) {
73 node->op_type = StageKind::kPlaceholder;
74 } else {
75 LOG(FATAL) << "Unsupported operator type" << op->_type_key;
76 }
77
78 node->compute_at = ComputeAtKind::kRoot;
79 node->op = std::move(op);
80 node->attrs.auto_unroll_max_step = 0;
81 node->attrs.storage_offset = 0;
82 data_ = std::move(node);
83}
84
85Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
86 ComputeAtKind compute_at, StageAttributes attrs) {
87 auto node = make_object<StageNode>();
88 node->op = std::move(op);
89 node->op_type = op_type;
90 node->iters = iters;
91 node->compute_at = compute_at;
92 node->attrs = attrs;
93 data_ = std::move(node);
94}
95
96/********** AttachMap **********/
97void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
98 AttachMapNode* pnode = CopyOnWrite();
99
100 // Delete the current entry of this stage
101 DeleteStageEntry(pnode, stage_id);
102
103 // Store the new stage/iterator relations to map
104 IterKey iter_key(target_stage_id, target_iter_id);
105 pnode->stage_to_attach_iter[stage_id] = iter_key;
106 pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
107}
108
109void AttachMap::DeleteStage(int stage_id) {
110 AttachMapNode* pnode = CopyOnWrite();
111 // Delete the original stage entry
112 DeleteStageEntry(pnode, stage_id);
113}
114
115void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
116 const std::vector<IterKey>& new_iters) {
117 ICHECK_EQ(original_iters.size(), new_iters.size());
118 AttachMapNode* pnode = CopyOnWrite();
119 std::unordered_map<IterKey, std::vector<StageKey>> new_iter_to_attached_stages;
120 for (size_t i = 0; i < original_iters.size(); ++i) {
121 auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
122 // We get <IterKey, std::vector<StageKey>> from this map
123 if (entry == pnode->iter_to_attached_stages.end()) {
124 // Skip if this iterator does not have any attach relations
125 continue;
126 }
127
128 // Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
129 for (const auto& s : entry->second) {
130 pnode->stage_to_attach_iter[s] = new_iters[i];
131 }
132
133 // Remove the original iterator relation from `iter_to_attached_stages` and add the new
134 // iterator to it
135 std::vector<int> attached_stages = std::move(entry->second);
136 pnode->iter_to_attached_stages.erase(entry);
137 new_iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
138 }
139
140 // Update new entries
141 for (auto& it : new_iter_to_attached_stages) {
142 pnode->iter_to_attached_stages[it.first] = std::move(it.second);
143 }
144}
145
146void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
147 auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
148 // We get <StageKey, IterKey> from this map
149 if (old_entry != pnode->stage_to_attach_iter.end()) {
150 // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
151 // any attached stage, delete this iterm too
152 auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
153 // We get <IterKey, std::vector<StageKey>> from this map
154 FindAndDeleteItem(&entry2->second, stage_id);
155 if (entry2->second.size() == 0) {
156 pnode->iter_to_attached_stages.erase(entry2);
157 }
158 // Delete the stage in `stage_to_attach_iter`
159 pnode->stage_to_attach_iter.erase(old_entry);
160 }
161}
162
163AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
164 AttachMap map = AttachMap(make_object<AttachMapNode>());
165 auto pmap = map.CopyOnWrite();
166 for (const auto& x : operator->()->stage_to_attach_iter) {
167 auto key = x.first;
168 if (key >= start_id) {
169 key += offset;
170 }
171 auto value = x.second;
172 if (value.first >= start_id) {
173 value.first += offset;
174 }
175 pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
176 }
177 for (const auto& x : operator->()->iter_to_attached_stages) {
178 auto key = x.first;
179 if (key.first >= start_id) {
180 key.first += offset;
181 }
182 auto value = x.second;
183 for (auto& i : value) {
184 if (i >= start_id) {
185 i += offset;
186 }
187 }
188 pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
189 }
190 return map;
191}
192
193/********** State **********/
194State::State(const Array<te::Operation>& ops) {
195 auto node = make_object<StateNode>();
196 for (const auto& op : ops) {
197 node->stages.push_back(Stage(op));
198 }
199 node->attach_map = AttachMap(make_object<AttachMapNode>());
200 node->concrete = true;
201 data_ = std::move(node);
202}
203
204/********** Schedule primitives apis for state **********/
205Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
206 const Stage& stage = operator->()->stages[stage_id];
207 if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
208 LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
209 << "kThreadX, kThreadY, kBlockZ, kThreadZ";
210 }
211 AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
212 CopyOnWrite()->transform_steps.push_back(step);
213 return step->ApplyToState(this);
214}
215
216Iterator State::parallel(int stage_id, const Iterator& it) {
217 const Stage& stage = operator->()->stages[stage_id];
218 AnnotationStep step =
219 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
220 CopyOnWrite()->transform_steps.push_back(step);
221 return step->ApplyToState(this);
222}
223
224Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
225 const Stage& stage = operator->()->stages[stage_id];
226
227 // Don't unroll if the extent is larger than max_unroll
228 if (max_unroll != -1 && it->range.defined()) {
229 if (auto imm = it->range->extent.as<IntImmNode>()) {
230 if (imm->value > max_unroll) {
231 return it;
232 }
233 }
234 }
235
236 AnnotationStep step =
237 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
238 CopyOnWrite()->transform_steps.push_back(step);
239 return step->ApplyToState(this);
240}
241
242Iterator State::vectorize(int stage_id, const Iterator& it) {
243 const Stage& stage = operator->()->stages[stage_id];
244 AnnotationStep step =
245 AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
246 CopyOnWrite()->transform_steps.push_back(step);
247 return step->ApplyToState(this);
248}
249
250Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
251 const Stage& stage = operator->()->stages[stage_id];
252 Array<Integer> indices;
253 GetIndices(stage->iters, iters, &indices);
254 FuseStep step = FuseStep(stage_id, indices);
255 CopyOnWrite()->transform_steps.push_back(step);
256 return step->ApplyToState(this);
257}
258
259void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
260 const Stage& stage = operator->()->stages[stage_id];
261 PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
262 CopyOnWrite()->transform_steps.push_back(step);
263 return step->ApplyToState(this);
264}
265
266void State::reorder(int stage_id, const Array<Iterator>& order) {
267 const Stage& stage = operator->()->stages[stage_id];
268 ICHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
269 << "should be specified";
270 Array<Integer> after_ids;
271 GetIndices(stage->iters, order, &after_ids);
272 ReorderStep step = ReorderStep(stage_id, after_ids);
273 CopyOnWrite()->transform_steps.push_back(step);
274 step->ApplyToState(this);
275}
276
277Array<Iterator> State::split(int stage_id, const Iterator& it,
278 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
279 const Stage& stage = operator->()->stages[stage_id];
280 SplitStep step =
281 SplitStep(stage_id, GetIndex(stage->iters, it),
282 it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
283 CopyOnWrite()->transform_steps.push_back(step);
284 return step->ApplyToState(this);
285}
286
287Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
288 int n_split) {
289 const Stage& stage = operator->()->stages[stage_id];
290 FollowSplitStep step =
291 FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
292 CopyOnWrite()->transform_steps.push_back(step);
293 return step->ApplyToState(this);
294}
295
296Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
297 const Array<Integer>& src_step_ids, int level,
298 bool factor_or_nparts) {
299 const Stage& stage = operator->()->stages[stage_id];
300 FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
301 src_step_ids, level, factor_or_nparts);
302 CopyOnWrite()->transform_steps.push_back(step);
303 return step->ApplyToState(this);
304}
305
306void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
307 const Stage& stage = operator->()->stages[stage_id];
308 StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
309 CopyOnWrite()->transform_steps.push_back(step);
310 return step->ApplyToState(this);
311}
312
313void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
314 const Stage& target_stage = operator->()->stages[target_stage_id];
315 ComputeAtStep step =
316 ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
317 CopyOnWrite()->transform_steps.push_back(step);
318 step->ApplyToState(this);
319}
320
321void State::compute_inline(int stage_id) {
322 ComputeInlineStep step = ComputeInlineStep(stage_id);
323 CopyOnWrite()->transform_steps.push_back(step);
324 step->ApplyToState(this);
325}
326
327void State::compute_root(int stage_id) {
328 ComputeRootStep step = ComputeRootStep(stage_id);
329 CopyOnWrite()->transform_steps.push_back(step);
330 step->ApplyToState(this);
331}
332
333int State::cache_read(int stage_id, const String& scope_name,
334 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
335 CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
336 CopyOnWrite()->transform_steps.push_back(step);
337 return step->ApplyToState(this, dag);
338}
339
340int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
341 CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
342 CopyOnWrite()->transform_steps.push_back(step);
343 return step->ApplyToState(this, dag);
344}
345
346int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
347 const Stage& stage = operator->()->stages[stage_id];
348 RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
349 CopyOnWrite()->transform_steps.push_back(step);
350 return step->ApplyToState(this, dag);
351}
352
353// Print stage to ostream
354void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
355 bool delete_trivial_loop) {
356 const Stage& stage = state->stages[stage_id];
357
358 if (stage->attrs.auto_unroll_max_step != 0) {
359 for (size_t j = 0; j < base_indent; ++j) {
360 *os << " ";
361 }
362 *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n";
363 }
364 if (stage->attrs.storage_offset != 0) {
365 for (size_t j = 0; j < base_indent; ++j) {
366 *os << " ";
367 }
368 *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n";
369 }
370
371 size_t indent = 0;
372 for (size_t i = 0; i < stage->iters.size(); ++i) {
373 const Iterator& iter = stage->iters[i];
374
375 if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) {
376 for (size_t j = 0; j < base_indent + indent; ++j) {
377 *os << " ";
378 }
379 *os << IteratorAnnotationString[static_cast<int>(iter->annotation)] << " ";
380 if (iter->range.defined()) {
381 *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")";
382 } else {
383 *os << iter->name << " (None)";
384 }
385 *os << "\n";
386
387 indent += 2;
388 }
389
390 if (state.defined()) {
391 IterKey iter_key(stage_id, i);
392 auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
393 if (pair != state->attach_map->iter_to_attached_stages.end()) {
394 // Print the attached stage
395 for (const auto& attach_stage_id : pair->second) {
396 PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
397 }
398 }
399 }
400 }
401
402 for (size_t j = 0; j < base_indent + indent; ++j) {
403 *os << " ";
404 }
405 *os << stage->op->name << " = ...\n";
406}
407
408// Print state to ostream
409void PrintState(std::ostream* os, const State& state, bool delete_trivial_loop) {
410 // Gather placeholders
411 Array<String> placeholders;
412 for (const auto& stage : state->stages) {
413 if (stage->op_type == StageKind::kPlaceholder) {
414 placeholders.push_back(stage->op->name);
415 }
416 }
417
418 *os << "Placeholder: ";
419 for (size_t i = 0; i < placeholders.size(); ++i) {
420 *os << placeholders[i];
421 if (i != placeholders.size() - 1) {
422 *os << ", ";
423 }
424 }
425 *os << "\n";
426
427 // Print all stages
428 for (size_t i = 0; i < state->stages.size(); ++i) {
429 const Stage& stage = state->stages[i];
430 if (stage->op_type == StageKind::kPlaceholder) {
431 continue;
432 } else if (stage->op_type == StageKind::kCompute) {
433 if (stage->compute_at == ComputeAtKind::kRoot) {
434 PrintStage(os, i, state, 0, delete_trivial_loop);
435 }
436 } else {
437 LOG(FATAL) << "Invalid op type";
438 }
439 }
440}
441
442String State::ToStr(bool delete_trivial_loop) const {
443 std::ostringstream os;
444 PrintState(&os, (*this), delete_trivial_loop);
445 return os.str();
446}
447
448TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
449 .set_dispatch<StageNode>([](const ObjectRef& ref, ReprPrinter* p) {
450 const auto& stage = tvm::Downcast<Stage>(ref);
451 p->stream << stage->GetTypeKey() << "(" << stage.get() << ": " << stage->op->name << ")";
452 });
453
454TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
455 .set_dispatch<StateNode>([](const ObjectRef& ref, ReprPrinter* p) {
456 PrintState(&p->stream, tvm::Downcast<State>(ref), true);
457 });
458
459/********** State interface API for ffi **********/
460TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
461 .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
462 const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
463 return Array<ObjectRef>{state, res};
464 });
465
466TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
467 .set_body_typed([](State state, int stage_id, const Iterator& it) {
468 const auto& res = state.parallel(stage_id, it);
469 return Array<ObjectRef>{state, res};
470 });
471
472TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
473 .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
474 const auto& res = state.unroll(stage_id, it, max_unroll);
475 return Array<ObjectRef>{state, res};
476 });
477
478TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
479 .set_body_typed([](State state, int stage_id, const Iterator& it) {
480 const auto& res = state.vectorize(stage_id, it);
481 return Array<ObjectRef>{state, res};
482 });
483
484TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
485 .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
486 const auto& res = state.fuse(stage_id, iters);
487 return Array<ObjectRef>{state, res};
488 });
489
490TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
491 .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
492 state.pragma(stage_id, it, pragma_type);
493 return state;
494 });
495
496TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
497 .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
498 state.reorder(stage_id, order);
499 return state;
500 });
501
502TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
503 .set_body_typed([](State state, int stage_id, const Iterator& it,
504 const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
505 const auto& res = state.split(stage_id, it, lengths, inner_to_outer);
506 return Array<ObjectRef>{state, res};
507 });
508
509TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
510 .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
511 int n_split) {
512 const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
513 return Array<ObjectRef>{state, Array<Iterator>(res)};
514 });
515
516TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
517 .set_body_typed([](State state, int stage_id, const Iterator& it,
518 const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
519 const auto& res =
520 state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
521 return Array<ObjectRef>{state, Array<Iterator>(res)};
522 });
523
524TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
525 .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
526 state.storage_align(stage_id, it, factor, offset);
527 return state;
528 });
529
530TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
531 .set_body_typed([](State state, int stage_id, int target_stage_id,
532 const Iterator& target_iter) {
533 state.compute_at(stage_id, target_stage_id, target_iter);
534 return state;
535 });
536
537TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
538 .set_body_typed([](State state, int stage_id) {
539 state.compute_inline(stage_id);
540 return state;
541 });
542
543TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
544 .set_body_typed([](State state, int stage_id) {
545 state.compute_root(stage_id);
546 return state;
547 });
548
549TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
550 .set_body_typed([](State state, int stage_id, const String& scope_name,
551 const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
552 int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
553 return Array<ObjectRef>{state, Integer(res)};
554 });
555
556TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
557 .set_body_typed([](State state, int stage_id, const String& scope_name,
558 const ComputeDAG& task_dag) {
559 int res = state.cache_write(stage_id, scope_name, task_dag);
560 return Array<ObjectRef>{state, Integer(res)};
561 });
562
563TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
564 .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
565 const ComputeDAG& dag) {
566 int res = state.rfactor(stage_id, it, factor_iter_id, dag);
567 return Array<ObjectRef>{state, Integer(res)};
568 });
569
570TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
571 return std::equal_to<State>()(state1, state2);
572});
573
574} // namespace auto_scheduler
575} // namespace tvm
576