1/*******************************************************************************
2 Copyright (c) The Taichi Authors (2016- ). All Rights Reserved.
3 The use of this software is governed by the LICENSE file.
4*******************************************************************************/
5
6#include "taichi/system/threading.h"
7
8#include <algorithm>
9#include <condition_variable>
10#include <thread>
11#include <vector>
12
13namespace taichi {
14
15bool test_threading() {
16 auto tp = ThreadPool(20);
17 for (int j = 0; j < 100; j++) {
18 tp.run(10, j + 1, &j, [](void *j, int _thread_id, int i) {
19 double ret = 0.0;
20 for (int t = 0; t < 10000000; t++) {
21 ret += t * 1e-20;
22 }
23 TI_P(int(i + ret + 10 * *(int *)j));
24 });
25 }
26 return true;
27}
28
29ThreadPool::ThreadPool(int max_num_threads) : max_num_threads(max_num_threads) {
30 exiting = false;
31 started = false;
32 running_threads = 0;
33 timestamp = 1;
34 last_finished = 0;
35 task_head = 0;
36 task_tail = 0;
37 thread_counter = 0;
38 threads.resize((std::size_t)max_num_threads);
39 for (int i = 0; i < max_num_threads; i++) {
40 threads[i] = std::thread([this] { this->target(); });
41 }
42}
43
44void ThreadPool::run(int splits,
45 int desired_num_threads,
46 void *range_for_task_context,
47 RangeForTaskFunc *func) {
48 {
49 std::lock_guard _(mutex);
50 this->range_for_task_context = range_for_task_context;
51 this->func = func;
52 this->desired_num_threads = std::min(desired_num_threads, max_num_threads);
53 TI_ASSERT(this->desired_num_threads > 0);
54 // TI_P(this->desired_num_threads);
55 started = false;
56 task_head = 0;
57 task_tail = splits;
58 timestamp++;
59 TI_ASSERT(timestamp < (1LL << 62)); // avoid overflowing here
60 }
61
62 // wake up all slaves
63 slave_cv.notify_all();
64 {
65 std::unique_lock<std::mutex> lock(mutex);
66 // TODO: the workers may have finished before master waiting on master_cv
67 master_cv.wait(lock, [this] { return started && running_threads == 0; });
68 }
69 TI_ASSERT(task_head >= task_tail);
70}
71
72void ThreadPool::target() {
73 uint64 last_timestamp = 0;
74 int thread_id;
75 {
76 std::lock_guard<std::mutex> lock(mutex);
77 thread_id = thread_counter++;
78 }
79 while (true) {
80 {
81 std::unique_lock<std::mutex> lock(mutex);
82 slave_cv.wait(lock, [this, last_timestamp, thread_id] {
83 return (timestamp > last_timestamp &&
84 thread_id < desired_num_threads) ||
85 this->exiting;
86 });
87 last_timestamp = timestamp;
88 if (exiting) {
89 break;
90 } else {
91 if (last_finished >= last_timestamp) {
92 continue;
93 // This could happen when part of the desired threads wake up and
94 // finish all the task, and then this thread wake up finding nothing
95 // to do. Should skip this task directly.
96 } else {
97 started = true;
98 running_threads++;
99 }
100 }
101 }
102
103 while (true) {
104 // For a single parallel task
105 int task_id;
106 {
107 task_id = task_head.fetch_add(1, std::memory_order_relaxed);
108 if (task_id >= task_tail)
109 break;
110 }
111
112 func(this->range_for_task_context, thread_id, task_id);
113 }
114
115 bool all_finished = false;
116 {
117 std::lock_guard<std::mutex> lock(mutex);
118 running_threads--;
119 if (running_threads == 0) {
120 all_finished = true;
121 last_finished = last_timestamp;
122 }
123 }
124 if (all_finished)
125 master_cv.notify_one();
126 }
127}
128
129ThreadPool::~ThreadPool() {
130 {
131 std::lock_guard<std::mutex> lg(mutex);
132 exiting = true;
133 }
134 slave_cv.notify_all();
135 for (auto &th : threads)
136 th.join();
137}
138
139} // namespace taichi
140