1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/compute_dag.cc
22 * \brief Compute declaration graph and its related analysis tools.
23 */
24
25#include <tvm/auto_scheduler/compute_dag.h>
26#include <tvm/auto_scheduler/loop_state.h>
27#include <tvm/auto_scheduler/search_policy.h>
28#include <tvm/auto_scheduler/transform_step.h>
29#include <tvm/runtime/registry.h>
30#include <tvm/support/parallel_for.h>
31#include <tvm/te/operation.h>
32#include <tvm/te/schedule.h>
33#include <tvm/te/schedule_pass.h>
34#include <tvm/tir/builtin.h>
35#include <tvm/tir/stmt_functor.h>
36#include <tvm/topi/transform.h>
37
38#include <algorithm>
39#include <cstdint>
40#include <queue>
41#include <unordered_map>
42#include <unordered_set>
43#include <vector>
44
45#include "../arith/pattern_match.h"
46#include "../relay/transforms/auto_scheduler_layout_rewrite.h"
47#include "search_policy/utils.h"
48#include "utils.h"
49
50namespace tvm {
51namespace auto_scheduler {
52
53using namespace tvm::tir;
54
55template <class T>
56using OperationMap = AccessAnalyzerNode::OperationMap<T>;
57using OperationSet = std::unordered_set<te::Operation, ObjectHash, ObjectEqual>;
58
59TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
60
61// Topo-sort ops from tensors according to their read-write relations.
62Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) {
63 std::unordered_map<const te::OperationNode*, int> degree;
64 std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*>> edge_set;
65 std::unordered_map<const te::OperationNode*, int> priority;
66 std::unordered_set<const te::OperationNode*> visited;
67
68 // traverse to build edge_set and count degree
69 std::vector<const te::OperationNode*> stack;
70 stack.reserve(tensors.size());
71 for (const auto& x : tensors) {
72 stack.push_back(x->op.operator->());
73 }
74
75 int ct = 0;
76 while (!stack.empty()) {
77 const te::OperationNode* op = stack.back();
78 stack.pop_back();
79 if (visited.count(op)) {
80 continue;
81 }
82
83 priority[op] = ct;
84 ct++;
85 visited.insert(op);
86
87 if (op->IsInstance<te::PlaceholderOpNode>()) {
88 degree[op] = 0;
89 } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
90 const Array<te::Tensor>& input_tensors = cop->InputTensors();
91 degree[op] = input_tensors.size();
92 for (const auto& ten : input_tensors) {
93 edge_set[ten->op.operator->()].push_back(op);
94 stack.push_back(ten->op.operator->());
95 }
96 } else {
97 LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
98 }
99 }
100
101 // topo sort
102 Array<te::Operation> ops;
103
104 using Item = std::pair<const te::OperationNode*, int>;
105 auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
106 std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
107 for (const auto& iter : degree) {
108 if (iter.second == 0) {
109 queue.push(Item(iter.first, priority[iter.first]));
110 }
111 }
112
113 ops.reserve(degree.size());
114 while (!queue.empty()) {
115 Item item = queue.top();
116 queue.pop();
117 ops.push_back(GetRef<te::Operation>(item.first));
118 for (const auto& dst : edge_set[item.first]) {
119 degree[dst] -= 1;
120 if (degree[dst] == 0) {
121 queue.push(Item(dst, priority[dst]));
122 }
123 }
124 }
125
126 return ops;
127}
128
129// Extract all tensor accesses in an expr
130class ReadAccessExtractor : public StmtExprVisitor {
131 public:
132 void Extract(PrimExpr expr) { this->VisitExpr(expr); }
133
134 void VisitExpr_(const CallNode* op) final {
135 if (op->op.same_as(builtin::if_then_else())) {
136 has_branch = true;
137 }
138 StmtExprVisitor::VisitExpr_(op);
139 }
140
141 void VisitExpr_(const ProducerLoadNode* op) final {
142 read_access[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(),
143 op->indices.end());
144 StmtExprVisitor::VisitExpr_(op);
145 }
146
147 void VisitStmt_(const IfThenElseNode* op) final {
148 has_branch = true;
149 StmtExprVisitor::VisitStmt_(op);
150 }
151
152 void VisitExpr_(const SelectNode* op) final {
153 has_branch = true;
154 StmtExprVisitor::VisitExpr_(op);
155 }
156
157 // All read accesses to all operations
158 // The innermost vector stores multi-dimensional indices.
159 // The middle vector stores possible multiple accesses
160 OperationMap<std::vector<std::vector<PrimExpr>>> read_access;
161 // Whether this expression has branch
162 bool has_branch{false};
163};
164
165// Returns whether the expr equals to the var with an optional const shift
166bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) {
167 arith::PVar<PrimExpr> x;
168 arith::PVar<IntImm> c;
169
170 if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) &&
171 x.Eval().same_as(var)) {
172 return true;
173 }
174 return false;
175}
176
177// Return whether the access to an operation is a simple access
178// (i.e. all index is just a variable with an optional constant shift)
179// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not.
180bool IsSimpleAccess(const te::Operation& op, const std::vector<PrimExpr>& indices,
181 bool* axis_missing, bool* axis_duplicated, bool* same_order) {
182 auto cop = op.as<te::ComputeOpNode>();
183 if (cop == nullptr) {
184 return false;
185 }
186
187 std::vector<int> index_to_var_idx;
188 std::vector<int> var_idx_ct(cop->axis.size(), 0);
189
190 for (const auto& expr : indices) {
191 if (!is_const_int(expr)) {
192 bool found = false;
193 for (size_t i = 0; i < cop->axis.size(); ++i) {
194 if (IsConstShiftEqual(cop->axis[i]->var, expr)) {
195 index_to_var_idx.push_back(i);
196 var_idx_ct[i]++;
197 found = true;
198 break;
199 }
200 }
201 if (!found) {
202 return false;
203 }
204 }
205 }
206
207 *axis_missing = false; // Some axes are missing
208 *axis_duplicated = false; // Some axes appear more than once
209 *same_order = true; // The axis order is the same as op->axis
210 for (int ct : var_idx_ct) {
211 if (ct == 0) {
212 *axis_missing = true;
213 } else if (ct > 1) {
214 *axis_duplicated = true;
215 }
216 }
217 for (size_t i = 1; i < index_to_var_idx.size(); ++i) {
218 if (index_to_var_idx[i] < index_to_var_idx[i - 1]) {
219 *same_order = false;
220 break;
221 }
222 }
223
224 return true;
225}
226
227// Gather all VarNodes in an expr
228void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) {
229 PostOrderVisit(expr, [&vars](const ObjectRef& node) {
230 if (const VarNode* op = node.as<VarNode>()) {
231 vars->insert(op);
232 }
233 });
234}
235
236// Check whether an expr has expensive operations (e.g. exp)
237bool HasExpensiveOp(const PrimExpr& expr) {
238 bool found = false;
239 PostOrderVisit(expr, [&found](const ObjectRef& node) {
240 if (const CallNode* op = node.as<CallNode>()) {
241 if (op->op.as<OpNode>()->name == "tir.exp") {
242 found = true;
243 }
244 }
245 });
246 return found;
247}
248
249AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) {
250 auto node = make_object<AccessAnalyzerNode>();
251 OperationMap<bool> has_branch;
252
253 // Get all ops in topological order
254 node->ops_topo_order = TopoSortOps(tensors);
255
256 arith::Analyzer analyzer;
257
258 // Build read & write access map
259 for (const auto& op : node->ops_topo_order) {
260 if (op->IsInstance<te::PlaceholderOpNode>()) {
261 node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>();
262 } else if (auto cop = op.as<te::ComputeOpNode>()) {
263 ReadAccessExtractor extractor;
264 for (const auto& exp : cop->body) {
265 extractor.Extract(exp);
266 }
267
268 // read_by and read_from map
269 for (const auto& iter : extractor.read_access) {
270 std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op];
271 accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end());
272 }
273
274 node->read_from[op] = std::move(extractor.read_access);
275 has_branch[op] = extractor.has_branch;
276
277 // compute number of common outer iterators
278 for (const auto& pair : node->read_from[op]) {
279 const te::Operation& producer = pair.first;
280 const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
281 const Array<PrimExpr>& output_shape = op->output_shape(0);
282 const Array<PrimExpr>& producer_shape = producer->output_shape(0);
283
284 int n_common;
285 for (n_common = 0;
286 n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size()));
287 n_common++) {
288 if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) {
289 break;
290 }
291
292 bool injective = true;
293 for (const auto& access : access_list) {
294 if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) {
295 injective = false;
296 break;
297 }
298 }
299
300 if (!injective) {
301 break;
302 }
303 }
304
305 node->num_common_outer_iterators[op][producer] = n_common;
306 node->num_common_outer_iterators[producer][op] = n_common;
307 }
308 } else {
309 LOG(FATAL) << "Invalid op: " << op;
310 }
311 }
312
313 // Do some static analysis on ComputeOps
314 for (const auto& op : node->ops_topo_order) {
315 if (op->IsInstance<te::PlaceholderOpNode>()) {
316 node->is_simple_access[op] = true;
317 node->needs_multi_level_tiling[op] = false;
318 node->is_strictly_inlineable[op] = false;
319 node->is_output[op] = false;
320 } else if (auto cop = op.as<te::ComputeOpNode>()) {
321 // check whether this op is element-wise and strict-inlineable
322 bool is_simple_access = true;
323 bool is_strictly_inlineable = true;
324
325 bool axis_missing, axis_duplicated, same_order;
326 for (const auto& pair : node->read_from[op]) {
327 const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
328 for (const auto& access : access_list) {
329 if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated,
330 &same_order)) {
331 is_simple_access = false;
332 is_strictly_inlineable = false;
333 break;
334 }
335 if (!same_order || axis_duplicated) {
336 // do not strictly inline transpose
337 is_strictly_inlineable = false;
338 }
339 }
340 if (!is_simple_access) {
341 break;
342 }
343 }
344
345 // don't strictly inline expensive op (e.g. exp)
346 bool has_expensive_op = false;
347 for (const auto& expr : cop->body) {
348 has_expensive_op |= HasExpensiveOp(expr);
349 }
350 if (has_expensive_op || has_branch[op]) {
351 is_strictly_inlineable = false;
352 }
353
354 // constant tensor is strict-inlineable
355 if (node->read_from[op].empty()) {
356 is_strictly_inlineable = true;
357 }
358
359 node->is_simple_access[op] = is_simple_access;
360 node->is_strictly_inlineable[op] = is_strictly_inlineable;
361
362 // check whether the op needs multi-level tiling
363 bool needs_multi_level_tiling = false;
364 int n_missing = 0;
365
366 for (const auto& pair : node->read_from[op]) {
367 const std::vector<std::vector<PrimExpr>>& access_list = pair.second;
368 std::unordered_set<const VarNode*> vars;
369 for (const std::vector<PrimExpr>& access : access_list) {
370 for (const PrimExpr& expr : access) {
371 GatherVars(expr, &vars);
372 }
373 }
374
375 for (const auto& axis : cop->axis) {
376 if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) {
377 n_missing++;
378 break;
379 }
380 }
381
382 if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) {
383 needs_multi_level_tiling = true;
384 break;
385 }
386 }
387
388 // do not perform multi-level tiling on "fake reduction" with const tensors
389 if (op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)) {
390 needs_multi_level_tiling = false;
391 }
392
393 node->needs_multi_level_tiling[op] = needs_multi_level_tiling;
394
395 // check whether the op is output
396 node->is_output[op] = node->read_by[op].empty();
397 } else {
398 LOG(FATAL) << "Invalid op" << op;
399 }
400 }
401
402 data_ = std::move(node);
403}
404
405bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const {
406 return operator->()->needs_multi_level_tiling.at(op);
407}
408
409bool AccessAnalyzer::IsOutput(const te::Operation& op) const {
410 return operator->()->is_output.at(op);
411}
412
413bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const {
414 return operator->()->is_simple_access.at(op);
415}
416
417bool AccessAnalyzer::IsStrictlyInlineable(const te::Operation& op) const {
418 return operator->()->is_strictly_inlineable.at(op);
419}
420
421OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const {
422 OperationSet inlined_ops;
423 for (const auto& stage : state->stages) {
424 if (stage->compute_at == ComputeAtKind::kInlined) {
425 inlined_ops.insert(stage->op);
426 }
427 }
428
429 OperationSet consumers;
430 std::function<void(const te::Operation&)> collect;
431 collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) {
432 for (const auto& iter : operator->()->read_by.at(op)) {
433 if (inlined_ops.count(iter.first)) {
434 collect(iter.first);
435 } else {
436 consumers.insert(iter.first);
437 }
438 }
439 };
440
441 collect(op);
442 return consumers;
443}
444
445OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const {
446 OperationSet producers;
447 for (const auto& iter : operator->()->read_from.at(op)) {
448 producers.insert(iter.first);
449 }
450 return producers;
451}
452
453OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const {
454 OperationSet inlined_ops;
455 for (const auto& stage : state->stages) {
456 if (stage->compute_at == ComputeAtKind::kInlined) {
457 inlined_ops.insert(stage->op);
458 }
459 }
460
461 OperationSet producers;
462 std::function<void(const te::Operation&)> collect;
463 collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) {
464 for (const auto& iter : operator->()->read_from.at(op)) {
465 if (inlined_ops.count(iter.first)) {
466 collect(iter.first);
467 } else {
468 producers.insert(iter.first);
469 }
470 }
471 };
472
473 collect(op);
474 return producers;
475}
476
477int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op,
478 const te::Operation& target_op) const {
479 int ret = INT32_MAX;
480 bool meet = false;
481
482 std::function<void(const te::Operation&, int)> traverse;
483 traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) {
484 if (cur_op == target_op) {
485 ret = std::min(ret, cur_num);
486 meet = true;
487 return;
488 }
489
490 for (const auto& iter : operator->()->read_by.at(cur_op)) {
491 traverse(
492 iter.first,
493 std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first)));
494 }
495 };
496
497 traverse(op, op->output_shape(0).size());
498 return meet ? ret : 0;
499}
500
501bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op,
502 const te::Operation& target_op) const {
503 te::Operation cur_op = op;
504 while (cur_op != target_op) {
505 const AccessAnalyzerNode::OperationMap<std::vector<std::vector<PrimExpr>>>& map =
506 operator->()->read_by.at(cur_op);
507
508 if (map.size() != 1) {
509 return false;
510 }
511 te::Operation next_op = map.begin()->first;
512
513 // Check condition 1: They have the same output size
514 auto p_cur = cur_op.as<te::ComputeOpNode>();
515 auto p_next = next_op.as<te::ComputeOpNode>();
516 if (p_cur == nullptr || p_next == nullptr) {
517 return false;
518 }
519
520 Array<PrimExpr> output_shape = p_cur->output_shape(0);
521 for (int i = 1; i < p_cur->num_outputs(); ++i) {
522 if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) {
523 return false;
524 }
525 }
526 for (int i = 0; i < p_next->num_outputs(); ++i) {
527 if (!IntArrayEqual(p_next->output_shape(i), output_shape)) {
528 return false;
529 }
530 }
531
532 // Check condition 2: The read is elementwise
533 const std::vector<std::vector<PrimExpr>> reads = map.begin()->second;
534 bool is_simple_access, axis_missing, axis_duplicated, same_order;
535 for (const auto& read : reads) {
536 is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing,
537 &axis_duplicated, &same_order);
538 if (!is_simple_access || axis_missing || axis_duplicated || !same_order) {
539 return false;
540 }
541 }
542
543 cur_op = std::move(next_op);
544 }
545 return true;
546}
547
548// Estimate the number of float operations in an expression
549class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
550 public:
551 double EstimateFlop(const Array<te::Operation>& ops) {
552 double ret = 0;
553 for (const auto& op : ops) {
554 if (auto pop = op.as<te::ComputeOpNode>()) {
555 if (pop->attrs.count("FLOP")) {
556 // Use user-provided FLOP
557 auto pint = pop->attrs["FLOP"].as<IntImmNode>();
558 ICHECK(pint != nullptr);
559 ret += pint->value;
560 } else {
561 // Estimate by parsing the compute body
562 double num_element = AxisLengthProd(pop->axis);
563 if (num_element == -1) {
564 fail_ = true;
565 break;
566 }
567 cur_type_code_ = pop->output_dtype(0).code();
568 double op_per_element = 0;
569 for (const auto& x : pop->body) {
570 op_per_element += VisitExpr(x);
571 }
572 ret += num_element * op_per_element;
573 }
574 } else if (op->IsInstance<te::PlaceholderOpNode>()) {
575 {} // do nothing
576 } else {
577 LOG(FATAL) << "Invalid op type " << op;
578 }
579 }
580
581 return fail_ ? -1 : ret;
582 }
583
584 double VisitExpr_(const ReduceNode* op) final {
585 uint64_t num_iter = 1;
586 for (const auto& x : op->axis) {
587 if (auto imm = x->dom->extent.as<IntImmNode>()) {
588 num_iter *= imm->value;
589 } else {
590 fail_ = true;
591 num_iter = -1;
592 }
593 }
594 double body_flop = 0;
595 for (size_t i = 0; i < op->combiner->result.size(); ++i) {
596 body_flop += VisitExpr(op->combiner->result[i]);
597 body_flop += VisitExpr(op->source[i]);
598 }
599 return num_iter * body_flop;
600 }
601
602 double VisitExpr_(const FloatImmNode* op) final { return 0.0; }
603 double VisitExpr_(const IntImmNode* op) final { return 0.0; }
604 double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; }
605
606 double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
607 double VisitExpr_(const VarNode* op) final { return 0.0; }
608
609 double VisitExpr_(const SelectNode* op) final {
610 return VisitExpr(op->condition) +
611 std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
612 }
613
614// Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS.
615#define VisitBinary(Node) \
616 double VisitExpr_(const Node* op) final { \
617 double base = 1.0; \
618 if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \
619 base = 0.0; \
620 } \
621 return base + VisitExpr(op->a) + VisitExpr(op->b); \
622 }
623
624#define VisitUnary(Node) \
625 double VisitExpr_(const Node* op) final { \
626 double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
627 return base + VisitExpr(op->a); \
628 }
629
630 VisitBinary(AddNode);
631 VisitBinary(SubNode);
632 VisitBinary(MulNode);
633 VisitBinary(DivNode);
634 VisitBinary(ModNode);
635 VisitBinary(FloorDivNode);
636 VisitBinary(FloorModNode);
637 VisitBinary(MaxNode);
638 VisitBinary(MinNode);
639 VisitBinary(EQNode);
640 VisitBinary(NENode);
641 VisitBinary(LTNode);
642 VisitBinary(LENode);
643 VisitBinary(GTNode);
644 VisitBinary(GENode);
645 VisitBinary(AndNode);
646 VisitBinary(OrNode);
647 VisitUnary(NotNode);
648
649 double VisitExpr_(const CallNode* op) final {
650 double ret = 0.0;
651 for (const auto& x : op->args) {
652 ret += VisitExpr(x);
653 }
654 return ret;
655 }
656
657 double VisitExprDefault_(const Object* op) final {
658 fail_ = true;
659 return -1.0;
660 }
661
662 private:
663 bool fail_{false};
664 int cur_type_code_;
665};
666
667void CheckComputeValidity(const te::Schedule& sch) {
668 // Check the validity of a compute definition:
669 // The name of each iterator should be unique.
670 for (auto stage : sch->stages) {
671 if (stage->op->IsInstance<te::ComputeOpNode>()) {
672 std::unordered_set<std::string> names;
673 for (const auto& x : stage->leaf_iter_vars) {
674 ICHECK(!names.count(x->var->name_hint))
675 << "Find duplicated iterator names in the compute definition: " << x->var->name_hint
676 << ". Please use different names for different iterators.";
677 names.insert(x->var->name_hint);
678 }
679 }
680 }
681}
682
683ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
684 auto node = make_object<ComputeDAGNode>();
685 node->tensors = std::move(tensors);
686 node->access_analyzer = AccessAnalyzer(node->tensors);
687
688 Array<te::Operation> out_ops;
689 for (const auto& op : node->access_analyzer->ops_topo_order) {
690 if (node->access_analyzer.IsOutput(op)) {
691 out_ops.push_back(op);
692 }
693 }
694 te::Schedule sch = te::create_schedule(out_ops);
695 for (auto stage : sch->stages) {
696 node->ops.push_back(stage->op);
697 }
698
699 // Make sure it is a valid compute definition
700 CheckComputeValidity(sch);
701
702 node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
703 node->init_state = State(node->ops);
704 data_ = std::move(node);
705}
706
707ComputeDAG::ComputeDAG(const te::Schedule& sch) {
708 auto node = make_object<ComputeDAGNode>();
709
710 // Make sure it is a valid compute definition
711 CheckComputeValidity(sch);
712
713 // Initialize ops. Here we enforce the order of ops and stages are consistent
714 for (auto stage : sch->stages) {
715 node->ops.push_back(stage->op);
716 }
717
718 // Collect input and output tensors
719 Array<te::Tensor> tensors;
720 for (auto stage : sch->stages) {
721 if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
722 for (auto i = 0; i < stage->op->num_outputs(); ++i) {
723 tensors.push_back(stage->op.output(i));
724 }
725 }
726 }
727 node->tensors = std::move(tensors);
728 node->access_analyzer = AccessAnalyzer(node->tensors);
729 node->flop_ct = FlopEstimator().EstimateFlop(node->ops);
730 node->init_state = State(node->ops);
731 data_ = std::move(node);
732}
733
734class IndexRewriter : public StmtExprMutator {
735 public:
736 IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout)
737 : placeholder_op_(placeholder_op) {
738 ParseKernelLayout(new_layout, &new_shape_, &new_names_);
739 }
740
741 PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); }
742
743 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
744 te::Tensor t = Downcast<te::Tensor>(op->producer);
745 if (t->op == placeholder_op_) {
746 std::unordered_map<std::string, PrimExpr> name_to_arg;
747 for (const auto& arg : op->indices) {
748 std::string axis_name;
749 if (const auto* int_imm = arg.as<IntImmNode>()) {
750 ICHECK_EQ(int_imm->value, 0);
751 axis_name = "IntImm";
752 } else {
753 axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint));
754 ICHECK_EQ(name_to_arg.count(axis_name), 0);
755 name_to_arg[axis_name] = arg;
756 }
757 }
758
759 std::unordered_map<std::string, PrimExpr> div_factors;
760 std::vector<PrimExpr> r_new_args;
761 for (int i = new_names_.size() - 1; i >= 0; --i) {
762 auto ori_iter_name = new_names_[i];
763 auto name_it = name_to_arg.find(ori_iter_name);
764 ICHECK(name_it != name_to_arg.end());
765 PrimExpr ori_arg = name_it->second;
766
767 PrimExpr mod_factor = new_shape_[i];
768
769 PrimExpr div_factor = 1;
770 if (div_factors.count(ori_iter_name)) {
771 div_factor = div_factors[ori_iter_name];
772 }
773 div_factors[ori_iter_name] = div_factor * new_shape_[i];
774
775 PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
776
777 r_new_args.push_back(new_arg);
778 }
779
780 Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
781 std::make_move_iterator(r_new_args.rend()));
782 return ProducerLoad(op->producer, new_args);
783 }
784 return GetRef<PrimExpr>(op);
785 }
786
787 private:
788 const te::Operation& placeholder_op_;
789 Array<PrimExpr> new_shape_;
790 std::vector<std::string> new_names_;
791};
792
793std::string GetOrigLayout(std::set<std::string>* placeholder_axis_names, const te::Operation& op,
794 const te::Tensor& placeholder) {
795 ReadAccessExtractor extractor;
796 for (const auto& exp : op.as<te::ComputeOpNode>()->body) {
797 extractor.Extract(exp);
798 }
799
800 std::ostringstream os;
801 uint32_t i = 0;
802 const auto& placeholder_op = placeholder->op;
803 ICHECK_GT(extractor.read_access.count(placeholder_op), 0);
804 for (const auto& ev : extractor.read_access[placeholder_op]) {
805 for (const auto& e : ev) {
806 std::string axis_name;
807 if (const auto* int_imm = e.as<IntImmNode>()) {
808 ICHECK_EQ(int_imm->value, 0);
809 axis_name = "IntImm";
810 } else {
811 axis_name = AxisBaseName(CleanName(Downcast<Var>(e)->name_hint));
812 }
813
814 placeholder_axis_names->insert(axis_name);
815 os << placeholder->shape[i++] << axis_name;
816 }
817 }
818
819 ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size());
820 std::string orig_layout = os.str();
821 os.str("");
822 ::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout);
823 return orig_layout;
824}
825
826std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage,
827 const te::Operation& op, const te::Tensor& placeholder,
828 const std::set<std::string>& placeholder_axis_names) {
829 std::ostringstream os;
830 Array<Iterator> stage_iters;
831
832 auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id);
833 int attach_pos = -1;
834 size_t iters_before_attach = 0;
835 if (attach_it != state->attach_map->stage_to_attach_iter.end()) {
836 auto attach = attach_it->second;
837 const auto& attach_stage = state->stages[attach.first];
838 attach_pos = attach.second;
839 stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(),
840 attach_stage->iters.begin() + attach_pos + 1);
841 }
842
843 stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end());
844
845 std::vector<Iterator> iters;
846 for (size_t i = 0; i < stage_iters.size(); ++i) {
847 const auto& iter = stage_iters[i];
848 if (iter->orig_iters.empty()) {
849 iters.push_back(iter);
850 } else {
851 for (const Iterator& ori_iter : iter->orig_iters) {
852 iters.push_back(ori_iter);
853 }
854 }
855 if (static_cast<int>(i) == attach_pos) {
856 iters_before_attach = iters.size();
857 }
858 }
859
860 std::vector<std::string> new_names;
861 std::vector<std::string> new_axis_names;
862 for (const Iterator& iter : iters) {
863 std::set<std::string> ori_iter_names;
864 ExtractOriginalIterators(iter->name, &ori_iter_names);
865 // fused iters have been replaced with iter->orig_iters.
866 // So there should be only one ori iter name extracted from iter->name.
867 ICHECK_EQ(ori_iter_names.size(), 1);
868 auto ori_iter_name = AxisBaseName(*ori_iter_names.begin());
869 new_axis_names.push_back(ori_iter_name);
870 }
871 for (size_t i = 0; i < new_axis_names.size(); ++i) {
872 auto iter = iters[i];
873 std::string ori_iter_name;
874 if (i < iters_before_attach) {
875 ori_iter_name = new_axis_names[i + iters_before_attach];
876 } else {
877 ori_iter_name = new_axis_names[i];
878 }
879 if (placeholder_axis_names.count(ori_iter_name)) {
880 PrimExpr extent;
881 if (iter->range.defined()) {
882 extent = iter->range->extent;
883 } else {
884 // This iter is simplified by InferBound, so it must have a length of one.
885 extent = 1;
886 }
887 os << extent << ori_iter_name;
888 new_names.push_back(ori_iter_name);
889 }
890 }
891 std::string new_layout = os.str();
892 os.str("");
893 ::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout);
894 return new_layout;
895}
896
897ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
898 LayoutRewriteOption layout_rewrite) const {
899 CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite)
900 << "Call ComputeDAG::RewriteLayout with NoRewrite.";
901 ComputeDAG new_dag = *this;
902 ComputeDAGNode* p_dag = new_dag.CopyOnWrite();
903
904 auto node = make_object<StateNode>();
905 node->transform_steps = *transform_steps;
906 node->concrete = true;
907 const State& state = InferBound(State(node));
908
909 OperationSet handled_ops;
910 for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) {
911 const auto& stage = state->stages[stage_id];
912
913 const te::Operation& op = stage->op;
914 if (!op->IsInstance<te::ComputeOpNode>()) {
915 continue;
916 }
917 const Map<String, ObjectRef>& attrs = op->attrs;
918 if (attrs.count(layout_free_placeholders_key) == 0) {
919 continue;
920 }
921 const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
922 for (const auto& placeholder : Downcast<Array<te::Tensor>>(attr_value)) {
923 const auto& placeholder_op = placeholder->op;
924
925 // Check whether this placeholder has already been handled
926 if (handled_ops.count(placeholder_op)) {
927 continue;
928 }
929 // Skip the op that is not direct consumer of this placeholder.
930 // This is usually caused by cache read/write.
931 bool direct_consumer = false;
932 for (auto& t : op->InputTensors()) {
933 if (t->op == placeholder_op) {
934 direct_consumer = true;
935 break;
936 }
937 }
938 if (!direct_consumer) {
939 continue;
940 }
941 handled_ops.insert(placeholder_op);
942
943 // Process original layout
944 std::set<std::string> placeholder_axis_names;
945 std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder);
946 Array<PrimExpr> origin_shape;
947 std::vector<std::string> origin_axes;
948 ParseKernelLayout(origin_layout, &origin_shape, &origin_axes);
949
950 // Process new layout
951 std::string new_layout =
952 GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names);
953 Array<PrimExpr> new_shape;
954 std::vector<std::string> new_axes;
955 ParseKernelLayout(new_layout, &new_shape, &new_axes);
956
957 // Process op updates
958 te::Operation new_op_to_update;
959 if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) {
960 // Create new placeholder
961 new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
962 placeholder_op.as<te::PlaceholderOpNode>()->dtype);
963 } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
964 // Process index strides
965 std::unordered_map<std::string, PrimExpr> axes_stride;
966 for (const auto& i : origin_axes) {
967 axes_stride[i] = Integer(1);
968 }
969 Array<PrimExpr> new_stride(new_shape.size(), PrimExpr());
970 PrimExpr temp = Integer(1);
971 for (int i = new_shape.size() - 1; i >= 0; i--) {
972 new_stride.Set(i, axes_stride[new_axes[i]]);
973 axes_stride[new_axes[i]] *= new_shape[i];
974 }
975
976 // Add an extra layout transform stage
977 const auto& layout_transform_tensor = te::compute(
978 new_shape,
979 [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes,
980 &new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
981 Array<PrimExpr> access_indices;
982 for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) {
983 PrimExpr temp = Integer(0);
984 for (size_t i = 0; i < new_shape.size(); i++) {
985 if (origin_axes[indice_index].compare(new_axes[i]) == 0) {
986 temp += indices[i] * new_stride[i];
987 }
988 }
989 access_indices.push_back(temp);
990 }
991 return placeholder_op.output(0)(access_indices);
992 },
993 "auto_scheduler_layout_transform");
994 new_op_to_update = layout_transform_tensor->op;
995
996 // Update the transform steps
997 for (size_t i = 0; i < transform_steps->size(); i++) {
998 Step step = (*transform_steps)[i];
999 if (step->stage_id >= static_cast<int>(stage_id)) {
1000 step.CopyOnWrite()->stage_id++;
1001 }
1002 if (step->IsInstance<ComputeAtStepNode>()) {
1003 auto compute_at_step = tvm::Downcast<ComputeAtStep>(step);
1004 if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) {
1005 dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++;
1006 }
1007 transform_steps->Set(i, std::move(compute_at_step));
1008 } else {
1009 transform_steps->Set(i, std::move(step));
1010 }
1011 }
1012
1013 // Add schedule for the new added transform stage
1014 Array<Integer> to_fuse;
1015
1016 if (new_shape.size() >= 5) {
1017 to_fuse.push_back(0);
1018 to_fuse.push_back(1);
1019 to_fuse.push_back(2);
1020 transform_steps->push_back(FuseStep(stage_id, to_fuse));
1021 } else if (new_shape.size() >= 3) {
1022 to_fuse.push_back(0);
1023 to_fuse.push_back(1);
1024 transform_steps->push_back(FuseStep(stage_id, to_fuse));
1025 }
1026 transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
1027 }
1028
1029 te::Operation new_compute_op, original_compute_op;
1030 Array<PrimExpr> new_body;
1031 IndexRewriter index_rewriter(placeholder_op, new_layout);
1032 for (const auto& op : p_dag->ops) {
1033 if (auto* pop = op.as<te::ComputeOpNode>()) {
1034 bool need_update = false;
1035 for (auto& t : op->InputTensors()) {
1036 if (t->op == placeholder_op) {
1037 need_update = true;
1038 break;
1039 }
1040 }
1041 if (need_update) {
1042 for (const auto& body : pop->body) {
1043 new_body.push_back(index_rewriter.Rewrite(body));
1044 }
1045 original_compute_op = op;
1046 CHECK(!new_compute_op.defined());
1047 auto new_attrs = pop->attrs;
1048 new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
1049 new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
1050 new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
1051 }
1052 }
1053 }
1054
1055 // construct the map from original_op to new_op
1056 std::unordered_map<te::Operation, te::Operation> updated_ops;
1057
1058 Array<te::Operation> original_ops = p_dag->ops;
1059 p_dag->ops.clear();
1060 for (size_t i = 0; i < original_ops.size(); ++i) {
1061 const auto& original_op = original_ops[i];
1062 if (original_op == placeholder_op) {
1063 if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
1064 p_dag->ops.push_back(placeholder_op);
1065 }
1066 p_dag->ops.push_back(new_op_to_update);
1067 updated_ops[placeholder_op] = new_op_to_update;
1068 } else if (original_op == original_compute_op) {
1069 p_dag->ops.push_back(new_compute_op);
1070 updated_ops[original_compute_op] = new_compute_op;
1071 } else {
1072 p_dag->ops.push_back(original_op);
1073 }
1074 }
1075
1076 ArrayNode* pops = p_dag->ops.CopyOnWrite();
1077 // Because ops is sorted in topo-order, only do one pass linear scan here.
1078 for (size_t i = 0; i < pops->size(); ++i) {
1079 const auto& original_op = Downcast<te::Operation>(pops->at(i));
1080 if (auto* pop = original_op.as<te::ComputeOpNode>()) {
1081 if (original_op == new_op_to_update) {
1082 continue;
1083 }
1084 auto inputs = pop->InputTensors();
1085 std::unordered_map<te::Tensor, te::Tensor> rmap;
1086 for (auto input : inputs) {
1087 auto it = updated_ops.find(input->op);
1088 te::Operation new_op;
1089 while (it != updated_ops.end()) {
1090 new_op = it->second;
1091 it = updated_ops.find(new_op);
1092 }
1093 if (new_op.defined()) {
1094 int index = input->value_index;
1095 rmap[input] = new_op.output(index);
1096 }
1097 }
1098 if (!rmap.empty()) {
1099 te::Operation new_op = pop->ReplaceInputs(original_op, rmap);
1100 updated_ops[original_op] = new_op;
1101 pops->SetItem(i, new_op);
1102 }
1103 }
1104 }
1105
1106 Array<te::Tensor> old_tensors = p_dag->tensors;
1107 ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
1108 for (size_t i = 0; i < old_tensors.size(); ++i) {
1109 const auto& old_tensor = old_tensors[i];
1110 if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed &&
1111 old_tensor->op->IsInstance<te::PlaceholderOpNode>()) {
1112 continue;
1113 }
1114 auto it = updated_ops.find(old_tensor->op);
1115 te::Operation new_op;
1116 while (it != updated_ops.end()) {
1117 new_op = it->second;
1118 it = updated_ops.find(new_op);
1119 }
1120 if (new_op.defined()) {
1121 auto index = old_tensor->value_index;
1122 p_tensors->SetItem(i, new_op.output(index));
1123 }
1124 }
1125 } // end for placeholder
1126 } // end for stage
1127 p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
1128
1129 Array<te::Operation> out_ops;
1130 for (const auto& op : p_dag->access_analyzer->ops_topo_order) {
1131 if (p_dag->access_analyzer.IsOutput(op)) {
1132 out_ops.push_back(op);
1133 }
1134 }
1135
1136 p_dag->ops.clear();
1137 te::Schedule sch = te::create_schedule(out_ops);
1138 for (auto stage : sch->stages) {
1139 p_dag->ops.push_back(stage->op);
1140 }
1141 p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
1142 p_dag->init_state = State(p_dag->ops);
1143
1144 return new_dag;
1145}
1146
1147// Return whether a DAG has placeholders that are marked as "layout free".
1148bool HasLayoutFreeTensors(const ComputeDAG& dag) {
1149 for (const auto& op : dag->ops) {
1150 if (!op->IsInstance<te::ComputeOpNode>()) {
1151 continue;
1152 }
1153 if (op->attrs.count(ComputeDAG::layout_free_placeholders_key)) {
1154 return true;
1155 }
1156 }
1157
1158 return false;
1159}
1160
1161std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
1162 const Array<Step>& transform_steps, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
1163 LayoutRewriteOption layout_rewrite) const {
1164 if (layout_rewrite != LayoutRewriteOption::NoRewrite && HasLayoutFreeTensors(*this) &&
1165 !transform_steps.empty()) {
1166 Array<Step> steps = transform_steps;
1167 const auto& dag = RewriteLayout(&steps, layout_rewrite);
1168 return dag.ApplySteps(steps);
1169 }
1170
1171 // Temporal object to be used if the input pointer is nullptr
1172 Array<te::Stage> temp_stages;
1173 StageToAxesMap temp_stage_to_axes;
1174 if (stages == nullptr) {
1175 stages = &temp_stages;
1176 }
1177 if (stage_to_axes == nullptr) {
1178 stage_to_axes = &temp_stage_to_axes;
1179 }
1180 Array<te::Operation> out_ops;
1181 for (const auto& op : operator->()->ops) {
1182 if (operator->()->access_analyzer.IsOutput(op)) {
1183 out_ops.push_back(op);
1184 }
1185 }
1186
1187 // Create the initial schedule
1188 te::Schedule schedule = te::create_schedule(out_ops);
1189
1190 // init axes
1191 for (const auto& x : operator->()->ops) {
1192 const te::Stage& stage = schedule[x];
1193 stages->push_back(stage);
1194 UpdateStageToAxesMap(stage, stage_to_axes);
1195 }
1196
1197 // Apply the history steps to TVM schedule
1198 // Call each step's ApplyToSchedule method
1199 for (const auto& step : transform_steps) {
1200 StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
1201 }
1202
1203 return std::make_pair(schedule, operator->()->tensors);
1204}
1205
1206String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const {
1207 Array<te::Stage> stages;
1208 StageToAxesMap stage_to_axes;
1209 Array<te::Operation> out_ops;
1210 for (const auto& op : operator->()->ops) {
1211 if (operator->()->access_analyzer.IsOutput(op)) {
1212 out_ops.push_back(op);
1213 }
1214 }
1215 // Create the initial schedule
1216 te::Schedule schedule = te::create_schedule(out_ops);
1217
1218 // init axes
1219 for (const auto& x : operator->()->ops) {
1220 const te::Stage& stage = schedule[x];
1221 stages.push_back(stage);
1222 UpdateStageToAxesMap(stage, &stage_to_axes);
1223 }
1224
1225 std::stringstream ss;
1226 for (const auto& stage : stages) {
1227 if (stage->op->IsInstance<te::ComputeOpNode>()) {
1228 auto op_name = CleanName(stage->op->name);
1229
1230 for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
1231 ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name);
1232 if (i != stage->leaf_iter_vars.size() - 1) {
1233 ss << ", ";
1234 }
1235 }
1236 ss << " = "
1237 << "tuple(" << op_name << ".op.axis)"
1238 << " + "
1239 << "tuple(" << op_name << ".op.reduce_axis)\n";
1240 }
1241 }
1242 // Call each step's PrintAsPythonAPI method
1243 for (const auto& step : transform_steps) {
1244 ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
1245 }
1246
1247 return ss.str();
1248}
1249
1250String ComputeDAG::PrintDAG(bool simple_mode) const {
1251 std::stringstream ss;
1252
1253 for (const auto& op : operator->()->ops) {
1254 if (op->IsInstance<te::PlaceholderOpNode>()) {
1255 ss << op->name << " = PLACEHOLDER ";
1256 if (!simple_mode) {
1257 ss << op.output(0)->shape;
1258 }
1259 ss << "\n";
1260 } else if (auto pop = op.as<te::ComputeOpNode>()) {
1261 for (size_t k = 0; k < pop->body.size(); ++k) {
1262 ss << op->name << "(";
1263 for (size_t i = 0; i < pop->axis.size(); i++) {
1264 ss << pop->axis[i]->var->name_hint;
1265 if (i != pop->axis.size() - 1) {
1266 ss << ", ";
1267 }
1268 }
1269 ss << ")";
1270 if (pop->body.size() > 1) {
1271 ss << ".v" << k;
1272 }
1273 if (auto p_reduce = pop->body[k].as<ReduceNode>()) {
1274 ICHECK_LT(k, p_reduce->combiner->result.size());
1275 PrimExpr combiner = p_reduce->combiner->result[k];
1276 if (combiner->IsInstance<AddNode>()) {
1277 ss << " += " << AsLegacyRepr(p_reduce->source[0]) << "\n";
1278 } else if (combiner->IsInstance<MaxNode>()) {
1279 ss << " max= " << AsLegacyRepr(p_reduce->source[0]) << "\n";
1280 } else if (combiner->IsInstance<MinNode>()) {
1281 ss << " min= " << AsLegacyRepr(p_reduce->source[0]) << "\n";
1282 } else if (combiner->IsInstance<SelectNode>()) {
1283 const auto& select = combiner.as<SelectNode>();
1284 ss << " select(" << AsLegacyRepr(select->condition) //
1285 << ", " << AsLegacyRepr(select->true_value) //
1286 << ", " << AsLegacyRepr(select->false_value) //
1287 << ")= (" << AsLegacyRepr(p_reduce->source[0]) //
1288 << ',' << AsLegacyRepr(p_reduce->source[1]) //
1289 << ")\n";
1290 } else {
1291 ss << "reduce" << AsLegacyRepr(combiner) << "\n";
1292 }
1293 } else {
1294 auto call = pop->body[k].as<CallNode>();
1295 if (simple_mode && call) {
1296 ss << " = " << AsLegacyRepr(call->op) << "\n";
1297 } else {
1298 ss << " = " << AsLegacyRepr(pop->body[k]) << "\n";
1299 }
1300 }
1301 }
1302 } else {
1303 LOG(FATAL) << "Invalid op";
1304 }
1305 }
1306 return String(ss.str());
1307}
1308
1309State ComputeDAG::InferBound(const State& state) const {
1310 ICHECK(state->concrete) << "Only concrete state can be processed to get bound info.";
1311
1312 State ret_state;
1313 StateNode* pstate;
1314
1315 if (state->stages.empty()) {
1316 // If the input state is incomplete with empty operation stage
1317 // create a new state from init_state and update it first
1318 ret_state = operator->()->init_state;
1319 pstate = ret_state.CopyOnWrite();
1320 pstate->transform_steps = state->transform_steps;
1321 for (const auto& step : pstate->transform_steps) {
1322 StepApplyToState(step, &ret_state, *this);
1323 }
1324 } else {
1325 ret_state = state;
1326 pstate = ret_state.CopyOnWrite();
1327 }
1328
1329 Array<te::Stage> stages;
1330 StageToAxesMap stage_to_axes;
1331 // Replay steps to tvm::Schedule
1332 auto [sch, tensors] = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
1333 (void)tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
1334 sch = sch.normalize_for_feature_extraction();
1335 // Get bound information from TVM schedule
1336 Map<IterVar, Range> bounds = te::InferBound(sch);
1337
1338 // Update the state bound information
1339 for (size_t i = 0; i < pstate->stages.size(); ++i) {
1340 const Stage& stage = pstate->stages[i];
1341
1342 if (stage->compute_at == ComputeAtKind::kInlined) {
1343 continue;
1344 }
1345
1346 Array<Iterator> new_iters;
1347 new_iters.reserve(stage->iters.size());
1348 // Get bound information from schedule
1349 // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result
1350 for (size_t j = 0; j < stage->iters.size(); ++j) {
1351 const Iterator& iter = stage->iters[j];
1352 const IterVar& axis = stage_to_axes.at(stages[i])[j];
1353
1354 auto find_res = bounds.find(axis);
1355 if (find_res != bounds.end()) {
1356 new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_kind,
1357 iter->annotation, &iter->orig_iters));
1358 } else {
1359 LOG(FATAL) << "Infer bound fails";
1360 }
1361 }
1362
1363 pstate->stages.Set(
1364 i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
1365 }
1366
1367 return ret_state;
1368}
1369
1370Array<State> ComputeDAG::InferBound(const Array<State>& states) const {
1371 Array<State> out_states(states.size(), State());
1372
1373 support::parallel_for(0, states.size(), [this, &states, &out_states](int i) {
1374 try {
1375 out_states.Set(i, (states[i].defined()) ? this->InferBound(states[i]) : states[i]);
1376 } catch (Error& e) {
1377 LOG(WARNING) << "InferBound fails on the state:\n"
1378 << states[i] << "\n"
1379 << "with: " << e.what() << std::endl;
1380 }
1381 });
1382
1383 return out_states;
1384}
1385
1386ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
1387 auto [sch, old_tensors] = ApplySteps(transform_steps);
1388 (void)old_tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
1389 return ComputeDAG(sch);
1390}
1391
1392TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1393 .set_dispatch<AccessAnalyzerNode>([](const ObjectRef& ref, ReprPrinter* p) {
1394 auto* node = static_cast<const AccessAnalyzerNode*>(ref.get());
1395 for (const auto& op : node->ops_topo_order) {
1396 p->stream << op << std::endl;
1397 p->stream << "is_simple_access:\t" << node->is_simple_access.at(op) << "\t\t";
1398 p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op)
1399 << std::endl;
1400 p->stream << "is_strictly_inlinable:\t" << node->is_strictly_inlineable.at(op) << "\t";
1401 p->stream << "is_output:\t" << node->is_output.at(op) << std::endl;
1402 p->stream << "Read from:\t";
1403 for (const auto& pair : node->read_from.at(op)) {
1404 for (const auto& index : pair.second) {
1405 p->stream << pair.first->name << Array<PrimExpr>(index) << ", ";
1406 }
1407 }
1408 p->stream << std::endl;
1409 p->stream << "Read by:\t";
1410 for (const auto& pair : node->read_by.at(op)) {
1411 for (const auto& index : pair.second) {
1412 p->stream << pair.first->name << Array<PrimExpr>(index) << ", ";
1413 }
1414 }
1415 p->stream << std::endl;
1416 p->stream << Chars('=', 50) << std::endl;
1417 }
1418
1419 AccessAnalyzer ana = GetRef<AccessAnalyzer>(node);
1420 p->stream << "ElementwiseMatch: \n";
1421 for (size_t i = 0; i < node->ops_topo_order.size(); ++i) {
1422 for (size_t j = 0; j < node->ops_topo_order.size(); ++j) {
1423 if (i == j) {
1424 continue;
1425 }
1426 if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) {
1427 p->stream << node->ops_topo_order[i]->name << " -> " << node->ops_topo_order[j]->name
1428 << std::endl;
1429 }
1430 }
1431 }
1432 p->stream << Chars('=', 50) << std::endl;
1433
1434 p->stream << "NumCommonOuterIterators: \n";
1435 for (const auto& src_pair : node->num_common_outer_iterators) {
1436 for (const auto& dst_pair : src_pair.second) {
1437 p->stream << src_pair.first->name << " " << dst_pair.first->name << " " << dst_pair.second
1438 << std::endl;
1439 }
1440 }
1441 p->stream << Chars('=', 50) << std::endl;
1442 });
1443
1444TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1445 .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
1446 auto* node = static_cast<const ComputeDAGNode*>(ref.get());
1447 auto dag = GetRef<ComputeDAG>(node);
1448 auto dag_str = dag.PrintDAG();
1449 p->stream << dag_str;
1450 });
1451
1452Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) {
1453 Array<PrimExpr> shape;
1454 std::vector<std::string> extracted_names;
1455 topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names);
1456
1457 Array<PrimExpr> ret(axis_names.size(), 1);
1458
1459 size_t ct = 0;
1460 for (size_t i = 0; i < axis_names.size(); ++i) {
1461 for (size_t j = 0; j < extracted_names.size(); ++j) {
1462 if (axis_names[i] == extracted_names[j]) {
1463 ret.Set(i, ret[i] * shape[j]);
1464 ct++;
1465 }
1466 }
1467 }
1468
1469 CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match";
1470
1471 return ret;
1472}
1473
1474TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")
1475 .set_body_typed([](Optional<Array<te::Tensor>> tensors, Optional<te::Schedule> sch) {
1476 if (sch) {
1477 return ComputeDAG(sch.value());
1478 }
1479 ICHECK(tensors) << "Both tensors and schedule are null";
1480 return ComputeDAG(tensors.value());
1481 });
1482
1483TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
1484 .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) {
1485 auto [sch, return_tensors] = dag.ApplySteps(state->transform_steps, nullptr, nullptr,
1486 static_cast<LayoutRewriteOption>(layout_rewrite));
1487 return Array<ObjectRef>{sch, return_tensors};
1488 });
1489
1490TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintPythonCodeFromState")
1491 .set_body_typed([](const ComputeDAG& dag, const State& state) {
1492 return dag.PrintStepsAsPython(state->transform_steps);
1493 });
1494
1495TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintDAG")
1496 .set_body_typed([](const ComputeDAG& dag, bool simple_mode) {
1497 return dag.PrintDAG(simple_mode);
1498 });
1499
1500TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState")
1501 .set_body_typed([](const ComputeDAG& dag, const State& state) {
1502 return dag.InferBound(state);
1503 });
1504
1505TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState")
1506 .set_body_typed([](const ComputeDAG& dag, const State& state) {
1507 Array<Step>* transform_steps = const_cast<Array<Step>*>(&state->transform_steps);
1508 return dag.RewriteLayout(transform_steps, LayoutRewriteOption::RewriteForPreTransformed);
1509 });
1510
1511TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
1512 .set_body_typed([](const te::Operation& placeholder_op, const std::string& new_layout,
1513 const PrimExpr& body) {
1514 IndexRewriter index_rewriter(placeholder_op, new_layout);
1515 return index_rewriter.Rewrite(body);
1516 });
1517
1518TVM_REGISTER_GLOBAL("auto_scheduler.RewriteTensorShape")
1519 .set_body_typed([](te::Tensor tensor, Array<PrimExpr> new_shape) -> void {
1520 ICHECK(tensor->op->IsInstance<te::PlaceholderOpNode>());
1521 te::PlaceholderOpNode* op =
1522 const_cast<te::PlaceholderOpNode*>(tensor->op.as<te::PlaceholderOpNode>());
1523 te::TensorNode* t = const_cast<te::TensorNode*>(tensor.get());
1524 op->shape = new_shape;
1525 t->shape = new_shape;
1526 });
1527
1528TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
1529 .set_body_typed(GetShapeFromRewrittenLayout);
1530
1531} // namespace auto_scheduler
1532} // namespace tvm
1533