1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file graph.h
22 * \brief Utilities to get information about schedule graph.
23 */
24#ifndef TVM_TE_SCHEDULE_GRAPH_H_
25#define TVM_TE_SCHEDULE_GRAPH_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/te/schedule.h>
29#include <tvm/tir/expr.h>
30
31#include <unordered_map>
32#include <unordered_set>
33#include <vector>
34
35namespace tvm {
36namespace te {
37
38/*!
39 * \brief data structure of Operation->Tensors it reads
40 */
41using ReadGraph = Map<Operation, Array<Tensor>>;
42
43/*!
44 * \brief AttachPath maps op-> a list of IterVar
45 */
46using AttachPath = Map<Operation, Array<IterVar>>;
47
48/*!
49 * \brief The map between tensor and operation it feeds to.
50 */
51using FeedGraph = std::unordered_map<Tensor, std::vector<Operation>>;
52
53/*!
54 * \brief Get read graph of each operation to all the
55 * Tensors that it directly depends on.
56 *
57 * The result map contains Operations needed to finish root Operation.
58 * \param roots The root operation.
59 * \return The result map.
60 */
61ReadGraph CreateReadGraph(const Array<Operation>& roots);
62
63/*!
64 * \brief Get minimum subgraph between outputs and inputs.
65 * The operations contains node which input-reachable from any inputs
66 * output reachable to any outputs.
67 *
68 * The inputs won't be included in the subgraph, the outputs will be included.
69 *
70 * \param outputs The outputs of the subgraph
71 * \param inputs The inputs to the subgraph.
72 * \param include_inputs Whether to include inputs
73 *
74 * \return The subgraph.
75 */
76Array<Operation> GetSubGraph(const Array<Tensor>& outputs, const Array<Tensor>& inputs,
77 bool include_inputs);
78
79/*!
80 * \brief Get a post DFS ordered of operations in the graph.
81 * \param roots The root of the graph.
82 * \param g The read graph.
83 * \return vector order of Operations in PostDFS order.
84 *
85 * \note PostDFSOrder is a special case of Topoligical order,
86 * and can be used when topoligical order is needed.
87 */
88Array<Operation> PostDFSOrder(const Array<Operation>& roots, const ReadGraph& g);
89
90/*!
91 * \brief Create feedgraph for given Schedule
92 * \param g The read graph.
93 * \return The created feedgraph.
94 */
95FeedGraph CreateFeedGraph(const ReadGraph& g);
96
97/*!
98 * \brief Create AttachPath that maps op-> a list of IterVar
99 * That represents the loop nest op sits in from inner most to outermost
100 * Also inserts attach_stage for scan updates when needed.
101 *
102 * \param sch The schedule.
103 * \return The attach path.
104 */
105AttachPath CreateAttachPath(Schedule sch);
106
107/*!
108 * \brief Get all operations inside the recursion of scan.
109 * \param scan_op The scan node ops.
110 * \return The body operations, in read dependency order.
111 */
112Array<Operation> ScanGetBody(const Operation& scan_op);
113
114/*!
115 * \brief Analyze each spatial dimension of scan's result.
116 * Give check on whether each dimension is fix point,
117 * An axis is a fixed point if it only refers back to itself in recursion
118 * and it is not used in axis of other recursion field.
119 *
120 * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
121 *
122 * \param scan The scan node.
123 * \return Map of spatial_axis -> IntImm
124 */
125Map<IterVar, PrimExpr> ScanFixPointAnalysis(const Operation& scan);
126
127} // namespace te
128} // namespace tvm
129
130#endif // TVM_TE_SCHEDULE_GRAPH_H_
131