1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_def_util.h" |
17 | |
18 | #include <set> |
19 | #include <unordered_map> |
20 | #include <unordered_set> |
21 | |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/attr_value_util.h" |
24 | #include "tensorflow/core/framework/op_def.pb.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/lib/core/stringpiece.h" |
28 | #include "tensorflow/core/lib/gtl/map_util.h" |
29 | #include "tensorflow/core/lib/hash/hash.h" |
30 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
31 | #include "tensorflow/core/lib/strings/scanner.h" |
32 | #include "tensorflow/core/lib/strings/str_util.h" |
33 | #include "tensorflow/core/lib/strings/strcat.h" |
34 | #include "tensorflow/core/platform/mutex.h" |
35 | #include "tensorflow/core/platform/protobuf.h" |
36 | #include "tensorflow/core/platform/types.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace { // ------ Helper functions ------ |
40 | |
41 | bool HasAttrStyleType(const OpDef::ArgDef& arg) { |
42 | return arg.type() != DT_INVALID || !arg.type_attr().empty() || |
43 | !arg.type_list_attr().empty(); |
44 | } |
45 | |
46 | Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { |
47 | const AttrValue& allowed_values(attr.allowed_values()); |
48 | for (auto allowed : allowed_values.list().type()) { |
49 | if (dt == allowed) { |
50 | return OkStatus(); |
51 | } |
52 | } |
53 | string allowed_str; |
54 | for (int i = 0; i < allowed_values.list().type_size(); ++i) { |
55 | if (!allowed_str.empty()) { |
56 | strings::StrAppend(&allowed_str, ", " ); |
57 | } |
58 | strings::StrAppend(&allowed_str, |
59 | DataTypeString(allowed_values.list().type(i))); |
60 | } |
61 | return errors::InvalidArgument( |
62 | "Value for attr '" , attr.name(), "' of " , DataTypeString(dt), |
63 | " is not in the list of allowed values: " , allowed_str); |
64 | } |
65 | |
66 | Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { |
67 | const AttrValue& allowed_values(attr.allowed_values()); |
68 | for (const auto& allowed : allowed_values.list().s()) { |
69 | if (str == allowed) { |
70 | return OkStatus(); |
71 | } |
72 | } |
73 | string allowed_str; |
74 | for (const string& allowed : allowed_values.list().s()) { |
75 | if (!allowed_str.empty()) { |
76 | strings::StrAppend(&allowed_str, ", " ); |
77 | } |
78 | strings::StrAppend(&allowed_str, "\"" , allowed, "\"" ); |
79 | } |
80 | return errors::InvalidArgument( |
81 | "Value for attr '" , attr.name(), "' of \"" , str, |
82 | "\" is not in the list of allowed values: " , allowed_str); |
83 | } |
84 | |
85 | } // namespace |
86 | |
87 | // Requires: attr has already been validated. |
88 | Status ValidateAttrValue(const AttrValue& attr_value, |
89 | const OpDef::AttrDef& attr) { |
90 | // Is it a valid value? |
91 | TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), |
92 | " for attr '" , attr.name(), "'" ); |
93 | |
94 | // Does the value satisfy the minimum constraint in the AttrDef? |
95 | if (attr.has_minimum()) { |
96 | if (attr.type() == "int" ) { |
97 | if (attr_value.i() < attr.minimum()) { |
98 | return errors::InvalidArgument( |
99 | "Value for attr '" , attr.name(), "' of " , attr_value.i(), |
100 | " must be at least minimum " , attr.minimum()); |
101 | } |
102 | } else { |
103 | int length = -1; |
104 | if (attr.type() == "list(string)" ) { |
105 | length = attr_value.list().s_size(); |
106 | } else if (attr.type() == "list(int)" ) { |
107 | length = attr_value.list().i_size(); |
108 | } else if (attr.type() == "list(float)" ) { |
109 | length = attr_value.list().f_size(); |
110 | } else if (attr.type() == "list(bool)" ) { |
111 | length = attr_value.list().b_size(); |
112 | } else if (attr.type() == "list(type)" ) { |
113 | length = attr_value.list().type_size(); |
114 | } else if (attr.type() == "list(shape)" ) { |
115 | length = attr_value.list().shape_size(); |
116 | } else if (attr.type() == "list(tensor)" ) { |
117 | length = attr_value.list().tensor_size(); |
118 | } else if (attr.type() == "list(func)" ) { |
119 | length = attr_value.list().func_size(); |
120 | } |
121 | if (length < attr.minimum()) { |
122 | return errors::InvalidArgument( |
123 | "Length for attr '" , attr.name(), "' of " , length, |
124 | " must be at least minimum " , attr.minimum()); |
125 | } |
126 | } |
127 | } |
128 | |
129 | // Does the value satisfy the allowed_value constraint in the AttrDef? |
130 | if (attr.has_allowed_values()) { |
131 | if (attr.type() == "type" ) { |
132 | TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); |
133 | } else if (attr.type() == "list(type)" ) { |
134 | for (int dt : attr_value.list().type()) { |
135 | TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr)); |
136 | } |
137 | } else if (attr.type() == "string" ) { |
138 | TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); |
139 | } else if (attr.type() == "list(string)" ) { |
140 | for (const string& str : attr_value.list().s()) { |
141 | TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); |
142 | } |
143 | } else { |
144 | return errors::Unimplemented( |
145 | "Support for allowed_values not implemented for type " , attr.type()); |
146 | } |
147 | } |
148 | return OkStatus(); |
149 | } |
150 | |
151 | const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { |
152 | for (int i = 0; i < op_def.attr_size(); ++i) { |
153 | if (op_def.attr(i).name() == name) { |
154 | return &op_def.attr(i); |
155 | } |
156 | } |
157 | return nullptr; |
158 | } |
159 | |
160 | OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { |
161 | for (int i = 0; i < op_def->attr_size(); ++i) { |
162 | if (op_def->attr(i).name() == name) { |
163 | return op_def->mutable_attr(i); |
164 | } |
165 | } |
166 | return nullptr; |
167 | } |
168 | |
169 | const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { |
170 | for (int i = 0; i < op_def.input_arg_size(); ++i) { |
171 | if (op_def.input_arg(i).name() == name) { |
172 | return &op_def.input_arg(i); |
173 | } |
174 | } |
175 | return nullptr; |
176 | } |
177 | |
178 | const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { |
179 | for (int i = 0; i < api_def.in_arg_size(); ++i) { |
180 | if (api_def.in_arg(i).name() == name) { |
181 | return &api_def.in_arg(i); |
182 | } |
183 | } |
184 | return nullptr; |
185 | } |
186 | |
187 | #define VALIDATE(EXPR, ...) \ |
188 | do { \ |
189 | if (!(EXPR)) { \ |
190 | return errors::InvalidArgument( \ |
191 | __VA_ARGS__, "; in OpDef: ", op_def.ShortDebugString()); \ |
192 | } \ |
193 | } while (false) |
194 | |
195 | static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, |
196 | bool output, std::set<string>* names) { |
197 | const string suffix = strings::StrCat( |
198 | output ? " for output '" : " for input '" , arg.name(), "'" ); |
199 | VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), |
200 | "Duplicate name: " , arg.name()); |
201 | VALIDATE(HasAttrStyleType(arg), "Missing type" , suffix); |
202 | |
203 | if (!arg.number_attr().empty()) { |
204 | const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); |
205 | VALIDATE(attr != nullptr, "No attr with name '" , arg.number_attr(), "'" , |
206 | suffix); |
207 | VALIDATE(attr->type() == "int" , "Attr '" , attr->name(), "' used as length" , |
208 | suffix, " has type " , attr->type(), " != int" ); |
209 | VALIDATE(attr->has_minimum(), "Attr '" , attr->name(), "' used as length" , |
210 | suffix, " must have minimum" ); |
211 | VALIDATE(attr->minimum() >= 0, "Attr '" , attr->name(), "' used as length" , |
212 | suffix, " must have minimum >= 0" ); |
213 | VALIDATE(arg.type_list_attr().empty(), |
214 | "Can't have both number_attr and type_list_attr" , suffix); |
215 | VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + |
216 | (!arg.type_attr().empty() ? 1 : 0) == |
217 | 1, |
218 | "Exactly one of type, type_attr must be set" , suffix); |
219 | } else { |
220 | const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + |
221 | (!arg.type_attr().empty() ? 1 : 0) + |
222 | (!arg.type_list_attr().empty() ? 1 : 0); |
223 | VALIDATE(num_type_fields == 1, |
224 | "Exactly one of type, type_attr, type_list_attr must be set" , |
225 | suffix); |
226 | } |
227 | |
228 | if (!arg.type_attr().empty()) { |
229 | const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); |
230 | VALIDATE(attr != nullptr, "No attr with name '" , arg.type_attr(), "'" , |
231 | suffix); |
232 | VALIDATE(attr->type() == "type" , "Attr '" , attr->name(), |
233 | "' used as type_attr" , suffix, " has type " , attr->type(), |
234 | " != type" ); |
235 | } else if (!arg.type_list_attr().empty()) { |
236 | const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); |
237 | VALIDATE(attr != nullptr, "No attr with name '" , arg.type_list_attr(), "'" , |
238 | suffix); |
239 | VALIDATE(attr->type() == "list(type)" , "Attr '" , attr->name(), |
240 | "' used as type_list_attr" , suffix, " has type " , attr->type(), |
241 | " != list(type)" ); |
242 | } else { |
243 | // All argument types should be non-reference types at this point. |
244 | // ArgDef.is_ref is set to true for reference arguments. |
245 | VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '" , |
246 | DataTypeString(arg.type()), "'. Use 'Ref(type)' instead" , suffix); |
247 | } |
248 | |
249 | return OkStatus(); |
250 | } |
251 | |
252 | bool IsValidOpName(StringPiece sp) { |
253 | using ::tensorflow::strings::Scanner; |
254 | |
255 | Scanner scanner(sp); |
256 | scanner.One(Scanner::UPPERLETTER).Any(Scanner::LETTER_DIGIT_UNDERSCORE); |
257 | |
258 | while (true) { |
259 | if (!scanner.GetResult()) // Some error in previous iteration. |
260 | return false; |
261 | if (scanner.empty()) // No error, but nothing left, good. |
262 | return true; |
263 | |
264 | // Absorb another name/namespace, starting with a '>' |
265 | scanner.One(Scanner::RANGLE) |
266 | .One(Scanner::UPPERLETTER) |
267 | .Any(Scanner::LETTER_DIGIT_UNDERSCORE); |
268 | } |
269 | } |
270 | |
271 | Status ValidateOpDef(const OpDef& op_def) { |
272 | if (!absl::StartsWith(op_def.name(), "_" )) { |
273 | VALIDATE(IsValidOpName(op_def.name()), "Invalid name: " , op_def.name(), |
274 | " (Did you use CamelCase?)" ); |
275 | } |
276 | |
277 | std::set<string> names; // for detecting duplicate names |
278 | for (const auto& attr : op_def.attr()) { |
279 | // Validate name |
280 | VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), |
281 | "Duplicate name: " , attr.name()); |
282 | DataType dt; |
283 | VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name " , |
284 | attr.name(), " that matches a data type" ); |
285 | |
286 | // Validate type |
287 | StringPiece type(attr.type()); |
288 | bool is_list = absl::ConsumePrefix(&type, "list(" ); |
289 | bool found = false; |
290 | for (StringPiece valid : {"string" , "int" , "float" , "bool" , "type" , "shape" , |
291 | "tensor" , "func" }) { |
292 | if (absl::ConsumePrefix(&type, valid)) { |
293 | found = true; |
294 | break; |
295 | } |
296 | } |
297 | VALIDATE(found, "Unrecognized type '" , type, "' in attr '" , attr.name(), |
298 | "'" ); |
299 | if (is_list) { |
300 | VALIDATE(absl::ConsumePrefix(&type, ")" ), |
301 | "'list(' is missing ')' in attr " , attr.name(), "'s type " , |
302 | attr.type()); |
303 | } |
304 | VALIDATE(type.empty(), "Extra '" , type, "' at the end of attr " , |
305 | attr.name(), "'s type " , attr.type()); |
306 | |
307 | // Validate minimum |
308 | if (attr.has_minimum()) { |
309 | VALIDATE(attr.type() == "int" || is_list, "Attr '" , attr.name(), |
310 | "' has minimum for unsupported type " , attr.type()); |
311 | if (is_list) { |
312 | VALIDATE(attr.minimum() >= 0, "Attr '" , attr.name(), |
313 | "' with list type must have a non-negative minimum, not " , |
314 | attr.minimum()); |
315 | } |
316 | } else { |
317 | VALIDATE(attr.minimum() == 0, "Attr '" , attr.name(), |
318 | "' with has_minimum = false but minimum " , attr.minimum(), |
319 | " not equal to default of 0" ); |
320 | } |
321 | |
322 | // Validate allowed_values |
323 | if (attr.has_allowed_values()) { |
324 | const string list_type = |
325 | is_list ? attr.type() : strings::StrCat("list(" , attr.type(), ")" ); |
326 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
327 | AttrValueHasType(attr.allowed_values(), list_type), " for attr '" , |
328 | attr.name(), "' in Op '" , op_def.name(), "'" ); |
329 | } |
330 | |
331 | // Validate default_value (after we have validated the rest of the attr, |
332 | // so we can use ValidateAttrValue()). |
333 | if (attr.has_default_value()) { |
334 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
335 | ValidateAttrValue(attr.default_value(), attr), " in Op '" , |
336 | op_def.name(), "'" ); |
337 | } |
338 | } |
339 | |
340 | for (const auto& arg : op_def.input_arg()) { |
341 | TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); |
342 | } |
343 | |
344 | for (const auto& arg : op_def.output_arg()) { |
345 | TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); |
346 | } |
347 | |
348 | return OkStatus(); |
349 | } |
350 | |
351 | #undef VALIDATE |
352 | |
353 | Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { |
354 | if (op_def.has_deprecation()) { |
355 | const OpDeprecation& dep = op_def.deprecation(); |
356 | if (graph_def_version >= dep.version()) { |
357 | return errors::Unimplemented( |
358 | "Op " , op_def.name(), " is not available in GraphDef version " , |
359 | graph_def_version, ". It has been removed in version " , dep.version(), |
360 | ". " , dep.explanation(), "." ); |
361 | } else { |
362 | // Warn only once for each op name, and do it in a threadsafe manner. |
363 | static mutex mu(LINKER_INITIALIZED); |
364 | static std::unordered_set<string> warned; |
365 | bool warn; |
366 | { |
367 | mutex_lock lock(mu); |
368 | warn = warned.insert(op_def.name()).second; |
369 | } |
370 | if (warn) { |
371 | LOG(WARNING) << "Op " << op_def.name() << " is deprecated." |
372 | << " It will cease to work in GraphDef version " |
373 | << dep.version() << ". " << dep.explanation() << "." ; |
374 | } |
375 | } |
376 | } |
377 | return OkStatus(); |
378 | } |
379 | |
380 | namespace { |
381 | |
382 | string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { |
383 | string ret; |
384 | for (const OpDef::ArgDef& arg : args) { |
385 | if (!ret.empty()) strings::StrAppend(&ret, ", " ); |
386 | strings::StrAppend(&ret, arg.name(), ":" ); |
387 | if (arg.is_ref()) strings::StrAppend(&ret, "Ref(" ); |
388 | if (!arg.number_attr().empty()) { |
389 | strings::StrAppend(&ret, arg.number_attr(), "*" ); |
390 | } |
391 | if (arg.type() != DT_INVALID) { |
392 | strings::StrAppend(&ret, DataTypeString(arg.type())); |
393 | } else { |
394 | strings::StrAppend(&ret, arg.type_attr()); |
395 | } |
396 | if (arg.is_ref()) strings::StrAppend(&ret, ")" ); |
397 | } |
398 | return ret; |
399 | } |
400 | |
401 | } // namespace |
402 | |
403 | string SummarizeOpDef(const OpDef& op_def) { |
404 | string ret = strings::StrCat("Op<name=" , op_def.name()); |
405 | strings::StrAppend(&ret, "; signature=" , SummarizeArgs(op_def.input_arg()), |
406 | " -> " , SummarizeArgs(op_def.output_arg())); |
407 | for (int i = 0; i < op_def.attr_size(); ++i) { |
408 | strings::StrAppend(&ret, "; attr=" , op_def.attr(i).name(), ":" , |
409 | op_def.attr(i).type()); |
410 | if (op_def.attr(i).has_default_value()) { |
411 | strings::StrAppend(&ret, ",default=" , |
412 | SummarizeAttrValue(op_def.attr(i).default_value())); |
413 | } |
414 | if (op_def.attr(i).has_minimum()) { |
415 | strings::StrAppend(&ret, ",min=" , op_def.attr(i).minimum()); |
416 | } |
417 | if (op_def.attr(i).has_allowed_values()) { |
418 | strings::StrAppend(&ret, ",allowed=" , |
419 | SummarizeAttrValue(op_def.attr(i).allowed_values())); |
420 | } |
421 | } |
422 | if (op_def.is_commutative()) { |
423 | strings::StrAppend(&ret, "; is_commutative=true" ); |
424 | } |
425 | if (op_def.is_aggregate()) { |
426 | strings::StrAppend(&ret, "; is_aggregate=true" ); |
427 | } |
428 | if (op_def.is_stateful()) { |
429 | strings::StrAppend(&ret, "; is_stateful=true" ); |
430 | } |
431 | if (op_def.allows_uninitialized_input()) { |
432 | strings::StrAppend(&ret, "; allows_uninitialized_input=true" ); |
433 | } |
434 | if (op_def.is_distributed_communication()) { |
435 | strings::StrAppend(&ret, "; is_distributed_communication=true" ); |
436 | } |
437 | strings::StrAppend(&ret, ">" ); |
438 | return ret; |
439 | } |
440 | |
441 | namespace { |
442 | |
443 | // Returns true if every element of `sub` is contained in `super`. |
444 | template <class T> |
445 | bool IsSubsetOf(const T& sub, const T& super) { |
446 | for (const auto& o : sub) { |
447 | bool found = false; |
448 | for (const auto& n : super) { |
449 | if (o == n) { |
450 | found = true; |
451 | break; |
452 | } |
453 | } |
454 | if (!found) return false; |
455 | } |
456 | return true; |
457 | } |
458 | |
459 | bool MoreRestrictive(const OpDef::AttrDef& old_attr, |
460 | const OpDef::AttrDef& new_attr) { |
461 | // Anything -> no restriction : not more restrictive. |
462 | if (!new_attr.has_allowed_values()) return false; |
463 | // No restriction -> restriction : more restrictive. |
464 | if (!old_attr.has_allowed_values()) return true; |
465 | // If anything that was previously allowed is no longer allowed: |
466 | // more restrictive. |
467 | if (!IsSubsetOf(old_attr.allowed_values().list().type(), |
468 | new_attr.allowed_values().list().type())) { |
469 | return true; |
470 | } |
471 | if (!IsSubsetOf(old_attr.allowed_values().list().s(), |
472 | new_attr.allowed_values().list().s())) { |
473 | return true; |
474 | } |
475 | return false; |
476 | } |
477 | |
478 | string AllowedStr(const OpDef::AttrDef& attr) { |
479 | if (!attr.has_allowed_values()) return "no restriction" ; |
480 | return SummarizeAttrValue(attr.allowed_values()); |
481 | } |
482 | |
483 | string DefaultAttrStr(const OpDef::AttrDef& attr) { |
484 | if (!attr.has_default_value()) return "no default" ; |
485 | return SummarizeAttrValue(attr.default_value()); |
486 | } |
487 | |
488 | bool HigherMinimum(const OpDef::AttrDef& old_attr, |
489 | const OpDef::AttrDef& new_attr) { |
490 | // Anything -> no restriction : not more restrictive. |
491 | if (!new_attr.has_minimum()) return false; |
492 | // No restriction -> restriction : more restrictive. |
493 | if (!old_attr.has_minimum()) return true; |
494 | // If anything that was previously allowed is no longer allowed: |
495 | // more restrictive. |
496 | return new_attr.minimum() > old_attr.minimum(); |
497 | } |
498 | |
499 | string MinStr(const OpDef::AttrDef& attr) { |
500 | if (!attr.has_minimum()) return "no minimum" ; |
501 | return strings::StrCat(attr.minimum()); |
502 | } |
503 | |
504 | typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap; |
505 | void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { |
506 | for (const auto& attr : op_def.attr()) { |
507 | (*attr_map)[attr.name()] = &attr; |
508 | } |
509 | } |
510 | |
511 | // Add a comma to *s every call but the first (*add_comma should be |
512 | // initialized to false). |
513 | void AddComma(string* s, bool* add_comma) { |
514 | if (*add_comma) { |
515 | strings::StrAppend(s, ", " ); |
516 | } else { |
517 | *add_comma = true; |
518 | } |
519 | } |
520 | |
521 | // Will add the `name` from arg if name is true. |
522 | void AddName(string* s, bool name, const OpDef::ArgDef& arg) { |
523 | if (name) { |
524 | strings::StrAppend(s, arg.name(), ":" ); |
525 | } |
526 | } |
527 | |
528 | // Compute a signature for either inputs or outputs that will be the |
529 | // same for both the old and new OpDef if they are compatible. We |
530 | // assume that new_attrs is a superset of old_attrs, and that any attr |
531 | // in the difference has a default. Our strategy is to make a list of |
532 | // types, where the types are things like: |
533 | // * "int32", "float", etc., |
534 | // * "T" for some attr "T" in old_attrs, or |
535 | // * "N * type" for "N" either some attr in old_attrs. |
536 | // |
537 | // We get the types by either using the attrs in args if they are in |
538 | // old_attrs, or substituting the default value from new_attrs. |
539 | string ComputeArgSignature( |
540 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, |
541 | const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref, |
542 | bool names) { |
543 | string s; |
544 | bool add_comma = false; |
545 | for (const OpDef::ArgDef& arg : args) { |
546 | if (!arg.type_list_attr().empty()) { |
547 | const OpDef::AttrDef* old_attr = |
548 | gtl::FindPtrOrNull(old_attrs, arg.type_list_attr()); |
549 | if (old_attr) { |
550 | // Both old and new have the list(type) attr, so can use it directly. |
551 | AddComma(&s, &add_comma); |
552 | AddName(&s, names, arg); |
553 | strings::StrAppend(&s, arg.type_list_attr()); |
554 | ref->push_back(arg.is_ref()); |
555 | } else { |
556 | // Missing the list(type) attr in the old, so use the default |
557 | // value for the attr from new instead. |
558 | const OpDef::AttrDef* new_attr = |
559 | gtl::FindPtrOrNull(new_attrs, arg.type_list_attr()); |
560 | const auto& type_list = new_attr->default_value().list().type(); |
561 | if (type_list.empty()) continue; |
562 | for (int i = 0; i < type_list.size(); ++i) { |
563 | AddComma(&s, &add_comma); |
564 | AddName(&s, names, arg); |
565 | strings::StrAppend( |
566 | &s, DataTypeString(static_cast<DataType>(type_list.Get(i)))); |
567 | ref->push_back(arg.is_ref()); |
568 | } |
569 | } |
570 | } else { |
571 | int num = 1; // How many input/outputs does this represent? |
572 | string type; // What is the type of this arg? |
573 | AddName(&type, names, arg); |
574 | if (!arg.number_attr().empty()) { |
575 | // N * type case. |
576 | const OpDef::AttrDef* old_attr = |
577 | gtl::FindPtrOrNull(old_attrs, arg.number_attr()); |
578 | if (old_attr) { |
579 | // Both old and new have the number attr, so can use it directly. |
580 | strings::StrAppend(&type, arg.number_attr(), " * " ); |
581 | } else { |
582 | // Missing the number attr in the old, so use the default |
583 | // value for the attr from new instead. |
584 | const OpDef::AttrDef* new_attr = |
585 | gtl::FindPtrOrNull(new_attrs, arg.number_attr()); |
586 | num = new_attr->default_value().i(); |
587 | } |
588 | } |
589 | |
590 | if (arg.type() != DT_INVALID) { |
591 | // int32, float, etc. case |
592 | strings::StrAppend(&type, DataTypeString(arg.type())); |
593 | } else { |
594 | const OpDef::AttrDef* old_attr = |
595 | gtl::FindPtrOrNull(old_attrs, arg.type_attr()); |
596 | if (old_attr) { |
597 | // Both old and new have the type attr, so can use it directly. |
598 | strings::StrAppend(&type, arg.type_attr()); |
599 | } else { |
600 | // Missing the type attr in the old, so use the default |
601 | // value for the attr from new instead. |
602 | const OpDef::AttrDef* new_attr = |
603 | gtl::FindPtrOrNull(new_attrs, arg.type_attr()); |
604 | strings::StrAppend(&type, |
605 | DataTypeString(new_attr->default_value().type())); |
606 | } |
607 | } |
608 | |
609 | // Record `num` * `type` in the signature. |
610 | for (int i = 0; i < num; ++i) { |
611 | AddComma(&s, &add_comma); |
612 | strings::StrAppend(&s, type); |
613 | ref->push_back(arg.is_ref()); |
614 | } |
615 | } |
616 | } |
617 | |
618 | return s; |
619 | } |
620 | |
621 | } // namespace |
622 | |
623 | Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { |
624 | #define VALIDATE(CONDITION, ...) \ |
625 | if (!(CONDITION)) { \ |
626 | return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ |
627 | "; old: ", SummarizeOpDef(old_op), \ |
628 | "; new: ", SummarizeOpDef(new_op)); \ |
629 | } |
630 | |
631 | VALIDATE(old_op.name() == new_op.name(), "Name mismatch" ); |
632 | |
633 | AttrMap new_attrs, old_attrs; |
634 | FillAttrMap(old_op, &old_attrs); |
635 | FillAttrMap(new_op, &new_attrs); |
636 | for (const auto& old_attr : old_op.attr()) { |
637 | const OpDef::AttrDef* new_attr = |
638 | gtl::FindPtrOrNull(new_attrs, old_attr.name()); |
639 | VALIDATE(new_attr != nullptr, "Attr '" , old_attr.name(), "' removed" ); |
640 | VALIDATE(old_attr.type() == new_attr->type(), "Attr '" , old_attr.name(), |
641 | "' changed type '" , old_attr.type(), "' -> '" , new_attr->type(), |
642 | "'" ); |
643 | VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '" , old_attr.name(), |
644 | "' has a stricter set of allowed values; from " , |
645 | AllowedStr(old_attr), " to " , AllowedStr(*new_attr)); |
646 | VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '" , old_attr.name(), |
647 | "' has a higher minimum; from " , MinStr(old_attr), " to " , |
648 | MinStr(*new_attr)); |
649 | } |
650 | |
651 | for (const auto& new_attr : new_op.attr()) { |
652 | const OpDef::AttrDef* old_attr = |
653 | gtl::FindPtrOrNull(old_attrs, new_attr.name()); |
654 | VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '" , |
655 | new_attr.name(), "' added without default" ); |
656 | } |
657 | |
658 | std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref; |
659 | const string old_in_sig = ComputeArgSignature( |
660 | old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */); |
661 | const string new_in_sig = ComputeArgSignature( |
662 | new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */); |
663 | VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '" , old_in_sig, |
664 | "' vs. '" , new_in_sig, "'" ); |
665 | VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen |
666 | "Unexpected change in input ref lists." ); |
667 | for (int i = 0, end = old_in_ref.size(); i < end; ++i) { |
668 | // Allowed to remove "ref" from an input (or leave it unchanged). |
669 | VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input " , i, |
670 | " changed from non-ref to ref" ); |
671 | } |
672 | |
673 | const string old_out_sig = |
674 | ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs, |
675 | &old_out_ref, true /* names */); |
676 | const string new_out_sig = |
677 | ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs, |
678 | &new_out_ref, true /* names */); |
679 | VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '" , |
680 | old_out_sig, "' vs. '" , new_out_sig, "'" ); |
681 | VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen |
682 | "Unexpected change in output ref lists" ); |
683 | for (int i = 0, end = old_out_ref.size(); i < end; ++i) { |
684 | // Allowed to add "ref" to an output (or leave it unchanged). |
685 | VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output " , i, |
686 | " changed from ref to non-ref" ); |
687 | } |
688 | |
689 | return OkStatus(); |
690 | } |
691 | |
692 | Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, |
693 | const OpDef& penultimate_op, |
694 | const OpDef& new_op) { |
695 | AttrMap new_attrs, old_attrs; |
696 | FillAttrMap(old_op, &old_attrs); |
697 | FillAttrMap(new_op, &new_attrs); |
698 | |
699 | for (const auto& penultimate_attr : penultimate_op.attr()) { |
700 | const OpDef::AttrDef* old_attr = |
701 | gtl::FindPtrOrNull(old_attrs, penultimate_attr.name()); |
702 | if (old_attr != nullptr) continue; // attr wasn't added |
703 | const OpDef::AttrDef* new_attr = |
704 | gtl::FindPtrOrNull(new_attrs, penultimate_attr.name()); |
705 | |
706 | // These shouldn't happen if the op passed OpDefCompatible(). |
707 | if (new_attr == nullptr) { |
708 | return errors::InvalidArgument("Missing attr '" , penultimate_attr.name(), |
709 | "' in op: " , SummarizeOpDef(new_op)); |
710 | } |
711 | if (!penultimate_attr.has_default_value() || |
712 | !new_attr->has_default_value()) { |
713 | return errors::InvalidArgument("Missing default for attr '" , |
714 | penultimate_attr.name(), |
715 | "' in op: " , SummarizeOpDef(new_op)); |
716 | } |
717 | |
718 | // Actually test that the attr's default value hasn't changed. |
719 | if (!AreAttrValuesEqual(penultimate_attr.default_value(), |
720 | new_attr->default_value())) { |
721 | return errors::InvalidArgument( |
722 | "Can't change default value for attr '" , penultimate_attr.name(), |
723 | "' from " , SummarizeAttrValue(penultimate_attr.default_value()), |
724 | " in op: " , SummarizeOpDef(new_op)); |
725 | } |
726 | } |
727 | |
728 | return OkStatus(); |
729 | } |
730 | |
731 | Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { |
732 | AttrMap new_attrs, old_attrs; |
733 | FillAttrMap(old_op, &old_attrs); |
734 | FillAttrMap(new_op, &new_attrs); |
735 | |
736 | for (const auto& old_attr : old_op.attr()) { |
737 | const OpDef::AttrDef* new_attr = |
738 | gtl::FindPtrOrNull(new_attrs, old_attr.name()); |
739 | if (new_attr == nullptr) continue; |
740 | if (new_attr->has_default_value() && !old_attr.has_default_value()) { |
741 | continue; // Adding new default values is safe. |
742 | } |
743 | if (old_attr.has_default_value() && !new_attr->has_default_value()) { |
744 | return errors::InvalidArgument( |
745 | "Attr '" , old_attr.name(), "' has removed it's default; " , "from " , |
746 | DefaultAttrStr(old_attr), " to " , DefaultAttrStr(*new_attr)); |
747 | } |
748 | if (old_attr.has_default_value() && |
749 | !AreAttrValuesEqual(old_attr.default_value(), |
750 | new_attr->default_value())) { |
751 | return errors::InvalidArgument( |
752 | "Attr '" , old_attr.name(), "' has changed it's default value; " , |
753 | "from " , DefaultAttrStr(old_attr), " to " , DefaultAttrStr(*new_attr)); |
754 | } |
755 | } |
756 | |
757 | return OkStatus(); |
758 | } |
759 | |
760 | void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { |
761 | for (int i = 0; i < op_def->input_arg_size(); ++i) { |
762 | op_def->mutable_input_arg(i)->clear_description(); |
763 | } |
764 | for (int i = 0; i < op_def->output_arg_size(); ++i) { |
765 | op_def->mutable_output_arg(i)->clear_description(); |
766 | } |
767 | for (int i = 0; i < op_def->attr_size(); ++i) { |
768 | op_def->mutable_attr(i)->clear_description(); |
769 | } |
770 | op_def->clear_summary(); |
771 | op_def->clear_description(); |
772 | } |
773 | |
774 | void RemoveDescriptionsFromOpDef(OpDef* op_def) { |
775 | RemoveNonDeprecationDescriptionsFromOpDef(op_def); |
776 | if (op_def->has_deprecation()) { |
777 | op_def->mutable_deprecation()->clear_explanation(); |
778 | } |
779 | } |
780 | |
781 | void RemoveDescriptionsFromOpList(OpList* op_list) { |
782 | for (int i = 0; i < op_list->op_size(); ++i) { |
783 | OpDef* op_def = op_list->mutable_op(i); |
784 | RemoveDescriptionsFromOpDef(op_def); |
785 | } |
786 | } |
787 | |
788 | bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { |
789 | if (std::is_base_of<protobuf::Message, OpDef::AttrDef>()) { |
790 | DCHECK_EQ(7, reinterpret_cast<const protobuf::Message*>(&a1) |
791 | ->GetDescriptor() |
792 | ->field_count()) |
793 | << "Please modify these equality and hash functions to reflect the " |
794 | "changes to the AttrDef protobuf" ; |
795 | } |
796 | |
797 | if (a1.name() != a2.name()) return false; |
798 | if (a1.type() != a2.type()) return false; |
799 | if (a1.description() != a2.description()) return false; |
800 | if (a1.has_minimum() != a2.has_minimum()) return false; |
801 | if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false; |
802 | if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false; |
803 | if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values())) |
804 | return false; |
805 | return true; |
806 | } |
807 | |
808 | uint64 AttrDefHash(const OpDef::AttrDef& a) { |
809 | uint64 h = Hash64(a.name()); |
810 | h = Hash64(a.type().data(), a.type().size(), h); |
811 | h = Hash64Combine(AttrValueHash(a.default_value()), h); |
812 | h = Hash64(a.description().data(), a.description().size(), h); |
813 | h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h); |
814 | h = Hash64Combine(static_cast<uint64>(a.minimum()), h); |
815 | h = Hash64Combine(AttrValueHash(a.allowed_values()), h); |
816 | return h; |
817 | } |
818 | |
819 | bool RepeatedAttrDefEqual( |
820 | const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1, |
821 | const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) { |
822 | std::unordered_map<string, const OpDef::AttrDef*> a1_set; |
823 | for (const OpDef::AttrDef& def : a1) { |
824 | if (a1_set.find(def.name()) != a1_set.end()) { |
825 | LOG(ERROR) << "AttrDef names must be unique, but '" << def.name() |
826 | << "' appears more than once" ; |
827 | } |
828 | a1_set[def.name()] = &def; |
829 | } |
830 | for (const OpDef::AttrDef& def : a2) { |
831 | auto iter = a1_set.find(def.name()); |
832 | if (iter == a1_set.end()) return false; |
833 | if (!AttrDefEqual(*iter->second, def)) return false; |
834 | a1_set.erase(iter); |
835 | } |
836 | if (!a1_set.empty()) return false; |
837 | return true; |
838 | } |
839 | |
840 | uint64 RepeatedAttrDefHash( |
841 | const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) { |
842 | // Insert AttrDefs into map to deterministically sort by name |
843 | std::map<string, const OpDef::AttrDef*> a_set; |
844 | for (const OpDef::AttrDef& def : a) { |
845 | a_set[def.name()] = &def; |
846 | } |
847 | // Iterate and combines hashes of keys and values |
848 | uint64 h = 0xDECAFCAFFE; |
849 | for (const auto& pair : a_set) { |
850 | h = Hash64(pair.first.data(), pair.first.size(), h); |
851 | h = Hash64Combine(AttrDefHash(*pair.second), h); |
852 | } |
853 | return h; |
854 | } |
855 | |
856 | bool OpDefEqual(const OpDef& o1, const OpDef& o2) { |
857 | // attr order doesn't matter. |
858 | // Compare it separately here instead of serializing below. |
859 | if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false; |
860 | |
861 | // `control_output` order doesn't matter. |
862 | std::set<string> control_output1(o1.control_output().begin(), |
863 | o1.control_output().end()); |
864 | std::set<string> control_output2(o2.control_output().begin(), |
865 | o2.control_output().end()); |
866 | if (control_output1 != control_output2) return false; |
867 | |
868 | // Clear `attr` and `control_output` fields, serialize, and compare serialized |
869 | // strings. |
870 | OpDef o1_copy = o1; |
871 | OpDef o2_copy = o2; |
872 | o1_copy.clear_attr(); |
873 | o1_copy.clear_control_output(); |
874 | o2_copy.clear_attr(); |
875 | o2_copy.clear_control_output(); |
876 | |
877 | return AreSerializedProtosEqual(o1_copy, o2_copy); |
878 | } |
879 | |
880 | uint64 OpDefHash(const OpDef& o) { |
881 | uint64 h = RepeatedAttrDefHash(o.attr()); |
882 | |
883 | // Compute deterministic order-independent control outputs hash. |
884 | std::set<string> control_output(o.control_output().begin(), |
885 | o.control_output().end()); |
886 | for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co)); |
887 | |
888 | OpDef o_copy = o; |
889 | o_copy.clear_attr(); |
890 | o_copy.clear_control_output(); |
891 | return DeterministicProtoHash64(o_copy, h); |
892 | } |
893 | |
894 | } // namespace tensorflow |
895 | |