1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Runtime/Executor/NetworkExecutionState.h"
17#include "glow/Backends/DeviceManager.h"
18
19using namespace glow;
20using namespace glow::runtime;
21
22namespace {
23static void updateTensor(Tensor &tensor, const Tensor &seed) {
24 if (auto *tensorPool = tensor.getOwningPool()) {
25 tensorPool->reclaim(std::move(tensor));
26 }
27 tensor = seed.getUnowned();
28}
29} // namespace
30
31void NetworkExecutionStatePool::addNewState(
32 std::unique_ptr<NetworkExecutionState> state) {
33
34 std::lock_guard<std::mutex> lock(stateLock_);
35 availableStates_.push_back(state.get());
36 states_.push_back(std::move(state));
37}
38
39NetworkExecutionState::NetworkExecutionState(const DAGNode *root,
40 bool enableDRT, bool enableP2P)
41 : enableDRT_(enableDRT), enableP2P_(enableP2P), inflightNodes_(0),
42 module_(root->module), root_(root) {}
43
44NetworkExecutionState::~NetworkExecutionState() {
45 // Free all allocated buffers.
46 for (auto &allocation : deviceAllocations_) {
47 allocation.second->freeAllocatedDeviceIOBuffer(allocation.first);
48 }
49}
50
51void NetworkExecutionState::bind(std::unique_ptr<ExecutionContext> resultCtx,
52 ResultCBTy cb, RunIdentifierTy runId) {
53 resultCtx_ = std::move(resultCtx);
54 cb_ = std::move(cb);
55 runId_ = runId;
56 // Reset execution state, inflight nodes, parents done, etc.
57 for (auto &count : nodeParentsDone_) {
58 count.second = 0;
59 }
60 inflightNodes_ = 0;
61 // Setup tracing if desired.
62 auto resultTraceContext = resultCtx_->getTraceContext();
63 if (resultTraceContext) {
64 for (auto &context : intermediateContexts_) {
65 context.second->setTraceContext(
66 glow::make_unique<TraceContext>(resultTraceContext->getTraceLevel()));
67 }
68 } else {
69 // Clear any trace context from a previous run.
70 for (auto &context : intermediateContexts_) {
71 context.second->setTraceContext(nullptr);
72 }
73 }
74 // Move inputs into tensors backing intermediate contexts.
75 // Instead we point the tensors to the provided buffers to avoid copy in and
76 // out. Once we have pinned allocations we will need to transfer.
77 // For now point input and output tensors to buffers used in resultCtx.
78 const auto &externalIOBindings = resultCtx_->getExternalIOBindings();
79 if (!externalIOBindings.empty()) {
80 // Fast path
81 if (ioIdxMapping_.empty()) {
82 for (const auto &pair : externalIOBindings) {
83 const auto it = externalPlaceholdersIdx_.find(pair.first);
84 if (it == externalPlaceholdersIdx_.end()) {
85 LOG(WARNING) << "Cannot match external placeholder: "
86 << pair.first->getDebugDesc();
87 ioIdxMapping_.emplace_back(-1);
88 } else {
89 ioIdxMapping_.emplace_back(it->second);
90 }
91 }
92 }
93 DCHECK(ioIdxMapping_.size() == externalIOBindings.size());
94 for (unsigned i = 0, e = externalIOBindings.size(); i < e; ++i) {
95 if (ioIdxMapping_[i] < 0) {
96 // Unable to match external placeholder
97 continue;
98 }
99 const auto &pair = externalIOBindings[i];
100 const auto &resultTensor = pair.second;
101 for (auto &bindingIt : externalPlaceholders_[ioIdxMapping_[i]]) {
102 updateTensor(bindingIt->second, resultTensor);
103 }
104 }
105 } else {
106 // Slow path for backward compatibility, we will do extra hash lookup
107 auto *resultPHBindings = resultCtx_->getPlaceholderBindings();
108 for (auto &pair : resultPHBindings->pairs()) {
109 auto *PH = pair.first;
110 const auto &resultTensor = pair.second;
111 const auto it = externalPlaceholdersIdx_.find(PH);
112 if (it == externalPlaceholdersIdx_.end()) {
113 continue;
114 }
115 for (auto &bindingIt : externalPlaceholders_[it->second]) {
116 updateTensor(bindingIt->second, resultTensor);
117 }
118 }
119 }
120}
121
122void NetworkExecutionState::init(
123 const DeviceManagerMapTy &devices,
124 std::unordered_map<DAGNode *, DeviceIDTy> &staticAssignment) {
125 // Create a queue for the breadth-first traversal through the graph.
126 std::queue<DAGNode *> bfsQueue;
127 // Marking the default err as checked so we don't get an unchecked error in
128 // destructor if we never use this state.
129 errContainer_.containsErr();
130
131 // Place the root nodes in the queue.
132 for (auto &node : root_->children) {
133 bfsQueue.push(node);
134 // Make a counter for the number of node parents done. This also is used for
135 // tracking if we've added the node already.
136 nodeParentsDone_[node] = 0;
137 }
138
139 // Breadth-first search.
140 while (!bfsQueue.empty()) {
141 // Get the next node in the BFS queue.
142 DAGNode *node = bfsQueue.front();
143 bfsQueue.pop();
144
145 // Push all unvisited children onto the BFS queue.
146 for (const auto &child : node->children) {
147 // Use nodeParentsDone_ as a set of nodes that have been visited already
148 // to avoid visiting a node more than once.
149 if (!nodeParentsDone_.count(child)) {
150 nodeParentsDone_[child] = 0;
151 bfsQueue.push(child);
152 }
153 }
154
155 // Make an (empty) context for the node.
156 auto intermediateContext = glow::make_unique<ExecutionContext>();
157 auto it = staticAssignment.find(node);
158 // If an assignment is provided for this context set it here.
159 if (it != staticAssignment.end()) {
160 auto dm = devices.find(it->second)->second.get();
161 intermediateContext->setBoundDeviceManager(dm);
162 }
163 // Get a device to do allocation we can use the first device since the
164 // allocation is not device specific.
165 auto &device = devices.begin()->second;
166
167 auto intermediatePHBindings = intermediateContext->getPlaceholderBindings();
168
169 // Get the symbol table for the node.
170 const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable();
171
172 // Add inputs/outputs to the context. Skip any marked as static.
173 for (const auto &symbolPair : symbolTable) {
174 const auto &symbolName = symbolPair.first;
175 const auto &symbolInfo = symbolPair.second;
176 if (symbolInfo.symbolCategory == SymbolCategory::Placeholder) {
177 auto PH = module_->getPlaceholderByNameSlow(symbolName);
178
179 DCHECK(PH) << "Placeholder: " << symbolName << " is not in the module";
180 // If PH is marked static skip it.
181 if (PH->isStatic()) {
182 continue;
183 }
184 // If we haven't allocated a buffer for this PH yet do so, otherwise
185 // reuse the allocation.
186 // TODO: for intermediate placeholders in DRT/P2P cases, we don't need
187 // to allocate a backing tensor on host.
188 auto bufferIt = buffers_.find(PH);
189 if (bufferIt == buffers_.end()) {
190 auto *deviceBuffer =
191 device->allocateDeviceIOBuffer(PH->getType()->getSizeInBytes());
192 buffers_[PH] = deviceBuffer;
193 deviceAllocations_.insert({deviceBuffer, device.get()});
194 }
195 auto buffer = buffers_[PH];
196 Tensor backingTensor(buffer, PH->getType());
197 auto itt = intermediatePHBindings->insert(PH, std::move(backingTensor));
198 // TODO: Only add to externalPlaceholders_ of PH is external placeholder
199 auto idxIt = externalPlaceholdersIdx_.find(PH);
200 if (idxIt == externalPlaceholdersIdx_.end()) {
201 externalPlaceholdersIdx_.emplace(PH, externalPlaceholders_.size());
202 externalPlaceholders_.emplace_back();
203 auto &vec = externalPlaceholders_.back();
204 vec.push_back(itt);
205 } else {
206 auto &vec = externalPlaceholders_[idxIt->second];
207 vec.push_back(itt);
208 }
209 }
210 }
211
212 // Insert the prepared ExecutionContext into the input contexts map.
213 intermediateContexts_.emplace(node, std::move(intermediateContext));
214 }
215 // If we used a static assignment call backend->bindContexts() on the new
216 // contexts.
217 if (staticAssignment.size()) {
218 std::vector<runtime::ContextBinding> contexts;
219 for (auto &intermediate : intermediateContexts_) {
220 runtime::ContextBinding intermediateBinding;
221 intermediateBinding.context = intermediate.second.get();
222 intermediateBinding.networkName = intermediate.first->name;
223 intermediateBinding.device = intermediate.second->getBoundDeviceManager();
224 contexts.push_back(intermediateBinding);
225 }
226 const auto &backendName = devices.begin()->second->getBackendName();
227 // Create a backend to call bindContexts on, since bindContexts only puts
228 // state in the DeviceManager and Context we can safely discard this backend
229 // once we are done with it.
230 std::unique_ptr<Backend> newBackend(createBackend(backendName));
231
232 EXIT_ON_ERR(
233 newBackend->bindContexts(contexts, root_, enableP2P_, enableDRT_));
234 }
235 initialized_ = true;
236}
237
238std::unique_ptr<ExecutionContext>
239NetworkExecutionState::getUniqueNodeContextPtr(const DAGNode *node) {
240 // The input PlaceholderBindings for the node should have been created in
241 // the constructor.
242 auto ctxIt = intermediateContexts_.find(node);
243
244 DCHECK(ctxIt != intermediateContexts_.end())
245 << "Input bindings not found but should exist!";
246
247 return std::move(ctxIt->second);
248}
249
250void NetworkExecutionState::returnUniqueNodeContextPtr(
251 const DAGNode *node, std::unique_ptr<ExecutionContext> ctx) {
252 intermediateContexts_[node] = std::move(ctx);
253}
254
255void NetworkExecutionState::incrementInflightNodes(unsigned increment) {
256 inflightNodes_ += increment;
257}
258
259bool NetworkExecutionState::decrementInflightNodes(unsigned decrement) {
260 // fetch_sub must be used here so that the function returns true to only one
261 // caller.
262 unsigned previousValue = inflightNodes_.fetch_sub(decrement);
263
264 // The decrement should never be more than the value of the counter at the
265 // time of decrement.
266 DCHECK_GE(previousValue, decrement)
267 << "More decrements than increments to inflight nodes!";
268
269 // Return true when the counter hits zero.
270 return (previousValue == decrement);
271}
272
273bool NetworkExecutionState::incrementNodeParentsDone(const DAGNode *node,
274 unsigned increment) {
275 // Get the parents done counter for the node. It should have
276 // been created in the constructor.
277 auto it = nodeParentsDone_.find(node);
278
279 DCHECK(it != nodeParentsDone_.end())
280 << "Node parents done counter should exist but not found!";
281
282 // fetch_add must be used here so that the function returns true to only
283 // one caller.
284 unsigned numParents = (node->parents).size();
285 unsigned previousValue = (it->second).fetch_add(increment);
286 unsigned newValue = previousValue + increment;
287
288 // The new value of the counter cannot exceed the number of parents that
289 // the node has.
290 DCHECK_LE(newValue, numParents)
291 << "Node parents done counter incremented beyond limit!";
292
293 // Return true only when the counter hits the total numer of parents.
294 return (newValue == numParents);
295}
296
297void NetworkExecutionState::insertIntoTraceContext(TraceContext *runCtx) {
298 if (!resultCtx_->getTraceContext()) {
299 return;
300 }
301
302 resultCtx_->getTraceContext()->merge(runCtx);
303}
304
305std::unique_ptr<ExecutionContext>
306NetworkExecutionState::getUniqueResultContextPtr() {
307 // The result PlaceholderBindings should have been created in the
308 // constructor.
309 DCHECK_NOTNULL(resultCtx_.get());
310 return std::move(resultCtx_);
311}
312
313ExecutionContext *NetworkExecutionState::getRawResultContextPtr() const {
314 // The result PlaceholderBindings should have been been created in the
315 // constructor and should not yet have been moved out if this function is
316 // being called.
317 DCHECK_NOTNULL(resultCtx_.get());
318 return resultCtx_.get();
319}
320