1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/device_propagation.h" |
17 | |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "absl/container/flat_hash_set.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/graph/algorithm.h" |
24 | #include "tensorflow/core/graph/graph.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace { |
29 | |
30 | const std::string& AssignedOrRequestedDevice(const Node& node) { |
31 | if (!node.assigned_device_name().empty()) { |
32 | return node.assigned_device_name(); |
33 | } |
34 | return node.requested_device(); |
35 | } |
36 | |
37 | void UpdateDeviceFromInputs( |
38 | const device_propagation::NodeFilter& node_filter, |
39 | const device_propagation::DeviceFilter& device_filter, Node* node) { |
40 | if (!AssignedOrRequestedDevice(*node).empty() || !node_filter(*node)) { |
41 | return; |
42 | } |
43 | string proposed_device = "" ; |
44 | Node* proposed_src = nullptr; |
45 | // Scan the input edges, propagate device assignment from its inputs to this |
46 | // node iff all input nodes has the same device assignment and the device is |
47 | // propagatable (checked by `device_filter`). Some kinds of edges are |
48 | // ignored. |
49 | for (const Edge* e : node->in_edges()) { |
50 | // Ignore control edge. |
51 | if (e->IsControlEdge()) { |
52 | continue; |
53 | } |
54 | Node* src = e->src(); |
55 | const string& src_device = AssignedOrRequestedDevice(*src); |
56 | |
57 | // Ignore LoopCond -> Switch and Enter -> Merge. In other words, the device |
58 | // placement of a Switch op is determined by all its non-LoopCond inputs and |
59 | // that of a Merge op is determined by all its non-Enter inputs. |
60 | if ((node->IsSwitch() && src->IsLoopCond()) || |
61 | (node->IsMerge() && src->IsEnter())) { |
62 | continue; |
63 | } |
64 | |
65 | // If a source device is not propagatable, stop. |
66 | if (!device_filter(src_device)) return; |
67 | |
68 | if (proposed_src == nullptr) { |
69 | proposed_device = src_device; |
70 | proposed_src = src; |
71 | } else if (proposed_device != src_device) { |
72 | // The device assignments of some input nodes are not the same. Stop. |
73 | return; |
74 | } |
75 | } |
76 | if (proposed_src) { |
77 | node->set_assigned_device_name(proposed_src->assigned_device_name()); |
78 | node->set_requested_device(proposed_src->requested_device()); |
79 | } |
80 | } |
81 | |
82 | } // namespace |
83 | |
84 | void PropagateDevices(const device_propagation::NodeFilter& node_filter, |
85 | const device_propagation::DeviceFilter& device_filter, |
86 | Graph* graph) { |
87 | ReverseDFS(*graph, {}, [&node_filter, &device_filter](Node* node) { |
88 | UpdateDeviceFromInputs(node_filter, device_filter, node); |
89 | }); |
90 | } |
91 | |
92 | } // namespace tensorflow |
93 | |