1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_PASSMANAGER_PASSCONFIG_H |
17 | #define GLOW_PASSMANAGER_PASSCONFIG_H |
18 | |
19 | #include "glow/Optimizer/GraphOptimizer/CompilationContext.h" |
20 | #include "glow/Support/Support.h" |
21 | |
22 | #include <bitset> |
23 | |
24 | namespace glow { |
25 | |
26 | /// Specifies convergence mode for a pass. |
27 | enum class ConvergenceMode { |
28 | /// Run a single pass over the Function. |
29 | OnePass, |
30 | /// Run the pass over the Function until a fixed point is reached. |
31 | UntilFixedPoint, |
32 | }; |
33 | |
34 | /// The base class for all pass config classes. |
35 | class PassConfigBase { |
36 | protected: |
37 | /// Convergence mode to inform the PassManager how to run the FunctionPass. |
38 | ConvergenceMode convergenceMode_{ConvergenceMode::OnePass}; |
39 | /// Which CompilationModes the Pass should be run in. |
40 | unsigned enabledCompModes_; |
41 | /// ID of the pass. |
42 | unsigned passID_; |
43 | |
44 | public: |
45 | /// Destructor. |
46 | virtual ~PassConfigBase() = default; |
47 | /// Constructor. |
48 | PassConfigBase(unsigned passID, |
49 | ConvergenceMode convergenceMode = ConvergenceMode::OnePass, |
50 | const std::set<CompilationMode> &enabledCompModes = |
51 | {CompilationMode::Infer, CompilationMode::Train}) |
52 | : convergenceMode_(convergenceMode), enabledCompModes_(0), |
53 | passID_(passID) { |
54 | for (const auto &mode : enabledCompModes) { |
55 | enabledCompModes_ |= 1 << (convertEnumToUnsigned(mode)); |
56 | } |
57 | } |
58 | |
59 | /// Constructor. |
60 | PassConfigBase(unsigned passID, ConvergenceMode convergenceMode, |
61 | unsigned enabledCompModes) |
62 | : convergenceMode_(convergenceMode), enabledCompModes_(enabledCompModes), |
63 | passID_(passID) { |
64 | CHECK( |
65 | (~((1 << convertEnumToUnsigned(CompilationMode::NumCompilationModes)) - |
66 | 1) & |
67 | enabledCompModes) == 0) |
68 | << "Unknown compilation modes: " << enabledCompModes; |
69 | } |
70 | |
71 | /// \returns the ConvergenceMode of this config. |
72 | ConvergenceMode getConvergenceMode() const { return convergenceMode_; } |
73 | |
74 | /// \returns whether \p mode is an enabled mode for this config. |
75 | bool isEnabledForCompilationMode(CompilationMode mode) const { |
76 | return enabledCompModes_ & (1 << (convertEnumToUnsigned(mode))); |
77 | } |
78 | |
79 | /// \returns enabled compilation modes. |
80 | unsigned getEnabledCompilationModes() const { return enabledCompModes_; } |
81 | |
82 | unsigned getID() const { return passID_; } |
83 | |
84 | /// Dump a textual representation of this config to \p os. |
85 | virtual void dump(llvm::raw_ostream &os, llvm::StringRef passName) const; |
86 | |
87 | /// \returns the name of the pass for this config. |
88 | virtual llvm::StringRef getNameOfPass() const = 0; |
89 | |
90 | /// \return true if two configs are equal. |
91 | virtual bool equals(const PassConfigBase &other) const; |
92 | }; |
93 | |
94 | /// Specifies a configuration for running an Pass when used in a |
95 | /// PassPipeline. Pass ids are represented by the type \p PASS_ID. |
96 | template <typename PASS_ID> class PassConfig : public PassConfigBase { |
97 | public: |
98 | using PassIDTy = PASS_ID; |
99 | |
100 | public: |
101 | // Constructor. |
102 | PassConfig(PassIDTy ID, |
103 | ConvergenceMode convergenceMode = ConvergenceMode::OnePass, |
104 | const std::set<CompilationMode> &enabledCompModes = |
105 | {CompilationMode::Infer, CompilationMode::Train}) |
106 | : PassConfigBase(static_cast<unsigned>(ID), convergenceMode, |
107 | enabledCompModes) {} |
108 | // Constructor. |
109 | PassConfig(PassIDTy ID, ConvergenceMode convergenceMode, |
110 | unsigned enabledCompModes) |
111 | : PassConfigBase(static_cast<unsigned>(ID), convergenceMode, |
112 | enabledCompModes) {} |
113 | // Destructor. |
114 | ~PassConfig() = default; |
115 | |
116 | /// \returns the passID of this config. |
117 | PassIDTy getPassID() const { return static_cast<PassIDTy>(passID_); } |
118 | |
119 | virtual llvm::StringRef getNameOfPass() const override { |
120 | return "<unknown pass>" ; |
121 | } |
122 | |
123 | virtual void dump(llvm::raw_ostream &os, |
124 | llvm::StringRef passName) const override { |
125 | PassConfigBase::dump(os, passName); |
126 | } |
127 | |
128 | /// Dump a textual representation of this config to \p os. |
129 | virtual void dump(llvm::raw_ostream &os = llvm::outs()) const { |
130 | dump(os, getNameOfPass()); |
131 | } |
132 | |
133 | /// \return true if two configs are equal. |
134 | virtual bool equals(const PassConfigBase &other) const override { |
135 | return (*this).PassConfigBase::equals(other); |
136 | } |
137 | }; |
138 | |
139 | } // namespace glow |
140 | |
141 | #endif // GLOW_PASSMANAGER_PASSCONFIG_H |
142 | |