1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_PASSMANAGER_PASSCONFIG_H
17#define GLOW_PASSMANAGER_PASSCONFIG_H
18
19#include "glow/Optimizer/GraphOptimizer/CompilationContext.h"
20#include "glow/Support/Support.h"
21
22#include <bitset>
23
24namespace glow {
25
26/// Specifies convergence mode for a pass.
27enum class ConvergenceMode {
28 /// Run a single pass over the Function.
29 OnePass,
30 /// Run the pass over the Function until a fixed point is reached.
31 UntilFixedPoint,
32};
33
34/// The base class for all pass config classes.
35class PassConfigBase {
36protected:
37 /// Convergence mode to inform the PassManager how to run the FunctionPass.
38 ConvergenceMode convergenceMode_{ConvergenceMode::OnePass};
39 /// Which CompilationModes the Pass should be run in.
40 unsigned enabledCompModes_;
41 /// ID of the pass.
42 unsigned passID_;
43
44public:
45 /// Destructor.
46 virtual ~PassConfigBase() = default;
47 /// Constructor.
48 PassConfigBase(unsigned passID,
49 ConvergenceMode convergenceMode = ConvergenceMode::OnePass,
50 const std::set<CompilationMode> &enabledCompModes =
51 {CompilationMode::Infer, CompilationMode::Train})
52 : convergenceMode_(convergenceMode), enabledCompModes_(0),
53 passID_(passID) {
54 for (const auto &mode : enabledCompModes) {
55 enabledCompModes_ |= 1 << (convertEnumToUnsigned(mode));
56 }
57 }
58
59 /// Constructor.
60 PassConfigBase(unsigned passID, ConvergenceMode convergenceMode,
61 unsigned enabledCompModes)
62 : convergenceMode_(convergenceMode), enabledCompModes_(enabledCompModes),
63 passID_(passID) {
64 CHECK(
65 (~((1 << convertEnumToUnsigned(CompilationMode::NumCompilationModes)) -
66 1) &
67 enabledCompModes) == 0)
68 << "Unknown compilation modes: " << enabledCompModes;
69 }
70
71 /// \returns the ConvergenceMode of this config.
72 ConvergenceMode getConvergenceMode() const { return convergenceMode_; }
73
74 /// \returns whether \p mode is an enabled mode for this config.
75 bool isEnabledForCompilationMode(CompilationMode mode) const {
76 return enabledCompModes_ & (1 << (convertEnumToUnsigned(mode)));
77 }
78
79 /// \returns enabled compilation modes.
80 unsigned getEnabledCompilationModes() const { return enabledCompModes_; }
81
82 unsigned getID() const { return passID_; }
83
84 /// Dump a textual representation of this config to \p os.
85 virtual void dump(llvm::raw_ostream &os, llvm::StringRef passName) const;
86
87 /// \returns the name of the pass for this config.
88 virtual llvm::StringRef getNameOfPass() const = 0;
89
90 /// \return true if two configs are equal.
91 virtual bool equals(const PassConfigBase &other) const;
92};
93
94/// Specifies a configuration for running an Pass when used in a
95/// PassPipeline. Pass ids are represented by the type \p PASS_ID.
96template <typename PASS_ID> class PassConfig : public PassConfigBase {
97public:
98 using PassIDTy = PASS_ID;
99
100public:
101 // Constructor.
102 PassConfig(PassIDTy ID,
103 ConvergenceMode convergenceMode = ConvergenceMode::OnePass,
104 const std::set<CompilationMode> &enabledCompModes =
105 {CompilationMode::Infer, CompilationMode::Train})
106 : PassConfigBase(static_cast<unsigned>(ID), convergenceMode,
107 enabledCompModes) {}
108 // Constructor.
109 PassConfig(PassIDTy ID, ConvergenceMode convergenceMode,
110 unsigned enabledCompModes)
111 : PassConfigBase(static_cast<unsigned>(ID), convergenceMode,
112 enabledCompModes) {}
113 // Destructor.
114 ~PassConfig() = default;
115
116 /// \returns the passID of this config.
117 PassIDTy getPassID() const { return static_cast<PassIDTy>(passID_); }
118
119 virtual llvm::StringRef getNameOfPass() const override {
120 return "<unknown pass>";
121 }
122
123 virtual void dump(llvm::raw_ostream &os,
124 llvm::StringRef passName) const override {
125 PassConfigBase::dump(os, passName);
126 }
127
128 /// Dump a textual representation of this config to \p os.
129 virtual void dump(llvm::raw_ostream &os = llvm::outs()) const {
130 dump(os, getNameOfPass());
131 }
132
133 /// \return true if two configs are equal.
134 virtual bool equals(const PassConfigBase &other) const override {
135 return (*this).PassConfigBase::equals(other);
136 }
137};
138
139} // namespace glow
140
141#endif // GLOW_PASSMANAGER_PASSCONFIG_H
142