1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
17#define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
18
19#include <set>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/attr_value.pb.h"
24#include "tensorflow/core/framework/attr_value_util.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor.pb.h"
29#include "tensorflow/core/lib/core/status.h"
30
31namespace tensorflow {
32namespace graph_transforms {
33
34// Used to quickly look up nodes in the graph def from a name.
35void MapNamesToNodes(const GraphDef& graph_def,
36 std::map<string, const NodeDef*>* result);
37
38// For every node in the graph create a list of the nodes that use it as an
39// input.
40void MapNodesToOutputs(const GraphDef& graph_def,
41 std::map<string, std::vector<const NodeDef*>>* result);
42
43// NodeDef input strings can contain other information besides the name of an
44// input node. These include:
45// - Optional '^' prefix, indicating this is a control edge.
46// - The required name of the input node.
47// - Optional ':<number>' suffix, showing which output of the node to use.
48// This function takes a raw string, and breaks it into those component parts.
49// The rules for inputs in function libraries are a bit more complex, and
50// aren't handled by this routine.
51void NodeNamePartsFromInput(const string& input_name, string* prefix,
52 string* node_name, string* suffix);
53
54// Adds a ':0' port to any inputs with no suffix, to make comparisons easier.
55string CanonicalInputName(const string& input_name);
56
57// Convenience function to strip the optional prefix and suffix components from
58// a string pulled from a NodeDef input, and return the plain node name.
59string NodeNameFromInput(const string& input_name);
60
61// Returns a stable hash for the contents of the NodeDef, so that equivalent
62// nodes should have equal hashes.
63uint64 HashNodeDef(const NodeDef& node);
64
65// Adds the given node name to the end of the node's inputs.
66void AddNodeInput(const string& input_name, NodeDef* node);
67
68// Copies an attribute from one NodeDef to another.
69void CopyNodeAttr(const NodeDef& source, const string& source_key,
70 const string& dest_key, NodeDef* dest);
71
72// Inserts a value into a NodeDef's map of attributes.
73// This is a bit different than AddNodeAttr in node_def_util.h because it
74// overwrites any existing attributes with the same key.
75template <class T>
76inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
77 AttrValue attr_value;
78 SetAttrValue(value, &attr_value);
79 auto* attr_map = node->mutable_attr();
80 (*attr_map)[key] = attr_value;
81}
82
83template <class T>
84inline void SetNodeTensorAttr(const string& key, const Tensor& tensor,
85 NodeDef* node) {
86 TensorProto tensor_proto;
87 tensor.AsProtoTensorContent(&tensor_proto);
88 SetNodeAttr(key, tensor_proto, node);
89}
90
91// Inserts a Tensor into the specified attribute of a NodeDef.
92template <class T>
93inline void SetNodeTensorAttr(const string& key, const TensorShape& shape,
94 const std::vector<T>& values, NodeDef* node) {
95 const DataType dtype = DataTypeToEnum<T>::v();
96 CHECK_EQ(shape.num_elements(), values.size());
97 Tensor tensor(dtype, shape);
98 T* dest_data = tensor.flat<T>().data();
99 std::copy_n(values.data(), values.size(), dest_data);
100 SetNodeTensorAttr<T>(key, tensor, node);
101}
102
103// Retrieves a tensor value from a NodeDef attribute.
104Tensor GetNodeTensorAttr(const NodeDef& node, const string& key);
105
106// Creates a copy of the input GraphDef, but only containing the nodes where the
107// supplied selector function returned true.
108void FilterGraphDef(const GraphDef& input_graph_def,
109 std::function<bool(const NodeDef&)> selector,
110 GraphDef* output_graph_def);
111
112// Creates a copy of the input graph, with all occurrences of the attributes
113// with the names in the argument removed from the node defs.
114void RemoveAttributes(const GraphDef& input_graph_def,
115 const std::vector<string>& attributes,
116 GraphDef* output_graph_def);
117
118// For a lot of replacement and matching operations it's useful to have the
119// nodes processed in a controlled order, so this does a topological sort to
120// ensure that nodes always appear in the GraphDef.node list after their inputs.
121Status SortByExecutionOrder(const GraphDef& input_graph_def,
122 GraphDef* output_graph_def);
123
124// Finds inputs that refer to nodes that are not in the graph.
125void FindInvalidInputs(const GraphDef& graph_def,
126 std::vector<std::pair<string, string>>* invalid_inputs);
127
128// Returns a descriptive error status if there are problems spotted with the
129// graph.
130Status IsGraphValid(const GraphDef& graph_def);
131
132// Returns input and output types for a particular NodeDef.
133Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
134 DataTypeVector* outputs);
135
136// Takes a comma-separated string of numbers and parses them into a shape.
137Status TensorShapeFromString(const string& shape_string, TensorShape* result);
138
139// This is used to spot particular subgraphs in a larger model. To use it,
140// create a pattern like:
141// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
142// This defines a subgraph where a Conv2D has a ResizeBilinear input, which
143// pulls from a MirrorPad op.
144// Regular expressions aren't supported for the op names, but you can use "*" to
145// match any op. You can also use | as a separator to match multiple op names,
146// like "Reshape|Concat|Conv2D".
147struct OpTypePattern {
148 string op;
149 std::vector<OpTypePattern> inputs;
150 string DebugString() const;
151};
152
153// Returns a sub-graph of nodes that match a pattern.
154struct NodeMatch {
155 NodeMatch() : node() {}
156 NodeDef node;
157 std::vector<NodeMatch> inputs;
158 string DebugString() const;
159};
160
161// Utility class to spot subgraphs matching particular patterns.
162class GraphMatcher {
163 public:
164 GraphMatcher(const GraphDef& graph_def);
165
166 // Sorts the input nodes into execution order, and then skips any previously
167 // matches so that no node appears in more than one match. The NodeDef
168 // pointers contained in the results are owned by the GraphMatcher object, and
169 // so will be invalid after its lifetime.
170 Status GetOpTypeMatches(const OpTypePattern& pattern,
171 std::vector<NodeMatch>* matches);
172
173 private:
174 bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern,
175 const std::set<string>& previously_matched_nodes,
176 NodeMatch* match);
177
178 GraphDef graph_def_;
179 std::map<string, const NodeDef*> node_map_;
180};
181
182struct ReplaceMatchingOpTypesOptions {
183 // Whether to raise an error if the graph is left with dangling inputs. If you
184 // enable this option, you must fix inconsistencies in a later pass.
185 bool allow_inconsistencies;
186};
187
188// Replaces all of the matching sub-graphs with new ops. This calls into the
189// given function, and expects to receive a set of new nodes to replace each
190// matched sub-graph. It has some logic to protect the integrity of the
191// resulting graph, for example making sure that nodes needed by other nodes
192// outside the sub-graph aren't removed. These are passed in as the set of
193// outputs, and nodes with the same names must be added to the new nodes
194// produced by the replacement function. Many of these checks can be disabled
195// by setting allow_inconsistencies to true in the options, but then it's the
196// caller's responsibility to patch up any problems before passing on the graph
197// to others. There's more comprehensive usage documentation in the README.
198Status ReplaceMatchingOpTypes(
199 const GraphDef& input_graph_def, const OpTypePattern& pattern,
200 const std::function<Status(const NodeMatch&, const std::set<string>&,
201 const std::set<string>&, std::vector<NodeDef>*)>&
202 node_generator,
203 const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def);
204
205// Returns a list of the unique nodes found in this match.
206void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
207
208// Changes all input references to a particular node name. Any nodes with names
209// listed in nodes_to_ignore will not have their inputs rewritten.
210Status RenameNodeInputs(const GraphDef& input_graph_def,
211 const std::map<string, string>& inputs_to_rename,
212 const std::unordered_set<string>& nodes_to_ignore,
213 GraphDef* output_graph_def);
214
215// Utility function that copies all the nodes found in a match into the
216// new_nodes list. This is useful in replacement functions when you decide to
217// leave the original matched subgraph untouched and make no changes.
218void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes);
219
220// Holds information that's needed for transform functions.
221typedef std::map<string, std::vector<string>> TransformFuncParameters;
222struct TransformFuncContext {
223 std::vector<string> input_names;
224 std::vector<string> output_names;
225 TransformFuncParameters params;
226
227 // Returns how many occurrences of the given parameter are present.
228 int CountParameters(const string& name) const;
229
230 // Gets a single instance of a parameter, using a default if it's not present.
231 Status GetOneStringParameter(const string& name, const string& default_value,
232 string* result) const;
233
234 // Gets a single occurrence of a parameter as a 32-bit integer, falling back
235 // to a default if it isn't present and returning an error if it isn't
236 // convertible to a number.
237 Status GetOneInt32Parameter(const string& name, int32_t default_value,
238 int32* result) const;
239
240 // Gets a single occurrence of a parameter as a 64-bit integer, falling back
241 // to a default if it isn't present and returning an error if it isn't
242 // convertible to a number.
243 Status GetOneInt64Parameter(const string& name, int64_t default_value,
244 int64_t* result) const;
245
246 // Gets a single occurrence of a parameter as a floating point number, falling
247 // back to a default if it isn't present and returning an error if it isn't
248 // convertible to a number.
249 Status GetOneFloatParameter(const string& name, float default_value,
250 float* result) const;
251
252 // Gets a single occurrence of a parameter as a boolean, falling back to a
253 // default if it isn't present and returning an error if it's not one of
254 // "true", "1", "false", or "0".
255 Status GetOneBoolParameter(const string& name, bool default_value,
256 bool* result) const;
257};
258
259// This is the function API for all graph transformations, taking an input
260// GraphDef and other arguments, and returning a transformed GraphDef.
261typedef std::function<Status(const GraphDef&,
262 const TransformFuncContext& context, GraphDef*)>
263 TransformFunc;
264
265// To add a new graph transform function, call the macro:
266// REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
267// Under the hood this adds the function to the list of known transforms, so you
268// just need to link in the .cc file with your registration call to have access
269// to it through the command line tool.
270// The rest of the machinery below is to enable that automagical registration.
271typedef std::map<string, TransformFunc> TransformRegistry;
272TransformRegistry* GetTransformRegistry();
273class TransformRegistrar {
274 public:
275 TransformRegistrar(const string& name, TransformFunc transform_func) {
276 TransformRegistry* transform_registry = GetTransformRegistry();
277 (*transform_registry)[name] = transform_func;
278 }
279};
280#define REGISTER_GRAPH_TRANSFORM(name, func) \
281 REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func)
282#define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \
283 REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)
284#define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \
285 static tensorflow::graph_transforms::TransformRegistrar \
286 registrar__body__##ctr##__object(name, func);
287
288} // namespace graph_transforms
289} // namespace tensorflow
290
291#endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
292