1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H
17#define GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H
18
19#include "Base.h"
20
21#include "glow/Runtime/HostManager/HostManager.h"
22
23namespace glow {
24namespace onnxifi {
25
26class HostManagerBackend : public Backend {
27public:
28 /// Create Glow ONNXIFI backend identifier using HostManager with the
29 /// given Glow backend \p kindName, whether to use onnx or caffe2 for models
30 /// (\p useOnnx).
31 HostManagerBackend(std::shared_ptr<runtime::HostManager> hostManager,
32 llvm::StringRef backendName, bool useOnnx)
33 : Backend(backendName, useOnnx), hostManager_(hostManager) {}
34
35 void runNetwork(const Graph *graph, std::unique_ptr<ExecutionContext> context,
36 runtime::ResultCBTy callback, uint64_t priority = 0) override;
37
38 onnxStatus addNetwork(std::unique_ptr<Module> module,
39 void *deferredBlobReader, CompilationContext &cctx,
40 std::map<std::string, Type> &&staticPlaceholderTypes);
41
42 onnxStatus removeNetwork(const Graph *graph) override;
43
44 // \returns a unique_ptr to a new HostManager for the given Backend \p
45 // backendName.
46 static std::unique_ptr<runtime::HostManager>
47 createHostManager(llvm::StringRef backendName);
48
49private:
50 std::shared_ptr<runtime::HostManager> hostManager_;
51};
52
53class HostManagerGraph : public Graph {
54public:
55 using Graph::Graph;
56
57 ~HostManagerGraph() override;
58
59 /// \returns a globally unique graph id.
60 static size_t makeUniqueGraphId();
61
62 /// Init Glow graph based on the ONNX model \p onnxModel and
63 /// static trained weights \p weightDescriptors. Weights can be read in later
64 /// by a \p deferedBlobReader. \p loadingGlowAOT specifies if the model has
65 /// already been AOT optimized via Glow.
66 onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize,
67 uint32_t weightCount,
68 const onnxTensorDescriptorV1 *weightDescriptors,
69 uint32_t maxSeqLengths, void *deferedBlobReader,
70 bool loadingGlowAOT) override;
71
72 /// Async run HostManagerGraph with the given ExecutionContext \p ctx then
73 /// signal \p outputEvent when done. \p phNameToOnnxTensorOutputs is a mapping
74 /// that is generated by the base class Graph and should be used to map
75 /// copy output placeholder tensors back to the given onnxifi tensors.
76 onnxStatus run(std::unique_ptr<ExecutionContext> ctx, EventPtr outputEvent,
77 onnxTraceEventList *traceEvents) override;
78
79 /// \returns the unique string name of the HostManagerGraph that the
80 /// underlying HostManagerGraph uses to identify this network.
81 const std::string &getName() const { return netName_; }
82
83private:
84 std::string netName_;
85 std::mutex tracesMutex_;
86 std::unique_ptr<TraceContext> mergedTraceContext_;
87 int numTracesToDump_{0};
88};
89
90} // namespace onnxifi
91} // namespace glow
92
93#endif // GLOW_ONNXIFI_HOSTMANAGERONNXIFI_H
94