1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_TOOLS_LOADER_LOADER_H |
17 | #define GLOW_TOOLS_LOADER_LOADER_H |
18 | |
19 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
20 | #include "glow/Importer/ProtobufLoader.h" |
21 | #include "glow/Importer/TFLiteModelLoader.h" |
22 | #include "glow/Quantization/Quantization.h" |
23 | #include "glow/Runtime/HostManager/HostManager.h" |
24 | |
25 | #include "llvm/Support/CommandLine.h" |
26 | |
27 | #include <glog/logging.h> |
28 | |
29 | /// Options. |
30 | extern llvm::cl::OptionCategory loaderCat; |
31 | |
32 | /// Number of devices to use. |
33 | extern llvm::cl::opt<unsigned> numDevices; |
34 | |
35 | /// Whether to run all inputs on all numDevices. Used for testing. |
36 | extern llvm::cl::opt<bool> runAllInputsOnAllDevices; |
37 | |
38 | /// Timer option used to indicate if inferences should be timed -time. |
39 | extern llvm::cl::opt<bool> timeOpt; |
40 | /// Iterations used to indicate the number of iterations to run an inferece |
41 | /// -iterations. |
42 | extern llvm::cl::opt<unsigned> iterationsOpt; |
43 | |
44 | namespace glow { |
45 | |
46 | class Tensor; |
47 | struct CompilationContext; |
48 | |
49 | /// \return true if emit bundle mode is enabled. |
50 | bool emittingBundle(); |
51 | |
52 | /// \return true if profiling the graph. |
53 | bool profilingGraph(); |
54 | |
55 | /// Parse/verify command line parameters. |
56 | void parseCommandLine(int argc, char **argv); |
57 | |
58 | /// Loader extension interface from which to derive in order to extend the |
59 | /// Loader driver object. |
60 | class Loader; |
61 | class ProtobufLoader; |
62 | |
63 | class LoaderExtension { |
64 | public: |
65 | /// Called once after ONNX or Caffe2 model loading. |
66 | virtual void postModelLoad(Loader &, PlaceholderBindings &, ProtobufLoader &, |
67 | llvm::StringMap<Placeholder *> &, |
68 | llvm::ArrayRef<TypeRef> inputImageType) = 0; |
69 | virtual void postModelLoad(Loader &, PlaceholderBindings &, |
70 | TFLiteModelLoader &, |
71 | llvm::StringMap<Placeholder *> &, |
72 | llvm::ArrayRef<TypeRef> inputImageType) = 0; |
73 | /// Called once at the beginning of the mini-batch inference. |
74 | virtual void inferInitMiniBatch(Loader &, PlaceholderBindings &, |
75 | size_t minibatchIndex, |
76 | size_t minibatchSize) = 0; |
77 | /// Called once after the completion of the mini-batch inference. |
78 | virtual void inferEndMiniBatch(Loader &, PlaceholderBindings &, |
79 | size_t minibatchIndex, |
80 | size_t minibatchSize) = 0; |
81 | virtual ~LoaderExtension() {} |
82 | }; |
83 | |
84 | /// Driver class for loading, compiling, and running inference for ONNX and |
85 | /// Caffe2 models. |
86 | class Loader { |
87 | /// Caffe2 network file name. |
88 | std::string caffe2NetDescFilename_; |
89 | /// Caffe2 weights file name. |
90 | std::string caffe2NetWeightFilename_; |
91 | /// ONNX model file name. |
92 | std::string onnxModelFilename_; |
93 | /// TensorFlowLite model file name. |
94 | std::string tfliteModelFilename_; |
95 | /// Name of loaded function. |
96 | std::string functionName_; |
97 | /// Host Manager for running the model. |
98 | std::unique_ptr<glow::runtime::HostManager> hostManager_; |
99 | /// Backend used for saving bundle and quantization. |
100 | std::unique_ptr<glow::Backend> backend_; |
101 | /// Function containing the model. |
102 | Function *F_{nullptr}; |
103 | /// Module |
104 | std::unique_ptr<Module> M_; |
105 | /// A map between quantization profiling names of NodeValues that were lowered |
106 | /// from each other. Maps to a set of names of NodeValues and their NodeKinds |
107 | /// that were replaced by the NodeValue (whose output name is the key) that |
108 | /// replaced them. |
109 | LoweredInfoMap loweredMap_; |
110 | /// List of Loader owned extension objects. |
111 | std::vector<std::unique_ptr<LoaderExtension>> loaderExtensionList_; |
112 | /// A map from the original names of the model inputs to placeholders. |
113 | llvm::StringMap<Placeholder *> inputPlaceholderByName_; |
114 | /// A map from the original names of the model outputs to placeholders. |
115 | llvm::StringMap<Placeholder *> outputPlaceholderByName_; |
116 | /// Info produced after calling the \ref compile function. |
117 | CompilationInfo compilationInfo_; |
118 | |
119 | public: |
120 | /// Getter for the hostManager, this can be useful for calling into the |
121 | /// HostManager directly. |
122 | runtime::HostManager *getHostManager() { return hostManager_.get(); } |
123 | |
124 | /// Getter for the Function. This should not be called after compile since the |
125 | /// compile process is destructive on the original function. |
126 | Function *getFunction() { return F_; } |
127 | |
128 | /// Getter for function name. |
129 | std::string getFunctionName() { return functionName_; } |
130 | |
131 | /// Getter for the Module. This should not be called after compile since the |
132 | /// compile process is destructive on the original function and module. |
133 | Module *getModule() { return F_->getParent(); } |
134 | |
135 | /// Getter for the Caffe2 network file name. |
136 | llvm::StringRef getCaffe2NetDescFilename() { return caffe2NetDescFilename_; } |
137 | |
138 | /// Getter for the Caffe2 weights file name. |
139 | llvm::StringRef getCaffe2NetWeightFilename() { |
140 | return caffe2NetWeightFilename_; |
141 | } |
142 | |
143 | /// Getter for the ONNX model file name. |
144 | llvm::StringRef getOnnxModelFilename() { return onnxModelFilename_; } |
145 | |
146 | /// Getter for the TensorFlowLite model file name. |
147 | llvm::StringRef getTFLiteModelFilename() { return tfliteModelFilename_; } |
148 | |
149 | /// Getter for the model path. |
150 | /// \pre (modelPathOpt.size() == 1) |
151 | static std::string getModelOptPath(); |
152 | |
153 | /// Getter for the model path, expected to be a directory. |
154 | /// \pre (modelPathOpt.size() == 1) |
155 | static llvm::StringRef getModelOptDir(); |
156 | |
157 | /// Get the quantization options based on the command line parameters of the |
158 | /// Loader. |
159 | static quantization::QuantizationConfiguration getQuantizationConfiguration(); |
160 | |
161 | /// Load a Caffe2, ONNX or TensorFlowLite model into this Loader object based |
162 | /// on the Loader command line options. If \p inputType is optionally given |
163 | /// then the model input is forced to have the given input type regardless of |
164 | /// the actual command line options (this requires for the model to have only |
165 | /// one input). |
166 | void loadModel(PlaceholderBindings *bindings = nullptr, |
167 | llvm::ArrayRef<TypeRef> inputType = {}); |
168 | |
169 | /// \returns a map between the model input names and the input placeholders. |
170 | /// The placeholder map is available once \ref loadModel() is called. |
171 | const llvm::StringMap<Placeholder *> &getInputPlaceholderMap() const { |
172 | return inputPlaceholderByName_; |
173 | } |
174 | |
175 | /// \returns a map between the model output names and the output placeholders. |
176 | /// The placeholder map is available once \ref loadModel() is called. |
177 | const llvm::StringMap<Placeholder *> &getOutputPlaceholderMap() const { |
178 | return outputPlaceholderByName_; |
179 | } |
180 | |
181 | /// Get the compilation options (context) for a given quantization \p mode. |
182 | /// The options are initialized by the Loader command line arguments. |
183 | CompilationContext getCompilationContext(QuantizationMode mode); |
184 | |
185 | /// Get the default compilation options (context) initialized by the Loader |
186 | /// command line arguments. |
187 | CompilationContext getCompilationContext(); |
188 | |
189 | /// Compiles the Function F_. Handles quantization, emitting bundles, and |
190 | /// dumping debug information. \p bindings bind specific |
191 | /// placeholders to concrete tensors. The concrete tensors include |
192 | /// quantization profile guided information. |
193 | void compile(PlaceholderBindings &bindings); |
194 | |
195 | /// Compiles the Function F_. Handles quantization, emitting bundles, and |
196 | /// dumping debug information. \p cctx is used for compiling F_. |
197 | void compile(CompilationContext &cctx); |
198 | |
199 | /// Runs inference, unless emit bundle mode is enabled. \p bindings |
200 | /// binds specific placeholders to concrete tensors. The concrete |
201 | /// tensors include quantization profile guided information. |
202 | void runInference(PlaceholderBindings &bindings, size_t batchSize = 1); |
203 | |
204 | /// Runs inference, \p context binds both Tensors to Placeholders and |
205 | /// potentially holds a TraceContext. This method allows obtaining TraceEvents |
206 | /// from the run. |
207 | void runInference(ExecutionContext *context, size_t batchSize = 1); |
208 | |
209 | /// Register a loader extension. |
210 | Loader ®isterExtension(std::unique_ptr<LoaderExtension> ext); |
211 | /// Called once after model loading. |
212 | void postModelLoad(PlaceholderBindings &bindings, ProtobufLoader &protoLoader, |
213 | llvm::StringMap<Placeholder *> &, |
214 | llvm::ArrayRef<TypeRef> inputImageType); |
215 | void postModelLoad(PlaceholderBindings &bindings, |
216 | TFLiteModelLoader &protoLoader, |
217 | llvm::StringMap<Placeholder *> &, |
218 | llvm::ArrayRef<TypeRef> inputImageType); |
219 | /// Called at the beginning of each mini-batch inference. |
220 | void inferInitMiniBatch(PlaceholderBindings &bindings, size_t minibatchIndex, |
221 | size_t minibatchSize); |
222 | /// Called after the completion of each mini-batch inference. |
223 | void inferEndMiniBatch(PlaceholderBindings &, size_t minibatchIndex, |
224 | size_t minibatchSize); |
225 | |
226 | /// Generates and serializes the profiling infos after gathering a profile |
227 | /// by running inference one or more times. \p bindings |
228 | /// binds specific placeholders to concrete tensors. The concrete tensors |
229 | /// include quantization profile guided information. |
230 | void generateAndSerializeProfilingInfos(PlaceholderBindings &bindings); |
231 | |
232 | /// Create the Loader driver object. If \p configDeviceIDs is empty then \ref |
233 | /// numDevices DeviceConfigs are created for each device, otherwise |
234 | /// configDeviceIDs is used to create DeviceConfigs with specified IDs. |
235 | Loader(llvm::ArrayRef<size_t> configDeviceIDs = {}); |
236 | }; |
237 | |
238 | } // namespace glow |
239 | |
240 | #endif // GLOW_TOOLS_LOADER_LOADER_H |
241 | |