1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17#include <stdexcept>
18#include <string>
19#include <unordered_map>
20#include <vector>
21
22#include "pybind11/pybind11.h"
23#include "pybind11/stl.h"
24#include "tensorflow/core/framework/graph.pb.h"
25#include "tensorflow/core/framework/node_def_util.h"
26#include "tensorflow/core/framework/op.h"
27#include "tensorflow/core/grappler/costs/graph_properties.h"
28#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
29#include "tensorflow/core/grappler/grappler_item.h"
30#include "tensorflow/core/grappler/grappler_item_builder.h"
31#include "tensorflow/core/grappler/utils.h"
32#include "tensorflow/core/grappler/utils/topological_sort.h"
33#include "tensorflow/core/protobuf/meta_graph.pb.h"
34#include "tensorflow/python/lib/core/pybind11_status.h"
35
36namespace py = pybind11;
37
38class ColocationGroups {
39 public:
40 void Group(const std::string& x, const std::string& y) {
41 Rep* x_root = Find(x);
42 Rep* y_root = Find(y);
43
44 // x and y are already in the same set
45 if (x_root == y_root) {
46 return;
47 }
48 // x and y are not in same set, so we merge them
49 // Use the occasion to strengthen what we know about the handle by merging
50 // the information about the 2 subsets.
51 if (x_root->rank < y_root->rank) {
52 x_root->parent = y_root;
53 } else if (x_root->rank > y_root->rank) {
54 y_root->parent = x_root;
55 } else {
56 // Arbitrarily make one root the new parent
57 y_root->parent = x_root;
58 x_root->rank = x_root->rank + 1;
59 }
60 }
61
62 void ExtractGroups(std::vector<std::vector<std::string>>* groups) {
63 groups->reserve(nodes_.size());
64 std::unordered_map<const Rep*, int> group_ids;
65 for (const auto& rep : nodes_) {
66 Rep* r = Find(rep.first);
67 auto it = group_ids.find(r);
68 std::vector<std::string>* g;
69 if (it == group_ids.end()) {
70 int id = group_ids.size();
71 group_ids[r] = id;
72 groups->resize(id + 1);
73 g = &groups->back();
74 } else {
75 int id = it->second;
76 g = &((*groups)[id]);
77 }
78 g->push_back(rep.first);
79 }
80 }
81
82 private:
83 struct Rep {
84 // Parent in the tree used to encode the set.
85 Rep* parent;
86 // Rank in the tree, used to figure out how to compress the path to the root
87 // of the tree.
88 int rank;
89 // The node.
90 std::string value;
91 };
92
93 Rep* Find(const std::string& n) {
94 auto it = nodes_.find(n);
95 if (it == nodes_.end()) {
96 // This is the first time we process this handle, create an entry for it.
97 Rep* node = new Rep;
98 node->parent = node;
99 node->rank = 0;
100 node->value = n;
101 nodes_[n] = node;
102 return node;
103 }
104 // Return the representative for the set, which is the root of the tree.
105 // Apply path compression to speedup future queries.
106 Rep* node = it->second;
107 Rep* root = node->parent;
108 while (root != root->parent) {
109 root = root->parent;
110 }
111 while (node->parent != root) {
112 Rep* next = node->parent;
113 node->parent = root;
114 node = next;
115 }
116 return root;
117 }
118
119 std::unordered_map<std::string, Rep*> nodes_;
120};
121
122PYBIND11_MAKE_OPAQUE(tensorflow::grappler::GrapplerItem);
123
124PYBIND11_MODULE(_pywrap_tf_item, m) {
125 py::class_<tensorflow::grappler::GrapplerItem> grappler_item(
126 m, "tensorflow::grappler::GrapplerItem");
127
128 m.def("TF_NewItem",
129 [](const py::bytes& serialized_metagraph, bool ignore_colocation,
130 bool ignore_user_placement) -> tensorflow::grappler::GrapplerItem* {
131 tensorflow::MetaGraphDef metagraph;
132 if (!metagraph.ParseFromString(std::string(serialized_metagraph))) {
133 throw std::invalid_argument(
134 "The MetaGraphDef could not be parsed as a valid protocol "
135 "buffer");
136 }
137 if (metagraph.collection_def().count("train_op") == 0) {
138 MaybeRaiseRegisteredFromStatus(tensorflow::errors::InvalidArgument(
139 "train_op not specified in the metagraph"));
140 }
141
142 tensorflow::grappler::ItemConfig cfg;
143 cfg.ignore_user_placement = ignore_user_placement;
144 cfg.ignore_colocation = ignore_colocation;
145 std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
146 tensorflow::grappler::GrapplerItemFromMetaGraphDef(
147 "item", metagraph, cfg);
148 if (item == nullptr) {
149 MaybeRaiseRegisteredFromStatus(
150 tensorflow::errors::InvalidArgument("Invalid metagraph"));
151 }
152 return item.release();
153 });
154
155 m.def("TF_IdentifyImportantOps",
156 [](tensorflow::grappler::GrapplerItem* item,
157 bool sort_topologically) -> std::vector<std::string> {
158 std::vector<const tensorflow::NodeDef*> main_ops =
159 item->MainOpsFanin();
160 std::vector<const tensorflow::NodeDef*> enqueue_ops =
161 item->EnqueueOpsFanin();
162 std::unordered_set<std::string> op_names;
163 for (auto op : main_ops) {
164 op_names.insert(op->name());
165 }
166 for (auto op : enqueue_ops) {
167 op_names.insert(op->name());
168 }
169
170 std::vector<std::string> ops;
171 if (sort_topologically) {
172 tensorflow::GraphDef subgraph;
173 for (const tensorflow::NodeDef& node : item->graph.node()) {
174 if (op_names.find(node.name()) != op_names.end()) {
175 *subgraph.add_node() = node;
176 }
177 }
178 tensorflow::MaybeRaiseFromStatus(
179 tensorflow::grappler::TopologicalSort(&subgraph));
180 for (const tensorflow::NodeDef& node : subgraph.node()) {
181 ops.push_back(node.name());
182 }
183 } else {
184 for (const auto& op_name : op_names) {
185 ops.push_back(op_name);
186 }
187 }
188 return ops;
189 });
190
191 m.def("TF_GetOpProperties",
192 [](tensorflow::grappler::GrapplerItem* item)
193 -> std::unordered_map<std::string, std::vector<py::bytes>> {
194 tensorflow::grappler::GraphProperties properties(*item);
195 tensorflow::MaybeRaiseFromStatus(properties.InferStatically(false));
196
197 std::unordered_map<std::string, std::vector<py::bytes>> props;
198 for (const auto& node : item->graph.node()) {
199 const std::string& node_name = node.name();
200 const std::vector<tensorflow::OpInfo::TensorProperties>&
201 output_props = properties.GetOutputProperties(node_name);
202
203 std::vector<py::bytes> prop;
204 prop.reserve(output_props.size());
205 for (const auto& output_prop : output_props) {
206 prop.push_back(output_prop.SerializeAsString());
207 }
208 props[node_name] = prop;
209 }
210 return props;
211 });
212
213 m.def("TF_GetColocationGroups",
214 [](tensorflow::grappler::GrapplerItem* item)
215 -> std::vector<std::vector<std::string>> {
216 ColocationGroups groupings;
217 tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
218 for (const auto& node : item->graph.node()) {
219 const tensorflow::OpDef* op_def;
220 if (!registry->LookUpOpDef(node.op(), &op_def).ok()) {
221 continue;
222 }
223 tensorflow::NameRangeMap inputs;
224 tensorflow::NameRangeMap outputs;
225 if (!tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs)
226 .ok()) {
227 continue;
228 }
229 for (const auto& arg : op_def->input_arg()) {
230 if (!arg.is_ref()) {
231 continue;
232 }
233 const auto& range = inputs[arg.name()];
234 for (int i = range.first; i < range.second; ++i) {
235 groupings.Group(node.name(),
236 tensorflow::grappler::NodeName(node.input(i)));
237 }
238 }
239 }
240
241 std::vector<std::vector<std::string>> groups;
242 groupings.ExtractGroups(&groups);
243 return groups;
244 });
245}
246