1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <memory> |
17 | #include <stdexcept> |
18 | #include <string> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "pybind11/pybind11.h" |
23 | #include "pybind11/stl.h" |
24 | #include "tensorflow/core/framework/graph.pb.h" |
25 | #include "tensorflow/core/framework/node_def_util.h" |
26 | #include "tensorflow/core/framework/op.h" |
27 | #include "tensorflow/core/grappler/costs/graph_properties.h" |
28 | #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" |
29 | #include "tensorflow/core/grappler/grappler_item.h" |
30 | #include "tensorflow/core/grappler/grappler_item_builder.h" |
31 | #include "tensorflow/core/grappler/utils.h" |
32 | #include "tensorflow/core/grappler/utils/topological_sort.h" |
33 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
34 | #include "tensorflow/python/lib/core/pybind11_status.h" |
35 | |
36 | namespace py = pybind11; |
37 | |
38 | class ColocationGroups { |
39 | public: |
40 | void Group(const std::string& x, const std::string& y) { |
41 | Rep* x_root = Find(x); |
42 | Rep* y_root = Find(y); |
43 | |
44 | // x and y are already in the same set |
45 | if (x_root == y_root) { |
46 | return; |
47 | } |
48 | // x and y are not in same set, so we merge them |
49 | // Use the occasion to strengthen what we know about the handle by merging |
50 | // the information about the 2 subsets. |
51 | if (x_root->rank < y_root->rank) { |
52 | x_root->parent = y_root; |
53 | } else if (x_root->rank > y_root->rank) { |
54 | y_root->parent = x_root; |
55 | } else { |
56 | // Arbitrarily make one root the new parent |
57 | y_root->parent = x_root; |
58 | x_root->rank = x_root->rank + 1; |
59 | } |
60 | } |
61 | |
62 | void (std::vector<std::vector<std::string>>* groups) { |
63 | groups->reserve(nodes_.size()); |
64 | std::unordered_map<const Rep*, int> group_ids; |
65 | for (const auto& rep : nodes_) { |
66 | Rep* r = Find(rep.first); |
67 | auto it = group_ids.find(r); |
68 | std::vector<std::string>* g; |
69 | if (it == group_ids.end()) { |
70 | int id = group_ids.size(); |
71 | group_ids[r] = id; |
72 | groups->resize(id + 1); |
73 | g = &groups->back(); |
74 | } else { |
75 | int id = it->second; |
76 | g = &((*groups)[id]); |
77 | } |
78 | g->push_back(rep.first); |
79 | } |
80 | } |
81 | |
82 | private: |
83 | struct Rep { |
84 | // Parent in the tree used to encode the set. |
85 | Rep* parent; |
86 | // Rank in the tree, used to figure out how to compress the path to the root |
87 | // of the tree. |
88 | int rank; |
89 | // The node. |
90 | std::string value; |
91 | }; |
92 | |
93 | Rep* Find(const std::string& n) { |
94 | auto it = nodes_.find(n); |
95 | if (it == nodes_.end()) { |
96 | // This is the first time we process this handle, create an entry for it. |
97 | Rep* node = new Rep; |
98 | node->parent = node; |
99 | node->rank = 0; |
100 | node->value = n; |
101 | nodes_[n] = node; |
102 | return node; |
103 | } |
104 | // Return the representative for the set, which is the root of the tree. |
105 | // Apply path compression to speedup future queries. |
106 | Rep* node = it->second; |
107 | Rep* root = node->parent; |
108 | while (root != root->parent) { |
109 | root = root->parent; |
110 | } |
111 | while (node->parent != root) { |
112 | Rep* next = node->parent; |
113 | node->parent = root; |
114 | node = next; |
115 | } |
116 | return root; |
117 | } |
118 | |
119 | std::unordered_map<std::string, Rep*> nodes_; |
120 | }; |
121 | |
122 | PYBIND11_MAKE_OPAQUE(tensorflow::grappler::GrapplerItem); |
123 | |
124 | PYBIND11_MODULE(_pywrap_tf_item, m) { |
125 | py::class_<tensorflow::grappler::GrapplerItem> grappler_item( |
126 | m, "tensorflow::grappler::GrapplerItem" ); |
127 | |
128 | m.def("TF_NewItem" , |
129 | [](const py::bytes& serialized_metagraph, bool ignore_colocation, |
130 | bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* { |
131 | tensorflow::MetaGraphDef metagraph; |
132 | if (!metagraph.ParseFromString(std::string(serialized_metagraph))) { |
133 | throw std::invalid_argument( |
134 | "The MetaGraphDef could not be parsed as a valid protocol " |
135 | "buffer" ); |
136 | } |
137 | if (metagraph.collection_def().count("train_op" ) == 0) { |
138 | MaybeRaiseRegisteredFromStatus(tensorflow::errors::InvalidArgument( |
139 | "train_op not specified in the metagraph" )); |
140 | } |
141 | |
142 | tensorflow::grappler::ItemConfig cfg; |
143 | cfg.ignore_user_placement = ignore_user_placement; |
144 | cfg.ignore_colocation = ignore_colocation; |
145 | std::unique_ptr<tensorflow::grappler::GrapplerItem> item = |
146 | tensorflow::grappler::GrapplerItemFromMetaGraphDef( |
147 | "item" , metagraph, cfg); |
148 | if (item == nullptr) { |
149 | MaybeRaiseRegisteredFromStatus( |
150 | tensorflow::errors::InvalidArgument("Invalid metagraph" )); |
151 | } |
152 | return item.release(); |
153 | }); |
154 | |
155 | m.def("TF_IdentifyImportantOps" , |
156 | [](tensorflow::grappler::GrapplerItem* item, |
157 | bool sort_topologically) -> std::vector<std::string> { |
158 | std::vector<const tensorflow::NodeDef*> main_ops = |
159 | item->MainOpsFanin(); |
160 | std::vector<const tensorflow::NodeDef*> enqueue_ops = |
161 | item->EnqueueOpsFanin(); |
162 | std::unordered_set<std::string> op_names; |
163 | for (auto op : main_ops) { |
164 | op_names.insert(op->name()); |
165 | } |
166 | for (auto op : enqueue_ops) { |
167 | op_names.insert(op->name()); |
168 | } |
169 | |
170 | std::vector<std::string> ops; |
171 | if (sort_topologically) { |
172 | tensorflow::GraphDef subgraph; |
173 | for (const tensorflow::NodeDef& node : item->graph.node()) { |
174 | if (op_names.find(node.name()) != op_names.end()) { |
175 | *subgraph.add_node() = node; |
176 | } |
177 | } |
178 | tensorflow::MaybeRaiseFromStatus( |
179 | tensorflow::grappler::TopologicalSort(&subgraph)); |
180 | for (const tensorflow::NodeDef& node : subgraph.node()) { |
181 | ops.push_back(node.name()); |
182 | } |
183 | } else { |
184 | for (const auto& op_name : op_names) { |
185 | ops.push_back(op_name); |
186 | } |
187 | } |
188 | return ops; |
189 | }); |
190 | |
191 | m.def("TF_GetOpProperties" , |
192 | [](tensorflow::grappler::GrapplerItem* item) |
193 | -> std::unordered_map<std::string, std::vector<py::bytes>> { |
194 | tensorflow::grappler::GraphProperties properties(*item); |
195 | tensorflow::MaybeRaiseFromStatus(properties.InferStatically(false)); |
196 | |
197 | std::unordered_map<std::string, std::vector<py::bytes>> props; |
198 | for (const auto& node : item->graph.node()) { |
199 | const std::string& node_name = node.name(); |
200 | const std::vector<tensorflow::OpInfo::TensorProperties>& |
201 | output_props = properties.GetOutputProperties(node_name); |
202 | |
203 | std::vector<py::bytes> prop; |
204 | prop.reserve(output_props.size()); |
205 | for (const auto& output_prop : output_props) { |
206 | prop.push_back(output_prop.SerializeAsString()); |
207 | } |
208 | props[node_name] = prop; |
209 | } |
210 | return props; |
211 | }); |
212 | |
213 | m.def("TF_GetColocationGroups" , |
214 | [](tensorflow::grappler::GrapplerItem* item) |
215 | -> std::vector<std::vector<std::string>> { |
216 | ColocationGroups groupings; |
217 | tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); |
218 | for (const auto& node : item->graph.node()) { |
219 | const tensorflow::OpDef* op_def; |
220 | if (!registry->LookUpOpDef(node.op(), &op_def).ok()) { |
221 | continue; |
222 | } |
223 | tensorflow::NameRangeMap inputs; |
224 | tensorflow::NameRangeMap outputs; |
225 | if (!tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs) |
226 | .ok()) { |
227 | continue; |
228 | } |
229 | for (const auto& arg : op_def->input_arg()) { |
230 | if (!arg.is_ref()) { |
231 | continue; |
232 | } |
233 | const auto& range = inputs[arg.name()]; |
234 | for (int i = range.first; i < range.second; ++i) { |
235 | groupings.Group(node.name(), |
236 | tensorflow::grappler::NodeName(node.input(i))); |
237 | } |
238 | } |
239 | } |
240 | |
241 | std::vector<std::vector<std::string>> groups; |
242 | groupings.ExtractGroups(&groups); |
243 | return groups; |
244 | }); |
245 | } |
246 | |