1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Support/ThreadPool.h"
18#include "glow/Support/Memory.h"
19
20#include "gtest/gtest.h"
21
22#include "llvm/ADT/STLExtras.h"
23
24#include <future>
25#include <vector>
26
27using namespace glow;
28
29TEST(ThreadPool, BasicTest) {
30 const unsigned numWorkers = 100;
31 const unsigned numWorkItems = 1000;
32 ThreadPool tp(numWorkers);
33
34 // Create vectors to store futures and promises for
35 // communicating results of work done on thread pool.
36 std::vector<std::future<int>> futures;
37 futures.reserve(numWorkItems);
38 std::vector<std::promise<int>> promises(numWorkItems);
39
40 // Submit 'numWorkItems' work items to the thread pool;
41 // each task takes its index and computes and returns
42 // 2x its index.
43 for (unsigned i = 0; i < numWorkItems; ++i) {
44 auto &p = promises[i];
45 futures.emplace_back(p.get_future());
46 tp.submit([&p, i]() { p.set_value(2 * i); });
47 }
48
49 // Check that every future holds the expected result
50 // (2x its index).
51 for (unsigned i = 0; i < numWorkItems; ++i) {
52 futures[i].wait();
53 auto result = futures[i].get();
54 EXPECT_EQ(result, 2 * i);
55 }
56}
57
58TEST(ThreadPool, moveCaptureTest) {
59 ThreadPool tp(1);
60
61 std::unique_ptr<int> input = glow::make_unique<int>(42);
62 int output = 0;
63 auto func = [input = std::move(input), &output]() { output = (*input) * 2; };
64
65 auto done = tp.submit(std::move(func));
66
67 done.wait();
68 EXPECT_EQ(output, 84);
69}
70
71TEST(ThreadPool, completionFutureTest) {
72 ThreadPool tp(1);
73
74 int input = 42, output = 0;
75 std::packaged_task<void(void)> task(
76 [&input, &output]() { output = input * 3; });
77
78 auto done = tp.submit(std::move(task));
79
80 done.wait();
81 EXPECT_EQ(output, 126);
82}
83
84/// Verify that we can get an Executor that runs tasks consistently on the same
85/// thread.
86TEST(ThreadPool, getExecutor) {
87 ThreadPool tp(3);
88
89 std::thread::id t1;
90 std::thread::id t2;
91
92 /// Check that runs on the same executor run on the same thread.
93 auto *ex = tp.getExecutor();
94 auto fut1 = ex->submit([&t1]() { t1 = std::this_thread::get_id(); });
95 auto fut2 = ex->submit([&t2]() { t2 = std::this_thread::get_id(); });
96
97 fut1.get();
98 fut2.get();
99
100 ASSERT_EQ(t1, t2);
101 ASSERT_NE(t1, std::thread::id());
102
103 /// Now verify this isn't always true.
104 t1 = t2 = std::thread::id();
105 auto *ex2 = tp.getExecutor();
106
107 fut1 = ex->submit([&t1] { t1 = std::this_thread::get_id(); });
108 fut2 = ex2->submit([&t2] { t2 = std::this_thread::get_id(); });
109
110 fut1.get();
111 fut2.get();
112
113 ASSERT_NE(t1, t2);
114 ASSERT_NE(t1, std::thread::id());
115}
116
117/// Verify that you can get more executors than there are threads in the pool.
118TEST(ThreadPool, getManyExecutors) {
119 ThreadPool tp(3);
120
121 std::atomic<size_t> left{20};
122 std::promise<void> finished;
123
124 auto F = [&left, &finished]() {
125 if (--left == 0) {
126 finished.set_value();
127 }
128 };
129
130 for (int i = 0; i < 10; ++i) {
131 auto *ex = tp.getExecutor();
132 // Submit two tasks
133 ex->submit(F);
134 ex->submit(F);
135 }
136
137 finished.get_future().get();
138 ASSERT_EQ(left, 0);
139}
140
141/// Verify we can run on all threads and that they are different.
142TEST(ThreadPool, runOnAllThreads) {
143 ThreadPool tp(3);
144 std::vector<std::thread::id> threadIds;
145
146 std::mutex vecLock;
147
148 auto fut = tp.runOnAllThreads([&threadIds, &vecLock]() {
149 std::lock_guard<std::mutex> l(vecLock);
150 threadIds.push_back(std::this_thread::get_id());
151 });
152
153 fut.get();
154
155 ASSERT_EQ(threadIds.size(), 3);
156 ASSERT_NE(threadIds[0], threadIds[1]);
157 ASSERT_NE(threadIds[1], threadIds[2]);
158 ASSERT_NE(threadIds[2], threadIds[0]);
159}
160