1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/kernel_def_util.h" |
17 | |
18 | #include "tensorflow/core/framework/attr_value.pb.h" |
19 | #include "tensorflow/core/framework/attr_value_util.h" |
20 | #include "tensorflow/core/framework/kernel_def.pb.h" |
21 | #include "tensorflow/core/framework/node_def_util.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace { |
27 | // Helper for KernelAttrsMatch(). |
28 | bool InTypeList(DataType dt, const AttrValue& type_list) { |
29 | for (int in_list : type_list.list().type()) { |
30 | if (dt == in_list) return true; |
31 | } |
32 | return false; |
33 | } |
34 | } // namespace |
35 | |
36 | Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, |
37 | bool* match) { |
38 | *match = false; |
39 | for (const auto& constraint : kernel_def.constraint()) { |
40 | auto constraint_value_case = AttrValue::VALUE_NOT_SET; |
41 | int value_type_num = 0; |
42 | if (constraint.allowed_values().list().type_size() > 0) { |
43 | constraint_value_case = AttrValue::kType; |
44 | value_type_num++; |
45 | } |
46 | if (constraint.allowed_values().list().s_size() > 0) { |
47 | constraint_value_case = AttrValue::kS; |
48 | value_type_num++; |
49 | } |
50 | if (constraint.allowed_values().list().i_size() > 0) { |
51 | constraint_value_case = AttrValue::kI; |
52 | value_type_num++; |
53 | } |
54 | if (constraint.allowed_values().list().b_size() > 0) { |
55 | constraint_value_case = AttrValue::kB; |
56 | value_type_num++; |
57 | } |
58 | |
59 | if (value_type_num == 0) { |
60 | return errors::Unimplemented( |
61 | "KernelDef '" , kernel_def.ShortDebugString(), |
62 | " has constraint on attr '" , constraint.name(), |
63 | "' with unsupported type: " , |
64 | SummarizeAttrValue(constraint.allowed_values())); |
65 | } |
66 | if (value_type_num > 1) { |
67 | return errors::InvalidArgument( |
68 | "KernelDef '" , kernel_def.ShortDebugString(), |
69 | " has constraint on attr '" , constraint.name(), |
70 | "' with more than one value type: " , |
71 | SummarizeAttrValue(constraint.allowed_values())); |
72 | } |
73 | |
74 | const AttrValue* attr_value = attrs.Find(constraint.name()); |
75 | if (attr_value == nullptr) { |
76 | return errors::InvalidArgument( |
77 | "OpKernel '" , kernel_def.op(), "' has constraint on attr '" , |
78 | constraint.name(), "' not in NodeDef '" , attrs.SummarizeNode(), |
79 | "', KernelDef: '" , kernel_def.ShortDebugString(), "'" ); |
80 | } |
81 | |
82 | #define RETURN_IF_ATTR_NOT_FOUND(n, oneof_case, type_str) \ |
83 | do { \ |
84 | if (constraint_value_case == AttrValue::oneof_case) { \ |
85 | Status s = AttrValueHasType(*attr_value, type_str); \ |
86 | if (!s.ok()) { \ |
87 | return errors::InvalidArgument( \ |
88 | "KernelDef '", kernel_def.ShortDebugString(), \ |
89 | "' has constraint on attr '", constraint.name(), \ |
90 | "' that has value '", SummarizeAttrValue(*attr_value), \ |
91 | "' that does not have the same type in NodeDef " \ |
92 | "'", \ |
93 | attrs.SummarizeNode(), "'"); \ |
94 | } \ |
95 | bool found = false; \ |
96 | for (auto& value : constraint.allowed_values().list().n()) { \ |
97 | if (value == attr_value->n()) { \ |
98 | found = true; \ |
99 | break; \ |
100 | } \ |
101 | } \ |
102 | if (!found) { \ |
103 | return OkStatus(); \ |
104 | } \ |
105 | } \ |
106 | } while (false) |
107 | |
108 | RETURN_IF_ATTR_NOT_FOUND(s, kS, "string" ); |
109 | RETURN_IF_ATTR_NOT_FOUND(i, kI, "int" ); |
110 | RETURN_IF_ATTR_NOT_FOUND(b, kB, "bool" ); |
111 | |
112 | #undef RETURN_IF_ATTR_NOT_FOUND |
113 | |
114 | if (constraint_value_case != AttrValue::kType) { |
115 | continue; |
116 | } |
117 | |
118 | if (attr_value->type() != DT_INVALID) { |
119 | if (!InTypeList(attr_value->type(), constraint.allowed_values())) { |
120 | return OkStatus(); |
121 | } |
122 | } else { |
123 | if (!AttrValueHasType(*attr_value, "list(type)" ).ok()) { |
124 | return errors::InvalidArgument( |
125 | "KernelDef '" , kernel_def.ShortDebugString(), |
126 | "' has constraint on attr '" , constraint.name(), |
127 | "' that has value '" , SummarizeAttrValue(*attr_value), |
128 | "' that does not have type 'type' or 'list(type)' in NodeDef " |
129 | "'" , |
130 | attrs.SummarizeNode(), "'" ); |
131 | } |
132 | |
133 | for (int t : attr_value->list().type()) { |
134 | if (!InTypeList(static_cast<DataType>(t), |
135 | constraint.allowed_values())) { |
136 | return OkStatus(); |
137 | } |
138 | } |
139 | } |
140 | } |
141 | *match = true; |
142 | return OkStatus(); |
143 | } |
144 | |
145 | } // namespace tensorflow |
146 | |