1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/Parallel.h> |
4 | #include <c10/core/thread_pool.h> |
5 | |
6 | namespace at { |
7 | |
8 | class TORCH_API PTThreadPool : public c10::ThreadPool { |
9 | public: |
10 | explicit PTThreadPool(int pool_size, int numa_node_id = -1) |
11 | : c10::ThreadPool(pool_size, numa_node_id, []() { |
12 | c10::setThreadName("PTThreadPool"); |
13 | at::init_num_threads(); |
14 | }) {} |
15 | }; |
16 | |
17 | } // namespace at |
18 |