1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #ifndef GLOW_IMPORTER_ONNXIFIMODELLOADER_H |
18 | #define GLOW_IMPORTER_ONNXIFIMODELLOADER_H |
19 | |
20 | #include "foxi/onnxifi.h" |
21 | |
22 | #include "glow/Importer/ONNXModelLoader.h" |
23 | |
24 | #include "llvm/ADT/StringMap.h" |
25 | |
26 | namespace glow { |
27 | |
28 | class ONNXIFIModelLoader { |
29 | private: |
30 | /// Default constructor. |
31 | explicit ONNXIFIModelLoader() {} |
32 | |
33 | /// The real loader. It can be ONNXModelLoader or Caffe2ModelLoader |
34 | std::unique_ptr<ProtobufLoader> core_{nullptr}; |
35 | |
36 | public: |
37 | /// \returns mapping between ONNX names and actual Glow input vars. |
38 | const llvm::StringMap<Placeholder *> &getInputVarsMapping() const { |
39 | return core_->getInputVarsMapping(); |
40 | } |
41 | |
42 | /// \returns mapping between ONNX names and actual Glow output nodes. |
43 | const llvm::StringMap<Placeholder *> &getOutputVarsMapping() const { |
44 | return core_->getOutputVarsMapping(); |
45 | } |
46 | |
47 | /// \returns vector of primary input names based on their position |
48 | const std::vector<std::string> &getPositionalInputNames() const { |
49 | return core_->getPositionalInputNames(); |
50 | } |
51 | |
52 | /// \returns vector of primary output names based on their position |
53 | const std::vector<std::string> &getPositionalOutputNames() const { |
54 | return core_->getPositionalOutputNames(); |
55 | } |
56 | |
57 | /// \returns a unique_ptr<ONNXIFIModelLoader> if \p onnxModel can be |
58 | /// parsed and static weights can be loaded from the \p weightDescriptors. |
59 | /// \returns Error otherwise. \p loadInputsAsPlaceholdersForOnnx is passed to |
60 | /// loadInputs to determine whether graph inputs are loaded as Placeholders or |
61 | /// Tensors. Loading inputs as Tensors is useful for when weights are not |
62 | /// provided such as when the graph being loaded is actually a small patch of |
63 | /// a larger graph because the graph inputs in this case may represent |
64 | /// internal values for the larger graph. \p constFoldInLoader is used to |
65 | /// determine whether to try constant folding at load time. \p mod will be |
66 | /// filled wth one or more Functions built. If the model is pre-partitioned, |
67 | /// then prepartitionedConfig from \p cctx will be filled with relevant |
68 | /// configuration for partitioning, and all Functions created will be named |
69 | /// with prefix \p netName. Otherwise prepartitionedConfig from \p cctx is |
70 | /// ignored, and \p netName is used as the name of the single Function that is |
71 | /// created inside \p mod. backendOpts.backendSpecificNodeInfo from \p cctx |
72 | /// is filled with any info loaded from the proto, relevant for custom ONNX |
73 | /// Glow models only. \p staticPlaceholderTypes will be filled with types to |
74 | /// use for static Placeholders if the proto being parsed is a custom ONNX |
75 | /// Glow model and contains such information; otherwise it's left unchanged. |
76 | static Expected<std::unique_ptr<ONNXIFIModelLoader>> |
77 | parse(const void *onnxModel, uint32_t onnxModelSize, uint32_t weightsCount, |
78 | const onnxTensorDescriptorV1 *weightDescriptors, Module &mod, |
79 | llvm::StringRef netName, CompilationContext &cctx, |
80 | std::map<std::string, Type> *staticPlaceholderTypes, |
81 | bool loadInputsAsPlaceholdersForOnnx = true, bool use_onnx = true, |
82 | bool constFoldInLoader = true); |
83 | }; |
84 | |
85 | } // namespace glow |
86 | |
87 | #endif // GLOW_IMPORTER_ONNXIFIMODELLOADER_H |
88 | |