1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16
17#include "tensorflow/core/common_runtime/base_collective_executor.h"
18#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19#include "tensorflow/core/common_runtime/collective_rma_local.h"
20#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23#include "tensorflow/core/distributed_runtime/worker_cache.h"
24#include "tensorflow/core/lib/random/random.h"
25
26namespace tensorflow {
27
28RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29 const ConfigProto& config, const DeviceMgr* dev_mgr,
30 std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
33 WorkerCacheInterface* worker_cache, const string& task_name)
34 : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
35 std::move(param_resolver),
36 std::move(nccl_communicator)),
37 worker_cache_(worker_cache),
38 task_name_(task_name) {
39 group_leader_ = (task_name == config.experimental().collective_group_leader())
40 ? ""
41 : config.experimental().collective_group_leader();
42}
43
44RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
45 for (auto it : sequence_table_) {
46 delete it.second;
47 }
48}
49
50CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64_t step_id) {
51 CollectiveRemoteAccessDistributed* rma =
52 new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
53 work_queue_, worker_cache_, step_id,
54 task_name_);
55 return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, work_queue_);
56}
57
58namespace {
59// StepId must leave the most-significant 7 bits empty for future use.
60static const int64_t kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
61
62int64_t NewRandomStepId() {
63 int64_t step_id = random::New64();
64 // Leave MS 8 bits clear for future use.
65 step_id &= kStepIdMask;
66 return step_id;
67}
68} // namespace
69
70void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
71 int64_t graph_key, const StatusCallback& done) {
72 if (group_leader_.empty()) {
73 mutex_lock l(sequence_mu_);
74 GraphKeySequence* gks = nullptr;
75 auto it = sequence_table_.find(graph_key);
76 if (it == sequence_table_.end()) {
77 gks = new GraphKeySequence(graph_key);
78 sequence_table_[graph_key] = gks;
79 } else {
80 gks = it->second;
81 }
82 gks->next_step_id_ = NewRandomStepId();
83 done(OkStatus());
84 } else {
85 WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_);
86 GetStepSequenceRequest* req = new GetStepSequenceRequest;
87 GetStepSequenceResponse* resp = new GetStepSequenceResponse;
88 req->add_graph_key(graph_key);
89 wi->GetStepSequenceAsync(
90 req, resp, [this, req, resp, done](const Status& s) {
91 if (!s.ok()) {
92 LOG(ERROR) << "Bad response [" << s
93 << "] from GetStepSequenceAsync call to "
94 << group_leader_;
95 done(s);
96 } else {
97 done(UpdateStepSequences(*resp));
98 }
99 delete req;
100 delete resp;
101 });
102 }
103}
104
105void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
106 const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
107 const StatusCallback& done) {
108 if (!group_leader_.empty()) {
109 LOG(ERROR) << "GetStepSequence called at non-group-leader";
110 done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
111 } else {
112 mutex_lock l(sequence_mu_);
113 for (int64_t graph_key : request->graph_key()) {
114 auto it = sequence_table_.find(graph_key);
115 GraphKeySequence* gks = nullptr;
116 if (it == sequence_table_.end()) {
117 gks = new GraphKeySequence(graph_key);
118 gks->next_step_id_ = NewRandomStepId();
119 sequence_table_[graph_key] = gks;
120 } else {
121 gks = it->second;
122 }
123 StepSequence* ss = response->add_step_sequence();
124 ss->set_graph_key(graph_key);
125 ss->set_next_step_id(gks->next_step_id_);
126 }
127 done(OkStatus());
128 }
129}
130
131Status RpcCollectiveExecutorMgr::UpdateStepSequences(
132 const GetStepSequenceResponse& resp) {
133 mutex_lock l(sequence_mu_);
134 for (const StepSequence& ss : resp.step_sequence()) {
135 GraphKeySequence* gks = nullptr;
136 auto it = sequence_table_.find(ss.graph_key());
137 if (it == sequence_table_.end()) {
138 gks = new GraphKeySequence(ss.graph_key());
139 sequence_table_[ss.graph_key()] = gks;
140 } else {
141 gks = it->second;
142 }
143 gks->next_step_id_ = ss.next_step_id();
144 }
145 return OkStatus();
146}
147
148int64_t RpcCollectiveExecutorMgr::NextStepId(int64_t graph_key) {
149 mutex_lock l(sequence_mu_);
150 auto it = sequence_table_.find(graph_key);
151 if (it != sequence_table_.end()) {
152 return it->second->next_step_id_;
153 }
154 return CollectiveExecutor::kInvalidId;
155}
156
157void RpcCollectiveExecutorMgr::RetireStepId(int64_t graph_key,
158 int64_t step_id) {
159 mutex_lock l(sequence_mu_);
160 auto it = sequence_table_.find(graph_key);
161 if (it != sequence_table_.end()) {
162 if (step_id == it->second->next_step_id_) {
163 it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
164 } else {
165 it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
166 }
167 } else {
168 LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
169 }
170}
171
172std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr(
173 const ConfigProto& config, const DeviceMgr* device_mgr,
174 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
175 WorkerCacheInterface* worker_cache, const string& default_worker_name) {
176 auto dev_resolver = std::make_unique<DeviceResolverDistributed>(device_mgr);
177 auto param_resolver = std::make_unique<CollectiveParamResolverDistributed>(
178 config, device_mgr, dev_resolver.get(), nccl_communicator.get(),
179 worker_cache, default_worker_name);
180 return std::make_unique<RpcCollectiveExecutorMgr>(
181 config, device_mgr, std::move(dev_resolver), std::move(param_resolver),
182 std::move(nccl_communicator), worker_cache, default_worker_name);
183}
184
185} // namespace tensorflow
186