1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
18
19#include <functional>
20#include <vector>
21
22#include "tensorflow/core/distributed_runtime/worker_cache.h"
23#include "tensorflow/core/protobuf/cluster.pb.h"
24#include "tensorflow/core/protobuf/config.pb.h"
25#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
26#include "tensorflow/core/public/session_options.h"
27
28namespace tsl {
29class Env;
30} // namespace tsl
31namespace tensorflow {
32using Env = tsl::Env;
33
34class CollectiveExecutorMgrInterface;
35class Device;
36class DeviceSet;
37class MasterSession;
38class OpRegistryInterface;
39
40// Options passed to the worker_cache_factory function.
41struct WorkerCacheFactoryOptions {
42 const ClusterDef* cluster_def = nullptr;
43 const string* job_name = nullptr;
44 int task_index;
45 const string* protocol = nullptr;
46 const RPCOptions* rpc_options = nullptr;
47
48 WorkerCacheFactoryOptions() {}
49
50 // Construct from a ServerDef proto.
51 //
52 // Note: server_def must outlive WorkerCacheFactoryOptions!
53 WorkerCacheFactoryOptions(const ServerDef& server_def) {
54 if (server_def.has_cluster() && !server_def.job_name().empty()) {
55 cluster_def = &server_def.cluster();
56 job_name = &server_def.job_name();
57 task_index = server_def.task_index();
58 protocol = &server_def.protocol();
59 rpc_options = &server_def.default_session_config().rpc_options();
60 }
61 }
62};
63
64// The master environment class, which holds a bag of pointers to
65// per-master state.
66//
67// MasterEnv does not own its member pointers.
68struct MasterEnv {
69 Env* env = nullptr;
70
71 // Object from which WorkerInterface instances can be obtained. Not owned.
72 WorkerCacheInterface* worker_cache = nullptr;
73
74 // The operation definitions to use. Must be filled before use.
75 const OpRegistryInterface* ops = nullptr;
76
77 // Local devices co-located with this master. Devices are not owned
78 // by the master service.
79 //
80 // REQUIRES: !local_devices.empty().
81 std::vector<Device*> local_devices;
82
83 // In large scaled distributed training, many singleton components (e.g.
84 // Rendezvous) can becomes the bottleneck of the system. This field allows
85 // us to shard the single components. This number will scale up with number
86 // of tasks in this cluster. It is always greater than 1.
87 int experimental_num_shards = 1;
88
89 // Factory for creating master sessions, given session options and a
90 // vector of devices.
91 //
92 // The caller of the function takes ownership of the returned
93 // `MasterSession`, which may not be null. Ownership of the
94 // `MasterEnv*` is retained by the caller.
95 std::function<MasterSession*(
96 SessionOptions, MasterEnv*,
97 std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
98 std::unique_ptr<WorkerCacheInterface>,
99 std::unique_ptr<DeviceSet> device_set,
100 std::vector<string> filtered_worker_list)>
101 master_session_factory;
102
103 std::function<Status(const WorkerCacheFactoryOptions&,
104 WorkerCacheInterface**)>
105 worker_cache_factory;
106
107 // Generates per-step CollectiveExecutors and has access to utilities
108 // supporting collective operations. Not owned.
109 CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
110};
111
112} // end namespace tensorflow
113
114#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
115