1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | |
28 | // Deletes any specified types of nodes, unless they're necessary for the |
29 | // graph's inputs or outputs. |
30 | Status RemoveNodes(const GraphDef& input_graph_def, |
31 | const TransformFuncContext& context, |
32 | GraphDef* output_graph_def) { |
33 | if (!context.params.count("op" )) { |
34 | return errors::InvalidArgument( |
35 | "remove_nodes expects at least one 'op'" |
36 | "argument, e.g. remove_nodes(op=Identity)" ); |
37 | } |
38 | int32_t max_inputs; |
39 | TF_RETURN_IF_ERROR( |
40 | context.GetOneInt32Parameter("max_inputs" , 1, &max_inputs)); |
41 | |
42 | // Make sure we don't get rid of any nodes used as graph inputs or outputs. |
43 | std::set<string> required_nodes; |
44 | for (const string& input : context.input_names) { |
45 | required_nodes.insert(NodeNameFromInput(input)); |
46 | } |
47 | for (const string& output : context.output_names) { |
48 | required_nodes.insert(NodeNameFromInput(output)); |
49 | } |
50 | |
51 | std::vector<string> ops_to_remove = context.params.at("op" ); |
52 | GraphDef current_graph_def = input_graph_def; |
53 | for (const string& op : ops_to_remove) { |
54 | for (int num_inputs = 1; num_inputs <= max_inputs; ++num_inputs) { |
55 | // Look for a variable number of inputs. |
56 | OpTypePattern pattern = {op}; |
57 | pattern.inputs.resize(num_inputs); |
58 | for (int i = 0; i < num_inputs; ++i) { |
59 | pattern.inputs[i] = {"*" }; |
60 | } |
61 | // Keep looking for nodes to remove until there are no more changes. |
62 | bool any_nodes_removed; |
63 | do { |
64 | any_nodes_removed = false; |
65 | std::map<string, string> inputs_to_rename; |
66 | GraphDef replaced_graph_def; |
67 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
68 | current_graph_def, pattern, |
69 | [&inputs_to_rename, &required_nodes, &any_nodes_removed]( |
70 | const NodeMatch& match, const std::set<string>& input_nodes, |
71 | const std::set<string>& output_nodes, |
72 | std::vector<NodeDef>* new_nodes) { |
73 | const NodeDef& replace_node = match.node; |
74 | // If this node is needed in the inputs or outputs don't replace |
75 | // it. |
76 | if (required_nodes.count(replace_node.name())) { |
77 | LOG(INFO) << "Skipping replacement for " << replace_node.name(); |
78 | CopyOriginalMatch(match, new_nodes); |
79 | return OkStatus(); |
80 | } |
81 | const NodeDef& input_node = match.inputs[0].node; |
82 | string target_name = input_node.name(); |
83 | for (const string& input : replace_node.input()) { |
84 | if (!input.compare(0, target_name.size(), target_name)) { |
85 | if (input.size() == target_name.size() || |
86 | input[target_name.size()] == ':') { |
87 | target_name = input; |
88 | break; |
89 | } |
90 | } |
91 | } |
92 | inputs_to_rename[replace_node.name()] = target_name; |
93 | inputs_to_rename["^" + replace_node.name()] = |
94 | "^" + input_node.name(); |
95 | new_nodes->push_back(input_node); |
96 | any_nodes_removed = true; |
97 | return OkStatus(); |
98 | }, |
99 | {true}, &replaced_graph_def)); |
100 | // Make sure all references to removed nodes now point to their inputs. |
101 | TF_RETURN_IF_ERROR( |
102 | RenameNodeInputs(replaced_graph_def, inputs_to_rename, |
103 | std::unordered_set<string>(), ¤t_graph_def)); |
104 | } while (any_nodes_removed); |
105 | } |
106 | } |
107 | |
108 | *output_graph_def = current_graph_def; |
109 | return OkStatus(); |
110 | } |
111 | |
112 | REGISTER_GRAPH_TRANSFORM("remove_nodes" , RemoveNodes); |
113 | |
114 | } // namespace graph_transforms |
115 | } // namespace tensorflow |
116 | |