1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file graph.h |
22 | * \brief Utilities to get information about schedule graph. |
23 | */ |
24 | #ifndef TVM_TE_SCHEDULE_GRAPH_H_ |
25 | #define TVM_TE_SCHEDULE_GRAPH_H_ |
26 | |
27 | #include <tvm/te/operation.h> |
28 | #include <tvm/te/schedule.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace te { |
37 | |
38 | /*! |
39 | * \brief data structure of Operation->Tensors it reads |
40 | */ |
41 | using ReadGraph = Map<Operation, Array<Tensor>>; |
42 | |
43 | /*! |
44 | * \brief AttachPath maps op-> a list of IterVar |
45 | */ |
46 | using AttachPath = Map<Operation, Array<IterVar>>; |
47 | |
48 | /*! |
49 | * \brief The map between tensor and operation it feeds to. |
50 | */ |
51 | using FeedGraph = std::unordered_map<Tensor, std::vector<Operation>>; |
52 | |
53 | /*! |
54 | * \brief Get read graph of each operation to all the |
55 | * Tensors that it directly depends on. |
56 | * |
57 | * The result map contains Operations needed to finish root Operation. |
58 | * \param roots The root operation. |
59 | * \return The result map. |
60 | */ |
61 | ReadGraph CreateReadGraph(const Array<Operation>& roots); |
62 | |
63 | /*! |
64 | * \brief Get minimum subgraph between outputs and inputs. |
65 | * The operations contains node which input-reachable from any inputs |
66 | * output reachable to any outputs. |
67 | * |
68 | * The inputs won't be included in the subgraph, the outputs will be included. |
69 | * |
70 | * \param outputs The outputs of the subgraph |
71 | * \param inputs The inputs to the subgraph. |
72 | * \param include_inputs Whether to include inputs |
73 | * |
74 | * \return The subgraph. |
75 | */ |
76 | Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs, |
77 | bool include_inputs); |
78 | |
79 | /*! |
80 | * \brief Get a post DFS ordered of operations in the graph. |
81 | * \param roots The root of the graph. |
82 | * \param g The read graph. |
83 | * \return vector order of Operations in PostDFS order. |
84 | * |
85 | * \note PostDFSOrder is a special case of Topoligical order, |
86 | * and can be used when topoligical order is needed. |
87 | */ |
88 | Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g); |
89 | |
90 | /*! |
91 | * \brief Create feedgraph for given Schedule |
92 | * \param g The read graph. |
93 | * \return The created feedgraph. |
94 | */ |
95 | FeedGraph CreateFeedGraph(const ReadGraph& g); |
96 | |
97 | /*! |
98 | * \brief Create AttachPath that maps op-> a list of IterVar |
99 | * That represents the loop nest op sits in from inner most to outermost |
100 | * Also inserts attach_stage for scan updates when needed. |
101 | * |
102 | * \param sch The schedule. |
103 | * \return The attach path. |
104 | */ |
105 | AttachPath CreateAttachPath(Schedule sch); |
106 | |
107 | /*! |
108 | * \brief Get all operations inside the recursion of scan. |
109 | * \param scan_op The scan node ops. |
110 | * \return The body operations, in read dependency order. |
111 | */ |
112 | Array<Operation> ScanGetBody(const Operation& scan_op); |
113 | |
114 | /*! |
115 | * \brief Analyze each spatial dimension of scan's result. |
116 | * Give check on whether each dimension is fix point, |
117 | * An axis is a fixed point if it only refers back to itself in recursion |
118 | * and it is not used in axis of other recursion field. |
119 | * |
120 | * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...] |
121 | * |
122 | * \param scan The scan node. |
123 | * \return Map of spatial_axis -> IntImm |
124 | */ |
125 | Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan); |
126 | |
127 | } // namespace te |
128 | } // namespace tvm |
129 | |
130 | #endif // TVM_TE_SCHEDULE_GRAPH_H_ |
131 | |