1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/partitioning_utils.h"
16
17#include <algorithm>
18#include <functional>
19#include <memory>
20#include <optional>
21#include <string>
22#include <unordered_map>
23#include <utility>
24
25#include "tensorflow/core/common_runtime/graph_constructor.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/graph/graph.h"
29#include "tensorflow/core/graph/graph_partition.h"
30
31namespace tensorflow {
32
33namespace {
34
35// A helper to partiton a `graph` given a `device_set` and a `graph`.
36// `partitions` maps device names to the graphdef assigned to that device.
37Status PartitionFunctionGraph(
38 const DeviceSet& device_set, Graph* graph,
39 std::unordered_map<string, GraphDef>* partitions,
40 std::function<string(const Node*)> node_to_loc,
41 std::function<string(const Edge*)> get_tensor_name_attr) {
42 PartitionOptions partition_options;
43 if (node_to_loc != nullptr) {
44 partition_options.node_to_loc = node_to_loc;
45 } else {
46 partition_options.node_to_loc = [](const Node* node) {
47 // TODO(iga): To support the distributed case, first split the graph by
48 // worker (e.g,. using the master session's `SplitByWorker` policy), and
49 // then recursively partition the per-worker shards at the remote
50 // worker(s). Currently, we simply split the graph at device boundaries.
51 return node->assigned_device_name();
52 };
53 }
54 int64_t edge_name_counter = 0;
55 partition_options.new_name = [&edge_name_counter](const string& prefix) {
56 return strings::StrCat(prefix, "/_", ++edge_name_counter);
57 };
58 partition_options.get_incarnation =
59 [&device_set](const string& name) -> int64 {
60 const Device* d = device_set.FindDeviceByName(name);
61 if (d == nullptr) {
62 return PartitionOptions::kIllegalIncarnation;
63 } else {
64 return d->attributes().incarnation();
65 }
66 };
67 partition_options.control_flow_added = false;
68 partition_options.get_tensor_name_attr = get_tensor_name_attr;
69
70 return Partition(partition_options, graph, partitions);
71}
72
73// A pair of matching Send/Recv ops.
74struct SendRecvPair {
75 Node* send_node = nullptr;
76 Node* recv_node = nullptr;
77};
78constexpr char kTensorNameAttr[] = "tensor_name";
79
80// Adds a dependency to each pair of matching Send/Recv ops to make the
81// dependency explicit.
82Status MakeSendRecvDependencyExplicit(Graph* graph) {
83 // Find all matching Send/Recv pairs.
84 absl::flat_hash_map<std::string, SendRecvPair> send_recv_pairs;
85 for (Node* node : graph->op_nodes()) {
86 if (node->IsSend() || node->IsRecv()) {
87 auto tensor_name_it = node->def().attr().find(kTensorNameAttr);
88 if (tensor_name_it == node->def().attr().end()) {
89 return errors::Internal(
90 "'", kTensorNameAttr,
91 "' attribute is not found from node: ", node->DebugString());
92 }
93 if (node->IsSend()) {
94 send_recv_pairs[tensor_name_it->second.s()].send_node = node;
95 } else {
96 send_recv_pairs[tensor_name_it->second.s()].recv_node = node;
97 }
98 }
99 }
100
101 // Add a control dependency to each pair of matching Send/Recv.
102 for (const auto& [tensor_name, send_recv_pair] : send_recv_pairs) {
103 if (send_recv_pair.send_node == nullptr ||
104 send_recv_pair.recv_node == nullptr) {
105 return errors::Internal(
106 "No matching Send/Recv nodes found for tensor_name = ", tensor_name);
107 }
108 graph->AddControlEdge(send_recv_pair.send_node, send_recv_pair.recv_node);
109 }
110 return OkStatus();
111}
112
113} // namespace
114
115Status PartitionFunctionGraph(
116 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
117 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs,
118 std::function<string(const Edge*)> get_tensor_name_attr) {
119 std::unordered_map<string, GraphDef> partitions;
120 TF_RETURN_IF_ERROR(
121 PartitionFunctionGraph(device_set, graph.get(), &partitions,
122 /*node_to_loc=*/nullptr, get_tensor_name_attr));
123
124 for (auto& partition : partitions) {
125 const string& device = partition.first;
126 GraphDef& graph_def = partition.second;
127 // Each partition gets a new graph.
128 std::unique_ptr<Graph> subgraph(
129 new Graph(graph->flib_def().default_registry()));
130 GraphConstructorOptions opts;
131 opts.allow_internal_ops = true;
132 opts.expect_device_spec = true;
133 TF_RETURN_IF_ERROR(
134 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
135 subgraphs->emplace(device, std::move(subgraph));
136 }
137
138 return OkStatus();
139}
140
141StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
142 const DeviceSet& device_set, std::unique_ptr<Graph> graph) {
143 // Skip transfer op insertion if the graph nodes are not assigned to multiple
144 // devices.
145 auto node_to_loc = [](const Node* node) {
146 return node->assigned_device_name();
147 };
148 bool has_multiple_devices = false;
149 absl::optional<std::string> location;
150 for (const Node* node : graph->op_nodes()) {
151 if (location) {
152 if (*location != node_to_loc(node)) {
153 has_multiple_devices = true;
154 break;
155 }
156 } else {
157 location = node_to_loc(node);
158 }
159 }
160 if (!has_multiple_devices) {
161 return graph;
162 }
163
164 // Transfer ops are needed as there are multiple devices, so proceed with the
165 // partitioning.
166 auto new_graph = std::make_unique<Graph>(graph->flib_def());
167
168 std::unordered_map<string, GraphDef> partitions;
169 TF_RETURN_IF_ERROR(PartitionFunctionGraph(device_set, graph.get(),
170 &partitions, node_to_loc,
171 /*get_tensor_name_attr=*/nullptr));
172
173 GraphDef merged_graph_def;
174 if (!partitions.empty()) {
175 auto iter = partitions.begin();
176 merged_graph_def = std::move(iter->second);
177 while (++iter != partitions.end()) {
178 // TODO(b/220440252): MergeFrom() does memory copies when merging repeated
179 // fields. Ideally, we can merge repeated fields by 'moving' data.
180 // Consider using `proto2::util::MoveToEnd()` or so, once it is open
181 // sourced.
182 merged_graph_def.MergeFrom(iter->second);
183 }
184 }
185
186 GraphConstructorOptions opts;
187 opts.allow_internal_ops = true;
188 opts.expect_device_spec = true;
189 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(merged_graph_def),
190 new_graph.get()));
191
192 TF_RETURN_IF_ERROR(MakeSendRecvDependencyExplicit(new_graph.get()));
193
194 return std::move(new_graph);
195}
196
197Status UpdateArgAndRetvalMetadata(
198 Graph* graph, std::vector<FunctionArgIndex>* arg_indices,
199 std::vector<int>* ret_indices,
200 std::vector<AllocatorAttributes>* arg_alloc_attrs,
201 std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device) {
202 std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
203 std::vector<std::pair<Node*, int>> ret_nodes;
204 const AttrValue* attr_value;
205
206 // Find the Arg and Retval nodes, along with their corresponding indices
207 // in the original function.
208 for (Node* node : graph->op_nodes()) {
209 if (node->IsArg()) {
210 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
211 int index = static_cast<int>(attr_value->i());
212 int sub_index = -1;
213 if (node->attrs().Find("sub_index", &attr_value).ok()) {
214 sub_index = static_cast<int>(attr_value->i());
215 }
216 arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
217 } else if (node->IsRetval()) {
218 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
219 int index = static_cast<int>(attr_value->i());
220 ret_nodes.emplace_back(node, index);
221 }
222 }
223
224 // Sort the nodes by index so that the order is stable.
225 //
226 // In particular, this enables calling a single-partition function with
227 // the same signature as the original unpartitioned function.
228 auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
229 std::pair<Node*, FunctionArgIndex> b) {
230 return std::tie(a.second.index, a.second.sub_index) <
231 std::tie(b.second.index, b.second.sub_index);
232 };
233 std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
234 auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
235 return a.second < b.second;
236 };
237 std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
238
239 arg_indices->reserve(arg_nodes.size());
240 for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
241 ret_indices->reserve(ret_nodes.size());
242 for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
243
244 for (int i = 0; i < arg_nodes.size(); ++i) {
245 Node* arg = arg_nodes[i].first;
246 arg->AddAttr("index", i);
247 TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
248 if (arg_alloc_attrs != nullptr) {
249 AllocatorAttributes alloc_attr;
250 DataType type = attr_value->type();
251 MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
252 : MTypeFromDType(type);
253 if (mtype == HOST_MEMORY) {
254 alloc_attr.set_on_host(true);
255 }
256 arg_alloc_attrs->push_back(alloc_attr);
257 }
258 }
259 for (int i = 0; i < ret_nodes.size(); ++i) {
260 Node* ret = ret_nodes[i].first;
261 ret->AddAttr("index", i);
262 TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
263 if (ret_alloc_attrs) {
264 AllocatorAttributes alloc_attr;
265 DataType type = attr_value->type();
266 MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
267 : MTypeFromDType(type);
268 if (mtype == HOST_MEMORY) {
269 alloc_attr.set_on_host(true);
270 }
271 ret_alloc_attrs->push_back(alloc_attr);
272 }
273 }
274
275 return OkStatus();
276}
277
278string FunctionNameGenerator::GetName() {
279 while (true) {
280 const string candidate = strings::StrCat(name_, "_", counter_++);
281 if (flib_def_->Find(candidate) == nullptr) {
282 return candidate;
283 }
284 }
285}
286
287} // namespace tensorflow
288