1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/common_runtime/base_collective_executor.h" |
21 | #include "tensorflow/core/framework/collective.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Hierarchical tree-algorithm implementation of collective broadcast. |
26 | class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { |
27 | public: |
28 | HierarchicalTreeBroadcaster(); |
29 | ~HierarchicalTreeBroadcaster() override = default; |
30 | |
31 | // Establishes the subdiv permutations needed for a hierarchical broadcast. |
32 | // If all devices are local, establishes a single subdiv comprising all |
33 | // devices. If any devices are on a different task, establishes n+1 subdivs |
34 | // for n tasks. |
35 | // The first subdiv comprises one device per task which gets the tensor on |
36 | // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task |
37 | // i. |
38 | Status InitializeCollectiveParams(CollectiveParams* col_params) override; |
39 | |
40 | // Initializes members of CollectiveContext not yet initialized, i.e. device |
41 | // and device_locality. Also saves the CollectiveContext in this object. |
42 | Status InitializeCollectiveContext( |
43 | std::shared_ptr<CollectiveContext> col_ctx) override; |
44 | |
45 | // Begins async execution of the hierarchical tree broadcast. |
46 | // Must be called in a blockable thread. |
47 | // TODO(b/80529858): remove the previous warning when we have a dedicated |
48 | // collective threadpool. |
49 | void Run(StatusCallback done) override; |
50 | |
51 | // Returns the rank of the device from which this device should receive |
52 | // its value, -1 if no value should be received. |
53 | static int TreeRecvFrom(const CollectiveParams& cp, int subdiv); |
54 | |
55 | // Populates targets with the ranks of the devices to which this device |
56 | // should forward the value. |
57 | static void TreeSendTo(const CollectiveParams& cp, int subdiv, |
58 | std::vector<int>* targets); |
59 | |
60 | private: |
61 | // Get the task to which the device at `device_rank` belongs. |
62 | int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task); |
63 | |
64 | // Sends `src_tensor` asynchronously from this device to device at `dst_rank` |
65 | // in `subdiv`. Calls `done` upon completion. |
66 | void DispatchSend(int subdiv, int dst_rank, int src_rank, |
67 | const Tensor* src_tensor, const StatusCallback& done); |
68 | |
69 | // Receives a tensor into the memory buffer owned by `dst_tensor` at this |
70 | // device from device at `src_rank` in `subdiv`. Calls `done` upon |
71 | // completion. |
72 | void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor, |
73 | const StatusCallback& done); |
74 | |
75 | // Executes the hierarchical broadcast defined by this op. |
76 | void RunTree(); |
77 | |
78 | std::shared_ptr<CollectiveContext> col_ctx_; |
79 | const CollectiveParams* col_params_; // Not owned |
80 | StatusCallback done_; |
81 | Status status_; |
82 | bool is_source_; |
83 | }; |
84 | |
85 | } // namespace tensorflow |
86 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ |
87 | |