1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/graph/collective_order.h" |
16 | |
17 | #include "absl/container/flat_hash_map.h" |
18 | #include "absl/container/flat_hash_set.h" |
19 | #include "tensorflow/core/graph/algorithm.h" |
20 | |
21 | namespace tensorflow { |
22 | namespace { |
23 | |
24 | // Find all CollectiveReduce nodes and the existing data dependencies between |
25 | // them. |
26 | Status DiscoverDataDependencies( |
27 | const Graph* graph, std::vector<Node*>* collective_nodes, |
28 | std::vector<int32>* instance_keys, |
29 | absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies) { |
30 | Status s; |
31 | // Algorithm: do Reverse DFS starting at sink. `node_leave` is called when |
32 | // all parents of `node` have been visited. At that point, |
33 | // `data_dependencies[node]` is a list containing `instance_key` of every |
34 | // `CollectiveReduce` on which `node` has a data dependency. |
35 | // For this node's children, add all these instance keys. Also, if this node |
36 | // is collective, add as a dependency for the children. |
37 | auto node_leave = [collective_nodes, instance_keys, data_dependencies, |
38 | &s](Node* node) { |
39 | int32_t instance_key; |
40 | bool enter_node = |
41 | node->IsCollective() && node->type_string() == "CollectiveReduce" ; |
42 | if (enter_node) { |
43 | Status get_attr_status = |
44 | GetNodeAttr(node->attrs(), "instance_key" , &instance_key); |
45 | s.Update(get_attr_status); |
46 | collective_nodes->push_back(node); |
47 | instance_keys->push_back(instance_key); |
48 | VLOG(2) << "collective node " << node->DebugString(); |
49 | } |
50 | // Avoid reference invalidation of `node_deps`. |
51 | data_dependencies->reserve(data_dependencies->size() + 1 + |
52 | node->out_edges().size()); |
53 | const auto& node_deps = (*data_dependencies)[node]; |
54 | for (const Edge* out_edge : node->out_edges()) { |
55 | auto& child_deps = (*data_dependencies)[out_edge->dst()]; |
56 | child_deps.insert(node_deps.begin(), node_deps.end()); |
57 | if (enter_node && s.ok()) { |
58 | child_deps.insert(instance_key); |
59 | } |
60 | } |
61 | }; |
62 | ReverseDFS(*graph, nullptr, node_leave); |
63 | return s; |
64 | } |
65 | |
66 | // Given a list of `collective_nodes` and `data_dependencies` between the |
67 | // collective nodes, create control dependencies between concurrent collectives |
68 | // and store in `dependency_edges`. |
69 | // If there exists an edge a -> b then `dependency_edges[a]` contains `b` |
70 | Status CreateControlDependencies( |
71 | const std::vector<Node*>& collective_nodes, |
72 | const std::vector<int32>& instance_keys, |
73 | absl::flat_hash_map<Node*, absl::flat_hash_set<int32>>* data_dependencies, |
74 | absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>* dependency_edges) { |
75 | // If there exists some path a -> ... -> b then `all_paths[a]` contains `b` |
76 | absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> all_paths; |
77 | for (int i = 0; i < collective_nodes.size() - 1; i++) { |
78 | if (!collective_nodes[i]->IsCollective() || |
79 | collective_nodes[i]->type_string() != "CollectiveReduce" ) { |
80 | return errors::Internal("Unexpected node " , |
81 | collective_nodes[i]->DebugString()); |
82 | } |
83 | const auto& deps_i = (*data_dependencies)[collective_nodes[i]]; |
84 | for (int j = i + 1; j < collective_nodes.size(); j++) { |
85 | if (collective_nodes[i]->requested_device() != |
86 | collective_nodes[j]->requested_device()) { |
87 | continue; |
88 | } |
89 | if (instance_keys[i] == instance_keys[j]) { |
90 | return errors::Internal("Unexpected same instance_key " , |
91 | instance_keys[i], |
92 | " on 2 nodes with the same device " , |
93 | collective_nodes[i]->requested_device()); |
94 | } |
95 | const auto& deps_j = (*data_dependencies)[collective_nodes[j]]; |
96 | if (deps_i.find(instance_keys[j]) == deps_i.end() && |
97 | deps_j.find(instance_keys[i]) == deps_j.end()) { |
98 | int src_idx = instance_keys[i] > instance_keys[j] ? i : j; |
99 | int dst_idx = instance_keys[i] > instance_keys[j] ? j : i; |
100 | Node* src_node = collective_nodes[src_idx]; |
101 | Node* dst_node = collective_nodes[dst_idx]; |
102 | VLOG(1) << "Adding control dependency from node " << src_node->name() |
103 | << " instance " << instance_keys[src_idx] << " to node " |
104 | << dst_node->name() << " instance " << instance_keys[dst_idx]; |
105 | (*dependency_edges)[src_node].insert(dst_node); |
106 | auto& src_paths = all_paths[src_node]; |
107 | src_paths.insert(dst_node); |
108 | for (Node* downstream_node : all_paths[dst_node]) { |
109 | src_paths.insert(downstream_node); |
110 | } |
111 | } |
112 | } |
113 | } |
114 | |
115 | // Prune dependency edges so that if there are edges a -> b, b -> c, and a -> |
116 | // c, then remove a -> c. This dependency would be handled naturally during |
117 | // op scheduling. |
118 | for (int i = 0; i < collective_nodes.size(); ++i) { |
119 | Node* node = collective_nodes[i]; |
120 | auto& neighbor_set = (*dependency_edges)[node]; |
121 | std::vector<Node*> neighbor_list(neighbor_set.begin(), neighbor_set.end()); |
122 | // For all n1, n2 in `neighbor_list` if there is a path from n1 -> n2 then |
123 | // eliminate n2 from `neighbor_set` and `neighbor_list`. We remove from |
124 | // `neighbor_list` by replacing with a `nullptr`, hence the `nullptr` checks |
125 | // below. |
126 | for (int j = 0; j < neighbor_list.size(); ++j) { |
127 | Node* n1 = neighbor_list[j]; |
128 | if (n1 == nullptr) continue; |
129 | auto& n1_paths = all_paths[n1]; |
130 | for (int k = 0; k < neighbor_list.size(); ++k) { |
131 | Node* n2 = neighbor_list[k]; |
132 | if (j == k || n2 == nullptr) continue; |
133 | if (n1_paths.find(n2) != n1_paths.end()) { |
134 | neighbor_set.erase(n2); |
135 | neighbor_list[k] = nullptr; |
136 | } |
137 | } |
138 | } |
139 | } |
140 | |
141 | return OkStatus(); |
142 | } |
143 | |
144 | // Insert control dependencies defined by `dependency_edges` in `graph`. If |
145 | // `order_type` is `kEdges`, insert explicit control edges, else if `order_type` |
146 | // is `kAttrs`, encode dependencies as an attribute on collective node. |
147 | Status InsertControlDependencies( |
148 | Graph* graph, GraphCollectiveOrder order_type, |
149 | const absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>>& |
150 | dependency_edges) { |
151 | if (order_type == GraphCollectiveOrder::kEdges) { |
152 | for (const auto& pair : dependency_edges) { |
153 | Node* src_node = pair.first; |
154 | for (Node* dst_node : pair.second) { |
155 | graph->AddControlEdge(src_node, dst_node); |
156 | } |
157 | } |
158 | } else if (order_type == GraphCollectiveOrder::kAttrs) { |
159 | // `wait_for` is the inverse of `dependency_edges`, i.e. `wait_for[node]` |
160 | // contains the list of instance keys for which `node` must wait. |
161 | absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> wait_for; |
162 | for (const auto& pair : dependency_edges) { |
163 | int32_t src_instance; |
164 | TF_RETURN_IF_ERROR( |
165 | GetNodeAttr(pair.first->attrs(), "instance_key" , &src_instance)); |
166 | for (Node* dst_node : pair.second) { |
167 | wait_for[dst_node].insert(src_instance); |
168 | } |
169 | } |
170 | for (const auto& pair : wait_for) { |
171 | std::vector<int32> wait_for_list(pair.second.begin(), pair.second.end()); |
172 | pair.first->ClearAttr("wait_for" ); |
173 | pair.first->AddAttr("wait_for" , wait_for_list); |
174 | } |
175 | } else { |
176 | return errors::Internal("Unexpected GraphCollectiveOrder type " , |
177 | static_cast<int>(order_type)); |
178 | } |
179 | return OkStatus(); |
180 | } |
181 | |
182 | } // namespace |
183 | |
184 | Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) { |
185 | // `instance_keys[i]` corresponds to `collective_nodes[i]` |
186 | std::vector<Node*> collective_nodes; |
187 | std::vector<int32> instance_keys; |
188 | // node -> set of collectives on which node depends. |
189 | absl::flat_hash_map<Node*, absl::flat_hash_set<int32>> data_dependencies; |
190 | TF_RETURN_IF_ERROR(DiscoverDataDependencies( |
191 | graph, &collective_nodes, &instance_keys, &data_dependencies)); |
192 | |
193 | if (collective_nodes.empty()) return OkStatus(); |
194 | |
195 | absl::flat_hash_map<Node*, absl::flat_hash_set<Node*>> dependency_edges; |
196 | // For all pairs of collective nodes n1 and n2 on the same device, if n1 does |
197 | // not depend on n2 and n2 does not depend on n1, then they are potentially |
198 | // concurrent. Create an arbitrary, deterministic ordering between them. |
199 | TF_RETURN_IF_ERROR(CreateControlDependencies( |
200 | collective_nodes, instance_keys, &data_dependencies, &dependency_edges)); |
201 | |
202 | return InsertControlDependencies(graph, order_type, dependency_edges); |
203 | } |
204 | |
205 | } // namespace tensorflow |
206 | |