1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/attr_value_util.h" |
17 | |
18 | #include <string> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "absl/strings/escaping.h" |
23 | #include "tensorflow/core/framework/attr_value.pb_text.h" |
24 | #include "tensorflow/core/framework/tensor.pb_text.h" |
25 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/framework/types.pb_text.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/core/stringpiece.h" |
30 | #include "tensorflow/core/lib/hash/hash.h" |
31 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
32 | #include "tensorflow/core/lib/strings/str_util.h" |
33 | #include "tensorflow/core/platform/fingerprint.h" |
34 | #include "tensorflow/core/platform/protobuf.h" |
35 | |
36 | namespace tensorflow { |
37 | namespace { |
38 | |
39 | // Do not construct large tensors to compute their hash or compare for equality. |
40 | constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb |
41 | |
42 | // Limit nesting of tensors to 100 deep to prevent memory overflow. |
43 | constexpr int kMaxTensorNestDepth = 100; |
44 | |
45 | // Return the size of the tensor represented by this TensorProto. If shape is |
46 | // not fully defined return -1. |
47 | int64_t TensorByteSize(const TensorProto& t) { |
48 | // num_elements returns -1 if shape is not fully defined. |
49 | int64_t num_elems = PartialTensorShape(t.tensor_shape()).num_elements(); |
50 | return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype()); |
51 | } |
52 | |
53 | // Compute TensorProto hash by creating a Tensor, serializing it as tensor |
54 | // content, and computing a hash of it's string representation. This is unsafe |
55 | // operation, because large tensors can be represented as TensorProto, but can't |
56 | // be serialized to tensor content. |
57 | uint64 TensorProtoHash(const TensorProto& tp) { |
58 | Tensor tensor(tp.dtype()); |
59 | bool success = tensor.FromProto(tp); |
60 | DCHECK(success); |
61 | TensorProto p; |
62 | tensor.AsProtoTensorContent(&p); |
63 | return DeterministicProtoHash64(p); |
64 | } |
65 | |
66 | // Do not create large tensors in memory, compute hash based on TensorProto |
67 | // string representation. Tensors with identical content potentially can have a |
68 | // different hash code if they are defined with different TensorProto |
69 | // representations. |
70 | uint64 FastTensorProtoHash(const TensorProto& tp) { |
71 | if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) { |
72 | return DeterministicProtoHash64(tp); |
73 | } else { |
74 | return TensorProtoHash(tp); |
75 | } |
76 | } |
77 | |
78 | bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs, |
79 | bool allow_false_negatives) { |
80 | // A small TensorProto can expand into a giant Tensor. So we avoid |
81 | // conversion to an actual Tensor if we can quickly rule out equality |
82 | // by comparing the Tensor size since different sized Tensors are definitely |
83 | // different. |
84 | const int64_t lhs_tensor_bytes = TensorByteSize(lhs); |
85 | const int64_t rhs_tensor_bytes = TensorByteSize(rhs); |
86 | if (lhs_tensor_bytes != rhs_tensor_bytes) { |
87 | return false; |
88 | } |
89 | |
90 | // If the TensorProto representation expands into a much bigger Tensor, |
91 | // we have a fast-path that first compares the protos. |
92 | const int64_t lhs_proto_bytes = lhs.ByteSizeLong(); |
93 | const bool large_expansion = |
94 | (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096); |
95 | |
96 | // If the tensor is very large, we'll only compare the proto representation if |
97 | // false negatives are allowed. This may miss some equivalent tensors whose |
98 | // actual tensor values are the same but which are described by different |
99 | // TensorProtos. This avoids construction of large protos in memory. |
100 | const bool only_compare_proto = |
101 | (allow_false_negatives && lhs_tensor_bytes > kMaxAttrValueTensorByteSize); |
102 | if (large_expansion || only_compare_proto) { |
103 | if (AreSerializedProtosEqual(lhs, rhs)) |
104 | return true; |
105 | else if (only_compare_proto) |
106 | return false; |
107 | } |
108 | |
109 | // Finally, compare them by constructing Tensors and serializing them back. |
110 | // There are multiple equivalent representations of attr values containing |
111 | // TensorProtos. Comparing Tensor objects is pretty tricky. This is unsafe |
112 | // operation, because large tensors can be represented as TensorProto, but |
113 | // can't be serialized to tensor content. |
114 | Tensor lhs_t(lhs.dtype()); |
115 | bool success = lhs_t.FromProto(lhs); |
116 | if (!success) { |
117 | return false; |
118 | } |
119 | |
120 | Tensor rhs_t(rhs.dtype()); |
121 | success = rhs_t.FromProto(rhs); |
122 | if (!success) { |
123 | return false; |
124 | } |
125 | |
126 | TensorProto lhs_tp; |
127 | lhs_t.AsProtoTensorContent(&lhs_tp); |
128 | |
129 | TensorProto rhs_tp; |
130 | rhs_t.AsProtoTensorContent(&rhs_tp); |
131 | |
132 | return AreSerializedProtosEqual(lhs_tp, rhs_tp); |
133 | } |
134 | |
135 | using TensorProtoHasher = std::function<uint64(const TensorProto&)>; |
136 | |
137 | uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { |
138 | if (a.has_tensor()) return tensor_hash(a.tensor()); |
139 | |
140 | if (a.has_func()) { |
141 | const NameAttrList& func = a.func(); |
142 | uint64 h = Hash64(func.name()); |
143 | std::map<string, AttrValue> map(func.attr().begin(), func.attr().end()); |
144 | for (const auto& pair : map) { |
145 | h = Hash64(pair.first.data(), pair.first.size(), h); |
146 | h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h); |
147 | } |
148 | return h; |
149 | } |
150 | |
151 | // If `a` is not a tensor or func, get a hash of serialized string. |
152 | return DeterministicProtoHash64(a); |
153 | } |
154 | |
155 | string SummarizeString(const string& str) { |
156 | string escaped = absl::CEscape(str); |
157 | |
158 | // If the string is long, replace the middle with ellipses. |
159 | constexpr int kMaxStringSummarySize = 80; |
160 | if (escaped.size() >= kMaxStringSummarySize) { |
161 | StringPiece prefix(escaped); |
162 | StringPiece suffix = prefix; |
163 | prefix.remove_suffix(escaped.size() - 10); |
164 | suffix.remove_prefix(escaped.size() - 10); |
165 | return strings::StrCat("\"" , prefix, "..." , suffix, "\"" ); |
166 | } else { |
167 | return strings::StrCat("\"" , escaped, "\"" ); |
168 | } |
169 | } |
170 | |
171 | string SummarizeTensor(const TensorProto& tensor_proto) { |
172 | Tensor t; |
173 | if (!t.FromProto(tensor_proto)) { |
174 | return strings::StrCat( |
175 | "<Invalid TensorProto: " , tensor_proto.ShortDebugString(), ">" ); |
176 | } |
177 | return t.DebugString(); |
178 | } |
179 | |
180 | string SummarizeFunc(const NameAttrList& func) { |
181 | std::vector<string> entries; |
182 | for (const auto& p : func.attr()) { |
183 | entries.push_back( |
184 | strings::StrCat(p.first, "=" , SummarizeAttrValue(p.second))); |
185 | } |
186 | std::sort(entries.begin(), entries.end()); |
187 | return strings::StrCat(func.name(), "[" , absl::StrJoin(entries, ", " ), "]" ); |
188 | } |
189 | |
190 | bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) { |
191 | int nests = 0; |
192 | int maxed_out = to_parse.length(); |
193 | int open_curly = to_parse.find('{'); |
194 | int open_bracket = to_parse.find('<'); |
195 | int close_curly = to_parse.find('}'); |
196 | int close_bracket = to_parse.find('>'); |
197 | if (open_curly == -1) { |
198 | open_curly = maxed_out; |
199 | } |
200 | if (open_bracket == -1) { |
201 | open_bracket = maxed_out; |
202 | } |
203 | int min = std::min(open_curly, open_bracket); |
204 | do { |
205 | if (open_curly == maxed_out && open_bracket == maxed_out) { |
206 | return true; |
207 | } |
208 | if (min == open_curly) { |
209 | nests += 1; |
210 | open_curly = to_parse.find('{', open_curly + 1); |
211 | if (open_curly == -1) { |
212 | open_curly = maxed_out; |
213 | } |
214 | } else if (min == open_bracket) { |
215 | nests += 1; |
216 | open_bracket = to_parse.find('<', open_bracket + 1); |
217 | if (open_bracket == -1) { |
218 | open_bracket = maxed_out; |
219 | } |
220 | } else if (min == close_curly) { |
221 | nests -= 1; |
222 | close_curly = to_parse.find('}', close_curly + 1); |
223 | if (close_curly == -1) { |
224 | close_curly = maxed_out; |
225 | } |
226 | } else if (min == close_bracket) { |
227 | nests -= 1; |
228 | close_bracket = to_parse.find('>', close_bracket + 1); |
229 | if (close_bracket == -1) { |
230 | close_bracket = maxed_out; |
231 | } |
232 | } |
233 | min = std::min({open_curly, open_bracket, close_curly, close_bracket}); |
234 | } while (nests < 100); |
235 | return false; |
236 | } |
237 | |
238 | } // namespace |
239 | |
240 | string SummarizeAttrValue(const AttrValue& attr_value) { |
241 | switch (attr_value.value_case()) { |
242 | case AttrValue::kS: |
243 | return SummarizeString(attr_value.s()); |
244 | case AttrValue::kI: |
245 | return strings::StrCat(attr_value.i()); |
246 | case AttrValue::kF: |
247 | return strings::StrCat(attr_value.f()); |
248 | case AttrValue::kB: |
249 | return attr_value.b() ? "true" : "false" ; |
250 | case AttrValue::kType: |
251 | return EnumName_DataType(attr_value.type()); |
252 | case AttrValue::kShape: |
253 | return PartialTensorShape::DebugString(attr_value.shape()); |
254 | case AttrValue::kTensor: |
255 | return SummarizeTensor(attr_value.tensor()); |
256 | case AttrValue::kList: { |
257 | std::vector<string> pieces; |
258 | if (attr_value.list().s_size() > 0) { |
259 | for (int i = 0; i < attr_value.list().s_size(); ++i) { |
260 | pieces.push_back(SummarizeString(attr_value.list().s(i))); |
261 | } |
262 | } else if (attr_value.list().i_size() > 0) { |
263 | for (int i = 0; i < attr_value.list().i_size(); ++i) { |
264 | pieces.push_back(strings::StrCat(attr_value.list().i(i))); |
265 | } |
266 | } else if (attr_value.list().f_size() > 0) { |
267 | for (int i = 0; i < attr_value.list().f_size(); ++i) { |
268 | pieces.push_back(strings::StrCat(attr_value.list().f(i))); |
269 | } |
270 | } else if (attr_value.list().b_size() > 0) { |
271 | for (int i = 0; i < attr_value.list().b_size(); ++i) { |
272 | pieces.push_back(attr_value.list().b(i) ? "true" : "false" ); |
273 | } |
274 | } else if (attr_value.list().type_size() > 0) { |
275 | for (int i = 0; i < attr_value.list().type_size(); ++i) { |
276 | pieces.push_back(EnumName_DataType(attr_value.list().type(i))); |
277 | } |
278 | } else if (attr_value.list().shape_size() > 0) { |
279 | for (int i = 0; i < attr_value.list().shape_size(); ++i) { |
280 | pieces.push_back( |
281 | TensorShape::DebugString(attr_value.list().shape(i))); |
282 | } |
283 | } else if (attr_value.list().tensor_size() > 0) { |
284 | for (int i = 0; i < attr_value.list().tensor_size(); ++i) { |
285 | pieces.push_back(SummarizeTensor(attr_value.list().tensor(i))); |
286 | } |
287 | } else if (attr_value.list().func_size() > 0) { |
288 | for (int i = 0; i < attr_value.list().func_size(); ++i) { |
289 | pieces.push_back(SummarizeFunc(attr_value.list().func(i))); |
290 | } |
291 | } |
292 | constexpr int kMaxListSummarySize = 15; |
293 | if (pieces.size() >= kMaxListSummarySize) { |
294 | pieces[5] = strings::StrCat(Fingerprint64( |
295 | absl::StrJoin(pieces.begin() + 5, pieces.end() - 5, "," ))); |
296 | pieces.erase(pieces.begin() + 6, pieces.end() - 5); |
297 | } |
298 | return strings::StrCat("[" , absl::StrJoin(pieces, ", " ), "]" ); |
299 | } |
300 | case AttrValue::kFunc: { |
301 | return SummarizeFunc(attr_value.func()); |
302 | } |
303 | case AttrValue::kPlaceholder: |
304 | return strings::StrCat("$" , attr_value.placeholder()); |
305 | case AttrValue::VALUE_NOT_SET: |
306 | return "<Unknown AttrValue type>" ; |
307 | } |
308 | return "<Unknown AttrValue type>" ; // Prevent missing return warning |
309 | } |
310 | |
311 | Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { |
312 | int num_set = 0; |
313 | |
314 | #define VALIDATE_FIELD(name, type_string, oneof_case) \ |
315 | do { \ |
316 | if (attr_value.has_list()) { \ |
317 | if (attr_value.list().name##_size() > 0) { \ |
318 | if (type != "list(" type_string ")") { \ |
319 | return errors::InvalidArgument( \ |
320 | "AttrValue had value with type 'list(" type_string ")' when '", \ |
321 | type, "' expected"); \ |
322 | } \ |
323 | ++num_set; \ |
324 | } \ |
325 | } else if (attr_value.value_case() == AttrValue::oneof_case) { \ |
326 | if (type != type_string) { \ |
327 | return errors::InvalidArgument( \ |
328 | "AttrValue had value with type '" type_string "' when '", type, \ |
329 | "' expected"); \ |
330 | } \ |
331 | ++num_set; \ |
332 | } \ |
333 | } while (false) |
334 | |
335 | VALIDATE_FIELD(s, "string" , kS); |
336 | VALIDATE_FIELD(i, "int" , kI); |
337 | VALIDATE_FIELD(f, "float" , kF); |
338 | VALIDATE_FIELD(b, "bool" , kB); |
339 | VALIDATE_FIELD(type, "type" , kType); |
340 | VALIDATE_FIELD(shape, "shape" , kShape); |
341 | VALIDATE_FIELD(tensor, "tensor" , kTensor); |
342 | VALIDATE_FIELD(func, "func" , kFunc); |
343 | |
344 | #undef VALIDATE_FIELD |
345 | |
346 | if (attr_value.value_case() == AttrValue::kPlaceholder) { |
347 | return errors::InvalidArgument( |
348 | "AttrValue had value with unexpected type 'placeholder'" ); |
349 | } |
350 | |
351 | // If the attr type is 'list', we expect attr_value.has_list() to be |
352 | // true. However, proto3's attr_value.has_list() can be false when |
353 | // set to an empty list for GraphDef versions <= 4. So we simply |
354 | // check if has_list is false and some other field in attr_value is |
355 | // set to flag the error. This test can be made more strict once |
356 | // support for GraphDef versions <= 4 is dropped. |
357 | if (absl::StartsWith(type, "list(" ) && !attr_value.has_list()) { |
358 | if (num_set) { |
359 | return errors::InvalidArgument( |
360 | "AttrValue missing value with expected type '" , type, "'" ); |
361 | } else { |
362 | // Indicate that we have a list, but an empty one. |
363 | ++num_set; |
364 | } |
365 | } |
366 | |
367 | // Okay to have an empty list, but not to be missing a non-list value. |
368 | if (num_set == 0 && !absl::StartsWith(type, "list(" )) { |
369 | return errors::InvalidArgument( |
370 | "AttrValue missing value with expected type '" , type, "'" ); |
371 | } |
372 | |
373 | // Ref types and DT_INVALID are illegal, and DataTypes must |
374 | // be a valid enum type. |
375 | if (type == "type" ) { |
376 | if (!DataType_IsValid(attr_value.type())) { |
377 | return errors::InvalidArgument("AttrValue has invalid DataType enum: " , |
378 | attr_value.type()); |
379 | } |
380 | if (IsRefType(attr_value.type())) { |
381 | return errors::InvalidArgument( |
382 | "AttrValue must not have reference type value of " , |
383 | DataTypeString(attr_value.type())); |
384 | } |
385 | if (attr_value.type() == DT_INVALID) { |
386 | return errors::InvalidArgument("AttrValue has invalid DataType" ); |
387 | } |
388 | } else if (type == "list(type)" ) { |
389 | for (auto as_int : attr_value.list().type()) { |
390 | const DataType dtype = static_cast<DataType>(as_int); |
391 | if (!DataType_IsValid(dtype)) { |
392 | return errors::InvalidArgument("AttrValue has invalid DataType enum: " , |
393 | as_int); |
394 | } |
395 | if (IsRefType(dtype)) { |
396 | return errors::InvalidArgument( |
397 | "AttrValue must not have reference type value of " , |
398 | DataTypeString(dtype)); |
399 | } |
400 | if (dtype == DT_INVALID) { |
401 | return errors::InvalidArgument("AttrValue contains invalid DataType" ); |
402 | } |
403 | } |
404 | } |
405 | |
406 | return OkStatus(); |
407 | } |
408 | |
409 | bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { |
410 | // Parse type. |
411 | string field_name; |
412 | bool is_list = absl::ConsumePrefix(&type, "list(" ); |
413 | if (absl::ConsumePrefix(&type, "string" )) { |
414 | field_name = "s" ; |
415 | } else if (absl::ConsumePrefix(&type, "int" )) { |
416 | field_name = "i" ; |
417 | } else if (absl::ConsumePrefix(&type, "float" )) { |
418 | field_name = "f" ; |
419 | } else if (absl::ConsumePrefix(&type, "bool" )) { |
420 | field_name = "b" ; |
421 | } else if (absl::ConsumePrefix(&type, "type" )) { |
422 | field_name = "type" ; |
423 | } else if (absl::ConsumePrefix(&type, "shape" )) { |
424 | field_name = "shape" ; |
425 | } else if (absl::ConsumePrefix(&type, "tensor" )) { |
426 | field_name = "tensor" ; |
427 | } else if (absl::ConsumePrefix(&type, "func" )) { |
428 | field_name = "func" ; |
429 | } else if (absl::ConsumePrefix(&type, "placeholder" )) { |
430 | field_name = "placeholder" ; |
431 | } else { |
432 | return false; |
433 | } |
434 | if (is_list && !absl::ConsumePrefix(&type, ")" )) { |
435 | return false; |
436 | } |
437 | |
438 | // Construct a valid text proto message to parse. |
439 | string to_parse; |
440 | if (is_list) { |
441 | // TextFormat parser considers "i: 7" to be the same as "i: [7]", |
442 | // but we only want to allow list values with []. |
443 | StringPiece cleaned = text; |
444 | str_util::RemoveLeadingWhitespace(&cleaned); |
445 | str_util::RemoveTrailingWhitespace(&cleaned); |
446 | if (cleaned.size() < 2 || cleaned[0] != '[' || |
447 | cleaned[cleaned.size() - 1] != ']') { |
448 | return false; |
449 | } |
450 | cleaned.remove_prefix(1); |
451 | str_util::RemoveLeadingWhitespace(&cleaned); |
452 | if (cleaned.size() == 1) { |
453 | // User wrote "[]", so return empty list without invoking the TextFormat |
454 | // parse which returns an error for "i: []". |
455 | out->Clear(); |
456 | out->mutable_list(); |
457 | return true; |
458 | } |
459 | to_parse = strings::StrCat("list { " , field_name, ": " , text, " }" ); |
460 | } else { |
461 | to_parse = strings::StrCat(field_name, ": " , text); |
462 | } |
463 | if (field_name == "tensor" ) { |
464 | if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth, |
465 | to_parse)) { |
466 | return false; |
467 | } |
468 | } |
469 | return ProtoParseFromString(to_parse, out); |
470 | } |
471 | |
472 | void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } |
473 | |
474 | #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ |
475 | void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } |
476 | |
477 | #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \ |
478 | void SetAttrValue(ARG_TYPE value, AttrValue* out) { \ |
479 | out->mutable_list()->Clear(); /* create list() even if value empty */ \ |
480 | for (const auto& v : value) { \ |
481 | out->mutable_list()->add_##FIELD(v); \ |
482 | } \ |
483 | } |
484 | |
485 | #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \ |
486 | DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ |
487 | DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD) |
488 | |
489 | DEFINE_SET_ATTR_VALUE_ONE(const string&, s) |
490 | DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s) |
491 | DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) |
492 | DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i) |
493 | DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i) |
494 | DEFINE_SET_ATTR_VALUE_BOTH(float, f) |
495 | DEFINE_SET_ATTR_VALUE_BOTH(double, f) |
496 | DEFINE_SET_ATTR_VALUE_BOTH(bool, b) |
497 | DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b) |
498 | DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b) |
499 | DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) |
500 | |
501 | void SetAttrValue(const tstring& value, AttrValue* out) { |
502 | out->set_s(value.data(), value.size()); |
503 | } |
504 | |
505 | void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) { |
506 | out->mutable_list()->Clear(); |
507 | for (const auto& v : value) { |
508 | out->mutable_list()->add_s(v.data(), v.size()); |
509 | } |
510 | } |
511 | |
512 | void SetAttrValue(StringPiece value, AttrValue* out) { |
513 | out->set_s(value.data(), value.size()); |
514 | } |
515 | |
516 | void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) { |
517 | out->mutable_list()->Clear(); // Create list() even if value empty. |
518 | for (const auto& v : value) { |
519 | out->mutable_list()->add_s(v.data(), v.size()); |
520 | } |
521 | } |
522 | |
523 | void MoveAttrValue(std::vector<string>&& value, AttrValue* out) { |
524 | out->mutable_list()->Clear(); // Create list() even if value empty. |
525 | for (auto& v : value) { |
526 | out->mutable_list()->add_s(std::move(v)); |
527 | } |
528 | } |
529 | |
530 | void SetAttrValue(const TensorShape& value, AttrValue* out) { |
531 | value.AsProto(out->mutable_shape()); |
532 | } |
533 | |
534 | void SetAttrValue(const TensorShapeProto& value, AttrValue* out) { |
535 | *out->mutable_shape() = value; |
536 | } |
537 | |
538 | void SetAttrValue(const PartialTensorShape& value, AttrValue* out) { |
539 | value.AsProto(out->mutable_shape()); |
540 | } |
541 | |
542 | void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) { |
543 | out->mutable_list()->Clear(); // Create list() even if value empty. |
544 | for (const auto& v : value) { |
545 | v.AsProto(out->mutable_list()->add_shape()); |
546 | } |
547 | } |
548 | |
549 | void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) { |
550 | out->mutable_list()->Clear(); // Create list() even if value empty. |
551 | for (const auto& v : value) { |
552 | *out->mutable_list()->add_shape() = v; |
553 | } |
554 | } |
555 | |
556 | void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value, |
557 | AttrValue* out) { |
558 | out->mutable_list()->Clear(); // Create list() even if value empty. |
559 | for (const auto& v : value) { |
560 | v.AsProto(out->mutable_list()->add_shape()); |
561 | } |
562 | } |
563 | |
564 | void SetAttrValue(const Tensor& value, AttrValue* out) { |
565 | if (value.NumElements() > 1) { |
566 | value.AsProtoTensorContent(out->mutable_tensor()); |
567 | } else { |
568 | value.AsProtoField(out->mutable_tensor()); |
569 | } |
570 | } |
571 | |
572 | void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) { |
573 | out->mutable_list()->Clear(); // Create list() even if value empty. |
574 | for (const auto& v : value) { |
575 | if (v.NumElements() > 1) { |
576 | v.AsProtoTensorContent(out->mutable_list()->add_tensor()); |
577 | } else { |
578 | v.AsProtoField(out->mutable_list()->add_tensor()); |
579 | } |
580 | } |
581 | } |
582 | |
583 | void SetAttrValue(const TensorProto& value, AttrValue* out) { |
584 | *out->mutable_tensor() = value; |
585 | } |
586 | |
587 | void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) { |
588 | out->mutable_list()->Clear(); // Create list() even if value empty. |
589 | for (const auto& v : value) { |
590 | *out->mutable_list()->add_tensor() = v; |
591 | } |
592 | } |
593 | |
594 | void SetAttrValue(const NameAttrList& value, AttrValue* out) { |
595 | *out->mutable_func() = value; |
596 | } |
597 | |
598 | void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) { |
599 | out->mutable_list()->Clear(); // Create list() even if value empty. |
600 | for (const auto& v : value) { |
601 | *out->mutable_list()->add_func() = v; |
602 | } |
603 | } |
604 | |
605 | bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, |
606 | bool allow_false_negatives) { |
607 | if (a.type() != b.type()) { |
608 | return false; |
609 | } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) { |
610 | return a.type() == b.type(); |
611 | } |
612 | |
613 | if (a.has_tensor() != b.has_tensor()) { |
614 | return false; |
615 | } else if (a.has_tensor() && b.has_tensor()) { |
616 | return AreTensorProtosEqual(a.tensor(), b.tensor(), allow_false_negatives); |
617 | } |
618 | |
619 | // `func` field contains a nested AttrValue. Compare such AttrValues |
620 | // recursively. |
621 | if (a.has_func() != b.has_func()) { |
622 | return false; |
623 | } else if (a.has_func() && b.has_func()) { |
624 | const NameAttrList& af = a.func(); |
625 | const NameAttrList& bf = b.func(); |
626 | if (af.name() != bf.name()) return false; |
627 | std::unordered_map<string, AttrValue> am(af.attr().begin(), |
628 | af.attr().end()); |
629 | for (const auto& bm_pair : bf.attr()) { |
630 | const auto& iter = am.find(bm_pair.first); |
631 | if (iter == am.end()) return false; |
632 | if (!AreAttrValuesEqual(iter->second, bm_pair.second, |
633 | allow_false_negatives)) |
634 | return false; |
635 | am.erase(iter); |
636 | } |
637 | if (!am.empty()) return false; |
638 | return true; |
639 | } |
640 | |
641 | // All other fields in AttrValue have deterministic representations. |
642 | // It is safe to compare their serialized strings. |
643 | return AreSerializedProtosEqual(a, b); |
644 | } |
645 | |
646 | uint64 AttrValueHash(const AttrValue& a) { |
647 | return AttrValueHash(a, TensorProtoHash); |
648 | } |
649 | |
650 | uint64 FastAttrValueHash(const AttrValue& a) { |
651 | return AttrValueHash(a, FastTensorProtoHash); |
652 | } |
653 | |
654 | bool HasPlaceHolder(const AttrValue& val) { |
655 | switch (val.value_case()) { |
656 | case AttrValue::kList: { |
657 | for (const NameAttrList& func : val.list().func()) { |
658 | for (const auto& p : func.attr()) { |
659 | if (HasPlaceHolder(p.second)) { |
660 | return true; |
661 | } |
662 | } |
663 | } |
664 | break; |
665 | } |
666 | case AttrValue::kFunc: |
667 | for (const auto& p : val.func().attr()) { |
668 | if (HasPlaceHolder(p.second)) { |
669 | return true; |
670 | } |
671 | } |
672 | break; |
673 | case AttrValue::kPlaceholder: |
674 | return true; |
675 | default: |
676 | break; |
677 | } |
678 | return false; |
679 | } |
680 | |
681 | bool SubstitutePlaceholders(const SubstituteFunc& substitute, |
682 | AttrValue* value) { |
683 | switch (value->value_case()) { |
684 | case AttrValue::kList: { |
685 | for (NameAttrList& func : *value->mutable_list()->mutable_func()) { |
686 | for (auto& p : *func.mutable_attr()) { |
687 | if (!SubstitutePlaceholders(substitute, &p.second)) { |
688 | return false; |
689 | } |
690 | } |
691 | } |
692 | break; |
693 | } |
694 | case AttrValue::kFunc: |
695 | for (auto& p : *(value->mutable_func()->mutable_attr())) { |
696 | if (!SubstitutePlaceholders(substitute, &p.second)) { |
697 | return false; |
698 | } |
699 | } |
700 | break; |
701 | case AttrValue::kPlaceholder: |
702 | return substitute(value->placeholder(), value); |
703 | case AttrValue::VALUE_NOT_SET: |
704 | return false; |
705 | default: |
706 | break; |
707 | } |
708 | return true; |
709 | } |
710 | |
711 | } // namespace tensorflow |
712 | |