1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_IMPORTER_CAFFE2MODELLOADER_H |
18 | #define GLOW_IMPORTER_CAFFE2MODELLOADER_H |
19 | |
20 | #include "glow/Graph/Graph.h" |
21 | #include "glow/Importer/CommonOperatorLoader.h" |
22 | |
23 | #include "llvm/ADT/ArrayRef.h" |
24 | #include "llvm/ADT/StringRef.h" |
25 | |
26 | #include "caffe2/proto/caffe2.pb.h" |
27 | |
28 | #include <string> |
29 | |
30 | namespace caffe2 { |
31 | class Argument; |
32 | class OperatorDef; |
33 | class NetDef; |
34 | } // namespace caffe2 |
35 | |
36 | namespace glow { |
37 | |
38 | class Tensor; |
39 | class Value; |
40 | |
41 | /// Loads caffe2 models. |
42 | class Caffe2ModelLoader |
43 | : public CommonOperatorLoader<caffe2::OperatorDef, caffe2::Argument> { |
44 | /// \returns True if the operator has broadcasting activated. |
45 | Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) override; |
46 | |
47 | /// \returns True if the operator with the name \p typeName has support for |
48 | /// multidirectional broadcasting. |
49 | bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) override; |
50 | |
51 | /// Load the weight tensors from the 'init' file and register them in the map |
52 | /// \p tensors. |
53 | Error loadWeightsFromNet(caffe2::NetDef &net); |
54 | |
55 | /// Loads an individual weight \p op. |
56 | Error loadWeight(const caffe2::OperatorDef &op); |
57 | |
58 | /// Load the structure of the network from the 'net' file. |
59 | Error loadNetwork(caffe2::NetDef &net); |
60 | |
61 | /// Load the operator \p op into the network. This creates one or more nodes |
62 | /// in the network. |
63 | Error loadOperator(const caffe2::OperatorDef &op); |
64 | |
65 | /// \returns True if the operator \p op is successfully folded. |
66 | Expected<bool> foldOperator(const caffe2::OperatorDef &op); |
67 | |
68 | /// Helper function to print better log information for operator failure cases |
69 | const std::string opErrMsg(const caffe2::OperatorDef &op, |
70 | const std::string &errMsg); |
71 | |
72 | /// Load the PRelu operator. |
73 | Error loadPRelu(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict); |
74 | |
75 | /// Load the Conv or ConvRelu operators. |
76 | Error loadConv(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict); |
77 | |
78 | /// Load the Softmax operator |
79 | Error loadSoftmax(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict); |
80 | |
81 | /// Load the ConvTranspose operator. |
82 | Error loadConvTranspose(const caffe2::OperatorDef &op, |
83 | ArgumentDictionaryTy &dict); |
84 | |
85 | /// Load the Int8Conv or Int8ConvRelu operators. |
86 | Error loadConvQuantized(const caffe2::OperatorDef &op, |
87 | ArgumentDictionaryTy &dict); |
88 | |
89 | /// Load LayerNorm Caffe2 operator \p op given \p dict. |
90 | Error loadLayerNorm(const caffe2::OperatorDef &op, |
91 | ArgumentDictionaryTy &dict); |
92 | |
93 | /// Reads a network (weights or structure) from the serialized protocol |
94 | /// buffer file. |
95 | Expected<caffe2::NetDef> loadProtoFile(const std::string &filename); |
96 | |
97 | /// loadInputs calls this function for each member in its target arguments. |
98 | /// Currently we are supporting two tensorprototypes: |
99 | /// caffe2::TensorProto, caffe2::QTensorProto |
100 | template <class TensorProtoType> |
101 | Error loadInputsWithTensorProtoType( |
102 | const caffe2::NetDef &net, |
103 | const std::unordered_set<std::string> &initializers, |
104 | const TensorProtoType &in); |
105 | |
106 | /// Creates tensor \p T from the input \p in. Note, there is no data |
107 | /// associated with the Tensor. This method makes sure that the tensor is |
108 | /// created with the proper shape and element type. |
109 | Expected<LoadWeightResult> |
110 | createAndSetTensorType(const caffe2::TensorProto &in); |
111 | |
112 | /// Creates quantized tensor \p T from the input \p in. Note, there is no data |
113 | /// associated with the Tensor. This method makes sure that the tensor is |
114 | /// created with the proper shape and element type. |
115 | Expected<LoadWeightResult> |
116 | createAndSetTensorType(const caffe2::QTensorProto &in); |
117 | |
118 | /// Load the inputs from the NetDef. \p initializers is the set of tensors |
119 | /// that should be loaded as empty Constants in the graph for the purposes of |
120 | /// onnxifi compatibility checks, any other inputs will be loaded as |
121 | /// placeholders. |
122 | Error loadInputs(const caffe2::NetDef &net, |
123 | const std::unordered_set<std::string> &initializers); |
124 | |
125 | /// \returns Expected<NetDef> if a NetDef can be constructed from the |
126 | /// in-memory serialized protobuf. |
127 | /// Loads ModelProto from the in-memory serialized protobuf \p |
128 | /// c2Model with the model size \p c2ModelSize. |
129 | static Expected<caffe2::NetDef> loadProto(const void *c2Model, |
130 | size_t c2ModelSize); |
131 | |
132 | /// Creates a Caffe2 model loader to build one or more Functions in \p mod. |
133 | /// Loads the ONNIXFI \p model from memory of \p modelSize size, |
134 | /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent |
135 | /// descriptors. Reports success/failure through optional parameter \p errPtr. |
136 | /// This constructor always overrides the default constant folding in loader |
137 | /// flag with \p constFoldInLoader. If the model is pre-partitioned, then \p |
138 | /// PPC will be filled with relevant configuration for partitioning, and all |
139 | /// Functions created will be named with prefix /p funNamePrefix. Otherwise \p |
140 | /// PPC is ignored, and \p funNamePrefix is used as the name of the single |
141 | /// Function that is created inside \p mod. If \p originNameToTQPMap is |
142 | /// non-null then names of ops and inputs that are quantized will be mapped to |
143 | /// the TQP that it came with. If \p loadUniquedDummyQParams then the actual |
144 | /// quant params in the model will be discarded and unique dummies will be |
145 | /// used instead. |
146 | Caffe2ModelLoader(const void *model, uint32_t modelSize, |
147 | uint32_t weightsCount, |
148 | const onnxTensorDescriptorV1 *weightDescriptors, |
149 | Module &mod, llvm::StringRef funNamePrefix, |
150 | runtime::PrePartitionedConfig *PPC, Error *errPtr = nullptr, |
151 | bool constFoldInLoader = true, |
152 | OriginNameToTQPMap *originNameToTQPMap = nullptr, |
153 | bool loadUniquedDummyQParams = false, |
154 | bool zeroScaleFP16Clip = false, |
155 | bool clipQuantRangeToFP16 = false); |
156 | |
157 | friend class ONNXIFIModelLoader; |
158 | |
159 | /// Complete initialization when loading a module, including loading |
160 | /// pre-partitioned models, given \p networkDef loaded from caller, as well as |
161 | /// \p funNamePrefix, and \p PPC forwarded from caller. |
162 | Error initWithModule(caffe2::NetDef &networkDef, |
163 | llvm::StringRef funNamePrefix, |
164 | runtime::PrePartitionedConfig *PPC); |
165 | |
166 | /// \returns success if the folding of operator \p op in the loader |
167 | /// \p loader is successful. The folding utility uses temporary |
168 | /// loader \p tmpLoader, and associated temporary function \p F. |
169 | template <class LoaderType, class OpType> |
170 | friend Error constantFoldInLoader(Function *F, LoaderType &tmpLoader, |
171 | LoaderType *loader, const OpType &op); |
172 | |
173 | public: |
174 | /// Loads the caffe2 model that's represented by a network description file, |
175 | /// serialized in \p netDescFilename, and weights file, serialized in |
176 | /// \p netWeightFilename, and populates the network in \p F. |
177 | /// The list \p types and \p names are used to initialized the inputs and |
178 | /// outputs with specific names and types. |
179 | /// If \p errPtr is not null then if an error occurs it will get assigned |
180 | /// there otherwise if an error occurs it will abort. |
181 | /// If \p originNameToTQPMap is non-null then names of ops and inputs that are |
182 | /// quantized will be mapped to the TQP that it came with. |
183 | /// If \p loadUniquedDummyQParams then the actual quant params in the model |
184 | /// will be discarded and unique dummies will be used instead. |
185 | Caffe2ModelLoader(const std::string &netDescFilename, |
186 | const std::string &netWeightFilename, |
187 | llvm::ArrayRef<const char *> names, |
188 | llvm::ArrayRef<TypeRef> types, Function &F, |
189 | Error *errPtr = nullptr, |
190 | OriginNameToTQPMap *originNameToTQPMap = nullptr, |
191 | bool loadUniquedDummyQParams = false, |
192 | bool zeroScaleFP16Clip = false, |
193 | bool clipQuantRangeToFP16 = false); |
194 | |
195 | /// Loads the caffe2 model that's represented by a network description file, |
196 | /// serialized in \p netDescFilename, and weights file, serialized in |
197 | /// \p netWeightFilename, and populates the network in \p mod. |
198 | /// Any Functions created in \p mod will have name (or prefixed name for |
199 | /// pre-partitioned protos) \p funNamePrefix. \p PPC is used to store the |
200 | /// pre-partitioned config for the model if relevant. |
201 | /// The list \p types and \p names are used to initialized the inputs and |
202 | /// outputs with specific names and types. |
203 | /// If \p errPtr is not null then if an error occurs it will get assigned |
204 | /// there otherwise if an error occurs it will abort. |
205 | Caffe2ModelLoader(const std::string &netDescFilename, |
206 | const std::string &netWeightFilename, |
207 | llvm::ArrayRef<const char *> names, |
208 | llvm::ArrayRef<TypeRef> types, Module &mod, |
209 | llvm::StringRef funNamePrefix, |
210 | runtime::PrePartitionedConfig *PPC = nullptr, |
211 | Error *errPtr = nullptr); |
212 | |
213 | /// Creates a Caffe2 model loader to build \p F. |
214 | /// If \p errPtr is not null then if an error occurs it will get assigned |
215 | /// there otherwise if an error occurs it will abort. |
216 | Caffe2ModelLoader(Function &F, Error *errPtr); |
217 | |
218 | /// Creates a Caffe2 model loader that builds into what is intended to be a |
219 | /// dummy Module in \p dummyMod, in order to fill in \p originNameToTQPMap |
220 | /// with a map from C2 op names to TQPs that they were loaded with in model |
221 | /// \p modelStr given \p weightsCount and \p weightDescriptors. Returns any |
222 | /// errors into \p errPtr. |
223 | Caffe2ModelLoader(const std::string &modelStr, uint32_t weightsCount, |
224 | const onnxTensorDescriptorV1 *weightDescriptors, |
225 | Module &dummyMod, Error *errPtr, |
226 | OriginNameToTQPMap *originNameToTQPMap, |
227 | bool clipQuantRangeToFP16); |
228 | }; |
229 | |
230 | } // namespace glow |
231 | |
232 | #endif // GLOW_IMPORTER_CAFFE2MODELLOADER_H |
233 | |