1 | #pragma once |
2 | |
3 | #include <atomic> |
4 | #include <condition_variable> |
5 | #include <deque> |
6 | #include <functional> |
7 | #include <future> |
8 | #include <mutex> |
9 | #include <thread> |
10 | #include <unordered_map> |
11 | #include "taichi/common/core.h" |
12 | |
13 | namespace taichi::lang { |
14 | class ParallelExecutor { |
15 | public: |
16 | using TaskType = std::function<void()>; |
17 | |
18 | explicit ParallelExecutor(const std::string &name, int num_threads); |
19 | ~ParallelExecutor(); |
20 | |
21 | void enqueue(const TaskType &func); |
22 | |
23 | void flush(); |
24 | |
25 | int get_num_threads() { |
26 | return num_threads_; |
27 | } |
28 | |
29 | private: |
30 | enum class ExecutorStatus { |
31 | uninitialized, |
32 | initialized, |
33 | finalized, |
34 | }; |
35 | |
36 | void worker_loop(); |
37 | |
38 | // Must be called while holding |mut|. |
39 | bool flush_cv_cond(); |
40 | |
41 | std::string name_; |
42 | int num_threads_; |
43 | std::atomic<int> thread_counter_{0}; |
44 | std::mutex mut_; |
45 | |
46 | // All guarded by |mut| |
47 | ExecutorStatus status_; |
48 | std::vector<std::thread> threads_; |
49 | std::deque<TaskType> task_queue_; |
50 | int running_threads_; |
51 | |
52 | // Used to signal the workers that they can start polling from |task_queue|. |
53 | std::condition_variable init_cv_; |
54 | // Used by |this| to instruct the worker thread that there is an event: |
55 | // * task being enqueued |
56 | // * shutting down |
57 | std::condition_variable worker_cv_; |
58 | // Used by a worker thread to unblock the caller from waiting for a flush. |
59 | // |
60 | // TODO: Instead of having this as a member variable, we can enqueue a |
61 | // callback upon flush(). The flush() will then block waiting for that |
62 | // callback to be executed? |
63 | std::condition_variable flush_cv_; |
64 | }; |
65 | } // namespace taichi::lang |
66 | |