1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/full_type_util.h" |
17 | |
18 | #include <algorithm> |
19 | #include <string> |
20 | |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/full_type.pb.h" |
24 | #include "tensorflow/core/framework/node_def.pb.h" |
25 | #include "tensorflow/core/framework/node_def_util.h" |
26 | #include "tensorflow/core/framework/op_def.pb.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/hash.h" |
30 | #include "tensorflow/core/platform/statusor.h" |
31 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | namespace full_type { |
36 | |
37 | OpTypeConstructor NoOp() { |
38 | return nullptr; |
39 | } |
40 | |
41 | OpTypeConstructor NoOutputs() { |
42 | return [](OpDef* op_def) { |
43 | op_def->mutable_output_arg(); |
44 | return OkStatus(); |
45 | }; |
46 | } |
47 | |
48 | OpTypeConstructor Nullary(FullTypeId t) { |
49 | return [t](OpDef* op_def) { |
50 | FullTypeDef* tdef = |
51 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
52 | tdef->set_type_id(t); |
53 | return OkStatus(); |
54 | }; |
55 | } |
56 | |
57 | OpTypeConstructor Unary(FullTypeId t, const string& var_name) { |
58 | return [t, var_name](OpDef* op_def) { |
59 | FullTypeDef* tdef = |
60 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
61 | tdef->set_type_id(t); |
62 | |
63 | FullTypeDef* arg = tdef->add_args(); |
64 | arg->set_type_id(TFT_VAR); |
65 | arg->set_s(var_name); |
66 | |
67 | return OkStatus(); |
68 | }; |
69 | } |
70 | |
71 | OpTypeConstructor UnaryGeneric(FullTypeId t) { |
72 | return [t](OpDef* op_def) { |
73 | FullTypeDef* tdef = |
74 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
75 | tdef->set_type_id(t); |
76 | |
77 | FullTypeDef* arg = tdef->add_args(); |
78 | arg->set_type_id(TFT_ANY); |
79 | |
80 | return OkStatus(); |
81 | }; |
82 | } |
83 | |
84 | OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) { |
85 | return [t, dtype](OpDef* op_def) { |
86 | FullTypeDef* tdef = |
87 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
88 | tdef->set_type_id(t); |
89 | |
90 | FullTypeDef* arg = tdef->add_args(); |
91 | arg->set_type_id(TFT_TENSOR); |
92 | FullTypeDef* targ = arg->add_args(); |
93 | targ->set_type_id(dtype); |
94 | |
95 | return OkStatus(); |
96 | }; |
97 | } |
98 | |
99 | OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) { |
100 | return [t, var_name](OpDef* op_def) { |
101 | FullTypeDef* tdef = |
102 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
103 | tdef->set_type_id(t); |
104 | |
105 | FullTypeDef* targ = tdef->add_args(); |
106 | targ->set_type_id(TFT_TENSOR); |
107 | FullTypeDef* varg = targ->add_args(); |
108 | varg->set_type_id(TFT_VAR); |
109 | varg->set_s(var_name); |
110 | |
111 | return OkStatus(); |
112 | }; |
113 | } |
114 | |
115 | OpTypeConstructor VariadicTensorContainer(FullTypeId t, |
116 | const string& var_name) { |
117 | return [t, var_name](OpDef* op_def) { |
118 | FullTypeDef* tdef = |
119 | op_def->mutable_output_arg(0)->mutable_experimental_full_type(); |
120 | tdef->set_type_id(t); |
121 | |
122 | FullTypeDef* for_each = tdef->add_args(); |
123 | for_each->set_type_id(TFT_FOR_EACH); |
124 | for_each->add_args()->set_type_id(TFT_PRODUCT); |
125 | |
126 | FullTypeDef* tpl = for_each->add_args(); |
127 | tpl->set_type_id(TFT_TENSOR); |
128 | FullTypeDef* targ = tpl->add_args(); |
129 | targ->set_type_id(TFT_VAR); |
130 | targ->set_s(var_name); |
131 | |
132 | FullTypeDef* tvar = for_each->add_args(); |
133 | tvar->set_type_id(TFT_VAR); |
134 | tvar->set_s(var_name); |
135 | |
136 | return OkStatus(); |
137 | }; |
138 | } |
139 | |
140 | namespace { |
141 | |
142 | typedef absl::flat_hash_map<StringPiece, const AttrValue*> AttrMap; |
143 | |
144 | inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t); |
145 | |
146 | Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { |
147 | DCHECK_EQ(t.args_size(), 0); |
148 | |
149 | StringPiece var_name = t.s(); |
150 | if (!attrs.contains(var_name)) { |
151 | return Status( |
152 | error::INVALID_ARGUMENT, |
153 | absl::StrCat("could not find an attribute for key '" , var_name, "'" )); |
154 | } |
155 | const AttrValue* attr = attrs.at(var_name); |
156 | |
157 | const auto attr_type = attr->value_case(); |
158 | if (attr_type == AttrValue::kType) { |
159 | map_dtype_to_tensor(attr->type(), t); |
160 | } else if (attr_type == AttrValue::kList) { |
161 | const auto& attr_list = attr->list(); |
162 | if (attr_list.type_size() != 1) { |
163 | return Status(error::UNIMPLEMENTED, |
164 | absl::StrCat("lists or other than one type element\n" , |
165 | attr_list.DebugString(), "\nkey=" , var_name)); |
166 | } |
167 | map_dtype_to_tensor(attr_list.type(0), t); |
168 | } else { |
169 | return Status(error::UNIMPLEMENTED, |
170 | absl::StrCat("unsupported attribute type " , |
171 | attr->DebugString(), " for name " , var_name)); |
172 | } |
173 | t.clear_s(); |
174 | return OkStatus(); |
175 | } |
176 | |
177 | Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { |
178 | if (t.args_size() != 3) { |
179 | return Status(error::INVALID_ARGUMENT, |
180 | absl::StrCat("illegal FOR_EACH type, expected 3 args, got " , |
181 | t.args_size())); |
182 | } |
183 | |
184 | const auto& cont = t.args(0); |
185 | const auto& tmpl = t.args(1); |
186 | const auto& t_var = t.args(2); |
187 | |
188 | StringPiece var_name = t_var.s(); |
189 | if (!attrs.contains(var_name)) { |
190 | return Status( |
191 | error::INVALID_ARGUMENT, |
192 | absl::StrCat("could not find an attribute for key '" , var_name, "'" )); |
193 | } |
194 | const AttrValue* attr = attrs.at(var_name); |
195 | |
196 | FullTypeDef result; |
197 | result.set_type_id(cont.type_id()); |
198 | |
199 | const auto attr_type = attr->value_case(); |
200 | if (attr_type == AttrValue::kType) { |
201 | FullTypeDef* target = result.add_args(); |
202 | *target = tmpl; |
203 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
204 | SubstituteFromAttrs(attrs, *target), "while substituting '" , var_name, |
205 | "' from\n" , attr->DebugString(), "\ninto " , target->DebugString()); |
206 | |
207 | } else if (attr_type == AttrValue::kList) { |
208 | const auto& attr_list = attr->list(); |
209 | int tsize = attr_list.type_size(); |
210 | if (tsize == 0) { |
211 | return Status(error::UNIMPLEMENTED, |
212 | absl::StrCat("unsupported list attribute type\n" , |
213 | attr_list.DebugString(), "\nkey=" , var_name)); |
214 | } |
215 | AttrValue replacement; |
216 | attrs[var_name] = &replacement; |
217 | for (int i = 0; i < tsize; i++) { |
218 | replacement.set_type(attr_list.type(i)); |
219 | FullTypeDef* target = result.add_args(); |
220 | *target = tmpl; |
221 | TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *target), |
222 | "while substituting '" , var_name, |
223 | "' from\n" , attr->DebugString(), "\n[" , i, |
224 | "] into\n" , target->DebugString()); |
225 | } |
226 | // In case of error, it's ok for the attributes map to remain in an invalid |
227 | // state. |
228 | attrs[var_name] = attr; |
229 | |
230 | } else { |
231 | return Status(error::UNIMPLEMENTED, |
232 | absl::StrCat("unsupported attribute type\n" , |
233 | attr->DebugString(), "\nfor name " , var_name)); |
234 | } |
235 | t = result; |
236 | return OkStatus(); |
237 | } |
238 | |
239 | Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { |
240 | int nargs = t.args_size(); |
241 | for (int j = 0; j < nargs; j++) { |
242 | FullTypeDef* arg_t = t.mutable_args(j); |
243 | TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *arg_t), |
244 | "while substituting arg " , j, ": " , |
245 | arg_t->DebugString()); |
246 | |
247 | // Special case for DT_VARIANT tensors. We leave those unset to avoid even |
248 | // more special casing downstream. |
249 | if (arg_t->type_id() == TFT_TENSOR && arg_t->args_size() && |
250 | arg_t->args(0).type_id() == TFT_LEGACY_VARIANT) { |
251 | t.clear_args(); |
252 | break; |
253 | } |
254 | } |
255 | return OkStatus(); |
256 | } |
257 | |
258 | inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { |
259 | // Resolve dependent types. The convention for op registrations is to use |
260 | // attributes as type variables. |
261 | // See https://www.tensorflow.org/guide/create_op#type_polymorphism. |
262 | // Once the op signature can be defined entirely in FullType, this |
263 | // convention can be deprecated. |
264 | // |
265 | // Note: While this code performs some basic verifications, it generally |
266 | // assumes consistent op defs and attributes. If more complete |
267 | // verifications are needed, they should be done by separately, and in a |
268 | // way that can be reused for type inference. |
269 | switch (t.type_id()) { |
270 | case TFT_VAR: |
271 | return SubstituteVar(attrs, t); |
272 | |
273 | case TFT_FOR_EACH: |
274 | return SubstituteForEach(attrs, t); |
275 | |
276 | default: |
277 | return SubstituteGeneric(attrs, t); |
278 | } |
279 | return OkStatus(); |
280 | } |
281 | |
282 | } // namespace |
283 | |
284 | Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, |
285 | FullTypeDef& target) { |
286 | target.Clear(); |
287 | target.set_type_id(TFT_PRODUCT); |
288 | |
289 | AttrMap map; |
290 | for (const auto& attr : attrs) { |
291 | map.emplace(attr.first, &attr.second); |
292 | } |
293 | |
294 | int nargs = op_def.output_arg_size(); |
295 | for (int i = 0; i < nargs; i++) { |
296 | auto& t = *(target.add_args()); |
297 | t = op_def.output_arg(i).experimental_full_type(); |
298 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
299 | SubstituteFromAttrs(map, t), "while expanding vars of\n" , |
300 | t.DebugString(), "\nfrom\n" , attrs.SummarizeNode()); |
301 | } |
302 | |
303 | return OkStatus(); |
304 | } |
305 | |
306 | const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) { |
307 | static FullTypeDef* unset_type = []() { |
308 | FullTypeDef* t = new FullTypeDef(); |
309 | return t; |
310 | }(); |
311 | |
312 | if (i < t.args_size()) { |
313 | return t.args(i); |
314 | } |
315 | return *unset_type; |
316 | } |
317 | |
318 | const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i) { |
319 | static FullTypeDef* any_type = []() { |
320 | FullTypeDef* t = new FullTypeDef(); |
321 | t->set_type_id(TFT_ANY); |
322 | return t; |
323 | }(); |
324 | |
325 | if (i < t.args_size()) { |
326 | const FullTypeDef& f_val = t.args(i); |
327 | if (f_val.type_id() == TFT_UNSET) { |
328 | return *any_type; |
329 | } |
330 | return f_val; |
331 | } |
332 | return *any_type; |
333 | } |
334 | |
335 | bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs) { |
336 | if (lhs.type_id() != rhs.type_id()) { |
337 | return false; |
338 | } |
339 | const auto& lhs_s = lhs.s(); |
340 | const auto& rhs_s = rhs.s(); |
341 | if (lhs_s.empty()) { |
342 | if (!rhs_s.empty()) { |
343 | return false; |
344 | } |
345 | } else if (rhs_s != lhs_s) { |
346 | return false; |
347 | } |
348 | for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) { |
349 | const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i); |
350 | const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i); |
351 | |
352 | if (!IsEqual(lhs_arg, rhs_arg)) { |
353 | return false; |
354 | } |
355 | } |
356 | return true; |
357 | } |
358 | |
359 | uint64_t Hash(const FullTypeDef& arg) { |
360 | // Following style of IsEqual above and walking across FullTypeDef. |
361 | uint64_t val = Hash64Combine(arg.type_id(), 0); |
362 | |
363 | const auto& arg_s = arg.s(); |
364 | val = Hash64Combine(val, Hash64(arg_s)); |
365 | for (int i = 0, e = arg.args_size(); i < e; ++i) { |
366 | const FullTypeDef& arg_arg = GetArgDefaultAny(arg, i); |
367 | val = Hash64Combine(val, Hash(arg_arg)); |
368 | } |
369 | |
370 | return val; |
371 | } |
372 | |
373 | bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs, bool covariant) { |
374 | // Rule: ANY is a supertype of all types. |
375 | if (rhs.type_id() == TFT_ANY) { |
376 | return true; |
377 | } |
378 | // Compatibility rule: UNSET is treated as ANY for the purpose of subtyping. |
379 | if (rhs.type_id() == TFT_UNSET) { |
380 | return true; |
381 | } |
382 | // Compatibility rule: TENSOR[LEGACY_VARIANT] is treated as ANY for the |
383 | // purpose of subtyping. |
384 | if ((rhs.type_id() == TFT_TENSOR) && |
385 | (GetArgDefaultUnset(rhs, 0).type_id() == TFT_LEGACY_VARIANT)) { |
386 | return true; |
387 | } |
388 | // Rule: encodings are subtypes of the encoding type. |
389 | if (lhs.type_id() == TFT_ENCODED) { |
390 | return IsSubtype(GetArgDefaultAny(lhs, 1), rhs, true); |
391 | } |
392 | |
393 | // Default rule: type IDs must match. |
394 | if (lhs.type_id() != rhs.type_id()) { |
395 | return false; |
396 | } |
397 | |
398 | // Arguments must be subtypes of one another. |
399 | for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) { |
400 | const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i); |
401 | const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i); |
402 | |
403 | if (covariant) { |
404 | if (!IsSubtype(lhs_arg, rhs_arg)) { |
405 | return false; |
406 | } |
407 | } else { |
408 | if (!IsSubtype(rhs_arg, lhs_arg)) { |
409 | return false; |
410 | } |
411 | } |
412 | } |
413 | |
414 | // Invariant: type IDs are equal, and all args are subtype of one another. |
415 | return true; |
416 | } |
417 | |
418 | } // namespace full_type |
419 | |
420 | } // namespace tensorflow |
421 | |