1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
17#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
18
19#include <functional>
20#include <vector>
21
22#include "tensorflow/core/framework/attr_value_util.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/node_def_util.h"
25#include "tensorflow/core/framework/op.h"
26#include "tensorflow/core/framework/op_def.pb.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/graph/graph.h"
29#include "tensorflow/core/graph/graph_node_util.h"
30#include "tensorflow/core/lib/core/status.h"
31#include "tensorflow/core/lib/gtl/array_slice.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33
34namespace tensorflow {
35
36class NodeDefBuilder;
37typedef std::function<Status(const OpDef&, int, const NodeDef&,
38 NodeDefBuilder*)>
39 FakeInputFunctor;
40
41// This is a helper for creating a NodeDef. Automatically sets attrs
42// that can be inferred from the inputs, and uses default values
43// (where they exist) for unspecified attrs. Example usage:
44//
45// NodeDef node_def;
46// Status status = NodeDefBuilder(node_name, op_name)
47// .Input(...)
48// .Attr(...)
49// .Finalize(&node_def);
50// if (!status.ok()) return status;
51// // Use node_def here.
52class NodeDefBuilder {
53 public:
54 // To specify an output to be consumed by one of the Input() methods below.
55 struct NodeOut {
56 NodeOut(StringPiece n, int i, DataType dt);
57 NodeOut(); // uninitialized, call Reset() before use.
58 void Reset(StringPiece n, int i, DataType dt);
59 string node;
60 int index;
61 DataType data_type;
62 };
63
64 // Specify the name and the Op (either via an OpDef or the name of
65 // the Op plus a registry) for the NodeDef. Other fields are
66 // specified by calling the methods below.
67 // REQUIRES: The OpDef must satisfy ValidateOpDef().
68 NodeDefBuilder(StringPiece name, StringPiece op_name,
69 const OpRegistryInterface* op_registry = OpRegistry::Global(),
70 const NodeDebugInfo* debug = nullptr);
71 NodeDefBuilder(StringPiece name, StringPiece op_name,
72 const NodeDebugInfo& debug);
73 // REQUIRES: in addition, *op_def must outlive *this.
74 NodeDefBuilder(StringPiece name, const OpDef* op_def);
75
76 // You must call one Input() function per input_arg in the Op,
77 // *and in the same order as the input_args appear in the OpDef.*
78
79 // For inputs that take a single tensor.
80 NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt);
81 NodeDefBuilder& Input(const NodeOut& src);
82
83 // For inputs that take a list of tensors.
84 NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
85
86 // To create inputs in tests, see fake_input.h.
87 NodeDefBuilder& Input(FakeInputFunctor fake_input);
88
89 // Specify that this node must only run after src_node.
90 NodeDefBuilder& ControlInput(StringPiece src_node);
91
92 // Constrains what devices this node may be scheduled on.
93 NodeDefBuilder& Device(StringPiece device_spec);
94
95 // Sets the attr, if not already set. If already set with a different
96 // value, an error will be returned from Finalize().
97 NodeDefBuilder& Attr(StringPiece name, const AttrValue& value);
98 NodeDefBuilder& Attr(StringPiece name, AttrValue&& value);
99 NodeDefBuilder& Attr(StringPiece name, StringPiece value);
100 NodeDefBuilder& Attr(StringPiece name, const char* value);
101 NodeDefBuilder& Attr(StringPiece name, int32_t value);
102 NodeDefBuilder& Attr(StringPiece name, int64_t value);
103 NodeDefBuilder& Attr(StringPiece name, float value);
104 NodeDefBuilder& Attr(StringPiece name, double value);
105 NodeDefBuilder& Attr(StringPiece name, bool value);
106 NodeDefBuilder& Attr(StringPiece name, DataType value);
107 NodeDefBuilder& Attr(StringPiece name, const PartialTensorShape& value);
108 NodeDefBuilder& Attr(StringPiece name, const Tensor& value);
109 NodeDefBuilder& Attr(StringPiece name, const TensorProto& value);
110 NodeDefBuilder& Attr(StringPiece name, const NameAttrList& value);
111 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<StringPiece> value);
112 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<const char*> value);
113 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<string> value);
114 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<tstring> value);
115 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int32> value);
116 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<int64_t> value);
117 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<float> value);
118 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<bool> value);
119 NodeDefBuilder& Attr(StringPiece name, const std::vector<bool>& value);
120 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<DataType> value);
121 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<TensorShape> value);
122 NodeDefBuilder& Attr(StringPiece name,
123 gtl::ArraySlice<PartialTensorShape> value);
124 NodeDefBuilder& Attr(StringPiece name,
125 gtl::ArraySlice<TensorShapeProto> value);
126 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<Tensor> value);
127 NodeDefBuilder& Attr(StringPiece name, gtl::ArraySlice<NameAttrList> value);
128
129 template <class T>
130 NodeDefBuilder& Attr(StringPiece name, std::initializer_list<T> value) {
131 return Attr(name, gtl::ArraySlice<T>(value));
132 }
133
134 // Finish building the NodeDef, returning any errors or setting
135 // *node_def if none.
136 // If `consume` is true, the builder state will be moved into `node_def`,
137 // and the builder will be left in an undefined state.
138 // WARNING: Not all problems are detected! The resulting NodeDef may
139 // not be valid! Call ValidateNodeDef() from node_def_utils to be sure.
140 Status Finalize(NodeDef* node_def, bool consume = false);
141
142 // Accessors for the values set in the constructor.
143 const string& node_name() const { return node_def_.name(); }
144 const OpDef& op_def() const { return *op_def_; }
145
146 private:
147 // Called in the constructors.
148 void Initialize();
149
150 // Get the current ArgDef and advance to the next one. Returns nullptr
151 // if no more inputs are available.
152 const OpDef::ArgDef* NextArgDef();
153
154 // Returns true if there is still an input_arg available in *op_def_,
155 // otherwise adds to error_ and returns false.
156 bool NextArgAvailable();
157
158 // These do the main work of the Input() methods.
159 void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node,
160 int src_index, DataType dt);
161 void ListInput(const OpDef::ArgDef* input_arg,
162 gtl::ArraySlice<NodeOut> src_list);
163
164 // Add "src_node:src_index" to the list of inputs in the node_def_.
165 void AddInput(StringPiece src_node, int src_index);
166
167 // Generate an error if you can't pass dt when expected is expected.
168 void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
169 DataType dt);
170
171 // If input_arg->is_ref() is true, generate an error if dt is not a ref.
172 void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt);
173
174 // Makes dt a ref type if that is what the input_arg specifies.
175 DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) {
176 return input_arg->is_ref() ? MakeRefType(dt) : dt;
177 }
178
179 // Returns true if an attr named `name` is already present in the node_def_.
180 // If such an attr is already present and `value` is not equal to the present
181 // value, an error is generated.
182 bool AttrValueAlreadyPresent(StringPiece name, const AttrValue& value);
183
184 const OpDef* op_def_;
185 NodeDef node_def_;
186 int inputs_specified_;
187 std::vector<string> control_inputs_;
188 std::vector<string> errors_;
189};
190
191} // namespace tensorflow
192
193#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
194