1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/Executor/ThreadPoolExecutor.h"
18#include "glow/Backends/DeviceManager.h"
19#include "glow/Support/Support.h"
20#include "glow/Support/ThreadPool.h"
21
22#include "gtest/gtest.h"
23
24#include <chrono>
25#include <future>
26#include <thread>
27#include <unordered_set>
28
29using namespace glow;
30using namespace glow::runtime;
31
32/// This is an implementation of DeviceManager tailored for testing Executor
33/// implementations. registerResult() gives the caller the ability to
34/// dictate precisely what a subsequent call to runFunction() should return.
35/// registerResult() should be called before calling Executor::run() in each
36/// test. The rest of the implementation of the DeviceManager interface exists
37/// to satisfy the compiler.
38class TestDeviceManager final : public runtime::DeviceManager {
39public:
40 TestDeviceManager(unsigned numWorkers, const DeviceConfig &deviceConfig)
41 : DeviceManager(deviceConfig), threadPool_(numWorkers) {}
42
43 /// The functions below are the interface for DeviceManager. See
44 /// glow::DeviceManager for descriptions of what they do. Since this
45 /// class exists only to help test Executor implementations, the only
46 /// important function is runFunction().
47 void addNetwork(const Module *module, FunctionMapTy functions,
48 ReadyCBTy readyCB) override {}
49
50 void evictNetwork(std::string functionName,
51 EvictFunctionCBTy evictCB) override {
52 // Erase the entry so that the same function name can be used to register
53 // another result.
54
55 if (!resultMap_.erase(functionName)) {
56 evictCB(
57 functionName,
58 MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_NOT_FOUND,
59 strFormat("Could not find function with name %s to evict",
60 functionName.c_str())));
61 return;
62 }
63 evictCB(functionName, Error::success());
64 }
65
66 /// Look up the previously registered response for \p functionName and
67 /// call \p resultCB with it after checking that \p context contains the
68 /// expected Placeholder-Tensor mappings.
69 void doRunFunction(std::string functionName,
70 std::unique_ptr<ExecutionContext> context,
71 ResultCBTy resultCB) {
72
73 RunIdentifierTy runId = 0;
74 bool successResult = false;
75
76 // Retrieve the registered response for the function if there is one.
77 if (context && resultCB && resultMap_.count(functionName)) {
78 std::unique_ptr<RunFunctionResult> registeredResult =
79 std::move(resultMap_[functionName]);
80
81 // Check that context contains the expected Placeholder-Tensor mappings.
82 std::unique_ptr<ExecutionContext> inputContext =
83 std::move(registeredResult->inputContext);
84
85 bool equalInputs = true;
86 for (auto &p : inputContext->getPlaceholderBindings()->pairs()) {
87 Tensor *CT = context->getPlaceholderBindings()->get(p.first);
88 if (!CT) {
89 equalInputs = false;
90 break;
91 }
92 equalInputs &= p.second.isEqual(*CT, 0.0001, true);
93 }
94
95 if (equalInputs) {
96 // If bindings contains all expected mappings, overwrite the default
97 // runId, result and resultContext with the registered
98 // ones.
99 runId = registeredResult->runId;
100 successResult = registeredResult->success;
101
102 for (const auto &p :
103 registeredResult->resultContext->getPlaceholderBindings()
104 ->pairs()) {
105 context->getPlaceholderBindings()->get(p.first)->assign(&p.second);
106 }
107 }
108 }
109
110 if (successResult) {
111 resultCB(runId, Error::success(), std::move(context));
112 } else {
113 resultCB(runId, MAKE_ERR("An error occurred"), std::move(context));
114 }
115 }
116
117 /// Do not call this at the same time as registerResult().
118 runtime::RunIdentifierTy
119 runFunction(std::string functionName,
120 std::unique_ptr<ExecutionContext> context,
121 ResultCBTy resultCB) override {
122 // Give the call to the thread pool to process to make the tests
123 // multithreaded if needed.
124 this->threadPool_.submit(
125 [this, functionName, context = std::move(context), resultCB]() mutable {
126 this->doRunFunction(functionName, std::move(context), resultCB);
127 });
128 return 0;
129 }
130
131 uint64_t getMaximumMemory() const override {
132 return std::numeric_limits<uint64_t>::max();
133 }
134
135 uint64_t getAvailableMemory() const override {
136 return std::numeric_limits<uint64_t>::max();
137 }
138
139 bool isMemoryAvailable(uint64_t /*estimate*/) const override { return true; }
140
141 /// Register a result that should be returned by the subsequent call to
142 /// runFunction with the same \p functionName. The callback for that call
143 /// to runFunction will be called with \p runId, \p success, and \p
144 /// \p resultContext if the context passed in to runFunction
145 /// matches \p inputContext. \returns true if registration was
146 /// successful, false if not. Do not call this at the same time as
147 /// runFunction().
148 bool registerResult(const std::string &functionName, RunIdentifierTy runId,
149 bool success,
150 std::unique_ptr<ExecutionContext> inputContext,
151 std::unique_ptr<ExecutionContext> resultContext) {
152 bool registered = false;
153
154 if (!resultMap_.count(functionName)) {
155 // If the function name has not already been registered, insert it into
156 // resultMap_.
157 std::tie(std::ignore, registered) = resultMap_.insert(std::make_pair(
158 functionName, glow::make_unique<RunFunctionResult>(
159 runId, success, std::move(inputContext),
160 std::move(resultContext))));
161 }
162
163 return registered;
164 }
165
166private:
167 /// This struct wraps all of the data needed to reply to a runFunction() call.
168 /// It exists so that that all of these things can be stored in one map.
169 struct RunFunctionResult {
170 /// The run ID that should be returned.
171 RunIdentifierTy runId;
172 /// If success then no error should be returned otherwise an Error should be
173 /// returned.
174 bool success;
175 /// The expected input context for the invocation.
176 std::unique_ptr<ExecutionContext> inputContext;
177 /// The result context that should be returned.
178 std::unique_ptr<ExecutionContext> resultContext;
179
180 /// Constructor.
181 RunFunctionResult(RunIdentifierTy run, bool successParam,
182 std::unique_ptr<ExecutionContext> inputcontext,
183 std::unique_ptr<ExecutionContext> resultcontext)
184 : runId(run), success(successParam),
185 inputContext(std::move(inputcontext)),
186 resultContext(std::move(resultcontext)) {}
187 };
188
189 /// Map of function name -> RunFunctionResult instance containing the
190 /// RunFunctionResult instance for the function.
191 using TestDeviceManagerResultMapTy =
192 std::unordered_map<std::string, std::unique_ptr<RunFunctionResult>>;
193
194 /// Map for storing registered results.
195 TestDeviceManagerResultMapTy resultMap_;
196 /// Thread pool for executing runFunction() in a multithreaded fashion.
197 ThreadPool threadPool_;
198};
199
200using PlaceholderNameMapTy =
201 std::unordered_map<std::string, std::unique_ptr<Placeholder>>;
202using DAGNodeNameMapTy =
203 std::unordered_map<std::string, std::unique_ptr<DAGNode>>;
204
205/// This class serves as an interface to a test created by ExecutorTestBuilder.
206/// It also contains the resources necessary to run the test. Instances are
207/// meant to be created only by ExecutorTestBuilder.
208class ExecutorTest final {
209public:
210 /// Constructor.
211 ExecutorTest(const std::shared_ptr<Executor> &executor,
212 std::unique_ptr<DAGNode> root, std::unique_ptr<Module> module,
213 std::unique_ptr<Type> type, DAGNodeNameMapTy nodes,
214 PlaceholderNameMapTy placeholders,
215 std::unique_ptr<ExecutionContext> inputContext,
216 std::unique_ptr<ExecutionContext> outputContext,
217 RunIdentifierTy runId, bool expectSuccess)
218 : executor_(executor), root_(std::move(root)), module_(std::move(module)),
219 type_(std::move(type)), nodes_(std::move(nodes)),
220 placeholders_(std::move(placeholders)),
221 inputContext_(std::move(inputContext)),
222 outputContext_(std::move(outputContext)), runId_(runId),
223 expectSuccess_(expectSuccess), testRun_(false) {
224 root_->module = module_.get();
225 // Create context pool.
226 executor_->createPool(root_.get(), 1000, false, false);
227 }
228
229 /// Run the test.
230 bool run() {
231 if (testRun_) {
232 assert(!"Test has already been run!");
233 }
234
235 // Variables for storing runId actually returned by
236 // Executor::run() via its callback.
237 RunIdentifierTy executorRunId;
238 std::unique_ptr<ExecutionContext> executorOutputContext;
239
240 // Call Executor::run().
241 std::promise<bool> promise;
242 std::future<bool> future = promise.get_future();
243 executor_->run(root_.get(), std::move(inputContext_), runId_,
244 [&promise, &executorRunId, &executorOutputContext](
245 RunIdentifierTy runId, Error err,
246 std::unique_ptr<ExecutionContext> context) {
247 executorRunId = runId;
248 executorOutputContext = std::move(context);
249 promise.set_value(ERR_TO_BOOL(std::move(err)));
250 });
251
252 bool runSuccess = !future.get();
253
254 // Check that the values returned in the Executor callback match
255 // expectations.
256 bool runIdsMatch = executorRunId == runId_;
257 bool resultsMatch = runSuccess == expectSuccess_;
258
259 bool bindingsMatch = PlaceholderBindings::compare(
260 executorOutputContext->getPlaceholderBindings(),
261 outputContext_->getPlaceholderBindings());
262
263 // If the run failed, we shouldn't expect bindingsMatch to be true.
264 bool testPassed =
265 runIdsMatch && resultsMatch && (!runSuccess || bindingsMatch);
266
267 testRun_ = true;
268 executor_->freePool(root_.get());
269 return testPassed;
270 }
271
272private:
273 /// The Executor to run the test with.
274 std::shared_ptr<Executor> executor_;
275 /// The root node of the DAG being tested.
276 std::unique_ptr<DAGNode> root_;
277 /// The Module containing the PHs.
278 std::unique_ptr<Module> module_;
279 /// The Type for all of the Placeholders that will be used during execution.
280 std::unique_ptr<Type> type_;
281 /// All nodes in the DAG.
282 DAGNodeNameMapTy nodes_;
283 /// All Placeholders that will be used during execution.
284 PlaceholderNameMapTy placeholders_;
285 /// The input ExecutionContext that should be passed to Executor::run()
286 /// when running the test.
287 std::unique_ptr<ExecutionContext> inputContext_;
288 /// The expected ExecutionContext that the Executor should return.
289 std::unique_ptr<ExecutionContext> outputContext_;
290 /// The run ID that should be passed to Executor::run() when running
291 /// the test.
292 RunIdentifierTy runId_;
293 /// The expected result that the Executor should return.
294 bool expectSuccess_;
295 /// Tracks whether or not the test has already been run.
296 bool testRun_;
297};
298
299/// This class helps build tests for testing Executor implementations. It
300/// presents a simple interface for executor DAG construction; nodes are added
301/// by specifying its parents, device ID, and named inputs and outputs. This
302/// builder class takes care of all of the work needed to actually run this DAG:
303/// creation of Placeholders and Tensors for all inputs and outputs; creation of
304/// input/output ExecutionContext for each node to verify that each one
305/// receives the correct input and produces the correct output; and registration
306/// with the TestDeviceManager.
307class ExecutorTestBuilder final {
308public:
309 /// Constructor. The exact value of type_ doesn't really matter since the
310 /// important thing to test is that that Placeholder values are propagated
311 /// between ExecutionContexts correctly.
312 ExecutorTestBuilder(const std::shared_ptr<Executor> &executor,
313 const DeviceManagerMapTy &deviceManagers)
314 : executor_(executor), module_(glow::make_unique<Module>()),
315 root_(glow::make_unique<DAGNode>()),
316 bindings_(glow::make_unique<PlaceholderBindings>()),
317 type_(
318 std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {32, 64, 128}))),
319 success_(true), deviceManagers_(deviceManagers) {}
320
321 /// Add a node named \p name to the DAG with parents \p parents that runs on a
322 /// device specified by \p deviceId. A RuntimeBundle is created for the node
323 /// with runtime symbol information created from \p inputs and \p outputs.
324 /// \p runId is the run ID for the node and \p success is the desired
325 /// execution status. If \p parents is empty, the new node is added as a child
326 /// of the root.
327 void addNode(const std::string &name, DeviceIDTy deviceId,
328 llvm::ArrayRef<llvm::StringRef> parents,
329 llvm::ArrayRef<llvm::StringRef> inputs,
330 llvm::ArrayRef<llvm::StringRef> outputs, RunIdentifierTy runId,
331 bool success) {
332 auto newNode = glow::make_unique<DAGNode>();
333 auto *newNodeRawPtr = newNode.get();
334
335 // If this is the first node being added, record the run ID for the graph.
336 // Otherwise, make sure that the runId matches that of the previous nodes.
337 if (nodes_.empty()) {
338 runId_ = runId;
339 } else {
340 assert(runId == runId_ && "Node run ID does not match rest of graph!");
341 }
342
343 // If the result for this node is false, set the expected
344 // result for the entire test to false.
345 success_ &= success;
346
347 // Add parents to the list of parents in the new node and add the newNode
348 // to the list of children in the parents. If the parent list is empty,
349 // make the root the only parent. Also, update the set of known leaves
350 // by removing any parents of the new node from it. This will be useful
351 // later.
352 if (!parents.empty()) {
353 for (const auto &parent : parents) {
354 auto it = nodes_.find(parent.str());
355 if (it == nodes_.end()) {
356 assert(!"Parent specified for node not found!");
357 }
358 DAGNode *parentPtr = (it->second).get();
359 (newNode->parents).emplace_back(parentPtr);
360 (parentPtr->children).emplace_back(newNodeRawPtr);
361 leaves_.erase(parentPtr);
362 }
363 } else {
364 (newNode->parents).emplace_back(root_.get());
365 (root_->children).emplace_back(newNode.get());
366 }
367
368 // Iterate through inputs and outputs and:
369 // 1) Create Placeholders and Tensors for inputs/output names that have not
370 // been mapped to a Placeholder yet.
371 // 2) Assemble the input ExecutionContexts that the node is expected to be
372 // called with
373 // and the ExecutionContexts that the node should produce as output.
374 // 3) Generate the symbol table for the new node by generating
375 // RuntimeSymbolInfo objects for each input and output.
376 SymbolTableTy symbolTable;
377 size_t offset = 0;
378
379 auto nodeInputContext = glow::make_unique<ExecutionContext>();
380 auto nodeOutputContext = glow::make_unique<ExecutionContext>();
381
382 auto nodeInputBindings = nodeInputContext->getPlaceholderBindings();
383 auto nodeOutputBindings = nodeOutputContext->getPlaceholderBindings();
384
385 for (const auto &input : inputs) {
386 // Both input and output bindings should contain bindings for the inputs.
387 insertSymbolIntoPlaceholderBindings(input, nodeInputBindings);
388 insertSymbolIntoPlaceholderBindings(input, nodeOutputBindings);
389
390 RuntimeSymbolInfo runtimeSymbolInfo;
391 runtimeSymbolInfo.size = type_->getSizeInBytes();
392 runtimeSymbolInfo.offset = offset;
393 runtimeSymbolInfo.type = *type_;
394 runtimeSymbolInfo.input = true;
395 runtimeSymbolInfo.output = false;
396 runtimeSymbolInfo.symbolCategory = SymbolCategory::Placeholder;
397 symbolTable.insert(std::make_pair(input, runtimeSymbolInfo));
398 offset += type_->getSizeInBytes();
399 }
400
401 for (const auto &output : outputs) {
402 insertSymbolIntoPlaceholderBindings(output, nodeOutputBindings);
403
404 RuntimeSymbolInfo runtimeSymbolInfo;
405 runtimeSymbolInfo.size = type_->getSizeInBytes();
406 runtimeSymbolInfo.offset = offset;
407 runtimeSymbolInfo.type = *type_;
408 runtimeSymbolInfo.input = false;
409 runtimeSymbolInfo.output = true;
410 runtimeSymbolInfo.symbolCategory = SymbolCategory::Placeholder;
411 symbolTable.insert(std::make_pair(output, runtimeSymbolInfo));
412 offset += type_->getSizeInBytes();
413 }
414
415 // Set the name, device ID, and RuntimeBundle of the new node.
416 newNode->name = name;
417 newNode->deviceRuntimeInfos[deviceId] = DeviceRuntimeInfo();
418
419 newNode->runtimeBundle = glow::make_unique<RuntimeBundle>(
420 symbolTable, /*constWeight=*/0, /*mutableWeight=*/0,
421 /*activations=*/0);
422
423 // Register node result with the appropriate DeviceManager.
424 auto it = deviceManagers_.find(deviceId);
425
426 if (it == deviceManagers_.end()) {
427 assert(!"No test device manager found for this device ID");
428 }
429
430 auto *deviceManagerPtr = it->second.get();
431 auto testDeviceManagerPtr =
432 static_cast<TestDeviceManager *>(deviceManagerPtr);
433
434 bool registered = testDeviceManagerPtr->registerResult(
435 name, runId, success, std::move(nodeInputContext),
436 std::move(nodeOutputContext));
437
438 (void)registered;
439 assert(registered && "Node registration was not successful");
440
441 // Add the new node to nodes_ and leaves_.
442 nodes_.insert(std::make_pair(name, std::move(newNode)));
443 leaves_.insert(newNodeRawPtr);
444 }
445
446 /// Emit the test built so far and clear any state in the builder.
447 ExecutorTest emitTest() {
448 // Get the input and output symbol names for the whole DAG.
449 std::vector<std::string> inputSymbols = gatherInputSymbols();
450 std::vector<std::string> outputSymbols = gatherOutputSymbols();
451
452 // Generate the input and output ExecutionContexts for the test. This
453 // input ExecutionContexts contains the input Placeholders of all root
454 // nodes and output Placeholders of all leaves (but backed by zero tensors).
455 // This is the ExecutionContexts that needs to be passed to
456 // Executor::run() to run the test. The output ExecutionContexts contains
457 // the same Placeholders as the input ExecutionContexts, but the leaves'
458 // output Placeholders are mapped to their expected output Tensors. This
459 // ExecutionContext is used to verify that the one returned by the
460 // Executor is correct.
461 auto inputContext = glow::make_unique<ExecutionContext>();
462 auto outputContext = glow::make_unique<ExecutionContext>();
463
464 for (const auto &symbol : inputSymbols) {
465 insertSymbolIntoPlaceholderBindings(
466 symbol, inputContext->getPlaceholderBindings());
467 insertSymbolIntoPlaceholderBindings(
468 symbol, outputContext->getPlaceholderBindings());
469 }
470
471 for (const auto &symbol : outputSymbols) {
472 auto *placeholder = bindings_->getPlaceholderByNameSlow(symbol);
473 if (!placeholder) {
474 assert(!"Placeholder for DAG output not found!");
475 }
476 insertSymbolIntoPlaceholderBindings(
477 symbol, inputContext->getPlaceholderBindings());
478 insertSymbolIntoPlaceholderBindings(
479 symbol, outputContext->getPlaceholderBindings());
480 }
481 // Create the test object.
482 ExecutorTest test(executor_, std::move(root_), std::move(module_),
483 std::move(type_), std::move(nodes_),
484 std::move(placeholders_), std::move(inputContext),
485 std::move(outputContext), runId_, success_);
486
487 // Reset builder state to allow a new test to be constructed with this
488 // instance.
489 root_ = glow::make_unique<DAGNode>();
490 module_ = glow::make_unique<Module>();
491 bindings_->clear();
492 type_ = std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {1, 2, 2}));
493 nodes_.clear();
494 leaves_.clear();
495 placeholders_.clear();
496 success_ = true;
497
498 return test;
499 }
500
501private:
502 /// Collect all input symbol names for the test. \returns a vector containing
503 /// the names of all test input symbols.
504 std::vector<std::string> gatherInputSymbols() const {
505 std::vector<std::string> inputSymbols;
506
507 // Input symbols for the entire test are the inputs of all nodes that have
508 // no parents.
509 for (const auto &node : root_->children) {
510 const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable();
511
512 for (const auto &symbolPair : symbolTable) {
513 const auto &symbolName = symbolPair.first;
514 const auto &symbolInfo = symbolPair.second;
515
516 if (symbolInfo.input) {
517 inputSymbols.emplace_back(symbolName);
518 }
519 }
520 }
521
522 return inputSymbols;
523 }
524
525 /// Collect all output symbol names for the test. \returns a vector containing
526 /// the names of all test output symbols.
527 std::vector<std::string> gatherOutputSymbols() const {
528 std::vector<std::string> outputSymbols;
529
530 // Input symbols for the entire test are the outputs of all nodes that have
531 // no children.
532 for (const auto &node : leaves_) {
533 const SymbolTableTy &symbolTable = node->runtimeBundle->getSymbolTable();
534
535 for (const auto &symbolPair : symbolTable) {
536 const auto &symbolName = symbolPair.first;
537 const auto &symbolInfo = symbolPair.second;
538
539 if (symbolInfo.output) {
540 outputSymbols.emplace_back(symbolName);
541 }
542 }
543 }
544
545 return outputSymbols;
546 }
547
548 /// Insert a Placeholder named \p name with type type_ into \p bindings
549 /// and generate a random Tensor for it. If this Placeholder has already been
550 /// mapped for the test being created, reuse the existing value.
551 void insertSymbolIntoPlaceholderBindings(llvm::StringRef name,
552 PlaceholderBindings *bindings) {
553 auto ph = module_->getPlaceholderByNameSlow(name);
554
555 if (!ph) {
556 // This is a new symbol. Create a Placeholder and an initialize and new
557 // Tensor for it.
558 auto placeholder = module_->createPlaceholder(type_.get(), name, false);
559 auto *tensor = bindings_->allocate(placeholder);
560 tensor->init(Tensor::InitKind::Xavier, 1.0, rng_);
561 bindings->insert(placeholder, tensor->clone());
562 } else {
563 // This is a symbol that already has an associated Placeholder and Tensor.
564 // Copy that Tensor.
565 const auto *tensor = bindings_->get(ph);
566 bindings->insert(ph, tensor->clone());
567 }
568 }
569
570 /// The Executor being tested.
571 std::shared_ptr<Executor> executor_;
572 /// Module for holding PHs
573 std::unique_ptr<Module> module_;
574 /// The root of the DAG being constructed.
575 std::unique_ptr<DAGNode> root_;
576 /// This PlaceholderBindings holds all created and initialized Placeholders
577 /// for the test.
578 std::unique_ptr<PlaceholderBindings> bindings_;
579 /// The Type for all Placeholders and Tensors in the test. The exact value
580 /// is not important; the main thing being tested is the propagation of
581 /// Placeholders and Tensors as the DAG executes.
582 std::unique_ptr<Type> type_;
583 /// PRNG for filling Tensors.
584 PseudoRNG rng_;
585 /// The nodes in the DAG being constructed.
586 DAGNodeNameMapTy nodes_;
587 /// The leaves in the DAG being constructed. This helps collect output symbols
588 /// during test emission.
589 std::unordered_set<const DAGNode *> leaves_;
590 /// All Placeholders in the test.
591 PlaceholderNameMapTy placeholders_;
592 /// The run ID for the DAG.
593 RunIdentifierTy runId_;
594 /// The expected result for the DAG.
595 bool success_;
596 /// Map from DeviceIDTy -> TestDeviceManager. This enables the construction of
597 /// tests with nodes spread across devices.
598 const DeviceManagerMapTy &deviceManagers_;
599};
600
601/// This test fixture provides ThreadPoolExecutor, ExecutorTestBuilder,
602/// DeviceManagerMapTy instances to all tests.
603class ThreadPoolExecutorTest : public ::testing::Test {
604protected:
605 ThreadPoolExecutorTest()
606 : executor_(std::make_shared<ThreadPoolExecutor>(deviceManagerMap_)),
607 testBuilder_(executor_, deviceManagerMap_) {}
608 ~ThreadPoolExecutorTest() = default;
609
610 /// The Executor being tested.
611 std::shared_ptr<ThreadPoolExecutor> executor_;
612 /// An ExecutorTestBuilder instance for creating tests.
613 ExecutorTestBuilder testBuilder_;
614 /// DeviceManager map for initializing executor_.
615 DeviceManagerMapTy deviceManagerMap_;
616};
617
618/// Tests that an empty DAG is handled correctly.
619TEST_F(ThreadPoolExecutorTest, EmptyDAG) {
620 constexpr RunIdentifierTy testRunId = 10;
621
622 // Make a PlaceholderBindings with one Placeholder in it to make sure
623 // Executor::run() doesn't modify it when the root given to it is null. Make
624 // two identical copies; one to give to Executor::run(), and another to
625 // compare the returned PlaceholderBindings with.
626 PseudoRNG rng;
627 auto type = std::unique_ptr<Type>(new Type(ElemKind::FloatTy, {1, 2, 2}));
628 auto placeholder = glow::make_unique<Placeholder>(
629 "a", type.get(), /*trainable=*/false, ANY_LAYOUT);
630
631 auto testContext = glow::make_unique<ExecutionContext>();
632 auto refContext = glow::make_unique<ExecutionContext>();
633
634 auto *tensor =
635 testContext->getPlaceholderBindings()->allocate(placeholder.get());
636 tensor->init(Tensor::InitKind::Xavier, 1.0, rng);
637 refContext->getPlaceholderBindings()->insert(placeholder.get(),
638 tensor->clone());
639
640 // Variables for storing runId actually returned by
641 // Executor::run() via its callback.
642 RunIdentifierTy executorRunId;
643 std::unique_ptr<ExecutionContext> executorOutputContext;
644
645 // Call Executor::run().
646 std::promise<void> promise;
647 std::future<void> future = promise.get_future();
648 std::unique_ptr<Error> runErr;
649 executor_->run(nullptr, std::move(testContext), testRunId,
650 [&runErr, &promise, &executorRunId, &executorOutputContext](
651 RunIdentifierTy runId, Error err,
652 std::unique_ptr<ExecutionContext> context) {
653 executorRunId = runId;
654 executorOutputContext = std::move(context);
655 runErr = glow::make_unique<Error>(std::move(err));
656 promise.set_value();
657 });
658
659 EXPECT_FALSE(ERR_TO_BOOL(std::move(*DCHECK_NOTNULL(runErr.get()))));
660
661 EXPECT_EQ(executorRunId, testRunId);
662
663 EXPECT_TRUE(PlaceholderBindings::compare(
664 refContext->getPlaceholderBindings(),
665 executorOutputContext->getPlaceholderBindings()));
666}
667
668/// Tests that a single node can run correctly.
669TEST_F(ThreadPoolExecutorTest, SingleNode) {
670 constexpr RunIdentifierTy testRunId = 10;
671 constexpr DeviceIDTy testDeviceId = 111;
672 constexpr unsigned deviceManagerThreads = 1;
673
674 // Make a TestDeviceManager and insert into the DeviceManagerMap map (which
675 // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map
676 // (which the ExecutorTestBuilder has a reference to).
677 auto deviceManager = glow::make_unique<TestDeviceManager>(
678 deviceManagerThreads, DeviceConfig("Interpreter"));
679 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
680
681 // Build the DAG. The DAG created below looks like this:
682 /**
683 * root
684 * |
685 * v
686 * net
687 **/
688
689 testBuilder_.addNode("net", testDeviceId,
690 /*parents=*/{}, {"netInput"}, {"netOutput"}, testRunId,
691 true);
692
693 ExecutorTest test = testBuilder_.emitTest();
694 EXPECT_TRUE(test.run());
695}
696
697/// Tests that several instances of a single node DAG can be run in parallel.
698TEST_F(ThreadPoolExecutorTest, ConcurrentSingleNode) {
699 constexpr RunIdentifierTy baseTestRunId = 10;
700 constexpr DeviceIDTy testDeviceId = 111;
701 constexpr unsigned deviceManagerThreads = 3;
702 unsigned numConcurrentRuns = 100;
703
704 // Make a TestDeviceManager and insert into the DeviceManagerMap map (which
705 // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map
706 // (which the ExecutorTestBuilder has a reference to).
707 auto deviceManager = glow::make_unique<TestDeviceManager>(
708 deviceManagerThreads, DeviceConfig("Interpreter"));
709 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
710
711 // Mutex for accessing threadsReady and testsPassed.
712 std::mutex mtx;
713 // Condition variables for signalling between the test runner threads
714 // and this thread. These are used to implement a barrier that ensures
715 // all test runner threads have been created and are executing before any
716 // are allowed to run a test (in order to try and increase the number of
717 // threads that call Executor::run() at the same time).
718 std::condition_variable driverCV, threadCV;
719 // Counters for implementing the aforementioned barrier and tracking the
720 // number of tests that pass.
721 unsigned threadsReady = 0, testsPassed = 0;
722 std::vector<std::thread> threads;
723 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
724 // Build the DAG. The DAG created below looks like this:
725 /**
726 * root
727 * |
728 * v
729 * net
730 **/
731
732 // The names must be distinct since the DeviceManager distinguishes based
733 // on function name. The run IDs must also be distinct (hence the +i).
734 testBuilder_.addNode(strFormat("net_%d", i), testDeviceId,
735 /*parents=*/{}, {"netInput"}, {"netOutput"},
736 baseTestRunId + i, true);
737 ExecutorTest t = testBuilder_.emitTest();
738
739 std::thread th([&mtx, &driverCV, &threadCV, &threadsReady, &testsPassed,
740 test = std::move(t), numConcurrentRuns]() mutable {
741 std::unique_lock<std::mutex> lock(mtx);
742 // Increment threadsReady to mark this thread as ready to run the test.
743 threadsReady++;
744 // If threadsReady == numConcurrentRuns, this thread is the last to be
745 // initialized and execute, so signal the driver that all threads are
746 // ready.
747 if (threadsReady == numConcurrentRuns) {
748 driverCV.notify_one();
749 }
750 // Wait for the driver's signal.
751 threadCV.wait(lock);
752 // Unlock the mutex to let all other threads run their tests concurrently.
753 lock.unlock();
754 bool passed = test.run();
755 lock.lock();
756
757 if (passed) {
758 testsPassed++;
759 }
760 });
761 threads.emplace_back(std::move(th));
762 }
763
764 std::unique_lock<std::mutex> lock(mtx);
765 // If threadsReady != numConcurrentRuns, not all threads are ready to run
766 // their tests. Wait until they are.
767 if (threadsReady != numConcurrentRuns) {
768 driverCV.wait(lock, [&threadsReady, numConcurrentRuns] {
769 return threadsReady == numConcurrentRuns;
770 });
771 }
772 // Wake up all test runners.
773 threadCV.notify_all();
774 lock.unlock();
775
776 // Join all threads.
777 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
778 threads[i].join();
779 }
780
781 // All tests should pass.
782 EXPECT_EQ(testsPassed, numConcurrentRuns);
783}
784
785/// Tests that successive calls to ThreadPoolExecutor::run() with the same
786/// runId don't succeed.
787TEST_F(ThreadPoolExecutorTest, ConcurrentSingleNodeDuplicateRunId) {
788 constexpr RunIdentifierTy testRunId = 10;
789 constexpr DeviceIDTy testDeviceId = 111;
790 constexpr unsigned deviceManagerThreads = 1;
791 constexpr unsigned numConcurrentRuns = 100;
792
793 // Make a TestDeviceManager and insert into the DeviceManagerMap map (which
794 // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map
795 // (which the ExecutorTestBuilder has a reference to).
796 auto deviceManager = glow::make_unique<TestDeviceManager>(
797 deviceManagerThreads, DeviceConfig("Interpreter"));
798 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
799
800 std::atomic<unsigned> testsPassed{0};
801 std::vector<std::thread> threads;
802 std::vector<ExecutorTest> tests;
803
804 // Build all tests.
805 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
806 // Build the DAG. The DAG created below looks like this:
807 /**
808 * root
809 * |
810 * v
811 * net
812 **/
813
814 testBuilder_.addNode(strFormat("net_%d", i), testDeviceId,
815 /*parents=*/{}, {"netInput"}, {"netOutput"}, testRunId,
816 true);
817 tests.emplace_back(testBuilder_.emitTest());
818 }
819
820 // Run all tests.
821 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
822 std::thread th([&testsPassed, test = std::move(tests[i])]() mutable {
823 bool passed = test.run();
824 if (passed) {
825 testsPassed++;
826 }
827 });
828 threads.emplace_back(std::move(th));
829 }
830
831 // Join all threads.
832 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
833 threads[i].join();
834 }
835
836 // At least one test should pass. Depending on the interleaving, the
837 // rest can all pass or all fail or anything in between.
838 EXPECT_GE(testsPassed, 1);
839}
840
841/// Tests that a DAG with multiple nodes can run correctly.
842TEST_F(ThreadPoolExecutorTest, MultiNode) {
843 constexpr RunIdentifierTy testRunId = 10;
844 constexpr DeviceIDTy testDeviceId = 111;
845 constexpr unsigned deviceManagerThreads = 3;
846
847 // Make a TestDeviceManager and insert into the DeviceManagerMap map (which
848 // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map
849 // (which the ExecutorTestBuilder has a reference to).
850 auto deviceManager = glow::make_unique<TestDeviceManager>(
851 deviceManagerThreads, DeviceConfig("Interpreter"));
852 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
853
854 // Build the DAG. The DAG created below looks like this:
855 /**
856 * root
857 * / \
858 * v v
859 * alpha beta
860 * \ /
861 * v v
862 * gamma
863 * / \
864 * v v
865 * delta eps
866 **/
867
868 testBuilder_.addNode("alpha", testDeviceId,
869 /*parents=*/{}, /*inputs=*/{"alphaIn"},
870 /*outputs=*/{"alphaOut"}, testRunId, true);
871 testBuilder_.addNode("beta", testDeviceId,
872 /*parents=*/{}, /*inputs=*/{"betaIn"},
873 /*outputs=*/{"betaOut"}, testRunId, true);
874 testBuilder_.addNode("gamma", testDeviceId,
875 /*parents=*/{"alpha", "beta"},
876 /*inputs=*/{"alphaOut", "betaOut"},
877 /*outputs=*/{"deltaIn", "epsIn"}, testRunId, true);
878 testBuilder_.addNode("delta", testDeviceId,
879 /*parents=*/{"gamma"}, /*inputs=*/{"deltaIn"},
880 /*outputs=*/{"deltaOut"}, testRunId, true);
881 testBuilder_.addNode("eps", testDeviceId,
882 /*parents=*/{"gamma"}, /*inputs=*/{"epsIn"},
883 /*outputs=*/{"epsOut"}, testRunId, true);
884
885 ExecutorTest test = testBuilder_.emitTest();
886 EXPECT_TRUE(test.run());
887}
888
889/// Tests that a DAG with a node that fails can run correctly.
890TEST_F(ThreadPoolExecutorTest, MultiNodeWithFailure) {
891 constexpr RunIdentifierTy testRunId = 10;
892 constexpr DeviceIDTy testDeviceId = 111;
893 constexpr unsigned deviceManagerThreads = 3;
894
895 // Make a TestDeviceManager and insert into the DeviceManagerMap map (which
896 // the ThreadPoolExecutor has a reference to) and the TestDeviceManager map
897 // (which the ExecutorTestBuilder has a reference to).
898 auto deviceManager = glow::make_unique<TestDeviceManager>(
899 deviceManagerThreads, DeviceConfig("Interpreter"));
900 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
901
902 // Build the DAG. The DAG created below looks like this:
903 /**
904 * root
905 * / \
906 * v v
907 * alpha delta
908 * | |
909 * v v
910 * beta eps
911 * | |
912 * v v
913 * gamma zeta
914 **/
915
916 testBuilder_.addNode("alpha", testDeviceId,
917 /*parents=*/{}, /*inputs=*/{"alphaIn"},
918 /*outputs=*/{"alphaOut"}, testRunId, true);
919 testBuilder_.addNode("beta", testDeviceId,
920 /*parents=*/{"alpha"}, /*inputs=*/{"alphaOut"},
921 /*outputs=*/{"betaOut"}, testRunId, true);
922 testBuilder_.addNode("gamma", testDeviceId,
923 /*parents=*/{"beta"},
924 /*inputs=*/{"betaOut"},
925 /*outputs=*/{"gammaOut"}, testRunId, true);
926 testBuilder_.addNode("delta", testDeviceId,
927 /*parents=*/{}, /*inputs=*/{"deltaIn"},
928 /*outputs=*/{"deltaOut"}, testRunId, true);
929 testBuilder_.addNode("eps", testDeviceId,
930 /*parents=*/{"delta"}, /*inputs=*/{"deltaOut"},
931 /*outputs=*/{"epsOut"}, testRunId, false);
932 testBuilder_.addNode("zeta", testDeviceId,
933 /*parents=*/{"eps"}, /*inputs=*/{"epsOut"},
934 /*outputs=*/{"zetaOut"}, testRunId, true);
935
936 ExecutorTest test = testBuilder_.emitTest();
937 EXPECT_TRUE(test.run());
938}
939
940/// Tests that a DAG with nodes spread across multiple devices can run
941/// correctly.
942TEST_F(ThreadPoolExecutorTest, MultiNodeMultiDevice) {
943 constexpr RunIdentifierTy testRunId = 10;
944 constexpr DeviceIDTy testDeviceIdA = 111;
945 constexpr DeviceIDTy testDeviceIdB = 112;
946 constexpr DeviceIDTy testDeviceIdC = 113;
947 constexpr unsigned deviceManagerThreads = 3;
948
949 // Make TestDeviceManagers and insert them into the DeviceManagerMap map
950 // (which the ThreadPoolExecutor has a reference to) and the TestDeviceManager
951 // map (which the ExecutorTestBuilder has a reference to).
952 for (DeviceIDTy deviceId : {testDeviceIdA, testDeviceIdB, testDeviceIdC}) {
953 auto deviceManager = glow::make_unique<TestDeviceManager>(
954 deviceManagerThreads, DeviceConfig("Interpreter"));
955 deviceManagerMap_.emplace(deviceId, std::move(deviceManager));
956 }
957
958 // Build the DAG. The DAG created below looks like this:
959 /**
960 * root
961 * / \
962 * v v
963 * alpha beta
964 * \ /
965 * v v
966 * gamma
967 * / \
968 * v v
969 * delta eps
970 **/
971
972 testBuilder_.addNode("alpha", testDeviceIdA,
973 /*parents=*/{}, /*inputs=*/{"alphaIn"},
974 /*outputs=*/{"alphaOut"}, testRunId, true);
975 testBuilder_.addNode("beta", testDeviceIdB,
976 /*parents=*/{}, /*inputs=*/{"betaIn"},
977 /*outputs=*/{"betaOut"}, testRunId, true);
978 testBuilder_.addNode("gamma", testDeviceIdC,
979 /*parents=*/{"alpha", "beta"},
980 /*inputs=*/{"alphaOut", "betaOut"},
981 /*outputs=*/{"deltaIn", "epsIn"}, testRunId, true);
982 testBuilder_.addNode("delta", testDeviceIdA,
983 /*parents=*/{"gamma"}, /*inputs=*/{"deltaIn"},
984 /*outputs=*/{"deltaOut"}, testRunId, true);
985 testBuilder_.addNode("eps", testDeviceIdB,
986 /*parents=*/{"gamma"}, /*inputs=*/{"epsIn"},
987 /*outputs=*/{"epsOut"}, testRunId, true);
988
989 ExecutorTest test = testBuilder_.emitTest();
990 EXPECT_TRUE(test.run());
991}
992
993/// Tests that several instances of a DAG with multiple nodes can run correctly
994/// in parallel.
995TEST_F(ThreadPoolExecutorTest, ConcurrentMultiNode) {
996 constexpr RunIdentifierTy baseTestRunId = 10;
997 constexpr DeviceIDTy testDeviceId = 111;
998 constexpr unsigned deviceManagerThreads = 3;
999 unsigned numConcurrentRuns = 100;
1000
1001 // Make a TestDeviceManager and insert it into the DeviceManagerMap map
1002 // (which the ThreadPoolExecutor has a reference to) and the TestDeviceManager
1003 // map (which the ExecutorTestBuilder has a reference to).
1004 auto deviceManager = glow::make_unique<TestDeviceManager>(
1005 deviceManagerThreads, DeviceConfig("Interpreter"));
1006 deviceManagerMap_.emplace(testDeviceId, std::move(deviceManager));
1007
1008 // Mutex for accessing threadsReady and testsPassed.
1009 std::mutex mtx;
1010 // Condition variables for signalling between the test runner threads
1011 // and this thread. These are used to implement a barrier that ensures
1012 // all test runner threads have been created and are executing before any
1013 // are allowed to run a test (in order to try and increase the number of
1014 // threads that call Executor::run() at the same time).
1015 std::condition_variable driverCV, threadCV;
1016 // Counters for implementing the aforementioned barrier and tracking the
1017 // number of tests that pass.
1018 unsigned threadsReady = 0, testsPassed = 0;
1019 std::vector<std::thread> threads;
1020 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
1021 // Build the DAG. The DAG created below looks like this:
1022 /**
1023 * root
1024 * / \
1025 * v v
1026 * alpha beta
1027 * \ /
1028 * v v
1029 * gamma
1030 * / \
1031 * v v
1032 * delta eps
1033 **/
1034
1035 // The names must be distinct for each run since the DeviceManager
1036 // distinguishes based on function name.
1037 std::string alpha = strFormat("alpha_%d", i);
1038 std::string beta = strFormat("beta_%d", i);
1039 std::string gamma = strFormat("gamma_%d", i);
1040 std::string delta = strFormat("delta_%d", i);
1041 std::string eps = strFormat("eps_%d", i);
1042
1043 // The run IDs must be distinct as well to distinguish all the concurrent
1044 // runs from each other.
1045 testBuilder_.addNode(alpha, testDeviceId,
1046 /*parents=*/{}, /*inputs=*/{"alphaIn"},
1047 /*outputs=*/{"alphaOut"}, baseTestRunId + i, true);
1048 testBuilder_.addNode(beta, testDeviceId,
1049 /*parents=*/{}, /*inputs=*/{"betaIn"},
1050 /*outputs=*/{"betaOut"}, baseTestRunId + i, true);
1051 testBuilder_.addNode(gamma, testDeviceId,
1052 /*parents=*/{alpha, beta},
1053 /*inputs=*/{"alphaOut", "betaOut"},
1054 /*outputs=*/{"deltaIn", "epsIn"}, baseTestRunId + i,
1055 true);
1056 testBuilder_.addNode(delta, testDeviceId,
1057 /*parents=*/{gamma}, /*inputs=*/{"deltaIn"},
1058 /*outputs=*/{"deltaOut"}, baseTestRunId + i, true);
1059 testBuilder_.addNode(eps, testDeviceId,
1060 /*parents=*/{gamma}, /*inputs=*/{"epsIn"},
1061 /*outputs=*/{"epsOut"}, baseTestRunId + i, true);
1062
1063 ExecutorTest t = testBuilder_.emitTest();
1064 std::thread th([&mtx, &driverCV, &threadCV, &threadsReady, &testsPassed,
1065 test = std::move(t), numConcurrentRuns]() mutable {
1066 std::unique_lock<std::mutex> lock(mtx);
1067 // Increment threadsReady to mark this thread as ready to run the test.
1068 threadsReady++;
1069 // If threadsReady == numConcurrentRuns, this thread is the last to be
1070 // initialized and execute, so signal the driver that all threads are
1071 // ready.
1072 if (threadsReady == numConcurrentRuns) {
1073 driverCV.notify_one();
1074 }
1075 // Wait for the driver's signal.
1076 threadCV.wait(lock);
1077 // Unlock the mutex to let all other threads run their tests concurrently.
1078 lock.unlock();
1079 bool passed = test.run();
1080 lock.lock();
1081
1082 if (passed) {
1083 testsPassed++;
1084 }
1085 });
1086 threads.emplace_back(std::move(th));
1087 }
1088
1089 std::unique_lock<std::mutex> lock(mtx);
1090 // If threadsReady != numConcurrentRuns, not all threads are ready to run
1091 // their tests. Wait until they are.
1092 if (threadsReady != numConcurrentRuns) {
1093 driverCV.wait(lock, [&threadsReady, numConcurrentRuns] {
1094 return threadsReady == numConcurrentRuns;
1095 });
1096 }
1097 // Wake up all test runners.
1098 threadCV.notify_all();
1099 lock.unlock();
1100
1101 // Join all threads.
1102 for (unsigned i = 0; i < numConcurrentRuns; ++i) {
1103 threads[i].join();
1104 }
1105
1106 // All tests should pass.
1107 EXPECT_EQ(testsPassed, numConcurrentRuns);
1108}
1109