1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/partition_spec.cc |
22 | * \brief Combine a \p PartitionRule with a \p Target. |
23 | */ |
24 | |
25 | #include "./partition_spec.h" |
26 | |
27 | #include "./utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace collage { |
32 | |
33 | String DefaultValidateSubGraphFunc(const Function& function) { return String(); } |
34 | |
35 | TVM_REGISTER_NODE_TYPE(PartitionSpecNode); |
36 | |
37 | void PartitionSpecNode::VisitAttrs(AttrVisitor* v) { |
38 | // TODO(mbs) |
39 | } |
40 | |
41 | std::vector<CandidatePartition> PartitionSpecNode::AllCandidates( |
42 | const DataflowGraph& dataflow_graph) const { |
43 | std::vector<CandidatePartition> result; |
44 | // Make sure the target is in scope for inspection by any predicates in |
45 | // DFPatternPartitionRuleNode rules. |
46 | With<Target> target_scope(target_); |
47 | // Gather all the candidates. |
48 | std::vector<CandidatePartition> candidates = |
49 | rule_->AllCandidates(dataflow_graph, GetRef<PartitionSpec>(this)); |
50 | // Update the rules names. |
51 | for (const auto& candidate : candidates) { |
52 | ICHECK_EQ(candidate->spec_, GetRef<PartitionSpec>(this)); |
53 | String rule_name = NestLabels(spec_name_, candidate->rule_name_); |
54 | CandidatePartition new_candidate = WithRuleName(candidate, std::move(rule_name)); |
55 | result.emplace_back(std::move(new_candidate)); |
56 | } |
57 | return result; |
58 | } |
59 | |
60 | std::string PartitionSpecNode::ToString() const { |
61 | Doc doc; |
62 | doc << "PartitionSpec(" << Doc::NewLine(2); |
63 | std::vector<Doc> body_items; |
64 | body_items.emplace_back(); |
65 | body_items.back() << "spec_name=" << Doc::StrLiteral(spec_name_); |
66 | body_items.emplace_back(); |
67 | body_items.back() << "target=" << target_->ToDebugString(); |
68 | body_items.emplace_back(); |
69 | body_items.back() << "rule=" << rule_->ToDoc(); |
70 | doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); |
71 | doc << ")" ; |
72 | return doc.str(); |
73 | } |
74 | |
75 | PartitionSpec::PartitionSpec(String spec_name, Target target, PartitionRule rule, |
76 | TValidateSubGraphFunc validate_sub_graph_func) { |
77 | auto node = runtime::make_object<PartitionSpecNode>(); |
78 | node->spec_name_ = std::move(spec_name); |
79 | node->target_ = std::move(target); |
80 | node->rule_ = std::move(rule); |
81 | node->validate_sub_graph_func_ = std::move(validate_sub_graph_func); |
82 | data_ = std::move(node); |
83 | } |
84 | |
85 | } // namespace collage |
86 | } // namespace relay |
87 | } // namespace tvm |
88 | |