1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17#include "tensorflow/core/common_runtime/graph_constructor.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/subgraph.h"
20#include "tensorflow/core/platform/init_main.h"
21#include "tensorflow/core/public/session.h"
22#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23#include "tensorflow/tools/graph_transforms/transform_utils.h"
24
25namespace tensorflow {
26namespace graph_transforms {
27
28namespace {
29
30Status TypeForPlaceholder(const TransformFuncContext& context,
31 const string& node_name, DataType* result) {
32 // If we don't find anything else, return float.
33 *result = DT_FLOAT;
34
35 // Check to see if we have been given a default for all placeholders.
36 if (context.params.count("type")) {
37 if (context.params.at("type").size() != 1) {
38 return errors::InvalidArgument(
39 "You must pass no more than one default 'type' to "
40 "strip_unused_nodes");
41 }
42 const string& type_string = context.params.at("type")[0];
43 if (!DataTypeFromString(type_string, result)) {
44 return errors::InvalidArgument("Couldn't understand type argument '",
45 type_string, "'");
46 }
47 }
48
49 // See if there's a particular type specified for this placeholder.
50 if (context.params.count("name") || context.params.count("type_for_name")) {
51 if (!context.params.count("name") ||
52 !context.params.count("type_for_name") ||
53 (context.params.at("type_for_name").size() !=
54 context.params.at("name").size())) {
55 return errors::InvalidArgument(
56 "You must pass a 'type_for_name' arg for every 'name', e.g. "
57 "strip_unused_nodes(name=foo, type_for_name=float, name=bar, "
58 "type_for_name=quint8");
59 }
60 const int name_count = context.params.at("name").size();
61 for (int i = 0; i < name_count; ++i) {
62 if (context.params.at("name")[i] == node_name) {
63 const string& type_string = context.params.at("type_for_name")[i];
64 if (!DataTypeFromString(type_string, result)) {
65 return errors::InvalidArgument("Couldn't understand type argument '",
66 type_string, "'");
67 }
68 }
69 }
70 }
71
72 return OkStatus();
73}
74
75Status ShapeForPlaceholder(const TransformFuncContext& context,
76 const string& node_name, TensorShape* result) {
77 // If we don't find anything else, return scalar.
78 *result = {};
79
80 // Check to see if we have been given a default for all placeholders.
81 if (context.params.count("shape")) {
82 if (context.params.at("shape").size() != 1) {
83 return errors::InvalidArgument(
84 "You must pass no more than one default 'shape' to "
85 "strip_unused_nodes");
86 }
87 const string& shape_string = context.params.at("shape")[0];
88 TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
89 }
90
91 // See if there's a particular type specified for this placeholder.
92 if (context.params.count("name") || context.params.count("shape_for_name")) {
93 if (!context.params.count("name") ||
94 !context.params.count("shape_for_name") ||
95 (context.params.at("shape_for_name").size() !=
96 context.params.at("name").size())) {
97 return errors::InvalidArgument(
98 "You must pass a 'shape_for_name' arg for every 'name', e.g. "
99 "strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, "
100 "shape_for_name=\"1\"");
101 }
102 const int name_count = context.params.at("name").size();
103 for (int i = 0; i < name_count; ++i) {
104 if (context.params.at("name")[i] == node_name) {
105 const string& shape_string = context.params.at("shape_for_name")[i];
106 TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
107 }
108 }
109 }
110
111 return OkStatus();
112}
113} // namespace
114
115// Delete any nodes that don't contribute to the inference result.
116Status StripUnusedNodes(const GraphDef& input_graph_def,
117 const TransformFuncContext& context,
118 GraphDef* output_graph_def) {
119 std::set<string> required_nodes;
120 std::set<string> input_nodes;
121 for (const string& input : context.input_names) {
122 required_nodes.insert(NodeNameFromInput(input));
123 input_nodes.insert(NodeNameFromInput(input));
124 }
125 for (const string& output : context.output_names) {
126 required_nodes.insert(output);
127 }
128
129 std::map<string, const NodeDef*> node_lookup;
130 MapNamesToNodes(input_graph_def, &node_lookup);
131
132 std::vector<string> current_inputs;
133 for (const string& output_name : context.output_names) {
134 current_inputs.push_back(NodeNameFromInput(output_name));
135 }
136
137 while (!current_inputs.empty()) {
138 std::set<string> next_inputs;
139 for (const string& current_input : current_inputs) {
140 required_nodes.insert(current_input);
141 if (input_nodes.count(current_input)) {
142 continue;
143 }
144 if (!node_lookup.count(current_input)) {
145 return errors::InvalidArgument("Input node ", current_input,
146 " not found in graph");
147 }
148 const NodeDef* current_node = node_lookup[current_input];
149 for (const string& input_name : current_node->input()) {
150 string input_node_name = NodeNameFromInput(input_name);
151 if (!required_nodes.count(input_node_name)) {
152 next_inputs.insert(input_node_name);
153 }
154 }
155 }
156 current_inputs =
157 std::vector<string>(next_inputs.begin(), next_inputs.end());
158 }
159
160 GraphDef filtered_graph_def;
161 FilterGraphDef(input_graph_def,
162 [&](const NodeDef& node) {
163 return required_nodes.count(node.name()) > 0;
164 },
165 &filtered_graph_def);
166
167 output_graph_def->Clear();
168 for (const NodeDef& node : filtered_graph_def.node()) {
169 if (input_nodes.count(node.name())) {
170 NodeDef placeholder_node;
171 if (node.op() == "Placeholder") {
172 placeholder_node = node;
173 } else {
174 placeholder_node.set_op("Placeholder");
175 placeholder_node.set_name(node.name());
176 DataType type;
177 TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type));
178 TensorShape shape;
179 TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape));
180 SetNodeAttr("dtype", type, &placeholder_node);
181 SetNodeAttr("shape", shape, &placeholder_node);
182 }
183 *(output_graph_def->mutable_node()->Add()) = placeholder_node;
184 } else {
185 *(output_graph_def->mutable_node()->Add()) = node;
186 }
187 }
188 return OkStatus();
189}
190
191REGISTER_GRAPH_TRANSFORM("strip_unused_nodes", StripUnusedNodes);
192
193} // namespace graph_transforms
194} // namespace tensorflow
195