1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5/* `getCustomPrePasses()` returns a vector of passes that will be executed
6 * after differentiation but before any fusion. This is the de-facto location
7 * for compiler backends to insert passes.
8 *
9 * `getCustomPostPasses()` returns a vector of passes that will be
10 * executed after differentiation and after fusion (if any). This is the
11 * location for fusion cleanup passes if they are needed.
12 *
13 * Static registration of a pass can be done by creating a global
14 * `Register{Pre,Post}Pass r(Pass)` variable in a compilation unit.
15 *
16 * pass_manager.h uses a Meyer's singleton to store a vector of `Pass`es, which
17 * modify the IR graph in place.
18 */
19
20namespace torch {
21namespace jit {
22
23// A pass modifies a Graph in place.
24using GraphPass = std::function<void(std::shared_ptr<Graph>&)>;
25
26// Since Passes are std::functions, we associate a UUID to each pass, this way
27// if we want to deregister a pass, we have something to reference it by.
28using GraphPassNameType = unsigned int;
29
30// Graph pass entries have a name associated with them
31using GraphPassEntry = std::pair<GraphPass, GraphPassNameType>;
32
33// Return currently registered passes. Passes are stored in a static vector
34TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
35getCustomPostPasses();
36TORCH_API std::vector<std::pair<GraphPass, GraphPassNameType>>&
37getCustomPrePasses();
38
39TORCH_API GraphPassNameType registerPostPass(GraphPass p);
40TORCH_API GraphPassNameType registerPrePass(GraphPass p);
41
42// Look up pass by name passed in, remove it from registered passes
43TORCH_API void clearPostPass(GraphPassNameType p);
44TORCH_API void clearPrePass(GraphPassNameType p);
45
46// Remove all passes
47TORCH_API void clearAllPostPasses();
48TORCH_API void clearAllPrePasses();
49
50// LEGACY CALL
51struct TORCH_API RegisterPostPass {
52 RegisterPostPass(GraphPass p);
53};
54
55using RegisterPass = RegisterPostPass;
56
57/*
58 * PassManager is a wrapper on the register/clear PostPass functions above. It
59 * will register the pass provided in "registerPass" and will hold on to its
60 * associated name that way clearPass can be later called and will delete the
61 * pass used to register when called.
62 *
63 * PassManager is templated because we want static variables based on a
64 * particular GraphPass. When deriving from PassManager, you should send as the
65 * template parameter your derived class as you would for the curiously
66 * recurring template pattern. This template parameter isn't actually used and
67 * is simply done to prevent static members from being shared across derived
68 * types.
69 */
70template <typename DerivedType>
71struct C10_EXPORT PassManager {
72 private:
73 // We want this class to be abstract because it's
74 virtual void abstract() = 0;
75
76 protected:
77 /*
78 * isRegistered() will return if a pass has been registered
79 * isRegistered(true) will change the value of the internal static bool
80 *
81 * There's an internal static bool to this function to keep track of the
82 * state, this is so when functions are derived from this class, they don't
83 * have to worry about initializing the static members.
84 */
85 static bool isRegistered(bool flip_bit = false) {
86 static bool val = false;
87 if (flip_bit)
88 val = !val;
89 return val;
90 }
91
92 /*
93 * name() will return the name of the registered pass
94 * name(pass_name, true) will set the name of the pass
95 * Similarly to isRegistered we use an internal static variable to hold the
96 * name.
97 */
98 static GraphPassNameType passID(
99 GraphPassNameType PassID = 0,
100 bool set = false) {
101 static GraphPassNameType pass_id = 0;
102 if (set)
103 pass_id = PassID;
104 return pass_id;
105 }
106
107 public:
108 // registerPass(pass) will register the pass provided and set the
109 // name/isRegistered functions appropriately, it returns a bool value
110 // indicating whether the given pass is already registered previously.
111 static bool registerPass(GraphPass p) {
112 if (!isRegistered()) {
113 // If we don't already have a registered pass, register pass
114 // hold on to its name, change isRegistered to true
115 passID(registerPostPass(std::move(p)), true);
116 isRegistered(true);
117 return false;
118 }
119 return true;
120 }
121
122 // Calls ClearPostPass(passID())
123 static void clearPass() {
124 // If the pass is registered, clear it and change isRegistered to false.
125 if (isRegistered()) {
126 clearPostPass(passID());
127 isRegistered(true);
128 }
129 }
130
131 // clang-tidy requires virtual destructor;
132 virtual ~PassManager() = default;
133};
134
135} // namespace jit
136} // namespace torch
137