1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
16 | |
17 | #include "absl/memory/memory.h" |
18 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
19 | #include "tensorflow/core/common_runtime/build_graph_options.h" |
20 | #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
21 | #include "tensorflow/core/common_runtime/collective_rma_local.h" |
22 | #include "tensorflow/core/common_runtime/device_mgr.h" |
23 | #include "tensorflow/core/common_runtime/device_resolver_local.h" |
24 | #include "tensorflow/core/framework/collective.h" |
25 | #include "tensorflow/core/protobuf/config.pb.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | CollectiveExecutorMgr::CollectiveExecutorMgr( |
30 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
31 | std::unique_ptr<DeviceResolverInterface> dev_resolver, |
32 | std::unique_ptr<ParamResolverInterface> param_resolver, |
33 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator) |
34 | : dev_mgr_(dev_mgr), |
35 | dev_resolver_(std::move(dev_resolver)), |
36 | param_resolver_(std::move(param_resolver)), |
37 | gpu_ring_order_( |
38 | config.gpu_options().experimental().collective_ring_order()), |
39 | nccl_communicator_(std::move(nccl_communicator)), |
40 | work_queue_(std::make_shared<UnboundedWorkQueue>(Env::Default(), |
41 | "collective_ops" )) {} |
42 | |
43 | CollectiveExecutorMgr::~CollectiveExecutorMgr() { |
44 | for (auto iter : executor_table_) { |
45 | iter.second->Unref(); |
46 | } |
47 | } |
48 | |
49 | CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64_t step_id) { |
50 | CollectiveExecutor* ce = nullptr; |
51 | { |
52 | mutex_lock l(exec_mu_); |
53 | auto it = executor_table_.find(step_id); |
54 | if (it != executor_table_.end()) { |
55 | ce = it->second; |
56 | } else { |
57 | ce = Create(step_id); |
58 | executor_table_[step_id] = ce; |
59 | } |
60 | ce->Ref(); |
61 | } |
62 | return ce; |
63 | } |
64 | |
65 | CollectiveExecutor* CollectiveExecutorMgr::Create(int64_t step_id) { |
66 | CollectiveRemoteAccessLocal* rma = |
67 | new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id); |
68 | return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, work_queue_); |
69 | } |
70 | |
71 | void CollectiveExecutorMgr::Cleanup(int64_t step_id) { |
72 | CollectiveExecutor* ce = nullptr; |
73 | { |
74 | mutex_lock l(exec_mu_); |
75 | auto it = executor_table_.find(step_id); |
76 | if (it != executor_table_.end()) { |
77 | ce = it->second; |
78 | executor_table_.erase(it); |
79 | } |
80 | } |
81 | if (ce) ce->Unref(); |
82 | } |
83 | |
84 | void CollectiveExecutorMgr::GetStepSequenceAsync( |
85 | const GetStepSequenceRequest* request, GetStepSequenceResponse* response, |
86 | const StatusCallback& done) { |
87 | done(errors::Internal( |
88 | "CollectiveExecutorMgr does not implement GetStepSequence." )); |
89 | } |
90 | |
91 | void CollectiveExecutorMgr::RefreshStepIdSequenceAsync( |
92 | int64_t graph_key, const StatusCallback& done) { |
93 | done(errors::Internal( |
94 | "CollectiveExecutorMgr does not implement RefreshStepIdSequence." )); |
95 | } |
96 | |
97 | std::unique_ptr<CollectiveExecutorMgr> CreateProdLocalCollectiveExecutorMgr( |
98 | const ConfigProto& config, const DeviceMgr* device_mgr, |
99 | std::unique_ptr<NcclCommunicatorInterface> nccl_communicator) { |
100 | auto device_resolver = std::make_unique<DeviceResolverLocal>(device_mgr); |
101 | auto param_resolver = std::make_unique<CollectiveParamResolverLocal>( |
102 | config, device_mgr, device_resolver.get(), nccl_communicator.get(), |
103 | "/job:localhost/replica:0/task:0" ); |
104 | return std::make_unique<CollectiveExecutorMgr>( |
105 | config, device_mgr, std::move(device_resolver), std::move(param_resolver), |
106 | std::move(nccl_communicator)); |
107 | } |
108 | |
109 | } // namespace tensorflow |
110 | |