1 | #include <taichi/system/timeline.h> |
---|---|
2 | #include "taichi/program/parallel_executor.h" |
3 | |
4 | namespace taichi::lang { |
5 | |
6 | ParallelExecutor::ParallelExecutor(const std::string &name, int num_threads) |
7 | : name_(name), |
8 | num_threads_(num_threads), |
9 | status_(ExecutorStatus::uninitialized), |
10 | running_threads_(0) { |
11 | if (num_threads <= 0) { |
12 | return; |
13 | } |
14 | { |
15 | auto _ = std::lock_guard<std::mutex>(mut_); |
16 | |
17 | for (int i = 0; i < num_threads; i++) { |
18 | threads_.emplace_back([this]() { this->worker_loop(); }); |
19 | } |
20 | |
21 | status_ = ExecutorStatus::initialized; |
22 | } |
23 | init_cv_.notify_all(); |
24 | } |
25 | |
26 | ParallelExecutor::~ParallelExecutor() { |
27 | // TODO: We should have a new ExecutorStatus, e.g. shutting_down, to prevent |
28 | // new tasks from being enqueued during shut down. |
29 | if (num_threads_ <= 0) { |
30 | return; |
31 | } |
32 | flush(); |
33 | { |
34 | auto _ = std::lock_guard<std::mutex>(mut_); |
35 | status_ = ExecutorStatus::finalized; |
36 | } |
37 | // Signal the workers that they need to shutdown. |
38 | worker_cv_.notify_all(); |
39 | for (auto &th : threads_) { |
40 | th.join(); |
41 | } |
42 | } |
43 | |
44 | void ParallelExecutor::enqueue(const TaskType &func) { |
45 | if (num_threads_ <= 0) { |
46 | func(); |
47 | return; |
48 | } |
49 | { |
50 | std::lock_guard<std::mutex> _(mut_); |
51 | task_queue_.push_back(func); |
52 | } |
53 | worker_cv_.notify_all(); |
54 | } |
55 | |
56 | void ParallelExecutor::flush() { |
57 | if (num_threads_ <= 0) { |
58 | return; |
59 | } |
60 | std::unique_lock<std::mutex> lock(mut_); |
61 | while (!flush_cv_cond()) { |
62 | flush_cv_.wait(lock); |
63 | } |
64 | } |
65 | |
66 | bool ParallelExecutor::flush_cv_cond() { |
67 | return (task_queue_.empty() && running_threads_ == 0); |
68 | } |
69 | |
70 | void ParallelExecutor::worker_loop() { |
71 | TI_DEBUG("Starting worker thread."); |
72 | auto thread_id = thread_counter_++; |
73 | |
74 | std::string thread_name = name_; |
75 | if (num_threads_ != 1) |
76 | thread_name += fmt::format("_{}", thread_id); |
77 | Timeline::get_this_thread_instance().set_name(thread_name); |
78 | |
79 | { |
80 | std::unique_lock<std::mutex> lock(mut_); |
81 | while (status_ == ExecutorStatus::uninitialized) { |
82 | init_cv_.wait(lock); |
83 | } |
84 | } |
85 | |
86 | TI_DEBUG("Worker thread initialized and running."); |
87 | bool done = false; |
88 | while (!done) { |
89 | bool notify_flush_cv = false; |
90 | { |
91 | std::unique_lock<std::mutex> lock(mut_); |
92 | while (task_queue_.empty() && status_ == ExecutorStatus::initialized) { |
93 | worker_cv_.wait(lock); |
94 | } |
95 | // So long as |task_queue| is not empty, we keep running. |
96 | if (!task_queue_.empty()) { |
97 | auto task = task_queue_.front(); |
98 | running_threads_++; |
99 | task_queue_.pop_front(); |
100 | lock.unlock(); |
101 | |
102 | // Run the task |
103 | task(); |
104 | |
105 | lock.lock(); |
106 | running_threads_--; |
107 | } |
108 | notify_flush_cv = flush_cv_cond(); |
109 | if (status_ == ExecutorStatus::finalized && task_queue_.empty()) { |
110 | done = true; |
111 | } |
112 | } |
113 | if (notify_flush_cv) { |
114 | // It is fine to notify |flush_cv_| while nobody is waiting on it. |
115 | flush_cv_.notify_one(); |
116 | } |
117 | } |
118 | } |
119 | } // namespace taichi::lang |
120 |