1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17
18#include <vector>
19#include "tensorflow/core/framework/attr_value.pb.h"
20#include "tensorflow/core/framework/op.h"
21#include "tensorflow/core/framework/op_def_util.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/strings/str_util.h"
24
25namespace tensorflow {
26
27NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
28 : node(n), index(i), data_type(dt) {}
29
30NodeDefBuilder::NodeOut::NodeOut() {
31 // uninitialized, call Reset() before use.
32}
33
34void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
35 node = string(n);
36 index = i;
37 data_type = dt;
38}
39
40NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
41 const OpRegistryInterface* op_registry,
42 const NodeDebugInfo* debug) {
43 node_def_.set_name(string(name));
44 const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
45 if (status.ok()) {
46 Initialize();
47 } else {
48 errors_.push_back(status.error_message());
49 inputs_specified_ = 0;
50 }
51 if (debug != nullptr) MergeDebugInfo(*debug, &node_def_);
52}
53
54NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
55 const NodeDebugInfo& debug)
56 : NodeDefBuilder(name, op_name) {
57 MergeDebugInfo(debug, &node_def_);
58}
59
60NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
61 : op_def_(op_def) {
62 node_def_.set_name(string(name));
63 Initialize();
64}
65
66void NodeDefBuilder::Initialize() {
67 inputs_specified_ = 0;
68 node_def_.set_op(op_def_->name());
69}
70
71const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
72 if (!NextArgAvailable()) return nullptr;
73 return &op_def_->input_arg(inputs_specified_++);
74}
75
76bool NodeDefBuilder::NextArgAvailable() {
77 if (op_def_ == nullptr) {
78 return false;
79 } else if (inputs_specified_ >= op_def_->input_arg_size()) {
80 errors_.push_back(strings::StrCat("More Input() calls than the ",
81 op_def_->input_arg_size(),
82 " input_args"));
83 return false;
84 }
85 return true;
86}
87
88NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
89 if (NextArgAvailable()) {
90 Status status = fake_input(*op_def_, inputs_specified_, node_def_, this);
91 if (!status.ok()) errors_.push_back(status.error_message());
92 }
93 return *this;
94}
95
96NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index,
97 DataType dt) {
98 const OpDef::ArgDef* arg = NextArgDef();
99 if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
100 return *this;
101}
102
103NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) {
104 Input(src.node, src.index, src.data_type);
105 return *this;
106}
107
108// For inputs that take a list of tensors.
109NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
110 const OpDef::ArgDef* arg = NextArgDef();
111 if (arg != nullptr) ListInput(arg, src_list);
112 return *this;
113}
114
115void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
116 StringPiece src_node, int src_index,
117 DataType dt) {
118 AddInput(src_node, src_index);
119
120 if (!input_arg->number_attr().empty() ||
121 !input_arg->type_list_attr().empty()) {
122 errors_.push_back(strings::StrCat("Single tensor passed to '",
123 input_arg->name(), "', expected list"));
124 return;
125 }
126
127 if (input_arg->type() != DT_INVALID) {
128 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
129 VerifyInputType(input_arg, expected, dt);
130 } else {
131 VerifyInputRef(input_arg, dt);
132 Attr(input_arg->type_attr(), BaseType(dt));
133 }
134}
135
136void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
137 gtl::ArraySlice<NodeOut> src_list) {
138 for (const auto& node_out : src_list) {
139 AddInput(node_out.node, node_out.index);
140 }
141
142 if (!input_arg->number_attr().empty()) {
143 Attr(input_arg->number_attr(), static_cast<int64_t>(src_list.size()));
144 if (input_arg->type() != DT_INVALID) {
145 const DataType expected = MaybeAddRef(input_arg, input_arg->type());
146 for (const auto& node_out : src_list) {
147 VerifyInputType(input_arg, expected, node_out.data_type);
148 }
149 } else if (!src_list.empty()) {
150 const DataType base = BaseType(src_list[0].data_type);
151 Attr(input_arg->type_attr(), base);
152 const DataType expected = MaybeAddRef(input_arg, base);
153 for (const auto& node_out : src_list) {
154 VerifyInputType(input_arg, expected, node_out.data_type);
155 }
156 }
157 } else if (!input_arg->type_list_attr().empty()) {
158 DataTypeVector type_vec;
159 type_vec.reserve(src_list.size());
160 for (const auto& node_out : src_list) {
161 const DataType dt = node_out.data_type;
162 VerifyInputRef(input_arg, dt);
163 type_vec.push_back(BaseType(dt));
164 }
165 Attr(input_arg->type_list_attr(), type_vec);
166 } else {
167 errors_.push_back(strings::StrCat("List provided to input '",
168 input_arg->name(),
169 "' when single Tensor expected"));
170 }
171}
172
173void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
174 if (src_node.empty()) {
175 errors_.push_back("Empty input node name");
176 } else if (src_node[0] == '^') {
177 errors_.push_back(
178 strings::StrCat("Non-control input starting with ^: ", src_node));
179 } else if (src_index > 0) {
180 node_def_.add_input(strings::StrCat(src_node, ":", src_index));
181 } else {
182 node_def_.add_input(string(src_node));
183 }
184}
185
186void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
187 DataType expected, DataType dt) {
188 if (!TypesCompatible(expected, dt)) {
189 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
190 DataTypeString(dt), " expected ",
191 DataTypeString(expected)));
192 }
193}
194
195void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
196 DataType dt) {
197 if (input_arg->is_ref() && !IsRefType(dt)) {
198 errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
199 DataTypeString(dt),
200 " expected ref type"));
201 }
202}
203
204NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
205 control_inputs_.emplace_back(src_node);
206 return *this;
207}
208
209NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
210 node_def_.set_device(string(device_spec));
211 return *this;
212}
213
214Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
215 const std::vector<string>* errors_ptr = &errors_;
216 std::vector<string> errors_storage;
217 if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
218 // Since this is a const method, to add an error, we have to make
219 // a copy of the existing errors.
220 errors_storage = errors_;
221 errors_storage.push_back(
222 strings::StrCat(inputs_specified_, " inputs specified of ",
223 op_def_->input_arg_size(), " inputs in Op"));
224 errors_ptr = &errors_storage;
225 }
226
227 if (!errors_ptr->empty()) {
228 if (errors_ptr->size() == 1) {
229 if (op_def_ == nullptr) {
230 return errors::InvalidArgument((*errors_ptr)[0],
231 " while building NodeDef '",
232 node_def_.name(), "'");
233 }
234 return errors::InvalidArgument(
235 (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
236 "' using ", SummarizeOpDef(*op_def_));
237 } else {
238 return errors::InvalidArgument(
239 errors_ptr->size(), " errors while building NodeDef '",
240 node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
241 absl::StrJoin(*errors_ptr, "\n"));
242 }
243 } else {
244 NodeDef node_def_backup;
245 if (node_def == nullptr) node_def = &node_def_backup;
246 if (consume) {
247 *node_def = std::move(node_def_);
248 } else {
249 *node_def = node_def_;
250 }
251
252 // Add control inputs after the regular inputs.
253 for (const auto& control_input : control_inputs_) {
254 node_def->add_input(strings::StrCat("^", control_input));
255 }
256
257 // Add default values for unspecified attrs.
258 AddDefaultsToNodeDef(*op_def_, node_def);
259
260 return OkStatus();
261 }
262}
263
264bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name,
265 const AttrValue& value) {
266 if (const AttrValue* found = AttrSlice(node_def_).Find(name)) {
267 if (!AreAttrValuesEqual(*found, value)) {
268 errors_.push_back(strings::StrCat("Inconsistent values for attr '", name,
269 "' ", SummarizeAttrValue(*found),
270 " vs. ", SummarizeAttrValue(value)));
271 }
272 return true;
273 }
274 return false;
275}
276
277NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) {
278 if (!AttrValueAlreadyPresent(name, value)) {
279 AddNodeAttr(name, value, &node_def_);
280 }
281 return *this;
282}
283
284NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) {
285 if (!AttrValueAlreadyPresent(name, value)) {
286 AddNodeAttr(name, std::move(value), &node_def_);
287 }
288 return *this;
289}
290
291#define ATTR(T) \
292 NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \
293 AttrValue attr_value; \
294 SetAttrValue(value, &attr_value); \
295 return Attr(name, attr_value); \
296 }
297ATTR(StringPiece)
298ATTR(const char*)
299ATTR(int32_t)
300ATTR(int64_t)
301ATTR(float)
302ATTR(double)
303ATTR(bool)
304ATTR(DataType)
305ATTR(const PartialTensorShape&)
306ATTR(const Tensor&)
307ATTR(const TensorProto&)
308ATTR(const NameAttrList&)
309ATTR(gtl::ArraySlice<StringPiece>)
310ATTR(gtl::ArraySlice<const char*>)
311ATTR(gtl::ArraySlice<string>)
312ATTR(gtl::ArraySlice<tstring>)
313ATTR(gtl::ArraySlice<int32>)
314ATTR(gtl::ArraySlice<int64_t>)
315ATTR(gtl::ArraySlice<float>)
316ATTR(gtl::ArraySlice<bool>)
317ATTR(const std::vector<bool>&)
318ATTR(gtl::ArraySlice<DataType>)
319ATTR(gtl::ArraySlice<TensorShape>)
320ATTR(gtl::ArraySlice<PartialTensorShape>)
321ATTR(gtl::ArraySlice<TensorShapeProto>)
322ATTR(gtl::ArraySlice<Tensor>)
323ATTR(gtl::ArraySlice<NameAttrList>)
324#undef ATTR
325
326} // namespace tensorflow
327