1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/compute_dag.cc |
22 | * \brief Compute declaration graph and its related analysis tools. |
23 | */ |
24 | |
25 | #include <tvm/auto_scheduler/compute_dag.h> |
26 | #include <tvm/auto_scheduler/loop_state.h> |
27 | #include <tvm/auto_scheduler/search_policy.h> |
28 | #include <tvm/auto_scheduler/transform_step.h> |
29 | #include <tvm/runtime/registry.h> |
30 | #include <tvm/support/parallel_for.h> |
31 | #include <tvm/te/operation.h> |
32 | #include <tvm/te/schedule.h> |
33 | #include <tvm/te/schedule_pass.h> |
34 | #include <tvm/tir/builtin.h> |
35 | #include <tvm/tir/stmt_functor.h> |
36 | #include <tvm/topi/transform.h> |
37 | |
38 | #include <algorithm> |
39 | #include <cstdint> |
40 | #include <queue> |
41 | #include <unordered_map> |
42 | #include <unordered_set> |
43 | #include <vector> |
44 | |
45 | #include "../arith/pattern_match.h" |
46 | #include "../relay/transforms/auto_scheduler_layout_rewrite.h" |
47 | #include "search_policy/utils.h" |
48 | #include "utils.h" |
49 | |
50 | namespace tvm { |
51 | namespace auto_scheduler { |
52 | |
53 | using namespace tvm::tir; |
54 | |
55 | template <class T> |
56 | using OperationMap = AccessAnalyzerNode::OperationMap<T>; |
57 | using OperationSet = std::unordered_set<te::Operation, ObjectHash, ObjectEqual>; |
58 | |
59 | TVM_REGISTER_NODE_TYPE(ComputeDAGNode); |
60 | |
61 | // Topo-sort ops from tensors according to their read-write relations. |
62 | Array<te::Operation> TopoSortOps(const Array<te::Tensor>& tensors) { |
63 | std::unordered_map<const te::OperationNode*, int> degree; |
64 | std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*>> edge_set; |
65 | std::unordered_map<const te::OperationNode*, int> priority; |
66 | std::unordered_set<const te::OperationNode*> visited; |
67 | |
68 | // traverse to build edge_set and count degree |
69 | std::vector<const te::OperationNode*> stack; |
70 | stack.reserve(tensors.size()); |
71 | for (const auto& x : tensors) { |
72 | stack.push_back(x->op.operator->()); |
73 | } |
74 | |
75 | int ct = 0; |
76 | while (!stack.empty()) { |
77 | const te::OperationNode* op = stack.back(); |
78 | stack.pop_back(); |
79 | if (visited.count(op)) { |
80 | continue; |
81 | } |
82 | |
83 | priority[op] = ct; |
84 | ct++; |
85 | visited.insert(op); |
86 | |
87 | if (op->IsInstance<te::PlaceholderOpNode>()) { |
88 | degree[op] = 0; |
89 | } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) { |
90 | const Array<te::Tensor>& input_tensors = cop->InputTensors(); |
91 | degree[op] = input_tensors.size(); |
92 | for (const auto& ten : input_tensors) { |
93 | edge_set[ten->op.operator->()].push_back(op); |
94 | stack.push_back(ten->op.operator->()); |
95 | } |
96 | } else { |
97 | LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op); |
98 | } |
99 | } |
100 | |
101 | // topo sort |
102 | Array<te::Operation> ops; |
103 | |
104 | using Item = std::pair<const te::OperationNode*, int>; |
105 | auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; }; |
106 | std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp); |
107 | for (const auto& iter : degree) { |
108 | if (iter.second == 0) { |
109 | queue.push(Item(iter.first, priority[iter.first])); |
110 | } |
111 | } |
112 | |
113 | ops.reserve(degree.size()); |
114 | while (!queue.empty()) { |
115 | Item item = queue.top(); |
116 | queue.pop(); |
117 | ops.push_back(GetRef<te::Operation>(item.first)); |
118 | for (const auto& dst : edge_set[item.first]) { |
119 | degree[dst] -= 1; |
120 | if (degree[dst] == 0) { |
121 | queue.push(Item(dst, priority[dst])); |
122 | } |
123 | } |
124 | } |
125 | |
126 | return ops; |
127 | } |
128 | |
129 | // Extract all tensor accesses in an expr |
130 | class : public StmtExprVisitor { |
131 | public: |
132 | void (PrimExpr expr) { this->VisitExpr(expr); } |
133 | |
134 | void (const CallNode* op) final { |
135 | if (op->op.same_as(builtin::if_then_else())) { |
136 | has_branch = true; |
137 | } |
138 | StmtExprVisitor::VisitExpr_(op); |
139 | } |
140 | |
141 | void (const ProducerLoadNode* op) final { |
142 | read_access[Downcast<te::Tensor>(op->producer)->op].emplace_back(op->indices.begin(), |
143 | op->indices.end()); |
144 | StmtExprVisitor::VisitExpr_(op); |
145 | } |
146 | |
147 | void (const IfThenElseNode* op) final { |
148 | has_branch = true; |
149 | StmtExprVisitor::VisitStmt_(op); |
150 | } |
151 | |
152 | void (const SelectNode* op) final { |
153 | has_branch = true; |
154 | StmtExprVisitor::VisitExpr_(op); |
155 | } |
156 | |
157 | // All read accesses to all operations |
158 | // The innermost vector stores multi-dimensional indices. |
159 | // The middle vector stores possible multiple accesses |
160 | OperationMap<std::vector<std::vector<PrimExpr>>> ; |
161 | // Whether this expression has branch |
162 | bool {false}; |
163 | }; |
164 | |
165 | // Returns whether the expr equals to the var with an optional const shift |
166 | bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { |
167 | arith::PVar<PrimExpr> x; |
168 | arith::PVar<IntImm> c; |
169 | |
170 | if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) && |
171 | x.Eval().same_as(var)) { |
172 | return true; |
173 | } |
174 | return false; |
175 | } |
176 | |
177 | // Return whether the access to an operation is a simple access |
178 | // (i.e. all index is just a variable with an optional constant shift) |
179 | // For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not. |
180 | bool IsSimpleAccess(const te::Operation& op, const std::vector<PrimExpr>& indices, |
181 | bool* axis_missing, bool* axis_duplicated, bool* same_order) { |
182 | auto cop = op.as<te::ComputeOpNode>(); |
183 | if (cop == nullptr) { |
184 | return false; |
185 | } |
186 | |
187 | std::vector<int> index_to_var_idx; |
188 | std::vector<int> var_idx_ct(cop->axis.size(), 0); |
189 | |
190 | for (const auto& expr : indices) { |
191 | if (!is_const_int(expr)) { |
192 | bool found = false; |
193 | for (size_t i = 0; i < cop->axis.size(); ++i) { |
194 | if (IsConstShiftEqual(cop->axis[i]->var, expr)) { |
195 | index_to_var_idx.push_back(i); |
196 | var_idx_ct[i]++; |
197 | found = true; |
198 | break; |
199 | } |
200 | } |
201 | if (!found) { |
202 | return false; |
203 | } |
204 | } |
205 | } |
206 | |
207 | *axis_missing = false; // Some axes are missing |
208 | *axis_duplicated = false; // Some axes appear more than once |
209 | *same_order = true; // The axis order is the same as op->axis |
210 | for (int ct : var_idx_ct) { |
211 | if (ct == 0) { |
212 | *axis_missing = true; |
213 | } else if (ct > 1) { |
214 | *axis_duplicated = true; |
215 | } |
216 | } |
217 | for (size_t i = 1; i < index_to_var_idx.size(); ++i) { |
218 | if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { |
219 | *same_order = false; |
220 | break; |
221 | } |
222 | } |
223 | |
224 | return true; |
225 | } |
226 | |
227 | // Gather all VarNodes in an expr |
228 | void GatherVars(const PrimExpr& expr, std::unordered_set<const VarNode*>* vars) { |
229 | PostOrderVisit(expr, [&vars](const ObjectRef& node) { |
230 | if (const VarNode* op = node.as<VarNode>()) { |
231 | vars->insert(op); |
232 | } |
233 | }); |
234 | } |
235 | |
236 | // Check whether an expr has expensive operations (e.g. exp) |
237 | bool HasExpensiveOp(const PrimExpr& expr) { |
238 | bool found = false; |
239 | PostOrderVisit(expr, [&found](const ObjectRef& node) { |
240 | if (const CallNode* op = node.as<CallNode>()) { |
241 | if (op->op.as<OpNode>()->name == "tir.exp" ) { |
242 | found = true; |
243 | } |
244 | } |
245 | }); |
246 | return found; |
247 | } |
248 | |
249 | AccessAnalyzer::AccessAnalyzer(const Array<te::Tensor>& tensors) { |
250 | auto node = make_object<AccessAnalyzerNode>(); |
251 | OperationMap<bool> has_branch; |
252 | |
253 | // Get all ops in topological order |
254 | node->ops_topo_order = TopoSortOps(tensors); |
255 | |
256 | arith::Analyzer analyzer; |
257 | |
258 | // Build read & write access map |
259 | for (const auto& op : node->ops_topo_order) { |
260 | if (op->IsInstance<te::PlaceholderOpNode>()) { |
261 | node->read_from[op] = OperationMap<std::vector<std::vector<PrimExpr>>>(); |
262 | } else if (auto cop = op.as<te::ComputeOpNode>()) { |
263 | ReadAccessExtractor ; |
264 | for (const auto& exp : cop->body) { |
265 | extractor.Extract(exp); |
266 | } |
267 | |
268 | // read_by and read_from map |
269 | for (const auto& iter : extractor.read_access) { |
270 | std::vector<std::vector<PrimExpr>>& accesses = node->read_by[iter.first][op]; |
271 | accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); |
272 | } |
273 | |
274 | node->read_from[op] = std::move(extractor.read_access); |
275 | has_branch[op] = extractor.has_branch; |
276 | |
277 | // compute number of common outer iterators |
278 | for (const auto& pair : node->read_from[op]) { |
279 | const te::Operation& producer = pair.first; |
280 | const std::vector<std::vector<PrimExpr>>& access_list = pair.second; |
281 | const Array<PrimExpr>& output_shape = op->output_shape(0); |
282 | const Array<PrimExpr>& producer_shape = producer->output_shape(0); |
283 | |
284 | int n_common; |
285 | for (n_common = 0; |
286 | n_common < static_cast<int>(std::min(output_shape.size(), producer_shape.size())); |
287 | n_common++) { |
288 | if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) { |
289 | break; |
290 | } |
291 | |
292 | bool injective = true; |
293 | for (const auto& access : access_list) { |
294 | if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { |
295 | injective = false; |
296 | break; |
297 | } |
298 | } |
299 | |
300 | if (!injective) { |
301 | break; |
302 | } |
303 | } |
304 | |
305 | node->num_common_outer_iterators[op][producer] = n_common; |
306 | node->num_common_outer_iterators[producer][op] = n_common; |
307 | } |
308 | } else { |
309 | LOG(FATAL) << "Invalid op: " << op; |
310 | } |
311 | } |
312 | |
313 | // Do some static analysis on ComputeOps |
314 | for (const auto& op : node->ops_topo_order) { |
315 | if (op->IsInstance<te::PlaceholderOpNode>()) { |
316 | node->is_simple_access[op] = true; |
317 | node->needs_multi_level_tiling[op] = false; |
318 | node->is_strictly_inlineable[op] = false; |
319 | node->is_output[op] = false; |
320 | } else if (auto cop = op.as<te::ComputeOpNode>()) { |
321 | // check whether this op is element-wise and strict-inlineable |
322 | bool is_simple_access = true; |
323 | bool is_strictly_inlineable = true; |
324 | |
325 | bool axis_missing, axis_duplicated, same_order; |
326 | for (const auto& pair : node->read_from[op]) { |
327 | const std::vector<std::vector<PrimExpr>>& access_list = pair.second; |
328 | for (const auto& access : access_list) { |
329 | if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, |
330 | &same_order)) { |
331 | is_simple_access = false; |
332 | is_strictly_inlineable = false; |
333 | break; |
334 | } |
335 | if (!same_order || axis_duplicated) { |
336 | // do not strictly inline transpose |
337 | is_strictly_inlineable = false; |
338 | } |
339 | } |
340 | if (!is_simple_access) { |
341 | break; |
342 | } |
343 | } |
344 | |
345 | // don't strictly inline expensive op (e.g. exp) |
346 | bool has_expensive_op = false; |
347 | for (const auto& expr : cop->body) { |
348 | has_expensive_op |= HasExpensiveOp(expr); |
349 | } |
350 | if (has_expensive_op || has_branch[op]) { |
351 | is_strictly_inlineable = false; |
352 | } |
353 | |
354 | // constant tensor is strict-inlineable |
355 | if (node->read_from[op].empty()) { |
356 | is_strictly_inlineable = true; |
357 | } |
358 | |
359 | node->is_simple_access[op] = is_simple_access; |
360 | node->is_strictly_inlineable[op] = is_strictly_inlineable; |
361 | |
362 | // check whether the op needs multi-level tiling |
363 | bool needs_multi_level_tiling = false; |
364 | int n_missing = 0; |
365 | |
366 | for (const auto& pair : node->read_from[op]) { |
367 | const std::vector<std::vector<PrimExpr>>& access_list = pair.second; |
368 | std::unordered_set<const VarNode*> vars; |
369 | for (const std::vector<PrimExpr>& access : access_list) { |
370 | for (const PrimExpr& expr : access) { |
371 | GatherVars(expr, &vars); |
372 | } |
373 | } |
374 | |
375 | for (const auto& axis : cop->axis) { |
376 | if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { |
377 | n_missing++; |
378 | break; |
379 | } |
380 | } |
381 | |
382 | if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) { |
383 | needs_multi_level_tiling = true; |
384 | break; |
385 | } |
386 | } |
387 | |
388 | // do not perform multi-level tiling on "fake reduction" with const tensors |
389 | if (op->attrs.count(SearchPolicyKey::simplify_const_tensor_indices)) { |
390 | needs_multi_level_tiling = false; |
391 | } |
392 | |
393 | node->needs_multi_level_tiling[op] = needs_multi_level_tiling; |
394 | |
395 | // check whether the op is output |
396 | node->is_output[op] = node->read_by[op].empty(); |
397 | } else { |
398 | LOG(FATAL) << "Invalid op" << op; |
399 | } |
400 | } |
401 | |
402 | data_ = std::move(node); |
403 | } |
404 | |
405 | bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const { |
406 | return operator->()->needs_multi_level_tiling.at(op); |
407 | } |
408 | |
409 | bool AccessAnalyzer::IsOutput(const te::Operation& op) const { |
410 | return operator->()->is_output.at(op); |
411 | } |
412 | |
413 | bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { |
414 | return operator->()->is_simple_access.at(op); |
415 | } |
416 | |
417 | bool AccessAnalyzer::IsStrictlyInlineable(const te::Operation& op) const { |
418 | return operator->()->is_strictly_inlineable.at(op); |
419 | } |
420 | |
421 | OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { |
422 | OperationSet inlined_ops; |
423 | for (const auto& stage : state->stages) { |
424 | if (stage->compute_at == ComputeAtKind::kInlined) { |
425 | inlined_ops.insert(stage->op); |
426 | } |
427 | } |
428 | |
429 | OperationSet consumers; |
430 | std::function<void(const te::Operation&)> collect; |
431 | collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { |
432 | for (const auto& iter : operator->()->read_by.at(op)) { |
433 | if (inlined_ops.count(iter.first)) { |
434 | collect(iter.first); |
435 | } else { |
436 | consumers.insert(iter.first); |
437 | } |
438 | } |
439 | }; |
440 | |
441 | collect(op); |
442 | return consumers; |
443 | } |
444 | |
445 | OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const { |
446 | OperationSet producers; |
447 | for (const auto& iter : operator->()->read_from.at(op)) { |
448 | producers.insert(iter.first); |
449 | } |
450 | return producers; |
451 | } |
452 | |
453 | OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const { |
454 | OperationSet inlined_ops; |
455 | for (const auto& stage : state->stages) { |
456 | if (stage->compute_at == ComputeAtKind::kInlined) { |
457 | inlined_ops.insert(stage->op); |
458 | } |
459 | } |
460 | |
461 | OperationSet producers; |
462 | std::function<void(const te::Operation&)> collect; |
463 | collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { |
464 | for (const auto& iter : operator->()->read_from.at(op)) { |
465 | if (inlined_ops.count(iter.first)) { |
466 | collect(iter.first); |
467 | } else { |
468 | producers.insert(iter.first); |
469 | } |
470 | } |
471 | }; |
472 | |
473 | collect(op); |
474 | return producers; |
475 | } |
476 | |
477 | int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, |
478 | const te::Operation& target_op) const { |
479 | int ret = INT32_MAX; |
480 | bool meet = false; |
481 | |
482 | std::function<void(const te::Operation&, int)> traverse; |
483 | traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) { |
484 | if (cur_op == target_op) { |
485 | ret = std::min(ret, cur_num); |
486 | meet = true; |
487 | return; |
488 | } |
489 | |
490 | for (const auto& iter : operator->()->read_by.at(cur_op)) { |
491 | traverse( |
492 | iter.first, |
493 | std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first))); |
494 | } |
495 | }; |
496 | |
497 | traverse(op, op->output_shape(0).size()); |
498 | return meet ? ret : 0; |
499 | } |
500 | |
501 | bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, |
502 | const te::Operation& target_op) const { |
503 | te::Operation cur_op = op; |
504 | while (cur_op != target_op) { |
505 | const AccessAnalyzerNode::OperationMap<std::vector<std::vector<PrimExpr>>>& map = |
506 | operator->()->read_by.at(cur_op); |
507 | |
508 | if (map.size() != 1) { |
509 | return false; |
510 | } |
511 | te::Operation next_op = map.begin()->first; |
512 | |
513 | // Check condition 1: They have the same output size |
514 | auto p_cur = cur_op.as<te::ComputeOpNode>(); |
515 | auto p_next = next_op.as<te::ComputeOpNode>(); |
516 | if (p_cur == nullptr || p_next == nullptr) { |
517 | return false; |
518 | } |
519 | |
520 | Array<PrimExpr> output_shape = p_cur->output_shape(0); |
521 | for (int i = 1; i < p_cur->num_outputs(); ++i) { |
522 | if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { |
523 | return false; |
524 | } |
525 | } |
526 | for (int i = 0; i < p_next->num_outputs(); ++i) { |
527 | if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { |
528 | return false; |
529 | } |
530 | } |
531 | |
532 | // Check condition 2: The read is elementwise |
533 | const std::vector<std::vector<PrimExpr>> reads = map.begin()->second; |
534 | bool is_simple_access, axis_missing, axis_duplicated, same_order; |
535 | for (const auto& read : reads) { |
536 | is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, |
537 | &axis_duplicated, &same_order); |
538 | if (!is_simple_access || axis_missing || axis_duplicated || !same_order) { |
539 | return false; |
540 | } |
541 | } |
542 | |
543 | cur_op = std::move(next_op); |
544 | } |
545 | return true; |
546 | } |
547 | |
548 | // Estimate the number of float operations in an expression |
549 | class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> { |
550 | public: |
551 | double EstimateFlop(const Array<te::Operation>& ops) { |
552 | double ret = 0; |
553 | for (const auto& op : ops) { |
554 | if (auto pop = op.as<te::ComputeOpNode>()) { |
555 | if (pop->attrs.count("FLOP" )) { |
556 | // Use user-provided FLOP |
557 | auto pint = pop->attrs["FLOP" ].as<IntImmNode>(); |
558 | ICHECK(pint != nullptr); |
559 | ret += pint->value; |
560 | } else { |
561 | // Estimate by parsing the compute body |
562 | double num_element = AxisLengthProd(pop->axis); |
563 | if (num_element == -1) { |
564 | fail_ = true; |
565 | break; |
566 | } |
567 | cur_type_code_ = pop->output_dtype(0).code(); |
568 | double op_per_element = 0; |
569 | for (const auto& x : pop->body) { |
570 | op_per_element += VisitExpr(x); |
571 | } |
572 | ret += num_element * op_per_element; |
573 | } |
574 | } else if (op->IsInstance<te::PlaceholderOpNode>()) { |
575 | {} // do nothing |
576 | } else { |
577 | LOG(FATAL) << "Invalid op type " << op; |
578 | } |
579 | } |
580 | |
581 | return fail_ ? -1 : ret; |
582 | } |
583 | |
584 | double VisitExpr_(const ReduceNode* op) final { |
585 | uint64_t num_iter = 1; |
586 | for (const auto& x : op->axis) { |
587 | if (auto imm = x->dom->extent.as<IntImmNode>()) { |
588 | num_iter *= imm->value; |
589 | } else { |
590 | fail_ = true; |
591 | num_iter = -1; |
592 | } |
593 | } |
594 | double body_flop = 0; |
595 | for (size_t i = 0; i < op->combiner->result.size(); ++i) { |
596 | body_flop += VisitExpr(op->combiner->result[i]); |
597 | body_flop += VisitExpr(op->source[i]); |
598 | } |
599 | return num_iter * body_flop; |
600 | } |
601 | |
602 | double VisitExpr_(const FloatImmNode* op) final { return 0.0; } |
603 | double VisitExpr_(const IntImmNode* op) final { return 0.0; } |
604 | double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } |
605 | |
606 | double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } |
607 | double VisitExpr_(const VarNode* op) final { return 0.0; } |
608 | |
609 | double VisitExpr_(const SelectNode* op) final { |
610 | return VisitExpr(op->condition) + |
611 | std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); |
612 | } |
613 | |
614 | // Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS. |
615 | #define VisitBinary(Node) \ |
616 | double VisitExpr_(const Node* op) final { \ |
617 | double base = 1.0; \ |
618 | if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \ |
619 | base = 0.0; \ |
620 | } \ |
621 | return base + VisitExpr(op->a) + VisitExpr(op->b); \ |
622 | } |
623 | |
624 | #define VisitUnary(Node) \ |
625 | double VisitExpr_(const Node* op) final { \ |
626 | double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ |
627 | return base + VisitExpr(op->a); \ |
628 | } |
629 | |
630 | VisitBinary(AddNode); |
631 | VisitBinary(SubNode); |
632 | VisitBinary(MulNode); |
633 | VisitBinary(DivNode); |
634 | VisitBinary(ModNode); |
635 | VisitBinary(FloorDivNode); |
636 | VisitBinary(FloorModNode); |
637 | VisitBinary(MaxNode); |
638 | VisitBinary(MinNode); |
639 | VisitBinary(EQNode); |
640 | VisitBinary(NENode); |
641 | VisitBinary(LTNode); |
642 | VisitBinary(LENode); |
643 | VisitBinary(GTNode); |
644 | VisitBinary(GENode); |
645 | VisitBinary(AndNode); |
646 | VisitBinary(OrNode); |
647 | VisitUnary(NotNode); |
648 | |
649 | double VisitExpr_(const CallNode* op) final { |
650 | double ret = 0.0; |
651 | for (const auto& x : op->args) { |
652 | ret += VisitExpr(x); |
653 | } |
654 | return ret; |
655 | } |
656 | |
657 | double VisitExprDefault_(const Object* op) final { |
658 | fail_ = true; |
659 | return -1.0; |
660 | } |
661 | |
662 | private: |
663 | bool fail_{false}; |
664 | int cur_type_code_; |
665 | }; |
666 | |
667 | void CheckComputeValidity(const te::Schedule& sch) { |
668 | // Check the validity of a compute definition: |
669 | // The name of each iterator should be unique. |
670 | for (auto stage : sch->stages) { |
671 | if (stage->op->IsInstance<te::ComputeOpNode>()) { |
672 | std::unordered_set<std::string> names; |
673 | for (const auto& x : stage->leaf_iter_vars) { |
674 | ICHECK(!names.count(x->var->name_hint)) |
675 | << "Find duplicated iterator names in the compute definition: " << x->var->name_hint |
676 | << ". Please use different names for different iterators." ; |
677 | names.insert(x->var->name_hint); |
678 | } |
679 | } |
680 | } |
681 | } |
682 | |
683 | ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) { |
684 | auto node = make_object<ComputeDAGNode>(); |
685 | node->tensors = std::move(tensors); |
686 | node->access_analyzer = AccessAnalyzer(node->tensors); |
687 | |
688 | Array<te::Operation> out_ops; |
689 | for (const auto& op : node->access_analyzer->ops_topo_order) { |
690 | if (node->access_analyzer.IsOutput(op)) { |
691 | out_ops.push_back(op); |
692 | } |
693 | } |
694 | te::Schedule sch = te::create_schedule(out_ops); |
695 | for (auto stage : sch->stages) { |
696 | node->ops.push_back(stage->op); |
697 | } |
698 | |
699 | // Make sure it is a valid compute definition |
700 | CheckComputeValidity(sch); |
701 | |
702 | node->flop_ct = FlopEstimator().EstimateFlop(node->ops); |
703 | node->init_state = State(node->ops); |
704 | data_ = std::move(node); |
705 | } |
706 | |
707 | ComputeDAG::ComputeDAG(const te::Schedule& sch) { |
708 | auto node = make_object<ComputeDAGNode>(); |
709 | |
710 | // Make sure it is a valid compute definition |
711 | CheckComputeValidity(sch); |
712 | |
713 | // Initialize ops. Here we enforce the order of ops and stages are consistent |
714 | for (auto stage : sch->stages) { |
715 | node->ops.push_back(stage->op); |
716 | } |
717 | |
718 | // Collect input and output tensors |
719 | Array<te::Tensor> tensors; |
720 | for (auto stage : sch->stages) { |
721 | if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) { |
722 | for (auto i = 0; i < stage->op->num_outputs(); ++i) { |
723 | tensors.push_back(stage->op.output(i)); |
724 | } |
725 | } |
726 | } |
727 | node->tensors = std::move(tensors); |
728 | node->access_analyzer = AccessAnalyzer(node->tensors); |
729 | node->flop_ct = FlopEstimator().EstimateFlop(node->ops); |
730 | node->init_state = State(node->ops); |
731 | data_ = std::move(node); |
732 | } |
733 | |
734 | class IndexRewriter : public StmtExprMutator { |
735 | public: |
736 | IndexRewriter(const te::Operation& placeholder_op, const std::string& new_layout) |
737 | : placeholder_op_(placeholder_op) { |
738 | ParseKernelLayout(new_layout, &new_shape_, &new_names_); |
739 | } |
740 | |
741 | PrimExpr Rewrite(PrimExpr expr) { return this->VisitExpr(expr); } |
742 | |
743 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
744 | te::Tensor t = Downcast<te::Tensor>(op->producer); |
745 | if (t->op == placeholder_op_) { |
746 | std::unordered_map<std::string, PrimExpr> name_to_arg; |
747 | for (const auto& arg : op->indices) { |
748 | std::string axis_name; |
749 | if (const auto* int_imm = arg.as<IntImmNode>()) { |
750 | ICHECK_EQ(int_imm->value, 0); |
751 | axis_name = "IntImm" ; |
752 | } else { |
753 | axis_name = AxisBaseName(CleanName(Downcast<Var>(arg)->name_hint)); |
754 | ICHECK_EQ(name_to_arg.count(axis_name), 0); |
755 | name_to_arg[axis_name] = arg; |
756 | } |
757 | } |
758 | |
759 | std::unordered_map<std::string, PrimExpr> div_factors; |
760 | std::vector<PrimExpr> r_new_args; |
761 | for (int i = new_names_.size() - 1; i >= 0; --i) { |
762 | auto ori_iter_name = new_names_[i]; |
763 | auto name_it = name_to_arg.find(ori_iter_name); |
764 | ICHECK(name_it != name_to_arg.end()); |
765 | PrimExpr ori_arg = name_it->second; |
766 | |
767 | PrimExpr mod_factor = new_shape_[i]; |
768 | |
769 | PrimExpr div_factor = 1; |
770 | if (div_factors.count(ori_iter_name)) { |
771 | div_factor = div_factors[ori_iter_name]; |
772 | } |
773 | div_factors[ori_iter_name] = div_factor * new_shape_[i]; |
774 | |
775 | PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); |
776 | |
777 | r_new_args.push_back(new_arg); |
778 | } |
779 | |
780 | Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()), |
781 | std::make_move_iterator(r_new_args.rend())); |
782 | return ProducerLoad(op->producer, new_args); |
783 | } |
784 | return GetRef<PrimExpr>(op); |
785 | } |
786 | |
787 | private: |
788 | const te::Operation& placeholder_op_; |
789 | Array<PrimExpr> new_shape_; |
790 | std::vector<std::string> new_names_; |
791 | }; |
792 | |
793 | std::string GetOrigLayout(std::set<std::string>* placeholder_axis_names, const te::Operation& op, |
794 | const te::Tensor& placeholder) { |
795 | ReadAccessExtractor ; |
796 | for (const auto& exp : op.as<te::ComputeOpNode>()->body) { |
797 | extractor.Extract(exp); |
798 | } |
799 | |
800 | std::ostringstream os; |
801 | uint32_t i = 0; |
802 | const auto& placeholder_op = placeholder->op; |
803 | ICHECK_GT(extractor.read_access.count(placeholder_op), 0); |
804 | for (const auto& ev : extractor.read_access[placeholder_op]) { |
805 | for (const auto& e : ev) { |
806 | std::string axis_name; |
807 | if (const auto* int_imm = e.as<IntImmNode>()) { |
808 | ICHECK_EQ(int_imm->value, 0); |
809 | axis_name = "IntImm" ; |
810 | } else { |
811 | axis_name = AxisBaseName(CleanName(Downcast<Var>(e)->name_hint)); |
812 | } |
813 | |
814 | placeholder_axis_names->insert(axis_name); |
815 | os << placeholder->shape[i++] << axis_name; |
816 | } |
817 | } |
818 | |
819 | ICHECK_EQ(placeholder_axis_names->size(), placeholder->shape.size()); |
820 | std::string orig_layout = os.str(); |
821 | os.str("" ); |
822 | ::tvm::relay::AutoSchedulerLayoutRewriter::global_ori_layouts_queue.push_back(orig_layout); |
823 | return orig_layout; |
824 | } |
825 | |
826 | std::string GetNewLayout(const State& state, const int stage_id, const Stage& stage, |
827 | const te::Operation& op, const te::Tensor& placeholder, |
828 | const std::set<std::string>& placeholder_axis_names) { |
829 | std::ostringstream os; |
830 | Array<Iterator> stage_iters; |
831 | |
832 | auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); |
833 | int attach_pos = -1; |
834 | size_t iters_before_attach = 0; |
835 | if (attach_it != state->attach_map->stage_to_attach_iter.end()) { |
836 | auto attach = attach_it->second; |
837 | const auto& attach_stage = state->stages[attach.first]; |
838 | attach_pos = attach.second; |
839 | stage_iters.insert(stage_iters.end(), attach_stage->iters.begin(), |
840 | attach_stage->iters.begin() + attach_pos + 1); |
841 | } |
842 | |
843 | stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); |
844 | |
845 | std::vector<Iterator> iters; |
846 | for (size_t i = 0; i < stage_iters.size(); ++i) { |
847 | const auto& iter = stage_iters[i]; |
848 | if (iter->orig_iters.empty()) { |
849 | iters.push_back(iter); |
850 | } else { |
851 | for (const Iterator& ori_iter : iter->orig_iters) { |
852 | iters.push_back(ori_iter); |
853 | } |
854 | } |
855 | if (static_cast<int>(i) == attach_pos) { |
856 | iters_before_attach = iters.size(); |
857 | } |
858 | } |
859 | |
860 | std::vector<std::string> new_names; |
861 | std::vector<std::string> new_axis_names; |
862 | for (const Iterator& iter : iters) { |
863 | std::set<std::string> ori_iter_names; |
864 | ExtractOriginalIterators(iter->name, &ori_iter_names); |
865 | // fused iters have been replaced with iter->orig_iters. |
866 | // So there should be only one ori iter name extracted from iter->name. |
867 | ICHECK_EQ(ori_iter_names.size(), 1); |
868 | auto ori_iter_name = AxisBaseName(*ori_iter_names.begin()); |
869 | new_axis_names.push_back(ori_iter_name); |
870 | } |
871 | for (size_t i = 0; i < new_axis_names.size(); ++i) { |
872 | auto iter = iters[i]; |
873 | std::string ori_iter_name; |
874 | if (i < iters_before_attach) { |
875 | ori_iter_name = new_axis_names[i + iters_before_attach]; |
876 | } else { |
877 | ori_iter_name = new_axis_names[i]; |
878 | } |
879 | if (placeholder_axis_names.count(ori_iter_name)) { |
880 | PrimExpr extent; |
881 | if (iter->range.defined()) { |
882 | extent = iter->range->extent; |
883 | } else { |
884 | // This iter is simplified by InferBound, so it must have a length of one. |
885 | extent = 1; |
886 | } |
887 | os << extent << ori_iter_name; |
888 | new_names.push_back(ori_iter_name); |
889 | } |
890 | } |
891 | std::string new_layout = os.str(); |
892 | os.str("" ); |
893 | ::tvm::relay::AutoSchedulerLayoutRewriter::global_new_layouts_queue.push_back(new_layout); |
894 | return new_layout; |
895 | } |
896 | |
897 | ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps, |
898 | LayoutRewriteOption layout_rewrite) const { |
899 | CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite) |
900 | << "Call ComputeDAG::RewriteLayout with NoRewrite." ; |
901 | ComputeDAG new_dag = *this; |
902 | ComputeDAGNode* p_dag = new_dag.CopyOnWrite(); |
903 | |
904 | auto node = make_object<StateNode>(); |
905 | node->transform_steps = *transform_steps; |
906 | node->concrete = true; |
907 | const State& state = InferBound(State(node)); |
908 | |
909 | OperationSet handled_ops; |
910 | for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) { |
911 | const auto& stage = state->stages[stage_id]; |
912 | |
913 | const te::Operation& op = stage->op; |
914 | if (!op->IsInstance<te::ComputeOpNode>()) { |
915 | continue; |
916 | } |
917 | const Map<String, ObjectRef>& attrs = op->attrs; |
918 | if (attrs.count(layout_free_placeholders_key) == 0) { |
919 | continue; |
920 | } |
921 | const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; |
922 | for (const auto& placeholder : Downcast<Array<te::Tensor>>(attr_value)) { |
923 | const auto& placeholder_op = placeholder->op; |
924 | |
925 | // Check whether this placeholder has already been handled |
926 | if (handled_ops.count(placeholder_op)) { |
927 | continue; |
928 | } |
929 | // Skip the op that is not direct consumer of this placeholder. |
930 | // This is usually caused by cache read/write. |
931 | bool direct_consumer = false; |
932 | for (auto& t : op->InputTensors()) { |
933 | if (t->op == placeholder_op) { |
934 | direct_consumer = true; |
935 | break; |
936 | } |
937 | } |
938 | if (!direct_consumer) { |
939 | continue; |
940 | } |
941 | handled_ops.insert(placeholder_op); |
942 | |
943 | // Process original layout |
944 | std::set<std::string> placeholder_axis_names; |
945 | std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder); |
946 | Array<PrimExpr> origin_shape; |
947 | std::vector<std::string> origin_axes; |
948 | ParseKernelLayout(origin_layout, &origin_shape, &origin_axes); |
949 | |
950 | // Process new layout |
951 | std::string new_layout = |
952 | GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names); |
953 | Array<PrimExpr> new_shape; |
954 | std::vector<std::string> new_axes; |
955 | ParseKernelLayout(new_layout, &new_shape, &new_axes); |
956 | |
957 | // Process op updates |
958 | te::Operation new_op_to_update; |
959 | if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) { |
960 | // Create new placeholder |
961 | new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape, |
962 | placeholder_op.as<te::PlaceholderOpNode>()->dtype); |
963 | } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { |
964 | // Process index strides |
965 | std::unordered_map<std::string, PrimExpr> axes_stride; |
966 | for (const auto& i : origin_axes) { |
967 | axes_stride[i] = Integer(1); |
968 | } |
969 | Array<PrimExpr> new_stride(new_shape.size(), PrimExpr()); |
970 | PrimExpr temp = Integer(1); |
971 | for (int i = new_shape.size() - 1; i >= 0; i--) { |
972 | new_stride.Set(i, axes_stride[new_axes[i]]); |
973 | axes_stride[new_axes[i]] *= new_shape[i]; |
974 | } |
975 | |
976 | // Add an extra layout transform stage |
977 | const auto& layout_transform_tensor = te::compute( |
978 | new_shape, |
979 | [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes, |
980 | &new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr { |
981 | Array<PrimExpr> access_indices; |
982 | for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) { |
983 | PrimExpr temp = Integer(0); |
984 | for (size_t i = 0; i < new_shape.size(); i++) { |
985 | if (origin_axes[indice_index].compare(new_axes[i]) == 0) { |
986 | temp += indices[i] * new_stride[i]; |
987 | } |
988 | } |
989 | access_indices.push_back(temp); |
990 | } |
991 | return placeholder_op.output(0)(access_indices); |
992 | }, |
993 | "auto_scheduler_layout_transform" ); |
994 | new_op_to_update = layout_transform_tensor->op; |
995 | |
996 | // Update the transform steps |
997 | for (size_t i = 0; i < transform_steps->size(); i++) { |
998 | Step step = (*transform_steps)[i]; |
999 | if (step->stage_id >= static_cast<int>(stage_id)) { |
1000 | step.CopyOnWrite()->stage_id++; |
1001 | } |
1002 | if (step->IsInstance<ComputeAtStepNode>()) { |
1003 | auto compute_at_step = tvm::Downcast<ComputeAtStep>(step); |
1004 | if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) { |
1005 | dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++; |
1006 | } |
1007 | transform_steps->Set(i, std::move(compute_at_step)); |
1008 | } else { |
1009 | transform_steps->Set(i, std::move(step)); |
1010 | } |
1011 | } |
1012 | |
1013 | // Add schedule for the new added transform stage |
1014 | Array<Integer> to_fuse; |
1015 | |
1016 | if (new_shape.size() >= 5) { |
1017 | to_fuse.push_back(0); |
1018 | to_fuse.push_back(1); |
1019 | to_fuse.push_back(2); |
1020 | transform_steps->push_back(FuseStep(stage_id, to_fuse)); |
1021 | } else if (new_shape.size() >= 3) { |
1022 | to_fuse.push_back(0); |
1023 | to_fuse.push_back(1); |
1024 | transform_steps->push_back(FuseStep(stage_id, to_fuse)); |
1025 | } |
1026 | transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel)); |
1027 | } |
1028 | |
1029 | te::Operation new_compute_op, original_compute_op; |
1030 | Array<PrimExpr> new_body; |
1031 | IndexRewriter index_rewriter(placeholder_op, new_layout); |
1032 | for (const auto& op : p_dag->ops) { |
1033 | if (auto* pop = op.as<te::ComputeOpNode>()) { |
1034 | bool need_update = false; |
1035 | for (auto& t : op->InputTensors()) { |
1036 | if (t->op == placeholder_op) { |
1037 | need_update = true; |
1038 | break; |
1039 | } |
1040 | } |
1041 | if (need_update) { |
1042 | for (const auto& body : pop->body) { |
1043 | new_body.push_back(index_rewriter.Rewrite(body)); |
1044 | } |
1045 | original_compute_op = op; |
1046 | CHECK(!new_compute_op.defined()); |
1047 | auto new_attrs = pop->attrs; |
1048 | new_attrs.Set("ori_placeholder_layout" , tvm::String(origin_layout)); |
1049 | new_attrs.Set("new_placeholder_layout" , tvm::String(new_layout)); |
1050 | new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body); |
1051 | } |
1052 | } |
1053 | } |
1054 | |
1055 | // construct the map from original_op to new_op |
1056 | std::unordered_map<te::Operation, te::Operation> updated_ops; |
1057 | |
1058 | Array<te::Operation> original_ops = p_dag->ops; |
1059 | p_dag->ops.clear(); |
1060 | for (size_t i = 0; i < original_ops.size(); ++i) { |
1061 | const auto& original_op = original_ops[i]; |
1062 | if (original_op == placeholder_op) { |
1063 | if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) { |
1064 | p_dag->ops.push_back(placeholder_op); |
1065 | } |
1066 | p_dag->ops.push_back(new_op_to_update); |
1067 | updated_ops[placeholder_op] = new_op_to_update; |
1068 | } else if (original_op == original_compute_op) { |
1069 | p_dag->ops.push_back(new_compute_op); |
1070 | updated_ops[original_compute_op] = new_compute_op; |
1071 | } else { |
1072 | p_dag->ops.push_back(original_op); |
1073 | } |
1074 | } |
1075 | |
1076 | ArrayNode* pops = p_dag->ops.CopyOnWrite(); |
1077 | // Because ops is sorted in topo-order, only do one pass linear scan here. |
1078 | for (size_t i = 0; i < pops->size(); ++i) { |
1079 | const auto& original_op = Downcast<te::Operation>(pops->at(i)); |
1080 | if (auto* pop = original_op.as<te::ComputeOpNode>()) { |
1081 | if (original_op == new_op_to_update) { |
1082 | continue; |
1083 | } |
1084 | auto inputs = pop->InputTensors(); |
1085 | std::unordered_map<te::Tensor, te::Tensor> rmap; |
1086 | for (auto input : inputs) { |
1087 | auto it = updated_ops.find(input->op); |
1088 | te::Operation new_op; |
1089 | while (it != updated_ops.end()) { |
1090 | new_op = it->second; |
1091 | it = updated_ops.find(new_op); |
1092 | } |
1093 | if (new_op.defined()) { |
1094 | int index = input->value_index; |
1095 | rmap[input] = new_op.output(index); |
1096 | } |
1097 | } |
1098 | if (!rmap.empty()) { |
1099 | te::Operation new_op = pop->ReplaceInputs(original_op, rmap); |
1100 | updated_ops[original_op] = new_op; |
1101 | pops->SetItem(i, new_op); |
1102 | } |
1103 | } |
1104 | } |
1105 | |
1106 | Array<te::Tensor> old_tensors = p_dag->tensors; |
1107 | ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite(); |
1108 | for (size_t i = 0; i < old_tensors.size(); ++i) { |
1109 | const auto& old_tensor = old_tensors[i]; |
1110 | if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed && |
1111 | old_tensor->op->IsInstance<te::PlaceholderOpNode>()) { |
1112 | continue; |
1113 | } |
1114 | auto it = updated_ops.find(old_tensor->op); |
1115 | te::Operation new_op; |
1116 | while (it != updated_ops.end()) { |
1117 | new_op = it->second; |
1118 | it = updated_ops.find(new_op); |
1119 | } |
1120 | if (new_op.defined()) { |
1121 | auto index = old_tensor->value_index; |
1122 | p_tensors->SetItem(i, new_op.output(index)); |
1123 | } |
1124 | } |
1125 | } // end for placeholder |
1126 | } // end for stage |
1127 | p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors); |
1128 | |
1129 | Array<te::Operation> out_ops; |
1130 | for (const auto& op : p_dag->access_analyzer->ops_topo_order) { |
1131 | if (p_dag->access_analyzer.IsOutput(op)) { |
1132 | out_ops.push_back(op); |
1133 | } |
1134 | } |
1135 | |
1136 | p_dag->ops.clear(); |
1137 | te::Schedule sch = te::create_schedule(out_ops); |
1138 | for (auto stage : sch->stages) { |
1139 | p_dag->ops.push_back(stage->op); |
1140 | } |
1141 | p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); |
1142 | p_dag->init_state = State(p_dag->ops); |
1143 | |
1144 | return new_dag; |
1145 | } |
1146 | |
1147 | // Return whether a DAG has placeholders that are marked as "layout free". |
1148 | bool HasLayoutFreeTensors(const ComputeDAG& dag) { |
1149 | for (const auto& op : dag->ops) { |
1150 | if (!op->IsInstance<te::ComputeOpNode>()) { |
1151 | continue; |
1152 | } |
1153 | if (op->attrs.count(ComputeDAG::layout_free_placeholders_key)) { |
1154 | return true; |
1155 | } |
1156 | } |
1157 | |
1158 | return false; |
1159 | } |
1160 | |
1161 | std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps( |
1162 | const Array<Step>& transform_steps, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes, |
1163 | LayoutRewriteOption layout_rewrite) const { |
1164 | if (layout_rewrite != LayoutRewriteOption::NoRewrite && HasLayoutFreeTensors(*this) && |
1165 | !transform_steps.empty()) { |
1166 | Array<Step> steps = transform_steps; |
1167 | const auto& dag = RewriteLayout(&steps, layout_rewrite); |
1168 | return dag.ApplySteps(steps); |
1169 | } |
1170 | |
1171 | // Temporal object to be used if the input pointer is nullptr |
1172 | Array<te::Stage> temp_stages; |
1173 | StageToAxesMap temp_stage_to_axes; |
1174 | if (stages == nullptr) { |
1175 | stages = &temp_stages; |
1176 | } |
1177 | if (stage_to_axes == nullptr) { |
1178 | stage_to_axes = &temp_stage_to_axes; |
1179 | } |
1180 | Array<te::Operation> out_ops; |
1181 | for (const auto& op : operator->()->ops) { |
1182 | if (operator->()->access_analyzer.IsOutput(op)) { |
1183 | out_ops.push_back(op); |
1184 | } |
1185 | } |
1186 | |
1187 | // Create the initial schedule |
1188 | te::Schedule schedule = te::create_schedule(out_ops); |
1189 | |
1190 | // init axes |
1191 | for (const auto& x : operator->()->ops) { |
1192 | const te::Stage& stage = schedule[x]; |
1193 | stages->push_back(stage); |
1194 | UpdateStageToAxesMap(stage, stage_to_axes); |
1195 | } |
1196 | |
1197 | // Apply the history steps to TVM schedule |
1198 | // Call each step's ApplyToSchedule method |
1199 | for (const auto& step : transform_steps) { |
1200 | StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps); |
1201 | } |
1202 | |
1203 | return std::make_pair(schedule, operator->()->tensors); |
1204 | } |
1205 | |
1206 | String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const { |
1207 | Array<te::Stage> stages; |
1208 | StageToAxesMap stage_to_axes; |
1209 | Array<te::Operation> out_ops; |
1210 | for (const auto& op : operator->()->ops) { |
1211 | if (operator->()->access_analyzer.IsOutput(op)) { |
1212 | out_ops.push_back(op); |
1213 | } |
1214 | } |
1215 | // Create the initial schedule |
1216 | te::Schedule schedule = te::create_schedule(out_ops); |
1217 | |
1218 | // init axes |
1219 | for (const auto& x : operator->()->ops) { |
1220 | const te::Stage& stage = schedule[x]; |
1221 | stages.push_back(stage); |
1222 | UpdateStageToAxesMap(stage, &stage_to_axes); |
1223 | } |
1224 | |
1225 | std::stringstream ss; |
1226 | for (const auto& stage : stages) { |
1227 | if (stage->op->IsInstance<te::ComputeOpNode>()) { |
1228 | auto op_name = CleanName(stage->op->name); |
1229 | |
1230 | for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { |
1231 | ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name); |
1232 | if (i != stage->leaf_iter_vars.size() - 1) { |
1233 | ss << ", " ; |
1234 | } |
1235 | } |
1236 | ss << " = " |
1237 | << "tuple(" << op_name << ".op.axis)" |
1238 | << " + " |
1239 | << "tuple(" << op_name << ".op.reduce_axis)\n" ; |
1240 | } |
1241 | } |
1242 | // Call each step's PrintAsPythonAPI method |
1243 | for (const auto& step : transform_steps) { |
1244 | ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps); |
1245 | } |
1246 | |
1247 | return ss.str(); |
1248 | } |
1249 | |
1250 | String ComputeDAG::PrintDAG(bool simple_mode) const { |
1251 | std::stringstream ss; |
1252 | |
1253 | for (const auto& op : operator->()->ops) { |
1254 | if (op->IsInstance<te::PlaceholderOpNode>()) { |
1255 | ss << op->name << " = PLACEHOLDER " ; |
1256 | if (!simple_mode) { |
1257 | ss << op.output(0)->shape; |
1258 | } |
1259 | ss << "\n" ; |
1260 | } else if (auto pop = op.as<te::ComputeOpNode>()) { |
1261 | for (size_t k = 0; k < pop->body.size(); ++k) { |
1262 | ss << op->name << "(" ; |
1263 | for (size_t i = 0; i < pop->axis.size(); i++) { |
1264 | ss << pop->axis[i]->var->name_hint; |
1265 | if (i != pop->axis.size() - 1) { |
1266 | ss << ", " ; |
1267 | } |
1268 | } |
1269 | ss << ")" ; |
1270 | if (pop->body.size() > 1) { |
1271 | ss << ".v" << k; |
1272 | } |
1273 | if (auto p_reduce = pop->body[k].as<ReduceNode>()) { |
1274 | ICHECK_LT(k, p_reduce->combiner->result.size()); |
1275 | PrimExpr combiner = p_reduce->combiner->result[k]; |
1276 | if (combiner->IsInstance<AddNode>()) { |
1277 | ss << " += " << AsLegacyRepr(p_reduce->source[0]) << "\n" ; |
1278 | } else if (combiner->IsInstance<MaxNode>()) { |
1279 | ss << " max= " << AsLegacyRepr(p_reduce->source[0]) << "\n" ; |
1280 | } else if (combiner->IsInstance<MinNode>()) { |
1281 | ss << " min= " << AsLegacyRepr(p_reduce->source[0]) << "\n" ; |
1282 | } else if (combiner->IsInstance<SelectNode>()) { |
1283 | const auto& select = combiner.as<SelectNode>(); |
1284 | ss << " select(" << AsLegacyRepr(select->condition) // |
1285 | << ", " << AsLegacyRepr(select->true_value) // |
1286 | << ", " << AsLegacyRepr(select->false_value) // |
1287 | << ")= (" << AsLegacyRepr(p_reduce->source[0]) // |
1288 | << ',' << AsLegacyRepr(p_reduce->source[1]) // |
1289 | << ")\n" ; |
1290 | } else { |
1291 | ss << "reduce" << AsLegacyRepr(combiner) << "\n" ; |
1292 | } |
1293 | } else { |
1294 | auto call = pop->body[k].as<CallNode>(); |
1295 | if (simple_mode && call) { |
1296 | ss << " = " << AsLegacyRepr(call->op) << "\n" ; |
1297 | } else { |
1298 | ss << " = " << AsLegacyRepr(pop->body[k]) << "\n" ; |
1299 | } |
1300 | } |
1301 | } |
1302 | } else { |
1303 | LOG(FATAL) << "Invalid op" ; |
1304 | } |
1305 | } |
1306 | return String(ss.str()); |
1307 | } |
1308 | |
1309 | State ComputeDAG::InferBound(const State& state) const { |
1310 | ICHECK(state->concrete) << "Only concrete state can be processed to get bound info." ; |
1311 | |
1312 | State ret_state; |
1313 | StateNode* pstate; |
1314 | |
1315 | if (state->stages.empty()) { |
1316 | // If the input state is incomplete with empty operation stage |
1317 | // create a new state from init_state and update it first |
1318 | ret_state = operator->()->init_state; |
1319 | pstate = ret_state.CopyOnWrite(); |
1320 | pstate->transform_steps = state->transform_steps; |
1321 | for (const auto& step : pstate->transform_steps) { |
1322 | StepApplyToState(step, &ret_state, *this); |
1323 | } |
1324 | } else { |
1325 | ret_state = state; |
1326 | pstate = ret_state.CopyOnWrite(); |
1327 | } |
1328 | |
1329 | Array<te::Stage> stages; |
1330 | StageToAxesMap stage_to_axes; |
1331 | // Replay steps to tvm::Schedule |
1332 | auto [sch, tensors] = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); |
1333 | (void)tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
1334 | sch = sch.normalize_for_feature_extraction(); |
1335 | // Get bound information from TVM schedule |
1336 | Map<IterVar, Range> bounds = te::InferBound(sch); |
1337 | |
1338 | // Update the state bound information |
1339 | for (size_t i = 0; i < pstate->stages.size(); ++i) { |
1340 | const Stage& stage = pstate->stages[i]; |
1341 | |
1342 | if (stage->compute_at == ComputeAtKind::kInlined) { |
1343 | continue; |
1344 | } |
1345 | |
1346 | Array<Iterator> new_iters; |
1347 | new_iters.reserve(stage->iters.size()); |
1348 | // Get bound information from schedule |
1349 | // the StageToAxesMap is used to find the corresponding IterVar in TVM schedule result |
1350 | for (size_t j = 0; j < stage->iters.size(); ++j) { |
1351 | const Iterator& iter = stage->iters[j]; |
1352 | const IterVar& axis = stage_to_axes.at(stages[i])[j]; |
1353 | |
1354 | auto find_res = bounds.find(axis); |
1355 | if (find_res != bounds.end()) { |
1356 | new_iters.push_back(Iterator(iter->name, (*find_res).second, iter->iter_kind, |
1357 | iter->annotation, &iter->orig_iters)); |
1358 | } else { |
1359 | LOG(FATAL) << "Infer bound fails" ; |
1360 | } |
1361 | } |
1362 | |
1363 | pstate->stages.Set( |
1364 | i, Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs)); |
1365 | } |
1366 | |
1367 | return ret_state; |
1368 | } |
1369 | |
1370 | Array<State> ComputeDAG::InferBound(const Array<State>& states) const { |
1371 | Array<State> out_states(states.size(), State()); |
1372 | |
1373 | support::parallel_for(0, states.size(), [this, &states, &out_states](int i) { |
1374 | try { |
1375 | out_states.Set(i, (states[i].defined()) ? this->InferBound(states[i]) : states[i]); |
1376 | } catch (Error& e) { |
1377 | LOG(WARNING) << "InferBound fails on the state:\n" |
1378 | << states[i] << "\n" |
1379 | << "with: " << e.what() << std::endl; |
1380 | } |
1381 | }); |
1382 | |
1383 | return out_states; |
1384 | } |
1385 | |
1386 | ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const { |
1387 | auto [sch, old_tensors] = ApplySteps(transform_steps); |
1388 | (void)old_tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
1389 | return ComputeDAG(sch); |
1390 | } |
1391 | |
1392 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
1393 | .set_dispatch<AccessAnalyzerNode>([](const ObjectRef& ref, ReprPrinter* p) { |
1394 | auto* node = static_cast<const AccessAnalyzerNode*>(ref.get()); |
1395 | for (const auto& op : node->ops_topo_order) { |
1396 | p->stream << op << std::endl; |
1397 | p->stream << "is_simple_access:\t" << node->is_simple_access.at(op) << "\t\t" ; |
1398 | p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op) |
1399 | << std::endl; |
1400 | p->stream << "is_strictly_inlinable:\t" << node->is_strictly_inlineable.at(op) << "\t" ; |
1401 | p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; |
1402 | p->stream << "Read from:\t" ; |
1403 | for (const auto& pair : node->read_from.at(op)) { |
1404 | for (const auto& index : pair.second) { |
1405 | p->stream << pair.first->name << Array<PrimExpr>(index) << ", " ; |
1406 | } |
1407 | } |
1408 | p->stream << std::endl; |
1409 | p->stream << "Read by:\t" ; |
1410 | for (const auto& pair : node->read_by.at(op)) { |
1411 | for (const auto& index : pair.second) { |
1412 | p->stream << pair.first->name << Array<PrimExpr>(index) << ", " ; |
1413 | } |
1414 | } |
1415 | p->stream << std::endl; |
1416 | p->stream << Chars('=', 50) << std::endl; |
1417 | } |
1418 | |
1419 | AccessAnalyzer ana = GetRef<AccessAnalyzer>(node); |
1420 | p->stream << "ElementwiseMatch: \n" ; |
1421 | for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { |
1422 | for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { |
1423 | if (i == j) { |
1424 | continue; |
1425 | } |
1426 | if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { |
1427 | p->stream << node->ops_topo_order[i]->name << " -> " << node->ops_topo_order[j]->name |
1428 | << std::endl; |
1429 | } |
1430 | } |
1431 | } |
1432 | p->stream << Chars('=', 50) << std::endl; |
1433 | |
1434 | p->stream << "NumCommonOuterIterators: \n" ; |
1435 | for (const auto& src_pair : node->num_common_outer_iterators) { |
1436 | for (const auto& dst_pair : src_pair.second) { |
1437 | p->stream << src_pair.first->name << " " << dst_pair.first->name << " " << dst_pair.second |
1438 | << std::endl; |
1439 | } |
1440 | } |
1441 | p->stream << Chars('=', 50) << std::endl; |
1442 | }); |
1443 | |
1444 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
1445 | .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) { |
1446 | auto* node = static_cast<const ComputeDAGNode*>(ref.get()); |
1447 | auto dag = GetRef<ComputeDAG>(node); |
1448 | auto dag_str = dag.PrintDAG(); |
1449 | p->stream << dag_str; |
1450 | }); |
1451 | |
1452 | Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) { |
1453 | Array<PrimExpr> shape; |
1454 | std::vector<std::string> ; |
1455 | topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names); |
1456 | |
1457 | Array<PrimExpr> ret(axis_names.size(), 1); |
1458 | |
1459 | size_t ct = 0; |
1460 | for (size_t i = 0; i < axis_names.size(); ++i) { |
1461 | for (size_t j = 0; j < extracted_names.size(); ++j) { |
1462 | if (axis_names[i] == extracted_names[j]) { |
1463 | ret.Set(i, ret[i] * shape[j]); |
1464 | ct++; |
1465 | } |
1466 | } |
1467 | } |
1468 | |
1469 | CHECK_EQ(ct, extracted_names.size()) << "The number or names of axes do not match" ; |
1470 | |
1471 | return ret; |
1472 | } |
1473 | |
1474 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG" ) |
1475 | .set_body_typed([](Optional<Array<te::Tensor>> tensors, Optional<te::Schedule> sch) { |
1476 | if (sch) { |
1477 | return ComputeDAG(sch.value()); |
1478 | } |
1479 | ICHECK(tensors) << "Both tensors and schedule are null" ; |
1480 | return ComputeDAG(tensors.value()); |
1481 | }); |
1482 | |
1483 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState" ) |
1484 | .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { |
1485 | auto [sch, return_tensors] = dag.ApplySteps(state->transform_steps, nullptr, nullptr, |
1486 | static_cast<LayoutRewriteOption>(layout_rewrite)); |
1487 | return Array<ObjectRef>{sch, return_tensors}; |
1488 | }); |
1489 | |
1490 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintPythonCodeFromState" ) |
1491 | .set_body_typed([](const ComputeDAG& dag, const State& state) { |
1492 | return dag.PrintStepsAsPython(state->transform_steps); |
1493 | }); |
1494 | |
1495 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintDAG" ) |
1496 | .set_body_typed([](const ComputeDAG& dag, bool simple_mode) { |
1497 | return dag.PrintDAG(simple_mode); |
1498 | }); |
1499 | |
1500 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState" ) |
1501 | .set_body_typed([](const ComputeDAG& dag, const State& state) { |
1502 | return dag.InferBound(state); |
1503 | }); |
1504 | |
1505 | TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGRewriteLayoutFromState" ) |
1506 | .set_body_typed([](const ComputeDAG& dag, const State& state) { |
1507 | Array<Step>* transform_steps = const_cast<Array<Step>*>(&state->transform_steps); |
1508 | return dag.RewriteLayout(transform_steps, LayoutRewriteOption::RewriteForPreTransformed); |
1509 | }); |
1510 | |
1511 | TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout" ) |
1512 | .set_body_typed([](const te::Operation& placeholder_op, const std::string& new_layout, |
1513 | const PrimExpr& body) { |
1514 | IndexRewriter index_rewriter(placeholder_op, new_layout); |
1515 | return index_rewriter.Rewrite(body); |
1516 | }); |
1517 | |
1518 | TVM_REGISTER_GLOBAL("auto_scheduler.RewriteTensorShape" ) |
1519 | .set_body_typed([](te::Tensor tensor, Array<PrimExpr> new_shape) -> void { |
1520 | ICHECK(tensor->op->IsInstance<te::PlaceholderOpNode>()); |
1521 | te::PlaceholderOpNode* op = |
1522 | const_cast<te::PlaceholderOpNode*>(tensor->op.as<te::PlaceholderOpNode>()); |
1523 | te::TensorNode* t = const_cast<te::TensorNode*>(tensor.get()); |
1524 | op->shape = new_shape; |
1525 | t->shape = new_shape; |
1526 | }); |
1527 | |
1528 | TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout" ) |
1529 | .set_body_typed(GetShapeFromRewrittenLayout); |
1530 | |
1531 | } // namespace auto_scheduler |
1532 | } // namespace tvm |
1533 | |