1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/work_sharder.h" |
17 | |
18 | #include "tensorflow/core/platform/blocking_counter.h" |
19 | #include "tensorflow/core/platform/logging.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | /* ABSL_CONST_INIT */ thread_local int per_thread_max_parallelism = 1000000; |
24 | |
25 | void SetPerThreadMaxParallelism(int max_parallelism) { |
26 | CHECK_LE(0, max_parallelism); |
27 | per_thread_max_parallelism = max_parallelism; |
28 | } |
29 | |
30 | int GetPerThreadMaxParallelism() { return per_thread_max_parallelism; } |
31 | |
32 | void Shard(int max_parallelism, thread::ThreadPool* workers, int64_t total, |
33 | int64_t cost_per_unit, std::function<void(int64_t, int64_t)> work) { |
34 | CHECK_GE(total, 0); |
35 | if (total == 0) { |
36 | return; |
37 | } |
38 | max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism()); |
39 | if (max_parallelism <= 1) { |
40 | // Just inline the whole work since we only have 1 thread (core). |
41 | work(0, total); |
42 | return; |
43 | } |
44 | if (max_parallelism >= workers->NumThreads()) { |
45 | workers->ParallelFor(total, cost_per_unit, work); |
46 | return; |
47 | } |
48 | Sharder::Do( |
49 | total, cost_per_unit, work, |
50 | [&workers](Sharder::Closure c) { workers->Schedule(c); }, |
51 | max_parallelism); |
52 | } |
53 | |
54 | // DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy, which |
55 | // allows you to specify the strategy for choosing shard sizes, including using |
56 | // a fixed shard size. |
57 | void Sharder::Do(int64_t total, int64_t cost_per_unit, const Work& work, |
58 | const Runner& runner, int max_parallelism) { |
59 | cost_per_unit = std::max(int64_t{1}, cost_per_unit); |
60 | // We shard [0, total) into "num_shards" shards. |
61 | // 1 <= num_shards <= num worker threads |
62 | // |
63 | // If total * cost_per_unit is small, it is not worth shard too |
64 | // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000 |
65 | // is 10us. |
66 | static const int64_t kMinCostPerShard = 10000; |
67 | const int num_shards = |
68 | std::max<int>(1, std::min(static_cast<int64_t>(max_parallelism), |
69 | total * cost_per_unit / kMinCostPerShard)); |
70 | |
71 | // Each shard contains up to "block_size" units. [0, total) is sharded |
72 | // into: |
73 | // [0, block_size), [block_size, 2*block_size), ... |
74 | // The 1st shard is done by the caller thread and the other shards |
75 | // are dispatched to the worker threads. The last shard may be smaller than |
76 | // block_size. |
77 | const int64_t block_size = (total + num_shards - 1) / num_shards; |
78 | CHECK_GT(block_size, 0); // total > 0 guarantees this. |
79 | if (block_size >= total) { |
80 | work(0, total); |
81 | return; |
82 | } |
83 | const int num_shards_used = (total + block_size - 1) / block_size; |
84 | BlockingCounter counter(num_shards_used - 1); |
85 | for (int64_t start = block_size; start < total; start += block_size) { |
86 | auto limit = std::min(start + block_size, total); |
87 | runner([&work, &counter, start, limit]() { |
88 | work(start, limit); // Compute the shard. |
89 | counter.DecrementCount(); // The shard is done. |
90 | }); |
91 | } |
92 | |
93 | // Inline execute the 1st shard. |
94 | work(0, std::min(block_size, total)); |
95 | counter.Wait(); |
96 | } |
97 | |
98 | } // end namespace tensorflow |
99 | |