1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ |
17 | #define TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ |
18 | |
19 | #include <functional> |
20 | |
21 | #include "tensorflow/core/lib/core/threadpool.h" |
22 | #include "tensorflow/core/platform/types.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | // DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy, which |
27 | // allows you to specify the strategy for choosing shard sizes, including using |
28 | // a fixed shard size. Use this function only if you want to manually cap |
29 | // parallelism. |
30 | // |
31 | // Shards the "total" unit of work assuming each unit of work having |
32 | // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., |
33 | // total - 1. Each shard contains 1 or more units of work and the |
34 | // total cost of each shard is roughly the same. The calling thread and the |
35 | // "workers" are used to compute each shard (calling work(start, |
36 | // limit). A common configuration is that "workers" is a thread pool |
37 | // with at least "max_parallelism" threads. |
38 | // |
39 | // "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds |
40 | // if not CPU-bound) to complete a unit of work. Overestimating creates too |
41 | // many shards and CPU time will be dominated by per-shard overhead, such as |
42 | // Context creation. Underestimating may not fully make use of the specified |
43 | // parallelism. |
44 | // |
45 | // "work" should be a callable taking (int64, int64) arguments. |
46 | // work(start, limit) computes the work units from [start, |
47 | // limit), i.e., [start, limit) is a shard. |
48 | // |
49 | // Too much parallelism can also cause excessive thread switches, |
50 | // therefore, Shard() often limits the maximum parallelism. Each |
51 | // caller can provide the 1st argument max_parallelism. A thread can |
52 | // call SetMaxParallelism() so that all Shard() calls later limits the |
53 | // thread parallelism. |
54 | // |
55 | // REQUIRES: max_parallelism >= 0 |
56 | // REQUIRES: workers != nullptr |
57 | // REQUIRES: total >= 0 |
58 | // REQUIRES: cost_per_unit >= 0 |
59 | void Shard(int max_parallelism, thread::ThreadPool* workers, int64_t total, |
60 | int64_t cost_per_unit, std::function<void(int64_t, int64_t)> work); |
61 | |
62 | // Each thread has an associated option to express the desired maximum |
63 | // parallelism. Its default is a very large quantity. |
64 | // |
65 | // Within TF runtime, per-thread max parallelism affects Shard() and |
66 | // intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is |
67 | // arranged to be called by a tf_compute thread, Shard() calls and |
68 | // eigen device assignment happens in that thread afterwards becomes |
69 | // single-threaded. |
70 | void SetPerThreadMaxParallelism(int max_parallelism); |
71 | int GetPerThreadMaxParallelism(); |
72 | |
73 | // Helper to set and unset per-thread max parallelism. |
74 | class ScopedPerThreadMaxParallelism { |
75 | public: |
76 | ScopedPerThreadMaxParallelism(int max_parallelism) |
77 | : previous_(GetPerThreadMaxParallelism()) { |
78 | SetPerThreadMaxParallelism(max_parallelism); |
79 | } |
80 | |
81 | ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); } |
82 | |
83 | private: |
84 | int previous_ = -1; |
85 | }; |
86 | |
87 | // Implementation details for Shard(). |
88 | class Sharder { |
89 | public: |
90 | typedef std::function<void()> Closure; |
91 | typedef std::function<void(Closure)> Runner; |
92 | typedef std::function<void(int64_t, int64_t)> Work; |
93 | |
94 | // Refers to Shard()'s comment for the meaning of total, |
95 | // cost_per_unit, work, max_parallelism. runner is an interface to |
96 | // schedule a closure. Shard() uses thread::ThreadPool instead. |
97 | static void Do(int64_t total, int64_t cost_per_unit, const Work& work, |
98 | const Runner& runner, int max_parallelism); |
99 | }; |
100 | |
101 | } // end namespace tensorflow |
102 | |
103 | #endif // TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_ |
104 | |