1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file parallel_for.cc |
22 | * \brief An implementation to run loop in parallel. |
23 | */ |
24 | #include <tvm/runtime/logging.h> |
25 | #include <tvm/support/parallel_for.h> |
26 | |
27 | #include <future> |
28 | #include <thread> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | namespace tvm { |
33 | namespace support { |
34 | |
35 | std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) { |
36 | int total_task_count = (end - begin) / step; |
37 | ICHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin |
38 | << " end: " << end << " step: " << step; |
39 | std::vector<std::vector<int>> ret; |
40 | ret.reserve(num_threads); |
41 | for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) { |
42 | if (thread >= ret.size()) { |
43 | ret.push_back(std::vector<int>()); |
44 | } |
45 | ret[thread].push_back(begin); |
46 | } |
47 | return ret; |
48 | } |
49 | |
50 | void parallel_for(int begin, int end, const std::function<void(int)>& f, int step, |
51 | const PartitionerFuncType partitioner) { |
52 | static bool GLOBAL_PARALLEL_FOR_FLAG{false}; |
53 | static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG; |
54 | { |
55 | std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG); |
56 | ICHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're " |
57 | << "currently inside another parallel_for loop." ; |
58 | GLOBAL_PARALLEL_FOR_FLAG = true; |
59 | } |
60 | |
61 | int default_num_threads = std::thread::hardware_concurrency(); |
62 | const auto& run_partitions = partitioner(begin, end, step, default_num_threads); |
63 | |
64 | std::vector<std::thread> threads; |
65 | threads.reserve(run_partitions.size()); |
66 | std::vector<std::future<void>> res_vec; |
67 | res_vec.reserve(run_partitions.size()); |
68 | for (const auto& run_partition : run_partitions) { |
69 | std::packaged_task<void(const std::vector<int>&, const std::function<void(int)>&)> task( |
70 | [](const std::vector<int>& run_partition, const std::function<void(int)>& f) { |
71 | for (const auto& i : run_partition) { |
72 | f(i); |
73 | } |
74 | }); |
75 | res_vec.emplace_back(task.get_future()); |
76 | threads.emplace_back(std::move(task), run_partition, f); |
77 | } |
78 | |
79 | for (auto&& thread : threads) { |
80 | thread.join(); |
81 | } |
82 | { |
83 | std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG); |
84 | ICHECK(GLOBAL_PARALLEL_FOR_FLAG); |
85 | GLOBAL_PARALLEL_FOR_FLAG = false; |
86 | } |
87 | try { |
88 | for (auto&& i : res_vec) { |
89 | i.get(); |
90 | } |
91 | } catch (const std::exception& e) { |
92 | LOG(FATAL) << "Parallel_for error with " << e.what(); |
93 | } |
94 | } |
95 | |
96 | void parallel_for_dynamic(int begin, int end, int num_threads, |
97 | const std::function<void(int thread_id, int task_id)>& f) { |
98 | // Step 1. Sanity checks |
99 | if (begin == end) { |
100 | return; |
101 | } |
102 | CHECK_LE(begin, end) << "ValueError: The interval [begin, end) requires `begin <= end`" ; |
103 | CHECK_GT(num_threads, 0) << "ValueError: `num_threads` should be positive" ; |
104 | // Step 2. Launch threads |
105 | // Step 2.1. Launch worker 1 to worker `num_threads - 1` |
106 | std::atomic<int> counter{begin}; |
107 | std::vector<std::future<void>> futures; |
108 | std::vector<std::thread> threads; |
109 | futures.reserve(num_threads - 1); |
110 | threads.reserve(num_threads - 1); |
111 | auto worker = [end, &counter, &f](int thread_id) -> void { |
112 | for (int task_id; (task_id = counter++) < end;) { |
113 | f(thread_id, task_id); |
114 | } |
115 | }; |
116 | for (int thread_id = 1; thread_id < num_threads; ++thread_id) { |
117 | std::packaged_task<void(int)> task(worker); |
118 | futures.emplace_back(task.get_future()); |
119 | threads.emplace_back(std::move(task), thread_id); |
120 | } |
121 | // Step 2.2. Launch worker 0 inplace |
122 | try { |
123 | worker(0); |
124 | } catch (const std::exception& e) { |
125 | for (auto&& thread : threads) { |
126 | thread.join(); |
127 | } |
128 | LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); |
129 | } |
130 | // Step 3. Join threads and check exceptions |
131 | for (auto&& thread : threads) { |
132 | thread.join(); |
133 | } |
134 | try { |
135 | for (auto&& future : futures) { |
136 | future.get(); |
137 | } |
138 | } catch (const std::exception& e) { |
139 | LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what(); |
140 | } |
141 | } |
142 | |
143 | } // namespace support |
144 | } // namespace tvm |
145 | |