1#pragma once
2
3#include <dispatch.h>
4
5namespace torch {
6namespace jit {
7namespace fuser {
8namespace cuda {
9
10class Expr;
11
12namespace kir {
13class Predicate;
14class TensorIndex;
15class ForLoop;
16class IfThenElse;
17class Scope;
18
19// Base visitor class that visits all nodes in provided vector<Expr*>.
20//
21// Includes visiting through scopes like IfThenElse and ForLoop, and tracks
22// them in scopes_ and for_loops_.
23//
24// Makes a copy of exprs at exprs_ which could be used to modify and return.
25//
26// When traversing through ITE/FLs it will use a copy
27// of the provided expressions to make it safe to insert/delete nodes.
28//
29// Provides a simple base class to inherit from for typical lowering passes on
30// Expr list
31class TORCH_CUDA_CU_API IrVisitor : public OptOutDispatch {
32 public:
33 std::vector<Expr*> handle(const std::vector<Expr*>& expr);
34
35 protected:
36 using OptOutDispatch::handle;
37
38 virtual void handle(ForLoop*) override;
39 virtual void handle(IfThenElse*) override;
40
41 protected:
42 std::vector<ForLoop*> for_loops_;
43 std::vector<Scope*> scope_;
44 std::vector<Expr*> scope_exprs_;
45 std::vector<Expr*> exprs_;
46};
47
48// Const version of IrVisitor
49class TORCH_CUDA_CU_API ConstIrVisitor : public OptOutConstDispatch {
50 public:
51 std::vector<const Expr*> handle(const std::vector<const Expr*>& expr);
52
53 protected:
54 using OptOutConstDispatch::handle;
55
56 virtual void handle(const ForLoop*) override;
57 virtual void handle(const IfThenElse*) override;
58
59 protected:
60 std::vector<const ForLoop*> for_loops_;
61 std::vector<const Scope*> scope_;
62 std::vector<const Expr*> scope_exprs_;
63 std::vector<const Expr*> exprs_;
64};
65
66// Base Expr Mutator class that visits all nodes with IrVisitor, and then
67// inserts new expressions, replaces expressions based on insertion/replace
68// maps provided or removes existing expressions. These replacement
69// maps are expected to accumulate during an initial traversal, then
70// runs an insertion based on them after the overloaded traversal.
71//
72// Order of mutations may be important, mutations are ordered according to the
73// following rules:
74// Before/After insertions are ordered as registered when reverse_order ==
75// false,
76//
77// Before/After insertions are in reverse order as registered when
78// reverse_order == true,
79//
80// Before/After insertions are done before Expr replacements, so reference for
81// insertions must be on pre-replaced Exprs
82//
83// Removal of expressions is done after replacements.
84//
85// To place in a scope that is empty, simply provide a nullptr reference
86// Since insertions are done in order, it's possible to insert an expression in
87// an empty scope, and then use that inserted scope as a reference for
88// subsequent mutations.
89class ExprMutator : public IrVisitor {
90 protected:
91 std::vector<Expr*> traverseAndInsert(
92 const std::vector<Expr*>& expr,
93 bool reverse_order = false);
94
95 std::vector<Expr*> mutate(bool reverse_order = false);
96
97 using IrVisitor::handle;
98 // Registration function which *don't* need to be called "in place" during
99 // visiting.
100 void registerInsertBefore(Expr* reference, Expr* new_expr, Scope* scope);
101 void registerInsertAfter(Expr* reference, Expr* new_expr, Scope* scope);
102 void registerReplace(Expr* reference, Expr* new_expr, Scope* scope);
103 void registerRemove(Expr* expr_to_remove, Scope* scope);
104
105 // Registration function which need to be called "in place" during visiting.
106 // I.E.
107 // if you want to insert before/after or replace an Expr, you must register
108 // when in handle(Expr*) of that expr.
109 void registerInsertBefore(Expr* reference, Expr* new_expr);
110 void registerInsertAfter(Expr* reference, Expr* new_expr);
111 void registerReplace(Expr* reference, Expr* new_expr);
112 void registerRemove(Expr* expr_to_remove);
113
114 private:
115 enum class MutationMode { BEFORE, AFTER, REPLACE, REMOVE };
116
117 void registerMutation(
118 Expr* ref,
119 Expr* new_expr,
120 Scope* scope,
121 MutationMode mode);
122
123 struct MutationInformation {
124 Expr* reference = nullptr;
125 Expr* new_expr = nullptr;
126 Scope* scope = nullptr;
127 MutationMode mode = MutationMode::BEFORE;
128 };
129
130 // Track insertions as they're registered
131 std::vector<MutationInformation> insertions_;
132
133 // Track replacements as they're registered
134 std::vector<MutationInformation> replacements_;
135
136 // Track removal as they're registered
137 std::vector<MutationInformation> removal_;
138};
139
140} // namespace kir
141} // namespace cuda
142} // namespace fuser
143} // namespace jit
144} // namespace torch
145