1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | |
28 | namespace { |
29 | |
30 | Status TypeForPlaceholder(const TransformFuncContext& context, |
31 | const string& node_name, DataType* result) { |
32 | // If we don't find anything else, return float. |
33 | *result = DT_FLOAT; |
34 | |
35 | // Check to see if we have been given a default for all placeholders. |
36 | if (context.params.count("type" )) { |
37 | if (context.params.at("type" ).size() != 1) { |
38 | return errors::InvalidArgument( |
39 | "You must pass no more than one default 'type' to " |
40 | "strip_unused_nodes" ); |
41 | } |
42 | const string& type_string = context.params.at("type" )[0]; |
43 | if (!DataTypeFromString(type_string, result)) { |
44 | return errors::InvalidArgument("Couldn't understand type argument '" , |
45 | type_string, "'" ); |
46 | } |
47 | } |
48 | |
49 | // See if there's a particular type specified for this placeholder. |
50 | if (context.params.count("name" ) || context.params.count("type_for_name" )) { |
51 | if (!context.params.count("name" ) || |
52 | !context.params.count("type_for_name" ) || |
53 | (context.params.at("type_for_name" ).size() != |
54 | context.params.at("name" ).size())) { |
55 | return errors::InvalidArgument( |
56 | "You must pass a 'type_for_name' arg for every 'name', e.g. " |
57 | "strip_unused_nodes(name=foo, type_for_name=float, name=bar, " |
58 | "type_for_name=quint8" ); |
59 | } |
60 | const int name_count = context.params.at("name" ).size(); |
61 | for (int i = 0; i < name_count; ++i) { |
62 | if (context.params.at("name" )[i] == node_name) { |
63 | const string& type_string = context.params.at("type_for_name" )[i]; |
64 | if (!DataTypeFromString(type_string, result)) { |
65 | return errors::InvalidArgument("Couldn't understand type argument '" , |
66 | type_string, "'" ); |
67 | } |
68 | } |
69 | } |
70 | } |
71 | |
72 | return OkStatus(); |
73 | } |
74 | |
75 | Status ShapeForPlaceholder(const TransformFuncContext& context, |
76 | const string& node_name, TensorShape* result) { |
77 | // If we don't find anything else, return scalar. |
78 | *result = {}; |
79 | |
80 | // Check to see if we have been given a default for all placeholders. |
81 | if (context.params.count("shape" )) { |
82 | if (context.params.at("shape" ).size() != 1) { |
83 | return errors::InvalidArgument( |
84 | "You must pass no more than one default 'shape' to " |
85 | "strip_unused_nodes" ); |
86 | } |
87 | const string& shape_string = context.params.at("shape" )[0]; |
88 | TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); |
89 | } |
90 | |
91 | // See if there's a particular type specified for this placeholder. |
92 | if (context.params.count("name" ) || context.params.count("shape_for_name" )) { |
93 | if (!context.params.count("name" ) || |
94 | !context.params.count("shape_for_name" ) || |
95 | (context.params.at("shape_for_name" ).size() != |
96 | context.params.at("name" ).size())) { |
97 | return errors::InvalidArgument( |
98 | "You must pass a 'shape_for_name' arg for every 'name', e.g. " |
99 | "strip_unused_nodes(name=foo, shape_for_name=\"2,2,1\", name=bar, " |
100 | "shape_for_name=\"1\"" ); |
101 | } |
102 | const int name_count = context.params.at("name" ).size(); |
103 | for (int i = 0; i < name_count; ++i) { |
104 | if (context.params.at("name" )[i] == node_name) { |
105 | const string& shape_string = context.params.at("shape_for_name" )[i]; |
106 | TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); |
107 | } |
108 | } |
109 | } |
110 | |
111 | return OkStatus(); |
112 | } |
113 | } // namespace |
114 | |
115 | // Delete any nodes that don't contribute to the inference result. |
116 | Status StripUnusedNodes(const GraphDef& input_graph_def, |
117 | const TransformFuncContext& context, |
118 | GraphDef* output_graph_def) { |
119 | std::set<string> required_nodes; |
120 | std::set<string> input_nodes; |
121 | for (const string& input : context.input_names) { |
122 | required_nodes.insert(NodeNameFromInput(input)); |
123 | input_nodes.insert(NodeNameFromInput(input)); |
124 | } |
125 | for (const string& output : context.output_names) { |
126 | required_nodes.insert(output); |
127 | } |
128 | |
129 | std::map<string, const NodeDef*> node_lookup; |
130 | MapNamesToNodes(input_graph_def, &node_lookup); |
131 | |
132 | std::vector<string> current_inputs; |
133 | for (const string& output_name : context.output_names) { |
134 | current_inputs.push_back(NodeNameFromInput(output_name)); |
135 | } |
136 | |
137 | while (!current_inputs.empty()) { |
138 | std::set<string> next_inputs; |
139 | for (const string& current_input : current_inputs) { |
140 | required_nodes.insert(current_input); |
141 | if (input_nodes.count(current_input)) { |
142 | continue; |
143 | } |
144 | if (!node_lookup.count(current_input)) { |
145 | return errors::InvalidArgument("Input node " , current_input, |
146 | " not found in graph" ); |
147 | } |
148 | const NodeDef* current_node = node_lookup[current_input]; |
149 | for (const string& input_name : current_node->input()) { |
150 | string input_node_name = NodeNameFromInput(input_name); |
151 | if (!required_nodes.count(input_node_name)) { |
152 | next_inputs.insert(input_node_name); |
153 | } |
154 | } |
155 | } |
156 | current_inputs = |
157 | std::vector<string>(next_inputs.begin(), next_inputs.end()); |
158 | } |
159 | |
160 | GraphDef filtered_graph_def; |
161 | FilterGraphDef(input_graph_def, |
162 | [&](const NodeDef& node) { |
163 | return required_nodes.count(node.name()) > 0; |
164 | }, |
165 | &filtered_graph_def); |
166 | |
167 | output_graph_def->Clear(); |
168 | for (const NodeDef& node : filtered_graph_def.node()) { |
169 | if (input_nodes.count(node.name())) { |
170 | NodeDef placeholder_node; |
171 | if (node.op() == "Placeholder" ) { |
172 | placeholder_node = node; |
173 | } else { |
174 | placeholder_node.set_op("Placeholder" ); |
175 | placeholder_node.set_name(node.name()); |
176 | DataType type; |
177 | TF_RETURN_IF_ERROR(TypeForPlaceholder(context, node.name(), &type)); |
178 | TensorShape shape; |
179 | TF_RETURN_IF_ERROR(ShapeForPlaceholder(context, node.name(), &shape)); |
180 | SetNodeAttr("dtype" , type, &placeholder_node); |
181 | SetNodeAttr("shape" , shape, &placeholder_node); |
182 | } |
183 | *(output_graph_def->mutable_node()->Add()) = placeholder_node; |
184 | } else { |
185 | *(output_graph_def->mutable_node()->Add()) = node; |
186 | } |
187 | } |
188 | return OkStatus(); |
189 | } |
190 | |
191 | REGISTER_GRAPH_TRANSFORM("strip_unused_nodes" , StripUnusedNodes); |
192 | |
193 | } // namespace graph_transforms |
194 | } // namespace tensorflow |
195 | |