1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/partition_spec.h
22 * \brief Combine a \p PartitionRule with a \p Target.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_PARTITION_SPEC_H_
26#define TVM_RELAY_COLLAGE_PARTITION_SPEC_H_
27
28#include <tvm/relay/function.h>
29#include <tvm/runtime/container/string.h>
30#include <tvm/target/target.h>
31
32#include <string>
33#include <vector>
34
35#include "./partition_rule.h"
36#include "./sub_graph.h"
37
38namespace tvm {
39namespace relay {
40namespace collage {
41
42/*!
43 * \brief Type of functions for checking the validity of partitions before they proceed to lowering
44 * and codegen. The argument is the function extracted from the overall expression to represent
45 * the partition. The result is a non-empty error message string if the candidate should be
46 * rejected.
47 */
48using TValidateSubGraphFunc = TypedPackedFunc<String(const Function& function)>;
49
50/*!
51 * \brief The default validation function. Always returns the empty string, ie no error.
52 */
53String DefaultValidateSubGraphFunc(const Function& function);
54
55/*!
56 * \brief Pairs a \p PartitionRule with one or more \p Targets it can be used for.
57 */
58class PartitionSpecNode : public Object {
59 public:
60 /*!
61 * \brief Specification name to distinguish this spec from all others. Typically the BYOC
62 * 'compiler' name, "tvm", or "host".
63 */
64 String spec_name_;
65
66 /*!
67 * \brief The target all candidate partitions should be compiled for.
68 *
69 * It's tempting to support multiple targets here since. Eg the partitioning rules for
70 * TVM are the same irrespective of whether the target is "cuda" or "llvm", so it would make
71 * sense to build the candidate partitions first without committing to any target, then 'stamp'
72 * them for each target as the final step.
73 *
74 * However, we want to make sure any predicate in \p DFPatternPartitionRuleNode instances
75 * can have access to the current target instance. Eg the predicate may need to consult
76 * build-time configuration to decide what operators, shapes etc are actually supported.
77 * That implies the specific target is known when the candidate partitions are being constructed.
78 *
79 * So for now we'll just force each spec to have exactly one target.
80 */
81 Target target_;
82
83 /*!
84 * \brief The partition rule to use to gather candidates.
85 */
86 PartitionRule rule_;
87
88 /*!
89 * \brief The validation function to apply to each candidate's the extracted function before
90 * proceeding to lowering/codegen.
91 */
92 TValidateSubGraphFunc validate_sub_graph_func_ = DefaultValidateSubGraphFunc;
93
94 void VisitAttrs(AttrVisitor* v);
95
96 /*!
97 * \brief Returns all the candidate partitions found by this specification. The candidates
98 * will be for a specific target, but will not yet have an extracted function or cost.
99 */
100 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph) const;
101
102 std::string ToString() const;
103
104 static constexpr const char* _type_key = "relay.collage.PartitionSpec";
105 TVM_DECLARE_FINAL_OBJECT_INFO(PartitionSpecNode, Object);
106};
107
108class PartitionSpec : public ObjectRef {
109 public:
110 PartitionSpec(String spec_name, Target target, PartitionRule rule,
111 TValidateSubGraphFunc validate_sub_graph_func = DefaultValidateSubGraphFunc);
112
113 TVM_DEFINE_OBJECT_REF_METHODS(PartitionSpec, ObjectRef, PartitionSpecNode);
114};
115
116} // namespace collage
117} // namespace relay
118} // namespace tvm
119
120#endif // TVM_RELAY_COLLAGE_PARTITION_SPEC_H_
121