1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/collage/candidate_partition_index.h |
22 | * \brief Index for finding relevant candidate partitions for a particular search state. |
23 | */ |
24 | #ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ |
25 | #define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ |
26 | |
27 | #include <tvm/relay/expr.h> |
28 | |
29 | #include <memory> |
30 | #include <string> |
31 | #include <unordered_map> |
32 | #include <vector> |
33 | |
34 | #include "./partition_spec.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | namespace collage { |
39 | |
40 | /*! |
41 | * \brief Collects and indexes all the candidate partitions for the overall expression. This index |
42 | * is used during partitioning search to find the next valid candidate partition to explore from the |
43 | * current search state. We do not yet attempt to estimate the cost of each candidate partition, and |
44 | * when we do so during the search we may discover it to be infeasible. |
45 | */ |
46 | class CandidatePartitionIndex { |
47 | public: |
48 | CandidatePartitionIndex(const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices, |
49 | DataflowGraph* dataflow_graph); |
50 | |
51 | /*! \brief Constructs the index. */ |
52 | void Index(const Array<PartitionSpec>& partition_specs); |
53 | |
54 | /*! \brief Returns all the candidates which may begin at \p index. */ |
55 | const std::vector<CandidatePartition>& candidates_at(PostDfsIndex index) const { |
56 | ICHECK_LT(index, dataflow_graph_->size()); |
57 | return first_inside_index_to_candidates_[index]; |
58 | } |
59 | |
60 | /*! \brief Estimates the casts of all candidates in the index. Each candidate caches its cost. */ |
61 | void EstimateAllCosts(const CostEstimator cost_estimator, |
62 | const std::shared_ptr<CandidateFunctionCache>& cache); |
63 | |
64 | size_t size() const { return size_; } |
65 | |
66 | std::string ToSummary() const; |
67 | |
68 | private: |
69 | /*! |
70 | * \brief Returns true if \p candidate's desired target is compatible with any existing target |
71 | * constraints on the candidate's sub-expressions. |
72 | */ |
73 | bool IsCompatibleWithVirtualDevice(const CandidatePartition& candidate); |
74 | |
75 | /*! \brief Returns all valid candidates found from \p partition_specs. */ |
76 | std::vector<CandidatePartition> Collect(const Array<PartitionSpec>& partition_specs); |
77 | |
78 | /*! |
79 | * \brief The \p VirtualDevice for every sub-expression in the overall expression. Needed to |
80 | * ensure candidates do not contradict the target/device placement already determined by |
81 | * device planning. |
82 | */ |
83 | const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_; |
84 | |
85 | /*! \brief Dataflow graph for overall expression. */ |
86 | DataflowGraph* dataflow_graph_; |
87 | |
88 | /*! |
89 | * \brief Maps post-dfs indexes to the all the candidates which have that as their first inside |
90 | * index, and which should be considered in the Collage search. |
91 | */ |
92 | std::vector<std::vector<CandidatePartition>> first_inside_index_to_candidates_; |
93 | |
94 | /*! \brief Number of entries in above. */ |
95 | size_t size_ = 0; |
96 | }; |
97 | |
98 | } // namespace collage |
99 | } // namespace relay |
100 | } // namespace tvm |
101 | |
102 | #endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_ |
103 | |