1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file parallel_for.cc
22 * \brief An implementation to run loop in parallel.
23 */
24#include <tvm/runtime/logging.h>
25#include <tvm/support/parallel_for.h>
26
27#include <future>
28#include <thread>
29#include <utility>
30#include <vector>
31
32namespace tvm {
33namespace support {
34
35std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
36 int total_task_count = (end - begin) / step;
37 ICHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin
38 << " end: " << end << " step: " << step;
39 std::vector<std::vector<int>> ret;
40 ret.reserve(num_threads);
41 for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
42 if (thread >= ret.size()) {
43 ret.push_back(std::vector<int>());
44 }
45 ret[thread].push_back(begin);
46 }
47 return ret;
48}
49
50void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
51 const PartitionerFuncType partitioner) {
52 static bool GLOBAL_PARALLEL_FOR_FLAG{false};
53 static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG;
54 {
55 std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
56 ICHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're "
57 << "currently inside another parallel_for loop.";
58 GLOBAL_PARALLEL_FOR_FLAG = true;
59 }
60
61 int default_num_threads = std::thread::hardware_concurrency();
62 const auto& run_partitions = partitioner(begin, end, step, default_num_threads);
63
64 std::vector<std::thread> threads;
65 threads.reserve(run_partitions.size());
66 std::vector<std::future<void>> res_vec;
67 res_vec.reserve(run_partitions.size());
68 for (const auto& run_partition : run_partitions) {
69 std::packaged_task<void(const std::vector<int>&, const std::function<void(int)>&)> task(
70 [](const std::vector<int>& run_partition, const std::function<void(int)>& f) {
71 for (const auto& i : run_partition) {
72 f(i);
73 }
74 });
75 res_vec.emplace_back(task.get_future());
76 threads.emplace_back(std::move(task), run_partition, f);
77 }
78
79 for (auto&& thread : threads) {
80 thread.join();
81 }
82 {
83 std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
84 ICHECK(GLOBAL_PARALLEL_FOR_FLAG);
85 GLOBAL_PARALLEL_FOR_FLAG = false;
86 }
87 try {
88 for (auto&& i : res_vec) {
89 i.get();
90 }
91 } catch (const std::exception& e) {
92 LOG(FATAL) << "Parallel_for error with " << e.what();
93 }
94}
95
96void parallel_for_dynamic(int begin, int end, int num_threads,
97 const std::function<void(int thread_id, int task_id)>& f) {
98 // Step 1. Sanity checks
99 if (begin == end) {
100 return;
101 }
102 CHECK_LE(begin, end) << "ValueError: The interval [begin, end) requires `begin <= end`";
103 CHECK_GT(num_threads, 0) << "ValueError: `num_threads` should be positive";
104 // Step 2. Launch threads
105 // Step 2.1. Launch worker 1 to worker `num_threads - 1`
106 std::atomic<int> counter{begin};
107 std::vector<std::future<void>> futures;
108 std::vector<std::thread> threads;
109 futures.reserve(num_threads - 1);
110 threads.reserve(num_threads - 1);
111 auto worker = [end, &counter, &f](int thread_id) -> void {
112 for (int task_id; (task_id = counter++) < end;) {
113 f(thread_id, task_id);
114 }
115 };
116 for (int thread_id = 1; thread_id < num_threads; ++thread_id) {
117 std::packaged_task<void(int)> task(worker);
118 futures.emplace_back(task.get_future());
119 threads.emplace_back(std::move(task), thread_id);
120 }
121 // Step 2.2. Launch worker 0 inplace
122 try {
123 worker(0);
124 } catch (const std::exception& e) {
125 for (auto&& thread : threads) {
126 thread.join();
127 }
128 LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what();
129 }
130 // Step 3. Join threads and check exceptions
131 for (auto&& thread : threads) {
132 thread.join();
133 }
134 try {
135 for (auto&& future : futures) {
136 future.get();
137 }
138 } catch (const std::exception& e) {
139 LOG(FATAL) << "RuntimeError: parallel_for_dynamic error with " << e.what();
140 }
141}
142
143} // namespace support
144} // namespace tvm
145