1 | #include <torch/csrc/jit/passes/pass_manager.h> |
---|---|
2 | |
3 | namespace torch { |
4 | namespace jit { |
5 | |
6 | // Start UUID at 1 |
7 | static GraphPassNameType graphPassID = 1; |
8 | |
9 | std::vector<GraphPassEntry>& getCustomPostPasses() { |
10 | static std::vector<GraphPassEntry> passes; |
11 | return passes; |
12 | } |
13 | |
14 | std::vector<GraphPassEntry>& getCustomPrePasses() { |
15 | static std::vector<GraphPassEntry> passes; |
16 | return passes; |
17 | } |
18 | |
19 | GraphPassNameType registerPostPass(GraphPass p) { |
20 | getCustomPostPasses().emplace_back(std::move(p), graphPassID); |
21 | return graphPassID++; |
22 | } |
23 | |
24 | GraphPassNameType registerPass(GraphPass p) { |
25 | return registerPostPass(std::move(p)); |
26 | } |
27 | |
28 | GraphPassNameType registerPrePass(GraphPass p) { |
29 | getCustomPrePasses().emplace_back(std::move(p), graphPassID); |
30 | return graphPassID++; |
31 | } |
32 | |
33 | void clearPostPass(GraphPassNameType pid) { |
34 | auto& passes = getCustomPostPasses(); |
35 | auto it = passes.begin(); |
36 | for (; it != passes.end(); it++) { |
37 | if (pid == (*it).second) |
38 | break; |
39 | } |
40 | if (it != passes.end()) |
41 | passes.erase(it); |
42 | } |
43 | |
44 | void clearPrePass(GraphPassNameType pid) { |
45 | auto& passes = getCustomPrePasses(); |
46 | auto it = passes.begin(); |
47 | for (; it != passes.end(); it++) { |
48 | if (pid == (*it).second) |
49 | break; |
50 | } |
51 | if (it != passes.end()) |
52 | passes.erase(it); |
53 | } |
54 | |
55 | void clearAllPostPasses() { |
56 | auto& passes = getCustomPostPasses(); |
57 | passes.erase(passes.begin(), passes.end()); |
58 | } |
59 | |
60 | void clearAllPrePasses() { |
61 | auto& passes = getCustomPrePasses(); |
62 | passes.erase(passes.begin(), passes.end()); |
63 | } |
64 | |
65 | // LEGACY CALL |
66 | RegisterPostPass::RegisterPostPass(GraphPass p) { |
67 | registerPass(std::move(p)); |
68 | } |
69 | |
70 | } // namespace jit |
71 | } // namespace torch |
72 |