1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h"
16
17#include <algorithm>
18#include <queue>
19
20#include "absl/strings/str_cat.h"
21#include "tensorflow/core/common_runtime/optimize_cross_host_control_deps.h"
22#include "tensorflow/core/framework/node_def.pb.h"
23#include "tensorflow/core/framework/node_def_builder.h"
24#include "tensorflow/core/platform/errors.h"
25
26namespace tensorflow {
27namespace {
28
29constexpr int kOptimizeCrossHostEdgesTheshold = 8;
30
31// A helper for rewriting nodes assigned to a virtual composite device.
32class ReplicateHelper {
33 public:
34 // Initialize replicated nodes with nullptr.
35 Status InitializeNode(const Node* node, int num_allowed_devices) {
36 if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) {
37 return errors::InvalidArgument("Node ", node->name(),
38 " has been replicated.");
39 }
40 std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr);
41 replicated_nodes_map_.emplace(node, std::move(replicated_nodes));
42 return OkStatus();
43 }
44
45 // Replicate the given node to an allowed device.
46 Status ReplicateNode(const Node* node,
47 const std::vector<string>& allowed_devices,
48 int allowed_device_index, Graph* graph) {
49 auto& replicated_nodes = replicated_nodes_map_.at(node);
50 if (replicated_nodes[allowed_device_index] != nullptr) {
51 return OkStatus();
52 }
53 const auto& device = allowed_devices.at(allowed_device_index);
54 NodeDef node_def = node->def();
55 const string suffix = strings::StrCat("/R", allowed_device_index);
56 node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix)));
57 TF_ASSIGN_OR_RETURN(Node * replicated_node, graph->AddNode(node_def));
58 replicated_node->set_assigned_device_name(device);
59 if (replicated_node->IsArg()) {
60 replicated_node->AddAttr("sub_index", allowed_device_index);
61 }
62 replicated_nodes[allowed_device_index] = replicated_node;
63 return OkStatus();
64 }
65
66 // Replace an edge (a regular device -> composite device) with
67 // N edges (a regular device -> allowed devices).
68 void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge,
69 Graph* graph) const {
70 Node* src = edge->src();
71 const std::vector<Node*>& dst_replicated_nodes =
72 replicated_nodes_map_.at(edge->dst());
73 for (Node* dst : dst_replicated_nodes) {
74 // Skip a replicated dst node without any consumer.
75 if (dst == nullptr) {
76 continue;
77 }
78 graph->AddEdge(src, edge->src_output(), dst, edge->dst_input());
79 }
80 }
81
82 // Replace an edge (composite device -> composite device) with
83 // N edges (allowed devices -> allowed devices).
84 Status ReplicateFromCompositeDeviceToCompositeDevice(
85 const Edge* edge, const std::vector<string>& allowed_devices,
86 Graph* graph) {
87 const std::vector<Node*>& src_replicated_nodes =
88 replicated_nodes_map_.at(edge->src());
89 const std::vector<Node*>& dst_replicated_nodes =
90 replicated_nodes_map_.at(edge->dst());
91 if (src_replicated_nodes.size() != dst_replicated_nodes.size()) {
92 return errors::InvalidArgument(
93 "Nodes assigned to the same composite device should have the "
94 "same number of replicated nodes. Found an edge from node ",
95 edge->src()->name(), " (", src_replicated_nodes.size(),
96 " replicated nodes) to node ", edge->dst()->name(), " (",
97 dst_replicated_nodes.size(), " replicated nodes).");
98 }
99 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
100 Node* dst = dst_replicated_nodes.at(i);
101 // Skip a replicated dst node without any consumer.
102 if (dst == nullptr) {
103 continue;
104 }
105 TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph));
106 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
107 edge->dst_input());
108 }
109 return OkStatus();
110 }
111
112 // Data edge: replace an edge (composite device -> a regular device) with
113 // one edge (one allowed device -> a regular device).
114 // Control edge: replace an edge (composite device -> a regular device) with
115 // N edges (allowed devices -> a regular device).
116 Status ReplicateFromCompositeDeviceToRegularDevice(
117 const Edge* edge, const std::vector<string>& allowed_devices,
118 Graph* graph) {
119 const std::vector<Node*>& src_replicated_nodes =
120 replicated_nodes_map_.at(edge->src());
121 Node* dst = edge->dst();
122 const string& dst_device = dst->assigned_device_name();
123 bool found_src_node = false;
124 for (int i = 0; i < allowed_devices.size(); ++i) {
125 if (allowed_devices.at(i) == dst_device) {
126 TF_RETURN_IF_ERROR(
127 ReplicateNode(edge->src(), allowed_devices, i, graph));
128 graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst,
129 edge->dst_input());
130 found_src_node = true;
131 break;
132 }
133 }
134 if (!found_src_node) {
135 for (int i = 0; i < allowed_devices.size(); ++i) {
136 TF_RETURN_IF_ERROR(
137 ReplicateNode(edge->src(), allowed_devices, i, graph));
138 }
139 if (edge->IsControlEdge()) {
140 for (Node* replicated_node : src_replicated_nodes) {
141 // Duplication check in `Graph::AddControlEdge` is expensive for the
142 // dst node with a lot of input edges. Here each (src, dst) pair
143 // will only occur once so it is safe to skip the duplication check.
144 graph->AddControlEdge(replicated_node, dst,
145 /*allow_duplicates=*/true);
146 }
147 return OkStatus();
148 }
149 if (edge->src()->type_string() == "_Arg") {
150 // This happens when the dst node runs on a host CPU and
151 // captures a function with an arg node assigned to the same
152 // composite device (e.g. ScanDataset).
153 // For this case, we insert a PackOp between replicated nodes and the
154 // dst node. The dst node is responsible for unpacking the packed
155 // tensor.
156 // Add '/Packed' as a substring to the name of the new node, which
157 // could be helpful when debugging the graph.
158 NodeDefBuilder pack_builder(
159 graph->NewName(absl::StrCat(edge->src()->name(), "/Packed")),
160 "Pack");
161 const int num_replicas = src_replicated_nodes.size();
162 pack_builder.Attr("N", num_replicas);
163 const DataType dtype = edge->src()->output_type(edge->src_output());
164 pack_builder.Attr("T", dtype);
165 std::vector<NodeDefBuilder::NodeOut> inputs;
166 inputs.reserve(src_replicated_nodes.size());
167 for (Node* replicated_node : src_replicated_nodes) {
168 inputs.emplace_back(NodeDefBuilder::NodeOut{
169 replicated_node->name(), edge->src_output(), dtype});
170 }
171 pack_builder.Input(inputs);
172 NodeDef pack_def;
173 TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def));
174 TF_ASSIGN_OR_RETURN(Node * pack_node, graph->AddNode(pack_def));
175 pack_node->set_assigned_device_name(dst->assigned_device_name());
176 for (int i = 0; i < src_replicated_nodes.size(); ++i) {
177 graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node,
178 i);
179 }
180 graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input());
181 } else {
182 return errors::InvalidArgument(
183 "Dst node should be assigned to an allowed device. Found an "
184 "edge from node ",
185 edge->src()->name(), " assigned to ",
186 edge->src()->assigned_device_name(), " to node ", dst->name(),
187 " assigned to ", dst_device);
188 }
189 }
190 return OkStatus();
191 }
192
193 private:
194 // Map from original nodes to corresponding replicated nodes.
195 absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_;
196};
197
198// Replicate the nodes in cluster_nodes and update edges.
199Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices,
200 absl::flat_hash_map<Node*, int>* cluster_nodes,
201 ReplicateHelper* helper, Graph* graph) {
202 // Contains nodes in cluster_nodes whose out nodes are all on physical
203 // devices.
204 std::queue<Node*> nodes_ready_to_delete;
205 for (auto& pair : *cluster_nodes) {
206 Node* node = pair.first;
207 for (const Edge* edge : node->out_edges()) {
208 Node* dst = edge->dst();
209 if (dst->assigned_device_name() != node->assigned_device_name()) {
210 // The dst node is assigned to a different device.
211 TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice(
212 edge, allowed_devices, graph));
213 --pair.second;
214 }
215 }
216 // Node is ready to delete when all its consumer nodes are assigned to a
217 // physical device.
218 if (cluster_nodes->at(node) == 0) {
219 nodes_ready_to_delete.push(node);
220 }
221 }
222
223 while (!nodes_ready_to_delete.empty()) {
224 Node* node = nodes_ready_to_delete.front();
225 nodes_ready_to_delete.pop();
226
227 // Update input edges.
228 for (const Edge* edge : node->in_edges()) {
229 Node* src = edge->src();
230 if (src->assigned_device_name() != node->assigned_device_name()) {
231 // The source node is assigned to a different device.
232 helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph);
233 } else {
234 // The source node is assigned to the same composite device.
235 TF_RETURN_IF_ERROR(
236 helper->ReplicateFromCompositeDeviceToCompositeDevice(
237 edge, allowed_devices, graph));
238 if (--(*cluster_nodes)[src] == 0) {
239 nodes_ready_to_delete.push(src);
240 }
241 }
242 }
243
244 // Remove the original node.
245 cluster_nodes->erase(node);
246 graph->RemoveNode(node);
247 }
248 return OkStatus();
249}
250
251} // namespace
252
253Status ReplicatePerReplicaNodesInFunctionGraph(
254 const absl::flat_hash_map<string, const std::vector<string>*>&
255 composite_devices,
256 Graph* graph) {
257 std::set<string> composite_device_names;
258 for (const auto& it : composite_devices) {
259 composite_device_names.insert(it.first);
260 }
261 // Map from a composite device to a cluster of nodes assigned to the
262 // composite device and the numbers of their out edges to process.
263 absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>>
264 composite_device_to_cluster_nodes;
265 for (Node* n : graph->op_nodes()) {
266 if (composite_device_names.find(n->assigned_device_name()) !=
267 composite_device_names.end()) {
268 // TODO(b/145922293): Validate that an _Arg node assigned to a
269 // CompositeDevice should have an attribute indicating that the _Arg node
270 // represents a packed input.
271 composite_device_to_cluster_nodes[n->assigned_device_name()].emplace(
272 n, n->out_edges().size());
273 }
274 }
275
276 if (composite_device_to_cluster_nodes.empty()) {
277 VLOG(1) << "No nodes with composiste device found.";
278 return OkStatus();
279 }
280
281 for (auto& it : composite_device_to_cluster_nodes) {
282 const std::vector<string>& allowed_devices =
283 *composite_devices.at(it.first);
284 if (allowed_devices.empty()) {
285 return errors::InvalidArgument("No allowed device of composite device: ",
286 it.first);
287 }
288 absl::flat_hash_map<Node*, int>& cluster_nodes = it.second;
289 if (allowed_devices.size() == 1) {
290 // Reuse the original nodes if there is only one allowed device.
291 for (const auto& pair : it.second) {
292 Node* n = pair.first;
293 n->set_assigned_device_name(allowed_devices.at(0));
294 if (n->IsArg()) {
295 n->AddAttr("sub_index", 0);
296 }
297 }
298 continue;
299 }
300 ReplicateHelper helper;
301 for (const auto& pair : cluster_nodes) {
302 TF_RETURN_IF_ERROR(
303 helper.InitializeNode(pair.first, allowed_devices.size()));
304 }
305
306 TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes,
307 &helper, graph));
308
309 if (!cluster_nodes.empty()) {
310 return errors::InvalidArgument(
311 "There are still ", cluster_nodes.size(),
312 " nodes on CompositiveDevice ",
313 cluster_nodes.begin()->first->assigned_device_name());
314 }
315 }
316
317 // Optimize cross host control output/input edges. We apply the optimizations
318 // at the end to reduce the newly created cross-host edges caused by
319 // per-replica nodes/edges replications.
320 TF_RETURN_IF_ERROR(OptimizeCrossHostControlOutputEdges(
321 graph, kOptimizeCrossHostEdgesTheshold));
322 TF_RETURN_IF_ERROR(OptimizeCrossHostControlInputEdges(
323 graph, kOptimizeCrossHostEdgesTheshold));
324
325 return OkStatus();
326}
327
328} // namespace tensorflow
329