1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/node_def_builder.h" |
17 | |
18 | #include <vector> |
19 | #include "tensorflow/core/framework/attr_value.pb.h" |
20 | #include "tensorflow/core/framework/op.h" |
21 | #include "tensorflow/core/framework/op_def_util.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/lib/strings/str_util.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) |
28 | : node(n), index(i), data_type(dt) {} |
29 | |
30 | NodeDefBuilder::NodeOut::NodeOut() { |
31 | // uninitialized, call Reset() before use. |
32 | } |
33 | |
34 | void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { |
35 | node = string(n); |
36 | index = i; |
37 | data_type = dt; |
38 | } |
39 | |
40 | NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, |
41 | const OpRegistryInterface* op_registry, |
42 | const NodeDebugInfo* debug) { |
43 | node_def_.set_name(string(name)); |
44 | const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_); |
45 | if (status.ok()) { |
46 | Initialize(); |
47 | } else { |
48 | errors_.push_back(status.error_message()); |
49 | inputs_specified_ = 0; |
50 | } |
51 | if (debug != nullptr) MergeDebugInfo(*debug, &node_def_); |
52 | } |
53 | |
54 | NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, |
55 | const NodeDebugInfo& debug) |
56 | : NodeDefBuilder(name, op_name) { |
57 | MergeDebugInfo(debug, &node_def_); |
58 | } |
59 | |
60 | NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) |
61 | : op_def_(op_def) { |
62 | node_def_.set_name(string(name)); |
63 | Initialize(); |
64 | } |
65 | |
66 | void NodeDefBuilder::Initialize() { |
67 | inputs_specified_ = 0; |
68 | node_def_.set_op(op_def_->name()); |
69 | } |
70 | |
71 | const OpDef::ArgDef* NodeDefBuilder::NextArgDef() { |
72 | if (!NextArgAvailable()) return nullptr; |
73 | return &op_def_->input_arg(inputs_specified_++); |
74 | } |
75 | |
76 | bool NodeDefBuilder::NextArgAvailable() { |
77 | if (op_def_ == nullptr) { |
78 | return false; |
79 | } else if (inputs_specified_ >= op_def_->input_arg_size()) { |
80 | errors_.push_back(strings::StrCat("More Input() calls than the " , |
81 | op_def_->input_arg_size(), |
82 | " input_args" )); |
83 | return false; |
84 | } |
85 | return true; |
86 | } |
87 | |
88 | NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { |
89 | if (NextArgAvailable()) { |
90 | Status status = fake_input(*op_def_, inputs_specified_, node_def_, this); |
91 | if (!status.ok()) errors_.push_back(status.error_message()); |
92 | } |
93 | return *this; |
94 | } |
95 | |
96 | NodeDefBuilder& NodeDefBuilder::Input(StringPiece src_node, int src_index, |
97 | DataType dt) { |
98 | const OpDef::ArgDef* arg = NextArgDef(); |
99 | if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); |
100 | return *this; |
101 | } |
102 | |
103 | NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) { |
104 | Input(src.node, src.index, src.data_type); |
105 | return *this; |
106 | } |
107 | |
108 | // For inputs that take a list of tensors. |
109 | NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice<NodeOut> src_list) { |
110 | const OpDef::ArgDef* arg = NextArgDef(); |
111 | if (arg != nullptr) ListInput(arg, src_list); |
112 | return *this; |
113 | } |
114 | |
115 | void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, |
116 | StringPiece src_node, int src_index, |
117 | DataType dt) { |
118 | AddInput(src_node, src_index); |
119 | |
120 | if (!input_arg->number_attr().empty() || |
121 | !input_arg->type_list_attr().empty()) { |
122 | errors_.push_back(strings::StrCat("Single tensor passed to '" , |
123 | input_arg->name(), "', expected list" )); |
124 | return; |
125 | } |
126 | |
127 | if (input_arg->type() != DT_INVALID) { |
128 | const DataType expected = MaybeAddRef(input_arg, input_arg->type()); |
129 | VerifyInputType(input_arg, expected, dt); |
130 | } else { |
131 | VerifyInputRef(input_arg, dt); |
132 | Attr(input_arg->type_attr(), BaseType(dt)); |
133 | } |
134 | } |
135 | |
136 | void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, |
137 | gtl::ArraySlice<NodeOut> src_list) { |
138 | for (const auto& node_out : src_list) { |
139 | AddInput(node_out.node, node_out.index); |
140 | } |
141 | |
142 | if (!input_arg->number_attr().empty()) { |
143 | Attr(input_arg->number_attr(), static_cast<int64_t>(src_list.size())); |
144 | if (input_arg->type() != DT_INVALID) { |
145 | const DataType expected = MaybeAddRef(input_arg, input_arg->type()); |
146 | for (const auto& node_out : src_list) { |
147 | VerifyInputType(input_arg, expected, node_out.data_type); |
148 | } |
149 | } else if (!src_list.empty()) { |
150 | const DataType base = BaseType(src_list[0].data_type); |
151 | Attr(input_arg->type_attr(), base); |
152 | const DataType expected = MaybeAddRef(input_arg, base); |
153 | for (const auto& node_out : src_list) { |
154 | VerifyInputType(input_arg, expected, node_out.data_type); |
155 | } |
156 | } |
157 | } else if (!input_arg->type_list_attr().empty()) { |
158 | DataTypeVector type_vec; |
159 | type_vec.reserve(src_list.size()); |
160 | for (const auto& node_out : src_list) { |
161 | const DataType dt = node_out.data_type; |
162 | VerifyInputRef(input_arg, dt); |
163 | type_vec.push_back(BaseType(dt)); |
164 | } |
165 | Attr(input_arg->type_list_attr(), type_vec); |
166 | } else { |
167 | errors_.push_back(strings::StrCat("List provided to input '" , |
168 | input_arg->name(), |
169 | "' when single Tensor expected" )); |
170 | } |
171 | } |
172 | |
173 | void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { |
174 | if (src_node.empty()) { |
175 | errors_.push_back("Empty input node name" ); |
176 | } else if (src_node[0] == '^') { |
177 | errors_.push_back( |
178 | strings::StrCat("Non-control input starting with ^: " , src_node)); |
179 | } else if (src_index > 0) { |
180 | node_def_.add_input(strings::StrCat(src_node, ":" , src_index)); |
181 | } else { |
182 | node_def_.add_input(string(src_node)); |
183 | } |
184 | } |
185 | |
186 | void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg, |
187 | DataType expected, DataType dt) { |
188 | if (!TypesCompatible(expected, dt)) { |
189 | errors_.push_back(strings::StrCat("Input '" , input_arg->name(), "' passed " , |
190 | DataTypeString(dt), " expected " , |
191 | DataTypeString(expected))); |
192 | } |
193 | } |
194 | |
195 | void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, |
196 | DataType dt) { |
197 | if (input_arg->is_ref() && !IsRefType(dt)) { |
198 | errors_.push_back(strings::StrCat("Input '" , input_arg->name(), "' passed " , |
199 | DataTypeString(dt), |
200 | " expected ref type" )); |
201 | } |
202 | } |
203 | |
204 | NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { |
205 | control_inputs_.emplace_back(src_node); |
206 | return *this; |
207 | } |
208 | |
209 | NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { |
210 | node_def_.set_device(string(device_spec)); |
211 | return *this; |
212 | } |
213 | |
214 | Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { |
215 | const std::vector<string>* errors_ptr = &errors_; |
216 | std::vector<string> errors_storage; |
217 | if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { |
218 | // Since this is a const method, to add an error, we have to make |
219 | // a copy of the existing errors. |
220 | errors_storage = errors_; |
221 | errors_storage.push_back( |
222 | strings::StrCat(inputs_specified_, " inputs specified of " , |
223 | op_def_->input_arg_size(), " inputs in Op" )); |
224 | errors_ptr = &errors_storage; |
225 | } |
226 | |
227 | if (!errors_ptr->empty()) { |
228 | if (errors_ptr->size() == 1) { |
229 | if (op_def_ == nullptr) { |
230 | return errors::InvalidArgument((*errors_ptr)[0], |
231 | " while building NodeDef '" , |
232 | node_def_.name(), "'" ); |
233 | } |
234 | return errors::InvalidArgument( |
235 | (*errors_ptr)[0], " while building NodeDef '" , node_def_.name(), |
236 | "' using " , SummarizeOpDef(*op_def_)); |
237 | } else { |
238 | return errors::InvalidArgument( |
239 | errors_ptr->size(), " errors while building NodeDef '" , |
240 | node_def_.name(), "' using " , SummarizeOpDef(*op_def_), ":\n" , |
241 | absl::StrJoin(*errors_ptr, "\n" )); |
242 | } |
243 | } else { |
244 | NodeDef node_def_backup; |
245 | if (node_def == nullptr) node_def = &node_def_backup; |
246 | if (consume) { |
247 | *node_def = std::move(node_def_); |
248 | } else { |
249 | *node_def = node_def_; |
250 | } |
251 | |
252 | // Add control inputs after the regular inputs. |
253 | for (const auto& control_input : control_inputs_) { |
254 | node_def->add_input(strings::StrCat("^" , control_input)); |
255 | } |
256 | |
257 | // Add default values for unspecified attrs. |
258 | AddDefaultsToNodeDef(*op_def_, node_def); |
259 | |
260 | return OkStatus(); |
261 | } |
262 | } |
263 | |
264 | bool NodeDefBuilder::AttrValueAlreadyPresent(StringPiece name, |
265 | const AttrValue& value) { |
266 | if (const AttrValue* found = AttrSlice(node_def_).Find(name)) { |
267 | if (!AreAttrValuesEqual(*found, value)) { |
268 | errors_.push_back(strings::StrCat("Inconsistent values for attr '" , name, |
269 | "' " , SummarizeAttrValue(*found), |
270 | " vs. " , SummarizeAttrValue(value))); |
271 | } |
272 | return true; |
273 | } |
274 | return false; |
275 | } |
276 | |
277 | NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, const AttrValue& value) { |
278 | if (!AttrValueAlreadyPresent(name, value)) { |
279 | AddNodeAttr(name, value, &node_def_); |
280 | } |
281 | return *this; |
282 | } |
283 | |
284 | NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, AttrValue&& value) { |
285 | if (!AttrValueAlreadyPresent(name, value)) { |
286 | AddNodeAttr(name, std::move(value), &node_def_); |
287 | } |
288 | return *this; |
289 | } |
290 | |
291 | #define ATTR(T) \ |
292 | NodeDefBuilder& NodeDefBuilder::Attr(StringPiece name, T value) { \ |
293 | AttrValue attr_value; \ |
294 | SetAttrValue(value, &attr_value); \ |
295 | return Attr(name, attr_value); \ |
296 | } |
297 | ATTR(StringPiece) |
298 | ATTR(const char*) |
299 | ATTR(int32_t) |
300 | ATTR(int64_t) |
301 | ATTR(float) |
302 | ATTR(double) |
303 | ATTR(bool) |
304 | ATTR(DataType) |
305 | ATTR(const PartialTensorShape&) |
306 | ATTR(const Tensor&) |
307 | ATTR(const TensorProto&) |
308 | ATTR(const NameAttrList&) |
309 | ATTR(gtl::ArraySlice<StringPiece>) |
310 | ATTR(gtl::ArraySlice<const char*>) |
311 | ATTR(gtl::ArraySlice<string>) |
312 | ATTR(gtl::ArraySlice<tstring>) |
313 | ATTR(gtl::ArraySlice<int32>) |
314 | ATTR(gtl::ArraySlice<int64_t>) |
315 | ATTR(gtl::ArraySlice<float>) |
316 | ATTR(gtl::ArraySlice<bool>) |
317 | ATTR(const std::vector<bool>&) |
318 | ATTR(gtl::ArraySlice<DataType>) |
319 | ATTR(gtl::ArraySlice<TensorShape>) |
320 | ATTR(gtl::ArraySlice<PartialTensorShape>) |
321 | ATTR(gtl::ArraySlice<TensorShapeProto>) |
322 | ATTR(gtl::ArraySlice<Tensor>) |
323 | ATTR(gtl::ArraySlice<NameAttrList>) |
324 | #undef ATTR |
325 | |
326 | } // namespace tensorflow |
327 | |