1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <dispatch.h>
6#include <type.h>
7
8#include <deque>
9#include <unordered_set>
10#include <vector>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17class Fusion;
18class Statement;
19class Expr;
20class Val;
21
22/*
23 * IterVisitor starts from leaf nodes, fusion outputs, or the provided values.
24 * It walks the DAG bacwkards from the starting nodes, to roots. Each node in
25 * the dag will be called with handle(Statement*) in topolgical order inputs of
26 * the fusion to outputs of the fusion.
27 *
28 * TODO: We may want a BFS version of this code to extract ILP, not implemented
29 * yet.
30 *
31 * TODO: We may want to have ordering of outputs to inputs. I'm not sure why we
32 * would want this, but seems like it would be a reasonable request.
33 */
34// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
35class TORCH_CUDA_CU_API IterVisitor : public OptOutDispatch {
36 public:
37 ~IterVisitor() override = default;
38
39 IterVisitor() = default;
40
41 IterVisitor(const IterVisitor& other) = default;
42 IterVisitor& operator=(const IterVisitor& other) = default;
43
44 IterVisitor(IterVisitor&& other) = default;
45 IterVisitor& operator=(IterVisitor&& other) = default;
46
47 protected:
48 // Functions return nodes in reverse order to be added to the to_visit queue
49 // These functions will start at outputs and propagate up through the DAG
50 // to inputs based on depth first traversal. Next could be called on a node
51 // multiple times.
52 virtual std::vector<Statement*> next(Statement* stmt);
53
54 virtual std::vector<Statement*> next(Val* v);
55
56 virtual std::vector<Statement*> next(Expr* expr);
57
58 // This handle functions is called on every Statement* in topological order,
59 // starting from outputs to inputs.
60 void handle(Statement* s) override;
61
62 // This handle functions is called on every Expr* in topological order,
63 // starting from outputs to inputs.
64 void handle(Expr* e) override;
65
66 // This handle functions is called on every Val* in topological order,
67 // starting from outputs to inputs.
68 void handle(Val* v) override;
69
70 // The entire stack during traversal. stmt_stack.back().back() is the node
71 // that is being called in handle(). stmt_stack.back() contains siblings (not
72 // guarenteed to be all siblings throughout traversal). stmt_stack.front()
73 // contains the outputs we started with (not guarenteed to be all outputs
74 // throughout traversal).
75 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
76 std::vector<std::vector<Statement*>> stmt_stack;
77
78 void traverseHelper(Fusion* fusion, bool traverse_all_paths = false);
79
80 public:
81 //! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
82 //! from inputs towards outputs.
83 //! \param traverseAllPaths = false only call handle on each Statement* once
84 //! traverseAllPaths = true traverses all paths between expressions/values.
85 //! Calls handle on a Statement* for every path from inputs to "to".
86 //! \param traverseIntoMembers = When hitting nodes like TensorView,
87 //! TensorDomain, or IterDomain where there are members of the nodes that are
88 //! Val's a value of "true" will also traverse into those member Val's, a
89 //! value of "false" will not traverse into the members.
90 void traverseTo(
91 Fusion* fusion,
92 const std::vector<Val*>& to,
93 bool traverse_all_paths = false,
94 bool traverse_into_members = false);
95
96 //! Traverses nodes in Fusion from inputs in topological order to "to". i.e.
97 //! from inputs towards outputs.
98 //! \param traverseAllPaths = false only call handle on each Statement* once
99 //! traverseAllPaths = true traverses all paths between expressions/values.
100 //! Calls handle on a Statement* for every path from inputs to "to".
101 //! \param traverseIntoMembers = When hitting nodes like TensorView,
102 //! TensorDomain, or IterDomain where there are members of the nodes that are
103 //! Val's a value of "true" will also traverse into those member Val's, a
104 //! value of "false" will not traverse into the members.
105 //! \param from: Specified values to start traversing. If a "from" Val is not
106 //! on path from inputs to "to" node it will not be visited. If there's a path
107 //! from inputs to "to" that doesn't go through "from" that input and the path
108 //! from it will also be traversed.
109 void traverseBetween(
110 Fusion* fusion,
111 const std::unordered_set<Val*>& from,
112 const std::vector<Val*>& to,
113 bool traverse_all_paths = false,
114 bool traverse_into_members = false);
115
116 // Iterates from terminating outputs registered with the fusion. Terminating
117 // means value is not used to generate any other value used in producing
118 // registered outputs.
119 void traverse(Fusion* fusion);
120
121 // Same as traverse put it traverses every edge, meaning it will traverse
122 // values more than once.
123 void traverseAllPaths(Fusion* fusion);
124
125 //! Get inputs to vals. Possible input vals can be optionally
126 //! given. If not, vals with no producers are returned.
127 //
128 // TODO: This doesn't seem to fit with IterVisitor. Should probably be moved
129 // out of the class.
130 static std::vector<Val*> getInputsTo(
131 const std::vector<Val*>& vals,
132 const std::vector<Val*>& inputs = {});
133};
134
135/*
136 * Backward visitor IterVisitor calls handle in reverse order from outputs
137 * to inputs It would be really nice to unify this with IterVisitor, however,
138 * the challenge there is that we specify traversal from outputs towards inputs
139 * because it implicitly provides DCE. However, if users are not careful, they
140 * could miss necessary outputs to do a backward traversal.
141 *
142 * BackwardVisitor checks that all outputs of an Expr is visited before visiting
143 * the Expr. If we don't provide nodes to start from on all backward paths of
144 * those outputs we will never visit the Expr.
145 *
146 * The first step of BackwardVisitor is to make sure we've specified enough
147 * outputs to guarentee that we will traverse all outputs of all exprs during
148 * the backward traversal. In the case where we don't require visiting all
149 * outputs of some exprs, example being the `N` output of welford ops.
150 * `must_cover_all_expr_outputs` is added to disable the check, and in
151 * this case the visitor pass need be aware
152 * 1. Exprs with any output that has a use chain that ends with a final
153 * consumer in the `from` list `will be` visited.
154 * 2. Vals that doesn't have a use chain that ends with a final
155 * consumer in the `from` list `will not be` visited, even though its
156 * definition expr might be visited. An example is if the `N` output
157 * of an welford op is unused, but other outputs are, the welford op
158 * will be visited but the `N` output will not.
159 *
160 */
161// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
162class TORCH_CUDA_CU_API BackwardVisitor : public OptOutDispatch {
163 protected:
164 // NOLINTNEXTLINE(modernize-use-override)
165 virtual ~BackwardVisitor() = default;
166
167 BackwardVisitor(bool must_cover_all_expr_outputs = true)
168 : must_cover_all_expr_outputs_(must_cover_all_expr_outputs) {}
169
170 BackwardVisitor(const BackwardVisitor& other) = default;
171 BackwardVisitor& operator=(const BackwardVisitor& other) = default;
172
173 BackwardVisitor(BackwardVisitor&& other) = default;
174 BackwardVisitor& operator=(BackwardVisitor&& other) = default;
175
176 // Functions return nodes in reverse order to be added to the to_visit queue
177 // These functions will start at outputs and propagate up through the DAG
178 // to inputs based on depth first traversal. Next could be called on a node
179 // multiple times.
180 virtual std::vector<Statement*> next(Statement* stmt);
181
182 virtual std::vector<Statement*> next(Expr* expr);
183
184 virtual std::vector<Statement*> next(Val* val);
185
186 // This handle functions is called on every Statement* in topological order,
187 // starting from outputs to inputs.
188 // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
189 virtual void handle(Statement* stmt) override;
190
191 // This handle functions is called on every Expr* in topological order,
192 // starting from outputs to inputs.
193 // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
194 virtual void handle(Expr* expr) override;
195
196 // This handle functions is called on every Val* in topological order,
197 // starting from outputs to inputs.
198 // NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
199 virtual void handle(Val* val) override;
200
201 // All exprs that need to be visited in this traversal. Labeled in topological
202 // order (size_t).
203 std::unordered_map<Expr*, size_t> traversal_exprs_;
204
205 // The entire stack during traversal. stmt_stack.back().back() is the node
206 // that is being called in handle(). stmt_stack.back() contains siblings (not
207 // guarenteed to be all siblings throughout traversal). stmt_stack.front()
208 // contains the inputs we started with (not guarenteed to be all outputs
209 // throughout traversal).
210 std::deque<std::deque<Statement*>> stmt_stack_;
211
212 // Starts at nodes provided in from, traverses from these nodes to inputs.
213 // Calls handle on all Statement*s in topological sorted order.
214 // traverseAllPaths = false only call handle on each Statement* once
215 // traverseAllPaths = true traverses all paths from nodes in from to inputs.
216 // Handle on a Statement* for every path from "from" nodes, to inputs.
217 void traverseTo(
218 Fusion* fusion,
219 const std::vector<Val*>& from,
220 bool traverseAllPaths = false);
221
222 bool must_cover_all_expr_outputs_ = true;
223};
224
225class TORCH_CUDA_CU_API DependencyCheck {
226 public:
227 // Returns if "dependency" is a dependency of "of".
228 static bool isDependencyOf(Val* dependency, Val* of);
229
230 // Finds a Val* path from "of" to "dependency". Returns that path.
231 // deque.back() is "of", deque[0] is dependency if a chain exists.
232 static std::deque<Val*> getSingleDependencyChain(Val* dependency, Val* of);
233
234 // Finds all Val* paths from "of" to "dependency". Returns those paths.
235 // deque[i].back() is "of", and deque[i][0] is "dependency". Returns an
236 // empty deque if no dependency found.
237 static std::deque<std::deque<Val*>> getAllDependencyChains(
238 Val* dependency,
239 Val* of);
240
241 // Finds all Val* paths from all leaf nodes to "dependency". Returns those
242 // paths. deque[i].back() are leaf nodes, and deque[i][0] is "dependency".
243 // Returns an empty deque if there are no uses of dependency found.
244 static std::deque<std::deque<Val*>> getAllUseChains(Val* dependency);
245
246 // Grab all values that exist between and including provided
247 // vals. Returned values are topologicaly ordered, and unique.
248 static std::vector<Val*> getAllValsBetween(
249 const std::unordered_set<Val*>& dependencies,
250 const std::vector<Val*>& of);
251
252 // Returns all dependent exprs that exist between
253 // the provided vals
254 static std::vector<Expr*> getAllExprsBetween(
255 const std::unordered_set<Val*>& dependencies,
256 const std::vector<Val*>& of);
257
258 // Return registered outputs of the fusion that are a dependency of any val of
259 static std::unordered_set<Val*> getAllOutputsOf(
260 const std::unordered_set<Val*>& of);
261
262 // Return all Vals that depend on the given Vals
263 static std::unordered_set<Val*> getAllDependentVals(
264 const std::unordered_set<Val*>& of);
265};
266
267// Expr sort will take a fusion and return a topologically sorted list of
268// expressions.
269class StmtSort : public IterVisitor {
270 protected:
271 StmtSort() = default;
272
273 std::vector<Statement*> stmts;
274
275 void handle(Statement* stmt) override;
276
277 public:
278 // If traverse_members it will also extract all member nodes in the sorted
279 // statement list in the fusion. i.e. all IterDomains, extents, and associated
280 // expressions of them
281 static std::vector<Statement*> getStmts(
282 Fusion* fusion,
283 bool traverse_members = false);
284
285 // Returns ordered Statements required to produce from, including from.
286 static std::vector<Statement*> getStmts(
287 Fusion* fusion,
288 const std::vector<Val*>& to,
289 bool traverse_members = false);
290
291 // Returns ordered Statements required to produce from, including from.
292 // Stops traversal once hiting any Statements in to. Includes Statements in
293 // to.
294 //
295 // Warning: this doesn't necessarily prevent statements before `to` from being
296 // returned. e.g.
297 // i1 = i0
298 // i2 = i1
299 // i3 = i2
300 // i4 = i3 + i1
301 // getExprs(fusion, {i4}, {i3})
302 // will return the definition and values {i0, i1, i4}
303 // i3 is dependent on i1, but since i4 also is then the traversal will go down
304 // the i4->i1->i0 path, even though the i4->i3-//>i2->i1 path is blocked.
305 //
306 // If traverse_members it will also extract all member nodes in the sorted
307 // expr list in the fusion. i.e. all expressions on IterDomains, extents, etc
308 static std::vector<Statement*> getStmtsBetween(
309 Fusion* fusion,
310 const std::vector<Val*>& from,
311 const std::vector<Val*>& to,
312 bool traverse_members = false);
313
314 // Same as getStmts version but filters to only return the Expr*s
315 static std::vector<Expr*> getExprs(
316 Fusion* fusion,
317 bool traverse_members = false);
318
319 // Same as getStmts version but filters to only return the Expr*s
320 static std::vector<Expr*> getExprs(
321 Fusion* fusion,
322 const std::vector<Val*>& to,
323 bool traverse_members = false);
324
325 // Same as getStmts version but filters to only return the Expr*s
326 static std::vector<Expr*> getExprsBetween(
327 Fusion* fusion,
328 const std::vector<Val*>& from,
329 const std::vector<Val*>& to,
330 bool traverse_members = false);
331};
332
333class InputsOf : public IterVisitor {
334 private:
335 std::unordered_set<Val*> grabbed_inputs;
336 std::vector<Val*> ordered_inputs;
337
338 void handle(Val* v) final;
339
340 public:
341 static std::vector<Val*> output(Fusion* fusion, Val* output_);
342 static std::vector<Val*> outputs(
343 Fusion* fusion,
344 const std::vector<Val*>& outputs_);
345};
346
347} // namespace cuda
348} // namespace fuser
349} // namespace jit
350} // namespace torch
351