1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ |
16 | #define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ |
17 | |
18 | #include <string> |
19 | #include <vector> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "tensorflow/core/framework/cancellation.h" |
23 | #include "tensorflow/core/framework/device_attributes.pb.h" |
24 | #include "tensorflow/core/framework/device_base.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/lib/core/refcount.h" |
27 | #include "tensorflow/core/lib/core/status.h" |
28 | #include "tensorflow/core/platform/intrusive_ptr.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | class BufRendezvous; |
33 | class CompleteGroupRequest; |
34 | class CompleteGroupResponse; |
35 | class CompleteInstanceRequest; |
36 | class CompleteInstanceResponse; |
37 | class Device; |
38 | class DeviceMgr; |
39 | class GetStepSequenceRequest; |
40 | class GetStepSequenceResponse; |
41 | class NcclManager; |
42 | class Tensor; |
43 | |
44 | // Types of supported collective operations. |
45 | enum CollectiveType { |
46 | REDUCTION_COLLECTIVE = 0, |
47 | BROADCAST_COLLECTIVE, |
48 | GATHER_COLLECTIVE, |
49 | PERMUTE_COLLECTIVE, |
50 | ALL_TO_ALL_COLLECTIVE, |
51 | UNDEFINED_COLLECTIVE, |
52 | }; |
53 | |
54 | // Some collective op implementations require runtime group configuration from |
55 | // the OpKernel. Currently, this struct is used to set communicator key for |
56 | // NCCL-based collective implementation. |
57 | struct CollGroupRuntimeDetails { |
58 | string communicator_key; // for communicator-based techniques e.g. NCCL |
59 | string ToString() const; |
60 | }; |
61 | |
62 | struct CollGroupMember { |
63 | DeviceAttributes device; |
64 | string task; |
65 | bool is_local; |
66 | // User provided rank |
67 | int32 rank = -1; |
68 | }; |
69 | |
70 | // Data common to all members of a device group. |
71 | // All members share the same device set but its order is |
72 | // particular to an instance so it is stored there. |
73 | struct CollGroupParams { |
74 | // Inputs from Collective ops: |
75 | int32 group_key; |
76 | int32 group_size; |
77 | DeviceType device_type; |
78 | int user_specified_rank = -1; // rank provided by the user. |
79 | // Generated from Collective Group Resolver: |
80 | // Members in this group, in default rank order. |
81 | std::vector<CollGroupMember> members; |
82 | // True if every task has the same number of devices. |
83 | bool same_num_devices_per_task = false; |
84 | // Task -> number of devices on that task. |
85 | std::unordered_map<string, int32> num_devices_per_task; |
86 | int32 num_tasks; // number of distinct tasks in group |
87 | CollGroupRuntimeDetails runtime_details; |
88 | string ToString() const; |
89 | CollGroupParams() |
90 | : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} |
91 | }; |
92 | |
93 | // The best implementation of a collective op depends on many factors |
94 | // including the number of devices involved, the topology of |
95 | // interconnects between them and the sizes of inputs. This structure |
96 | // is used in generating and representing data movement choreography |
97 | // for each specific algorithm, hence it does not have a single, fixed |
98 | // interpretation. On first execution the runtime will update this |
99 | // structure with decisions that will guide all subsequent executions. |
100 | struct CollImplDetails { |
101 | string collective_name; |
102 | std::vector<std::vector<int>> subdiv_permutations; |
103 | // subdiv_offsets and max_subdivs_per_device are used together as follows: |
104 | // When subdiv_offsets is provided (non-empty) it is used as is. When |
105 | // subdiv_offsets is not provided subdivisons are generated dynamically |
106 | // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND |
107 | // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault |
108 | // is used. When max_subdivs_per_device = -1, no subivision is done. |
109 | int max_subdivs_per_device = -1; // Upper bound on subdivisions per device. |
110 | std::vector<int> subdiv_offsets; |
111 | std::vector<int> subdiv_source_rank; // rank of source in each subdiv |
112 | std::vector<int32> |
113 | dependencies; // collective instances on which this node depends |
114 | string communication_hint; // user-supplied hint for implementation choice, |
115 | // e.g. ring or nccl |
116 | float timeout_seconds; // If non zero, set a completion timeout for the |
117 | // collective op to detect staleness. |
118 | }; |
119 | |
120 | // Data common to all members of a collective instance. |
121 | // TODO(b/163171014) Refactor this struct to not be a union of all fields. |
122 | struct CollInstanceParams { |
123 | // Identifies all participating graph nodes. |
124 | int32 instance_key = -1; |
125 | CollectiveType type = UNDEFINED_COLLECTIVE; |
126 | DataType data_type = DT_FLOAT; |
127 | TensorShape shape = {0}; |
128 | CollImplDetails impl_details; |
129 | string ToString() const; |
130 | CollInstanceParams& operator=(const struct CollInstanceParams& other); |
131 | std::vector<string> devices; // permuter only |
132 | |
133 | // For permuter only |
134 | // Each rank in the permutation is a receiver. |
135 | // Indices of each rank means a sender to that rank. |
136 | // Example: permutation = {2,0,1} means |
137 | // rank 0 sends to rank 2 |
138 | // rank 1 sends to rank 0 |
139 | // rank 2 sends to rank 1 |
140 | std::vector<int> permutation; |
141 | }; |
142 | |
143 | // Unique to a single CollectiveOp node. |
144 | struct CollectiveParams : public core::RefCounted { |
145 | CollGroupParams group; |
146 | CollInstanceParams instance; |
147 | |
148 | string name = "" ; // node name used only for log or error messages |
149 | int default_rank = -1; // index of this op within device_names |
150 | bool is_source = false; // broadcast only |
151 | int source_rank = -1; // broadcast only |
152 | // Rank of this device in each subdivision permutation. |
153 | std::vector<int> subdiv_rank; |
154 | OpKernel* merge_op = nullptr; // reduction only |
155 | OpKernel* final_op = nullptr; // reduction only |
156 | string ToString() const; |
157 | bool run_group_initialization = true; |
158 | }; |
159 | |
160 | class CollectiveExecutor; |
161 | |
162 | // Interface that provides resolution of device localities. |
163 | class DeviceResolverInterface { |
164 | public: |
165 | virtual ~DeviceResolverInterface() {} |
166 | |
167 | // Populates *attributes with the DeviceAttributes of the specified device. |
168 | virtual Status GetDeviceAttributes(const string& device, |
169 | DeviceAttributes* attributes) = 0; |
170 | |
171 | // Returns all device attributes of a task. |
172 | virtual Status GetAllDeviceAttributes( |
173 | const string& task, std::vector<DeviceAttributes>* attributes) = 0; |
174 | |
175 | // Updates device attributes. It returns error if any device already |
176 | // exists in the DeviceResolver and has a different incarnation. |
177 | virtual Status UpdateDeviceAttributes( |
178 | const std::vector<DeviceAttributes>& attributes) = 0; |
179 | }; |
180 | |
181 | // Interface that provides resolution of shared CollectiveParams fields. |
182 | class ParamResolverInterface { |
183 | public: |
184 | virtual ~ParamResolverInterface() {} |
185 | |
186 | // Called by each collective op at first execution in order to fill out |
187 | // the CollectiveParams structure with data gathered from the full |
188 | // (maybe distributed) collection of peer nodes. |
189 | virtual void CompleteParamsAsync(const DeviceAttributes& device, |
190 | CollectiveParams* cp, |
191 | CancellationManager* cancel_mgr, |
192 | const StatusCallback& done) = 0; |
193 | |
194 | // Completes group_params with data gathered from all devices in the group. |
195 | // This blocks until all devices are there. |
196 | virtual void CompleteGroupAsync(const DeviceAttributes& device, |
197 | CollGroupParams* group_params, |
198 | CancellationManager* cancel_mgr, |
199 | const StatusCallback& done) = 0; |
200 | |
201 | // Used within a distributed implementation to discover/verify data |
202 | // shared across an instance group. |
203 | // Note: this works differently from CompleteGroupAsync as a refactor is in |
204 | // progress. |
205 | virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, |
206 | CompleteInstanceResponse* response, |
207 | CancellationManager* cancel_mgr, |
208 | const StatusCallback& done) = 0; |
209 | |
210 | // Looks up a group. It returns an error if the group is not ready or not |
211 | // found. |
212 | virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) = 0; |
213 | |
214 | // Aborts the resolver. After abortion the resolver can no longer be used. |
215 | virtual void StartAbort(const Status& s) = 0; |
216 | }; |
217 | |
218 | // Graphs which utilize Collective Ops in a common instance must |
219 | // execute with identical step_ids even if they are disjoint graphs |
220 | // run by otherwise independent tasks. This interface supplies |
221 | // coordinated step_ids to use in such cases. |
222 | class StepSequenceInterface { |
223 | public: |
224 | virtual ~StepSequenceInterface() {} |
225 | |
226 | // Used with a distributed implementation to coordinate step_id |
227 | // sequences across tasks. |
228 | virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, |
229 | GetStepSequenceResponse* response, |
230 | const StatusCallback& done) = 0; |
231 | |
232 | // Refresh the local per-graph_key step_id sequence from collective |
233 | // group leader, if applicable. |
234 | virtual void RefreshStepIdSequenceAsync(int64_t graph_key, |
235 | const StatusCallback& done) = 0; |
236 | |
237 | // Returns the step_id that should be used for initiating a new execution |
238 | // on the specified graph. May return the same step_id multiple times if |
239 | // RetireStepId or RefreshStepIdReservation is not called. |
240 | virtual int64_t NextStepId(int64_t graph_key) = 0; |
241 | |
242 | // Reports that execution of the given step has completed successfully. |
243 | // Should be called immediately after a step completes with OK status, |
244 | // prior to calling NextStepId(). If the step fails, don't call. |
245 | virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0; |
246 | }; |
247 | |
248 | class NcclCommunicatorInterface; |
249 | |
250 | // Interface that provides access to per-step CollectiveExecutor |
251 | // instances and various distributed resolution capabilities. |
252 | class CollectiveExecutorMgrInterface : public StepSequenceInterface { |
253 | public: |
254 | virtual ~CollectiveExecutorMgrInterface() {} |
255 | |
256 | // Returns the step-specific CollectiveExecutor, creating if one does not |
257 | // already exist. The caller assumes ownership of one Ref on the object. |
258 | virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0; |
259 | |
260 | // If there is a CollectiveExecutor for step_id, remove it from the |
261 | // table. |
262 | virtual void Cleanup(int64_t step_id) = 0; |
263 | |
264 | virtual ParamResolverInterface* GetParamResolver() const = 0; |
265 | |
266 | virtual DeviceResolverInterface* GetDeviceResolver() const = 0; |
267 | |
268 | virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0; |
269 | }; |
270 | |
271 | // Interface that a Collective Op implementation uses to exchange data |
272 | // with peers. Note that data exchange is currently limited to types |
273 | // for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric |
274 | // types. |
275 | class CollectiveRemoteAccess { |
276 | public: |
277 | virtual ~CollectiveRemoteAccess() {} |
278 | |
279 | virtual void RecvFromPeer(const string& peer_device, const string& peer_task, |
280 | bool peer_is_local, const string& key, |
281 | Device* to_device, DeviceContext* to_device_ctx, |
282 | const AllocatorAttributes& to_alloc_attr, |
283 | Tensor* to_tensor, |
284 | const DeviceLocality& client_locality, |
285 | int dev_to_dev_stream_index, |
286 | CancellationManager* cancellation_manager, |
287 | const StatusCallback& done) = 0; |
288 | |
289 | virtual void PostToPeer(const string& peer_device, const string& peer_task, |
290 | const string& key, Device* from_device, |
291 | DeviceContext* from_device_ctx, |
292 | const AllocatorAttributes& from_alloc_attr, |
293 | const Tensor* from_tensor, |
294 | const DeviceLocality& client_locality, |
295 | CancellationManager* cancellation_manager, |
296 | const StatusCallback& done) = 0; |
297 | |
298 | // Checks the health of a collective peer. It probes the peer to see if it is |
299 | // alive. Note that if a peer has restarted, it's considered a different one, |
300 | // so CheckPeerHealth fails. |
301 | virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, |
302 | const StatusCallback& done) = 0; |
303 | |
304 | virtual BufRendezvous* buf_rendezvous() = 0; |
305 | |
306 | virtual void StartAbort(const Status& s) = 0; |
307 | }; |
308 | |
309 | // A step-specific object that can execute a collective operation completely |
310 | // described by a CollectiveParams object. |
311 | class CollectiveExecutor : public core::RefCounted { |
312 | public: |
313 | virtual void StartAbort(const Status& s) {} |
314 | |
315 | virtual void ExecuteAsync(OpKernelContext* ctx, |
316 | const CollectiveParams* col_params, |
317 | const string& exec_key, StatusCallback done) { |
318 | done(errors::Internal( |
319 | "A collective Op has been called in a context in which " |
320 | "a CollectiveExecutor has not been provided." )); |
321 | } |
322 | |
323 | virtual void CompleteParamsAsync(const DeviceAttributes& device, |
324 | CollectiveParams* cp, |
325 | CancellationManager* cancel_mgr, |
326 | StatusCallback done) { |
327 | done(errors::Internal( |
328 | "A collective Op has been called in a context in which " |
329 | "a CollectiveExecutor has not been provided." )); |
330 | } |
331 | |
332 | virtual void CompleteGroupAsync(const DeviceAttributes& device, |
333 | CollGroupParams* group_params, |
334 | CancellationManager* cancel_mgr, |
335 | StatusCallback done) { |
336 | return cem_->GetParamResolver()->CompleteGroupAsync(device, group_params, |
337 | cancel_mgr, done); |
338 | } |
339 | |
340 | virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) { |
341 | return cem_->GetParamResolver()->LookupGroup(group_key, group); |
342 | } |
343 | |
344 | // Runs the potentially-blocking closure/expensive callback. |
345 | virtual void RunClosure(std::function<void()> closure) = 0; |
346 | |
347 | virtual CollectiveRemoteAccess* remote_access() { return nullptr; } |
348 | |
349 | // `WaitForDependencies` and `Launched` are used for fine-grained control of |
350 | // execution order between collective instances. These functions are intended |
351 | // to be called in `Run` function of collective implementations, and may be |
352 | // used to make part, or whole, of the collective execution ordered with |
353 | // respect to other collective instances. |
354 | // |
355 | // `WaitForDependencies` will block until it is safe to continue the callee's |
356 | // execution, where safety is defined as: ordered with respect to the |
357 | // collective instances defined in the callee's `wait_for` attribute. |
358 | virtual void WaitForDependencies(const CollectiveParams& col_params) {} |
359 | // `UnblockDependencies` unblocks the dependent collective instances by |
360 | // recording that this caller's device has completed the critical portion of |
361 | // the collective execution. |
362 | virtual void UnblockDependencies(const CollectiveParams& col_params) {} |
363 | |
364 | // Used to designate an invalid group or instance key. |
365 | static int64_t kInvalidId; |
366 | |
367 | // Lexically scoped handle for Ref. |
368 | class Handle { |
369 | public: |
370 | explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { |
371 | if (!inherit_ref) ce->Ref(); |
372 | } |
373 | ~Handle() { ce_->Unref(); } |
374 | CollectiveExecutor* get() const { return ce_; } |
375 | |
376 | private: |
377 | CollectiveExecutor* ce_; |
378 | }; |
379 | |
380 | protected: |
381 | explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) |
382 | : cem_(cem) {} |
383 | |
384 | // For use only by derived classes |
385 | static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); |
386 | CollectiveExecutorMgrInterface* cem_; |
387 | |
388 | TF_DISALLOW_COPY_AND_ASSIGN(CollectiveExecutor); |
389 | }; |
390 | |
391 | struct CollectiveContext { |
392 | CollectiveExecutor* col_exec; // Not owned |
393 | NcclCommunicatorInterface* nccl_communicator; // Not owned |
394 | const DeviceMgr* dev_mgr; // Not owned |
395 | OpKernelContext* op_ctx; // Not owned |
396 | OpKernelContext::Params* op_params; // Not owned |
397 | core::IntrusivePtr<const CollectiveParams> col_params; |
398 | const string exec_key; |
399 | const int64_t step_id; |
400 | const Tensor* input; // Not owned |
401 | Tensor* output; // Not owned |
402 | Device* device; // The device for which this instance labors |
403 | const string device_name; |
404 | DeviceLocality device_locality; |
405 | |
406 | CollectiveContext(CollectiveExecutor* col_exec, |
407 | NcclCommunicatorInterface* nccl_communicator, |
408 | const DeviceMgr* dev_mgr, OpKernelContext* ctx, |
409 | OpKernelContext::Params* op_params, |
410 | const CollectiveParams* col_params, const string& exec_key, |
411 | int64_t step_id, const Tensor* input, Tensor* output); |
412 | }; |
413 | |
414 | class NcclCommunicatorInterface { |
415 | public: |
416 | virtual ~NcclCommunicatorInterface() = default; |
417 | |
418 | virtual string GenerateCommunicatorKey() = 0; |
419 | |
420 | virtual void Enqueue(std::shared_ptr<CollectiveContext> col_ctx, |
421 | StatusCallback done) = 0; |
422 | |
423 | virtual void StartAbort(const Status& s) = 0; |
424 | }; |
425 | |
426 | // Interface of a Collective Op implementation. Each specific CollectiveOp will |
427 | // implement this interface and register the implementation via the |
428 | // CollectiveRegistry detailed below. See common_runtime/ring_reducer and |
429 | // common_runtime/hierarchical_tree_broadcaster for examples. |
430 | class CollectiveImplementationInterface : public core::RefCounted { |
431 | public: |
432 | virtual ~CollectiveImplementationInterface() = default; |
433 | |
434 | // Initializes the portions of `col_params` specific to this |
435 | // implementation. Called exactly once for every Collective instance during |
436 | // the CollectiveParams resolution process when the graph is first executed, |
437 | // at the end of `CompleteInstanceLocal()`. |
438 | // NOTE(ayushd): This is effectively a static function because it modifies the |
439 | // `col_params` passed in and should not manipulate any data members. However |
440 | // because it is virtual and needs to be implemented by every derived class we |
441 | // do not mark it as static. |
442 | virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; |
443 | |
444 | // Prepares the CollectiveContext for executing this CollectiveImplementation. |
445 | // Called from CollectiveExecutor right before calling Run(). The |
446 | // CollectiveContext passed in must outlive the CollectiveImplementation |
447 | // object. |
448 | virtual Status InitializeCollectiveContext( |
449 | std::shared_ptr<CollectiveContext> col_ctx) = 0; |
450 | |
451 | // Processes and moves data according to the logic of this Collective |
452 | // implementation. Relies on appropriate initialization of op-specific |
453 | // CollectiveParams in InitializeCollectiveParams(), as well as appropriate |
454 | // context initialization in InitializeCollectiveContext(). |
455 | virtual void Run(StatusCallback done) = 0; |
456 | }; |
457 | |
458 | // Static-methods only class for registering and looking up collective |
459 | // implementations. |
460 | class CollectiveRegistry { |
461 | public: |
462 | using Factory = std::function<CollectiveImplementationInterface*()>; |
463 | // Looks up a previously registered CollectiveImplementation under |
464 | // `collective_name`. If found, creates an instance of the implementation and |
465 | // assign to `implementation`. |
466 | static Status Lookup(const string& collective_name, |
467 | CollectiveImplementationInterface** implementation); |
468 | |
469 | // Looks up a previously registered CollectiveImplementation under |
470 | // `collective_name`. If found, returns the static instance of this |
471 | // implementation via `implementation`. This instance should only be used to |
472 | // call InitializateCollectiveParams. |
473 | static Status LookupParamResolverInstance( |
474 | const string& collective_name, |
475 | CollectiveImplementationInterface** implementation); |
476 | |
477 | // Returns all registered collective implementations. |
478 | static void GetAll( |
479 | std::vector<CollectiveImplementationInterface*>* implementations); |
480 | |
481 | private: |
482 | friend class CollectiveRegistration; |
483 | // Registers a CollectiveImplementation with name `collective_name` and |
484 | // factory `factory`. The latter is a function used to create instances of |
485 | // the CollectiveImplementation. Also creates a static instance of the |
486 | // implementation - this instance is used during param resolution and should |
487 | // only be used to call InitializeCollectiveParams. |
488 | static Status Register(const string& collective_name, Factory factory); |
489 | |
490 | static Status LookupHelper(const string& collective_name, |
491 | CollectiveImplementationInterface** implementation, |
492 | bool param_resolver); |
493 | }; |
494 | |
495 | // Class used to call CollectiveRegistry::Register. This should only be used to |
496 | // create a global static object. |
497 | class CollectiveRegistration { |
498 | public: |
499 | CollectiveRegistration(const string& collective_name, |
500 | CollectiveRegistry::Factory factory) { |
501 | TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); |
502 | } |
503 | }; |
504 | |
505 | #define REGISTER_COLLECTIVE(name, implementation) \ |
506 | static CollectiveRegistration register_##name##_collective( \ |
507 | #name, []() { return new implementation; }); |
508 | |
509 | } // namespace tensorflow |
510 | |
511 | #endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ |
512 | |