1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_TOOLS_LOADER_LOADER_H
17#define GLOW_TOOLS_LOADER_LOADER_H
18
19#include "glow/ExecutionEngine/ExecutionEngine.h"
20#include "glow/Importer/ProtobufLoader.h"
21#include "glow/Importer/TFLiteModelLoader.h"
22#include "glow/Quantization/Quantization.h"
23#include "glow/Runtime/HostManager/HostManager.h"
24
25#include "llvm/Support/CommandLine.h"
26
27#include <glog/logging.h>
28
29/// Options.
30extern llvm::cl::OptionCategory loaderCat;
31
32/// Number of devices to use.
33extern llvm::cl::opt<unsigned> numDevices;
34
35/// Whether to run all inputs on all numDevices. Used for testing.
36extern llvm::cl::opt<bool> runAllInputsOnAllDevices;
37
38/// Timer option used to indicate if inferences should be timed -time.
39extern llvm::cl::opt<bool> timeOpt;
40/// Iterations used to indicate the number of iterations to run an inferece
41/// -iterations.
42extern llvm::cl::opt<unsigned> iterationsOpt;
43
44namespace glow {
45
46class Tensor;
47struct CompilationContext;
48
49/// \return true if emit bundle mode is enabled.
50bool emittingBundle();
51
52/// \return true if profiling the graph.
53bool profilingGraph();
54
55/// Parse/verify command line parameters.
56void parseCommandLine(int argc, char **argv);
57
58/// Loader extension interface from which to derive in order to extend the
59/// Loader driver object.
60class Loader;
61class ProtobufLoader;
62
63class LoaderExtension {
64public:
65 /// Called once after ONNX or Caffe2 model loading.
66 virtual void postModelLoad(Loader &, PlaceholderBindings &, ProtobufLoader &,
67 llvm::StringMap<Placeholder *> &,
68 llvm::ArrayRef<TypeRef> inputImageType) = 0;
69 virtual void postModelLoad(Loader &, PlaceholderBindings &,
70 TFLiteModelLoader &,
71 llvm::StringMap<Placeholder *> &,
72 llvm::ArrayRef<TypeRef> inputImageType) = 0;
73 /// Called once at the beginning of the mini-batch inference.
74 virtual void inferInitMiniBatch(Loader &, PlaceholderBindings &,
75 size_t minibatchIndex,
76 size_t minibatchSize) = 0;
77 /// Called once after the completion of the mini-batch inference.
78 virtual void inferEndMiniBatch(Loader &, PlaceholderBindings &,
79 size_t minibatchIndex,
80 size_t minibatchSize) = 0;
81 virtual ~LoaderExtension() {}
82};
83
84/// Driver class for loading, compiling, and running inference for ONNX and
85/// Caffe2 models.
86class Loader {
87 /// Caffe2 network file name.
88 std::string caffe2NetDescFilename_;
89 /// Caffe2 weights file name.
90 std::string caffe2NetWeightFilename_;
91 /// ONNX model file name.
92 std::string onnxModelFilename_;
93 /// TensorFlowLite model file name.
94 std::string tfliteModelFilename_;
95 /// Name of loaded function.
96 std::string functionName_;
97 /// Host Manager for running the model.
98 std::unique_ptr<glow::runtime::HostManager> hostManager_;
99 /// Backend used for saving bundle and quantization.
100 std::unique_ptr<glow::Backend> backend_;
101 /// Function containing the model.
102 Function *F_{nullptr};
103 /// Module
104 std::unique_ptr<Module> M_;
105 /// A map between quantization profiling names of NodeValues that were lowered
106 /// from each other. Maps to a set of names of NodeValues and their NodeKinds
107 /// that were replaced by the NodeValue (whose output name is the key) that
108 /// replaced them.
109 LoweredInfoMap loweredMap_;
110 /// List of Loader owned extension objects.
111 std::vector<std::unique_ptr<LoaderExtension>> loaderExtensionList_;
112 /// A map from the original names of the model inputs to placeholders.
113 llvm::StringMap<Placeholder *> inputPlaceholderByName_;
114 /// A map from the original names of the model outputs to placeholders.
115 llvm::StringMap<Placeholder *> outputPlaceholderByName_;
116 /// Info produced after calling the \ref compile function.
117 CompilationInfo compilationInfo_;
118
119public:
120 /// Getter for the hostManager, this can be useful for calling into the
121 /// HostManager directly.
122 runtime::HostManager *getHostManager() { return hostManager_.get(); }
123
124 /// Getter for the Function. This should not be called after compile since the
125 /// compile process is destructive on the original function.
126 Function *getFunction() { return F_; }
127
128 /// Getter for function name.
129 std::string getFunctionName() { return functionName_; }
130
131 /// Getter for the Module. This should not be called after compile since the
132 /// compile process is destructive on the original function and module.
133 Module *getModule() { return F_->getParent(); }
134
135 /// Getter for the Caffe2 network file name.
136 llvm::StringRef getCaffe2NetDescFilename() { return caffe2NetDescFilename_; }
137
138 /// Getter for the Caffe2 weights file name.
139 llvm::StringRef getCaffe2NetWeightFilename() {
140 return caffe2NetWeightFilename_;
141 }
142
143 /// Getter for the ONNX model file name.
144 llvm::StringRef getOnnxModelFilename() { return onnxModelFilename_; }
145
146 /// Getter for the TensorFlowLite model file name.
147 llvm::StringRef getTFLiteModelFilename() { return tfliteModelFilename_; }
148
149 /// Getter for the model path.
150 /// \pre (modelPathOpt.size() == 1)
151 static std::string getModelOptPath();
152
153 /// Getter for the model path, expected to be a directory.
154 /// \pre (modelPathOpt.size() == 1)
155 static llvm::StringRef getModelOptDir();
156
157 /// Get the quantization options based on the command line parameters of the
158 /// Loader.
159 static quantization::QuantizationConfiguration getQuantizationConfiguration();
160
161 /// Load a Caffe2, ONNX or TensorFlowLite model into this Loader object based
162 /// on the Loader command line options. If \p inputType is optionally given
163 /// then the model input is forced to have the given input type regardless of
164 /// the actual command line options (this requires for the model to have only
165 /// one input).
166 void loadModel(PlaceholderBindings *bindings = nullptr,
167 llvm::ArrayRef<TypeRef> inputType = {});
168
169 /// \returns a map between the model input names and the input placeholders.
170 /// The placeholder map is available once \ref loadModel() is called.
171 const llvm::StringMap<Placeholder *> &getInputPlaceholderMap() const {
172 return inputPlaceholderByName_;
173 }
174
175 /// \returns a map between the model output names and the output placeholders.
176 /// The placeholder map is available once \ref loadModel() is called.
177 const llvm::StringMap<Placeholder *> &getOutputPlaceholderMap() const {
178 return outputPlaceholderByName_;
179 }
180
181 /// Get the compilation options (context) for a given quantization \p mode.
182 /// The options are initialized by the Loader command line arguments.
183 CompilationContext getCompilationContext(QuantizationMode mode);
184
185 /// Get the default compilation options (context) initialized by the Loader
186 /// command line arguments.
187 CompilationContext getCompilationContext();
188
189 /// Compiles the Function F_. Handles quantization, emitting bundles, and
190 /// dumping debug information. \p bindings bind specific
191 /// placeholders to concrete tensors. The concrete tensors include
192 /// quantization profile guided information.
193 void compile(PlaceholderBindings &bindings);
194
195 /// Compiles the Function F_. Handles quantization, emitting bundles, and
196 /// dumping debug information. \p cctx is used for compiling F_.
197 void compile(CompilationContext &cctx);
198
199 /// Runs inference, unless emit bundle mode is enabled. \p bindings
200 /// binds specific placeholders to concrete tensors. The concrete
201 /// tensors include quantization profile guided information.
202 void runInference(PlaceholderBindings &bindings, size_t batchSize = 1);
203
204 /// Runs inference, \p context binds both Tensors to Placeholders and
205 /// potentially holds a TraceContext. This method allows obtaining TraceEvents
206 /// from the run.
207 void runInference(ExecutionContext *context, size_t batchSize = 1);
208
209 /// Register a loader extension.
210 Loader &registerExtension(std::unique_ptr<LoaderExtension> ext);
211 /// Called once after model loading.
212 void postModelLoad(PlaceholderBindings &bindings, ProtobufLoader &protoLoader,
213 llvm::StringMap<Placeholder *> &,
214 llvm::ArrayRef<TypeRef> inputImageType);
215 void postModelLoad(PlaceholderBindings &bindings,
216 TFLiteModelLoader &protoLoader,
217 llvm::StringMap<Placeholder *> &,
218 llvm::ArrayRef<TypeRef> inputImageType);
219 /// Called at the beginning of each mini-batch inference.
220 void inferInitMiniBatch(PlaceholderBindings &bindings, size_t minibatchIndex,
221 size_t minibatchSize);
222 /// Called after the completion of each mini-batch inference.
223 void inferEndMiniBatch(PlaceholderBindings &, size_t minibatchIndex,
224 size_t minibatchSize);
225
226 /// Generates and serializes the profiling infos after gathering a profile
227 /// by running inference one or more times. \p bindings
228 /// binds specific placeholders to concrete tensors. The concrete tensors
229 /// include quantization profile guided information.
230 void generateAndSerializeProfilingInfos(PlaceholderBindings &bindings);
231
232 /// Create the Loader driver object. If \p configDeviceIDs is empty then \ref
233 /// numDevices DeviceConfigs are created for each device, otherwise
234 /// configDeviceIDs is used to create DeviceConfigs with specified IDs.
235 Loader(llvm::ArrayRef<size_t> configDeviceIDs = {});
236};
237
238} // namespace glow
239
240#endif // GLOW_TOOLS_LOADER_LOADER_H
241