1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ |
18 | |
19 | #include "tensorflow/core/framework/graph.pb.h" |
20 | #include "tensorflow/core/graph/graph.h" |
21 | #include "tensorflow/core/graph/tensor_id.h" |
22 | #include "tensorflow/core/lib/core/status.h" |
23 | |
24 | namespace tensorflow { |
25 | class ShapeRefiner; |
26 | |
27 | // Construct a Graph *g out of a GraphDef gdef. Returns non-OK on |
28 | // error, in which case *g is left in an incomplete state. |
29 | // |
30 | // *g is expected to be an empty graph (with no more than a source and sink |
31 | // nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph, |
32 | // see ImportGraphDef. |
33 | struct GraphConstructorOptions { |
34 | GraphConstructorOptions() {} |
35 | |
36 | // If true, allows internal ops in the GraphDef. |
37 | bool allow_internal_ops = false; |
38 | |
39 | // If true, the graph def is expected to have fully specified |
40 | // devices for all nodes. A node in the resulting graph "g" has the |
41 | // device name set accordingly. |
42 | // |
43 | // TODO(zhifengc): if possible, consider removing this option. |
44 | bool expect_device_spec = false; |
45 | |
46 | // If true, validates that nodes being converted have all expected attrs |
47 | // set and no unknown attrs set by calling ValidateNodeDef(). |
48 | // Setting validate_nodes without add_default_attributes, will fail if |
49 | // the GraphDef does not have all required attributes set. |
50 | bool validate_nodes = false; |
51 | |
52 | // If true, GraphConstructor will add attributes with their default |
53 | // value to the Node when they are missing from the NodeDef. |
54 | bool add_default_attributes = true; |
55 | }; |
56 | extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, |
57 | const GraphDef& gdef, Graph* g); |
58 | extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, |
59 | GraphDef&& gdef, Graph* g); |
60 | |
61 | // Same as ConvertGraphDefToGraph, but takes just nodes. Used by function |
62 | // instantiation. |
63 | // TODO(irving): This will turn into std::vector<NodeInfoPtr> soon. |
64 | extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, |
65 | gtl::ArraySlice<NodeDef> nodes, Graph* g); |
66 | |
67 | // Options for calling ImportGraphDef(). |
68 | struct ImportGraphDefOptions { |
69 | ImportGraphDefOptions() |
70 | : uniquify_names(false), |
71 | uniquify_prefix(false), |
72 | skip_mapped_nodes(false), |
73 | validate_shape(true) {} |
74 | |
75 | // Name prefix to use for nodes imported from the GraphDef. For example, if |
76 | // prefix="animals" and GraphDef contains a node "bunny" then the node will be |
77 | // named "animals/bunny" in *g. Must not be already used as a node name or |
78 | // prefix in the graph. |
79 | string prefix; |
80 | |
81 | // If true, imported node names will be modified if their name already exists |
82 | // in the graph. If false, conflicting names will be treated as an error. Note |
83 | // that this option has no effect if `prefix` is specified, since `prefix` |
84 | // will guarantee all node names are unique. |
85 | bool uniquify_names; |
86 | |
87 | // If true, `prefix` will be modified if it already exists as a node name or |
88 | // prefix in the graph. If false, a conflicting prefix will be treated as an |
89 | // error. This option has no effect if `prefix` isn't specified. |
90 | bool uniquify_prefix; |
91 | |
92 | // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef` |
93 | // corresponding to `input_map` keys will be remapped to the nodes in `g` |
94 | // corresponding to the values. |
95 | // |
96 | // Keys should not include `prefix`, i.e., a key ID's name should be the name |
97 | // as it originally appears in `gdef`. |
98 | // |
99 | // If this is non-empty, ImportGraphDef must be called with the shape refiner |
100 | // used to create the existing nodes referenced in `input_map`. |
101 | // TODO(skyewm): can we remove this requirement? How do we access the original |
102 | // shape refiner? |
103 | std::map<SafeTensorId, SafeTensorId> input_map; |
104 | |
105 | // If true, nodes that will have all output edges removed because of |
106 | // overrides in `input_map` will not be imported. |
107 | bool skip_mapped_nodes; |
108 | |
109 | // The names of existing nodes in `g` that the imported graph should have |
110 | // control dependencies on. |
111 | // |
112 | // Note that to avoid creating many redundant control edges, ImportGraphDef() |
113 | // won't add control edges to nodes that will inherit the dependencies from |
114 | // other nodes in `gdef`. |
115 | std::vector<string> control_dependencies; |
116 | |
117 | // Tensors in `gdef` that will be returned via the ImportGraphDefResults |
118 | // output parameter of `ImportGraphDef()`. If this list is non-empty, the |
119 | // caller must pass a results object to `ImportGraphDef()`. The |
120 | // `return_tensors` field will be populated with the imported nodes in `g`. |
121 | // |
122 | // Entries should not include `prefix`, i.e., each ID's name should be the |
123 | // name as it originally appears in `gdef`. |
124 | // |
125 | // If this contains a tensor that's also being remapped via `input_map`, the |
126 | // corresponding existing tensor in `g` will be returned. |
127 | std::vector<SafeTensorId> return_tensors; |
128 | |
129 | // The names of nodes in `gdef` that will be returned via the |
130 | // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list |
131 | // is non-empty, the caller must pass a results object to |
132 | // `ImportGraphDef()`. The `return_nodes` field will be populated with the |
133 | // imported nodes in `g`. |
134 | // |
135 | // Entries should not include `prefix`, i.e., each node's name should be the |
136 | // name as it originally appears in `gdef`. |
137 | // |
138 | // Unlike `return_tensors`, `input_map` has no effect on the nodes |
139 | // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. |
140 | // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. |
141 | std::vector<string> return_nodes; |
142 | |
143 | // If true, checks that all colocation constraints are nodes in the GraphDef. |
144 | bool validate_colocation_constraints = true; |
145 | |
146 | // If false skips shape validation. |
147 | bool validate_shape; |
148 | |
149 | // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries |
150 | // with ops that are not defined in the binary calling ImportGraphDef. |
151 | // Similar to the producer_op_list argument to import_graph_def in the |
152 | // python API. |
153 | |
154 | // Try to set default execution device for this grapth. |
155 | string default_device; |
156 | }; |
157 | |
158 | // Optional results that may be returned by ImportGraphDef. |
159 | struct ImportGraphDefResults { |
160 | // The requested tensors associated with |
161 | // ImportGraphDefOptions::return_tensors. Note that the index may be different |
162 | // than the requested index if the returned tensor has been remapped according |
163 | // to `input_map`. |
164 | typedef int Index; |
165 | std::vector<std::pair<Node*, Index>> return_tensors; |
166 | |
167 | // The requested nodes associated with ImportGraphDefOptions::return_nodes. |
168 | std::vector<Node*> return_nodes; |
169 | |
170 | // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and |
171 | // weren't used as an input to any node in `gdef`. These keys are likely due |
172 | // to typos, and callers may wish to treat their existence as an error. |
173 | std::vector<SafeTensorId> missing_unused_input_map_keys; |
174 | }; |
175 | |
176 | // Adds the graph in GraphDef `gdef` into an existing Graph `*g`. |
177 | // |
178 | // On error, returns non-OK and leaves `*g` unmodified. |
179 | // |
180 | // `refiner` can be null. It should be non-null if the caller |
181 | // intends to add additional nodes to the graph after the import. This |
182 | // allows the caller to validate shapes of those nodes (since |
183 | // ShapeRefiner::AddNode must be called in topological order). |
184 | // |
185 | // `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is |
186 | // non-empty. It can also be set to fetch the unused input map keys. If it's |
187 | // non-null, all the vector fields must be empty. |
188 | // |
189 | // TODO(ashankar): Push this mechanism and get rid of Session::Extend() |
190 | // as a means of enhancing an existing Graph. |
191 | extern Status ImportGraphDef(const ImportGraphDefOptions& opts, |
192 | const GraphDef& gdef, Graph* g, |
193 | ShapeRefiner* refiner, |
194 | ImportGraphDefResults* results = nullptr); |
195 | |
196 | // Make a copy of "src" into "*dest". |
197 | // |
198 | // REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges |
199 | // other than the implicit Source/Sink nodes. |
200 | extern void CopyGraph(const Graph& src, Graph* dest); |
201 | |
202 | } // namespace tensorflow |
203 | |
204 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ |
205 | |