1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ |
16 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ |
17 | |
18 | #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
19 | #include "tensorflow/core/framework/cancellation.h" |
20 | #include "tensorflow/core/framework/device_attributes.pb.h" |
21 | #include "tensorflow/core/platform/status.h" |
22 | |
23 | namespace tensorflow { |
24 | class ConfigProto; |
25 | class WorkerCacheInterface; |
26 | class DeviceResolverDistributed; |
27 | class DeviceMgr; |
28 | |
29 | class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { |
30 | public: |
31 | CollectiveParamResolverDistributed( |
32 | const ConfigProto& config, const DeviceMgr* dev_mgr, |
33 | DeviceResolverDistributed* dev_resolver, |
34 | NcclCommunicatorInterface* nccl_communicator, |
35 | WorkerCacheInterface* worker_cache, const string& task_name); |
36 | |
37 | void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, |
38 | CancellationManager* cancel_mgr, |
39 | const StatusCallback& done) override; |
40 | |
41 | void CompleteGroupAsync(const DeviceAttributes& device, |
42 | CollGroupParams* group_params, |
43 | CancellationManager* cancel_mgr, |
44 | const StatusCallback& done) override; |
45 | |
46 | void CompleteInstanceAsync(const CompleteInstanceRequest* request, |
47 | CompleteInstanceResponse* response, |
48 | CancellationManager* cancel_mgr, |
49 | const StatusCallback& done) override; |
50 | |
51 | void StartAbort(const Status& s) override; |
52 | |
53 | protected: |
54 | // Returns the cached group iff there's an entry for this group_key in the |
55 | // local group_table_; returns nullptr otherwise. |
56 | GroupRec* GetCachedGroup(int32_t group_key) TF_LOCKS_EXCLUDED(group_mu_); |
57 | |
58 | // Updates group_table_ with contents of resp. |
59 | Status UpdateGroupCache(const CompleteGroupResponse& resp) |
60 | TF_LOCKS_EXCLUDED(group_mu_); |
61 | |
62 | // Finds the GroupRec that corresponds to cp->group_key and also |
63 | // populates cp->group from that GroupRec. |
64 | // |
65 | // Semantics are like those of CompleteGroupLocal but will make a |
66 | // remote call to the group leader if necessary. |
67 | void CompleteGroupDistributed(const DeviceAttributes& device, |
68 | CollGroupParams* group_params, |
69 | CancellationManager* cancel_mgr, |
70 | const StatusCallback& done); |
71 | |
72 | // Returns true iff there's an entry for this instance_key in the |
73 | // local instance_table_. |
74 | bool InstanceIsCached(int32_t group_key, int32_t instance_key) |
75 | TF_LOCKS_EXCLUDED(instance_mu_); |
76 | |
77 | // Updates instance_table_ with contents of resp. |
78 | Status UpdateInstanceCache(CollectiveParams* cp, |
79 | const CompleteInstanceResponse& resp) |
80 | TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); |
81 | |
82 | // Finish populating *cp. Semantics are like those of |
83 | // CompleteInstanceLocal but will make a remote call to the group |
84 | // leader if necessary. |
85 | void CompleteInstanceDistributed(const string& device, CollectiveParams* cp, |
86 | CancellationManager* cancel_mgr, |
87 | const StatusCallback& done) |
88 | TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); |
89 | |
90 | WorkerCacheInterface* worker_cache_; // Not owned |
91 | const string group_leader_; |
92 | CancellationManager abortion_cancel_mgr_; |
93 | }; |
94 | |
95 | } // namespace tensorflow |
96 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ |
97 | |