1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
17
18#include <vector>
19
20#include "tensorflow/core/common_runtime/base_collective_executor.h"
21#include "tensorflow/core/framework/collective.h"
22
23namespace tensorflow {
24
25// Hierarchical tree-algorithm implementation of collective broadcast.
26class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
27 public:
28 HierarchicalTreeBroadcaster();
29 ~HierarchicalTreeBroadcaster() override = default;
30
31 // Establishes the subdiv permutations needed for a hierarchical broadcast.
32 // If all devices are local, establishes a single subdiv comprising all
33 // devices. If any devices are on a different task, establishes n+1 subdivs
34 // for n tasks.
35 // The first subdiv comprises one device per task which gets the tensor on
36 // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task
37 // i.
38 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
39
40 // Initializes members of CollectiveContext not yet initialized, i.e. device
41 // and device_locality. Also saves the CollectiveContext in this object.
42 Status InitializeCollectiveContext(
43 std::shared_ptr<CollectiveContext> col_ctx) override;
44
45 // Begins async execution of the hierarchical tree broadcast.
46 // Must be called in a blockable thread.
47 // TODO(b/80529858): remove the previous warning when we have a dedicated
48 // collective threadpool.
49 void Run(StatusCallback done) override;
50
51 // Returns the rank of the device from which this device should receive
52 // its value, -1 if no value should be received.
53 static int TreeRecvFrom(const CollectiveParams& cp, int subdiv);
54
55 // Populates targets with the ranks of the devices to which this device
56 // should forward the value.
57 static void TreeSendTo(const CollectiveParams& cp, int subdiv,
58 std::vector<int>* targets);
59
60 private:
61 // Get the task to which the device at `device_rank` belongs.
62 int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task);
63
64 // Sends `src_tensor` asynchronously from this device to device at `dst_rank`
65 // in `subdiv`. Calls `done` upon completion.
66 void DispatchSend(int subdiv, int dst_rank, int src_rank,
67 const Tensor* src_tensor, const StatusCallback& done);
68
69 // Receives a tensor into the memory buffer owned by `dst_tensor` at this
70 // device from device at `src_rank` in `subdiv`. Calls `done` upon
71 // completion.
72 void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
73 const StatusCallback& done);
74
75 // Executes the hierarchical broadcast defined by this op.
76 void RunTree();
77
78 std::shared_ptr<CollectiveContext> col_ctx_;
79 const CollectiveParams* col_params_; // Not owned
80 StatusCallback done_;
81 Status status_;
82 bool is_source_;
83};
84
85} // namespace tensorflow
86#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
87