1#pragma once
2
3#include <ATen/Parallel.h>
4#include <c10/core/thread_pool.h>
5
6namespace at {
7
8class TORCH_API PTThreadPool : public c10::ThreadPool {
9 public:
10 explicit PTThreadPool(int pool_size, int numa_node_id = -1)
11 : c10::ThreadPool(pool_size, numa_node_id, []() {
12 c10::setThreadName("PTThreadPool");
13 at::init_num_threads();
14 }) {}
15};
16
17} // namespace at
18