1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
17#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
18
19#include <vector>
20
21#include "tensorflow/core/framework/function.pb.h"
22#include "tensorflow/core/framework/graph.pb.h"
23#include "tensorflow/core/framework/op.h"
24#include "tensorflow/core/graph/graph.h"
25#include "tensorflow/core/graph/node_builder.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/lib/core/stringpiece.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29
30namespace tensorflow {
31
32// Given a function like:
33// namespace ops {
34// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
35// if (opts.HaveError()) return nullptr;
36// static const string kOpName = "Identity";
37// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
38// opts.op_registry());
39// node_builder.Input(input);
40// return opts.FinalizeBuilder(&node_builder);
41// }
42// } // namespace ops
43//
44// // Or, alternatively:
45// namespace ops {
46// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
47// static const string kOpName = "Identity";
48// return UnaryOp(kOpName, input, opts);
49// }
50// } // namespace ops
51//
52// You call it like:
53// GraphDefBuilder b;
54// using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
55// Node* na = Const(7, b.opts());
56// // Note: WithName() returns a copy, opts is unchanged.
57// Node* nb = Const(5, b.opts().WithName("control-input"));
58// Node* nc = Identity(na, b.opts().WithControlInput(nb));
59// GraphDef graph_def;
60// Status status = b.ToGraphDef(&graph_def);
61// if (!status.ok()) { /* Handle error */ }
62//
63// In tests you can skip the status handling via:
64// GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
65// ...
66// b.ToGraphDef(&graph_def);
67
68class GraphDefBuilder {
69 public:
70 // Options for adding a Node to a Graph.
71 class Options {
72 public:
73 // Sets the Graph (that Nodes will be added to) and the status. The
74 // status may be set to nullptr, in which case errors cause CHECK
75 // failures. The graph and status must outlive *this.
76 Options(Graph* graph, Status* status);
77 ~Options();
78
79 // Methods for setting options. These are const methods: they
80 // return a copy of *this with the option set.
81 Options WithName(StringPiece name) const;
82 Options WithDevice(StringPiece device) const;
83 Options WithControlInput(Node* control_input) const;
84 Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
85
86 // Override the default value for an optional attr.
87 template <class T>
88 Options WithAttr(StringPiece attr_name, T&& value) const {
89 return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
90 }
91 // Note: overload needed to allow {...} expressions for value.
92 template <class T>
93 Options WithAttr(StringPiece attr_name,
94 std::initializer_list<T> value) const {
95 return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
96 }
97
98 // Methods for using options from a function that creates a Node.
99
100 // Returns true if the status associated with *this has an error.
101 // Use this to skip processing that may depend on prior results.
102 bool HaveError() const { return status_ != nullptr && !status_->ok(); }
103
104 // Returns a string representation of the status associated with *this.
105 // Returns the string `"OK"` if the status doesn't have any error.
106 string StatusToString() const {
107 return status_->ok() ? "OK" : status_->error_message();
108 }
109
110 // Given the Op type name, return a name for a node of that type.
111 // Uses the value set in WithName() if that has been called. Otherwise,
112 // returns a name built out of the Op type name.
113 string GetNameForOp(StringPiece op) const;
114
115 // Sets the device, adds control inputs, adds attrs, and calls Finalize().
116 // If Finalize returns an error, it is saved and this function returns
117 // nullptr.
118 Node* FinalizeBuilder(NodeBuilder* builder) const;
119
120 // Updates the associated status, if any, or calls TF_CHECK_OK if none.
121 void UpdateStatus(const Status& status) const;
122
123 // Accessor
124 const OpRegistryInterface* op_registry() const {
125 return graph_->op_registry();
126 }
127
128 private:
129 Options WithNameImpl(StringPiece name);
130 Options WithDeviceImpl(StringPiece device);
131 Options WithControlInputImpl(Node* control_input);
132 Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
133 template <class T>
134 Options WithAttrImpl(StringPiece name, T&& value) {
135 attrs_.emplace_back(string(name), AttrValue());
136 SetAttrValue(std::forward<T>(value), &attrs_.back().second);
137 return *this;
138 }
139
140 Graph* const graph_;
141 Status* const status_;
142 string name_;
143 string device_;
144 std::vector<Node*> control_inputs_;
145 std::vector<std::pair<string, AttrValue>> attrs_;
146 };
147
148 // Start building a new graph.
149 explicit GraphDefBuilder(
150 const OpRegistryInterface* op_registry = OpRegistry::Global())
151 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {}
152
153 // For use in tests, where you want to fail immediately on error instead
154 // of checking the status at the end.
155 enum TestFailImmediatelyType { kFailImmediately };
156 explicit GraphDefBuilder(
157 TestFailImmediatelyType,
158 const OpRegistryInterface* op_registry = OpRegistry::Global())
159 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {}
160
161 // Gets the Options with the associated Graph and Status.
162 const Options& opts() const { return opts_; }
163
164 // Once all the nodes have been added, call this to get whether it was
165 // successful, and if so fill *graph_def.
166 Status ToGraphDef(GraphDef* graph_def) const;
167
168 // Adds the function and gradient definitions in `fdef_lib` to this graph's op
169 // registry. Ignores duplicate functions, and returns a bad status if an
170 // imported function differs from an existing function or op with the same
171 // name.
172 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
173 return flib_def_.AddLibrary(fdef_lib);
174 }
175
176 // Returns whether a user-defined function with `name` already exists in the
177 // graph.
178 bool HasFunction(const string& name) {
179 return flib_def_.Find(name) != nullptr;
180 }
181
182 private:
183 Graph graph_;
184 FunctionLibraryDefinition flib_def_;
185 Status status_;
186 Options opts_;
187};
188
189namespace ops {
190
191// A NodeOut may either be a regular input or back input. Regular
192// inputs are specified via either a Node* or a Node* and an output
193// index. Back inputs are specified by a node name, output index, and
194// output type.
195typedef NodeBuilder::NodeOut NodeOut;
196
197// For adding an Op with no inputs to a GraphDefBuilder.
198Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
199
200// For adding an Op with one input to a GraphDefBuilder.
201Node* UnaryOp(const string& op_name, NodeOut input,
202 const GraphDefBuilder::Options& opts);
203
204// For adding an Op with two inputs to a GraphDefBuilder.
205Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
206 const GraphDefBuilder::Options& opts);
207
208// For adding an Op with three inputs to a GraphDefBuilder.
209Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c,
210 const GraphDefBuilder::Options& opts);
211
212} // namespace ops
213} // namespace tensorflow
214
215#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
216