1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/collage/candidate_partition_index.h
22 * \brief Index for finding relevant candidate partitions for a particular search state.
23 */
24#ifndef TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
25#define TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
26
27#include <tvm/relay/expr.h>
28
29#include <memory>
30#include <string>
31#include <unordered_map>
32#include <vector>
33
34#include "./partition_spec.h"
35
36namespace tvm {
37namespace relay {
38namespace collage {
39
40/*!
41 * \brief Collects and indexes all the candidate partitions for the overall expression. This index
42 * is used during partitioning search to find the next valid candidate partition to explore from the
43 * current search state. We do not yet attempt to estimate the cost of each candidate partition, and
44 * when we do so during the search we may discover it to be infeasible.
45 */
46class CandidatePartitionIndex {
47 public:
48 CandidatePartitionIndex(const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
49 DataflowGraph* dataflow_graph);
50
51 /*! \brief Constructs the index. */
52 void Index(const Array<PartitionSpec>& partition_specs);
53
54 /*! \brief Returns all the candidates which may begin at \p index. */
55 const std::vector<CandidatePartition>& candidates_at(PostDfsIndex index) const {
56 ICHECK_LT(index, dataflow_graph_->size());
57 return first_inside_index_to_candidates_[index];
58 }
59
60 /*! \brief Estimates the casts of all candidates in the index. Each candidate caches its cost. */
61 void EstimateAllCosts(const CostEstimator cost_estimator,
62 const std::shared_ptr<CandidateFunctionCache>& cache);
63
64 size_t size() const { return size_; }
65
66 std::string ToSummary() const;
67
68 private:
69 /*!
70 * \brief Returns true if \p candidate's desired target is compatible with any existing target
71 * constraints on the candidate's sub-expressions.
72 */
73 bool IsCompatibleWithVirtualDevice(const CandidatePartition& candidate);
74
75 /*! \brief Returns all valid candidates found from \p partition_specs. */
76 std::vector<CandidatePartition> Collect(const Array<PartitionSpec>& partition_specs);
77
78 /*!
79 * \brief The \p VirtualDevice for every sub-expression in the overall expression. Needed to
80 * ensure candidates do not contradict the target/device placement already determined by
81 * device planning.
82 */
83 const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_;
84
85 /*! \brief Dataflow graph for overall expression. */
86 DataflowGraph* dataflow_graph_;
87
88 /*!
89 * \brief Maps post-dfs indexes to the all the candidates which have that as their first inside
90 * index, and which should be considered in the Collage search.
91 */
92 std::vector<std::vector<CandidatePartition>> first_inside_index_to_candidates_;
93
94 /*! \brief Number of entries in above. */
95 size_t size_ = 0;
96};
97
98} // namespace collage
99} // namespace relay
100} // namespace tvm
101
102#endif // TVM_RELAY_COLLAGE_CANDIDATE_PARTITION_INDEX_H_
103