1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef GLOW_IMPORTER_CAFFE2MODELLOADER_H
18#define GLOW_IMPORTER_CAFFE2MODELLOADER_H
19
20#include "glow/Graph/Graph.h"
21#include "glow/Importer/CommonOperatorLoader.h"
22
23#include "llvm/ADT/ArrayRef.h"
24#include "llvm/ADT/StringRef.h"
25
26#include "caffe2/proto/caffe2.pb.h"
27
28#include <string>
29
30namespace caffe2 {
31class Argument;
32class OperatorDef;
33class NetDef;
34} // namespace caffe2
35
36namespace glow {
37
38class Tensor;
39class Value;
40
41/// Loads caffe2 models.
42class Caffe2ModelLoader
43 : public CommonOperatorLoader<caffe2::OperatorDef, caffe2::Argument> {
44 /// \returns True if the operator has broadcasting activated.
45 Expected<bool> getBroadcast(ArgumentDictionaryTy &dict) override;
46
47 /// \returns True if the operator with the name \p typeName has support for
48 /// multidirectional broadcasting.
49 bool hasMultidirectionalBroadcast(const llvm::StringRef typeName) override;
50
51 /// Load the weight tensors from the 'init' file and register them in the map
52 /// \p tensors.
53 Error loadWeightsFromNet(caffe2::NetDef &net);
54
55 /// Loads an individual weight \p op.
56 Error loadWeight(const caffe2::OperatorDef &op);
57
58 /// Load the structure of the network from the 'net' file.
59 Error loadNetwork(caffe2::NetDef &net);
60
61 /// Load the operator \p op into the network. This creates one or more nodes
62 /// in the network.
63 Error loadOperator(const caffe2::OperatorDef &op);
64
65 /// \returns True if the operator \p op is successfully folded.
66 Expected<bool> foldOperator(const caffe2::OperatorDef &op);
67
68 /// Helper function to print better log information for operator failure cases
69 const std::string opErrMsg(const caffe2::OperatorDef &op,
70 const std::string &errMsg);
71
72 /// Load the PRelu operator.
73 Error loadPRelu(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict);
74
75 /// Load the Conv or ConvRelu operators.
76 Error loadConv(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict);
77
78 /// Load the Softmax operator
79 Error loadSoftmax(const caffe2::OperatorDef &op, ArgumentDictionaryTy &dict);
80
81 /// Load the ConvTranspose operator.
82 Error loadConvTranspose(const caffe2::OperatorDef &op,
83 ArgumentDictionaryTy &dict);
84
85 /// Load the Int8Conv or Int8ConvRelu operators.
86 Error loadConvQuantized(const caffe2::OperatorDef &op,
87 ArgumentDictionaryTy &dict);
88
89 /// Load LayerNorm Caffe2 operator \p op given \p dict.
90 Error loadLayerNorm(const caffe2::OperatorDef &op,
91 ArgumentDictionaryTy &dict);
92
93 /// Reads a network (weights or structure) from the serialized protocol
94 /// buffer file.
95 Expected<caffe2::NetDef> loadProtoFile(const std::string &filename);
96
97 /// loadInputs calls this function for each member in its target arguments.
98 /// Currently we are supporting two tensorprototypes:
99 /// caffe2::TensorProto, caffe2::QTensorProto
100 template <class TensorProtoType>
101 Error loadInputsWithTensorProtoType(
102 const caffe2::NetDef &net,
103 const std::unordered_set<std::string> &initializers,
104 const TensorProtoType &in);
105
106 /// Creates tensor \p T from the input \p in. Note, there is no data
107 /// associated with the Tensor. This method makes sure that the tensor is
108 /// created with the proper shape and element type.
109 Expected<LoadWeightResult>
110 createAndSetTensorType(const caffe2::TensorProto &in);
111
112 /// Creates quantized tensor \p T from the input \p in. Note, there is no data
113 /// associated with the Tensor. This method makes sure that the tensor is
114 /// created with the proper shape and element type.
115 Expected<LoadWeightResult>
116 createAndSetTensorType(const caffe2::QTensorProto &in);
117
118 /// Load the inputs from the NetDef. \p initializers is the set of tensors
119 /// that should be loaded as empty Constants in the graph for the purposes of
120 /// onnxifi compatibility checks, any other inputs will be loaded as
121 /// placeholders.
122 Error loadInputs(const caffe2::NetDef &net,
123 const std::unordered_set<std::string> &initializers);
124
125 /// \returns Expected<NetDef> if a NetDef can be constructed from the
126 /// in-memory serialized protobuf.
127 /// Loads ModelProto from the in-memory serialized protobuf \p
128 /// c2Model with the model size \p c2ModelSize.
129 static Expected<caffe2::NetDef> loadProto(const void *c2Model,
130 size_t c2ModelSize);
131
132 /// Creates a Caffe2 model loader to build one or more Functions in \p mod.
133 /// Loads the ONNIXFI \p model from memory of \p modelSize size,
134 /// and \p weightsCount, and \p onnxTensorDescriptorV1 correspondent
135 /// descriptors. Reports success/failure through optional parameter \p errPtr.
136 /// This constructor always overrides the default constant folding in loader
137 /// flag with \p constFoldInLoader. If the model is pre-partitioned, then \p
138 /// PPC will be filled with relevant configuration for partitioning, and all
139 /// Functions created will be named with prefix /p funNamePrefix. Otherwise \p
140 /// PPC is ignored, and \p funNamePrefix is used as the name of the single
141 /// Function that is created inside \p mod. If \p originNameToTQPMap is
142 /// non-null then names of ops and inputs that are quantized will be mapped to
143 /// the TQP that it came with. If \p loadUniquedDummyQParams then the actual
144 /// quant params in the model will be discarded and unique dummies will be
145 /// used instead.
146 Caffe2ModelLoader(const void *model, uint32_t modelSize,
147 uint32_t weightsCount,
148 const onnxTensorDescriptorV1 *weightDescriptors,
149 Module &mod, llvm::StringRef funNamePrefix,
150 runtime::PrePartitionedConfig *PPC, Error *errPtr = nullptr,
151 bool constFoldInLoader = true,
152 OriginNameToTQPMap *originNameToTQPMap = nullptr,
153 bool loadUniquedDummyQParams = false,
154 bool zeroScaleFP16Clip = false,
155 bool clipQuantRangeToFP16 = false);
156
157 friend class ONNXIFIModelLoader;
158
159 /// Complete initialization when loading a module, including loading
160 /// pre-partitioned models, given \p networkDef loaded from caller, as well as
161 /// \p funNamePrefix, and \p PPC forwarded from caller.
162 Error initWithModule(caffe2::NetDef &networkDef,
163 llvm::StringRef funNamePrefix,
164 runtime::PrePartitionedConfig *PPC);
165
166 /// \returns success if the folding of operator \p op in the loader
167 /// \p loader is successful. The folding utility uses temporary
168 /// loader \p tmpLoader, and associated temporary function \p F.
169 template <class LoaderType, class OpType>
170 friend Error constantFoldInLoader(Function *F, LoaderType &tmpLoader,
171 LoaderType *loader, const OpType &op);
172
173public:
174 /// Loads the caffe2 model that's represented by a network description file,
175 /// serialized in \p netDescFilename, and weights file, serialized in
176 /// \p netWeightFilename, and populates the network in \p F.
177 /// The list \p types and \p names are used to initialized the inputs and
178 /// outputs with specific names and types.
179 /// If \p errPtr is not null then if an error occurs it will get assigned
180 /// there otherwise if an error occurs it will abort.
181 /// If \p originNameToTQPMap is non-null then names of ops and inputs that are
182 /// quantized will be mapped to the TQP that it came with.
183 /// If \p loadUniquedDummyQParams then the actual quant params in the model
184 /// will be discarded and unique dummies will be used instead.
185 Caffe2ModelLoader(const std::string &netDescFilename,
186 const std::string &netWeightFilename,
187 llvm::ArrayRef<const char *> names,
188 llvm::ArrayRef<TypeRef> types, Function &F,
189 Error *errPtr = nullptr,
190 OriginNameToTQPMap *originNameToTQPMap = nullptr,
191 bool loadUniquedDummyQParams = false,
192 bool zeroScaleFP16Clip = false,
193 bool clipQuantRangeToFP16 = false);
194
195 /// Loads the caffe2 model that's represented by a network description file,
196 /// serialized in \p netDescFilename, and weights file, serialized in
197 /// \p netWeightFilename, and populates the network in \p mod.
198 /// Any Functions created in \p mod will have name (or prefixed name for
199 /// pre-partitioned protos) \p funNamePrefix. \p PPC is used to store the
200 /// pre-partitioned config for the model if relevant.
201 /// The list \p types and \p names are used to initialized the inputs and
202 /// outputs with specific names and types.
203 /// If \p errPtr is not null then if an error occurs it will get assigned
204 /// there otherwise if an error occurs it will abort.
205 Caffe2ModelLoader(const std::string &netDescFilename,
206 const std::string &netWeightFilename,
207 llvm::ArrayRef<const char *> names,
208 llvm::ArrayRef<TypeRef> types, Module &mod,
209 llvm::StringRef funNamePrefix,
210 runtime::PrePartitionedConfig *PPC = nullptr,
211 Error *errPtr = nullptr);
212
213 /// Creates a Caffe2 model loader to build \p F.
214 /// If \p errPtr is not null then if an error occurs it will get assigned
215 /// there otherwise if an error occurs it will abort.
216 Caffe2ModelLoader(Function &F, Error *errPtr);
217
218 /// Creates a Caffe2 model loader that builds into what is intended to be a
219 /// dummy Module in \p dummyMod, in order to fill in \p originNameToTQPMap
220 /// with a map from C2 op names to TQPs that they were loaded with in model
221 /// \p modelStr given \p weightsCount and \p weightDescriptors. Returns any
222 /// errors into \p errPtr.
223 Caffe2ModelLoader(const std::string &modelStr, uint32_t weightsCount,
224 const onnxTensorDescriptorV1 *weightDescriptors,
225 Module &dummyMod, Error *errPtr,
226 OriginNameToTQPMap *originNameToTQPMap,
227 bool clipQuantRangeToFP16);
228};
229
230} // namespace glow
231
232#endif // GLOW_IMPORTER_CAFFE2MODELLOADER_H
233