1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ |
17 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ |
18 | |
19 | #include <functional> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/distributed_runtime/worker_cache.h" |
23 | #include "tensorflow/core/protobuf/cluster.pb.h" |
24 | #include "tensorflow/core/protobuf/config.pb.h" |
25 | #include "tensorflow/core/protobuf/tensorflow_server.pb.h" |
26 | #include "tensorflow/core/public/session_options.h" |
27 | |
28 | namespace tsl { |
29 | class Env; |
30 | } // namespace tsl |
31 | namespace tensorflow { |
32 | using Env = tsl::Env; |
33 | |
34 | class CollectiveExecutorMgrInterface; |
35 | class Device; |
36 | class DeviceSet; |
37 | class MasterSession; |
38 | class OpRegistryInterface; |
39 | |
40 | // Options passed to the worker_cache_factory function. |
41 | struct WorkerCacheFactoryOptions { |
42 | const ClusterDef* cluster_def = nullptr; |
43 | const string* job_name = nullptr; |
44 | int task_index; |
45 | const string* protocol = nullptr; |
46 | const RPCOptions* rpc_options = nullptr; |
47 | |
48 | WorkerCacheFactoryOptions() {} |
49 | |
50 | // Construct from a ServerDef proto. |
51 | // |
52 | // Note: server_def must outlive WorkerCacheFactoryOptions! |
53 | WorkerCacheFactoryOptions(const ServerDef& server_def) { |
54 | if (server_def.has_cluster() && !server_def.job_name().empty()) { |
55 | cluster_def = &server_def.cluster(); |
56 | job_name = &server_def.job_name(); |
57 | task_index = server_def.task_index(); |
58 | protocol = &server_def.protocol(); |
59 | rpc_options = &server_def.default_session_config().rpc_options(); |
60 | } |
61 | } |
62 | }; |
63 | |
64 | // The master environment class, which holds a bag of pointers to |
65 | // per-master state. |
66 | // |
67 | // MasterEnv does not own its member pointers. |
68 | struct MasterEnv { |
69 | Env* env = nullptr; |
70 | |
71 | // Object from which WorkerInterface instances can be obtained. Not owned. |
72 | WorkerCacheInterface* worker_cache = nullptr; |
73 | |
74 | // The operation definitions to use. Must be filled before use. |
75 | const OpRegistryInterface* ops = nullptr; |
76 | |
77 | // Local devices co-located with this master. Devices are not owned |
78 | // by the master service. |
79 | // |
80 | // REQUIRES: !local_devices.empty(). |
81 | std::vector<Device*> local_devices; |
82 | |
83 | // In large scaled distributed training, many singleton components (e.g. |
84 | // Rendezvous) can becomes the bottleneck of the system. This field allows |
85 | // us to shard the single components. This number will scale up with number |
86 | // of tasks in this cluster. It is always greater than 1. |
87 | int experimental_num_shards = 1; |
88 | |
89 | // Factory for creating master sessions, given session options and a |
90 | // vector of devices. |
91 | // |
92 | // The caller of the function takes ownership of the returned |
93 | // `MasterSession`, which may not be null. Ownership of the |
94 | // `MasterEnv*` is retained by the caller. |
95 | std::function<MasterSession*( |
96 | SessionOptions, MasterEnv*, |
97 | std::unique_ptr<std::vector<std::unique_ptr<Device>>>, |
98 | std::unique_ptr<WorkerCacheInterface>, |
99 | std::unique_ptr<DeviceSet> device_set, |
100 | std::vector<string> filtered_worker_list)> |
101 | master_session_factory; |
102 | |
103 | std::function<Status(const WorkerCacheFactoryOptions&, |
104 | WorkerCacheInterface**)> |
105 | worker_cache_factory; |
106 | |
107 | // Generates per-step CollectiveExecutors and has access to utilities |
108 | // supporting collective operations. Not owned. |
109 | CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr; |
110 | }; |
111 | |
112 | } // end namespace tensorflow |
113 | |
114 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ |
115 | |