1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ |
17 | |
18 | #include <functional> |
19 | #include <memory> |
20 | #include <set> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/core/framework/collective.h" |
26 | #include "tensorflow/core/framework/device_attributes.pb.h" |
27 | #include "tensorflow/core/lib/gtl/flatmap.h" |
28 | #include "tensorflow/core/platform/thread_annotations.h" |
29 | |
30 | namespace tensorflow { |
31 | class CompleteGroupRequest; |
32 | class CompleteGroupResponse; |
33 | class CompleteInstanceRequest; |
34 | class CompleteInstanceResponse; |
35 | class ConfigProto; |
36 | class DeviceMgr; |
37 | |
38 | // Implements ParamResolverInterface for a single-task context. |
39 | // It also implements the functionality necessary to serve as the |
40 | // group leader for param resolution in a multi-task context. |
41 | class CollectiveParamResolverLocal : public ParamResolverInterface { |
42 | public: |
43 | CollectiveParamResolverLocal(const ConfigProto& config, |
44 | const DeviceMgr* dev_mgr, |
45 | DeviceResolverInterface* dev_resolver, |
46 | NcclCommunicatorInterface* nccl_communicator, |
47 | const string& task_name); |
48 | |
49 | ~CollectiveParamResolverLocal() override {} |
50 | |
51 | void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, |
52 | CancellationManager* cancel_mgr, |
53 | const StatusCallback& done) override; |
54 | |
55 | void CompleteGroupAsync(const DeviceAttributes& device, |
56 | CollGroupParams* group_params, |
57 | CancellationManager* cancel_mgr, |
58 | const StatusCallback& done) override; |
59 | |
60 | void CompleteInstanceAsync(const CompleteInstanceRequest* request, |
61 | CompleteInstanceResponse* response, |
62 | CancellationManager* cancel_mgr, |
63 | const StatusCallback& done) override; |
64 | |
65 | Status LookupGroup(int32_t group_key, CollGroupParams* group) override; |
66 | |
67 | void StartAbort(const Status& s) override; |
68 | |
69 | protected: |
70 | // For access to InstanceRec and CompleteDefaultRanking. |
71 | friend class CollectiveParamResolverLocalTest; |
72 | |
73 | // Used to complete/verify CollGroup. |
74 | struct GroupRec { |
75 | mutable mutex mu; |
76 | CollGroupParams group TF_GUARDED_BY(mu); |
77 | Status status TF_GUARDED_BY(mu); |
78 | std::unordered_map<string, int64_t> incarnations_by_device_name |
79 | TF_GUARDED_BY(mu); |
80 | std::vector<CollGroupParams*> pending_params TF_GUARDED_BY(mu); |
81 | std::vector<StatusCallback> pending_done TF_GUARDED_BY(mu); |
82 | }; |
83 | |
84 | // Finds the GroupRec that corresponds to group_params->group_key. |
85 | // Also populates group_params from that group_rec. |
86 | // Will wait until GroupRec is fully populated or an error arises before |
87 | // calling done. Callback GroupRec* arg is only valid if status is ok. |
88 | // Ownership of GroupRec stays with this object and does not pass to the |
89 | // callback. |
90 | void CompleteGroupLocal(const DeviceAttributes& device, |
91 | CollGroupParams* group_params, |
92 | CancellationManager* cancel_mgr, StatusCallback done) |
93 | TF_LOCKS_EXCLUDED(group_mu_); |
94 | |
95 | // Finishes the group parameters once all members of the group are there. |
96 | void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu); |
97 | |
98 | // Cancels the group if it's still pending. |
99 | void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); |
100 | |
101 | // Lookup and populate parameters from an already initialized group. |
102 | Status LookupAndPopulateGroupParams(CollGroupParams* group_params); |
103 | |
104 | // Used to complete/verify CollInstance. |
105 | struct InstanceRec; |
106 | |
107 | typedef std::function<void(InstanceRec*)> IRConsumer; |
108 | struct InstanceRec { |
109 | mutex mu; |
110 | // Values to be shared by all instances, constant after initialization. |
111 | CollectiveParams* shared; |
112 | // If an error occurs during initialization this structure stays in the |
113 | // table with a non-OK status. Purging the table and restarting needs to be |
114 | // done at a higher level. |
115 | Status status TF_GUARDED_BY(mu); |
116 | |
117 | // These fields are used to count the instances that have called |
118 | // in and become known while resolving broadcast source identity and |
119 | // communicator key. |
120 | int source_rank TF_GUARDED_BY(mu); |
121 | string communicator_key TF_GUARDED_BY(mu); |
122 | int known_count TF_GUARDED_BY(mu); |
123 | std::vector<bool> known TF_GUARDED_BY(mu); |
124 | std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu); |
125 | |
126 | InstanceRec() |
127 | : shared(new CollectiveParams()), source_rank(-1), known_count(0) {} |
128 | ~InstanceRec() { shared->Unref(); } |
129 | }; |
130 | |
131 | // Find the InstanceRec with the same instance_key as cp. If it doesn't |
132 | // already exist, create and initialize from gr and cp. |
133 | // created is set to true if a new IRec is created, false otherwise. |
134 | // |
135 | // Precondition: *gr must be a complete GroupRec, i.e. the value set |
136 | // by CompleteGroupLocal. *cp must be populated with all the fields |
137 | // required by InitInstanceSharedParams. Ownership of InstanceRec stays |
138 | // with this object and does not pass to the callback. |
139 | InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created) |
140 | TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); |
141 | |
142 | // Populate *ir with device membership from gr, then initialize to be specific |
143 | // to cp->instance_key, i.e. order the devices and tasks. |
144 | // |
145 | // Preconditions: |
146 | // cp is populated with all DeviceLocalities |
147 | void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir); |
148 | |
149 | // Establishes the final order of gp->device_names and gp->task_names by |
150 | // considering localities of all devices. |
151 | void CompleteDefaultRanking(CollGroupParams* gp); |
152 | |
153 | // Finish populating *cp. |
154 | // Precondition: *gr has been fully populated by CompleteGroupLocal. |
155 | void CompleteInstanceLocal(const string& device, CollectiveParams* cp, |
156 | const StatusCallback& done) |
157 | TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); |
158 | |
159 | // Finish populating *cp from fully initialized *ir. |
160 | // Precondition: *gr and *ir are fully populated. |
161 | void CompleteInstanceFromInitializedIRec(const string& device, |
162 | CollectiveParams* cp, |
163 | InstanceRec* ir, |
164 | const StatusCallback& done) |
165 | TF_LOCKS_EXCLUDED(ir->mu); |
166 | |
167 | // Complete instance params after waiting for group. |
168 | // Precondition: *cp has complete group data and default_rank. |
169 | void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f) |
170 | TF_LOCKS_EXCLUDED(ir->mu); |
171 | |
172 | // If cp.device_names contains only devices local to this process |
173 | // populates *localities, else returns an error. |
174 | Status GetLocalDeviceLocalities(const CollectiveParams& cp, |
175 | std::vector<DeviceLocality>* localities); |
176 | |
177 | // Sets cp->instance_default_rank according to location of device in |
178 | // current ordering of cp->instance.device_names. |
179 | void SetDefaultRank(const string& device, CollectiveParams* cp); |
180 | |
181 | // Sets cp->instance.type based on collective op type, and attempts to assign |
182 | // best implementation. |
183 | void AssignCollectiveType(CollectiveParams* cp); |
184 | |
185 | void StartAbortLocal(const Status& s) |
186 | TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_); |
187 | |
188 | const bool nccl_; |
189 | const DeviceMgr* dev_mgr_; |
190 | DeviceResolverInterface* dev_resolver_; // Not owned. |
191 | NcclCommunicatorInterface* nccl_communicator_; // Not owned. |
192 | string task_name_; |
193 | string gpu_ring_order_; |
194 | mutex group_mu_; |
195 | gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_ |
196 | TF_GUARDED_BY(group_mu_); |
197 | mutex instance_mu_; |
198 | gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>> |
199 | instance_table_ TF_GUARDED_BY(instance_mu_); |
200 | mutex status_mu_; |
201 | Status status_ TF_GUARDED_BY(status_mu_); |
202 | }; |
203 | |
204 | } // namespace tensorflow |
205 | |
206 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ |
207 | |