1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_ONNXIFI_BASE_H |
17 | #define GLOW_ONNXIFI_BASE_H |
18 | |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
21 | #include "glow/Importer/ONNXIFIModelLoader.h" |
22 | #include "glow/Runtime/RuntimeTypes.h" |
23 | #include "glow/Support/TensorPool.h" |
24 | |
25 | #include "foxi/onnxifi.h" |
26 | #include "foxi/onnxifi_ext.h" |
27 | |
28 | #include <atomic> |
29 | #include <condition_variable> |
30 | #include <mutex> |
31 | |
32 | #include "llvm/ADT/StringMap.h" |
33 | |
34 | namespace glow { |
35 | namespace onnxifi { |
36 | |
37 | class Graph; |
38 | |
39 | /// Backend associated with the Glow backend. |
40 | class Backend { |
41 | public: |
42 | /// Create Glow ONNXIFI backend identifier with the |
43 | /// given Glow backend \p backendName, whether to use onnx or caffe2 for |
44 | /// models |
45 | /// (\p useOnnx) |
46 | Backend(llvm::StringRef backendName, bool useOnnx) |
47 | : useOnnx_(useOnnx), glowBackend_(createBackend(backendName)) {} |
48 | |
49 | virtual ~Backend() = default; |
50 | |
51 | /// Verify that a given onnx graph is supported by the backend by importing |
52 | /// the onnx graph to a glow function, lowering this function, and checking |
53 | /// that all of the glow nodes that are contained in the lowered graph are |
54 | /// compatible with the glow backend. |
55 | onnxStatus checkGraphCompatibility(const void *onnxModel, |
56 | size_t onnxModelSize); |
57 | |
58 | /// \returns the whether use onnx or not. |
59 | bool getUseOnnx() const { return useOnnx_; } |
60 | |
61 | /// \returns a reference to the backend. |
62 | const glow::Backend &getBackend() const { return *glowBackend_; } |
63 | |
64 | virtual void runNetwork(const Graph *graph, |
65 | std::unique_ptr<ExecutionContext> context, |
66 | runtime::ResultCBTy callback, uint64_t priority = 0) { |
67 | } |
68 | |
69 | virtual onnxStatus removeNetwork(const Graph *graph) { |
70 | return ONNXIFI_STATUS_SUCCESS; |
71 | } |
72 | |
73 | protected: |
74 | bool useOnnx_; |
75 | std::unique_ptr<glow::Backend> glowBackend_; |
76 | }; |
77 | |
78 | typedef Backend *BackendPtr; |
79 | |
80 | class Event { |
81 | public: |
82 | Event() : fired_{false} {} |
83 | /// Signal the event. |
84 | bool signal(onnxStatus status); |
85 | |
86 | /// Wait until the event is signalled. |
87 | onnxStatus wait(); |
88 | |
89 | /// Wait until the event is signalled or until at least \p timeoutMs |
90 | /// milliseconds have elapsed. \returns a pair with the first value being a |
91 | /// boolean that is true if the event was signalled (no timeout occurred) and |
92 | /// the second is the value of the event's status. |
93 | std::pair<bool, onnxStatus> waitFor(size_t timeoutMs); |
94 | |
95 | /// Check if event was signalled. |
96 | bool isSignalled() { return fired_; } |
97 | |
98 | const std::string &getMessage() const { return message_; } |
99 | |
100 | void setMessage(const std::string &message) { message_ = message; } |
101 | |
102 | private: |
103 | std::atomic<bool> fired_; |
104 | std::mutex mutex_; |
105 | std::condition_variable cond_; |
106 | std::string message_; |
107 | /// Used to hold an onnxStatus that will be passed for the signaller of the |
108 | /// event to a waiter. Should only be accessed while holding mutex_. |
109 | onnxStatus status_ = ONNXIFI_STATUS_SUCCESS; |
110 | }; |
111 | |
112 | typedef Event *EventPtr; |
113 | |
114 | class Graph { |
115 | public: |
116 | explicit Graph(BackendPtr backendPtr); |
117 | virtual ~Graph() = default; |
118 | |
119 | BackendPtr backend() { return backendPtr_; } |
120 | |
121 | /// Setup Glow graph in preparation for the inference and run. |
122 | /// Set input memory addresses for inputs based on the \p inputDescriptors. |
123 | /// Set output memory addresses for outputs based on the \p |
124 | /// outputDescriptors. Will async signal the \p outputEvent when run is |
125 | /// complete. \p traceEvents is a pointer to onnxTraceEventList, if it is not |
126 | /// null then it is expected that this will be populated with trace events |
127 | /// from the run before signalling the outputEvent. |
128 | onnxStatus setIOAndRun(uint32_t inputsCount, |
129 | const onnxTensorDescriptorV1 *inputDescriptors, |
130 | uint32_t outputsCount, |
131 | const onnxTensorDescriptorV1 *outputDescriptors, |
132 | EventPtr outputEvent, onnxTraceEventList *traceEvents); |
133 | |
134 | /// Init Glow graph based on the ONNX model \p onnxModel and |
135 | /// static trained weights \p weightDescriptors. Weights can be read in later |
136 | /// by a \p deferedBlobReader. |
137 | virtual onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize, |
138 | uint32_t weightCount, |
139 | const onnxTensorDescriptorV1 *weightDescriptors, |
140 | uint32_t maxSeqLength, void *deferedBlobReader, |
141 | bool loadingGlowAOT) = 0; |
142 | |
143 | virtual onnxStatus run(std::unique_ptr<ExecutionContext> ctx, |
144 | EventPtr outputEvent, |
145 | onnxTraceEventList *traceEvents) = 0; |
146 | |
147 | /// Copy any trace events \p traceContext into \p traceEvents. If |
148 | /// \p traceEvents is null then do nothing. |
149 | static void setTraceEvents(onnxTraceEventList *traceEvents, |
150 | TraceContext *traceContext); |
151 | |
152 | /// Free all memory that was allocated by setTraceEvents when creating \p |
153 | /// traceEvents. |
154 | static void releaseTraceEvents(onnxTraceEventList *traceEvents); |
155 | |
156 | protected: |
157 | BackendPtr backendPtr_; |
158 | |
159 | /// Mapping between ONNX name for the input variable and Glow |
160 | /// placeholder for input. |
161 | llvm::StringMap<Placeholder *> onnxInputToPlaceholder_; |
162 | |
163 | /// Mapping between ONNX name for the output variable and Glow |
164 | /// placeholder for output. |
165 | llvm::StringMap<Placeholder *> onnxOutputToPlaceholder_; |
166 | |
167 | /// A list of input names ordered by their position in ONNXIFI input |
168 | /// descriptor array. |
169 | std::vector<std::string> onnxInputNames_; |
170 | |
171 | /// A list of input placeholders ordered by their position in ONNXIFI input |
172 | /// descriptor array. |
173 | std::vector<Placeholder *> onnxInputPlaceholders_; |
174 | |
175 | /// A list of output names ordered by their position in ONNXIFI output |
176 | /// descriptor array. |
177 | std::vector<std::string> onnxOutputNames_; |
178 | |
179 | /// A list of output placeholders ordered by their position in ONNXIFI output |
180 | /// descriptor array. |
181 | std::vector<Placeholder *> onnxOutputPlaceholders_; |
182 | |
183 | /// An object pool for tensors, to share allocations. |
184 | TensorPool tensorPool_; |
185 | |
186 | /// An anchor tensor specialized for zero length indices |
187 | Tensor zeroLengthSequence_; |
188 | |
189 | /// Set the zero length tensor |
190 | void setZeroLengthSequence(dim_t maxSeqLength); |
191 | |
192 | /// Bind input/output placeholders |
193 | bool bindPlaceholders(const ONNXIFIModelLoader &loader, |
194 | LoadedPlaceholderNameMap *loadedPHNames = nullptr); |
195 | |
196 | private: |
197 | /// inference dump counter |
198 | std::atomic<size_t> ioDumpCounter_{0}; |
199 | |
200 | /// Setup input mapping and adjust inputs if necessary |
201 | onnxStatus adjustInputs(uint32_t inputsCount, |
202 | const onnxTensorDescriptorV1 *inputDescriptors, |
203 | ExecutionContext *ctx); |
204 | }; |
205 | |
206 | typedef Graph *GraphPtr; |
207 | |
208 | /// Save the Function from ONNXIFI to a file |
209 | void saveOnnxifiModel(Function *F); |
210 | |
211 | } // namespace onnxifi |
212 | } // namespace glow |
213 | |
214 | #endif // GLOW_ONNXIFI_BASE_H |
215 | |