1#include <torch/csrc/jit/passes/pass_manager.h>
2
3namespace torch {
4namespace jit {
5
6// Start UUID at 1
7static GraphPassNameType graphPassID = 1;
8
9std::vector<GraphPassEntry>& getCustomPostPasses() {
10 static std::vector<GraphPassEntry> passes;
11 return passes;
12}
13
14std::vector<GraphPassEntry>& getCustomPrePasses() {
15 static std::vector<GraphPassEntry> passes;
16 return passes;
17}
18
19GraphPassNameType registerPostPass(GraphPass p) {
20 getCustomPostPasses().emplace_back(std::move(p), graphPassID);
21 return graphPassID++;
22}
23
24GraphPassNameType registerPass(GraphPass p) {
25 return registerPostPass(std::move(p));
26}
27
28GraphPassNameType registerPrePass(GraphPass p) {
29 getCustomPrePasses().emplace_back(std::move(p), graphPassID);
30 return graphPassID++;
31}
32
33void clearPostPass(GraphPassNameType pid) {
34 auto& passes = getCustomPostPasses();
35 auto it = passes.begin();
36 for (; it != passes.end(); it++) {
37 if (pid == (*it).second)
38 break;
39 }
40 if (it != passes.end())
41 passes.erase(it);
42}
43
44void clearPrePass(GraphPassNameType pid) {
45 auto& passes = getCustomPrePasses();
46 auto it = passes.begin();
47 for (; it != passes.end(); it++) {
48 if (pid == (*it).second)
49 break;
50 }
51 if (it != passes.end())
52 passes.erase(it);
53}
54
55void clearAllPostPasses() {
56 auto& passes = getCustomPostPasses();
57 passes.erase(passes.begin(), passes.end());
58}
59
60void clearAllPrePasses() {
61 auto& passes = getCustomPrePasses();
62 passes.erase(passes.begin(), passes.end());
63}
64
65// LEGACY CALL
66RegisterPostPass::RegisterPostPass(GraphPass p) {
67 registerPass(std::move(p));
68}
69
70} // namespace jit
71} // namespace torch
72