1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/inspecting_placer.h"
16
17#include <memory>
18#include <unordered_map>
19#include <vector>
20
21#include "absl/strings/str_join.h"
22#include "tensorflow/core/common_runtime/colocation_graph.h"
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/common_runtime/function_body.h"
25#include "tensorflow/core/common_runtime/function_def_utils.h"
26#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/graph/graph_node_util.h"
31#include "tensorflow/core/lib/core/errors.h"
32
33namespace tensorflow {
34
35string IOColocationGroups::DebugString() const {
36 std::unordered_map<int, std::vector<string>> group_members;
37 for (int arg_index = 0; arg_index < input_groups.size(); ++arg_index) {
38 int group_id = input_groups[arg_index];
39 group_members[group_id].push_back(strings::StrCat("i:", arg_index));
40 }
41 for (int ret_index = 0; ret_index < output_groups.size(); ++ret_index) {
42 int group_id = output_groups[ret_index];
43 group_members[group_id].push_back(strings::StrCat("o:", ret_index));
44 }
45
46 std::vector<string> group_strings;
47 for (const auto& it : group_members) {
48 int group_id = it.first;
49 const std::vector<string>& members = it.second;
50 const PossibleDevices& devices = group_devices[group_id];
51 group_strings.push_back(strings::StrCat(
52 "Group(", group_id, " members = [", absl::StrJoin(members, ", "),
53 "] requested_device_name = \"",
54 DeviceNameUtils::ParsedNameToString(devices.requested_device_name),
55 "\" resource_device_name = \"",
56 DeviceNameUtils::ParsedNameToString(devices.resource_device_name),
57 "\" device_types = [",
58 absl::StrJoin(
59 devices.device_types, ", ",
60 [](string* out, const std::pair<DeviceType, int32>& type_and_pref) {
61 out->append(DeviceTypeString(type_and_pref.first));
62 }),
63 "])"));
64 }
65
66 return absl::StrJoin(group_strings, "\n\t");
67}
68
69// Utility class for constructing IOColocationGroups from a ColocationGraph.
70class ColocationGraphToIOColocationGroups {
71 public:
72 // colocation_graph is mutable because finding root nodes can update
73 // parent pointers. It is not modified otherwise.
74 explicit ColocationGraphToIOColocationGroups(
75 ColocationGraph* colocation_graph)
76 : colocation_graph_(colocation_graph), next_group_id_(0) {}
77
78 void AssignGroups(const gtl::InlinedVector<Node*, 4>& nodes,
79 std::vector<int>* groups) {
80 for (int i = 0; i < nodes.size(); ++i) {
81 int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id());
82 const auto& it = group_ids_.find(root_id);
83 int assigned_group_id;
84 if (it == group_ids_.end()) {
85 group_ids_[root_id] = next_group_id_;
86 assigned_group_id = next_group_id_;
87 ++next_group_id_;
88 } else {
89 assigned_group_id = it->second;
90 }
91 groups->push_back(assigned_group_id);
92 }
93 }
94
95 Status FillGroups(std::vector<PossibleDevices>* group_devices) {
96 group_devices->resize(group_ids_.size());
97 for (const auto& it : group_ids_) {
98 int assigned_group_id = it.second;
99 PossibleDevices& possible_devices = (*group_devices)[assigned_group_id];
100 const Member& member = colocation_graph_->members()[it.first];
101 TF_RETURN_IF_ERROR(member.FillPossibleDevices(&possible_devices));
102 }
103 return OkStatus();
104 }
105
106 private:
107 ColocationGraph* colocation_graph_;
108 // Allocated group ids: collocation_graph root id -> allocated group id.
109 std::unordered_map<int, int> group_ids_;
110 int next_group_id_;
111};
112
113InspectingPlacer::InspectingPlacer(const FunctionStack& stack,
114 const FunctionLibraryDefinition* flib_def,
115 const DeviceSet* device_set,
116 const Device* default_device,
117 bool allow_soft_placement,
118 bool log_device_placement)
119 : stack_(stack),
120 flib_def_(*flib_def),
121 device_set_(*device_set),
122 default_device_(default_device),
123 allow_soft_placement_(allow_soft_placement),
124 log_device_placement_(log_device_placement) {}
125
126Status InspectingPlacer::ComputeIOColocationGroups(const Node& node,
127 IOColocationGroups* groups) {
128 const FunctionDef* fdef;
129 NameAttrList func;
130 TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
131 std::unique_ptr<FunctionBody> fbody;
132
133 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
134 &flib_def_, &fbody));
135
136 TF_RETURN_IF_ERROR(
137 IsolatePlacerInspectionRequiredOps(flib_def_, fbody->graph));
138 if (stack_.HasFunction(func.name())) {
139 return errors::Unimplemented(
140 "Recursive function calls are not supported. Node ",
141 FormatNodeForError(node), " inside the body of ",
142 errors::FormatFunctionForError(stack_.current_function_name()),
143 " calls function ", errors::FormatFunctionForError(func.name()),
144 " which is already present in the call stack:\n ",
145 stack_.FormatForError());
146 }
147
148 ColocationGraph colocation_graph(
149 fbody->graph, stack_.Push(&node, func.name()), &flib_def_, &device_set_,
150 default_device_, allow_soft_placement_, log_device_placement_);
151 TF_RETURN_IF_ERROR(colocation_graph.Initialize());
152
153 ColocationGraphToIOColocationGroups converter(&colocation_graph);
154 converter.AssignGroups(fbody->arg_nodes, &groups->input_groups);
155 converter.AssignGroups(fbody->ret_nodes, &groups->output_groups);
156 TF_RETURN_IF_ERROR(converter.FillGroups(&groups->group_devices));
157 return OkStatus();
158}
159
160} // namespace tensorflow
161