1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/partition_spec.h |
22 | * \brief Combine a \p PartitionRule with a \p Target. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ |
26 | #define TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ |
27 | |
28 | #include <tvm/relay/function.h> |
29 | #include <tvm/runtime/container/string.h> |
30 | #include <tvm/target/target.h> |
31 | |
32 | #include <string> |
33 | #include <vector> |
34 | |
35 | #include "./partition_rule.h" |
36 | #include "./sub_graph.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | namespace collage { |
41 | |
42 | /*! |
43 | * \brief Type of functions for checking the validity of partitions before they proceed to lowering |
44 | * and codegen. The argument is the function extracted from the overall expression to represent |
45 | * the partition. The result is a non-empty error message string if the candidate should be |
46 | * rejected. |
47 | */ |
48 | using TValidateSubGraphFunc = TypedPackedFunc<String(const Function& function)>; |
49 | |
50 | /*! |
51 | * \brief The default validation function. Always returns the empty string, ie no error. |
52 | */ |
53 | String DefaultValidateSubGraphFunc(const Function& function); |
54 | |
55 | /*! |
56 | * \brief Pairs a \p PartitionRule with one or more \p Targets it can be used for. |
57 | */ |
58 | class PartitionSpecNode : public Object { |
59 | public: |
60 | /*! |
61 | * \brief Specification name to distinguish this spec from all others. Typically the BYOC |
62 | * 'compiler' name, "tvm", or "host". |
63 | */ |
64 | String spec_name_; |
65 | |
66 | /*! |
67 | * \brief The target all candidate partitions should be compiled for. |
68 | * |
69 | * It's tempting to support multiple targets here since. Eg the partitioning rules for |
70 | * TVM are the same irrespective of whether the target is "cuda" or "llvm", so it would make |
71 | * sense to build the candidate partitions first without committing to any target, then 'stamp' |
72 | * them for each target as the final step. |
73 | * |
74 | * However, we want to make sure any predicate in \p DFPatternPartitionRuleNode instances |
75 | * can have access to the current target instance. Eg the predicate may need to consult |
76 | * build-time configuration to decide what operators, shapes etc are actually supported. |
77 | * That implies the specific target is known when the candidate partitions are being constructed. |
78 | * |
79 | * So for now we'll just force each spec to have exactly one target. |
80 | */ |
81 | Target target_; |
82 | |
83 | /*! |
84 | * \brief The partition rule to use to gather candidates. |
85 | */ |
86 | PartitionRule rule_; |
87 | |
88 | /*! |
89 | * \brief The validation function to apply to each candidate's the extracted function before |
90 | * proceeding to lowering/codegen. |
91 | */ |
92 | TValidateSubGraphFunc validate_sub_graph_func_ = DefaultValidateSubGraphFunc; |
93 | |
94 | void VisitAttrs(AttrVisitor* v); |
95 | |
96 | /*! |
97 | * \brief Returns all the candidate partitions found by this specification. The candidates |
98 | * will be for a specific target, but will not yet have an extracted function or cost. |
99 | */ |
100 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph) const; |
101 | |
102 | std::string ToString() const; |
103 | |
104 | static constexpr const char* _type_key = "relay.collage.PartitionSpec" ; |
105 | TVM_DECLARE_FINAL_OBJECT_INFO(PartitionSpecNode, Object); |
106 | }; |
107 | |
108 | class PartitionSpec : public ObjectRef { |
109 | public: |
110 | PartitionSpec(String spec_name, Target target, PartitionRule rule, |
111 | TValidateSubGraphFunc validate_sub_graph_func = DefaultValidateSubGraphFunc); |
112 | |
113 | TVM_DEFINE_OBJECT_REF_METHODS(PartitionSpec, ObjectRef, PartitionSpecNode); |
114 | }; |
115 | |
116 | } // namespace collage |
117 | } // namespace relay |
118 | } // namespace tvm |
119 | |
120 | #endif // TVM_RELAY_COLLAGE_PARTITION_SPEC_H_ |
121 | |