1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <dispatch.h> |
6 | #include <type.h> |
7 | |
8 | #include <deque> |
9 | #include <unordered_set> |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | class Fusion; |
18 | class Statement; |
19 | class Expr; |
20 | class Val; |
21 | |
22 | /* |
23 | * IterVisitor starts from leaf nodes, fusion outputs, or the provided values. |
24 | * It walks the DAG bacwkards from the starting nodes, to roots. Each node in |
25 | * the dag will be called with handle(Statement*) in topolgical order inputs of |
26 | * the fusion to outputs of the fusion. |
27 | * |
28 | * TODO: We may want a BFS version of this code to extract ILP, not implemented |
29 | * yet. |
30 | * |
31 | * TODO: We may want to have ordering of outputs to inputs. I'm not sure why we |
32 | * would want this, but seems like it would be a reasonable request. |
33 | */ |
34 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
35 | class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch { |
36 | public: |
37 | ~IterVisitor() override = default; |
38 | |
39 | IterVisitor() = default; |
40 | |
41 | IterVisitor(const IterVisitor& other) = default; |
42 | IterVisitor& operator=(const IterVisitor& other) = default; |
43 | |
44 | IterVisitor(IterVisitor&& other) = default; |
45 | IterVisitor& operator=(IterVisitor&& other) = default; |
46 | |
47 | protected: |
48 | // Functions return nodes in reverse order to be added to the to_visit queue |
49 | // These functions will start at outputs and propagate up through the DAG |
50 | // to inputs based on depth first traversal. Next could be called on a node |
51 | // multiple times. |
52 | virtual std::vector<Statement*> next(Statement* stmt); |
53 | |
54 | virtual std::vector<Statement*> next(Val* v); |
55 | |
56 | virtual std::vector<Statement*> next(Expr* expr); |
57 | |
58 | // This handle functions is called on every Statement* in topological order, |
59 | // starting from outputs to inputs. |
60 | void handle(Statement* s) override; |
61 | |
62 | // This handle functions is called on every Expr* in topological order, |
63 | // starting from outputs to inputs. |
64 | void handle(Expr* e) override; |
65 | |
66 | // This handle functions is called on every Val* in topological order, |
67 | // starting from outputs to inputs. |
68 | void handle(Val* v) override; |
69 | |
70 | // The entire stack during traversal. stmt_stack.back().back() is the node |
71 | // that is being called in handle(). stmt_stack.back() contains siblings (not |
72 | // guarenteed to be all siblings throughout traversal). stmt_stack.front() |
73 | // contains the outputs we started with (not guarenteed to be all outputs |
74 | // throughout traversal). |
75 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
76 | std::vector<std::vector<Statement*>> stmt_stack; |
77 | |
78 | void traverseHelper(Fusion* fusion, bool traverse_all_paths = false); |
79 | |
80 | public: |
81 | //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. |
82 | //! from inputs towards outputs. |
83 | //! \param traverseAllPaths = false only call handle on each Statement* once |
84 | //! traverseAllPaths = true traverses all paths between expressions/values. |
85 | //! Calls handle on a Statement* for every path from inputs to "to". |
86 | //! \param traverseIntoMembers = When hitting nodes like TensorView, |
87 | //! TensorDomain, or IterDomain where there are members of the nodes that are |
88 | //! Val's a value of "true" will also traverse into those member Val's, a |
89 | //! value of "false" will not traverse into the members. |
90 | void traverseTo( |
91 | Fusion* fusion, |
92 | const std::vector<Val*>& to, |
93 | bool traverse_all_paths = false, |
94 | bool traverse_into_members = false); |
95 | |
96 | //! Traverses nodes in Fusion from inputs in topological order to "to". i.e. |
97 | //! from inputs towards outputs. |
98 | //! \param traverseAllPaths = false only call handle on each Statement* once |
99 | //! traverseAllPaths = true traverses all paths between expressions/values. |
100 | //! Calls handle on a Statement* for every path from inputs to "to". |
101 | //! \param traverseIntoMembers = When hitting nodes like TensorView, |
102 | //! TensorDomain, or IterDomain where there are members of the nodes that are |
103 | //! Val's a value of "true" will also traverse into those member Val's, a |
104 | //! value of "false" will not traverse into the members. |
105 | //! \param from: Specified values to start traversing. If a "from" Val is not |
106 | //! on path from inputs to "to" node it will not be visited. If there's a path |
107 | //! from inputs to "to" that doesn't go through "from" that input and the path |
108 | //! from it will also be traversed. |
109 | void traverseBetween( |
110 | Fusion* fusion, |
111 | const std::unordered_set<Val*>& from, |
112 | const std::vector<Val*>& to, |
113 | bool traverse_all_paths = false, |
114 | bool traverse_into_members = false); |
115 | |
116 | // Iterates from terminating outputs registered with the fusion. Terminating |
117 | // means value is not used to generate any other value used in producing |
118 | // registered outputs. |
119 | void traverse(Fusion* fusion); |
120 | |
121 | // Same as traverse put it traverses every edge, meaning it will traverse |
122 | // values more than once. |
123 | void traverseAllPaths(Fusion* fusion); |
124 | |
125 | //! Get inputs to vals. Possible input vals can be optionally |
126 | //! given. If not, vals with no producers are returned. |
127 | // |
128 | // TODO: This doesn't seem to fit with IterVisitor. Should probably be moved |
129 | // out of the class. |
130 | static std::vector<Val*> getInputsTo( |
131 | const std::vector<Val*>& vals, |
132 | const std::vector<Val*>& inputs = {}); |
133 | }; |
134 | |
135 | /* |
136 | * Backward visitor IterVisitor calls handle in reverse order from outputs |
137 | * to inputs It would be really nice to unify this with IterVisitor, however, |
138 | * the challenge there is that we specify traversal from outputs towards inputs |
139 | * because it implicitly provides DCE. However, if users are not careful, they |
140 | * could miss necessary outputs to do a backward traversal. |
141 | * |
142 | * BackwardVisitor checks that all outputs of an Expr is visited before visiting |
143 | * the Expr. If we don't provide nodes to start from on all backward paths of |
144 | * those outputs we will never visit the Expr. |
145 | * |
146 | * The first step of BackwardVisitor is to make sure we've specified enough |
147 | * outputs to guarentee that we will traverse all outputs of all exprs during |
148 | * the backward traversal. In the case where we don't require visiting all |
149 | * outputs of some exprs, example being the `N` output of welford ops. |
150 | * `must_cover_all_expr_outputs` is added to disable the check, and in |
151 | * this case the visitor pass need be aware |
152 | * 1. Exprs with any output that has a use chain that ends with a final |
153 | * consumer in the `from` list `will be` visited. |
154 | * 2. Vals that doesn't have a use chain that ends with a final |
155 | * consumer in the `from` list `will not be` visited, even though its |
156 | * definition expr might be visited. An example is if the `N` output |
157 | * of an welford op is unused, but other outputs are, the welford op |
158 | * will be visited but the `N` output will not. |
159 | * |
160 | */ |
161 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
162 | class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch { |
163 | protected: |
164 | // NOLINTNEXTLINE(modernize-use-override) |
165 | virtual ~BackwardVisitor() = default; |
166 | |
167 | BackwardVisitor(bool must_cover_all_expr_outputs = true) |
168 | : must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {} |
169 | |
170 | BackwardVisitor(const BackwardVisitor& other) = default; |
171 | BackwardVisitor& operator=(const BackwardVisitor& other) = default; |
172 | |
173 | BackwardVisitor(BackwardVisitor&& other) = default; |
174 | BackwardVisitor& operator=(BackwardVisitor&& other) = default; |
175 | |
176 | // Functions return nodes in reverse order to be added to the to_visit queue |
177 | // These functions will start at outputs and propagate up through the DAG |
178 | // to inputs based on depth first traversal. Next could be called on a node |
179 | // multiple times. |
180 | virtual std::vector<Statement*> next(Statement* stmt); |
181 | |
182 | virtual std::vector<Statement*> next(Expr* expr); |
183 | |
184 | virtual std::vector<Statement*> next(Val* val); |
185 | |
186 | // This handle functions is called on every Statement* in topological order, |
187 | // starting from outputs to inputs. |
188 | // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) |
189 | virtual void handle(Statement* stmt) override; |
190 | |
191 | // This handle functions is called on every Expr* in topological order, |
192 | // starting from outputs to inputs. |
193 | // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) |
194 | virtual void handle(Expr* expr) override; |
195 | |
196 | // This handle functions is called on every Val* in topological order, |
197 | // starting from outputs to inputs. |
198 | // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions) |
199 | virtual void handle(Val* val) override; |
200 | |
201 | // All exprs that need to be visited in this traversal. Labeled in topological |
202 | // order (size_t). |
203 | std::unordered_map<Expr*, size_t> traversal_exprs_; |
204 | |
205 | // The entire stack during traversal. stmt_stack.back().back() is the node |
206 | // that is being called in handle(). stmt_stack.back() contains siblings (not |
207 | // guarenteed to be all siblings throughout traversal). stmt_stack.front() |
208 | // contains the inputs we started with (not guarenteed to be all outputs |
209 | // throughout traversal). |
210 | std::deque<std::deque<Statement*>> stmt_stack_; |
211 | |
212 | // Starts at nodes provided in from, traverses from these nodes to inputs. |
213 | // Calls handle on all Statement*s in topological sorted order. |
214 | // traverseAllPaths = false only call handle on each Statement* once |
215 | // traverseAllPaths = true traverses all paths from nodes in from to inputs. |
216 | // Handle on a Statement* for every path from "from" nodes, to inputs. |
217 | void traverseTo( |
218 | Fusion* fusion, |
219 | const std::vector<Val*>& from, |
220 | bool traverseAllPaths = false); |
221 | |
222 | bool must_cover_all_expr_outputs_ = true; |
223 | }; |
224 | |
225 | class TORCH_CUDA_CU_API DependencyCheck { |
226 | public: |
227 | // Returns if "dependency" is a dependency of "of". |
228 | static bool isDependencyOf(Val* dependency, Val* of); |
229 | |
230 | // Finds a Val* path from "of" to "dependency". Returns that path. |
231 | // deque.back() is "of", deque[0] is dependency if a chain exists. |
232 | static std::deque<Val*> getSingleDependencyChain(Val* dependency, Val* of); |
233 | |
234 | // Finds all Val* paths from "of" to "dependency". Returns those paths. |
235 | // deque[i].back() is "of", and deque[i][0] is "dependency". Returns an |
236 | // empty deque if no dependency found. |
237 | static std::deque<std::deque<Val*>> getAllDependencyChains( |
238 | Val* dependency, |
239 | Val* of); |
240 | |
241 | // Finds all Val* paths from all leaf nodes to "dependency". Returns those |
242 | // paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency". |
243 | // Returns an empty deque if there are no uses of dependency found. |
244 | static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency); |
245 | |
246 | // Grab all values that exist between and including provided |
247 | // vals. Returned values are topologicaly ordered, and unique. |
248 | static std::vector<Val*> getAllValsBetween( |
249 | const std::unordered_set<Val*>& dependencies, |
250 | const std::vector<Val*>& of); |
251 | |
252 | // Returns all dependent exprs that exist between |
253 | // the provided vals |
254 | static std::vector<Expr*> getAllExprsBetween( |
255 | const std::unordered_set<Val*>& dependencies, |
256 | const std::vector<Val*>& of); |
257 | |
258 | // Return registered outputs of the fusion that are a dependency of any val of |
259 | static std::unordered_set<Val*> getAllOutputsOf( |
260 | const std::unordered_set<Val*>& of); |
261 | |
262 | // Return all Vals that depend on the given Vals |
263 | static std::unordered_set<Val*> getAllDependentVals( |
264 | const std::unordered_set<Val*>& of); |
265 | }; |
266 | |
267 | // Expr sort will take a fusion and return a topologically sorted list of |
268 | // expressions. |
269 | class StmtSort : public IterVisitor { |
270 | protected: |
271 | StmtSort() = default; |
272 | |
273 | std::vector<Statement*> stmts; |
274 | |
275 | void handle(Statement* stmt) override; |
276 | |
277 | public: |
278 | // If traverse_members it will also extract all member nodes in the sorted |
279 | // statement list in the fusion. i.e. all IterDomains, extents, and associated |
280 | // expressions of them |
281 | static std::vector<Statement*> getStmts( |
282 | Fusion* fusion, |
283 | bool traverse_members = false); |
284 | |
285 | // Returns ordered Statements required to produce from, including from. |
286 | static std::vector<Statement*> getStmts( |
287 | Fusion* fusion, |
288 | const std::vector<Val*>& to, |
289 | bool traverse_members = false); |
290 | |
291 | // Returns ordered Statements required to produce from, including from. |
292 | // Stops traversal once hiting any Statements in to. Includes Statements in |
293 | // to. |
294 | // |
295 | // Warning: this doesn't necessarily prevent statements before `to` from being |
296 | // returned. e.g. |
297 | // i1 = i0 |
298 | // i2 = i1 |
299 | // i3 = i2 |
300 | // i4 = i3 + i1 |
301 | // getExprs(fusion, {i4}, {i3}) |
302 | // will return the definition and values {i0, i1, i4} |
303 | // i3 is dependent on i1, but since i4 also is then the traversal will go down |
304 | // the i4->i1->i0 path, even though the i4->i3-//>i2->i1 path is blocked. |
305 | // |
306 | // If traverse_members it will also extract all member nodes in the sorted |
307 | // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc |
308 | static std::vector<Statement*> getStmtsBetween( |
309 | Fusion* fusion, |
310 | const std::vector<Val*>& from, |
311 | const std::vector<Val*>& to, |
312 | bool traverse_members = false); |
313 | |
314 | // Same as getStmts version but filters to only return the Expr*s |
315 | static std::vector<Expr*> getExprs( |
316 | Fusion* fusion, |
317 | bool traverse_members = false); |
318 | |
319 | // Same as getStmts version but filters to only return the Expr*s |
320 | static std::vector<Expr*> getExprs( |
321 | Fusion* fusion, |
322 | const std::vector<Val*>& to, |
323 | bool traverse_members = false); |
324 | |
325 | // Same as getStmts version but filters to only return the Expr*s |
326 | static std::vector<Expr*> getExprsBetween( |
327 | Fusion* fusion, |
328 | const std::vector<Val*>& from, |
329 | const std::vector<Val*>& to, |
330 | bool traverse_members = false); |
331 | }; |
332 | |
333 | class InputsOf : public IterVisitor { |
334 | private: |
335 | std::unordered_set<Val*> grabbed_inputs; |
336 | std::vector<Val*> ordered_inputs; |
337 | |
338 | void handle(Val* v) final; |
339 | |
340 | public: |
341 | static std::vector<Val*> output(Fusion* fusion, Val* output_); |
342 | static std::vector<Val*> outputs( |
343 | Fusion* fusion, |
344 | const std::vector<Val*>& outputs_); |
345 | }; |
346 | |
347 | } // namespace cuda |
348 | } // namespace fuser |
349 | } // namespace jit |
350 | } // namespace torch |
351 | |