1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Runtime/Executor/NetworkExecutionState.h" |
17 | #include "glow/Backends/DeviceManager.h" |
18 | |
19 | using namespace glow; |
20 | using namespace glow::runtime; |
21 | |
22 | namespace { |
23 | static void updateTensor(Tensor &tensor, const Tensor &seed) { |
24 | if (auto *tensorPool = tensor.getOwningPool()) { |
25 | tensorPool->reclaim(std::move(tensor)); |
26 | } |
27 | tensor = seed.getUnowned(); |
28 | } |
29 | } // namespace |
30 | |
31 | void NetworkExecutionStatePool::addNewState( |
32 | std::unique_ptr<NetworkExecutionState> state) { |
33 | |
34 | std::lock_guard<std::mutex> lock(stateLock_); |
35 | availableStates_.push_back(state.get()); |
36 | states_.push_back(std::move(state)); |
37 | } |
38 | |
39 | NetworkExecutionState::NetworkExecutionState(const DAGNode *root, |
40 | bool enableDRT, bool enableP2P) |
41 | : enableDRT_(enableDRT), enableP2P_(enableP2P), inflightNodes_(0), |
42 | module_(root->module), root_(root) {} |
43 | |
44 | NetworkExecutionState::~NetworkExecutionState() { |
45 | // Free all allocated buffers. |
46 | for (auto &allocation : deviceAllocations_) { |
47 | allocation.second->freeAllocatedDeviceIOBuffer(allocation.first); |
48 | } |
49 | } |
50 | |
51 | void NetworkExecutionState::bind(std::unique_ptr<ExecutionContext> resultCtx, |
52 | ResultCBTy cb, RunIdentifierTy runId) { |
53 | resultCtx_ = std::move(resultCtx); |
54 | cb_ = std::move(cb); |
55 | runId_ = runId; |
56 | // Reset execution state, inflight nodes, parents done, etc. |
57 | for (auto &count : nodeParentsDone_) { |
58 | count.second = 0; |
59 | } |
60 | inflightNodes_ = 0; |
61 | // Setup tracing if desired. |
62 | auto resultTraceContext = resultCtx_->getTraceContext(); |
63 | if (resultTraceContext) { |
64 | for (auto &context : intermediateContexts_) { |
65 | context.second->setTraceContext( |
66 | glow::make_unique<TraceContext>(resultTraceContext->getTraceLevel())); |
67 | } |
68 | } else { |
69 | // Clear any trace context from a previous run. |
70 | for (auto &context : intermediateContexts_) { |
71 | context.second->setTraceContext(nullptr); |
72 | } |
73 | } |
74 | // Move inputs into tensors backing intermediate contexts. |
75 | // Instead we point the tensors to the provided buffers to avoid copy in and |
76 | // out. Once we have pinned allocations we will need to transfer. |
77 | // For now point input and output tensors to buffers used in resultCtx. |
78 | const auto &externalIOBindings = resultCtx_->getExternalIOBindings(); |
79 | if (!externalIOBindings.empty()) { |
80 | // Fast path |
81 | if (ioIdxMapping_.empty()) { |
82 | for (const auto &pair : externalIOBindings) { |
83 | const auto it = externalPlaceholdersIdx_.find(pair.first); |
84 | if (it == externalPlaceholdersIdx_.end()) { |
85 | LOG(WARNING) << "Cannot match external placeholder: " |
86 | << pair.first->getDebugDesc(); |
87 | ioIdxMapping_.emplace_back(-1); |
88 | } else { |
89 | ioIdxMapping_.emplace_back(it->second); |
90 | } |
91 | } |
92 | } |
93 | DCHECK(ioIdxMapping_.size() == externalIOBindings.size()); |
94 | for (unsigned i = 0, e = externalIOBindings.size(); i < e; ++i) { |
95 | if (ioIdxMapping_[i] < 0) { |
96 | // Unable to match external placeholder |
97 | continue; |
98 | } |
99 | const auto &pair = externalIOBindings[i]; |
100 | const auto &resultTensor = pair.second; |
101 | for (auto &bindingIt : externalPlaceholders_[ioIdxMapping_[i]]) { |
102 | updateTensor(bindingIt->second, resultTensor); |
103 | } |
104 | } |
105 | } else { |
106 | // Slow path for backward compatibility, we will do extra hash lookup |
107 | auto *resultPHBindings = resultCtx_->getPlaceholderBindings(); |
108 | for (auto &pair : resultPHBindings->pairs()) { |
109 | auto *PH = pair.first; |
110 | const auto &resultTensor = pair.second; |
111 | const auto it = externalPlaceholdersIdx_.find(PH); |
112 | if (it == externalPlaceholdersIdx_.end()) { |
113 | continue; |
114 | } |
115 | for (auto &bindingIt : externalPlaceholders_[it->second]) { |
116 | updateTensor(bindingIt->second, resultTensor); |
117 | } |
118 | } |
119 | } |
120 | } |
121 | |
122 | void NetworkExecutionState::init( |
123 | const DeviceManagerMapTy &devices, |
124 | std::unordered_map<DAGNode *, DeviceIDTy> &staticAssignment) { |
125 | // Create a queue for the breadth-first traversal through the graph. |
126 | std::queue<DAGNode *> bfsQueue; |
127 | // Marking the default err as checked so we don't get an unchecked error in |
128 | // destructor if we never use this state. |
129 | errContainer_.containsErr(); |
130 | |
131 | // Place the root nodes in the queue. |
132 | for (auto &node : root_->children) { |
133 | bfsQueue.push(node); |
134 | // Make a counter for the number of node parents done. This also is used for |
135 | // tracking if we've added the node already. |
136 | nodeParentsDone_[node] = 0; |
137 | } |
138 | |
139 | // Breadth-first search. |
140 | while (!bfsQueue.empty()) { |
141 | // Get the next node in the BFS queue. |
142 | DAGNode *node = bfsQueue.front(); |
143 | bfsQueue.pop(); |
144 | |
145 | // Push all unvisited children onto the BFS queue. |
146 | for (const auto &child : node->children) { |
147 | // Use nodeParentsDone_ as a set of nodes that have been visited already |
148 | // to avoid visiting a node more than once. |
149 | if (!nodeParentsDone_.count(child)) { |
150 | nodeParentsDone_[child] = 0; |
151 | bfsQueue.push(child); |
152 | } |
153 | } |
154 | |
155 | // Make an (empty) context for the node. |
156 | auto intermediateContext = glow::make_unique<ExecutionContext>(); |
157 | auto it = staticAssignment.find(node); |
158 | // If an assignment is provided for this context set it here. |
159 | if (it != staticAssignment.end()) { |
160 | auto dm = devices.find(it->second)->second.get(); |
161 | intermediateContext->setBoundDeviceManager(dm); |
162 | } |
163 | // Get a device to do allocation we can use the first device since the |
164 | // allocation is not device specific. |
165 | auto &device = devices.begin()->second; |
166 | |
167 | auto intermediatePHBindings = intermediateContext->getPlaceholderBindings(); |
168 | |
169 | // Get the symbol table for the node. |
170 | const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable(); |
171 | |
172 | // Add inputs/outputs to the context. Skip any marked as static. |
173 | for (const auto &symbolPair : symbolTable) { |
174 | const auto &symbolName = symbolPair.first; |
175 | const auto &symbolInfo = symbolPair.second; |
176 | if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) { |
177 | auto PH = module_->getPlaceholderByNameSlow(symbolName); |
178 | |
179 | DCHECK(PH) << "Placeholder: " << symbolName << " is not in the module" ; |
180 | // If PH is marked static skip it. |
181 | if (PH->isStatic()) { |
182 | continue; |
183 | } |
184 | // If we haven't allocated a buffer for this PH yet do so, otherwise |
185 | // reuse the allocation. |
186 | // TODO: for intermediate placeholders in DRT/P2P cases, we don't need |
187 | // to allocate a backing tensor on host. |
188 | auto bufferIt = buffers_.find(PH); |
189 | if (bufferIt == buffers_.end()) { |
190 | auto *deviceBuffer = |
191 | device->allocateDeviceIOBuffer(PH->getType()->getSizeInBytes()); |
192 | buffers_[PH] = deviceBuffer; |
193 | deviceAllocations_.insert({deviceBuffer, device.get()}); |
194 | } |
195 | auto buffer = buffers_[PH]; |
196 | Tensor backingTensor(buffer, PH->getType()); |
197 | auto itt = intermediatePHBindings->insert(PH, std::move(backingTensor)); |
198 | // TODO: Only add to externalPlaceholders_ of PH is external placeholder |
199 | auto idxIt = externalPlaceholdersIdx_.find(PH); |
200 | if (idxIt == externalPlaceholdersIdx_.end()) { |
201 | externalPlaceholdersIdx_.emplace(PH, externalPlaceholders_.size()); |
202 | externalPlaceholders_.emplace_back(); |
203 | auto &vec = externalPlaceholders_.back(); |
204 | vec.push_back(itt); |
205 | } else { |
206 | auto &vec = externalPlaceholders_[idxIt->second]; |
207 | vec.push_back(itt); |
208 | } |
209 | } |
210 | } |
211 | |
212 | // Insert the prepared ExecutionContext into the input contexts map. |
213 | intermediateContexts_.emplace(node, std::move(intermediateContext)); |
214 | } |
215 | // If we used a static assignment call backend->bindContexts() on the new |
216 | // contexts. |
217 | if (staticAssignment.size()) { |
218 | std::vector<runtime::ContextBinding> contexts; |
219 | for (auto &intermediate : intermediateContexts_) { |
220 | runtime::ContextBinding intermediateBinding; |
221 | intermediateBinding.context = intermediate.second.get(); |
222 | intermediateBinding.networkName = intermediate.first->name; |
223 | intermediateBinding.device = intermediate.second->getBoundDeviceManager(); |
224 | contexts.push_back(intermediateBinding); |
225 | } |
226 | const auto &backendName = devices.begin()->second->getBackendName(); |
227 | // Create a backend to call bindContexts on, since bindContexts only puts |
228 | // state in the DeviceManager and Context we can safely discard this backend |
229 | // once we are done with it. |
230 | std::unique_ptr<Backend> newBackend(createBackend(backendName)); |
231 | |
232 | EXIT_ON_ERR( |
233 | newBackend->bindContexts(contexts, root_, enableP2P_, enableDRT_)); |
234 | } |
235 | initialized_ = true; |
236 | } |
237 | |
238 | std::unique_ptr<ExecutionContext> |
239 | NetworkExecutionState::getUniqueNodeContextPtr(const DAGNode *node) { |
240 | // The input PlaceholderBindings for the node should have been created in |
241 | // the constructor. |
242 | auto ctxIt = intermediateContexts_.find(node); |
243 | |
244 | DCHECK(ctxIt != intermediateContexts_.end()) |
245 | << "Input bindings not found but should exist!" ; |
246 | |
247 | return std::move(ctxIt->second); |
248 | } |
249 | |
250 | void NetworkExecutionState::returnUniqueNodeContextPtr( |
251 | const DAGNode *node, std::unique_ptr<ExecutionContext> ctx) { |
252 | intermediateContexts_[node] = std::move(ctx); |
253 | } |
254 | |
255 | void NetworkExecutionState::incrementInflightNodes(unsigned increment) { |
256 | inflightNodes_ += increment; |
257 | } |
258 | |
259 | bool NetworkExecutionState::decrementInflightNodes(unsigned decrement) { |
260 | // fetch_sub must be used here so that the function returns true to only one |
261 | // caller. |
262 | unsigned previousValue = inflightNodes_.fetch_sub(decrement); |
263 | |
264 | // The decrement should never be more than the value of the counter at the |
265 | // time of decrement. |
266 | DCHECK_GE(previousValue, decrement) |
267 | << "More decrements than increments to inflight nodes!" ; |
268 | |
269 | // Return true when the counter hits zero. |
270 | return (previousValue == decrement); |
271 | } |
272 | |
273 | bool NetworkExecutionState::incrementNodeParentsDone(const DAGNode *node, |
274 | unsigned increment) { |
275 | // Get the parents done counter for the node. It should have |
276 | // been created in the constructor. |
277 | auto it = nodeParentsDone_.find(node); |
278 | |
279 | DCHECK(it != nodeParentsDone_.end()) |
280 | << "Node parents done counter should exist but not found!" ; |
281 | |
282 | // fetch_add must be used here so that the function returns true to only |
283 | // one caller. |
284 | unsigned numParents = (node->parents).size(); |
285 | unsigned previousValue = (it->second).fetch_add(increment); |
286 | unsigned newValue = previousValue + increment; |
287 | |
288 | // The new value of the counter cannot exceed the number of parents that |
289 | // the node has. |
290 | DCHECK_LE(newValue, numParents) |
291 | << "Node parents done counter incremented beyond limit!" ; |
292 | |
293 | // Return true only when the counter hits the total numer of parents. |
294 | return (newValue == numParents); |
295 | } |
296 | |
297 | void NetworkExecutionState::insertIntoTraceContext(TraceContext *runCtx) { |
298 | if (!resultCtx_->getTraceContext()) { |
299 | return; |
300 | } |
301 | |
302 | resultCtx_->getTraceContext()->merge(runCtx); |
303 | } |
304 | |
305 | std::unique_ptr<ExecutionContext> |
306 | NetworkExecutionState::getUniqueResultContextPtr() { |
307 | // The result PlaceholderBindings should have been created in the |
308 | // constructor. |
309 | DCHECK_NOTNULL(resultCtx_.get()); |
310 | return std::move(resultCtx_); |
311 | } |
312 | |
313 | ExecutionContext *NetworkExecutionState::getRawResultContextPtr() const { |
314 | // The result PlaceholderBindings should have been been created in the |
315 | // constructor and should not yet have been moved out if this function is |
316 | // being called. |
317 | DCHECK_NOTNULL(resultCtx_.get()); |
318 | return resultCtx_.get(); |
319 | } |
320 | |