1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
16#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
17
18#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19#include "tensorflow/core/framework/collective.h"
20
21namespace tensorflow {
22class CollectiveParamResolverDistributed;
23class ConfigProto;
24class DeviceMgr;
25class DeviceResolverDistributed;
26class WorkerCacheInterface;
27class StepSequenceRequest;
28class StepSequenceResponse;
29
30// An implementation of CollectiveExecutorMgr for a distributed environment
31// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs.
32//
33// In some execution environments it may be possible to implement a
34// higher-performance solution and use it in place of this class.
35class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
36 public:
37 RpcCollectiveExecutorMgr(
38 const ConfigProto& config, const DeviceMgr* dev_mgr,
39 std::unique_ptr<DeviceResolverDistributed> dev_resolver,
40 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
41 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
42 WorkerCacheInterface* worker_cache, const string& task_name);
43
44 virtual ~RpcCollectiveExecutorMgr();
45
46 // This function should only be called at the group_leader, by an RPC.
47 // Other needs for StepIds should be satisfied by NextStepId.
48 void GetStepSequenceAsync(const GetStepSequenceRequest* request,
49 GetStepSequenceResponse* response,
50 const StatusCallback& done) override;
51
52 void RefreshStepIdSequenceAsync(int64_t graph_key,
53 const StatusCallback& done) override;
54
55 int64_t NextStepId(int64_t graph_key) override;
56
57 void RetireStepId(int64_t graph_key, int64_t step_id) override;
58
59 protected:
60 virtual CollectiveExecutor* Create(int64_t step_id) override;
61
62 WorkerCacheInterface* const worker_cache_; // Not owned.
63 const string task_name_;
64 string group_leader_;
65 friend class RpcCollectiveExecutorMgrTest;
66
67 private:
68 Status UpdateStepSequences(const GetStepSequenceResponse& resp);
69
70 // This class maintains the step_id sequencing for a single
71 // collective_graph_key.
72 struct GraphKeySequence {
73 explicit GraphKeySequence(int64_t k)
74 : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {}
75
76 const int64_t graph_key_;
77 int64_t next_step_id_;
78 };
79
80 mutex sequence_mu_;
81 gtl::FlatMap<int64_t, GraphKeySequence*> sequence_table_
82 TF_GUARDED_BY(sequence_mu_);
83};
84
85// Creates a distributed CollectiveExecutorMgr with production implementations
86// of each components. Cases that need to inject other implementations of these
87// components should call CollectiveExecutorMgr constructor directly.
88std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr(
89 const ConfigProto& config, const DeviceMgr* device_mgr,
90 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
91 WorkerCacheInterface* worker_cache, const string& default_worker_name);
92
93} // namespace tensorflow
94#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
95