1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | |
5 | /* `getCustomPrePasses()` returns a vector of passes that will be executed |
6 | * after differentiation but before any fusion. This is the de-facto location |
7 | * for compiler backends to insert passes. |
8 | * |
9 | * `getCustomPostPasses()` returns a vector of passes that will be |
10 | * executed after differentiation and after fusion (if any). This is the |
11 | * location for fusion cleanup passes if they are needed. |
12 | * |
13 | * Static registration of a pass can be done by creating a global |
14 | * `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit. |
15 | * |
16 | * pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which |
17 | * modify the IR graph in place. |
18 | */ |
19 | |
20 | namespace torch { |
21 | namespace jit { |
22 | |
23 | // A pass modifies a Graph in place. |
24 | using GraphPass = std::function<void(std::shared_ptr<Graph>&)>; |
25 | |
26 | // Since Passes are std::functions, we associate a UUID to each pass, this way |
27 | // if we want to deregister a pass, we have something to reference it by. |
28 | using GraphPassNameType = unsigned int; |
29 | |
30 | // Graph pass entries have a name associated with them |
31 | using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>; |
32 | |
33 | // Return currently registered passes. Passes are stored in a static vector |
34 | TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>& |
35 | getCustomPostPasses(); |
36 | TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>& |
37 | getCustomPrePasses(); |
38 | |
39 | TORCH_API GraphPassNameType registerPostPass(GraphPass p); |
40 | TORCH_API GraphPassNameType registerPrePass(GraphPass p); |
41 | |
42 | // Look up pass by name passed in, remove it from registered passes |
43 | TORCH_API void clearPostPass(GraphPassNameType p); |
44 | TORCH_API void clearPrePass(GraphPassNameType p); |
45 | |
46 | // Remove all passes |
47 | TORCH_API void clearAllPostPasses(); |
48 | TORCH_API void clearAllPrePasses(); |
49 | |
50 | // LEGACY CALL |
51 | struct TORCH_API RegisterPostPass { |
52 | RegisterPostPass(GraphPass p); |
53 | }; |
54 | |
55 | using RegisterPass = RegisterPostPass; |
56 | |
57 | /* |
58 | * PassManager is a wrapper on the register/clear PostPass functions above. It |
59 | * will register the pass provided in "registerPass" and will hold on to its |
60 | * associated name that way clearPass can be later called and will delete the |
61 | * pass used to register when called. |
62 | * |
63 | * PassManager is templated because we want static variables based on a |
64 | * particular GraphPass. When deriving from PassManager, you should send as the |
65 | * template parameter your derived class as you would for the curiously |
66 | * recurring template pattern. This template parameter isn't actually used and |
67 | * is simply done to prevent static members from being shared across derived |
68 | * types. |
69 | */ |
70 | template <typename DerivedType> |
71 | struct C10_EXPORT PassManager { |
72 | private: |
73 | // We want this class to be abstract because it's |
74 | virtual void abstract() = 0; |
75 | |
76 | protected: |
77 | /* |
78 | * isRegistered() will return if a pass has been registered |
79 | * isRegistered(true) will change the value of the internal static bool |
80 | * |
81 | * There's an internal static bool to this function to keep track of the |
82 | * state, this is so when functions are derived from this class, they don't |
83 | * have to worry about initializing the static members. |
84 | */ |
85 | static bool isRegistered(bool flip_bit = false) { |
86 | static bool val = false; |
87 | if (flip_bit) |
88 | val = !val; |
89 | return val; |
90 | } |
91 | |
92 | /* |
93 | * name() will return the name of the registered pass |
94 | * name(pass_name, true) will set the name of the pass |
95 | * Similarly to isRegistered we use an internal static variable to hold the |
96 | * name. |
97 | */ |
98 | static GraphPassNameType passID( |
99 | GraphPassNameType PassID = 0, |
100 | bool set = false) { |
101 | static GraphPassNameType pass_id = 0; |
102 | if (set) |
103 | pass_id = PassID; |
104 | return pass_id; |
105 | } |
106 | |
107 | public: |
108 | // registerPass(pass) will register the pass provided and set the |
109 | // name/isRegistered functions appropriately, it returns a bool value |
110 | // indicating whether the given pass is already registered previously. |
111 | static bool registerPass(GraphPass p) { |
112 | if (!isRegistered()) { |
113 | // If we don't already have a registered pass, register pass |
114 | // hold on to its name, change isRegistered to true |
115 | passID(registerPostPass(std::move(p)), true); |
116 | isRegistered(true); |
117 | return false; |
118 | } |
119 | return true; |
120 | } |
121 | |
122 | // Calls ClearPostPass(passID()) |
123 | static void clearPass() { |
124 | // If the pass is registered, clear it and change isRegistered to false. |
125 | if (isRegistered()) { |
126 | clearPostPass(passID()); |
127 | isRegistered(true); |
128 | } |
129 | } |
130 | |
131 | // clang-tidy requires virtual destructor; |
132 | virtual ~PassManager() = default; |
133 | }; |
134 | |
135 | } // namespace jit |
136 | } // namespace torch |
137 | |