1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/inspecting_placer.h" |
16 | |
17 | #include <memory> |
18 | #include <unordered_map> |
19 | #include <vector> |
20 | |
21 | #include "absl/strings/str_join.h" |
22 | #include "tensorflow/core/common_runtime/colocation_graph.h" |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/common_runtime/function_body.h" |
25 | #include "tensorflow/core/common_runtime/function_def_utils.h" |
26 | #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" |
27 | #include "tensorflow/core/framework/function.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/graph/graph_node_util.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | string IOColocationGroups::DebugString() const { |
36 | std::unordered_map<int, std::vector<string>> group_members; |
37 | for (int arg_index = 0; arg_index < input_groups.size(); ++arg_index) { |
38 | int group_id = input_groups[arg_index]; |
39 | group_members[group_id].push_back(strings::StrCat("i:" , arg_index)); |
40 | } |
41 | for (int ret_index = 0; ret_index < output_groups.size(); ++ret_index) { |
42 | int group_id = output_groups[ret_index]; |
43 | group_members[group_id].push_back(strings::StrCat("o:" , ret_index)); |
44 | } |
45 | |
46 | std::vector<string> group_strings; |
47 | for (const auto& it : group_members) { |
48 | int group_id = it.first; |
49 | const std::vector<string>& members = it.second; |
50 | const PossibleDevices& devices = group_devices[group_id]; |
51 | group_strings.push_back(strings::StrCat( |
52 | "Group(" , group_id, " members = [" , absl::StrJoin(members, ", " ), |
53 | "] requested_device_name = \"" , |
54 | DeviceNameUtils::ParsedNameToString(devices.requested_device_name), |
55 | "\" resource_device_name = \"" , |
56 | DeviceNameUtils::ParsedNameToString(devices.resource_device_name), |
57 | "\" device_types = [" , |
58 | absl::StrJoin( |
59 | devices.device_types, ", " , |
60 | [](string* out, const std::pair<DeviceType, int32>& type_and_pref) { |
61 | out->append(DeviceTypeString(type_and_pref.first)); |
62 | }), |
63 | "])" )); |
64 | } |
65 | |
66 | return absl::StrJoin(group_strings, "\n\t" ); |
67 | } |
68 | |
69 | // Utility class for constructing IOColocationGroups from a ColocationGraph. |
70 | class ColocationGraphToIOColocationGroups { |
71 | public: |
72 | // colocation_graph is mutable because finding root nodes can update |
73 | // parent pointers. It is not modified otherwise. |
74 | explicit ColocationGraphToIOColocationGroups( |
75 | ColocationGraph* colocation_graph) |
76 | : colocation_graph_(colocation_graph), next_group_id_(0) {} |
77 | |
78 | void AssignGroups(const gtl::InlinedVector<Node*, 4>& nodes, |
79 | std::vector<int>* groups) { |
80 | for (int i = 0; i < nodes.size(); ++i) { |
81 | int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id()); |
82 | const auto& it = group_ids_.find(root_id); |
83 | int assigned_group_id; |
84 | if (it == group_ids_.end()) { |
85 | group_ids_[root_id] = next_group_id_; |
86 | assigned_group_id = next_group_id_; |
87 | ++next_group_id_; |
88 | } else { |
89 | assigned_group_id = it->second; |
90 | } |
91 | groups->push_back(assigned_group_id); |
92 | } |
93 | } |
94 | |
95 | Status FillGroups(std::vector<PossibleDevices>* group_devices) { |
96 | group_devices->resize(group_ids_.size()); |
97 | for (const auto& it : group_ids_) { |
98 | int assigned_group_id = it.second; |
99 | PossibleDevices& possible_devices = (*group_devices)[assigned_group_id]; |
100 | const Member& member = colocation_graph_->members()[it.first]; |
101 | TF_RETURN_IF_ERROR(member.FillPossibleDevices(&possible_devices)); |
102 | } |
103 | return OkStatus(); |
104 | } |
105 | |
106 | private: |
107 | ColocationGraph* colocation_graph_; |
108 | // Allocated group ids: collocation_graph root id -> allocated group id. |
109 | std::unordered_map<int, int> group_ids_; |
110 | int next_group_id_; |
111 | }; |
112 | |
113 | InspectingPlacer::InspectingPlacer(const FunctionStack& stack, |
114 | const FunctionLibraryDefinition* flib_def, |
115 | const DeviceSet* device_set, |
116 | const Device* default_device, |
117 | bool allow_soft_placement, |
118 | bool log_device_placement) |
119 | : stack_(stack), |
120 | flib_def_(*flib_def), |
121 | device_set_(*device_set), |
122 | default_device_(default_device), |
123 | allow_soft_placement_(allow_soft_placement), |
124 | log_device_placement_(log_device_placement) {} |
125 | |
126 | Status InspectingPlacer::ComputeIOColocationGroups(const Node& node, |
127 | IOColocationGroups* groups) { |
128 | const FunctionDef* fdef; |
129 | NameAttrList func; |
130 | TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func)); |
131 | std::unique_ptr<FunctionBody> fbody; |
132 | |
133 | TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()), |
134 | &flib_def_, &fbody)); |
135 | |
136 | TF_RETURN_IF_ERROR( |
137 | IsolatePlacerInspectionRequiredOps(flib_def_, fbody->graph)); |
138 | if (stack_.HasFunction(func.name())) { |
139 | return errors::Unimplemented( |
140 | "Recursive function calls are not supported. Node " , |
141 | FormatNodeForError(node), " inside the body of " , |
142 | errors::FormatFunctionForError(stack_.current_function_name()), |
143 | " calls function " , errors::FormatFunctionForError(func.name()), |
144 | " which is already present in the call stack:\n " , |
145 | stack_.FormatForError()); |
146 | } |
147 | |
148 | ColocationGraph colocation_graph( |
149 | fbody->graph, stack_.Push(&node, func.name()), &flib_def_, &device_set_, |
150 | default_device_, allow_soft_placement_, log_device_placement_); |
151 | TF_RETURN_IF_ERROR(colocation_graph.Initialize()); |
152 | |
153 | ColocationGraphToIOColocationGroups converter(&colocation_graph); |
154 | converter.AssignGroups(fbody->arg_nodes, &groups->input_groups); |
155 | converter.AssignGroups(fbody->ret_nodes, &groups->output_groups); |
156 | TF_RETURN_IF_ERROR(converter.FillGroups(&groups->group_devices)); |
157 | return OkStatus(); |
158 | } |
159 | |
160 | } // namespace tensorflow |
161 | |