1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/graph/collective_order.h"
16
17#include "absl/container/flat_hash_map.h"
18#include "absl/container/flat_hash_set.h"
19#include "tensorflow/core/graph/algorithm.h"
20
21namespace tensorflow {
22namespace {
23
24// Find all CollectiveReduce nodes and the existing data dependencies between
25// them.
26Status DiscoverDataDependencies(
27 const Graph* graph, std::vector<Node*>* collective_nodes,
28 std::vector<int32>* instance_keys,
29 absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies) {
30 Status s;
31 // Algorithm: do Reverse DFS starting at sink. `node_leave` is called when
32 // all parents of `node` have been visited. At that point,
33 // `data_dependencies[node]` is a list containing `instance_key` of every
34 // `CollectiveReduce` on which `node` has a data dependency.
35 // For this node's children, add all these instance keys. Also, if this node
36 // is collective, add as a dependency for the children.
37 auto node_leave = [collective_nodes, instance_keys, data_dependencies,
38 &s](Node* node) {
39 int32_t instance_key;
40 bool enter_node =
41 node->IsCollective() && node->type_string() == "CollectiveReduce";
42 if (enter_node) {
43 Status get_attr_status =
44 GetNodeAttr(node->attrs(), "instance_key", &instance_key);
45 s.Update(get_attr_status);
46 collective_nodes->push_back(node);
47 instance_keys->push_back(instance_key);
48 VLOG(2) << "collective node " << node->DebugString();
49 }
50 // Avoid reference invalidation of `node_deps`.
51 data_dependencies->reserve(data_dependencies->size() + 1 +
52 node->out_edges().size());
53 const auto& node_deps = (*data_dependencies)[node];
54 for (const Edge* out_edge : node->out_edges()) {
55 auto& child_deps = (*data_dependencies)[out_edge->dst()];
56 child_deps.insert(node_deps.begin(), node_deps.end());
57 if (enter_node && s.ok()) {
58 child_deps.insert(instance_key);
59 }
60 }
61 };
62 ReverseDFS(*graph, nullptr, node_leave);
63 return s;
64}
65
66// Given a list of `collective_nodes` and `data_dependencies` between the
67// collective nodes, create control dependencies between concurrent collectives
68// and store in `dependency_edges`.
69// If there exists an edge a -> b then `dependency_edges[a]` contains `b`
70Status CreateControlDependencies(
71 const std::vector<Node*>& collective_nodes,
72 const std::vector<int32>& instance_keys,
73 absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies,
74 absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>* dependency_edges) {
75 // If there exists some path a -> ... -> b then `all_paths[a]` contains `b`
76 absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> all_paths;
77 for (int i = 0; i < collective_nodes.size() - 1; i++) {
78 if (!collective_nodes[i]->IsCollective() ||
79 collective_nodes[i]->type_string() != "CollectiveReduce") {
80 return errors::Internal("Unexpected node ",
81 collective_nodes[i]->DebugString());
82 }
83 const auto& deps_i = (*data_dependencies)[collective_nodes[i]];
84 for (int j = i + 1; j < collective_nodes.size(); j++) {
85 if (collective_nodes[i]->requested_device() !=
86 collective_nodes[j]->requested_device()) {
87 continue;
88 }
89 if (instance_keys[i] == instance_keys[j]) {
90 return errors::Internal("Unexpected same instance_key ",
91 instance_keys[i],
92 " on 2 nodes with the same device ",
93 collective_nodes[i]->requested_device());
94 }
95 const auto& deps_j = (*data_dependencies)[collective_nodes[j]];
96 if (deps_i.find(instance_keys[j]) == deps_i.end() &&
97 deps_j.find(instance_keys[i]) == deps_j.end()) {
98 int src_idx = instance_keys[i] > instance_keys[j] ? i : j;
99 int dst_idx = instance_keys[i] > instance_keys[j] ? j : i;
100 Node* src_node = collective_nodes[src_idx];
101 Node* dst_node = collective_nodes[dst_idx];
102 VLOG(1) << "Adding control dependency from node " << src_node->name()
103 << " instance " << instance_keys[src_idx] << " to node "
104 << dst_node->name() << " instance " << instance_keys[dst_idx];
105 (*dependency_edges)[src_node].insert(dst_node);
106 auto& src_paths = all_paths[src_node];
107 src_paths.insert(dst_node);
108 for (Node* downstream_node : all_paths[dst_node]) {
109 src_paths.insert(downstream_node);
110 }
111 }
112 }
113 }
114
115 // Prune dependency edges so that if there are edges a -> b, b -> c, and a ->
116 // c, then remove a -> c. This dependency would be handled naturally during
117 // op scheduling.
118 for (int i = 0; i < collective_nodes.size(); ++i) {
119 Node* node = collective_nodes[i];
120 auto& neighbor_set = (*dependency_edges)[node];
121 std::vector<Node*> neighbor_list(neighbor_set.begin(), neighbor_set.end());
122 // For all n1, n2 in `neighbor_list` if there is a path from n1 -> n2 then
123 // eliminate n2 from `neighbor_set` and `neighbor_list`. We remove from
124 // `neighbor_list` by replacing with a `nullptr`, hence the `nullptr` checks
125 // below.
126 for (int j = 0; j < neighbor_list.size(); ++j) {
127 Node* n1 = neighbor_list[j];
128 if (n1 == nullptr) continue;
129 auto& n1_paths = all_paths[n1];
130 for (int k = 0; k < neighbor_list.size(); ++k) {
131 Node* n2 = neighbor_list[k];
132 if (j == k || n2 == nullptr) continue;
133 if (n1_paths.find(n2) != n1_paths.end()) {
134 neighbor_set.erase(n2);
135 neighbor_list[k] = nullptr;
136 }
137 }
138 }
139 }
140
141 return OkStatus();
142}
143
144// Insert control dependencies defined by `dependency_edges` in `graph`. If
145// `order_type` is `kEdges`, insert explicit control edges, else if `order_type`
146// is `kAttrs`, encode dependencies as an attribute on collective node.
147Status InsertControlDependencies(
148 Graph* graph, GraphCollectiveOrder order_type,
149 const absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>&
150 dependency_edges) {
151 if (order_type == GraphCollectiveOrder::kEdges) {
152 for (const auto& pair : dependency_edges) {
153 Node* src_node = pair.first;
154 for (Node* dst_node : pair.second) {
155 graph->AddControlEdge(src_node, dst_node);
156 }
157 }
158 } else if (order_type == GraphCollectiveOrder::kAttrs) {
159 // `wait_for` is the inverse of `dependency_edges`, i.e. `wait_for[node]`
160 // contains the list of instance keys for which `node` must wait.
161 absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> wait_for;
162 for (const auto& pair : dependency_edges) {
163 int32_t src_instance;
164 TF_RETURN_IF_ERROR(
165 GetNodeAttr(pair.first->attrs(), "instance_key", &src_instance));
166 for (Node* dst_node : pair.second) {
167 wait_for[dst_node].insert(src_instance);
168 }
169 }
170 for (const auto& pair : wait_for) {
171 std::vector<int32> wait_for_list(pair.second.begin(), pair.second.end());
172 pair.first->ClearAttr("wait_for");
173 pair.first->AddAttr("wait_for", wait_for_list);
174 }
175 } else {
176 return errors::Internal("Unexpected GraphCollectiveOrder type ",
177 static_cast<int>(order_type));
178 }
179 return OkStatus();
180}
181
182} // namespace
183
184Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) {
185 // `instance_keys[i]` corresponds to `collective_nodes[i]`
186 std::vector<Node*> collective_nodes;
187 std::vector<int32> instance_keys;
188 // node -> set of collectives on which node depends.
189 absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> data_dependencies;
190 TF_RETURN_IF_ERROR(DiscoverDataDependencies(
191 graph, &collective_nodes, &instance_keys, &data_dependencies));
192
193 if (collective_nodes.empty()) return OkStatus();
194
195 absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> dependency_edges;
196 // For all pairs of collective nodes n1 and n2 on the same device, if n1 does
197 // not depend on n2 and n2 does not depend on n1, then they are potentially
198 // concurrent. Create an arbitrary, deterministic ordering between them.
199 TF_RETURN_IF_ERROR(CreateControlDependencies(
200 collective_nodes, instance_keys, &data_dependencies, &dependency_edges));
201
202 return InsertControlDependencies(graph, order_type, dependency_edges);
203}
204
205} // namespace tensorflow
206