1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ |
17 | #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/function.pb.h" |
22 | #include "tensorflow/core/framework/graph.pb.h" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/graph/graph.h" |
25 | #include "tensorflow/core/graph/node_builder.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/lib/core/stringpiece.h" |
28 | #include "tensorflow/core/lib/gtl/array_slice.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | // Given a function like: |
33 | // namespace ops { |
34 | // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { |
35 | // if (opts.HaveError()) return nullptr; |
36 | // static const string kOpName = "Identity"; |
37 | // NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, |
38 | // opts.op_registry()); |
39 | // node_builder.Input(input); |
40 | // return opts.FinalizeBuilder(&node_builder); |
41 | // } |
42 | // } // namespace ops |
43 | // |
44 | // // Or, alternatively: |
45 | // namespace ops { |
46 | // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { |
47 | // static const string kOpName = "Identity"; |
48 | // return UnaryOp(kOpName, input, opts); |
49 | // } |
50 | // } // namespace ops |
51 | // |
52 | // You call it like: |
53 | // GraphDefBuilder b; |
54 | // using namespace ::tensorflow::ops; // NOLINT(build/namespaces) |
55 | // Node* na = Const(7, b.opts()); |
56 | // // Note: WithName() returns a copy, opts is unchanged. |
57 | // Node* nb = Const(5, b.opts().WithName("control-input")); |
58 | // Node* nc = Identity(na, b.opts().WithControlInput(nb)); |
59 | // GraphDef graph_def; |
60 | // Status status = b.ToGraphDef(&graph_def); |
61 | // if (!status.ok()) { /* Handle error */ } |
62 | // |
63 | // In tests you can skip the status handling via: |
64 | // GraphDefBuilder b(GraphDefBuilder::kFailImmediately); |
65 | // ... |
66 | // b.ToGraphDef(&graph_def); |
67 | |
68 | class GraphDefBuilder { |
69 | public: |
70 | // Options for adding a Node to a Graph. |
71 | class Options { |
72 | public: |
73 | // Sets the Graph (that Nodes will be added to) and the status. The |
74 | // status may be set to nullptr, in which case errors cause CHECK |
75 | // failures. The graph and status must outlive *this. |
76 | Options(Graph* graph, Status* status); |
77 | ~Options(); |
78 | |
79 | // Methods for setting options. These are const methods: they |
80 | // return a copy of *this with the option set. |
81 | Options WithName(StringPiece name) const; |
82 | Options WithDevice(StringPiece device) const; |
83 | Options WithControlInput(Node* control_input) const; |
84 | Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; |
85 | |
86 | // Override the default value for an optional attr. |
87 | template <class T> |
88 | Options WithAttr(StringPiece attr_name, T&& value) const { |
89 | return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); |
90 | } |
91 | // Note: overload needed to allow {...} expressions for value. |
92 | template <class T> |
93 | Options WithAttr(StringPiece attr_name, |
94 | std::initializer_list<T> value) const { |
95 | return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); |
96 | } |
97 | |
98 | // Methods for using options from a function that creates a Node. |
99 | |
100 | // Returns true if the status associated with *this has an error. |
101 | // Use this to skip processing that may depend on prior results. |
102 | bool HaveError() const { return status_ != nullptr && !status_->ok(); } |
103 | |
104 | // Returns a string representation of the status associated with *this. |
105 | // Returns the string `"OK"` if the status doesn't have any error. |
106 | string StatusToString() const { |
107 | return status_->ok() ? "OK" : status_->error_message(); |
108 | } |
109 | |
110 | // Given the Op type name, return a name for a node of that type. |
111 | // Uses the value set in WithName() if that has been called. Otherwise, |
112 | // returns a name built out of the Op type name. |
113 | string GetNameForOp(StringPiece op) const; |
114 | |
115 | // Sets the device, adds control inputs, adds attrs, and calls Finalize(). |
116 | // If Finalize returns an error, it is saved and this function returns |
117 | // nullptr. |
118 | Node* FinalizeBuilder(NodeBuilder* builder) const; |
119 | |
120 | // Updates the associated status, if any, or calls TF_CHECK_OK if none. |
121 | void UpdateStatus(const Status& status) const; |
122 | |
123 | // Accessor |
124 | const OpRegistryInterface* op_registry() const { |
125 | return graph_->op_registry(); |
126 | } |
127 | |
128 | private: |
129 | Options WithNameImpl(StringPiece name); |
130 | Options WithDeviceImpl(StringPiece device); |
131 | Options WithControlInputImpl(Node* control_input); |
132 | Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); |
133 | template <class T> |
134 | Options WithAttrImpl(StringPiece name, T&& value) { |
135 | attrs_.emplace_back(string(name), AttrValue()); |
136 | SetAttrValue(std::forward<T>(value), &attrs_.back().second); |
137 | return *this; |
138 | } |
139 | |
140 | Graph* const graph_; |
141 | Status* const status_; |
142 | string name_; |
143 | string device_; |
144 | std::vector<Node*> control_inputs_; |
145 | std::vector<std::pair<string, AttrValue>> attrs_; |
146 | }; |
147 | |
148 | // Start building a new graph. |
149 | explicit GraphDefBuilder( |
150 | const OpRegistryInterface* op_registry = OpRegistry::Global()) |
151 | : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {} |
152 | |
153 | // For use in tests, where you want to fail immediately on error instead |
154 | // of checking the status at the end. |
155 | enum TestFailImmediatelyType { kFailImmediately }; |
156 | explicit GraphDefBuilder( |
157 | TestFailImmediatelyType, |
158 | const OpRegistryInterface* op_registry = OpRegistry::Global()) |
159 | : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {} |
160 | |
161 | // Gets the Options with the associated Graph and Status. |
162 | const Options& opts() const { return opts_; } |
163 | |
164 | // Once all the nodes have been added, call this to get whether it was |
165 | // successful, and if so fill *graph_def. |
166 | Status ToGraphDef(GraphDef* graph_def) const; |
167 | |
168 | // Adds the function and gradient definitions in `fdef_lib` to this graph's op |
169 | // registry. Ignores duplicate functions, and returns a bad status if an |
170 | // imported function differs from an existing function or op with the same |
171 | // name. |
172 | Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { |
173 | return flib_def_.AddLibrary(fdef_lib); |
174 | } |
175 | |
176 | // Returns whether a user-defined function with `name` already exists in the |
177 | // graph. |
178 | bool HasFunction(const string& name) { |
179 | return flib_def_.Find(name) != nullptr; |
180 | } |
181 | |
182 | private: |
183 | Graph graph_; |
184 | FunctionLibraryDefinition flib_def_; |
185 | Status status_; |
186 | Options opts_; |
187 | }; |
188 | |
189 | namespace ops { |
190 | |
191 | // A NodeOut may either be a regular input or back input. Regular |
192 | // inputs are specified via either a Node* or a Node* and an output |
193 | // index. Back inputs are specified by a node name, output index, and |
194 | // output type. |
195 | typedef NodeBuilder::NodeOut NodeOut; |
196 | |
197 | // For adding an Op with no inputs to a GraphDefBuilder. |
198 | Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); |
199 | |
200 | // For adding an Op with one input to a GraphDefBuilder. |
201 | Node* UnaryOp(const string& op_name, NodeOut input, |
202 | const GraphDefBuilder::Options& opts); |
203 | |
204 | // For adding an Op with two inputs to a GraphDefBuilder. |
205 | Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, |
206 | const GraphDefBuilder::Options& opts); |
207 | |
208 | // For adding an Op with three inputs to a GraphDefBuilder. |
209 | Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c, |
210 | const GraphDefBuilder::Options& opts); |
211 | |
212 | } // namespace ops |
213 | } // namespace tensorflow |
214 | |
215 | #endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ |
216 | |