1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/partition_spec.cc
22 * \brief Combine a \p PartitionRule with a \p Target.
23 */
24
25#include "./partition_spec.h"
26
27#include "./utils.h"
28
29namespace tvm {
30namespace relay {
31namespace collage {
32
33String DefaultValidateSubGraphFunc(const Function& function) { return String(); }
34
35TVM_REGISTER_NODE_TYPE(PartitionSpecNode);
36
37void PartitionSpecNode::VisitAttrs(AttrVisitor* v) {
38 // TODO(mbs)
39}
40
41std::vector<CandidatePartition> PartitionSpecNode::AllCandidates(
42 const DataflowGraph& dataflow_graph) const {
43 std::vector<CandidatePartition> result;
44 // Make sure the target is in scope for inspection by any predicates in
45 // DFPatternPartitionRuleNode rules.
46 With<Target> target_scope(target_);
47 // Gather all the candidates.
48 std::vector<CandidatePartition> candidates =
49 rule_->AllCandidates(dataflow_graph, GetRef<PartitionSpec>(this));
50 // Update the rules names.
51 for (const auto& candidate : candidates) {
52 ICHECK_EQ(candidate->spec_, GetRef<PartitionSpec>(this));
53 String rule_name = NestLabels(spec_name_, candidate->rule_name_);
54 CandidatePartition new_candidate = WithRuleName(candidate, std::move(rule_name));
55 result.emplace_back(std::move(new_candidate));
56 }
57 return result;
58}
59
60std::string PartitionSpecNode::ToString() const {
61 Doc doc;
62 doc << "PartitionSpec(" << Doc::NewLine(2);
63 std::vector<Doc> body_items;
64 body_items.emplace_back();
65 body_items.back() << "spec_name=" << Doc::StrLiteral(spec_name_);
66 body_items.emplace_back();
67 body_items.back() << "target=" << target_->ToDebugString();
68 body_items.emplace_back();
69 body_items.back() << "rule=" << rule_->ToDoc();
70 doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine();
71 doc << ")";
72 return doc.str();
73}
74
75PartitionSpec::PartitionSpec(String spec_name, Target target, PartitionRule rule,
76 TValidateSubGraphFunc validate_sub_graph_func) {
77 auto node = runtime::make_object<PartitionSpecNode>();
78 node->spec_name_ = std::move(spec_name);
79 node->target_ = std::move(target);
80 node->rule_ = std::move(rule);
81 node->validate_sub_graph_func_ = std::move(validate_sub_graph_func);
82 data_ = std::move(node);
83}
84
85} // namespace collage
86} // namespace relay
87} // namespace tvm
88