1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/attr_value_util.h"
17
18#include <string>
19#include <unordered_map>
20#include <vector>
21
22#include "absl/strings/escaping.h"
23#include "tensorflow/core/framework/attr_value.pb_text.h"
24#include "tensorflow/core/framework/tensor.pb_text.h"
25#include "tensorflow/core/framework/tensor_shape.pb.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/framework/types.pb_text.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/core/stringpiece.h"
30#include "tensorflow/core/lib/hash/hash.h"
31#include "tensorflow/core/lib/strings/proto_serialization.h"
32#include "tensorflow/core/lib/strings/str_util.h"
33#include "tensorflow/core/platform/fingerprint.h"
34#include "tensorflow/core/platform/protobuf.h"
35
36namespace tensorflow {
37namespace {
38
39// Do not construct large tensors to compute their hash or compare for equality.
40constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
41
42// Limit nesting of tensors to 100 deep to prevent memory overflow.
43constexpr int kMaxTensorNestDepth = 100;
44
45// Return the size of the tensor represented by this TensorProto. If shape is
46// not fully defined return -1.
47int64_t TensorByteSize(const TensorProto& t) {
48 // num_elements returns -1 if shape is not fully defined.
49 int64_t num_elems = PartialTensorShape(t.tensor_shape()).num_elements();
50 return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
51}
52
53// Compute TensorProto hash by creating a Tensor, serializing it as tensor
54// content, and computing a hash of it's string representation. This is unsafe
55// operation, because large tensors can be represented as TensorProto, but can't
56// be serialized to tensor content.
57uint64 TensorProtoHash(const TensorProto& tp) {
58 Tensor tensor(tp.dtype());
59 bool success = tensor.FromProto(tp);
60 DCHECK(success);
61 TensorProto p;
62 tensor.AsProtoTensorContent(&p);
63 return DeterministicProtoHash64(p);
64}
65
66// Do not create large tensors in memory, compute hash based on TensorProto
67// string representation. Tensors with identical content potentially can have a
68// different hash code if they are defined with different TensorProto
69// representations.
70uint64 FastTensorProtoHash(const TensorProto& tp) {
71 if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
72 return DeterministicProtoHash64(tp);
73 } else {
74 return TensorProtoHash(tp);
75 }
76}
77
78bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs,
79 bool allow_false_negatives) {
80 // A small TensorProto can expand into a giant Tensor. So we avoid
81 // conversion to an actual Tensor if we can quickly rule out equality
82 // by comparing the Tensor size since different sized Tensors are definitely
83 // different.
84 const int64_t lhs_tensor_bytes = TensorByteSize(lhs);
85 const int64_t rhs_tensor_bytes = TensorByteSize(rhs);
86 if (lhs_tensor_bytes != rhs_tensor_bytes) {
87 return false;
88 }
89
90 // If the TensorProto representation expands into a much bigger Tensor,
91 // we have a fast-path that first compares the protos.
92 const int64_t lhs_proto_bytes = lhs.ByteSizeLong();
93 const bool large_expansion =
94 (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096);
95
96 // If the tensor is very large, we'll only compare the proto representation if
97 // false negatives are allowed. This may miss some equivalent tensors whose
98 // actual tensor values are the same but which are described by different
99 // TensorProtos. This avoids construction of large protos in memory.
100 const bool only_compare_proto =
101 (allow_false_negatives && lhs_tensor_bytes > kMaxAttrValueTensorByteSize);
102 if (large_expansion || only_compare_proto) {
103 if (AreSerializedProtosEqual(lhs, rhs))
104 return true;
105 else if (only_compare_proto)
106 return false;
107 }
108
109 // Finally, compare them by constructing Tensors and serializing them back.
110 // There are multiple equivalent representations of attr values containing
111 // TensorProtos. Comparing Tensor objects is pretty tricky. This is unsafe
112 // operation, because large tensors can be represented as TensorProto, but
113 // can't be serialized to tensor content.
114 Tensor lhs_t(lhs.dtype());
115 bool success = lhs_t.FromProto(lhs);
116 if (!success) {
117 return false;
118 }
119
120 Tensor rhs_t(rhs.dtype());
121 success = rhs_t.FromProto(rhs);
122 if (!success) {
123 return false;
124 }
125
126 TensorProto lhs_tp;
127 lhs_t.AsProtoTensorContent(&lhs_tp);
128
129 TensorProto rhs_tp;
130 rhs_t.AsProtoTensorContent(&rhs_tp);
131
132 return AreSerializedProtosEqual(lhs_tp, rhs_tp);
133}
134
135using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
136
137uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
138 if (a.has_tensor()) return tensor_hash(a.tensor());
139
140 if (a.has_func()) {
141 const NameAttrList& func = a.func();
142 uint64 h = Hash64(func.name());
143 std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
144 for (const auto& pair : map) {
145 h = Hash64(pair.first.data(), pair.first.size(), h);
146 h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
147 }
148 return h;
149 }
150
151 // If `a` is not a tensor or func, get a hash of serialized string.
152 return DeterministicProtoHash64(a);
153}
154
155string SummarizeString(const string& str) {
156 string escaped = absl::CEscape(str);
157
158 // If the string is long, replace the middle with ellipses.
159 constexpr int kMaxStringSummarySize = 80;
160 if (escaped.size() >= kMaxStringSummarySize) {
161 StringPiece prefix(escaped);
162 StringPiece suffix = prefix;
163 prefix.remove_suffix(escaped.size() - 10);
164 suffix.remove_prefix(escaped.size() - 10);
165 return strings::StrCat("\"", prefix, "...", suffix, "\"");
166 } else {
167 return strings::StrCat("\"", escaped, "\"");
168 }
169}
170
171string SummarizeTensor(const TensorProto& tensor_proto) {
172 Tensor t;
173 if (!t.FromProto(tensor_proto)) {
174 return strings::StrCat(
175 "<Invalid TensorProto: ", tensor_proto.ShortDebugString(), ">");
176 }
177 return t.DebugString();
178}
179
180string SummarizeFunc(const NameAttrList& func) {
181 std::vector<string> entries;
182 for (const auto& p : func.attr()) {
183 entries.push_back(
184 strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
185 }
186 std::sort(entries.begin(), entries.end());
187 return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
188}
189
190bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
191 int nests = 0;
192 int maxed_out = to_parse.length();
193 int open_curly = to_parse.find('{');
194 int open_bracket = to_parse.find('<');
195 int close_curly = to_parse.find('}');
196 int close_bracket = to_parse.find('>');
197 if (open_curly == -1) {
198 open_curly = maxed_out;
199 }
200 if (open_bracket == -1) {
201 open_bracket = maxed_out;
202 }
203 int min = std::min(open_curly, open_bracket);
204 do {
205 if (open_curly == maxed_out && open_bracket == maxed_out) {
206 return true;
207 }
208 if (min == open_curly) {
209 nests += 1;
210 open_curly = to_parse.find('{', open_curly + 1);
211 if (open_curly == -1) {
212 open_curly = maxed_out;
213 }
214 } else if (min == open_bracket) {
215 nests += 1;
216 open_bracket = to_parse.find('<', open_bracket + 1);
217 if (open_bracket == -1) {
218 open_bracket = maxed_out;
219 }
220 } else if (min == close_curly) {
221 nests -= 1;
222 close_curly = to_parse.find('}', close_curly + 1);
223 if (close_curly == -1) {
224 close_curly = maxed_out;
225 }
226 } else if (min == close_bracket) {
227 nests -= 1;
228 close_bracket = to_parse.find('>', close_bracket + 1);
229 if (close_bracket == -1) {
230 close_bracket = maxed_out;
231 }
232 }
233 min = std::min({open_curly, open_bracket, close_curly, close_bracket});
234 } while (nests < 100);
235 return false;
236}
237
238} // namespace
239
240string SummarizeAttrValue(const AttrValue& attr_value) {
241 switch (attr_value.value_case()) {
242 case AttrValue::kS:
243 return SummarizeString(attr_value.s());
244 case AttrValue::kI:
245 return strings::StrCat(attr_value.i());
246 case AttrValue::kF:
247 return strings::StrCat(attr_value.f());
248 case AttrValue::kB:
249 return attr_value.b() ? "true" : "false";
250 case AttrValue::kType:
251 return EnumName_DataType(attr_value.type());
252 case AttrValue::kShape:
253 return PartialTensorShape::DebugString(attr_value.shape());
254 case AttrValue::kTensor:
255 return SummarizeTensor(attr_value.tensor());
256 case AttrValue::kList: {
257 std::vector<string> pieces;
258 if (attr_value.list().s_size() > 0) {
259 for (int i = 0; i < attr_value.list().s_size(); ++i) {
260 pieces.push_back(SummarizeString(attr_value.list().s(i)));
261 }
262 } else if (attr_value.list().i_size() > 0) {
263 for (int i = 0; i < attr_value.list().i_size(); ++i) {
264 pieces.push_back(strings::StrCat(attr_value.list().i(i)));
265 }
266 } else if (attr_value.list().f_size() > 0) {
267 for (int i = 0; i < attr_value.list().f_size(); ++i) {
268 pieces.push_back(strings::StrCat(attr_value.list().f(i)));
269 }
270 } else if (attr_value.list().b_size() > 0) {
271 for (int i = 0; i < attr_value.list().b_size(); ++i) {
272 pieces.push_back(attr_value.list().b(i) ? "true" : "false");
273 }
274 } else if (attr_value.list().type_size() > 0) {
275 for (int i = 0; i < attr_value.list().type_size(); ++i) {
276 pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
277 }
278 } else if (attr_value.list().shape_size() > 0) {
279 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
280 pieces.push_back(
281 TensorShape::DebugString(attr_value.list().shape(i)));
282 }
283 } else if (attr_value.list().tensor_size() > 0) {
284 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
285 pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
286 }
287 } else if (attr_value.list().func_size() > 0) {
288 for (int i = 0; i < attr_value.list().func_size(); ++i) {
289 pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
290 }
291 }
292 constexpr int kMaxListSummarySize = 15;
293 if (pieces.size() >= kMaxListSummarySize) {
294 pieces[5] = strings::StrCat(Fingerprint64(
295 absl::StrJoin(pieces.begin() + 5, pieces.end() - 5, ",")));
296 pieces.erase(pieces.begin() + 6, pieces.end() - 5);
297 }
298 return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
299 }
300 case AttrValue::kFunc: {
301 return SummarizeFunc(attr_value.func());
302 }
303 case AttrValue::kPlaceholder:
304 return strings::StrCat("$", attr_value.placeholder());
305 case AttrValue::VALUE_NOT_SET:
306 return "<Unknown AttrValue type>";
307 }
308 return "<Unknown AttrValue type>"; // Prevent missing return warning
309}
310
311Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
312 int num_set = 0;
313
314#define VALIDATE_FIELD(name, type_string, oneof_case) \
315 do { \
316 if (attr_value.has_list()) { \
317 if (attr_value.list().name##_size() > 0) { \
318 if (type != "list(" type_string ")") { \
319 return errors::InvalidArgument( \
320 "AttrValue had value with type 'list(" type_string ")' when '", \
321 type, "' expected"); \
322 } \
323 ++num_set; \
324 } \
325 } else if (attr_value.value_case() == AttrValue::oneof_case) { \
326 if (type != type_string) { \
327 return errors::InvalidArgument( \
328 "AttrValue had value with type '" type_string "' when '", type, \
329 "' expected"); \
330 } \
331 ++num_set; \
332 } \
333 } while (false)
334
335 VALIDATE_FIELD(s, "string", kS);
336 VALIDATE_FIELD(i, "int", kI);
337 VALIDATE_FIELD(f, "float", kF);
338 VALIDATE_FIELD(b, "bool", kB);
339 VALIDATE_FIELD(type, "type", kType);
340 VALIDATE_FIELD(shape, "shape", kShape);
341 VALIDATE_FIELD(tensor, "tensor", kTensor);
342 VALIDATE_FIELD(func, "func", kFunc);
343
344#undef VALIDATE_FIELD
345
346 if (attr_value.value_case() == AttrValue::kPlaceholder) {
347 return errors::InvalidArgument(
348 "AttrValue had value with unexpected type 'placeholder'");
349 }
350
351 // If the attr type is 'list', we expect attr_value.has_list() to be
352 // true. However, proto3's attr_value.has_list() can be false when
353 // set to an empty list for GraphDef versions <= 4. So we simply
354 // check if has_list is false and some other field in attr_value is
355 // set to flag the error. This test can be made more strict once
356 // support for GraphDef versions <= 4 is dropped.
357 if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
358 if (num_set) {
359 return errors::InvalidArgument(
360 "AttrValue missing value with expected type '", type, "'");
361 } else {
362 // Indicate that we have a list, but an empty one.
363 ++num_set;
364 }
365 }
366
367 // Okay to have an empty list, but not to be missing a non-list value.
368 if (num_set == 0 && !absl::StartsWith(type, "list(")) {
369 return errors::InvalidArgument(
370 "AttrValue missing value with expected type '", type, "'");
371 }
372
373 // Ref types and DT_INVALID are illegal, and DataTypes must
374 // be a valid enum type.
375 if (type == "type") {
376 if (!DataType_IsValid(attr_value.type())) {
377 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
378 attr_value.type());
379 }
380 if (IsRefType(attr_value.type())) {
381 return errors::InvalidArgument(
382 "AttrValue must not have reference type value of ",
383 DataTypeString(attr_value.type()));
384 }
385 if (attr_value.type() == DT_INVALID) {
386 return errors::InvalidArgument("AttrValue has invalid DataType");
387 }
388 } else if (type == "list(type)") {
389 for (auto as_int : attr_value.list().type()) {
390 const DataType dtype = static_cast<DataType>(as_int);
391 if (!DataType_IsValid(dtype)) {
392 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
393 as_int);
394 }
395 if (IsRefType(dtype)) {
396 return errors::InvalidArgument(
397 "AttrValue must not have reference type value of ",
398 DataTypeString(dtype));
399 }
400 if (dtype == DT_INVALID) {
401 return errors::InvalidArgument("AttrValue contains invalid DataType");
402 }
403 }
404 }
405
406 return OkStatus();
407}
408
409bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
410 // Parse type.
411 string field_name;
412 bool is_list = absl::ConsumePrefix(&type, "list(");
413 if (absl::ConsumePrefix(&type, "string")) {
414 field_name = "s";
415 } else if (absl::ConsumePrefix(&type, "int")) {
416 field_name = "i";
417 } else if (absl::ConsumePrefix(&type, "float")) {
418 field_name = "f";
419 } else if (absl::ConsumePrefix(&type, "bool")) {
420 field_name = "b";
421 } else if (absl::ConsumePrefix(&type, "type")) {
422 field_name = "type";
423 } else if (absl::ConsumePrefix(&type, "shape")) {
424 field_name = "shape";
425 } else if (absl::ConsumePrefix(&type, "tensor")) {
426 field_name = "tensor";
427 } else if (absl::ConsumePrefix(&type, "func")) {
428 field_name = "func";
429 } else if (absl::ConsumePrefix(&type, "placeholder")) {
430 field_name = "placeholder";
431 } else {
432 return false;
433 }
434 if (is_list && !absl::ConsumePrefix(&type, ")")) {
435 return false;
436 }
437
438 // Construct a valid text proto message to parse.
439 string to_parse;
440 if (is_list) {
441 // TextFormat parser considers "i: 7" to be the same as "i: [7]",
442 // but we only want to allow list values with [].
443 StringPiece cleaned = text;
444 str_util::RemoveLeadingWhitespace(&cleaned);
445 str_util::RemoveTrailingWhitespace(&cleaned);
446 if (cleaned.size() < 2 || cleaned[0] != '[' ||
447 cleaned[cleaned.size() - 1] != ']') {
448 return false;
449 }
450 cleaned.remove_prefix(1);
451 str_util::RemoveLeadingWhitespace(&cleaned);
452 if (cleaned.size() == 1) {
453 // User wrote "[]", so return empty list without invoking the TextFormat
454 // parse which returns an error for "i: []".
455 out->Clear();
456 out->mutable_list();
457 return true;
458 }
459 to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
460 } else {
461 to_parse = strings::StrCat(field_name, ": ", text);
462 }
463 if (field_name == "tensor") {
464 if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
465 to_parse)) {
466 return false;
467 }
468 }
469 return ProtoParseFromString(to_parse, out);
470}
471
472void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
473
474#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
475 void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
476
477#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
478 void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
479 out->mutable_list()->Clear(); /* create list() even if value empty */ \
480 for (const auto& v : value) { \
481 out->mutable_list()->add_##FIELD(v); \
482 } \
483 }
484
485#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
486 DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
487 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
488
489DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
490DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
491DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
492DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i)
493DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i)
494DEFINE_SET_ATTR_VALUE_BOTH(float, f)
495DEFINE_SET_ATTR_VALUE_BOTH(double, f)
496DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
497DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
498DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
499DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
500
501void SetAttrValue(const tstring& value, AttrValue* out) {
502 out->set_s(value.data(), value.size());
503}
504
505void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
506 out->mutable_list()->Clear();
507 for (const auto& v : value) {
508 out->mutable_list()->add_s(v.data(), v.size());
509 }
510}
511
512void SetAttrValue(StringPiece value, AttrValue* out) {
513 out->set_s(value.data(), value.size());
514}
515
516void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
517 out->mutable_list()->Clear(); // Create list() even if value empty.
518 for (const auto& v : value) {
519 out->mutable_list()->add_s(v.data(), v.size());
520 }
521}
522
523void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
524 out->mutable_list()->Clear(); // Create list() even if value empty.
525 for (auto& v : value) {
526 out->mutable_list()->add_s(std::move(v));
527 }
528}
529
530void SetAttrValue(const TensorShape& value, AttrValue* out) {
531 value.AsProto(out->mutable_shape());
532}
533
534void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
535 *out->mutable_shape() = value;
536}
537
538void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
539 value.AsProto(out->mutable_shape());
540}
541
542void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
543 out->mutable_list()->Clear(); // Create list() even if value empty.
544 for (const auto& v : value) {
545 v.AsProto(out->mutable_list()->add_shape());
546 }
547}
548
549void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
550 out->mutable_list()->Clear(); // Create list() even if value empty.
551 for (const auto& v : value) {
552 *out->mutable_list()->add_shape() = v;
553 }
554}
555
556void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
557 AttrValue* out) {
558 out->mutable_list()->Clear(); // Create list() even if value empty.
559 for (const auto& v : value) {
560 v.AsProto(out->mutable_list()->add_shape());
561 }
562}
563
564void SetAttrValue(const Tensor& value, AttrValue* out) {
565 if (value.NumElements() > 1) {
566 value.AsProtoTensorContent(out->mutable_tensor());
567 } else {
568 value.AsProtoField(out->mutable_tensor());
569 }
570}
571
572void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
573 out->mutable_list()->Clear(); // Create list() even if value empty.
574 for (const auto& v : value) {
575 if (v.NumElements() > 1) {
576 v.AsProtoTensorContent(out->mutable_list()->add_tensor());
577 } else {
578 v.AsProtoField(out->mutable_list()->add_tensor());
579 }
580 }
581}
582
583void SetAttrValue(const TensorProto& value, AttrValue* out) {
584 *out->mutable_tensor() = value;
585}
586
587void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
588 out->mutable_list()->Clear(); // Create list() even if value empty.
589 for (const auto& v : value) {
590 *out->mutable_list()->add_tensor() = v;
591 }
592}
593
594void SetAttrValue(const NameAttrList& value, AttrValue* out) {
595 *out->mutable_func() = value;
596}
597
598void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
599 out->mutable_list()->Clear(); // Create list() even if value empty.
600 for (const auto& v : value) {
601 *out->mutable_list()->add_func() = v;
602 }
603}
604
605bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
606 bool allow_false_negatives) {
607 if (a.type() != b.type()) {
608 return false;
609 } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
610 return a.type() == b.type();
611 }
612
613 if (a.has_tensor() != b.has_tensor()) {
614 return false;
615 } else if (a.has_tensor() && b.has_tensor()) {
616 return AreTensorProtosEqual(a.tensor(), b.tensor(), allow_false_negatives);
617 }
618
619 // `func` field contains a nested AttrValue. Compare such AttrValues
620 // recursively.
621 if (a.has_func() != b.has_func()) {
622 return false;
623 } else if (a.has_func() && b.has_func()) {
624 const NameAttrList& af = a.func();
625 const NameAttrList& bf = b.func();
626 if (af.name() != bf.name()) return false;
627 std::unordered_map<string, AttrValue> am(af.attr().begin(),
628 af.attr().end());
629 for (const auto& bm_pair : bf.attr()) {
630 const auto& iter = am.find(bm_pair.first);
631 if (iter == am.end()) return false;
632 if (!AreAttrValuesEqual(iter->second, bm_pair.second,
633 allow_false_negatives))
634 return false;
635 am.erase(iter);
636 }
637 if (!am.empty()) return false;
638 return true;
639 }
640
641 // All other fields in AttrValue have deterministic representations.
642 // It is safe to compare their serialized strings.
643 return AreSerializedProtosEqual(a, b);
644}
645
646uint64 AttrValueHash(const AttrValue& a) {
647 return AttrValueHash(a, TensorProtoHash);
648}
649
650uint64 FastAttrValueHash(const AttrValue& a) {
651 return AttrValueHash(a, FastTensorProtoHash);
652}
653
654bool HasPlaceHolder(const AttrValue& val) {
655 switch (val.value_case()) {
656 case AttrValue::kList: {
657 for (const NameAttrList& func : val.list().func()) {
658 for (const auto& p : func.attr()) {
659 if (HasPlaceHolder(p.second)) {
660 return true;
661 }
662 }
663 }
664 break;
665 }
666 case AttrValue::kFunc:
667 for (const auto& p : val.func().attr()) {
668 if (HasPlaceHolder(p.second)) {
669 return true;
670 }
671 }
672 break;
673 case AttrValue::kPlaceholder:
674 return true;
675 default:
676 break;
677 }
678 return false;
679}
680
681bool SubstitutePlaceholders(const SubstituteFunc& substitute,
682 AttrValue* value) {
683 switch (value->value_case()) {
684 case AttrValue::kList: {
685 for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
686 for (auto& p : *func.mutable_attr()) {
687 if (!SubstitutePlaceholders(substitute, &p.second)) {
688 return false;
689 }
690 }
691 }
692 break;
693 }
694 case AttrValue::kFunc:
695 for (auto& p : *(value->mutable_func()->mutable_attr())) {
696 if (!SubstitutePlaceholders(substitute, &p.second)) {
697 return false;
698 }
699 }
700 break;
701 case AttrValue::kPlaceholder:
702 return substitute(value->placeholder(), value);
703 case AttrValue::VALUE_NOT_SET:
704 return false;
705 default:
706 break;
707 }
708 return true;
709}
710
711} // namespace tensorflow
712