1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ |
17 | #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ |
18 | |
19 | #include <set> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/attr_value.pb.h" |
24 | #include "tensorflow/core/framework/attr_value_util.h" |
25 | #include "tensorflow/core/framework/graph.pb.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor.pb.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace graph_transforms { |
33 | |
34 | // Used to quickly look up nodes in the graph def from a name. |
35 | void MapNamesToNodes(const GraphDef& graph_def, |
36 | std::map<string, const NodeDef*>* result); |
37 | |
38 | // For every node in the graph create a list of the nodes that use it as an |
39 | // input. |
40 | void MapNodesToOutputs(const GraphDef& graph_def, |
41 | std::map<string, std::vector<const NodeDef*>>* result); |
42 | |
43 | // NodeDef input strings can contain other information besides the name of an |
44 | // input node. These include: |
45 | // - Optional '^' prefix, indicating this is a control edge. |
46 | // - The required name of the input node. |
47 | // - Optional ':<number>' suffix, showing which output of the node to use. |
48 | // This function takes a raw string, and breaks it into those component parts. |
49 | // The rules for inputs in function libraries are a bit more complex, and |
50 | // aren't handled by this routine. |
51 | void NodeNamePartsFromInput(const string& input_name, string* prefix, |
52 | string* node_name, string* suffix); |
53 | |
54 | // Adds a ':0' port to any inputs with no suffix, to make comparisons easier. |
55 | string CanonicalInputName(const string& input_name); |
56 | |
57 | // Convenience function to strip the optional prefix and suffix components from |
58 | // a string pulled from a NodeDef input, and return the plain node name. |
59 | string NodeNameFromInput(const string& input_name); |
60 | |
61 | // Returns a stable hash for the contents of the NodeDef, so that equivalent |
62 | // nodes should have equal hashes. |
63 | uint64 HashNodeDef(const NodeDef& node); |
64 | |
65 | // Adds the given node name to the end of the node's inputs. |
66 | void AddNodeInput(const string& input_name, NodeDef* node); |
67 | |
68 | // Copies an attribute from one NodeDef to another. |
69 | void CopyNodeAttr(const NodeDef& source, const string& source_key, |
70 | const string& dest_key, NodeDef* dest); |
71 | |
72 | // Inserts a value into a NodeDef's map of attributes. |
73 | // This is a bit different than AddNodeAttr in node_def_util.h because it |
74 | // overwrites any existing attributes with the same key. |
75 | template <class T> |
76 | inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) { |
77 | AttrValue attr_value; |
78 | SetAttrValue(value, &attr_value); |
79 | auto* attr_map = node->mutable_attr(); |
80 | (*attr_map)[key] = attr_value; |
81 | } |
82 | |
83 | template <class T> |
84 | inline void SetNodeTensorAttr(const string& key, const Tensor& tensor, |
85 | NodeDef* node) { |
86 | TensorProto tensor_proto; |
87 | tensor.AsProtoTensorContent(&tensor_proto); |
88 | SetNodeAttr(key, tensor_proto, node); |
89 | } |
90 | |
91 | // Inserts a Tensor into the specified attribute of a NodeDef. |
92 | template <class T> |
93 | inline void SetNodeTensorAttr(const string& key, const TensorShape& shape, |
94 | const std::vector<T>& values, NodeDef* node) { |
95 | const DataType dtype = DataTypeToEnum<T>::v(); |
96 | CHECK_EQ(shape.num_elements(), values.size()); |
97 | Tensor tensor(dtype, shape); |
98 | T* dest_data = tensor.flat<T>().data(); |
99 | std::copy_n(values.data(), values.size(), dest_data); |
100 | SetNodeTensorAttr<T>(key, tensor, node); |
101 | } |
102 | |
103 | // Retrieves a tensor value from a NodeDef attribute. |
104 | Tensor GetNodeTensorAttr(const NodeDef& node, const string& key); |
105 | |
106 | // Creates a copy of the input GraphDef, but only containing the nodes where the |
107 | // supplied selector function returned true. |
108 | void FilterGraphDef(const GraphDef& input_graph_def, |
109 | std::function<bool(const NodeDef&)> selector, |
110 | GraphDef* output_graph_def); |
111 | |
112 | // Creates a copy of the input graph, with all occurrences of the attributes |
113 | // with the names in the argument removed from the node defs. |
114 | void RemoveAttributes(const GraphDef& input_graph_def, |
115 | const std::vector<string>& attributes, |
116 | GraphDef* output_graph_def); |
117 | |
118 | // For a lot of replacement and matching operations it's useful to have the |
119 | // nodes processed in a controlled order, so this does a topological sort to |
120 | // ensure that nodes always appear in the GraphDef.node list after their inputs. |
121 | Status SortByExecutionOrder(const GraphDef& input_graph_def, |
122 | GraphDef* output_graph_def); |
123 | |
124 | // Finds inputs that refer to nodes that are not in the graph. |
125 | void FindInvalidInputs(const GraphDef& graph_def, |
126 | std::vector<std::pair<string, string>>* invalid_inputs); |
127 | |
128 | // Returns a descriptive error status if there are problems spotted with the |
129 | // graph. |
130 | Status IsGraphValid(const GraphDef& graph_def); |
131 | |
132 | // Returns input and output types for a particular NodeDef. |
133 | Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, |
134 | DataTypeVector* outputs); |
135 | |
136 | // Takes a comma-separated string of numbers and parses them into a shape. |
137 | Status TensorShapeFromString(const string& shape_string, TensorShape* result); |
138 | |
139 | // This is used to spot particular subgraphs in a larger model. To use it, |
140 | // create a pattern like: |
141 | // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); |
142 | // This defines a subgraph where a Conv2D has a ResizeBilinear input, which |
143 | // pulls from a MirrorPad op. |
144 | // Regular expressions aren't supported for the op names, but you can use "*" to |
145 | // match any op. You can also use | as a separator to match multiple op names, |
146 | // like "Reshape|Concat|Conv2D". |
147 | struct OpTypePattern { |
148 | string op; |
149 | std::vector<OpTypePattern> inputs; |
150 | string DebugString() const; |
151 | }; |
152 | |
153 | // Returns a sub-graph of nodes that match a pattern. |
154 | struct NodeMatch { |
155 | NodeMatch() : node() {} |
156 | NodeDef node; |
157 | std::vector<NodeMatch> inputs; |
158 | string DebugString() const; |
159 | }; |
160 | |
161 | // Utility class to spot subgraphs matching particular patterns. |
162 | class GraphMatcher { |
163 | public: |
164 | GraphMatcher(const GraphDef& graph_def); |
165 | |
166 | // Sorts the input nodes into execution order, and then skips any previously |
167 | // matches so that no node appears in more than one match. The NodeDef |
168 | // pointers contained in the results are owned by the GraphMatcher object, and |
169 | // so will be invalid after its lifetime. |
170 | Status GetOpTypeMatches(const OpTypePattern& pattern, |
171 | std::vector<NodeMatch>* matches); |
172 | |
173 | private: |
174 | bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern, |
175 | const std::set<string>& previously_matched_nodes, |
176 | NodeMatch* match); |
177 | |
178 | GraphDef graph_def_; |
179 | std::map<string, const NodeDef*> node_map_; |
180 | }; |
181 | |
182 | struct ReplaceMatchingOpTypesOptions { |
183 | // Whether to raise an error if the graph is left with dangling inputs. If you |
184 | // enable this option, you must fix inconsistencies in a later pass. |
185 | bool allow_inconsistencies; |
186 | }; |
187 | |
188 | // Replaces all of the matching sub-graphs with new ops. This calls into the |
189 | // given function, and expects to receive a set of new nodes to replace each |
190 | // matched sub-graph. It has some logic to protect the integrity of the |
191 | // resulting graph, for example making sure that nodes needed by other nodes |
192 | // outside the sub-graph aren't removed. These are passed in as the set of |
193 | // outputs, and nodes with the same names must be added to the new nodes |
194 | // produced by the replacement function. Many of these checks can be disabled |
195 | // by setting allow_inconsistencies to true in the options, but then it's the |
196 | // caller's responsibility to patch up any problems before passing on the graph |
197 | // to others. There's more comprehensive usage documentation in the README. |
198 | Status ReplaceMatchingOpTypes( |
199 | const GraphDef& input_graph_def, const OpTypePattern& pattern, |
200 | const std::function<Status(const NodeMatch&, const std::set<string>&, |
201 | const std::set<string>&, std::vector<NodeDef>*)>& |
202 | node_generator, |
203 | const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def); |
204 | |
205 | // Returns a list of the unique nodes found in this match. |
206 | void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result); |
207 | |
208 | // Changes all input references to a particular node name. Any nodes with names |
209 | // listed in nodes_to_ignore will not have their inputs rewritten. |
210 | Status RenameNodeInputs(const GraphDef& input_graph_def, |
211 | const std::map<string, string>& inputs_to_rename, |
212 | const std::unordered_set<string>& nodes_to_ignore, |
213 | GraphDef* output_graph_def); |
214 | |
215 | // Utility function that copies all the nodes found in a match into the |
216 | // new_nodes list. This is useful in replacement functions when you decide to |
217 | // leave the original matched subgraph untouched and make no changes. |
218 | void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes); |
219 | |
220 | // Holds information that's needed for transform functions. |
221 | typedef std::map<string, std::vector<string>> TransformFuncParameters; |
222 | struct TransformFuncContext { |
223 | std::vector<string> input_names; |
224 | std::vector<string> output_names; |
225 | TransformFuncParameters params; |
226 | |
227 | // Returns how many occurrences of the given parameter are present. |
228 | int CountParameters(const string& name) const; |
229 | |
230 | // Gets a single instance of a parameter, using a default if it's not present. |
231 | Status GetOneStringParameter(const string& name, const string& default_value, |
232 | string* result) const; |
233 | |
234 | // Gets a single occurrence of a parameter as a 32-bit integer, falling back |
235 | // to a default if it isn't present and returning an error if it isn't |
236 | // convertible to a number. |
237 | Status GetOneInt32Parameter(const string& name, int32_t default_value, |
238 | int32* result) const; |
239 | |
240 | // Gets a single occurrence of a parameter as a 64-bit integer, falling back |
241 | // to a default if it isn't present and returning an error if it isn't |
242 | // convertible to a number. |
243 | Status GetOneInt64Parameter(const string& name, int64_t default_value, |
244 | int64_t* result) const; |
245 | |
246 | // Gets a single occurrence of a parameter as a floating point number, falling |
247 | // back to a default if it isn't present and returning an error if it isn't |
248 | // convertible to a number. |
249 | Status GetOneFloatParameter(const string& name, float default_value, |
250 | float* result) const; |
251 | |
252 | // Gets a single occurrence of a parameter as a boolean, falling back to a |
253 | // default if it isn't present and returning an error if it's not one of |
254 | // "true", "1", "false", or "0". |
255 | Status GetOneBoolParameter(const string& name, bool default_value, |
256 | bool* result) const; |
257 | }; |
258 | |
259 | // This is the function API for all graph transformations, taking an input |
260 | // GraphDef and other arguments, and returning a transformed GraphDef. |
261 | typedef std::function<Status(const GraphDef&, |
262 | const TransformFuncContext& context, GraphDef*)> |
263 | TransformFunc; |
264 | |
265 | // To add a new graph transform function, call the macro: |
266 | // REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants); |
267 | // Under the hood this adds the function to the list of known transforms, so you |
268 | // just need to link in the .cc file with your registration call to have access |
269 | // to it through the command line tool. |
270 | // The rest of the machinery below is to enable that automagical registration. |
271 | typedef std::map<string, TransformFunc> TransformRegistry; |
272 | TransformRegistry* GetTransformRegistry(); |
273 | class TransformRegistrar { |
274 | public: |
275 | TransformRegistrar(const string& name, TransformFunc transform_func) { |
276 | TransformRegistry* transform_registry = GetTransformRegistry(); |
277 | (*transform_registry)[name] = transform_func; |
278 | } |
279 | }; |
280 | #define REGISTER_GRAPH_TRANSFORM(name, func) \ |
281 | REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func) |
282 | #define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \ |
283 | REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) |
284 | #define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \ |
285 | static tensorflow::graph_transforms::TransformRegistrar \ |
286 | registrar__body__##ctr##__object(name, func); |
287 | |
288 | } // namespace graph_transforms |
289 | } // namespace tensorflow |
290 | |
291 | #endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ |
292 | |