1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H |
17 | #define GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H |
18 | |
19 | #include "Base.h" |
20 | |
21 | #include "glow/Runtime/HostManager/HostManager.h" |
22 | |
23 | namespace glow { |
24 | namespace onnxifi { |
25 | |
26 | class HostManagerBackend : public Backend { |
27 | public: |
28 | /// Create Glow ONNXIFI backend identifier using HostManager with the |
29 | /// given Glow backend \p kindName, whether to use onnx or caffe2 for models |
30 | /// (\p useOnnx). |
31 | HostManagerBackend(std::shared_ptr<runtime::HostManager> hostManager, |
32 | llvm::StringRef backendName, bool useOnnx) |
33 | : Backend(backendName, useOnnx), hostManager_(hostManager) {} |
34 | |
35 | void runNetwork(const Graph *graph, std::unique_ptr<ExecutionContext> context, |
36 | runtime::ResultCBTy callback, uint64_t priority = 0) override; |
37 | |
38 | onnxStatus addNetwork(std::unique_ptr<Module> module, |
39 | void *deferredBlobReader, CompilationContext &cctx, |
40 | std::map<std::string, Type> &&staticPlaceholderTypes); |
41 | |
42 | onnxStatus removeNetwork(const Graph *graph) override; |
43 | |
44 | // \returns a unique_ptr to a new HostManager for the given Backend \p |
45 | // backendName. |
46 | static std::unique_ptr<runtime::HostManager> |
47 | createHostManager(llvm::StringRef backendName); |
48 | |
49 | private: |
50 | std::shared_ptr<runtime::HostManager> hostManager_; |
51 | }; |
52 | |
53 | class HostManagerGraph : public Graph { |
54 | public: |
55 | using Graph::Graph; |
56 | |
57 | ~HostManagerGraph() override; |
58 | |
59 | /// \returns a globally unique graph id. |
60 | static size_t makeUniqueGraphId(); |
61 | |
62 | /// Init Glow graph based on the ONNX model \p onnxModel and |
63 | /// static trained weights \p weightDescriptors. Weights can be read in later |
64 | /// by a \p deferedBlobReader. \p loadingGlowAOT specifies if the model has |
65 | /// already been AOT optimized via Glow. |
66 | onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize, |
67 | uint32_t weightCount, |
68 | const onnxTensorDescriptorV1 *weightDescriptors, |
69 | uint32_t maxSeqLengths, void *deferedBlobReader, |
70 | bool loadingGlowAOT) override; |
71 | |
72 | /// Async run HostManagerGraph with the given ExecutionContext \p ctx then |
73 | /// signal \p outputEvent when done. \p phNameToOnnxTensorOutputs is a mapping |
74 | /// that is generated by the base class Graph and should be used to map |
75 | /// copy output placeholder tensors back to the given onnxifi tensors. |
76 | onnxStatus run(std::unique_ptr<ExecutionContext> ctx, EventPtr outputEvent, |
77 | onnxTraceEventList *traceEvents) override; |
78 | |
79 | /// \returns the unique string name of the HostManagerGraph that the |
80 | /// underlying HostManagerGraph uses to identify this network. |
81 | const std::string &getName() const { return netName_; } |
82 | |
83 | private: |
84 | std::string netName_; |
85 | std::mutex tracesMutex_; |
86 | std::unique_ptr<TraceContext> mergedTraceContext_; |
87 | int numTracesToDump_{0}; |
88 | }; |
89 | |
90 | } // namespace onnxifi |
91 | } // namespace glow |
92 | |
93 | #endif // GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H |
94 | |