1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_ONNXIFI_BASE_H
17#define GLOW_ONNXIFI_BASE_H
18
19#include "glow/Backend/Backend.h"
20#include "glow/ExecutionEngine/ExecutionEngine.h"
21#include "glow/Importer/ONNXIFIModelLoader.h"
22#include "glow/Runtime/RuntimeTypes.h"
23#include "glow/Support/TensorPool.h"
24
25#include "foxi/onnxifi.h"
26#include "foxi/onnxifi_ext.h"
27
28#include <atomic>
29#include <condition_variable>
30#include <mutex>
31
32#include "llvm/ADT/StringMap.h"
33
34namespace glow {
35namespace onnxifi {
36
37class Graph;
38
39/// Backend associated with the Glow backend.
40class Backend {
41public:
42 /// Create Glow ONNXIFI backend identifier with the
43 /// given Glow backend \p backendName, whether to use onnx or caffe2 for
44 /// models
45 /// (\p useOnnx)
46 Backend(llvm::StringRef backendName, bool useOnnx)
47 : useOnnx_(useOnnx), glowBackend_(createBackend(backendName)) {}
48
49 virtual ~Backend() = default;
50
51 /// Verify that a given onnx graph is supported by the backend by importing
52 /// the onnx graph to a glow function, lowering this function, and checking
53 /// that all of the glow nodes that are contained in the lowered graph are
54 /// compatible with the glow backend.
55 onnxStatus checkGraphCompatibility(const void *onnxModel,
56 size_t onnxModelSize);
57
58 /// \returns the whether use onnx or not.
59 bool getUseOnnx() const { return useOnnx_; }
60
61 /// \returns a reference to the backend.
62 const glow::Backend &getBackend() const { return *glowBackend_; }
63
64 virtual void runNetwork(const Graph *graph,
65 std::unique_ptr<ExecutionContext> context,
66 runtime::ResultCBTy callback, uint64_t priority = 0) {
67 }
68
69 virtual onnxStatus removeNetwork(const Graph *graph) {
70 return ONNXIFI_STATUS_SUCCESS;
71 }
72
73protected:
74 bool useOnnx_;
75 std::unique_ptr<glow::Backend> glowBackend_;
76};
77
78typedef Backend *BackendPtr;
79
80class Event {
81public:
82 Event() : fired_{false} {}
83 /// Signal the event.
84 bool signal(onnxStatus status);
85
86 /// Wait until the event is signalled.
87 onnxStatus wait();
88
89 /// Wait until the event is signalled or until at least \p timeoutMs
90 /// milliseconds have elapsed. \returns a pair with the first value being a
91 /// boolean that is true if the event was signalled (no timeout occurred) and
92 /// the second is the value of the event's status.
93 std::pair<bool, onnxStatus> waitFor(size_t timeoutMs);
94
95 /// Check if event was signalled.
96 bool isSignalled() { return fired_; }
97
98 const std::string &getMessage() const { return message_; }
99
100 void setMessage(const std::string &message) { message_ = message; }
101
102private:
103 std::atomic<bool> fired_;
104 std::mutex mutex_;
105 std::condition_variable cond_;
106 std::string message_;
107 /// Used to hold an onnxStatus that will be passed for the signaller of the
108 /// event to a waiter. Should only be accessed while holding mutex_.
109 onnxStatus status_ = ONNXIFI_STATUS_SUCCESS;
110};
111
112typedef Event *EventPtr;
113
114class Graph {
115public:
116 explicit Graph(BackendPtr backendPtr);
117 virtual ~Graph() = default;
118
119 BackendPtr backend() { return backendPtr_; }
120
121 /// Setup Glow graph in preparation for the inference and run.
122 /// Set input memory addresses for inputs based on the \p inputDescriptors.
123 /// Set output memory addresses for outputs based on the \p
124 /// outputDescriptors. Will async signal the \p outputEvent when run is
125 /// complete. \p traceEvents is a pointer to onnxTraceEventList, if it is not
126 /// null then it is expected that this will be populated with trace events
127 /// from the run before signalling the outputEvent.
128 onnxStatus setIOAndRun(uint32_t inputsCount,
129 const onnxTensorDescriptorV1 *inputDescriptors,
130 uint32_t outputsCount,
131 const onnxTensorDescriptorV1 *outputDescriptors,
132 EventPtr outputEvent, onnxTraceEventList *traceEvents);
133
134 /// Init Glow graph based on the ONNX model \p onnxModel and
135 /// static trained weights \p weightDescriptors. Weights can be read in later
136 /// by a \p deferedBlobReader.
137 virtual onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize,
138 uint32_t weightCount,
139 const onnxTensorDescriptorV1 *weightDescriptors,
140 uint32_t maxSeqLength, void *deferedBlobReader,
141 bool loadingGlowAOT) = 0;
142
143 virtual onnxStatus run(std::unique_ptr<ExecutionContext> ctx,
144 EventPtr outputEvent,
145 onnxTraceEventList *traceEvents) = 0;
146
147 /// Copy any trace events \p traceContext into \p traceEvents. If
148 /// \p traceEvents is null then do nothing.
149 static void setTraceEvents(onnxTraceEventList *traceEvents,
150 TraceContext *traceContext);
151
152 /// Free all memory that was allocated by setTraceEvents when creating \p
153 /// traceEvents.
154 static void releaseTraceEvents(onnxTraceEventList *traceEvents);
155
156protected:
157 BackendPtr backendPtr_;
158
159 /// Mapping between ONNX name for the input variable and Glow
160 /// placeholder for input.
161 llvm::StringMap<Placeholder *> onnxInputToPlaceholder_;
162
163 /// Mapping between ONNX name for the output variable and Glow
164 /// placeholder for output.
165 llvm::StringMap<Placeholder *> onnxOutputToPlaceholder_;
166
167 /// A list of input names ordered by their position in ONNXIFI input
168 /// descriptor array.
169 std::vector<std::string> onnxInputNames_;
170
171 /// A list of input placeholders ordered by their position in ONNXIFI input
172 /// descriptor array.
173 std::vector<Placeholder *> onnxInputPlaceholders_;
174
175 /// A list of output names ordered by their position in ONNXIFI output
176 /// descriptor array.
177 std::vector<std::string> onnxOutputNames_;
178
179 /// A list of output placeholders ordered by their position in ONNXIFI output
180 /// descriptor array.
181 std::vector<Placeholder *> onnxOutputPlaceholders_;
182
183 /// An object pool for tensors, to share allocations.
184 TensorPool tensorPool_;
185
186 /// An anchor tensor specialized for zero length indices
187 Tensor zeroLengthSequence_;
188
189 /// Set the zero length tensor
190 void setZeroLengthSequence(dim_t maxSeqLength);
191
192 /// Bind input/output placeholders
193 bool bindPlaceholders(const ONNXIFIModelLoader &loader,
194 LoadedPlaceholderNameMap *loadedPHNames = nullptr);
195
196private:
197 /// inference dump counter
198 std::atomic<size_t> ioDumpCounter_{0};
199
200 /// Setup input mapping and adjust inputs if necessary
201 onnxStatus adjustInputs(uint32_t inputsCount,
202 const onnxTensorDescriptorV1 *inputDescriptors,
203 ExecutionContext *ctx);
204};
205
206typedef Graph *GraphPtr;
207
208/// Save the Function from ONNXIFI to a file
209void saveOnnxifiModel(Function *F);
210
211} // namespace onnxifi
212} // namespace glow
213
214#endif // GLOW_ONNXIFI_BASE_H
215