1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
18
19#include "tensorflow/core/framework/graph.pb.h"
20#include "tensorflow/core/graph/graph.h"
21#include "tensorflow/core/graph/tensor_id.h"
22#include "tensorflow/core/lib/core/status.h"
23
24namespace tensorflow {
25class ShapeRefiner;
26
27// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
28// error, in which case *g is left in an incomplete state.
29//
30// *g is expected to be an empty graph (with no more than a source and sink
31// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
32// see ImportGraphDef.
33struct GraphConstructorOptions {
34 GraphConstructorOptions() {}
35
36 // If true, allows internal ops in the GraphDef.
37 bool allow_internal_ops = false;
38
39 // If true, the graph def is expected to have fully specified
40 // devices for all nodes. A node in the resulting graph "g" has the
41 // device name set accordingly.
42 //
43 // TODO(zhifengc): if possible, consider removing this option.
44 bool expect_device_spec = false;
45
46 // If true, validates that nodes being converted have all expected attrs
47 // set and no unknown attrs set by calling ValidateNodeDef().
48 // Setting validate_nodes without add_default_attributes, will fail if
49 // the GraphDef does not have all required attributes set.
50 bool validate_nodes = false;
51
52 // If true, GraphConstructor will add attributes with their default
53 // value to the Node when they are missing from the NodeDef.
54 bool add_default_attributes = true;
55};
56extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
57 const GraphDef& gdef, Graph* g);
58extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
59 GraphDef&& gdef, Graph* g);
60
61// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function
62// instantiation.
63// TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
64extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
65 gtl::ArraySlice<NodeDef> nodes, Graph* g);
66
67// Options for calling ImportGraphDef().
68struct ImportGraphDefOptions {
69 ImportGraphDefOptions()
70 : uniquify_names(false),
71 uniquify_prefix(false),
72 skip_mapped_nodes(false),
73 validate_shape(true) {}
74
75 // Name prefix to use for nodes imported from the GraphDef. For example, if
76 // prefix="animals" and GraphDef contains a node "bunny" then the node will be
77 // named "animals/bunny" in *g. Must not be already used as a node name or
78 // prefix in the graph.
79 string prefix;
80
81 // If true, imported node names will be modified if their name already exists
82 // in the graph. If false, conflicting names will be treated as an error. Note
83 // that this option has no effect if `prefix` is specified, since `prefix`
84 // will guarantee all node names are unique.
85 bool uniquify_names;
86
87 // If true, `prefix` will be modified if it already exists as a node name or
88 // prefix in the graph. If false, a conflicting prefix will be treated as an
89 // error. This option has no effect if `prefix` isn't specified.
90 bool uniquify_prefix;
91
92 // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
93 // corresponding to `input_map` keys will be remapped to the nodes in `g`
94 // corresponding to the values.
95 //
96 // Keys should not include `prefix`, i.e., a key ID's name should be the name
97 // as it originally appears in `gdef`.
98 //
99 // If this is non-empty, ImportGraphDef must be called with the shape refiner
100 // used to create the existing nodes referenced in `input_map`.
101 // TODO(skyewm): can we remove this requirement? How do we access the original
102 // shape refiner?
103 std::map<SafeTensorId, SafeTensorId> input_map;
104
105 // If true, nodes that will have all output edges removed because of
106 // overrides in `input_map` will not be imported.
107 bool skip_mapped_nodes;
108
109 // The names of existing nodes in `g` that the imported graph should have
110 // control dependencies on.
111 //
112 // Note that to avoid creating many redundant control edges, ImportGraphDef()
113 // won't add control edges to nodes that will inherit the dependencies from
114 // other nodes in `gdef`.
115 std::vector<string> control_dependencies;
116
117 // Tensors in `gdef` that will be returned via the ImportGraphDefResults
118 // output parameter of `ImportGraphDef()`. If this list is non-empty, the
119 // caller must pass a results object to `ImportGraphDef()`. The
120 // `return_tensors` field will be populated with the imported nodes in `g`.
121 //
122 // Entries should not include `prefix`, i.e., each ID's name should be the
123 // name as it originally appears in `gdef`.
124 //
125 // If this contains a tensor that's also being remapped via `input_map`, the
126 // corresponding existing tensor in `g` will be returned.
127 std::vector<SafeTensorId> return_tensors;
128
129 // The names of nodes in `gdef` that will be returned via the
130 // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
131 // is non-empty, the caller must pass a results object to
132 // `ImportGraphDef()`. The `return_nodes` field will be populated with the
133 // imported nodes in `g`.
134 //
135 // Entries should not include `prefix`, i.e., each node's name should be the
136 // name as it originally appears in `gdef`.
137 //
138 // Unlike `return_tensors`, `input_map` has no effect on the nodes
139 // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
140 // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
141 std::vector<string> return_nodes;
142
143 // If true, checks that all colocation constraints are nodes in the GraphDef.
144 bool validate_colocation_constraints = true;
145
146 // If false skips shape validation.
147 bool validate_shape;
148
149 // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
150 // with ops that are not defined in the binary calling ImportGraphDef.
151 // Similar to the producer_op_list argument to import_graph_def in the
152 // python API.
153
154 // Try to set default execution device for this grapth.
155 string default_device;
156};
157
158// Optional results that may be returned by ImportGraphDef.
159struct ImportGraphDefResults {
160 // The requested tensors associated with
161 // ImportGraphDefOptions::return_tensors. Note that the index may be different
162 // than the requested index if the returned tensor has been remapped according
163 // to `input_map`.
164 typedef int Index;
165 std::vector<std::pair<Node*, Index>> return_tensors;
166
167 // The requested nodes associated with ImportGraphDefOptions::return_nodes.
168 std::vector<Node*> return_nodes;
169
170 // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
171 // weren't used as an input to any node in `gdef`. These keys are likely due
172 // to typos, and callers may wish to treat their existence as an error.
173 std::vector<SafeTensorId> missing_unused_input_map_keys;
174};
175
176// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
177//
178// On error, returns non-OK and leaves `*g` unmodified.
179//
180// `refiner` can be null. It should be non-null if the caller
181// intends to add additional nodes to the graph after the import. This
182// allows the caller to validate shapes of those nodes (since
183// ShapeRefiner::AddNode must be called in topological order).
184//
185// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
186// non-empty. It can also be set to fetch the unused input map keys. If it's
187// non-null, all the vector fields must be empty.
188//
189// TODO(ashankar): Push this mechanism and get rid of Session::Extend()
190// as a means of enhancing an existing Graph.
191extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
192 const GraphDef& gdef, Graph* g,
193 ShapeRefiner* refiner,
194 ImportGraphDefResults* results = nullptr);
195
196// Make a copy of "src" into "*dest".
197//
198// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
199// other than the implicit Source/Sink nodes.
200extern void CopyGraph(const Graph& src, Graph* dest);
201
202} // namespace tensorflow
203
204#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_
205