1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/node_def_util.h" |
17 | |
18 | #include <algorithm> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "absl/strings/str_cat.h" |
23 | #include "absl/strings/str_join.h" |
24 | #include "tensorflow/core/framework/attr_value.pb.h" |
25 | #include "tensorflow/core/framework/attr_value_util.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/framework/op_def.pb.h" |
28 | #include "tensorflow/core/framework/op_def_util.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/tensor.pb.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/framework/types.pb.h" |
35 | #include "tensorflow/core/lib/gtl/map_util.h" |
36 | #include "tensorflow/core/platform/errors.h" |
37 | #include "tensorflow/core/platform/scanner.h" |
38 | #include "tensorflow/core/platform/status.h" |
39 | #include "tensorflow/core/platform/strcat.h" |
40 | #include "tensorflow/core/platform/stringpiece.h" |
41 | #include "tensorflow/core/platform/types.h" |
42 | |
43 | namespace tensorflow { |
44 | |
45 | const char* const kColocationAttrName = "_class" ; |
46 | const char* const kColocationGroupPrefix = "loc:@" ; |
47 | // For TPU distributed rewrite, TPU args are collected and "staged" on the local |
48 | // host using an IdentityN TF op. Some args may result from a remote source. |
49 | // When all arg tensors are available, the TPUExecute op can be inovoked. See |
50 | // DistributedTPURewritePass for more details. |
51 | const char* const kTpuExecuteStagingOp = "IdentityN" ; |
52 | const char* const kTpuExecuteStagingNodeName = "_variable_copy" ; |
53 | |
54 | AttrSlice::AttrSlice() : ndef_(nullptr) { |
55 | static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap; |
56 | attrs_ = kEmptyAttrValueMap; |
57 | } |
58 | |
59 | // Do not cache the map field reference because that may be invalidated on |
60 | // Clear. |
61 | AttrSlice::AttrSlice(const NodeDef& node_def) |
62 | : ndef_(&node_def), attrs_(nullptr) {} |
63 | |
64 | AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} |
65 | |
66 | string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) { |
67 | string ret; |
68 | |
69 | // We sort the attrs so the output is deterministic. |
70 | std::vector<string> attr_names; |
71 | attr_names.reserve(attrs.size()); |
72 | for (const auto& attr : attrs) { |
73 | attr_names.push_back(attr.first); |
74 | } |
75 | std::sort(attr_names.begin(), attr_names.end()); |
76 | bool first = true; |
77 | for (const string& attr_name : attr_names) { |
78 | if (!first) strings::StrAppend(&ret, ", " ); |
79 | first = false; |
80 | strings::StrAppend(&ret, attr_name, "=" , |
81 | SummarizeAttrValue(*attrs.Find(attr_name))); |
82 | } |
83 | |
84 | // Consider the device to be a final attr with name "_device". |
85 | if (!device.empty()) { |
86 | if (!first) strings::StrAppend(&ret, ", " ); |
87 | first = false; |
88 | strings::StrAppend(&ret, "_device=\"" , device, "\"" ); |
89 | } |
90 | return ret; |
91 | } |
92 | |
93 | string AttrSlice::SummarizeNode() const { |
94 | return ndef_ ? SummarizeNodeDef(*ndef_) |
95 | : strings::StrCat( |
96 | "[" , SummarizeAttrsHelper(*this, StringPiece()), "]" ); |
97 | } |
98 | |
99 | string AttrSlice::DebugString() const { |
100 | std::vector<string> attr_key_vals; |
101 | attr_key_vals.reserve(attrs()->size()); |
102 | for (const auto& it : *this) { |
103 | const string& name = it.first; |
104 | const AttrValue& attr_value = it.second; |
105 | attr_key_vals.push_back( |
106 | absl::StrCat(name, "=" , SummarizeAttrValue(attr_value))); |
107 | } |
108 | return absl::StrJoin(attr_key_vals, ", " ); |
109 | } |
110 | |
111 | string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) { |
112 | string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()), |
113 | " = " , node_def.op(), "[" ); |
114 | strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); |
115 | strings::StrAppend(&ret, "](" ); |
116 | |
117 | // Output inputs, including control inputs, verbatim. |
118 | bool first = true; |
119 | for (const string& input : node_def.input()) { |
120 | if (!first) strings::StrAppend(&ret, ", " ); |
121 | first = false; |
122 | if (max_inputs_in_summary-- == 0) { |
123 | strings::StrAppend(&ret, "..." ); |
124 | break; |
125 | } |
126 | strings::StrAppend(&ret, input); |
127 | } |
128 | strings::StrAppend(&ret, ")" ); |
129 | return ret; |
130 | } |
131 | |
132 | string SummarizeAttrs(const NodeDef& node_def) { |
133 | return SummarizeAttrsHelper(node_def, node_def.device()); |
134 | } |
135 | |
136 | string FormatNodeDefForError( |
137 | StringPiece node_name, bool has_experimental_debug_info, |
138 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info) { |
139 | return !has_experimental_debug_info || |
140 | experimental_debug_info.original_node_names().empty() |
141 | ? errors::FormatNodeNameForError(string(node_name)) |
142 | : errors::FormatOriginalNodeLocationForError( |
143 | experimental_debug_info.original_node_names(), |
144 | experimental_debug_info.original_func_names()); |
145 | } |
146 | |
147 | string FormatNodeDefForError(const NodeDef& node_def) { |
148 | return FormatNodeDefForError(node_def.name(), |
149 | node_def.has_experimental_debug_info(), |
150 | node_def.experimental_debug_info()); |
151 | } |
152 | |
153 | const AttrValue* AttrSlice::Find(StringPiece attr_name) const { |
154 | // Currently, the collection used for NodeDef::attr() (google::protobuf::Map) |
155 | // requires that the keys used for lookups have type 'const string&'. Because |
156 | // this method takes a StringPiece, it is necessary to allocate a temporary |
157 | // string, copy attr_name to it, and then use that temporary string for the |
158 | // lookup. This causes an excessive number of short-lived allocations, and for |
159 | // large graphs, this can be a significant cost. |
160 | // |
161 | // Because most nodes have a small number of attributes, a simple linear scan |
162 | // is generally more efficient than a hashed lookup. If google::protobuf::Map |
163 | // changes so that it supports efficient lookups using StringPiece instead of |
164 | // const string&, then this code could be changed to use attrs()->find() |
165 | // again. |
166 | |
167 | for (const auto& attr : *attrs()) { |
168 | if (attr.first == attr_name) { |
169 | return &attr.second; |
170 | } |
171 | } |
172 | return nullptr; |
173 | } |
174 | |
175 | const AttrValue* AttrSlice::FindByString(const string& attr_name) const { |
176 | auto iter = attrs()->find(attr_name); |
177 | if (iter != attrs()->end()) { |
178 | return &iter->second; |
179 | } else { |
180 | return nullptr; |
181 | } |
182 | } |
183 | |
184 | Status AttrSlice::CheckFind(StringPiece attr_name, |
185 | const AttrValue* attr_value) const { |
186 | if (attr_value != nullptr) { |
187 | return OkStatus(); |
188 | } |
189 | Status s = errors::NotFound("No attr named '" , attr_name, "' in NodeDef:" ); |
190 | // Skip AttachDef for internal attrs since it is a little bit |
191 | // expensive and it is common for them to correctly not be included |
192 | // in a NodeDef. |
193 | if (!absl::StartsWith(attr_name, "_" ) && ndef_ != nullptr) { |
194 | s = AttachDef(s, *ndef_); |
195 | } |
196 | return s; |
197 | } |
198 | |
199 | Status AttrSlice::Find(StringPiece attr_name, |
200 | const AttrValue** attr_value) const { |
201 | *attr_value = Find(attr_name); |
202 | return CheckFind(attr_name, *attr_value); |
203 | } |
204 | |
205 | Status AttrSlice::FindByString(const string& attr_name, |
206 | const AttrValue** attr_value) const { |
207 | *attr_value = FindByString(attr_name); |
208 | return CheckFind(attr_name, *attr_value); |
209 | } |
210 | |
211 | bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { |
212 | if (size() != other.size()) return false; |
213 | |
214 | for (const auto& attr : *other.attrs()) { |
215 | auto iter = attrs()->find(attr.first); |
216 | if (iter == attrs()->end()) return false; |
217 | // TODO(irving): Comparing AttrValues by proto is slightly buggy, since |
218 | // TensorProto is a nonunique representation of Tensor. This bug will go |
219 | // away once AttrSlice switches over to NodeInfo. |
220 | iter->second.SerializeToString(&scratch->a); |
221 | attr.second.SerializeToString(&scratch->b); |
222 | if (scratch->a != scratch->b) return false; |
223 | } |
224 | return true; |
225 | } |
226 | |
227 | // The ... is to allow the caller to inject some value validation code. Use |
228 | // just ; if no additional validation code is needed. |
229 | #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ |
230 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ |
231 | TYPE* value) { \ |
232 | const AttrValue* attr_value; \ |
233 | TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ |
234 | TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \ |
235 | const auto& v = attr_value->FIELD(); \ |
236 | __VA_ARGS__; \ |
237 | *value = CAST; \ |
238 | return OkStatus(); \ |
239 | } \ |
240 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ |
241 | std::vector<TYPE>* value) { \ |
242 | const AttrValue* attr_value; \ |
243 | TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ |
244 | TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \ |
245 | value->reserve(attr_value->list().FIELD().size()); \ |
246 | for (const auto& v : attr_value->list().FIELD()) { \ |
247 | __VA_ARGS__; \ |
248 | value->APPEND_OP(CAST); \ |
249 | } \ |
250 | return OkStatus(); \ |
251 | } |
252 | |
253 | #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ |
254 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ |
255 | TYPE* value) { \ |
256 | const AttrValue* attr_value = attrs.Find(attr_name); \ |
257 | if (attr_value == nullptr) { \ |
258 | return false; \ |
259 | } \ |
260 | Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \ |
261 | if (!s.ok()) { \ |
262 | return false; \ |
263 | } \ |
264 | const auto& v = attr_value->FIELD(); \ |
265 | __VA_ARGS__; \ |
266 | *value = CAST; \ |
267 | return true; \ |
268 | } \ |
269 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ |
270 | std::vector<TYPE>* value) { \ |
271 | const AttrValue* attr_value = attrs.Find(attr_name); \ |
272 | if (attr_value == nullptr) { \ |
273 | return false; \ |
274 | } \ |
275 | Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \ |
276 | if (!s.ok()) { \ |
277 | return false; \ |
278 | } \ |
279 | value->reserve(attr_value->list().FIELD().size()); \ |
280 | for (const auto& v : attr_value->list().FIELD()) { \ |
281 | __VA_ARGS__; \ |
282 | value->APPEND_OP(CAST); \ |
283 | } \ |
284 | return true; \ |
285 | } |
286 | DEFINE_GET_ATTR(tstring, s, "string" , emplace_back, v, ;) |
287 | DEFINE_TRY_GET_ATTR(tstring, s, "string" , emplace_back, v, ;) |
288 | DEFINE_GET_ATTR(string, s, "string" , emplace_back, v, ;) |
289 | DEFINE_TRY_GET_ATTR(string, s, "string" , emplace_back, v, ;) |
290 | DEFINE_GET_ATTR(int64_t, i, "int" , emplace_back, v, ;) |
291 | DEFINE_TRY_GET_ATTR(int64_t, i, "int" , emplace_back, v, ;) |
292 | DEFINE_GET_ATTR( |
293 | int32, i, "int" , emplace_back, static_cast<int32>(v), |
294 | if (static_cast<int64_t>(static_cast<int32>(v)) != v) { |
295 | return errors::InvalidArgument("Attr " , attr_name, " has value " , v, |
296 | " out of range for an int32" ); |
297 | }) |
298 | DEFINE_TRY_GET_ATTR( |
299 | int32, i, "int" , emplace_back, static_cast<int32>(v), |
300 | if (static_cast<int64_t>(static_cast<int32>(v)) != v) { |
301 | static int log_counter = 0; |
302 | if (log_counter < 10) { |
303 | log_counter++; |
304 | LOG(WARNING) << "Attr " << attr_name << " has value " << v |
305 | << " out of range for an int32" ; |
306 | } |
307 | return false; |
308 | }) |
309 | DEFINE_GET_ATTR(float, f, "float" , emplace_back, v, ;) |
310 | DEFINE_TRY_GET_ATTR(float, f, "float" , emplace_back, v, ;) |
311 | DEFINE_GET_ATTR(bool, b, "bool" , emplace_back, v, ;) |
312 | DEFINE_TRY_GET_ATTR(bool, b, "bool" , emplace_back, v, ;) |
313 | DEFINE_GET_ATTR(DataType, type, "type" , emplace_back, static_cast<DataType>(v), |
314 | ;) |
315 | DEFINE_TRY_GET_ATTR(DataType, type, "type" , emplace_back, |
316 | static_cast<DataType>(v), |
317 | ;) |
318 | DEFINE_GET_ATTR(TensorShapeProto, shape, "shape" , emplace_back, v, ;) |
319 | DEFINE_GET_ATTR(TensorShape, shape, "shape" , emplace_back, TensorShape(v), |
320 | TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));) |
321 | DEFINE_TRY_GET_ATTR( |
322 | TensorShape, shape, "shape" , emplace_back, TensorShape(v), |
323 | if (!TensorShape::IsValidShape(v).ok()) { |
324 | static int log_counter = 0; |
325 | if (log_counter < 10) { |
326 | log_counter++; |
327 | LOG(WARNING) << "Attr " << attr_name << " has invalid shape value " |
328 | << v.DebugString(); |
329 | } |
330 | return false; |
331 | }) |
332 | DEFINE_GET_ATTR(PartialTensorShape, shape, "shape" , emplace_back, |
333 | PartialTensorShape(v), |
334 | TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));) |
335 | DEFINE_GET_ATTR( |
336 | Tensor, tensor, "tensor" , emplace_back, t, Tensor t; if (!t.FromProto(v)) { |
337 | return errors::InvalidArgument("Attr " , attr_name, " has value " , |
338 | v.ShortDebugString(), |
339 | " that can't be converted to a Tensor" ); |
340 | }) |
341 | DEFINE_GET_ATTR(NameAttrList, func, "func" , emplace_back, v, ;); |
342 | #undef DEFINE_GET_ATTR |
343 | |
344 | bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { |
345 | return node_def.attr().find(string(attr_name)) != node_def.attr().end(); |
346 | } |
347 | |
348 | static const string& kEmptyString = *new string(); |
349 | |
350 | const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { |
351 | const AttrValue* attr_value = attrs.Find(attr_name); |
352 | if (attr_value == nullptr) { |
353 | return kEmptyString; |
354 | } |
355 | Status s = AttrValueHasType(*attr_value, "string" ); |
356 | if (!s.ok()) { |
357 | return kEmptyString; |
358 | } |
359 | return attr_value->s(); |
360 | } |
361 | |
362 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
363 | std::vector<const string*>* value) { |
364 | const AttrValue* attr_value = attrs.Find(attr_name); |
365 | if (attr_value == nullptr) { |
366 | return false; |
367 | } |
368 | Status s = AttrValueHasType(*attr_value, "list(string)" ); |
369 | if (!s.ok()) { |
370 | return false; |
371 | } |
372 | value->reserve(attr_value->list().s().size()); |
373 | for (const auto& v : attr_value->list().s()) { |
374 | value->push_back(&v); |
375 | } |
376 | return true; |
377 | } |
378 | |
379 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
380 | std::vector<const TensorShapeProto*>* value) { |
381 | const AttrValue* attr_value = attrs.Find(attr_name); |
382 | if (attr_value == nullptr) { |
383 | return false; |
384 | } |
385 | Status s = AttrValueHasType(*attr_value, "list(shape)" ); |
386 | if (!s.ok()) { |
387 | return false; |
388 | } |
389 | value->reserve(attr_value->list().shape().size()); |
390 | for (const auto& v : attr_value->list().shape()) { |
391 | value->push_back(&v); |
392 | } |
393 | return true; |
394 | } |
395 | |
396 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
397 | DataTypeVector* value) { |
398 | const AttrValue* attr_value; |
399 | TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); |
400 | TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)" )); |
401 | for (const auto& v : attr_value->list().type()) { |
402 | value->push_back(static_cast<DataType>(v)); |
403 | } |
404 | return OkStatus(); |
405 | } |
406 | |
407 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
408 | const TensorProto** value) { |
409 | const AttrValue* attr_value; |
410 | TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); |
411 | TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor" )); |
412 | *value = &attr_value->tensor(); |
413 | return OkStatus(); |
414 | } |
415 | |
416 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
417 | const TensorProto** value) { |
418 | const AttrValue* attr_value = attrs.Find(attr_name); |
419 | if (attr_value == nullptr) { |
420 | return false; |
421 | } |
422 | Status s = AttrValueHasType(*attr_value, "tensor" ); |
423 | if (!s.ok()) { |
424 | return false; |
425 | } |
426 | *value = &attr_value->tensor(); |
427 | return true; |
428 | } |
429 | |
430 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
431 | const NameAttrList** value) { |
432 | const AttrValue* attr_value; |
433 | TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); |
434 | TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func" )); |
435 | *value = &attr_value->func(); |
436 | return OkStatus(); |
437 | } |
438 | |
439 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
440 | const NameAttrList** value) { |
441 | const AttrValue* attr_value = attrs.Find(attr_name); |
442 | if (attr_value == nullptr) { |
443 | return false; |
444 | } |
445 | Status s = AttrValueHasType(*attr_value, "func" ); |
446 | if (!s.ok()) { |
447 | return false; |
448 | } |
449 | *value = &attr_value->func(); |
450 | return true; |
451 | } |
452 | |
453 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
454 | Padding* value) { |
455 | string str_value; |
456 | TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value)); |
457 | return GetPaddingFromString(str_value, value); |
458 | } |
459 | |
460 | namespace { // Helper for InOutTypesForNode(). |
461 | |
462 | template <class NodeDefOrAttrSlice> |
463 | Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, |
464 | const OpDef::ArgDef& arg_def, DataTypeVector* sig) { |
465 | const int original_size = sig->size(); |
466 | if (!arg_def.number_attr().empty()) { |
467 | // Same type repeated "repeats" times. |
468 | int64_t repeats = -1; |
469 | TF_RETURN_IF_ERROR( |
470 | GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats)); |
471 | // We can't handle outputs that are larger than int32 sizes. |
472 | if (static_cast<int64_t>(static_cast<int32>(repeats)) != repeats) { |
473 | return errors::InvalidArgument("Number of outputs is too big: " , repeats); |
474 | } |
475 | if (repeats < 0) { |
476 | return errors::InvalidArgument("Value for number_attr() " , repeats, |
477 | " < 0" ); |
478 | } |
479 | |
480 | if (!arg_def.type_attr().empty()) { |
481 | DataType dtype; |
482 | TF_RETURN_IF_ERROR( |
483 | GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype)); |
484 | for (int i = 0; i < repeats; ++i) { |
485 | sig->push_back(dtype); |
486 | } |
487 | } else if (arg_def.type() != DT_INVALID) { |
488 | for (int i = 0; i < repeats; ++i) { |
489 | sig->push_back(arg_def.type()); |
490 | } |
491 | } else { |
492 | return errors::InvalidArgument("Missing type or type_attr field in " , |
493 | arg_def.ShortDebugString()); |
494 | } |
495 | } else if (!arg_def.type_attr().empty()) { |
496 | const AttrValue* attr_value; |
497 | TF_RETURN_IF_ERROR(AttrSlice(node_or_attrs) |
498 | .FindByString(arg_def.type_attr(), &attr_value)); |
499 | sig->push_back(attr_value->type()); |
500 | } else if (!arg_def.type_list_attr().empty()) { |
501 | const AttrValue* attr_value; |
502 | TF_RETURN_IF_ERROR( |
503 | AttrSlice(node_or_attrs) |
504 | .FindByString(arg_def.type_list_attr(), &attr_value)); |
505 | for (int dtype : attr_value->list().type()) { |
506 | sig->push_back(static_cast<DataType>(dtype)); |
507 | } |
508 | } else if (arg_def.type() != DT_INVALID) { |
509 | sig->push_back(arg_def.type()); |
510 | } else { |
511 | return errors::InvalidArgument("No type fields in " , |
512 | arg_def.ShortDebugString()); |
513 | } |
514 | if (arg_def.is_ref()) { |
515 | // For all types that were added by this function call, make them refs. |
516 | for (size_t i = original_size; i < sig->size(); ++i) { |
517 | if (IsRefType((*sig)[i])) { |
518 | return errors::InvalidArgument( |
519 | "Requested reference to a reference type: " , |
520 | arg_def.ShortDebugString()); |
521 | } |
522 | (*sig)[i] = MakeRefType((*sig)[i]); |
523 | } |
524 | } |
525 | return OkStatus(); |
526 | } |
527 | |
528 | } // namespace |
529 | |
530 | Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, |
531 | int input_port, DataType* input_type) { |
532 | DataTypeVector input_types; |
533 | for (const auto& arg : op_def.input_arg()) { |
534 | TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types)); |
535 | int input_types_size = input_types.size(); |
536 | if (input_types_size > input_port) { |
537 | const DataType dtype = input_types[input_port]; |
538 | *input_type = dtype; |
539 | return OkStatus(); |
540 | } |
541 | } |
542 | return errors::InvalidArgument("Input " , input_port, " not found for node " , |
543 | node_def.name()); |
544 | } |
545 | |
546 | Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
547 | DataTypeVector* inputs) { |
548 | for (const auto& arg : op_def.input_arg()) { |
549 | TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); |
550 | } |
551 | return OkStatus(); |
552 | } |
553 | |
554 | Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, |
555 | int output_port, DataType* output_type) { |
556 | DataTypeVector output_types; |
557 | for (const auto& arg : op_def.output_arg()) { |
558 | TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types)); |
559 | int output_types_size = output_types.size(); |
560 | if (output_types_size > output_port) { |
561 | const DataType dtype = output_types[output_port]; |
562 | *output_type = dtype; |
563 | return OkStatus(); |
564 | } |
565 | } |
566 | return errors::InvalidArgument("Output " , output_port, " not found for node " , |
567 | node_def.name()); |
568 | } |
569 | |
570 | Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
571 | DataTypeVector* outputs) { |
572 | for (const auto& arg : op_def.output_arg()) { |
573 | TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); |
574 | } |
575 | return OkStatus(); |
576 | } |
577 | |
578 | Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, |
579 | DataTypeVector* outputs) { |
580 | for (const auto& arg : op_def.output_arg()) { |
581 | TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs)); |
582 | } |
583 | return OkStatus(); |
584 | } |
585 | |
586 | Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
587 | DataTypeVector* inputs, DataTypeVector* outputs) { |
588 | TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs)); |
589 | return OutputTypesForNode(node_def, op_def, outputs); |
590 | } |
591 | |
592 | Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, |
593 | int* num_outputs) { |
594 | DataTypeVector outputs; |
595 | TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); |
596 | *num_outputs = outputs.size(); |
597 | return OkStatus(); |
598 | } |
599 | |
600 | int OpPortIdToArgId(const NodeDef& node, |
601 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, |
602 | int port_id) { |
603 | for (int arg_id = 0; arg_id < args.size(); ++arg_id) { |
604 | if (port_id < 0) { |
605 | return -1; |
606 | } else if (port_id == 0) { |
607 | return arg_id; |
608 | } |
609 | |
610 | // Default is 1 port per arg. |
611 | int n = 1; |
612 | |
613 | const auto& arg = args.Get(arg_id); |
614 | if (!arg.number_attr().empty()) { |
615 | n = node.attr().at(arg.number_attr()).i(); |
616 | } else if (!arg.type_list_attr().empty()) { |
617 | n = node.attr().at(arg.type_list_attr()).list().type_size(); |
618 | } |
619 | |
620 | if (n < 0) { |
621 | // This should never happen. |
622 | DCHECK_GE(n, 0); |
623 | return -1; |
624 | } else if (port_id < n) { |
625 | return arg_id; |
626 | } |
627 | port_id -= n; |
628 | } |
629 | |
630 | return -1; |
631 | } |
632 | |
633 | Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { |
634 | if (node_def.op() != op_def.name()) { |
635 | return errors::InvalidArgument( |
636 | "NodeDef op '" , node_def.op(), "' does not match " , |
637 | SummarizeOpDef(op_def), "; NodeDef: " , FormatNodeDefForError(node_def)); |
638 | } |
639 | |
640 | bool seen_control = false; |
641 | size_t num_inputs = 0; |
642 | // TODO(josh11b): Unify the input field validation. |
643 | for (const string& input : node_def.input()) { |
644 | if (absl::StartsWith(input, "^" )) { |
645 | seen_control = true; |
646 | if (input.find(':') != string::npos) { |
647 | return errors::InvalidArgument("Control input '" , input, |
648 | "' must not have ':' in NodeDef: " , |
649 | FormatNodeDefForError(node_def)); |
650 | } |
651 | } else if (seen_control) { |
652 | return errors::InvalidArgument("Non-control input '" , input, |
653 | "' after control input in NodeDef: " , |
654 | FormatNodeDefForError(node_def)); |
655 | } else { |
656 | ++num_inputs; |
657 | } |
658 | } |
659 | |
660 | std::unordered_map<string, const OpDef::AttrDef*> op_attrs; |
661 | for (const auto& attr : op_def.attr()) { |
662 | if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { |
663 | return errors::InvalidArgument("OpDef has duplicate attr name '" , |
664 | attr.name(), |
665 | "': " , SummarizeOpDef(op_def)); |
666 | } |
667 | } |
668 | for (const auto& attr : node_def.attr()) { |
669 | // Allow internal optional attributes with names starting with "_". |
670 | if (absl::StartsWith(attr.first, "_" )) { |
671 | continue; |
672 | } |
673 | auto iter = op_attrs.find(attr.first); |
674 | if (iter == op_attrs.end()) { |
675 | LOG_EVERY_N_SEC(ERROR, 5) |
676 | << "NodeDef mentions attribute " << attr.first |
677 | << " which is not in the op definition: " << SummarizeOpDef(op_def) |
678 | << " This may be expected if your graph generating binary is newer " |
679 | << " than this binary. Unknown attributes will be ignored." |
680 | << " NodeDef: " << FormatNodeDefForError(node_def); |
681 | continue; |
682 | } |
683 | |
684 | // If attr value is placeholder, do not check it. |
685 | if (attr.second.placeholder().empty()) { |
686 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
687 | ValidateAttrValue(attr.second, *iter->second), |
688 | "; NodeDef: " , FormatNodeDefForError(node_def), "; " , |
689 | SummarizeOpDef(op_def)); |
690 | } |
691 | |
692 | // Keep track of which attr names have (not) been found in the NodeDef. |
693 | op_attrs.erase(iter); |
694 | } |
695 | |
696 | // Were all attrs in the OpDef found in the NodeDef? |
697 | if (!op_attrs.empty()) { |
698 | string attrs; |
699 | for (const auto& attr_pair : op_attrs) { |
700 | if (!attrs.empty()) strings::StrAppend(&attrs, "', '" ); |
701 | strings::StrAppend(&attrs, attr_pair.first); |
702 | } |
703 | return errors::InvalidArgument( |
704 | "NodeDef missing attr" , op_attrs.size() == 1 ? " '" : "s '" , attrs, |
705 | "' from " , SummarizeOpDef(op_def), |
706 | "; NodeDef: " , FormatNodeDefForError(node_def)); |
707 | } |
708 | |
709 | // Validate the number of inputs. |
710 | DataTypeVector inputs, outputs; |
711 | TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs)); |
712 | |
713 | if (num_inputs != inputs.size()) { |
714 | return errors::InvalidArgument( |
715 | "NodeDef expected inputs '" , DataTypeVectorString(inputs), |
716 | "' do not match " , num_inputs, " inputs specified; " , |
717 | SummarizeOpDef(op_def), "; NodeDef: " , FormatNodeDefForError(node_def)); |
718 | } |
719 | |
720 | return OkStatus(); |
721 | } |
722 | |
723 | namespace { // Helpers for NameRangesForNode() |
724 | |
725 | Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def, |
726 | const OpDef& op_def, int* num) { |
727 | if (!arg_def.number_attr().empty()) { |
728 | // Same type repeated "num" times. |
729 | return GetNodeAttr(attrs, arg_def.number_attr(), num); |
730 | } else if (!arg_def.type_list_attr().empty()) { |
731 | const AttrValue* attr_value; |
732 | TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value)); |
733 | *num = attr_value->list().type_size(); |
734 | } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { |
735 | *num = 1; |
736 | } else { |
737 | return errors::InvalidArgument( |
738 | "Argument '" , arg_def.name(), |
739 | "' incorrectly specified in op definition: " , SummarizeOpDef(op_def)); |
740 | } |
741 | return OkStatus(); |
742 | } |
743 | |
744 | Status NameRangesHelper(const AttrSlice& attrs, |
745 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, |
746 | const OpDef& op_def, NameRangeMap* result) { |
747 | int start = 0; |
748 | int num; |
749 | for (const auto& arg : args) { |
750 | TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num)); |
751 | (*result)[arg.name()] = std::make_pair(start, start + num); |
752 | start += num; |
753 | } |
754 | return OkStatus(); |
755 | } |
756 | |
757 | } // namespace |
758 | |
759 | Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, |
760 | NameRangeMap* inputs, NameRangeMap* outputs) { |
761 | if (inputs != nullptr) { |
762 | TF_RETURN_IF_ERROR( |
763 | NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs)); |
764 | } |
765 | if (outputs != nullptr) { |
766 | return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs); |
767 | } |
768 | return OkStatus(); |
769 | } |
770 | |
771 | void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { |
772 | for (const auto& attr_def : op_def.attr()) { |
773 | AttrSlice attrs(*node_def); |
774 | if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) { |
775 | AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def); |
776 | } |
777 | } |
778 | } |
779 | |
780 | void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def) { |
781 | AttrSlice attrs(*node_def); |
782 | for (const auto& attr_def : op_def.attr()) { |
783 | if (attr_def.has_default_value()) { |
784 | const AttrValue* attr = attrs.Find(attr_def.name()); |
785 | if (attr && AreAttrValuesEqual(*attr, attr_def.default_value())) |
786 | node_def->mutable_attr()->erase(attr_def.name()); |
787 | } |
788 | } |
789 | } |
790 | |
791 | namespace { |
792 | |
793 | using ::tensorflow::tstring; |
794 | using ::tensorflow::strings::Scanner; |
795 | |
796 | bool IsValidNodeName(StringPiece sp) { |
797 | Scanner scanner(sp); |
798 | scanner.One(Scanner::LETTER_DIGIT_DOT) |
799 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
800 | |
801 | while (true) { |
802 | if (!scanner.GetResult()) // Some error in previous iteration. |
803 | return false; |
804 | if (scanner.empty()) // No error, but nothing left, good. |
805 | return true; |
806 | |
807 | // Absorb another name/namespace, starting with a '>' |
808 | scanner.One(Scanner::RANGLE) |
809 | .One(Scanner::LETTER_DIGIT_DOT) |
810 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
811 | } |
812 | } |
813 | |
814 | bool IsValidDataInputName(StringPiece sp) { |
815 | // Data inputs are op_name, op_name:0, or op_name:12345. |
816 | Scanner scan(sp); |
817 | scan.One(Scanner::LETTER_DIGIT_DOT) |
818 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
819 | |
820 | while (true) { |
821 | if (!scan.GetResult()) // Some error in previous iteration. |
822 | return false; |
823 | if (scan.empty()) // No error, but nothing left, good. |
824 | return true; |
825 | |
826 | if (scan.Peek() == ':') { // Absorb identifier after the colon |
827 | scan.OneLiteral(":" ); |
828 | if (scan.Peek() == '0') { |
829 | scan.OneLiteral("0" ); // :0 |
830 | } else { |
831 | scan.Many(Scanner::DIGIT); // :[1-9][0-9]* |
832 | } |
833 | } else { |
834 | // Absorb another name/namespace, starting with a '>' |
835 | scan.One(Scanner::RANGLE) |
836 | .One(Scanner::LETTER_DIGIT_DOT) |
837 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
838 | } |
839 | } |
840 | } |
841 | |
842 | bool IsValidControlInputName(StringPiece sp) { |
843 | Scanner scan(sp); |
844 | scan.OneLiteral("^" ) |
845 | .One(Scanner::LETTER_DIGIT_DOT) |
846 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
847 | |
848 | while (true) { |
849 | if (!scan.GetResult()) // Some error in previous iteration. |
850 | return false; |
851 | if (scan.empty()) // No error, but nothing left, good. |
852 | return true; |
853 | |
854 | // Absorb another name/namespace, starting with a '>' |
855 | scan.One(Scanner::RANGLE) |
856 | .One(Scanner::LETTER_DIGIT_DOT) |
857 | .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); |
858 | } |
859 | } |
860 | |
861 | const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); |
862 | |
863 | } // namespace |
864 | |
865 | Status ValidateOpInput(const string& input_name, bool* is_control_input) { |
866 | *is_control_input = false; |
867 | if (IsValidDataInputName(input_name)) { |
868 | return OkStatus(); |
869 | } else if (IsValidControlInputName(input_name)) { |
870 | *is_control_input = true; |
871 | return OkStatus(); |
872 | } else { |
873 | return errors::InvalidArgument("Illegal op input name '" , input_name, "'" ); |
874 | } |
875 | } |
876 | |
877 | Status ValidateNodeName(const string& node_name) { |
878 | if (IsValidNodeName(node_name)) { |
879 | return OkStatus(); |
880 | } else { |
881 | return errors::InvalidArgument("Illegal op name '" , node_name, "'" ); |
882 | } |
883 | } |
884 | |
885 | Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { |
886 | Status s = ValidateNodeName(node_def.name()); |
887 | if (!s.ok()) { |
888 | return AttachDef(s, node_def); |
889 | } |
890 | bool in_control_inputs = false; |
891 | for (const string& input_name : node_def.input()) { |
892 | bool is_control_input; |
893 | s = ValidateOpInput(input_name, &is_control_input); |
894 | if (!s.ok()) { |
895 | return AttachDef(s, node_def); |
896 | } |
897 | |
898 | if (in_control_inputs && !is_control_input) { |
899 | return AttachDef(errors::InvalidArgument( |
900 | "All control inputs must follow all data inputs" ), |
901 | node_def); |
902 | } |
903 | in_control_inputs = is_control_input; |
904 | } |
905 | return OkStatus(); |
906 | } |
907 | |
908 | Status AttachDef(const Status& status, const NodeDef& node_def, |
909 | bool allow_multiple_formatted_node) { |
910 | Status ret = status; |
911 | string node_error; |
912 | if (!allow_multiple_formatted_node && |
913 | status.error_message().find("{{node " ) != string::npos) { |
914 | node_error = node_def.name(); |
915 | } else { |
916 | node_error = FormatNodeDefForError(node_def); |
917 | } |
918 | errors::AppendToMessage(&ret, strings::StrCat(" [[" , node_error, "]]" )); |
919 | return ret; |
920 | } |
921 | |
922 | void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { |
923 | node_def->mutable_attr()->insert( |
924 | AttrValueMap::value_type(string(name), value)); |
925 | } |
926 | |
927 | void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) { |
928 | (*node_def->mutable_attr())[string(name)] = std::move(value); |
929 | } |
930 | |
931 | #define ADD_NODE_ATTR(T) \ |
932 | void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ |
933 | AttrValue attr_value; \ |
934 | SetAttrValue(value, &attr_value); \ |
935 | AddNodeAttr(name, attr_value, node_def); \ |
936 | } |
937 | ADD_NODE_ATTR(StringPiece) |
938 | ADD_NODE_ATTR(const char*) |
939 | ADD_NODE_ATTR(int32_t) |
940 | ADD_NODE_ATTR(int64_t) |
941 | ADD_NODE_ATTR(float) |
942 | ADD_NODE_ATTR(double) |
943 | ADD_NODE_ATTR(bool) |
944 | ADD_NODE_ATTR(DataType) |
945 | ADD_NODE_ATTR(const PartialTensorShape&) |
946 | ADD_NODE_ATTR(const Tensor&) |
947 | ADD_NODE_ATTR(const TensorProto&) |
948 | ADD_NODE_ATTR(const NameAttrList&) |
949 | ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>) |
950 | ADD_NODE_ATTR(gtl::ArraySlice<const char*>) |
951 | ADD_NODE_ATTR(gtl::ArraySlice<string>) |
952 | ADD_NODE_ATTR(gtl::ArraySlice<int32>) |
953 | ADD_NODE_ATTR(gtl::ArraySlice<int64_t>) |
954 | ADD_NODE_ATTR(gtl::ArraySlice<float>) |
955 | ADD_NODE_ATTR(gtl::ArraySlice<bool>) |
956 | ADD_NODE_ATTR(const std::vector<bool>&) |
957 | ADD_NODE_ATTR(gtl::ArraySlice<DataType>) |
958 | ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>) |
959 | ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>) |
960 | ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>) |
961 | ADD_NODE_ATTR(gtl::ArraySlice<Tensor>) |
962 | ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>) |
963 | #undef ADD_NODE_ATTR |
964 | |
965 | void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { |
966 | map->insert(AttrValueMap::value_type(string(name), value)); |
967 | } |
968 | |
969 | #define ADD_ATTR(T) \ |
970 | void AddAttr(StringPiece name, T value, AttrValueMap* map) { \ |
971 | AttrValue attr_value; \ |
972 | SetAttrValue(value, &attr_value); \ |
973 | AddAttr(name, attr_value, map); \ |
974 | } |
975 | ADD_ATTR(bool) |
976 | #undef ADD_ATTR |
977 | |
978 | Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, |
979 | NodeDef* node_def, bool uniquify_frame_name) { |
980 | node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix)); |
981 | |
982 | // Update frame name to avoid multiple LoopCond nodes in one frame. |
983 | if (uniquify_frame_name && |
984 | (node_def->op() == "Enter" || node_def->op() == "RefEnter" )) { |
985 | string frame_name; |
986 | TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name" , &frame_name)); |
987 | AttrValue& attr = (*node_def->mutable_attr())["frame_name" ]; |
988 | frame_name = strings::StrCat(prefix, frame_name, suffix); |
989 | attr.set_s(frame_name); |
990 | } |
991 | |
992 | return OkStatus(); |
993 | } |
994 | |
995 | Status MaybeAddPrefixToColocationConstraints( |
996 | const std::unordered_set<string>& match, StringPiece prefix, |
997 | NodeDef* node_def) { |
998 | auto attr = node_def->mutable_attr()->find(kColocationAttrName); |
999 | if (attr == node_def->mutable_attr()->end()) { |
1000 | return OkStatus(); |
1001 | } |
1002 | auto constraints_list = attr->second.mutable_list(); |
1003 | auto constraints_size = constraints_list->s_size(); |
1004 | for (size_t i = 0; i < constraints_size; ++i) { |
1005 | StringPiece original(constraints_list->s(i)); |
1006 | if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { |
1007 | if (match.find(string(original)) != match.end()) { |
1008 | (*constraints_list->mutable_s(i)) = |
1009 | strings::StrCat(kColocationGroupPrefix, prefix, original); |
1010 | } |
1011 | } |
1012 | } |
1013 | return OkStatus(); |
1014 | } |
1015 | |
1016 | Status MaybeUpdateColocationConstraintsWithMap( |
1017 | const std::map<absl::string_view, absl::string_view>& node_name_map, |
1018 | NodeDef* node_def) { |
1019 | auto attr = node_def->mutable_attr()->find(kColocationAttrName); |
1020 | if (attr == node_def->mutable_attr()->end()) { |
1021 | return OkStatus(); |
1022 | } |
1023 | auto constraints_list = attr->second.mutable_list(); |
1024 | auto constraints_size = constraints_list->s_size(); |
1025 | for (size_t i = 0; i < constraints_size; ++i) { |
1026 | StringPiece original(constraints_list->s(i)); |
1027 | if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { |
1028 | if (node_name_map.find(original) != node_name_map.end()) { |
1029 | (*constraints_list->mutable_s(i)) = |
1030 | strings::StrCat(kColocationGroupPrefix, node_name_map.at(original)); |
1031 | } |
1032 | } |
1033 | } |
1034 | return OkStatus(); |
1035 | } |
1036 | |
1037 | } // namespace tensorflow |
1038 | |