1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_def_util.h"
17
18#include <set>
19#include <unordered_map>
20#include <unordered_set>
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/attr_value_util.h"
24#include "tensorflow/core/framework/op_def.pb.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/core/stringpiece.h"
28#include "tensorflow/core/lib/gtl/map_util.h"
29#include "tensorflow/core/lib/hash/hash.h"
30#include "tensorflow/core/lib/strings/proto_serialization.h"
31#include "tensorflow/core/lib/strings/scanner.h"
32#include "tensorflow/core/lib/strings/str_util.h"
33#include "tensorflow/core/lib/strings/strcat.h"
34#include "tensorflow/core/platform/mutex.h"
35#include "tensorflow/core/platform/protobuf.h"
36#include "tensorflow/core/platform/types.h"
37
38namespace tensorflow {
39namespace { // ------ Helper functions ------
40
41bool HasAttrStyleType(const OpDef::ArgDef& arg) {
42 return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
43 !arg.type_list_attr().empty();
44}
45
46Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
47 const AttrValue& allowed_values(attr.allowed_values());
48 for (auto allowed : allowed_values.list().type()) {
49 if (dt == allowed) {
50 return OkStatus();
51 }
52 }
53 string allowed_str;
54 for (int i = 0; i < allowed_values.list().type_size(); ++i) {
55 if (!allowed_str.empty()) {
56 strings::StrAppend(&allowed_str, ", ");
57 }
58 strings::StrAppend(&allowed_str,
59 DataTypeString(allowed_values.list().type(i)));
60 }
61 return errors::InvalidArgument(
62 "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
63 " is not in the list of allowed values: ", allowed_str);
64}
65
66Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
67 const AttrValue& allowed_values(attr.allowed_values());
68 for (const auto& allowed : allowed_values.list().s()) {
69 if (str == allowed) {
70 return OkStatus();
71 }
72 }
73 string allowed_str;
74 for (const string& allowed : allowed_values.list().s()) {
75 if (!allowed_str.empty()) {
76 strings::StrAppend(&allowed_str, ", ");
77 }
78 strings::StrAppend(&allowed_str, "\"", allowed, "\"");
79 }
80 return errors::InvalidArgument(
81 "Value for attr '", attr.name(), "' of \"", str,
82 "\" is not in the list of allowed values: ", allowed_str);
83}
84
85} // namespace
86
87// Requires: attr has already been validated.
88Status ValidateAttrValue(const AttrValue& attr_value,
89 const OpDef::AttrDef& attr) {
90 // Is it a valid value?
91 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
92 " for attr '", attr.name(), "'");
93
94 // Does the value satisfy the minimum constraint in the AttrDef?
95 if (attr.has_minimum()) {
96 if (attr.type() == "int") {
97 if (attr_value.i() < attr.minimum()) {
98 return errors::InvalidArgument(
99 "Value for attr '", attr.name(), "' of ", attr_value.i(),
100 " must be at least minimum ", attr.minimum());
101 }
102 } else {
103 int length = -1;
104 if (attr.type() == "list(string)") {
105 length = attr_value.list().s_size();
106 } else if (attr.type() == "list(int)") {
107 length = attr_value.list().i_size();
108 } else if (attr.type() == "list(float)") {
109 length = attr_value.list().f_size();
110 } else if (attr.type() == "list(bool)") {
111 length = attr_value.list().b_size();
112 } else if (attr.type() == "list(type)") {
113 length = attr_value.list().type_size();
114 } else if (attr.type() == "list(shape)") {
115 length = attr_value.list().shape_size();
116 } else if (attr.type() == "list(tensor)") {
117 length = attr_value.list().tensor_size();
118 } else if (attr.type() == "list(func)") {
119 length = attr_value.list().func_size();
120 }
121 if (length < attr.minimum()) {
122 return errors::InvalidArgument(
123 "Length for attr '", attr.name(), "' of ", length,
124 " must be at least minimum ", attr.minimum());
125 }
126 }
127 }
128
129 // Does the value satisfy the allowed_value constraint in the AttrDef?
130 if (attr.has_allowed_values()) {
131 if (attr.type() == "type") {
132 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
133 } else if (attr.type() == "list(type)") {
134 for (int dt : attr_value.list().type()) {
135 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
136 }
137 } else if (attr.type() == "string") {
138 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
139 } else if (attr.type() == "list(string)") {
140 for (const string& str : attr_value.list().s()) {
141 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
142 }
143 } else {
144 return errors::Unimplemented(
145 "Support for allowed_values not implemented for type ", attr.type());
146 }
147 }
148 return OkStatus();
149}
150
151const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
152 for (int i = 0; i < op_def.attr_size(); ++i) {
153 if (op_def.attr(i).name() == name) {
154 return &op_def.attr(i);
155 }
156 }
157 return nullptr;
158}
159
160OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
161 for (int i = 0; i < op_def->attr_size(); ++i) {
162 if (op_def->attr(i).name() == name) {
163 return op_def->mutable_attr(i);
164 }
165 }
166 return nullptr;
167}
168
169const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
170 for (int i = 0; i < op_def.input_arg_size(); ++i) {
171 if (op_def.input_arg(i).name() == name) {
172 return &op_def.input_arg(i);
173 }
174 }
175 return nullptr;
176}
177
178const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
179 for (int i = 0; i < api_def.in_arg_size(); ++i) {
180 if (api_def.in_arg(i).name() == name) {
181 return &api_def.in_arg(i);
182 }
183 }
184 return nullptr;
185}
186
187#define VALIDATE(EXPR, ...) \
188 do { \
189 if (!(EXPR)) { \
190 return errors::InvalidArgument( \
191 __VA_ARGS__, "; in OpDef: ", op_def.ShortDebugString()); \
192 } \
193 } while (false)
194
195static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
196 bool output, std::set<string>* names) {
197 const string suffix = strings::StrCat(
198 output ? " for output '" : " for input '", arg.name(), "'");
199 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
200 "Duplicate name: ", arg.name());
201 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
202
203 if (!arg.number_attr().empty()) {
204 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
205 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
206 suffix);
207 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
208 suffix, " has type ", attr->type(), " != int");
209 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
210 suffix, " must have minimum");
211 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
212 suffix, " must have minimum >= 0");
213 VALIDATE(arg.type_list_attr().empty(),
214 "Can't have both number_attr and type_list_attr", suffix);
215 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
216 (!arg.type_attr().empty() ? 1 : 0) ==
217 1,
218 "Exactly one of type, type_attr must be set", suffix);
219 } else {
220 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
221 (!arg.type_attr().empty() ? 1 : 0) +
222 (!arg.type_list_attr().empty() ? 1 : 0);
223 VALIDATE(num_type_fields == 1,
224 "Exactly one of type, type_attr, type_list_attr must be set",
225 suffix);
226 }
227
228 if (!arg.type_attr().empty()) {
229 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
230 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
231 suffix);
232 VALIDATE(attr->type() == "type", "Attr '", attr->name(),
233 "' used as type_attr", suffix, " has type ", attr->type(),
234 " != type");
235 } else if (!arg.type_list_attr().empty()) {
236 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
237 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
238 suffix);
239 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
240 "' used as type_list_attr", suffix, " has type ", attr->type(),
241 " != list(type)");
242 } else {
243 // All argument types should be non-reference types at this point.
244 // ArgDef.is_ref is set to true for reference arguments.
245 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
246 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
247 }
248
249 return OkStatus();
250}
251
252bool IsValidOpName(StringPiece sp) {
253 using ::tensorflow::strings::Scanner;
254
255 Scanner scanner(sp);
256 scanner.One(Scanner::UPPERLETTER).Any(Scanner::LETTER_DIGIT_UNDERSCORE);
257
258 while (true) {
259 if (!scanner.GetResult()) // Some error in previous iteration.
260 return false;
261 if (scanner.empty()) // No error, but nothing left, good.
262 return true;
263
264 // Absorb another name/namespace, starting with a '>'
265 scanner.One(Scanner::RANGLE)
266 .One(Scanner::UPPERLETTER)
267 .Any(Scanner::LETTER_DIGIT_UNDERSCORE);
268 }
269}
270
271Status ValidateOpDef(const OpDef& op_def) {
272 if (!absl::StartsWith(op_def.name(), "_")) {
273 VALIDATE(IsValidOpName(op_def.name()), "Invalid name: ", op_def.name(),
274 " (Did you use CamelCase?)");
275 }
276
277 std::set<string> names; // for detecting duplicate names
278 for (const auto& attr : op_def.attr()) {
279 // Validate name
280 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
281 "Duplicate name: ", attr.name());
282 DataType dt;
283 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
284 attr.name(), " that matches a data type");
285
286 // Validate type
287 StringPiece type(attr.type());
288 bool is_list = absl::ConsumePrefix(&type, "list(");
289 bool found = false;
290 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
291 "tensor", "func"}) {
292 if (absl::ConsumePrefix(&type, valid)) {
293 found = true;
294 break;
295 }
296 }
297 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
298 "'");
299 if (is_list) {
300 VALIDATE(absl::ConsumePrefix(&type, ")"),
301 "'list(' is missing ')' in attr ", attr.name(), "'s type ",
302 attr.type());
303 }
304 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
305 attr.name(), "'s type ", attr.type());
306
307 // Validate minimum
308 if (attr.has_minimum()) {
309 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
310 "' has minimum for unsupported type ", attr.type());
311 if (is_list) {
312 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
313 "' with list type must have a non-negative minimum, not ",
314 attr.minimum());
315 }
316 } else {
317 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
318 "' with has_minimum = false but minimum ", attr.minimum(),
319 " not equal to default of 0");
320 }
321
322 // Validate allowed_values
323 if (attr.has_allowed_values()) {
324 const string list_type =
325 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
326 TF_RETURN_WITH_CONTEXT_IF_ERROR(
327 AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
328 attr.name(), "' in Op '", op_def.name(), "'");
329 }
330
331 // Validate default_value (after we have validated the rest of the attr,
332 // so we can use ValidateAttrValue()).
333 if (attr.has_default_value()) {
334 TF_RETURN_WITH_CONTEXT_IF_ERROR(
335 ValidateAttrValue(attr.default_value(), attr), " in Op '",
336 op_def.name(), "'");
337 }
338 }
339
340 for (const auto& arg : op_def.input_arg()) {
341 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
342 }
343
344 for (const auto& arg : op_def.output_arg()) {
345 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
346 }
347
348 return OkStatus();
349}
350
351#undef VALIDATE
352
353Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
354 if (op_def.has_deprecation()) {
355 const OpDeprecation& dep = op_def.deprecation();
356 if (graph_def_version >= dep.version()) {
357 return errors::Unimplemented(
358 "Op ", op_def.name(), " is not available in GraphDef version ",
359 graph_def_version, ". It has been removed in version ", dep.version(),
360 ". ", dep.explanation(), ".");
361 } else {
362 // Warn only once for each op name, and do it in a threadsafe manner.
363 static mutex mu(LINKER_INITIALIZED);
364 static std::unordered_set<string> warned;
365 bool warn;
366 {
367 mutex_lock lock(mu);
368 warn = warned.insert(op_def.name()).second;
369 }
370 if (warn) {
371 LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
372 << " It will cease to work in GraphDef version "
373 << dep.version() << ". " << dep.explanation() << ".";
374 }
375 }
376 }
377 return OkStatus();
378}
379
380namespace {
381
382string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
383 string ret;
384 for (const OpDef::ArgDef& arg : args) {
385 if (!ret.empty()) strings::StrAppend(&ret, ", ");
386 strings::StrAppend(&ret, arg.name(), ":");
387 if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
388 if (!arg.number_attr().empty()) {
389 strings::StrAppend(&ret, arg.number_attr(), "*");
390 }
391 if (arg.type() != DT_INVALID) {
392 strings::StrAppend(&ret, DataTypeString(arg.type()));
393 } else {
394 strings::StrAppend(&ret, arg.type_attr());
395 }
396 if (arg.is_ref()) strings::StrAppend(&ret, ")");
397 }
398 return ret;
399}
400
401} // namespace
402
403string SummarizeOpDef(const OpDef& op_def) {
404 string ret = strings::StrCat("Op<name=", op_def.name());
405 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
406 " -> ", SummarizeArgs(op_def.output_arg()));
407 for (int i = 0; i < op_def.attr_size(); ++i) {
408 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
409 op_def.attr(i).type());
410 if (op_def.attr(i).has_default_value()) {
411 strings::StrAppend(&ret, ",default=",
412 SummarizeAttrValue(op_def.attr(i).default_value()));
413 }
414 if (op_def.attr(i).has_minimum()) {
415 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
416 }
417 if (op_def.attr(i).has_allowed_values()) {
418 strings::StrAppend(&ret, ",allowed=",
419 SummarizeAttrValue(op_def.attr(i).allowed_values()));
420 }
421 }
422 if (op_def.is_commutative()) {
423 strings::StrAppend(&ret, "; is_commutative=true");
424 }
425 if (op_def.is_aggregate()) {
426 strings::StrAppend(&ret, "; is_aggregate=true");
427 }
428 if (op_def.is_stateful()) {
429 strings::StrAppend(&ret, "; is_stateful=true");
430 }
431 if (op_def.allows_uninitialized_input()) {
432 strings::StrAppend(&ret, "; allows_uninitialized_input=true");
433 }
434 if (op_def.is_distributed_communication()) {
435 strings::StrAppend(&ret, "; is_distributed_communication=true");
436 }
437 strings::StrAppend(&ret, ">");
438 return ret;
439}
440
441namespace {
442
443// Returns true if every element of `sub` is contained in `super`.
444template <class T>
445bool IsSubsetOf(const T& sub, const T& super) {
446 for (const auto& o : sub) {
447 bool found = false;
448 for (const auto& n : super) {
449 if (o == n) {
450 found = true;
451 break;
452 }
453 }
454 if (!found) return false;
455 }
456 return true;
457}
458
459bool MoreRestrictive(const OpDef::AttrDef& old_attr,
460 const OpDef::AttrDef& new_attr) {
461 // Anything -> no restriction : not more restrictive.
462 if (!new_attr.has_allowed_values()) return false;
463 // No restriction -> restriction : more restrictive.
464 if (!old_attr.has_allowed_values()) return true;
465 // If anything that was previously allowed is no longer allowed:
466 // more restrictive.
467 if (!IsSubsetOf(old_attr.allowed_values().list().type(),
468 new_attr.allowed_values().list().type())) {
469 return true;
470 }
471 if (!IsSubsetOf(old_attr.allowed_values().list().s(),
472 new_attr.allowed_values().list().s())) {
473 return true;
474 }
475 return false;
476}
477
478string AllowedStr(const OpDef::AttrDef& attr) {
479 if (!attr.has_allowed_values()) return "no restriction";
480 return SummarizeAttrValue(attr.allowed_values());
481}
482
483string DefaultAttrStr(const OpDef::AttrDef& attr) {
484 if (!attr.has_default_value()) return "no default";
485 return SummarizeAttrValue(attr.default_value());
486}
487
488bool HigherMinimum(const OpDef::AttrDef& old_attr,
489 const OpDef::AttrDef& new_attr) {
490 // Anything -> no restriction : not more restrictive.
491 if (!new_attr.has_minimum()) return false;
492 // No restriction -> restriction : more restrictive.
493 if (!old_attr.has_minimum()) return true;
494 // If anything that was previously allowed is no longer allowed:
495 // more restrictive.
496 return new_attr.minimum() > old_attr.minimum();
497}
498
499string MinStr(const OpDef::AttrDef& attr) {
500 if (!attr.has_minimum()) return "no minimum";
501 return strings::StrCat(attr.minimum());
502}
503
504typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
505void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
506 for (const auto& attr : op_def.attr()) {
507 (*attr_map)[attr.name()] = &attr;
508 }
509}
510
511// Add a comma to *s every call but the first (*add_comma should be
512// initialized to false).
513void AddComma(string* s, bool* add_comma) {
514 if (*add_comma) {
515 strings::StrAppend(s, ", ");
516 } else {
517 *add_comma = true;
518 }
519}
520
521// Will add the `name` from arg if name is true.
522void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
523 if (name) {
524 strings::StrAppend(s, arg.name(), ":");
525 }
526}
527
528// Compute a signature for either inputs or outputs that will be the
529// same for both the old and new OpDef if they are compatible. We
530// assume that new_attrs is a superset of old_attrs, and that any attr
531// in the difference has a default. Our strategy is to make a list of
532// types, where the types are things like:
533// * "int32", "float", etc.,
534// * "T" for some attr "T" in old_attrs, or
535// * "N * type" for "N" either some attr in old_attrs.
536//
537// We get the types by either using the attrs in args if they are in
538// old_attrs, or substituting the default value from new_attrs.
539string ComputeArgSignature(
540 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
541 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
542 bool names) {
543 string s;
544 bool add_comma = false;
545 for (const OpDef::ArgDef& arg : args) {
546 if (!arg.type_list_attr().empty()) {
547 const OpDef::AttrDef* old_attr =
548 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
549 if (old_attr) {
550 // Both old and new have the list(type) attr, so can use it directly.
551 AddComma(&s, &add_comma);
552 AddName(&s, names, arg);
553 strings::StrAppend(&s, arg.type_list_attr());
554 ref->push_back(arg.is_ref());
555 } else {
556 // Missing the list(type) attr in the old, so use the default
557 // value for the attr from new instead.
558 const OpDef::AttrDef* new_attr =
559 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
560 const auto& type_list = new_attr->default_value().list().type();
561 if (type_list.empty()) continue;
562 for (int i = 0; i < type_list.size(); ++i) {
563 AddComma(&s, &add_comma);
564 AddName(&s, names, arg);
565 strings::StrAppend(
566 &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
567 ref->push_back(arg.is_ref());
568 }
569 }
570 } else {
571 int num = 1; // How many input/outputs does this represent?
572 string type; // What is the type of this arg?
573 AddName(&type, names, arg);
574 if (!arg.number_attr().empty()) {
575 // N * type case.
576 const OpDef::AttrDef* old_attr =
577 gtl::FindPtrOrNull(old_attrs, arg.number_attr());
578 if (old_attr) {
579 // Both old and new have the number attr, so can use it directly.
580 strings::StrAppend(&type, arg.number_attr(), " * ");
581 } else {
582 // Missing the number attr in the old, so use the default
583 // value for the attr from new instead.
584 const OpDef::AttrDef* new_attr =
585 gtl::FindPtrOrNull(new_attrs, arg.number_attr());
586 num = new_attr->default_value().i();
587 }
588 }
589
590 if (arg.type() != DT_INVALID) {
591 // int32, float, etc. case
592 strings::StrAppend(&type, DataTypeString(arg.type()));
593 } else {
594 const OpDef::AttrDef* old_attr =
595 gtl::FindPtrOrNull(old_attrs, arg.type_attr());
596 if (old_attr) {
597 // Both old and new have the type attr, so can use it directly.
598 strings::StrAppend(&type, arg.type_attr());
599 } else {
600 // Missing the type attr in the old, so use the default
601 // value for the attr from new instead.
602 const OpDef::AttrDef* new_attr =
603 gtl::FindPtrOrNull(new_attrs, arg.type_attr());
604 strings::StrAppend(&type,
605 DataTypeString(new_attr->default_value().type()));
606 }
607 }
608
609 // Record `num` * `type` in the signature.
610 for (int i = 0; i < num; ++i) {
611 AddComma(&s, &add_comma);
612 strings::StrAppend(&s, type);
613 ref->push_back(arg.is_ref());
614 }
615 }
616 }
617
618 return s;
619}
620
621} // namespace
622
623Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
624#define VALIDATE(CONDITION, ...) \
625 if (!(CONDITION)) { \
626 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
627 "; old: ", SummarizeOpDef(old_op), \
628 "; new: ", SummarizeOpDef(new_op)); \
629 }
630
631 VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
632
633 AttrMap new_attrs, old_attrs;
634 FillAttrMap(old_op, &old_attrs);
635 FillAttrMap(new_op, &new_attrs);
636 for (const auto& old_attr : old_op.attr()) {
637 const OpDef::AttrDef* new_attr =
638 gtl::FindPtrOrNull(new_attrs, old_attr.name());
639 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
640 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
641 "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
642 "'");
643 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
644 "' has a stricter set of allowed values; from ",
645 AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
646 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
647 "' has a higher minimum; from ", MinStr(old_attr), " to ",
648 MinStr(*new_attr));
649 }
650
651 for (const auto& new_attr : new_op.attr()) {
652 const OpDef::AttrDef* old_attr =
653 gtl::FindPtrOrNull(old_attrs, new_attr.name());
654 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
655 new_attr.name(), "' added without default");
656 }
657
658 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
659 const string old_in_sig = ComputeArgSignature(
660 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
661 const string new_in_sig = ComputeArgSignature(
662 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
663 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
664 "' vs. '", new_in_sig, "'");
665 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen
666 "Unexpected change in input ref lists.");
667 for (int i = 0, end = old_in_ref.size(); i < end; ++i) {
668 // Allowed to remove "ref" from an input (or leave it unchanged).
669 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
670 " changed from non-ref to ref");
671 }
672
673 const string old_out_sig =
674 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
675 &old_out_ref, true /* names */);
676 const string new_out_sig =
677 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
678 &new_out_ref, true /* names */);
679 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
680 old_out_sig, "' vs. '", new_out_sig, "'");
681 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen
682 "Unexpected change in output ref lists");
683 for (int i = 0, end = old_out_ref.size(); i < end; ++i) {
684 // Allowed to add "ref" to an output (or leave it unchanged).
685 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
686 " changed from ref to non-ref");
687 }
688
689 return OkStatus();
690}
691
692Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
693 const OpDef& penultimate_op,
694 const OpDef& new_op) {
695 AttrMap new_attrs, old_attrs;
696 FillAttrMap(old_op, &old_attrs);
697 FillAttrMap(new_op, &new_attrs);
698
699 for (const auto& penultimate_attr : penultimate_op.attr()) {
700 const OpDef::AttrDef* old_attr =
701 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
702 if (old_attr != nullptr) continue; // attr wasn't added
703 const OpDef::AttrDef* new_attr =
704 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
705
706 // These shouldn't happen if the op passed OpDefCompatible().
707 if (new_attr == nullptr) {
708 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
709 "' in op: ", SummarizeOpDef(new_op));
710 }
711 if (!penultimate_attr.has_default_value() ||
712 !new_attr->has_default_value()) {
713 return errors::InvalidArgument("Missing default for attr '",
714 penultimate_attr.name(),
715 "' in op: ", SummarizeOpDef(new_op));
716 }
717
718 // Actually test that the attr's default value hasn't changed.
719 if (!AreAttrValuesEqual(penultimate_attr.default_value(),
720 new_attr->default_value())) {
721 return errors::InvalidArgument(
722 "Can't change default value for attr '", penultimate_attr.name(),
723 "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
724 " in op: ", SummarizeOpDef(new_op));
725 }
726 }
727
728 return OkStatus();
729}
730
731Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
732 AttrMap new_attrs, old_attrs;
733 FillAttrMap(old_op, &old_attrs);
734 FillAttrMap(new_op, &new_attrs);
735
736 for (const auto& old_attr : old_op.attr()) {
737 const OpDef::AttrDef* new_attr =
738 gtl::FindPtrOrNull(new_attrs, old_attr.name());
739 if (new_attr == nullptr) continue;
740 if (new_attr->has_default_value() && !old_attr.has_default_value()) {
741 continue; // Adding new default values is safe.
742 }
743 if (old_attr.has_default_value() && !new_attr->has_default_value()) {
744 return errors::InvalidArgument(
745 "Attr '", old_attr.name(), "' has removed it's default; ", "from ",
746 DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
747 }
748 if (old_attr.has_default_value() &&
749 !AreAttrValuesEqual(old_attr.default_value(),
750 new_attr->default_value())) {
751 return errors::InvalidArgument(
752 "Attr '", old_attr.name(), "' has changed it's default value; ",
753 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
754 }
755 }
756
757 return OkStatus();
758}
759
760void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
761 for (int i = 0; i < op_def->input_arg_size(); ++i) {
762 op_def->mutable_input_arg(i)->clear_description();
763 }
764 for (int i = 0; i < op_def->output_arg_size(); ++i) {
765 op_def->mutable_output_arg(i)->clear_description();
766 }
767 for (int i = 0; i < op_def->attr_size(); ++i) {
768 op_def->mutable_attr(i)->clear_description();
769 }
770 op_def->clear_summary();
771 op_def->clear_description();
772}
773
774void RemoveDescriptionsFromOpDef(OpDef* op_def) {
775 RemoveNonDeprecationDescriptionsFromOpDef(op_def);
776 if (op_def->has_deprecation()) {
777 op_def->mutable_deprecation()->clear_explanation();
778 }
779}
780
781void RemoveDescriptionsFromOpList(OpList* op_list) {
782 for (int i = 0; i < op_list->op_size(); ++i) {
783 OpDef* op_def = op_list->mutable_op(i);
784 RemoveDescriptionsFromOpDef(op_def);
785 }
786}
787
788bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
789 if (std::is_base_of<protobuf::Message, OpDef::AttrDef>()) {
790 DCHECK_EQ(7, reinterpret_cast<const protobuf::Message*>(&a1)
791 ->GetDescriptor()
792 ->field_count())
793 << "Please modify these equality and hash functions to reflect the "
794 "changes to the AttrDef protobuf";
795 }
796
797 if (a1.name() != a2.name()) return false;
798 if (a1.type() != a2.type()) return false;
799 if (a1.description() != a2.description()) return false;
800 if (a1.has_minimum() != a2.has_minimum()) return false;
801 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
802 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
803 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
804 return false;
805 return true;
806}
807
808uint64 AttrDefHash(const OpDef::AttrDef& a) {
809 uint64 h = Hash64(a.name());
810 h = Hash64(a.type().data(), a.type().size(), h);
811 h = Hash64Combine(AttrValueHash(a.default_value()), h);
812 h = Hash64(a.description().data(), a.description().size(), h);
813 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
814 h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
815 h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
816 return h;
817}
818
819bool RepeatedAttrDefEqual(
820 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
821 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
822 std::unordered_map<string, const OpDef::AttrDef*> a1_set;
823 for (const OpDef::AttrDef& def : a1) {
824 if (a1_set.find(def.name()) != a1_set.end()) {
825 LOG(ERROR) << "AttrDef names must be unique, but '" << def.name()
826 << "' appears more than once";
827 }
828 a1_set[def.name()] = &def;
829 }
830 for (const OpDef::AttrDef& def : a2) {
831 auto iter = a1_set.find(def.name());
832 if (iter == a1_set.end()) return false;
833 if (!AttrDefEqual(*iter->second, def)) return false;
834 a1_set.erase(iter);
835 }
836 if (!a1_set.empty()) return false;
837 return true;
838}
839
840uint64 RepeatedAttrDefHash(
841 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
842 // Insert AttrDefs into map to deterministically sort by name
843 std::map<string, const OpDef::AttrDef*> a_set;
844 for (const OpDef::AttrDef& def : a) {
845 a_set[def.name()] = &def;
846 }
847 // Iterate and combines hashes of keys and values
848 uint64 h = 0xDECAFCAFFE;
849 for (const auto& pair : a_set) {
850 h = Hash64(pair.first.data(), pair.first.size(), h);
851 h = Hash64Combine(AttrDefHash(*pair.second), h);
852 }
853 return h;
854}
855
856bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
857 // attr order doesn't matter.
858 // Compare it separately here instead of serializing below.
859 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
860
861 // `control_output` order doesn't matter.
862 std::set<string> control_output1(o1.control_output().begin(),
863 o1.control_output().end());
864 std::set<string> control_output2(o2.control_output().begin(),
865 o2.control_output().end());
866 if (control_output1 != control_output2) return false;
867
868 // Clear `attr` and `control_output` fields, serialize, and compare serialized
869 // strings.
870 OpDef o1_copy = o1;
871 OpDef o2_copy = o2;
872 o1_copy.clear_attr();
873 o1_copy.clear_control_output();
874 o2_copy.clear_attr();
875 o2_copy.clear_control_output();
876
877 return AreSerializedProtosEqual(o1_copy, o2_copy);
878}
879
880uint64 OpDefHash(const OpDef& o) {
881 uint64 h = RepeatedAttrDefHash(o.attr());
882
883 // Compute deterministic order-independent control outputs hash.
884 std::set<string> control_output(o.control_output().begin(),
885 o.control_output().end());
886 for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
887
888 OpDef o_copy = o;
889 o_copy.clear_attr();
890 o_copy.clear_control_output();
891 return DeterministicProtoHash64(o_copy, h);
892}
893
894} // namespace tensorflow
895