1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
17#define TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
18
19#include <functional>
20
21#include "tensorflow/core/lib/core/threadpool.h"
22#include "tensorflow/core/platform/types.h"
23
24namespace tensorflow {
25
26// DEPRECATED: Prefer threadpool->ParallelFor with SchedulingStrategy, which
27// allows you to specify the strategy for choosing shard sizes, including using
28// a fixed shard size. Use this function only if you want to manually cap
29// parallelism.
30//
31// Shards the "total" unit of work assuming each unit of work having
32// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
33// total - 1. Each shard contains 1 or more units of work and the
34// total cost of each shard is roughly the same. The calling thread and the
35// "workers" are used to compute each shard (calling work(start,
36// limit). A common configuration is that "workers" is a thread pool
37// with at least "max_parallelism" threads.
38//
39// "cost_per_unit" is an estimate of the number of CPU cycles (or nanoseconds
40// if not CPU-bound) to complete a unit of work. Overestimating creates too
41// many shards and CPU time will be dominated by per-shard overhead, such as
42// Context creation. Underestimating may not fully make use of the specified
43// parallelism.
44//
45// "work" should be a callable taking (int64, int64) arguments.
46// work(start, limit) computes the work units from [start,
47// limit), i.e., [start, limit) is a shard.
48//
49// Too much parallelism can also cause excessive thread switches,
50// therefore, Shard() often limits the maximum parallelism. Each
51// caller can provide the 1st argument max_parallelism. A thread can
52// call SetMaxParallelism() so that all Shard() calls later limits the
53// thread parallelism.
54//
55// REQUIRES: max_parallelism >= 0
56// REQUIRES: workers != nullptr
57// REQUIRES: total >= 0
58// REQUIRES: cost_per_unit >= 0
59void Shard(int max_parallelism, thread::ThreadPool* workers, int64_t total,
60 int64_t cost_per_unit, std::function<void(int64_t, int64_t)> work);
61
62// Each thread has an associated option to express the desired maximum
63// parallelism. Its default is a very large quantity.
64//
65// Within TF runtime, per-thread max parallelism affects Shard() and
66// intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is
67// arranged to be called by a tf_compute thread, Shard() calls and
68// eigen device assignment happens in that thread afterwards becomes
69// single-threaded.
70void SetPerThreadMaxParallelism(int max_parallelism);
71int GetPerThreadMaxParallelism();
72
73// Helper to set and unset per-thread max parallelism.
74class ScopedPerThreadMaxParallelism {
75 public:
76 ScopedPerThreadMaxParallelism(int max_parallelism)
77 : previous_(GetPerThreadMaxParallelism()) {
78 SetPerThreadMaxParallelism(max_parallelism);
79 }
80
81 ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); }
82
83 private:
84 int previous_ = -1;
85};
86
87// Implementation details for Shard().
88class Sharder {
89 public:
90 typedef std::function<void()> Closure;
91 typedef std::function<void(Closure)> Runner;
92 typedef std::function<void(int64_t, int64_t)> Work;
93
94 // Refers to Shard()'s comment for the meaning of total,
95 // cost_per_unit, work, max_parallelism. runner is an interface to
96 // schedule a closure. Shard() uses thread::ThreadPool instead.
97 static void Do(int64_t total, int64_t cost_per_unit, const Work& work,
98 const Runner& runner, int max_parallelism);
99};
100
101} // end namespace tensorflow
102
103#endif // TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
104