1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/Executor/ThreadPoolExecutor.h" |
18 | #include "glow/Backends/DeviceManager.h" |
19 | #include "glow/Support/Support.h" |
20 | #include "glow/Support/ThreadPool.h" |
21 | |
22 | #include "gtest/gtest.h" |
23 | |
24 | #include <chrono> |
25 | #include <future> |
26 | #include <thread> |
27 | #include <unordered_set> |
28 | |
29 | using namespace glow; |
30 | using namespace glow::runtime; |
31 | |
32 | /// This is an implementation of DeviceManager tailored for testing Executor |
33 | /// implementations. registerResult() gives the caller the ability to |
34 | /// dictate precisely what a subsequent call to runFunction() should return. |
35 | /// registerResult() should be called before calling Executor::run() in each |
36 | /// test. The rest of the implementation of the DeviceManager interface exists |
37 | /// to satisfy the compiler. |
38 | class TestDeviceManager final : public runtime::DeviceManager { |
39 | public: |
40 | TestDeviceManager(unsigned numWorkers, const DeviceConfig &deviceConfig) |
41 | : DeviceManager(deviceConfig), threadPool_(numWorkers) {} |
42 | |
43 | /// The functions below are the interface for DeviceManager. See |
44 | /// glow::DeviceManager for descriptions of what they do. Since this |
45 | /// class exists only to help test Executor implementations, the only |
46 | /// important function is runFunction(). |
47 | void addNetwork(const Module *module, FunctionMapTy functions, |
48 | ReadyCBTy readyCB) override {} |
49 | |
50 | void evictNetwork(std::string functionName, |
51 | EvictFunctionCBTy evictCB) override { |
52 | // Erase the entry so that the same function name can be used to register |
53 | // another result. |
54 | |
55 | if (!resultMap_.erase(functionName)) { |
56 | evictCB( |
57 | functionName, |
58 | MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_NOT_FOUND, |
59 | strFormat("Could not find function with name %s to evict" , |
60 | functionName.c_str()))); |
61 | return; |
62 | } |
63 | evictCB(functionName, Error::success()); |
64 | } |
65 | |
66 | /// Look up the previously registered response for \p functionName and |
67 | /// call \p resultCB with it after checking that \p context contains the |
68 | /// expected Placeholder-Tensor mappings. |
69 | void doRunFunction(std::string functionName, |
70 | std::unique_ptr<ExecutionContext> context, |
71 | ResultCBTy resultCB) { |
72 | |
73 | RunIdentifierTy runId = 0; |
74 | bool successResult = false; |
75 | |
76 | // Retrieve the registered response for the function if there is one. |
77 | if (context && resultCB && resultMap_.count(functionName)) { |
78 | std::unique_ptr<RunFunctionResult> registeredResult = |
79 | std::move(resultMap_[functionName]); |
80 | |
81 | // Check that context contains the expected Placeholder-Tensor mappings. |
82 | std::unique_ptr<ExecutionContext> inputContext = |
83 | std::move(registeredResult->inputContext); |
84 | |
85 | bool equalInputs = true; |
86 | for (auto &p : inputContext->getPlaceholderBindings()->pairs()) { |
87 | Tensor *CT = context->getPlaceholderBindings()->get(p.first); |
88 | if (!CT) { |
89 | equalInputs = false; |
90 | break; |
91 | } |
92 | equalInputs &= p.second.isEqual(*CT, 0.0001, true); |
93 | } |
94 | |
95 | if (equalInputs) { |
96 | // If bindings contains all expected mappings, overwrite the default |
97 | // runId, result and resultContext with the registered |
98 | // ones. |
99 | runId = registeredResult->runId; |
100 | successResult = registeredResult->success; |
101 | |
102 | for (const auto &p : |
103 | registeredResult->resultContext->getPlaceholderBindings() |
104 | ->pairs()) { |
105 | context->getPlaceholderBindings()->get(p.first)->assign(&p.second); |
106 | } |
107 | } |
108 | } |
109 | |
110 | if (successResult) { |
111 | resultCB(runId, Error::success(), std::move(context)); |
112 | } else { |
113 | resultCB(runId, MAKE_ERR("An error occurred" ), std::move(context)); |
114 | } |
115 | } |
116 | |
117 | /// Do not call this at the same time as registerResult(). |
118 | runtime::RunIdentifierTy |
119 | runFunction(std::string functionName, |
120 | std::unique_ptr<ExecutionContext> context, |
121 | ResultCBTy resultCB) override { |
122 | // Give the call to the thread pool to process to make the tests |
123 | // multithreaded if needed. |
124 | this->threadPool_.submit( |
125 | [this, functionName, context = std::move(context), resultCB]() mutable { |
126 | this->doRunFunction(functionName, std::move(context), resultCB); |
127 | }); |
128 | return 0; |
129 | } |
130 | |
131 | uint64_t getMaximumMemory() const override { |
132 | return std::numeric_limits<uint64_t>::max(); |
133 | } |
134 | |
135 | uint64_t getAvailableMemory() const override { |
136 | return std::numeric_limits<uint64_t>::max(); |
137 | } |
138 | |
139 | bool isMemoryAvailable(uint64_t /*estimate*/) const override { return true; } |
140 | |
141 | /// Register a result that should be returned by the subsequent call to |
142 | /// runFunction with the same \p functionName. The callback for that call |
143 | /// to runFunction will be called with \p runId, \p success, and \p |
144 | /// \p resultContext if the context passed in to runFunction |
145 | /// matches \p inputContext. \returns true if registration was |
146 | /// successful, false if not. Do not call this at the same time as |
147 | /// runFunction(). |
148 | bool registerResult(const std::string &functionName, RunIdentifierTy runId, |
149 | bool success, |
150 | std::unique_ptr<ExecutionContext> inputContext, |
151 | std::unique_ptr<ExecutionContext> resultContext) { |
152 | bool registered = false; |
153 | |
154 | if (!resultMap_.count(functionName)) { |
155 | // If the function name has not already been registered, insert it into |
156 | // resultMap_. |
157 | std::tie(std::ignore, registered) = resultMap_.insert(std::make_pair( |
158 | functionName, glow::make_unique<RunFunctionResult>( |
159 | runId, success, std::move(inputContext), |
160 | std::move(resultContext)))); |
161 | } |
162 | |
163 | return registered; |
164 | } |
165 | |
166 | private: |
167 | /// This struct wraps all of the data needed to reply to a runFunction() call. |
168 | /// It exists so that that all of these things can be stored in one map. |
169 | struct RunFunctionResult { |
170 | /// The run ID that should be returned. |
171 | RunIdentifierTy runId; |
172 | /// If success then no error should be returned otherwise an Error should be |
173 | /// returned. |
174 | bool success; |
175 | /// The expected input context for the invocation. |
176 | std::unique_ptr<ExecutionContext> inputContext; |
177 | /// The result context that should be returned. |
178 | std::unique_ptr<ExecutionContext> resultContext; |
179 | |
180 | /// Constructor. |
181 | RunFunctionResult(RunIdentifierTy run, bool successParam, |
182 | std::unique_ptr<ExecutionContext> inputcontext, |
183 | std::unique_ptr<ExecutionContext> resultcontext) |
184 | : runId(run), success(successParam), |
185 | inputContext(std::move(inputcontext)), |
186 | resultContext(std::move(resultcontext)) {} |
187 | }; |
188 | |
189 | /// Map of function name -> RunFunctionResult instance containing the |
190 | /// RunFunctionResult instance for the function. |
191 | using TestDeviceManagerResultMapTy = |
192 | std::unordered_map<std::string, std::unique_ptr<RunFunctionResult>>; |
193 | |
194 | /// Map for storing registered results. |
195 | TestDeviceManagerResultMapTy resultMap_; |
196 | /// Thread pool for executing runFunction() in a multithreaded fashion. |
197 | ThreadPool threadPool_; |
198 | }; |
199 | |
200 | using PlaceholderNameMapTy = |
201 | std::unordered_map<std::string, std::unique_ptr<Placeholder>>; |
202 | using DAGNodeNameMapTy = |
203 | std::unordered_map<std::string, std::unique_ptr<DAGNode>>; |
204 | |
205 | /// This class serves as an interface to a test created by ExecutorTestBuilder. |
206 | /// It also contains the resources necessary to run the test. Instances are |
207 | /// meant to be created only by ExecutorTestBuilder. |
208 | class ExecutorTest final { |
209 | public: |
210 | /// Constructor. |
211 | ExecutorTest(const std::shared_ptr<Executor> &executor, |
212 | std::unique_ptr<DAGNode> root, std::unique_ptr<Module> module, |
213 | std::unique_ptr<Type> type, DAGNodeNameMapTy nodes, |
214 | PlaceholderNameMapTy placeholders, |
215 | std::unique_ptr<ExecutionContext> inputContext, |
216 | std::unique_ptr<ExecutionContext> outputContext, |
217 | RunIdentifierTy runId, bool expectSuccess) |
218 | : executor_(executor), root_(std::move(root)), module_(std::move(module)), |
219 | type_(std::move(type)), nodes_(std::move(nodes)), |
220 | placeholders_(std::move(placeholders)), |
221 | inputContext_(std::move(inputContext)), |
222 | outputContext_(std::move(outputContext)), runId_(runId), |
223 | expectSuccess_(expectSuccess), testRun_(false) { |
224 | root_->module = module_.get(); |
225 | // Create context pool. |
226 | executor_->createPool(root_.get(), 1000, false, false); |
227 | } |
228 | |
229 | /// Run the test. |
230 | bool run() { |
231 | if (testRun_) { |
232 | assert(!"Test has already been run!" ); |
233 | } |
234 | |
235 | // Variables for storing runId actually returned by |
236 | // Executor::run() via its callback. |
237 | RunIdentifierTy executorRunId; |
238 | std::unique_ptr<ExecutionContext> executorOutputContext; |
239 | |
240 | // Call Executor::run(). |
241 | std::promise<bool> promise; |
242 | std::future<bool> future = promise.get_future(); |
243 | executor_->run(root_.get(), std::move(inputContext_), runId_, |
244 | [&promise, &executorRunId, &executorOutputContext]( |
245 | RunIdentifierTy runId, Error err, |
246 | std::unique_ptr<ExecutionContext> context) { |
247 | executorRunId = runId; |
248 | executorOutputContext = std::move(context); |
249 | promise.set_value(ERR_TO_BOOL(std::move(err))); |
250 | }); |
251 | |
252 | bool runSuccess = !future.get(); |
253 | |
254 | // Check that the values returned in the Executor callback match |
255 | // expectations. |
256 | bool runIdsMatch = executorRunId == runId_; |
257 | bool resultsMatch = runSuccess == expectSuccess_; |
258 | |
259 | bool bindingsMatch = PlaceholderBindings::compare( |
260 | executorOutputContext->getPlaceholderBindings(), |
261 | outputContext_->getPlaceholderBindings()); |
262 | |
263 | // If the run failed, we shouldn't expect bindingsMatch to be true. |
264 | bool testPassed = |
265 | runIdsMatch && resultsMatch && (!runSuccess || bindingsMatch); |
266 | |
267 | testRun_ = true; |
268 | executor_->freePool(root_.get()); |
269 | return testPassed; |
270 | } |
271 | |
272 | private: |
273 | /// The Executor to run the test with. |
274 | std::shared_ptr<Executor> executor_; |
275 | /// The root node of the DAG being tested. |
276 | std::unique_ptr<DAGNode> root_; |
277 | /// The Module containing the PHs. |
278 | std::unique_ptr<Module> module_; |
279 | /// The Type for all of the Placeholders that will be used during execution. |
280 | std::unique_ptr<Type> type_; |
281 | /// All nodes in the DAG. |
282 | DAGNodeNameMapTy nodes_; |
283 | /// All Placeholders that will be used during execution. |
284 | PlaceholderNameMapTy placeholders_; |
285 | /// The input ExecutionContext that should be passed to Executor::run() |
286 | /// when running the test. |
287 | std::unique_ptr<ExecutionContext> inputContext_; |
288 | /// The expected ExecutionContext that the Executor should return. |
289 | std::unique_ptr<ExecutionContext> outputContext_; |
290 | /// The run ID that should be passed to Executor::run() when running |
291 | /// the test. |
292 | RunIdentifierTy runId_; |
293 | /// The expected result that the Executor should return. |
294 | bool expectSuccess_; |
295 | /// Tracks whether or not the test has already been run. |
296 | bool testRun_; |
297 | }; |
298 | |
299 | /// This class helps build tests for testing Executor implementations. It |
300 | /// presents a simple interface for executor DAG construction; nodes are added |
301 | /// by specifying its parents, device ID, and named inputs and outputs. This |
302 | /// builder class takes care of all of the work needed to actually run this DAG: |
303 | /// creation of Placeholders and Tensors for all inputs and outputs; creation of |
304 | /// input/output ExecutionContext for each node to verify that each one |
305 | /// receives the correct input and produces the correct output; and registration |
306 | /// with the TestDeviceManager. |
307 | class ExecutorTestBuilder final { |
308 | public: |
309 | /// Constructor. The exact value of type_ doesn't really matter since the |
310 | /// important thing to test is that that Placeholder values are propagated |
311 | /// between ExecutionContexts correctly. |
312 | ExecutorTestBuilder(const std::shared_ptr<Executor> &executor, |
313 | const DeviceManagerMapTy &deviceManagers) |
314 | : executor_(executor), module_(glow::make_unique<Module>()), |
315 | root_(glow::make_unique<DAGNode>()), |
316 | bindings_(glow::make_unique<PlaceholderBindings>()), |
317 | type_( |
318 | std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {32, 64, 128}))), |
319 | success_(true), deviceManagers_(deviceManagers) {} |
320 | |
321 | /// Add a node named \p name to the DAG with parents \p parents that runs on a |
322 | /// device specified by \p deviceId. A RuntimeBundle is created for the node |
323 | /// with runtime symbol information created from \p inputs and \p outputs. |
324 | /// \p runId is the run ID for the node and \p success is the desired |
325 | /// execution status. If \p parents is empty, the new node is added as a child |
326 | /// of the root. |
327 | void addNode(const std::string &name, DeviceIDTy deviceId, |
328 | llvm::ArrayRef<llvm::StringRef> parents, |
329 | llvm::ArrayRef<llvm::StringRef> inputs, |
330 | llvm::ArrayRef<llvm::StringRef> outputs, RunIdentifierTy runId, |
331 | bool success) { |
332 | auto newNode = glow::make_unique<DAGNode>(); |
333 | auto *newNodeRawPtr = newNode.get(); |
334 | |
335 | // If this is the first node being added, record the run ID for the graph. |
336 | // Otherwise, make sure that the runId matches that of the previous nodes. |
337 | if (nodes_.empty()) { |
338 | runId_ = runId; |
339 | } else { |
340 | assert(runId == runId_ && "Node run ID does not match rest of graph!" ); |
341 | } |
342 | |
343 | // If the result for this node is false, set the expected |
344 | // result for the entire test to false. |
345 | success_ &= success; |
346 | |
347 | // Add parents to the list of parents in the new node and add the newNode |
348 | // to the list of children in the parents. If the parent list is empty, |
349 | // make the root the only parent. Also, update the set of known leaves |
350 | // by removing any parents of the new node from it. This will be useful |
351 | // later. |
352 | if (!parents.empty()) { |
353 | for (const auto &parent : parents) { |
354 | auto it = nodes_.find(parent.str()); |
355 | if (it == nodes_.end()) { |
356 | assert(!"Parent specified for node not found!" ); |
357 | } |
358 | DAGNode *parentPtr = (it->second).get(); |
359 | (newNode->parents).emplace_back(parentPtr); |
360 | (parentPtr->children).emplace_back(newNodeRawPtr); |
361 | leaves_.erase(parentPtr); |
362 | } |
363 | } else { |
364 | (newNode->parents).emplace_back(root_.get()); |
365 | (root_->children).emplace_back(newNode.get()); |
366 | } |
367 | |
368 | // Iterate through inputs and outputs and: |
369 | // 1) Create Placeholders and Tensors for inputs/output names that have not |
370 | // been mapped to a Placeholder yet. |
371 | // 2) Assemble the input ExecutionContexts that the node is expected to be |
372 | // called with |
373 | // and the ExecutionContexts that the node should produce as output. |
374 | // 3) Generate the symbol table for the new node by generating |
375 | // RuntimeSymbolInfo objects for each input and output. |
376 | SymbolTableTy symbolTable; |
377 | size_t offset = 0; |
378 | |
379 | auto nodeInputContext = glow::make_unique<ExecutionContext>(); |
380 | auto nodeOutputContext = glow::make_unique<ExecutionContext>(); |
381 | |
382 | auto nodeInputBindings = nodeInputContext->getPlaceholderBindings(); |
383 | auto nodeOutputBindings = nodeOutputContext->getPlaceholderBindings(); |
384 | |
385 | for (const auto &input : inputs) { |
386 | // Both input and output bindings should contain bindings for the inputs. |
387 | insertSymbolIntoPlaceholderBindings(input, nodeInputBindings); |
388 | insertSymbolIntoPlaceholderBindings(input, nodeOutputBindings); |
389 | |
390 | RuntimeSymbolInfo runtimeSymbolInfo; |
391 | runtimeSymbolInfo.size = type_->getSizeInBytes(); |
392 | runtimeSymbolInfo.offset = offset; |
393 | runtimeSymbolInfo.type = *type_; |
394 | runtimeSymbolInfo.input = true; |
395 | runtimeSymbolInfo.output = false; |
396 | runtimeSymbolInfo.symbolCategory = SymbolCategory::Placeholder; |
397 | symbolTable.insert(std::make_pair(input, runtimeSymbolInfo)); |
398 | offset += type_->getSizeInBytes(); |
399 | } |
400 | |
401 | for (const auto &output : outputs) { |
402 | insertSymbolIntoPlaceholderBindings(output, nodeOutputBindings); |
403 | |
404 | RuntimeSymbolInfo runtimeSymbolInfo; |
405 | runtimeSymbolInfo.size = type_->getSizeInBytes(); |
406 | runtimeSymbolInfo.offset = offset; |
407 | runtimeSymbolInfo.type = *type_; |
408 | runtimeSymbolInfo.input = false; |
409 | runtimeSymbolInfo.output = true; |
410 | runtimeSymbolInfo.symbolCategory = SymbolCategory::Placeholder; |
411 | symbolTable.insert(std::make_pair(output, runtimeSymbolInfo)); |
412 | offset += type_->getSizeInBytes(); |
413 | } |
414 | |
415 | // Set the name, device ID, and RuntimeBundle of the new node. |
416 | newNode->name = name; |
417 | newNode->deviceRuntimeInfos[deviceId] = DeviceRuntimeInfo(); |
418 | |
419 | newNode->runtimeBundle = glow::make_unique<RuntimeBundle>( |
420 | symbolTable, /*constWeight=*/0, /*mutableWeight=*/0, |
421 | /*activations=*/0); |
422 | |
423 | // Register node result with the appropriate DeviceManager. |
424 | auto it = deviceManagers_.find(deviceId); |
425 | |
426 | if (it == deviceManagers_.end()) { |
427 | assert(!"No test device manager found for this device ID" ); |
428 | } |
429 | |
430 | auto *deviceManagerPtr = it->second.get(); |
431 | auto testDeviceManagerPtr = |
432 | static_cast<TestDeviceManager *>(deviceManagerPtr); |
433 | |
434 | bool registered = testDeviceManagerPtr->registerResult( |
435 | name, runId, success, std::move(nodeInputContext), |
436 | std::move(nodeOutputContext)); |
437 | |
438 | (void)registered; |
439 | assert(registered && "Node registration was not successful" ); |
440 | |
441 | // Add the new node to nodes_ and leaves_. |
442 | nodes_.insert(std::make_pair(name, std::move(newNode))); |
443 | leaves_.insert(newNodeRawPtr); |
444 | } |
445 | |
446 | /// Emit the test built so far and clear any state in the builder. |
447 | ExecutorTest emitTest() { |
448 | // Get the input and output symbol names for the whole DAG. |
449 | std::vector<std::string> inputSymbols = gatherInputSymbols(); |
450 | std::vector<std::string> outputSymbols = gatherOutputSymbols(); |
451 | |
452 | // Generate the input and output ExecutionContexts for the test. This |
453 | // input ExecutionContexts contains the input Placeholders of all root |
454 | // nodes and output Placeholders of all leaves (but backed by zero tensors). |
455 | // This is the ExecutionContexts that needs to be passed to |
456 | // Executor::run() to run the test. The output ExecutionContexts contains |
457 | // the same Placeholders as the input ExecutionContexts, but the leaves' |
458 | // output Placeholders are mapped to their expected output Tensors. This |
459 | // ExecutionContext is used to verify that the one returned by the |
460 | // Executor is correct. |
461 | auto inputContext = glow::make_unique<ExecutionContext>(); |
462 | auto outputContext = glow::make_unique<ExecutionContext>(); |
463 | |
464 | for (const auto &symbol : inputSymbols) { |
465 | insertSymbolIntoPlaceholderBindings( |
466 | symbol, inputContext->getPlaceholderBindings()); |
467 | insertSymbolIntoPlaceholderBindings( |
468 | symbol, outputContext->getPlaceholderBindings()); |
469 | } |
470 | |
471 | for (const auto &symbol : outputSymbols) { |
472 | auto *placeholder = bindings_->getPlaceholderByNameSlow(symbol); |
473 | if (!placeholder) { |
474 | assert(!"Placeholder for DAG output not found!" ); |
475 | } |
476 | insertSymbolIntoPlaceholderBindings( |
477 | symbol, inputContext->getPlaceholderBindings()); |
478 | insertSymbolIntoPlaceholderBindings( |
479 | symbol, outputContext->getPlaceholderBindings()); |
480 | } |
481 | // Create the test object. |
482 | ExecutorTest test(executor_, std::move(root_), std::move(module_), |
483 | std::move(type_), std::move(nodes_), |
484 | std::move(placeholders_), std::move(inputContext), |
485 | std::move(outputContext), runId_, success_); |
486 | |
487 | // Reset builder state to allow a new test to be constructed with this |
488 | // instance. |
489 | root_ = glow::make_unique<DAGNode>(); |
490 | module_ = glow::make_unique<Module>(); |
491 | bindings_->clear(); |
492 | type_ = std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {1, 2, 2})); |
493 | nodes_.clear(); |
494 | leaves_.clear(); |
495 | placeholders_.clear(); |
496 | success_ = true; |
497 | |
498 | return test; |
499 | } |
500 | |
501 | private: |
502 | /// Collect all input symbol names for the test. \returns a vector containing |
503 | /// the names of all test input symbols. |
504 | std::vector<std::string> gatherInputSymbols() const { |
505 | std::vector<std::string> inputSymbols; |
506 | |
507 | // Input symbols for the entire test are the inputs of all nodes that have |
508 | // no parents. |
509 | for (const auto &node : root_->children) { |
510 | const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable(); |
511 | |
512 | for (const auto &symbolPair : symbolTable) { |
513 | const auto &symbolName = symbolPair.first; |
514 | const auto &symbolInfo = symbolPair.second; |
515 | |
516 | if (symbolInfo.input) { |
517 | inputSymbols.emplace_back(symbolName); |
518 | } |
519 | } |
520 | } |
521 | |
522 | return inputSymbols; |
523 | } |
524 | |
525 | /// Collect all output symbol names for the test. \returns a vector containing |
526 | /// the names of all test output symbols. |
527 | std::vector<std::string> gatherOutputSymbols() const { |
528 | std::vector<std::string> outputSymbols; |
529 | |
530 | // Input symbols for the entire test are the outputs of all nodes that have |
531 | // no children. |
532 | for (const auto &node : leaves_) { |
533 | const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable(); |
534 | |
535 | for (const auto &symbolPair : symbolTable) { |
536 | const auto &symbolName = symbolPair.first; |
537 | const auto &symbolInfo = symbolPair.second; |
538 | |
539 | if (symbolInfo.output) { |
540 | outputSymbols.emplace_back(symbolName); |
541 | } |
542 | } |
543 | } |
544 | |
545 | return outputSymbols; |
546 | } |
547 | |
548 | /// Insert a Placeholder named \p name with type type_ into \p bindings |
549 | /// and generate a random Tensor for it. If this Placeholder has already been |
550 | /// mapped for the test being created, reuse the existing value. |
551 | void insertSymbolIntoPlaceholderBindings(llvm::StringRef name, |
552 | PlaceholderBindings *bindings) { |
553 | auto ph = module_->getPlaceholderByNameSlow(name); |
554 | |
555 | if (!ph) { |
556 | // This is a new symbol. Create a Placeholder and an initialize and new |
557 | // Tensor for it. |
558 | auto placeholder = module_->createPlaceholder(type_.get(), name, false); |
559 | auto *tensor = bindings_->allocate(placeholder); |
560 | tensor->init(Tensor::InitKind::Xavier, 1.0, rng_); |
561 | bindings->insert(placeholder, tensor->clone()); |
562 | } else { |
563 | // This is a symbol that already has an associated Placeholder and Tensor. |
564 | // Copy that Tensor. |
565 | const auto *tensor = bindings_->get(ph); |
566 | bindings->insert(ph, tensor->clone()); |
567 | } |
568 | } |
569 | |
570 | /// The Executor being tested. |
571 | std::shared_ptr<Executor> executor_; |
572 | /// Module for holding PHs |
573 | std::unique_ptr<Module> module_; |
574 | /// The root of the DAG being constructed. |
575 | std::unique_ptr<DAGNode> root_; |
576 | /// This PlaceholderBindings holds all created and initialized Placeholders |
577 | /// for the test. |
578 | std::unique_ptr<PlaceholderBindings> bindings_; |
579 | /// The Type for all Placeholders and Tensors in the test. The exact value |
580 | /// is not important; the main thing being tested is the propagation of |
581 | /// Placeholders and Tensors as the DAG executes. |
582 | std::unique_ptr<Type> type_; |
583 | /// PRNG for filling Tensors. |
584 | PseudoRNG rng_; |
585 | /// The nodes in the DAG being constructed. |
586 | DAGNodeNameMapTy nodes_; |
587 | /// The leaves in the DAG being constructed. This helps collect output symbols |
588 | /// during test emission. |
589 | std::unordered_set<const DAGNode *> leaves_; |
590 | /// All Placeholders in the test. |
591 | PlaceholderNameMapTy placeholders_; |
592 | /// The run ID for the DAG. |
593 | RunIdentifierTy runId_; |
594 | /// The expected result for the DAG. |
595 | bool success_; |
596 | /// Map from DeviceIDTy -> TestDeviceManager. This enables the construction of |
597 | /// tests with nodes spread across devices. |
598 | const DeviceManagerMapTy &deviceManagers_; |
599 | }; |
600 | |
601 | /// This test fixture provides ThreadPoolExecutor, ExecutorTestBuilder, |
602 | /// DeviceManagerMapTy instances to all tests. |
603 | class ThreadPoolExecutorTest : public ::testing::Test { |
604 | protected: |
605 | ThreadPoolExecutorTest() |
606 | : executor_(std::make_shared<ThreadPoolExecutor>(deviceManagerMap_)), |
607 | testBuilder_(executor_, deviceManagerMap_) {} |
608 | ~ThreadPoolExecutorTest() = default; |
609 | |
610 | /// The Executor being tested. |
611 | std::shared_ptr<ThreadPoolExecutor> executor_; |
612 | /// An ExecutorTestBuilder instance for creating tests. |
613 | ExecutorTestBuilder testBuilder_; |
614 | /// DeviceManager map for initializing executor_. |
615 | DeviceManagerMapTy deviceManagerMap_; |
616 | }; |
617 | |
618 | /// Tests that an empty DAG is handled correctly. |
619 | TEST_F(ThreadPoolExecutorTest, EmptyDAG) { |
620 | constexpr RunIdentifierTy testRunId = 10; |
621 | |
622 | // Make a PlaceholderBindings with one Placeholder in it to make sure |
623 | // Executor::run() doesn't modify it when the root given to it is null. Make |
624 | // two identical copies; one to give to Executor::run(), and another to |
625 | // compare the returned PlaceholderBindings with. |
626 | PseudoRNG rng; |
627 | auto type = std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {1, 2, 2})); |
628 | auto placeholder = glow::make_unique<Placeholder>( |
629 | "a" , type.get(), /*trainable=*/false, ANY_LAYOUT); |
630 | |
631 | auto testContext = glow::make_unique<ExecutionContext>(); |
632 | auto refContext = glow::make_unique<ExecutionContext>(); |
633 | |
634 | auto *tensor = |
635 | testContext->getPlaceholderBindings()->allocate(placeholder.get()); |
636 | tensor->init(Tensor::InitKind::Xavier, 1.0, rng); |
637 | refContext->getPlaceholderBindings()->insert(placeholder.get(), |
638 | tensor->clone()); |
639 | |
640 | // Variables for storing runId actually returned by |
641 | // Executor::run() via its callback. |
642 | RunIdentifierTy executorRunId; |
643 | std::unique_ptr<ExecutionContext> executorOutputContext; |
644 | |
645 | // Call Executor::run(). |
646 | std::promise<void> promise; |
647 | std::future<void> future = promise.get_future(); |
648 | std::unique_ptr<Error> runErr; |
649 | executor_->run(nullptr, std::move(testContext), testRunId, |
650 | [&runErr, &promise, &executorRunId, &executorOutputContext]( |
651 | RunIdentifierTy runId, Error err, |
652 | std::unique_ptr<ExecutionContext> context) { |
653 | executorRunId = runId; |
654 | executorOutputContext = std::move(context); |
655 | runErr = glow::make_unique<Error>(std::move(err)); |
656 | promise.set_value(); |
657 | }); |
658 | |
659 | EXPECT_FALSE(ERR_TO_BOOL(std::move(*DCHECK_NOTNULL(runErr.get())))); |
660 | |
661 | EXPECT_EQ(executorRunId, testRunId); |
662 | |
663 | EXPECT_TRUE(PlaceholderBindings::compare( |
664 | refContext->getPlaceholderBindings(), |
665 | executorOutputContext->getPlaceholderBindings())); |
666 | } |
667 | |
668 | /// Tests that a single node can run correctly. |
669 | TEST_F(ThreadPoolExecutorTest, SingleNode) { |
670 | constexpr RunIdentifierTy testRunId = 10; |
671 | constexpr DeviceIDTy testDeviceId = 111; |
672 | constexpr unsigned deviceManagerThreads = 1; |
673 | |
674 | // Make a TestDeviceManager and insert into the DeviceManagerMap map (which |
675 | // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map |
676 | // (which the ExecutorTestBuilder has a reference to). |
677 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
678 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
679 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
680 | |
681 | // Build the DAG. The DAG created below looks like this: |
682 | /** |
683 | * root |
684 | * | |
685 | * v |
686 | * net |
687 | **/ |
688 | |
689 | testBuilder_.addNode("net" , testDeviceId, |
690 | /*parents=*/{}, {"netInput" }, {"netOutput" }, testRunId, |
691 | true); |
692 | |
693 | ExecutorTest test = testBuilder_.emitTest(); |
694 | EXPECT_TRUE(test.run()); |
695 | } |
696 | |
697 | /// Tests that several instances of a single node DAG can be run in parallel. |
698 | TEST_F(ThreadPoolExecutorTest, ConcurrentSingleNode) { |
699 | constexpr RunIdentifierTy baseTestRunId = 10; |
700 | constexpr DeviceIDTy testDeviceId = 111; |
701 | constexpr unsigned deviceManagerThreads = 3; |
702 | unsigned numConcurrentRuns = 100; |
703 | |
704 | // Make a TestDeviceManager and insert into the DeviceManagerMap map (which |
705 | // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map |
706 | // (which the ExecutorTestBuilder has a reference to). |
707 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
708 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
709 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
710 | |
711 | // Mutex for accessing threadsReady and testsPassed. |
712 | std::mutex mtx; |
713 | // Condition variables for signalling between the test runner threads |
714 | // and this thread. These are used to implement a barrier that ensures |
715 | // all test runner threads have been created and are executing before any |
716 | // are allowed to run a test (in order to try and increase the number of |
717 | // threads that call Executor::run() at the same time). |
718 | std::condition_variable driverCV, threadCV; |
719 | // Counters for implementing the aforementioned barrier and tracking the |
720 | // number of tests that pass. |
721 | unsigned threadsReady = 0, testsPassed = 0; |
722 | std::vector<std::thread> threads; |
723 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
724 | // Build the DAG. The DAG created below looks like this: |
725 | /** |
726 | * root |
727 | * | |
728 | * v |
729 | * net |
730 | **/ |
731 | |
732 | // The names must be distinct since the DeviceManager distinguishes based |
733 | // on function name. The run IDs must also be distinct (hence the +i). |
734 | testBuilder_.addNode(strFormat("net_%d" , i), testDeviceId, |
735 | /*parents=*/{}, {"netInput" }, {"netOutput" }, |
736 | baseTestRunId + i, true); |
737 | ExecutorTest t = testBuilder_.emitTest(); |
738 | |
739 | std::thread th([&mtx, &driverCV, &threadCV, &threadsReady, &testsPassed, |
740 | test = std::move(t), numConcurrentRuns]() mutable { |
741 | std::unique_lock<std::mutex> lock(mtx); |
742 | // Increment threadsReady to mark this thread as ready to run the test. |
743 | threadsReady++; |
744 | // If threadsReady == numConcurrentRuns, this thread is the last to be |
745 | // initialized and execute, so signal the driver that all threads are |
746 | // ready. |
747 | if (threadsReady == numConcurrentRuns) { |
748 | driverCV.notify_one(); |
749 | } |
750 | // Wait for the driver's signal. |
751 | threadCV.wait(lock); |
752 | // Unlock the mutex to let all other threads run their tests concurrently. |
753 | lock.unlock(); |
754 | bool passed = test.run(); |
755 | lock.lock(); |
756 | |
757 | if (passed) { |
758 | testsPassed++; |
759 | } |
760 | }); |
761 | threads.emplace_back(std::move(th)); |
762 | } |
763 | |
764 | std::unique_lock<std::mutex> lock(mtx); |
765 | // If threadsReady != numConcurrentRuns, not all threads are ready to run |
766 | // their tests. Wait until they are. |
767 | if (threadsReady != numConcurrentRuns) { |
768 | driverCV.wait(lock, [&threadsReady, numConcurrentRuns] { |
769 | return threadsReady == numConcurrentRuns; |
770 | }); |
771 | } |
772 | // Wake up all test runners. |
773 | threadCV.notify_all(); |
774 | lock.unlock(); |
775 | |
776 | // Join all threads. |
777 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
778 | threads[i].join(); |
779 | } |
780 | |
781 | // All tests should pass. |
782 | EXPECT_EQ(testsPassed, numConcurrentRuns); |
783 | } |
784 | |
785 | /// Tests that successive calls to ThreadPoolExecutor::run() with the same |
786 | /// runId don't succeed. |
787 | TEST_F(ThreadPoolExecutorTest, ConcurrentSingleNodeDuplicateRunId) { |
788 | constexpr RunIdentifierTy testRunId = 10; |
789 | constexpr DeviceIDTy testDeviceId = 111; |
790 | constexpr unsigned deviceManagerThreads = 1; |
791 | constexpr unsigned numConcurrentRuns = 100; |
792 | |
793 | // Make a TestDeviceManager and insert into the DeviceManagerMap map (which |
794 | // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map |
795 | // (which the ExecutorTestBuilder has a reference to). |
796 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
797 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
798 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
799 | |
800 | std::atomic<unsigned> testsPassed{0}; |
801 | std::vector<std::thread> threads; |
802 | std::vector<ExecutorTest> tests; |
803 | |
804 | // Build all tests. |
805 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
806 | // Build the DAG. The DAG created below looks like this: |
807 | /** |
808 | * root |
809 | * | |
810 | * v |
811 | * net |
812 | **/ |
813 | |
814 | testBuilder_.addNode(strFormat("net_%d" , i), testDeviceId, |
815 | /*parents=*/{}, {"netInput" }, {"netOutput" }, testRunId, |
816 | true); |
817 | tests.emplace_back(testBuilder_.emitTest()); |
818 | } |
819 | |
820 | // Run all tests. |
821 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
822 | std::thread th([&testsPassed, test = std::move(tests[i])]() mutable { |
823 | bool passed = test.run(); |
824 | if (passed) { |
825 | testsPassed++; |
826 | } |
827 | }); |
828 | threads.emplace_back(std::move(th)); |
829 | } |
830 | |
831 | // Join all threads. |
832 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
833 | threads[i].join(); |
834 | } |
835 | |
836 | // At least one test should pass. Depending on the interleaving, the |
837 | // rest can all pass or all fail or anything in between. |
838 | EXPECT_GE(testsPassed, 1); |
839 | } |
840 | |
841 | /// Tests that a DAG with multiple nodes can run correctly. |
842 | TEST_F(ThreadPoolExecutorTest, MultiNode) { |
843 | constexpr RunIdentifierTy testRunId = 10; |
844 | constexpr DeviceIDTy testDeviceId = 111; |
845 | constexpr unsigned deviceManagerThreads = 3; |
846 | |
847 | // Make a TestDeviceManager and insert into the DeviceManagerMap map (which |
848 | // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map |
849 | // (which the ExecutorTestBuilder has a reference to). |
850 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
851 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
852 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
853 | |
854 | // Build the DAG. The DAG created below looks like this: |
855 | /** |
856 | * root |
857 | * / \ |
858 | * v v |
859 | * alpha beta |
860 | * \ / |
861 | * v v |
862 | * gamma |
863 | * / \ |
864 | * v v |
865 | * delta eps |
866 | **/ |
867 | |
868 | testBuilder_.addNode("alpha" , testDeviceId, |
869 | /*parents=*/{}, /*inputs=*/{"alphaIn" }, |
870 | /*outputs=*/{"alphaOut" }, testRunId, true); |
871 | testBuilder_.addNode("beta" , testDeviceId, |
872 | /*parents=*/{}, /*inputs=*/{"betaIn" }, |
873 | /*outputs=*/{"betaOut" }, testRunId, true); |
874 | testBuilder_.addNode("gamma" , testDeviceId, |
875 | /*parents=*/{"alpha" , "beta" }, |
876 | /*inputs=*/{"alphaOut" , "betaOut" }, |
877 | /*outputs=*/{"deltaIn" , "epsIn" }, testRunId, true); |
878 | testBuilder_.addNode("delta" , testDeviceId, |
879 | /*parents=*/{"gamma" }, /*inputs=*/{"deltaIn" }, |
880 | /*outputs=*/{"deltaOut" }, testRunId, true); |
881 | testBuilder_.addNode("eps" , testDeviceId, |
882 | /*parents=*/{"gamma" }, /*inputs=*/{"epsIn" }, |
883 | /*outputs=*/{"epsOut" }, testRunId, true); |
884 | |
885 | ExecutorTest test = testBuilder_.emitTest(); |
886 | EXPECT_TRUE(test.run()); |
887 | } |
888 | |
889 | /// Tests that a DAG with a node that fails can run correctly. |
890 | TEST_F(ThreadPoolExecutorTest, MultiNodeWithFailure) { |
891 | constexpr RunIdentifierTy testRunId = 10; |
892 | constexpr DeviceIDTy testDeviceId = 111; |
893 | constexpr unsigned deviceManagerThreads = 3; |
894 | |
895 | // Make a TestDeviceManager and insert into the DeviceManagerMap map (which |
896 | // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map |
897 | // (which the ExecutorTestBuilder has a reference to). |
898 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
899 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
900 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
901 | |
902 | // Build the DAG. The DAG created below looks like this: |
903 | /** |
904 | * root |
905 | * / \ |
906 | * v v |
907 | * alpha delta |
908 | * | | |
909 | * v v |
910 | * beta eps |
911 | * | | |
912 | * v v |
913 | * gamma zeta |
914 | **/ |
915 | |
916 | testBuilder_.addNode("alpha" , testDeviceId, |
917 | /*parents=*/{}, /*inputs=*/{"alphaIn" }, |
918 | /*outputs=*/{"alphaOut" }, testRunId, true); |
919 | testBuilder_.addNode("beta" , testDeviceId, |
920 | /*parents=*/{"alpha" }, /*inputs=*/{"alphaOut" }, |
921 | /*outputs=*/{"betaOut" }, testRunId, true); |
922 | testBuilder_.addNode("gamma" , testDeviceId, |
923 | /*parents=*/{"beta" }, |
924 | /*inputs=*/{"betaOut" }, |
925 | /*outputs=*/{"gammaOut" }, testRunId, true); |
926 | testBuilder_.addNode("delta" , testDeviceId, |
927 | /*parents=*/{}, /*inputs=*/{"deltaIn" }, |
928 | /*outputs=*/{"deltaOut" }, testRunId, true); |
929 | testBuilder_.addNode("eps" , testDeviceId, |
930 | /*parents=*/{"delta" }, /*inputs=*/{"deltaOut" }, |
931 | /*outputs=*/{"epsOut" }, testRunId, false); |
932 | testBuilder_.addNode("zeta" , testDeviceId, |
933 | /*parents=*/{"eps" }, /*inputs=*/{"epsOut" }, |
934 | /*outputs=*/{"zetaOut" }, testRunId, true); |
935 | |
936 | ExecutorTest test = testBuilder_.emitTest(); |
937 | EXPECT_TRUE(test.run()); |
938 | } |
939 | |
940 | /// Tests that a DAG with nodes spread across multiple devices can run |
941 | /// correctly. |
942 | TEST_F(ThreadPoolExecutorTest, MultiNodeMultiDevice) { |
943 | constexpr RunIdentifierTy testRunId = 10; |
944 | constexpr DeviceIDTy testDeviceIdA = 111; |
945 | constexpr DeviceIDTy testDeviceIdB = 112; |
946 | constexpr DeviceIDTy testDeviceIdC = 113; |
947 | constexpr unsigned deviceManagerThreads = 3; |
948 | |
949 | // Make TestDeviceManagers and insert them into the DeviceManagerMap map |
950 | // (which the ThreadPoolExecutor has a reference to) and the TestDeviceManager |
951 | // map (which the ExecutorTestBuilder has a reference to). |
952 | for (DeviceIDTy deviceId : {testDeviceIdA, testDeviceIdB, testDeviceIdC}) { |
953 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
954 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
955 | deviceManagerMap_.emplace(deviceId, std::move(deviceManager)); |
956 | } |
957 | |
958 | // Build the DAG. The DAG created below looks like this: |
959 | /** |
960 | * root |
961 | * / \ |
962 | * v v |
963 | * alpha beta |
964 | * \ / |
965 | * v v |
966 | * gamma |
967 | * / \ |
968 | * v v |
969 | * delta eps |
970 | **/ |
971 | |
972 | testBuilder_.addNode("alpha" , testDeviceIdA, |
973 | /*parents=*/{}, /*inputs=*/{"alphaIn" }, |
974 | /*outputs=*/{"alphaOut" }, testRunId, true); |
975 | testBuilder_.addNode("beta" , testDeviceIdB, |
976 | /*parents=*/{}, /*inputs=*/{"betaIn" }, |
977 | /*outputs=*/{"betaOut" }, testRunId, true); |
978 | testBuilder_.addNode("gamma" , testDeviceIdC, |
979 | /*parents=*/{"alpha" , "beta" }, |
980 | /*inputs=*/{"alphaOut" , "betaOut" }, |
981 | /*outputs=*/{"deltaIn" , "epsIn" }, testRunId, true); |
982 | testBuilder_.addNode("delta" , testDeviceIdA, |
983 | /*parents=*/{"gamma" }, /*inputs=*/{"deltaIn" }, |
984 | /*outputs=*/{"deltaOut" }, testRunId, true); |
985 | testBuilder_.addNode("eps" , testDeviceIdB, |
986 | /*parents=*/{"gamma" }, /*inputs=*/{"epsIn" }, |
987 | /*outputs=*/{"epsOut" }, testRunId, true); |
988 | |
989 | ExecutorTest test = testBuilder_.emitTest(); |
990 | EXPECT_TRUE(test.run()); |
991 | } |
992 | |
993 | /// Tests that several instances of a DAG with multiple nodes can run correctly |
994 | /// in parallel. |
995 | TEST_F(ThreadPoolExecutorTest, ConcurrentMultiNode) { |
996 | constexpr RunIdentifierTy baseTestRunId = 10; |
997 | constexpr DeviceIDTy testDeviceId = 111; |
998 | constexpr unsigned deviceManagerThreads = 3; |
999 | unsigned numConcurrentRuns = 100; |
1000 | |
1001 | // Make a TestDeviceManager and insert it into the DeviceManagerMap map |
1002 | // (which the ThreadPoolExecutor has a reference to) and the TestDeviceManager |
1003 | // map (which the ExecutorTestBuilder has a reference to). |
1004 | auto deviceManager = glow::make_unique<TestDeviceManager>( |
1005 | deviceManagerThreads, DeviceConfig("Interpreter" )); |
1006 | deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager)); |
1007 | |
1008 | // Mutex for accessing threadsReady and testsPassed. |
1009 | std::mutex mtx; |
1010 | // Condition variables for signalling between the test runner threads |
1011 | // and this thread. These are used to implement a barrier that ensures |
1012 | // all test runner threads have been created and are executing before any |
1013 | // are allowed to run a test (in order to try and increase the number of |
1014 | // threads that call Executor::run() at the same time). |
1015 | std::condition_variable driverCV, threadCV; |
1016 | // Counters for implementing the aforementioned barrier and tracking the |
1017 | // number of tests that pass. |
1018 | unsigned threadsReady = 0, testsPassed = 0; |
1019 | std::vector<std::thread> threads; |
1020 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
1021 | // Build the DAG. The DAG created below looks like this: |
1022 | /** |
1023 | * root |
1024 | * / \ |
1025 | * v v |
1026 | * alpha beta |
1027 | * \ / |
1028 | * v v |
1029 | * gamma |
1030 | * / \ |
1031 | * v v |
1032 | * delta eps |
1033 | **/ |
1034 | |
1035 | // The names must be distinct for each run since the DeviceManager |
1036 | // distinguishes based on function name. |
1037 | std::string alpha = strFormat("alpha_%d" , i); |
1038 | std::string beta = strFormat("beta_%d" , i); |
1039 | std::string gamma = strFormat("gamma_%d" , i); |
1040 | std::string delta = strFormat("delta_%d" , i); |
1041 | std::string eps = strFormat("eps_%d" , i); |
1042 | |
1043 | // The run IDs must be distinct as well to distinguish all the concurrent |
1044 | // runs from each other. |
1045 | testBuilder_.addNode(alpha, testDeviceId, |
1046 | /*parents=*/{}, /*inputs=*/{"alphaIn" }, |
1047 | /*outputs=*/{"alphaOut" }, baseTestRunId + i, true); |
1048 | testBuilder_.addNode(beta, testDeviceId, |
1049 | /*parents=*/{}, /*inputs=*/{"betaIn" }, |
1050 | /*outputs=*/{"betaOut" }, baseTestRunId + i, true); |
1051 | testBuilder_.addNode(gamma, testDeviceId, |
1052 | /*parents=*/{alpha, beta}, |
1053 | /*inputs=*/{"alphaOut" , "betaOut" }, |
1054 | /*outputs=*/{"deltaIn" , "epsIn" }, baseTestRunId + i, |
1055 | true); |
1056 | testBuilder_.addNode(delta, testDeviceId, |
1057 | /*parents=*/{gamma}, /*inputs=*/{"deltaIn" }, |
1058 | /*outputs=*/{"deltaOut" }, baseTestRunId + i, true); |
1059 | testBuilder_.addNode(eps, testDeviceId, |
1060 | /*parents=*/{gamma}, /*inputs=*/{"epsIn" }, |
1061 | /*outputs=*/{"epsOut" }, baseTestRunId + i, true); |
1062 | |
1063 | ExecutorTest t = testBuilder_.emitTest(); |
1064 | std::thread th([&mtx, &driverCV, &threadCV, &threadsReady, &testsPassed, |
1065 | test = std::move(t), numConcurrentRuns]() mutable { |
1066 | std::unique_lock<std::mutex> lock(mtx); |
1067 | // Increment threadsReady to mark this thread as ready to run the test. |
1068 | threadsReady++; |
1069 | // If threadsReady == numConcurrentRuns, this thread is the last to be |
1070 | // initialized and execute, so signal the driver that all threads are |
1071 | // ready. |
1072 | if (threadsReady == numConcurrentRuns) { |
1073 | driverCV.notify_one(); |
1074 | } |
1075 | // Wait for the driver's signal. |
1076 | threadCV.wait(lock); |
1077 | // Unlock the mutex to let all other threads run their tests concurrently. |
1078 | lock.unlock(); |
1079 | bool passed = test.run(); |
1080 | lock.lock(); |
1081 | |
1082 | if (passed) { |
1083 | testsPassed++; |
1084 | } |
1085 | }); |
1086 | threads.emplace_back(std::move(th)); |
1087 | } |
1088 | |
1089 | std::unique_lock<std::mutex> lock(mtx); |
1090 | // If threadsReady != numConcurrentRuns, not all threads are ready to run |
1091 | // their tests. Wait until they are. |
1092 | if (threadsReady != numConcurrentRuns) { |
1093 | driverCV.wait(lock, [&threadsReady, numConcurrentRuns] { |
1094 | return threadsReady == numConcurrentRuns; |
1095 | }); |
1096 | } |
1097 | // Wake up all test runners. |
1098 | threadCV.notify_all(); |
1099 | lock.unlock(); |
1100 | |
1101 | // Join all threads. |
1102 | for (unsigned i = 0; i < numConcurrentRuns; ++i) { |
1103 | threads[i].join(); |
1104 | } |
1105 | |
1106 | // All tests should pass. |
1107 | EXPECT_EQ(testsPassed, numConcurrentRuns); |
1108 | } |
1109 | |