1 | /******************************************************************************* |
2 | Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. |
3 | The use of this software is governed by the LICENSE file. |
4 | *******************************************************************************/ |
5 | |
6 | #include "taichi/system/threading.h" |
7 | |
8 | #include <algorithm> |
9 | #include <condition_variable> |
10 | #include <thread> |
11 | #include <vector> |
12 | |
13 | namespace taichi { |
14 | |
15 | bool test_threading() { |
16 | auto tp = ThreadPool(20); |
17 | for (int j = 0; j < 100; j++) { |
18 | tp.run(10, j + 1, &j, [](void *j, int _thread_id, int i) { |
19 | double ret = 0.0; |
20 | for (int t = 0; t < 10000000; t++) { |
21 | ret += t * 1e-20; |
22 | } |
23 | TI_P(int(i + ret + 10 * *(int *)j)); |
24 | }); |
25 | } |
26 | return true; |
27 | } |
28 | |
29 | ThreadPool::ThreadPool(int max_num_threads) : max_num_threads(max_num_threads) { |
30 | exiting = false; |
31 | started = false; |
32 | running_threads = 0; |
33 | timestamp = 1; |
34 | last_finished = 0; |
35 | task_head = 0; |
36 | task_tail = 0; |
37 | thread_counter = 0; |
38 | threads.resize((std::size_t)max_num_threads); |
39 | for (int i = 0; i < max_num_threads; i++) { |
40 | threads[i] = std::thread([this] { this->target(); }); |
41 | } |
42 | } |
43 | |
44 | void ThreadPool::run(int splits, |
45 | int desired_num_threads, |
46 | void *range_for_task_context, |
47 | RangeForTaskFunc *func) { |
48 | { |
49 | std::lock_guard _(mutex); |
50 | this->range_for_task_context = range_for_task_context; |
51 | this->func = func; |
52 | this->desired_num_threads = std::min(desired_num_threads, max_num_threads); |
53 | TI_ASSERT(this->desired_num_threads > 0); |
54 | // TI_P(this->desired_num_threads); |
55 | started = false; |
56 | task_head = 0; |
57 | task_tail = splits; |
58 | timestamp++; |
59 | TI_ASSERT(timestamp < (1LL << 62)); // avoid overflowing here |
60 | } |
61 | |
62 | // wake up all slaves |
63 | slave_cv.notify_all(); |
64 | { |
65 | std::unique_lock<std::mutex> lock(mutex); |
66 | // TODO: the workers may have finished before master waiting on master_cv |
67 | master_cv.wait(lock, [this] { return started && running_threads == 0; }); |
68 | } |
69 | TI_ASSERT(task_head >= task_tail); |
70 | } |
71 | |
72 | void ThreadPool::target() { |
73 | uint64 last_timestamp = 0; |
74 | int thread_id; |
75 | { |
76 | std::lock_guard<std::mutex> lock(mutex); |
77 | thread_id = thread_counter++; |
78 | } |
79 | while (true) { |
80 | { |
81 | std::unique_lock<std::mutex> lock(mutex); |
82 | slave_cv.wait(lock, [this, last_timestamp, thread_id] { |
83 | return (timestamp > last_timestamp && |
84 | thread_id < desired_num_threads) || |
85 | this->exiting; |
86 | }); |
87 | last_timestamp = timestamp; |
88 | if (exiting) { |
89 | break; |
90 | } else { |
91 | if (last_finished >= last_timestamp) { |
92 | continue; |
93 | // This could happen when part of the desired threads wake up and |
94 | // finish all the task, and then this thread wake up finding nothing |
95 | // to do. Should skip this task directly. |
96 | } else { |
97 | started = true; |
98 | running_threads++; |
99 | } |
100 | } |
101 | } |
102 | |
103 | while (true) { |
104 | // For a single parallel task |
105 | int task_id; |
106 | { |
107 | task_id = task_head.fetch_add(1, std::memory_order_relaxed); |
108 | if (task_id >= task_tail) |
109 | break; |
110 | } |
111 | |
112 | func(this->range_for_task_context, thread_id, task_id); |
113 | } |
114 | |
115 | bool all_finished = false; |
116 | { |
117 | std::lock_guard<std::mutex> lock(mutex); |
118 | running_threads--; |
119 | if (running_threads == 0) { |
120 | all_finished = true; |
121 | last_finished = last_timestamp; |
122 | } |
123 | } |
124 | if (all_finished) |
125 | master_cv.notify_one(); |
126 | } |
127 | } |
128 | |
129 | ThreadPool::~ThreadPool() { |
130 | { |
131 | std::lock_guard<std::mutex> lg(mutex); |
132 | exiting = true; |
133 | } |
134 | slave_cv.notify_all(); |
135 | for (auto &th : threads) |
136 | th.join(); |
137 | } |
138 | |
139 | } // namespace taichi |
140 | |