1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/partitioning_utils.h" |
16 | |
17 | #include <algorithm> |
18 | #include <functional> |
19 | #include <memory> |
20 | #include <optional> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <utility> |
24 | |
25 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
26 | #include "tensorflow/core/framework/function.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/graph/graph.h" |
29 | #include "tensorflow/core/graph/graph_partition.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | namespace { |
34 | |
35 | // A helper to partiton a `graph` given a `device_set` and a `graph`. |
36 | // `partitions` maps device names to the graphdef assigned to that device. |
37 | Status PartitionFunctionGraph( |
38 | const DeviceSet& device_set, Graph* graph, |
39 | std::unordered_map<string, GraphDef>* partitions, |
40 | std::function<string(const Node*)> node_to_loc, |
41 | std::function<string(const Edge*)> get_tensor_name_attr) { |
42 | PartitionOptions partition_options; |
43 | if (node_to_loc != nullptr) { |
44 | partition_options.node_to_loc = node_to_loc; |
45 | } else { |
46 | partition_options.node_to_loc = [](const Node* node) { |
47 | // TODO(iga): To support the distributed case, first split the graph by |
48 | // worker (e.g,. using the master session's `SplitByWorker` policy), and |
49 | // then recursively partition the per-worker shards at the remote |
50 | // worker(s). Currently, we simply split the graph at device boundaries. |
51 | return node->assigned_device_name(); |
52 | }; |
53 | } |
54 | int64_t edge_name_counter = 0; |
55 | partition_options.new_name = [&edge_name_counter](const string& prefix) { |
56 | return strings::StrCat(prefix, "/_" , ++edge_name_counter); |
57 | }; |
58 | partition_options.get_incarnation = |
59 | [&device_set](const string& name) -> int64 { |
60 | const Device* d = device_set.FindDeviceByName(name); |
61 | if (d == nullptr) { |
62 | return PartitionOptions::kIllegalIncarnation; |
63 | } else { |
64 | return d->attributes().incarnation(); |
65 | } |
66 | }; |
67 | partition_options.control_flow_added = false; |
68 | partition_options.get_tensor_name_attr = get_tensor_name_attr; |
69 | |
70 | return Partition(partition_options, graph, partitions); |
71 | } |
72 | |
73 | // A pair of matching Send/Recv ops. |
74 | struct SendRecvPair { |
75 | Node* send_node = nullptr; |
76 | Node* recv_node = nullptr; |
77 | }; |
78 | constexpr char kTensorNameAttr[] = "tensor_name" ; |
79 | |
80 | // Adds a dependency to each pair of matching Send/Recv ops to make the |
81 | // dependency explicit. |
82 | Status MakeSendRecvDependencyExplicit(Graph* graph) { |
83 | // Find all matching Send/Recv pairs. |
84 | absl::flat_hash_map<std::string, SendRecvPair> send_recv_pairs; |
85 | for (Node* node : graph->op_nodes()) { |
86 | if (node->IsSend() || node->IsRecv()) { |
87 | auto tensor_name_it = node->def().attr().find(kTensorNameAttr); |
88 | if (tensor_name_it == node->def().attr().end()) { |
89 | return errors::Internal( |
90 | "'" , kTensorNameAttr, |
91 | "' attribute is not found from node: " , node->DebugString()); |
92 | } |
93 | if (node->IsSend()) { |
94 | send_recv_pairs[tensor_name_it->second.s()].send_node = node; |
95 | } else { |
96 | send_recv_pairs[tensor_name_it->second.s()].recv_node = node; |
97 | } |
98 | } |
99 | } |
100 | |
101 | // Add a control dependency to each pair of matching Send/Recv. |
102 | for (const auto& [tensor_name, send_recv_pair] : send_recv_pairs) { |
103 | if (send_recv_pair.send_node == nullptr || |
104 | send_recv_pair.recv_node == nullptr) { |
105 | return errors::Internal( |
106 | "No matching Send/Recv nodes found for tensor_name = " , tensor_name); |
107 | } |
108 | graph->AddControlEdge(send_recv_pair.send_node, send_recv_pair.recv_node); |
109 | } |
110 | return OkStatus(); |
111 | } |
112 | |
113 | } // namespace |
114 | |
115 | Status PartitionFunctionGraph( |
116 | const DeviceSet& device_set, std::unique_ptr<Graph> graph, |
117 | std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs, |
118 | std::function<string(const Edge*)> get_tensor_name_attr) { |
119 | std::unordered_map<string, GraphDef> partitions; |
120 | TF_RETURN_IF_ERROR( |
121 | PartitionFunctionGraph(device_set, graph.get(), &partitions, |
122 | /*node_to_loc=*/nullptr, get_tensor_name_attr)); |
123 | |
124 | for (auto& partition : partitions) { |
125 | const string& device = partition.first; |
126 | GraphDef& graph_def = partition.second; |
127 | // Each partition gets a new graph. |
128 | std::unique_ptr<Graph> subgraph( |
129 | new Graph(graph->flib_def().default_registry())); |
130 | GraphConstructorOptions opts; |
131 | opts.allow_internal_ops = true; |
132 | opts.expect_device_spec = true; |
133 | TF_RETURN_IF_ERROR( |
134 | ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get())); |
135 | subgraphs->emplace(device, std::move(subgraph)); |
136 | } |
137 | |
138 | return OkStatus(); |
139 | } |
140 | |
141 | StatusOr<std::unique_ptr<Graph>> InsertTransferOps( |
142 | const DeviceSet& device_set, std::unique_ptr<Graph> graph) { |
143 | // Skip transfer op insertion if the graph nodes are not assigned to multiple |
144 | // devices. |
145 | auto node_to_loc = [](const Node* node) { |
146 | return node->assigned_device_name(); |
147 | }; |
148 | bool has_multiple_devices = false; |
149 | absl::optional<std::string> location; |
150 | for (const Node* node : graph->op_nodes()) { |
151 | if (location) { |
152 | if (*location != node_to_loc(node)) { |
153 | has_multiple_devices = true; |
154 | break; |
155 | } |
156 | } else { |
157 | location = node_to_loc(node); |
158 | } |
159 | } |
160 | if (!has_multiple_devices) { |
161 | return graph; |
162 | } |
163 | |
164 | // Transfer ops are needed as there are multiple devices, so proceed with the |
165 | // partitioning. |
166 | auto new_graph = std::make_unique<Graph>(graph->flib_def()); |
167 | |
168 | std::unordered_map<string, GraphDef> partitions; |
169 | TF_RETURN_IF_ERROR(PartitionFunctionGraph(device_set, graph.get(), |
170 | &partitions, node_to_loc, |
171 | /*get_tensor_name_attr=*/nullptr)); |
172 | |
173 | GraphDef merged_graph_def; |
174 | if (!partitions.empty()) { |
175 | auto iter = partitions.begin(); |
176 | merged_graph_def = std::move(iter->second); |
177 | while (++iter != partitions.end()) { |
178 | // TODO(b/220440252): MergeFrom() does memory copies when merging repeated |
179 | // fields. Ideally, we can merge repeated fields by 'moving' data. |
180 | // Consider using `proto2::util::MoveToEnd()` or so, once it is open |
181 | // sourced. |
182 | merged_graph_def.MergeFrom(iter->second); |
183 | } |
184 | } |
185 | |
186 | GraphConstructorOptions opts; |
187 | opts.allow_internal_ops = true; |
188 | opts.expect_device_spec = true; |
189 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(merged_graph_def), |
190 | new_graph.get())); |
191 | |
192 | TF_RETURN_IF_ERROR(MakeSendRecvDependencyExplicit(new_graph.get())); |
193 | |
194 | return std::move(new_graph); |
195 | } |
196 | |
197 | Status UpdateArgAndRetvalMetadata( |
198 | Graph* graph, std::vector<FunctionArgIndex>* arg_indices, |
199 | std::vector<int>* ret_indices, |
200 | std::vector<AllocatorAttributes>* arg_alloc_attrs, |
201 | std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device) { |
202 | std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes; |
203 | std::vector<std::pair<Node*, int>> ret_nodes; |
204 | const AttrValue* attr_value; |
205 | |
206 | // Find the Arg and Retval nodes, along with their corresponding indices |
207 | // in the original function. |
208 | for (Node* node : graph->op_nodes()) { |
209 | if (node->IsArg()) { |
210 | TF_RETURN_IF_ERROR(node->attrs().Find("index" , &attr_value)); |
211 | int index = static_cast<int>(attr_value->i()); |
212 | int sub_index = -1; |
213 | if (node->attrs().Find("sub_index" , &attr_value).ok()) { |
214 | sub_index = static_cast<int>(attr_value->i()); |
215 | } |
216 | arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index)); |
217 | } else if (node->IsRetval()) { |
218 | TF_RETURN_IF_ERROR(node->attrs().Find("index" , &attr_value)); |
219 | int index = static_cast<int>(attr_value->i()); |
220 | ret_nodes.emplace_back(node, index); |
221 | } |
222 | } |
223 | |
224 | // Sort the nodes by index so that the order is stable. |
225 | // |
226 | // In particular, this enables calling a single-partition function with |
227 | // the same signature as the original unpartitioned function. |
228 | auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a, |
229 | std::pair<Node*, FunctionArgIndex> b) { |
230 | return std::tie(a.second.index, a.second.sub_index) < |
231 | std::tie(b.second.index, b.second.sub_index); |
232 | }; |
233 | std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator); |
234 | auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) { |
235 | return a.second < b.second; |
236 | }; |
237 | std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator); |
238 | |
239 | arg_indices->reserve(arg_nodes.size()); |
240 | for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second); |
241 | ret_indices->reserve(ret_nodes.size()); |
242 | for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second); |
243 | |
244 | for (int i = 0; i < arg_nodes.size(); ++i) { |
245 | Node* arg = arg_nodes[i].first; |
246 | arg->AddAttr("index" , i); |
247 | TF_RETURN_IF_ERROR(arg->attrs().Find("T" , &attr_value)); |
248 | if (arg_alloc_attrs != nullptr) { |
249 | AllocatorAttributes alloc_attr; |
250 | DataType type = attr_value->type(); |
251 | MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type) |
252 | : MTypeFromDType(type); |
253 | if (mtype == HOST_MEMORY) { |
254 | alloc_attr.set_on_host(true); |
255 | } |
256 | arg_alloc_attrs->push_back(alloc_attr); |
257 | } |
258 | } |
259 | for (int i = 0; i < ret_nodes.size(); ++i) { |
260 | Node* ret = ret_nodes[i].first; |
261 | ret->AddAttr("index" , i); |
262 | TF_RETURN_IF_ERROR(ret->attrs().Find("T" , &attr_value)); |
263 | if (ret_alloc_attrs) { |
264 | AllocatorAttributes alloc_attr; |
265 | DataType type = attr_value->type(); |
266 | MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type) |
267 | : MTypeFromDType(type); |
268 | if (mtype == HOST_MEMORY) { |
269 | alloc_attr.set_on_host(true); |
270 | } |
271 | ret_alloc_attrs->push_back(alloc_attr); |
272 | } |
273 | } |
274 | |
275 | return OkStatus(); |
276 | } |
277 | |
278 | string FunctionNameGenerator::GetName() { |
279 | while (true) { |
280 | const string candidate = strings::StrCat(name_, "_" , counter_++); |
281 | if (flib_def_->Find(candidate) == nullptr) { |
282 | return candidate; |
283 | } |
284 | } |
285 | } |
286 | |
287 | } // namespace tensorflow |
288 | |