1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_util.h"
17
18#include <algorithm>
19#include <unordered_map>
20#include <vector>
21
22#include "absl/strings/str_cat.h"
23#include "absl/strings/str_join.h"
24#include "tensorflow/core/framework/attr_value.pb.h"
25#include "tensorflow/core/framework/attr_value_util.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/framework/op_def.pb.h"
28#include "tensorflow/core/framework/op_def_util.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor.pb.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/framework/tensor_shape.pb.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/framework/types.pb.h"
35#include "tensorflow/core/lib/gtl/map_util.h"
36#include "tensorflow/core/platform/errors.h"
37#include "tensorflow/core/platform/scanner.h"
38#include "tensorflow/core/platform/status.h"
39#include "tensorflow/core/platform/strcat.h"
40#include "tensorflow/core/platform/stringpiece.h"
41#include "tensorflow/core/platform/types.h"
42
43namespace tensorflow {
44
45const char* const kColocationAttrName = "_class";
46const char* const kColocationGroupPrefix = "loc:@";
47// For TPU distributed rewrite, TPU args are collected and "staged" on the local
48// host using an IdentityN TF op. Some args may result from a remote source.
49// When all arg tensors are available, the TPUExecute op can be inovoked. See
50// DistributedTPURewritePass for more details.
51const char* const kTpuExecuteStagingOp = "IdentityN";
52const char* const kTpuExecuteStagingNodeName = "_variable_copy";
53
54AttrSlice::AttrSlice() : ndef_(nullptr) {
55 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
56 attrs_ = kEmptyAttrValueMap;
57}
58
59// Do not cache the map field reference because that may be invalidated on
60// Clear.
61AttrSlice::AttrSlice(const NodeDef& node_def)
62 : ndef_(&node_def), attrs_(nullptr) {}
63
64AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
65
66string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
67 string ret;
68
69 // We sort the attrs so the output is deterministic.
70 std::vector<string> attr_names;
71 attr_names.reserve(attrs.size());
72 for (const auto& attr : attrs) {
73 attr_names.push_back(attr.first);
74 }
75 std::sort(attr_names.begin(), attr_names.end());
76 bool first = true;
77 for (const string& attr_name : attr_names) {
78 if (!first) strings::StrAppend(&ret, ", ");
79 first = false;
80 strings::StrAppend(&ret, attr_name, "=",
81 SummarizeAttrValue(*attrs.Find(attr_name)));
82 }
83
84 // Consider the device to be a final attr with name "_device".
85 if (!device.empty()) {
86 if (!first) strings::StrAppend(&ret, ", ");
87 first = false;
88 strings::StrAppend(&ret, "_device=\"", device, "\"");
89 }
90 return ret;
91}
92
93string AttrSlice::SummarizeNode() const {
94 return ndef_ ? SummarizeNodeDef(*ndef_)
95 : strings::StrCat(
96 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
97}
98
99string AttrSlice::DebugString() const {
100 std::vector<string> attr_key_vals;
101 attr_key_vals.reserve(attrs()->size());
102 for (const auto& it : *this) {
103 const string& name = it.first;
104 const AttrValue& attr_value = it.second;
105 attr_key_vals.push_back(
106 absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
107 }
108 return absl::StrJoin(attr_key_vals, ", ");
109}
110
111string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
112 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
113 " = ", node_def.op(), "[");
114 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
115 strings::StrAppend(&ret, "](");
116
117 // Output inputs, including control inputs, verbatim.
118 bool first = true;
119 for (const string& input : node_def.input()) {
120 if (!first) strings::StrAppend(&ret, ", ");
121 first = false;
122 if (max_inputs_in_summary-- == 0) {
123 strings::StrAppend(&ret, "...");
124 break;
125 }
126 strings::StrAppend(&ret, input);
127 }
128 strings::StrAppend(&ret, ")");
129 return ret;
130}
131
132string SummarizeAttrs(const NodeDef& node_def) {
133 return SummarizeAttrsHelper(node_def, node_def.device());
134}
135
136string FormatNodeDefForError(
137 StringPiece node_name, bool has_experimental_debug_info,
138 const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
139 return !has_experimental_debug_info ||
140 experimental_debug_info.original_node_names().empty()
141 ? errors::FormatNodeNameForError(string(node_name))
142 : errors::FormatOriginalNodeLocationForError(
143 experimental_debug_info.original_node_names(),
144 experimental_debug_info.original_func_names());
145}
146
147string FormatNodeDefForError(const NodeDef& node_def) {
148 return FormatNodeDefForError(node_def.name(),
149 node_def.has_experimental_debug_info(),
150 node_def.experimental_debug_info());
151}
152
153const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
154 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
155 // requires that the keys used for lookups have type 'const string&'. Because
156 // this method takes a StringPiece, it is necessary to allocate a temporary
157 // string, copy attr_name to it, and then use that temporary string for the
158 // lookup. This causes an excessive number of short-lived allocations, and for
159 // large graphs, this can be a significant cost.
160 //
161 // Because most nodes have a small number of attributes, a simple linear scan
162 // is generally more efficient than a hashed lookup. If google::protobuf::Map
163 // changes so that it supports efficient lookups using StringPiece instead of
164 // const string&, then this code could be changed to use attrs()->find()
165 // again.
166
167 for (const auto& attr : *attrs()) {
168 if (attr.first == attr_name) {
169 return &attr.second;
170 }
171 }
172 return nullptr;
173}
174
175const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
176 auto iter = attrs()->find(attr_name);
177 if (iter != attrs()->end()) {
178 return &iter->second;
179 } else {
180 return nullptr;
181 }
182}
183
184Status AttrSlice::CheckFind(StringPiece attr_name,
185 const AttrValue* attr_value) const {
186 if (attr_value != nullptr) {
187 return OkStatus();
188 }
189 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
190 // Skip AttachDef for internal attrs since it is a little bit
191 // expensive and it is common for them to correctly not be included
192 // in a NodeDef.
193 if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
194 s = AttachDef(s, *ndef_);
195 }
196 return s;
197}
198
199Status AttrSlice::Find(StringPiece attr_name,
200 const AttrValue** attr_value) const {
201 *attr_value = Find(attr_name);
202 return CheckFind(attr_name, *attr_value);
203}
204
205Status AttrSlice::FindByString(const string& attr_name,
206 const AttrValue** attr_value) const {
207 *attr_value = FindByString(attr_name);
208 return CheckFind(attr_name, *attr_value);
209}
210
211bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
212 if (size() != other.size()) return false;
213
214 for (const auto& attr : *other.attrs()) {
215 auto iter = attrs()->find(attr.first);
216 if (iter == attrs()->end()) return false;
217 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
218 // TensorProto is a nonunique representation of Tensor. This bug will go
219 // away once AttrSlice switches over to NodeInfo.
220 iter->second.SerializeToString(&scratch->a);
221 attr.second.SerializeToString(&scratch->b);
222 if (scratch->a != scratch->b) return false;
223 }
224 return true;
225}
226
227// The ... is to allow the caller to inject some value validation code. Use
228// just ; if no additional validation code is needed.
229#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
230 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
231 TYPE* value) { \
232 const AttrValue* attr_value; \
233 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
234 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
235 const auto& v = attr_value->FIELD(); \
236 __VA_ARGS__; \
237 *value = CAST; \
238 return OkStatus(); \
239 } \
240 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
241 std::vector<TYPE>* value) { \
242 const AttrValue* attr_value; \
243 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
244 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
245 value->reserve(attr_value->list().FIELD().size()); \
246 for (const auto& v : attr_value->list().FIELD()) { \
247 __VA_ARGS__; \
248 value->APPEND_OP(CAST); \
249 } \
250 return OkStatus(); \
251 }
252
253#define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
254 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
255 TYPE* value) { \
256 const AttrValue* attr_value = attrs.Find(attr_name); \
257 if (attr_value == nullptr) { \
258 return false; \
259 } \
260 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
261 if (!s.ok()) { \
262 return false; \
263 } \
264 const auto& v = attr_value->FIELD(); \
265 __VA_ARGS__; \
266 *value = CAST; \
267 return true; \
268 } \
269 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
270 std::vector<TYPE>* value) { \
271 const AttrValue* attr_value = attrs.Find(attr_name); \
272 if (attr_value == nullptr) { \
273 return false; \
274 } \
275 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
276 if (!s.ok()) { \
277 return false; \
278 } \
279 value->reserve(attr_value->list().FIELD().size()); \
280 for (const auto& v : attr_value->list().FIELD()) { \
281 __VA_ARGS__; \
282 value->APPEND_OP(CAST); \
283 } \
284 return true; \
285 }
286DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
287DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
288DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
289DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
290DEFINE_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
291DEFINE_TRY_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
292DEFINE_GET_ATTR(
293 int32, i, "int", emplace_back, static_cast<int32>(v),
294 if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
295 return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
296 " out of range for an int32");
297 })
298DEFINE_TRY_GET_ATTR(
299 int32, i, "int", emplace_back, static_cast<int32>(v),
300 if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
301 static int log_counter = 0;
302 if (log_counter < 10) {
303 log_counter++;
304 LOG(WARNING) << "Attr " << attr_name << " has value " << v
305 << " out of range for an int32";
306 }
307 return false;
308 })
309DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
310DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
311DEFINE_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
312DEFINE_TRY_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
313DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
314 ;)
315DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
316 static_cast<DataType>(v),
317 ;)
318DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
319DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
320 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
321DEFINE_TRY_GET_ATTR(
322 TensorShape, shape, "shape", emplace_back, TensorShape(v),
323 if (!TensorShape::IsValidShape(v).ok()) {
324 static int log_counter = 0;
325 if (log_counter < 10) {
326 log_counter++;
327 LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
328 << v.DebugString();
329 }
330 return false;
331 })
332DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
333 PartialTensorShape(v),
334 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
335DEFINE_GET_ATTR(
336 Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
337 return errors::InvalidArgument("Attr ", attr_name, " has value ",
338 v.ShortDebugString(),
339 " that can't be converted to a Tensor");
340 })
341DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
342#undef DEFINE_GET_ATTR
343
344bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
345 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
346}
347
348static const string& kEmptyString = *new string();
349
350const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
351 const AttrValue* attr_value = attrs.Find(attr_name);
352 if (attr_value == nullptr) {
353 return kEmptyString;
354 }
355 Status s = AttrValueHasType(*attr_value, "string");
356 if (!s.ok()) {
357 return kEmptyString;
358 }
359 return attr_value->s();
360}
361
362bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
363 std::vector<const string*>* value) {
364 const AttrValue* attr_value = attrs.Find(attr_name);
365 if (attr_value == nullptr) {
366 return false;
367 }
368 Status s = AttrValueHasType(*attr_value, "list(string)");
369 if (!s.ok()) {
370 return false;
371 }
372 value->reserve(attr_value->list().s().size());
373 for (const auto& v : attr_value->list().s()) {
374 value->push_back(&v);
375 }
376 return true;
377}
378
379bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
380 std::vector<const TensorShapeProto*>* value) {
381 const AttrValue* attr_value = attrs.Find(attr_name);
382 if (attr_value == nullptr) {
383 return false;
384 }
385 Status s = AttrValueHasType(*attr_value, "list(shape)");
386 if (!s.ok()) {
387 return false;
388 }
389 value->reserve(attr_value->list().shape().size());
390 for (const auto& v : attr_value->list().shape()) {
391 value->push_back(&v);
392 }
393 return true;
394}
395
396Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
397 DataTypeVector* value) {
398 const AttrValue* attr_value;
399 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
400 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
401 for (const auto& v : attr_value->list().type()) {
402 value->push_back(static_cast<DataType>(v));
403 }
404 return OkStatus();
405}
406
407Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
408 const TensorProto** value) {
409 const AttrValue* attr_value;
410 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
411 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
412 *value = &attr_value->tensor();
413 return OkStatus();
414}
415
416bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
417 const TensorProto** value) {
418 const AttrValue* attr_value = attrs.Find(attr_name);
419 if (attr_value == nullptr) {
420 return false;
421 }
422 Status s = AttrValueHasType(*attr_value, "tensor");
423 if (!s.ok()) {
424 return false;
425 }
426 *value = &attr_value->tensor();
427 return true;
428}
429
430Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
431 const NameAttrList** value) {
432 const AttrValue* attr_value;
433 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
434 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
435 *value = &attr_value->func();
436 return OkStatus();
437}
438
439bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
440 const NameAttrList** value) {
441 const AttrValue* attr_value = attrs.Find(attr_name);
442 if (attr_value == nullptr) {
443 return false;
444 }
445 Status s = AttrValueHasType(*attr_value, "func");
446 if (!s.ok()) {
447 return false;
448 }
449 *value = &attr_value->func();
450 return true;
451}
452
453Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
454 Padding* value) {
455 string str_value;
456 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
457 return GetPaddingFromString(str_value, value);
458}
459
460namespace { // Helper for InOutTypesForNode().
461
462template <class NodeDefOrAttrSlice>
463Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
464 const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
465 const int original_size = sig->size();
466 if (!arg_def.number_attr().empty()) {
467 // Same type repeated "repeats" times.
468 int64_t repeats = -1;
469 TF_RETURN_IF_ERROR(
470 GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
471 // We can't handle outputs that are larger than int32 sizes.
472 if (static_cast<int64_t>(static_cast<int32>(repeats)) != repeats) {
473 return errors::InvalidArgument("Number of outputs is too big: ", repeats);
474 }
475 if (repeats < 0) {
476 return errors::InvalidArgument("Value for number_attr() ", repeats,
477 " < 0");
478 }
479
480 if (!arg_def.type_attr().empty()) {
481 DataType dtype;
482 TF_RETURN_IF_ERROR(
483 GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
484 for (int i = 0; i < repeats; ++i) {
485 sig->push_back(dtype);
486 }
487 } else if (arg_def.type() != DT_INVALID) {
488 for (int i = 0; i < repeats; ++i) {
489 sig->push_back(arg_def.type());
490 }
491 } else {
492 return errors::InvalidArgument("Missing type or type_attr field in ",
493 arg_def.ShortDebugString());
494 }
495 } else if (!arg_def.type_attr().empty()) {
496 const AttrValue* attr_value;
497 TF_RETURN_IF_ERROR(AttrSlice(node_or_attrs)
498 .FindByString(arg_def.type_attr(), &attr_value));
499 sig->push_back(attr_value->type());
500 } else if (!arg_def.type_list_attr().empty()) {
501 const AttrValue* attr_value;
502 TF_RETURN_IF_ERROR(
503 AttrSlice(node_or_attrs)
504 .FindByString(arg_def.type_list_attr(), &attr_value));
505 for (int dtype : attr_value->list().type()) {
506 sig->push_back(static_cast<DataType>(dtype));
507 }
508 } else if (arg_def.type() != DT_INVALID) {
509 sig->push_back(arg_def.type());
510 } else {
511 return errors::InvalidArgument("No type fields in ",
512 arg_def.ShortDebugString());
513 }
514 if (arg_def.is_ref()) {
515 // For all types that were added by this function call, make them refs.
516 for (size_t i = original_size; i < sig->size(); ++i) {
517 if (IsRefType((*sig)[i])) {
518 return errors::InvalidArgument(
519 "Requested reference to a reference type: ",
520 arg_def.ShortDebugString());
521 }
522 (*sig)[i] = MakeRefType((*sig)[i]);
523 }
524 }
525 return OkStatus();
526}
527
528} // namespace
529
530Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
531 int input_port, DataType* input_type) {
532 DataTypeVector input_types;
533 for (const auto& arg : op_def.input_arg()) {
534 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
535 int input_types_size = input_types.size();
536 if (input_types_size > input_port) {
537 const DataType dtype = input_types[input_port];
538 *input_type = dtype;
539 return OkStatus();
540 }
541 }
542 return errors::InvalidArgument("Input ", input_port, " not found for node ",
543 node_def.name());
544}
545
546Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
547 DataTypeVector* inputs) {
548 for (const auto& arg : op_def.input_arg()) {
549 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
550 }
551 return OkStatus();
552}
553
554Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
555 int output_port, DataType* output_type) {
556 DataTypeVector output_types;
557 for (const auto& arg : op_def.output_arg()) {
558 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
559 int output_types_size = output_types.size();
560 if (output_types_size > output_port) {
561 const DataType dtype = output_types[output_port];
562 *output_type = dtype;
563 return OkStatus();
564 }
565 }
566 return errors::InvalidArgument("Output ", output_port, " not found for node ",
567 node_def.name());
568}
569
570Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
571 DataTypeVector* outputs) {
572 for (const auto& arg : op_def.output_arg()) {
573 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
574 }
575 return OkStatus();
576}
577
578Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
579 DataTypeVector* outputs) {
580 for (const auto& arg : op_def.output_arg()) {
581 TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
582 }
583 return OkStatus();
584}
585
586Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
587 DataTypeVector* inputs, DataTypeVector* outputs) {
588 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
589 return OutputTypesForNode(node_def, op_def, outputs);
590}
591
592Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
593 int* num_outputs) {
594 DataTypeVector outputs;
595 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
596 *num_outputs = outputs.size();
597 return OkStatus();
598}
599
600int OpPortIdToArgId(const NodeDef& node,
601 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
602 int port_id) {
603 for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
604 if (port_id < 0) {
605 return -1;
606 } else if (port_id == 0) {
607 return arg_id;
608 }
609
610 // Default is 1 port per arg.
611 int n = 1;
612
613 const auto& arg = args.Get(arg_id);
614 if (!arg.number_attr().empty()) {
615 n = node.attr().at(arg.number_attr()).i();
616 } else if (!arg.type_list_attr().empty()) {
617 n = node.attr().at(arg.type_list_attr()).list().type_size();
618 }
619
620 if (n < 0) {
621 // This should never happen.
622 DCHECK_GE(n, 0);
623 return -1;
624 } else if (port_id < n) {
625 return arg_id;
626 }
627 port_id -= n;
628 }
629
630 return -1;
631}
632
633Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
634 if (node_def.op() != op_def.name()) {
635 return errors::InvalidArgument(
636 "NodeDef op '", node_def.op(), "' does not match ",
637 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
638 }
639
640 bool seen_control = false;
641 size_t num_inputs = 0;
642 // TODO(josh11b): Unify the input field validation.
643 for (const string& input : node_def.input()) {
644 if (absl::StartsWith(input, "^")) {
645 seen_control = true;
646 if (input.find(':') != string::npos) {
647 return errors::InvalidArgument("Control input '", input,
648 "' must not have ':' in NodeDef: ",
649 FormatNodeDefForError(node_def));
650 }
651 } else if (seen_control) {
652 return errors::InvalidArgument("Non-control input '", input,
653 "' after control input in NodeDef: ",
654 FormatNodeDefForError(node_def));
655 } else {
656 ++num_inputs;
657 }
658 }
659
660 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
661 for (const auto& attr : op_def.attr()) {
662 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
663 return errors::InvalidArgument("OpDef has duplicate attr name '",
664 attr.name(),
665 "': ", SummarizeOpDef(op_def));
666 }
667 }
668 for (const auto& attr : node_def.attr()) {
669 // Allow internal optional attributes with names starting with "_".
670 if (absl::StartsWith(attr.first, "_")) {
671 continue;
672 }
673 auto iter = op_attrs.find(attr.first);
674 if (iter == op_attrs.end()) {
675 LOG_EVERY_N_SEC(ERROR, 5)
676 << "NodeDef mentions attribute " << attr.first
677 << " which is not in the op definition: " << SummarizeOpDef(op_def)
678 << " This may be expected if your graph generating binary is newer "
679 << " than this binary. Unknown attributes will be ignored."
680 << " NodeDef: " << FormatNodeDefForError(node_def);
681 continue;
682 }
683
684 // If attr value is placeholder, do not check it.
685 if (attr.second.placeholder().empty()) {
686 TF_RETURN_WITH_CONTEXT_IF_ERROR(
687 ValidateAttrValue(attr.second, *iter->second),
688 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
689 SummarizeOpDef(op_def));
690 }
691
692 // Keep track of which attr names have (not) been found in the NodeDef.
693 op_attrs.erase(iter);
694 }
695
696 // Were all attrs in the OpDef found in the NodeDef?
697 if (!op_attrs.empty()) {
698 string attrs;
699 for (const auto& attr_pair : op_attrs) {
700 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
701 strings::StrAppend(&attrs, attr_pair.first);
702 }
703 return errors::InvalidArgument(
704 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
705 "' from ", SummarizeOpDef(op_def),
706 "; NodeDef: ", FormatNodeDefForError(node_def));
707 }
708
709 // Validate the number of inputs.
710 DataTypeVector inputs, outputs;
711 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
712
713 if (num_inputs != inputs.size()) {
714 return errors::InvalidArgument(
715 "NodeDef expected inputs '", DataTypeVectorString(inputs),
716 "' do not match ", num_inputs, " inputs specified; ",
717 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
718 }
719
720 return OkStatus();
721}
722
723namespace { // Helpers for NameRangesForNode()
724
725Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
726 const OpDef& op_def, int* num) {
727 if (!arg_def.number_attr().empty()) {
728 // Same type repeated "num" times.
729 return GetNodeAttr(attrs, arg_def.number_attr(), num);
730 } else if (!arg_def.type_list_attr().empty()) {
731 const AttrValue* attr_value;
732 TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
733 *num = attr_value->list().type_size();
734 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
735 *num = 1;
736 } else {
737 return errors::InvalidArgument(
738 "Argument '", arg_def.name(),
739 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
740 }
741 return OkStatus();
742}
743
744Status NameRangesHelper(const AttrSlice& attrs,
745 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
746 const OpDef& op_def, NameRangeMap* result) {
747 int start = 0;
748 int num;
749 for (const auto& arg : args) {
750 TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
751 (*result)[arg.name()] = std::make_pair(start, start + num);
752 start += num;
753 }
754 return OkStatus();
755}
756
757} // namespace
758
759Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
760 NameRangeMap* inputs, NameRangeMap* outputs) {
761 if (inputs != nullptr) {
762 TF_RETURN_IF_ERROR(
763 NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
764 }
765 if (outputs != nullptr) {
766 return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
767 }
768 return OkStatus();
769}
770
771void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
772 for (const auto& attr_def : op_def.attr()) {
773 AttrSlice attrs(*node_def);
774 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
775 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
776 }
777 }
778}
779
780void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def) {
781 AttrSlice attrs(*node_def);
782 for (const auto& attr_def : op_def.attr()) {
783 if (attr_def.has_default_value()) {
784 const AttrValue* attr = attrs.Find(attr_def.name());
785 if (attr && AreAttrValuesEqual(*attr, attr_def.default_value()))
786 node_def->mutable_attr()->erase(attr_def.name());
787 }
788 }
789}
790
791namespace {
792
793using ::tensorflow::tstring;
794using ::tensorflow::strings::Scanner;
795
796bool IsValidNodeName(StringPiece sp) {
797 Scanner scanner(sp);
798 scanner.One(Scanner::LETTER_DIGIT_DOT)
799 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
800
801 while (true) {
802 if (!scanner.GetResult()) // Some error in previous iteration.
803 return false;
804 if (scanner.empty()) // No error, but nothing left, good.
805 return true;
806
807 // Absorb another name/namespace, starting with a '>'
808 scanner.One(Scanner::RANGLE)
809 .One(Scanner::LETTER_DIGIT_DOT)
810 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
811 }
812}
813
814bool IsValidDataInputName(StringPiece sp) {
815 // Data inputs are op_name, op_name:0, or op_name:12345.
816 Scanner scan(sp);
817 scan.One(Scanner::LETTER_DIGIT_DOT)
818 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
819
820 while (true) {
821 if (!scan.GetResult()) // Some error in previous iteration.
822 return false;
823 if (scan.empty()) // No error, but nothing left, good.
824 return true;
825
826 if (scan.Peek() == ':') { // Absorb identifier after the colon
827 scan.OneLiteral(":");
828 if (scan.Peek() == '0') {
829 scan.OneLiteral("0"); // :0
830 } else {
831 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
832 }
833 } else {
834 // Absorb another name/namespace, starting with a '>'
835 scan.One(Scanner::RANGLE)
836 .One(Scanner::LETTER_DIGIT_DOT)
837 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
838 }
839 }
840}
841
842bool IsValidControlInputName(StringPiece sp) {
843 Scanner scan(sp);
844 scan.OneLiteral("^")
845 .One(Scanner::LETTER_DIGIT_DOT)
846 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
847
848 while (true) {
849 if (!scan.GetResult()) // Some error in previous iteration.
850 return false;
851 if (scan.empty()) // No error, but nothing left, good.
852 return true;
853
854 // Absorb another name/namespace, starting with a '>'
855 scan.One(Scanner::RANGLE)
856 .One(Scanner::LETTER_DIGIT_DOT)
857 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
858 }
859}
860
861const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
862
863} // namespace
864
865Status ValidateOpInput(const string& input_name, bool* is_control_input) {
866 *is_control_input = false;
867 if (IsValidDataInputName(input_name)) {
868 return OkStatus();
869 } else if (IsValidControlInputName(input_name)) {
870 *is_control_input = true;
871 return OkStatus();
872 } else {
873 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
874 }
875}
876
877Status ValidateNodeName(const string& node_name) {
878 if (IsValidNodeName(node_name)) {
879 return OkStatus();
880 } else {
881 return errors::InvalidArgument("Illegal op name '", node_name, "'");
882 }
883}
884
885Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
886 Status s = ValidateNodeName(node_def.name());
887 if (!s.ok()) {
888 return AttachDef(s, node_def);
889 }
890 bool in_control_inputs = false;
891 for (const string& input_name : node_def.input()) {
892 bool is_control_input;
893 s = ValidateOpInput(input_name, &is_control_input);
894 if (!s.ok()) {
895 return AttachDef(s, node_def);
896 }
897
898 if (in_control_inputs && !is_control_input) {
899 return AttachDef(errors::InvalidArgument(
900 "All control inputs must follow all data inputs"),
901 node_def);
902 }
903 in_control_inputs = is_control_input;
904 }
905 return OkStatus();
906}
907
908Status AttachDef(const Status& status, const NodeDef& node_def,
909 bool allow_multiple_formatted_node) {
910 Status ret = status;
911 string node_error;
912 if (!allow_multiple_formatted_node &&
913 status.error_message().find("{{node ") != string::npos) {
914 node_error = node_def.name();
915 } else {
916 node_error = FormatNodeDefForError(node_def);
917 }
918 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
919 return ret;
920}
921
922void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
923 node_def->mutable_attr()->insert(
924 AttrValueMap::value_type(string(name), value));
925}
926
927void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
928 (*node_def->mutable_attr())[string(name)] = std::move(value);
929}
930
931#define ADD_NODE_ATTR(T) \
932 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
933 AttrValue attr_value; \
934 SetAttrValue(value, &attr_value); \
935 AddNodeAttr(name, attr_value, node_def); \
936 }
937ADD_NODE_ATTR(StringPiece)
938ADD_NODE_ATTR(const char*)
939ADD_NODE_ATTR(int32_t)
940ADD_NODE_ATTR(int64_t)
941ADD_NODE_ATTR(float)
942ADD_NODE_ATTR(double)
943ADD_NODE_ATTR(bool)
944ADD_NODE_ATTR(DataType)
945ADD_NODE_ATTR(const PartialTensorShape&)
946ADD_NODE_ATTR(const Tensor&)
947ADD_NODE_ATTR(const TensorProto&)
948ADD_NODE_ATTR(const NameAttrList&)
949ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
950ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
951ADD_NODE_ATTR(gtl::ArraySlice<string>)
952ADD_NODE_ATTR(gtl::ArraySlice<int32>)
953ADD_NODE_ATTR(gtl::ArraySlice<int64_t>)
954ADD_NODE_ATTR(gtl::ArraySlice<float>)
955ADD_NODE_ATTR(gtl::ArraySlice<bool>)
956ADD_NODE_ATTR(const std::vector<bool>&)
957ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
958ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
959ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
960ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
961ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
962ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
963#undef ADD_NODE_ATTR
964
965void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
966 map->insert(AttrValueMap::value_type(string(name), value));
967}
968
969#define ADD_ATTR(T) \
970 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
971 AttrValue attr_value; \
972 SetAttrValue(value, &attr_value); \
973 AddAttr(name, attr_value, map); \
974 }
975ADD_ATTR(bool)
976#undef ADD_ATTR
977
978Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
979 NodeDef* node_def, bool uniquify_frame_name) {
980 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
981
982 // Update frame name to avoid multiple LoopCond nodes in one frame.
983 if (uniquify_frame_name &&
984 (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
985 string frame_name;
986 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
987 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
988 frame_name = strings::StrCat(prefix, frame_name, suffix);
989 attr.set_s(frame_name);
990 }
991
992 return OkStatus();
993}
994
995Status MaybeAddPrefixToColocationConstraints(
996 const std::unordered_set<string>& match, StringPiece prefix,
997 NodeDef* node_def) {
998 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
999 if (attr == node_def->mutable_attr()->end()) {
1000 return OkStatus();
1001 }
1002 auto constraints_list = attr->second.mutable_list();
1003 auto constraints_size = constraints_list->s_size();
1004 for (size_t i = 0; i < constraints_size; ++i) {
1005 StringPiece original(constraints_list->s(i));
1006 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1007 if (match.find(string(original)) != match.end()) {
1008 (*constraints_list->mutable_s(i)) =
1009 strings::StrCat(kColocationGroupPrefix, prefix, original);
1010 }
1011 }
1012 }
1013 return OkStatus();
1014}
1015
1016Status MaybeUpdateColocationConstraintsWithMap(
1017 const std::map<absl::string_view, absl::string_view>& node_name_map,
1018 NodeDef* node_def) {
1019 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
1020 if (attr == node_def->mutable_attr()->end()) {
1021 return OkStatus();
1022 }
1023 auto constraints_list = attr->second.mutable_list();
1024 auto constraints_size = constraints_list->s_size();
1025 for (size_t i = 0; i < constraints_size; ++i) {
1026 StringPiece original(constraints_list->s(i));
1027 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1028 if (node_name_map.find(original) != node_name_map.end()) {
1029 (*constraints_list->mutable_s(i)) =
1030 strings::StrCat(kColocationGroupPrefix, node_name_map.at(original));
1031 }
1032 }
1033 }
1034 return OkStatus();
1035}
1036
1037} // namespace tensorflow
1038