1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/full_type_util.h"
17
18#include <algorithm>
19#include <string>
20
21#include "absl/container/flat_hash_map.h"
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/full_type.pb.h"
24#include "tensorflow/core/framework/node_def.pb.h"
25#include "tensorflow/core/framework/node_def_util.h"
26#include "tensorflow/core/framework/op_def.pb.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/platform/hash.h"
30#include "tensorflow/core/platform/statusor.h"
31#include "tensorflow/core/protobuf/error_codes.pb.h"
32
33namespace tensorflow {
34
35namespace full_type {
36
37OpTypeConstructor NoOp() {
38 return nullptr;
39}
40
41OpTypeConstructor NoOutputs() {
42 return [](OpDef* op_def) {
43 op_def->mutable_output_arg();
44 return OkStatus();
45 };
46}
47
48OpTypeConstructor Nullary(FullTypeId t) {
49 return [t](OpDef* op_def) {
50 FullTypeDef* tdef =
51 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
52 tdef->set_type_id(t);
53 return OkStatus();
54 };
55}
56
57OpTypeConstructor Unary(FullTypeId t, const string& var_name) {
58 return [t, var_name](OpDef* op_def) {
59 FullTypeDef* tdef =
60 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
61 tdef->set_type_id(t);
62
63 FullTypeDef* arg = tdef->add_args();
64 arg->set_type_id(TFT_VAR);
65 arg->set_s(var_name);
66
67 return OkStatus();
68 };
69}
70
71OpTypeConstructor UnaryGeneric(FullTypeId t) {
72 return [t](OpDef* op_def) {
73 FullTypeDef* tdef =
74 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
75 tdef->set_type_id(t);
76
77 FullTypeDef* arg = tdef->add_args();
78 arg->set_type_id(TFT_ANY);
79
80 return OkStatus();
81 };
82}
83
84OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) {
85 return [t, dtype](OpDef* op_def) {
86 FullTypeDef* tdef =
87 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
88 tdef->set_type_id(t);
89
90 FullTypeDef* arg = tdef->add_args();
91 arg->set_type_id(TFT_TENSOR);
92 FullTypeDef* targ = arg->add_args();
93 targ->set_type_id(dtype);
94
95 return OkStatus();
96 };
97}
98
99OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) {
100 return [t, var_name](OpDef* op_def) {
101 FullTypeDef* tdef =
102 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
103 tdef->set_type_id(t);
104
105 FullTypeDef* targ = tdef->add_args();
106 targ->set_type_id(TFT_TENSOR);
107 FullTypeDef* varg = targ->add_args();
108 varg->set_type_id(TFT_VAR);
109 varg->set_s(var_name);
110
111 return OkStatus();
112 };
113}
114
115OpTypeConstructor VariadicTensorContainer(FullTypeId t,
116 const string& var_name) {
117 return [t, var_name](OpDef* op_def) {
118 FullTypeDef* tdef =
119 op_def->mutable_output_arg(0)->mutable_experimental_full_type();
120 tdef->set_type_id(t);
121
122 FullTypeDef* for_each = tdef->add_args();
123 for_each->set_type_id(TFT_FOR_EACH);
124 for_each->add_args()->set_type_id(TFT_PRODUCT);
125
126 FullTypeDef* tpl = for_each->add_args();
127 tpl->set_type_id(TFT_TENSOR);
128 FullTypeDef* targ = tpl->add_args();
129 targ->set_type_id(TFT_VAR);
130 targ->set_s(var_name);
131
132 FullTypeDef* tvar = for_each->add_args();
133 tvar->set_type_id(TFT_VAR);
134 tvar->set_s(var_name);
135
136 return OkStatus();
137 };
138}
139
140namespace {
141
142typedef absl::flat_hash_map<StringPiece, const AttrValue*> AttrMap;
143
144inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t);
145
146Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) {
147 DCHECK_EQ(t.args_size(), 0);
148
149 StringPiece var_name = t.s();
150 if (!attrs.contains(var_name)) {
151 return Status(
152 error::INVALID_ARGUMENT,
153 absl::StrCat("could not find an attribute for key '", var_name, "'"));
154 }
155 const AttrValue* attr = attrs.at(var_name);
156
157 const auto attr_type = attr->value_case();
158 if (attr_type == AttrValue::kType) {
159 map_dtype_to_tensor(attr->type(), t);
160 } else if (attr_type == AttrValue::kList) {
161 const auto& attr_list = attr->list();
162 if (attr_list.type_size() != 1) {
163 return Status(error::UNIMPLEMENTED,
164 absl::StrCat("lists or other than one type element\n",
165 attr_list.DebugString(), "\nkey=", var_name));
166 }
167 map_dtype_to_tensor(attr_list.type(0), t);
168 } else {
169 return Status(error::UNIMPLEMENTED,
170 absl::StrCat("unsupported attribute type ",
171 attr->DebugString(), " for name ", var_name));
172 }
173 t.clear_s();
174 return OkStatus();
175}
176
177Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) {
178 if (t.args_size() != 3) {
179 return Status(error::INVALID_ARGUMENT,
180 absl::StrCat("illegal FOR_EACH type, expected 3 args, got ",
181 t.args_size()));
182 }
183
184 const auto& cont = t.args(0);
185 const auto& tmpl = t.args(1);
186 const auto& t_var = t.args(2);
187
188 StringPiece var_name = t_var.s();
189 if (!attrs.contains(var_name)) {
190 return Status(
191 error::INVALID_ARGUMENT,
192 absl::StrCat("could not find an attribute for key '", var_name, "'"));
193 }
194 const AttrValue* attr = attrs.at(var_name);
195
196 FullTypeDef result;
197 result.set_type_id(cont.type_id());
198
199 const auto attr_type = attr->value_case();
200 if (attr_type == AttrValue::kType) {
201 FullTypeDef* target = result.add_args();
202 *target = tmpl;
203 TF_RETURN_WITH_CONTEXT_IF_ERROR(
204 SubstituteFromAttrs(attrs, *target), "while substituting '", var_name,
205 "' from\n", attr->DebugString(), "\ninto ", target->DebugString());
206
207 } else if (attr_type == AttrValue::kList) {
208 const auto& attr_list = attr->list();
209 int tsize = attr_list.type_size();
210 if (tsize == 0) {
211 return Status(error::UNIMPLEMENTED,
212 absl::StrCat("unsupported list attribute type\n",
213 attr_list.DebugString(), "\nkey=", var_name));
214 }
215 AttrValue replacement;
216 attrs[var_name] = &replacement;
217 for (int i = 0; i < tsize; i++) {
218 replacement.set_type(attr_list.type(i));
219 FullTypeDef* target = result.add_args();
220 *target = tmpl;
221 TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *target),
222 "while substituting '", var_name,
223 "' from\n", attr->DebugString(), "\n[", i,
224 "] into\n", target->DebugString());
225 }
226 // In case of error, it's ok for the attributes map to remain in an invalid
227 // state.
228 attrs[var_name] = attr;
229
230 } else {
231 return Status(error::UNIMPLEMENTED,
232 absl::StrCat("unsupported attribute type\n",
233 attr->DebugString(), "\nfor name ", var_name));
234 }
235 t = result;
236 return OkStatus();
237}
238
239Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) {
240 int nargs = t.args_size();
241 for (int j = 0; j < nargs; j++) {
242 FullTypeDef* arg_t = t.mutable_args(j);
243 TF_RETURN_WITH_CONTEXT_IF_ERROR(SubstituteFromAttrs(attrs, *arg_t),
244 "while substituting arg ", j, ": ",
245 arg_t->DebugString());
246
247 // Special case for DT_VARIANT tensors. We leave those unset to avoid even
248 // more special casing downstream.
249 if (arg_t->type_id() == TFT_TENSOR && arg_t->args_size() &&
250 arg_t->args(0).type_id() == TFT_LEGACY_VARIANT) {
251 t.clear_args();
252 break;
253 }
254 }
255 return OkStatus();
256}
257
258inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) {
259 // Resolve dependent types. The convention for op registrations is to use
260 // attributes as type variables.
261 // See https://www.tensorflow.org/guide/create_op#type_polymorphism.
262 // Once the op signature can be defined entirely in FullType, this
263 // convention can be deprecated.
264 //
265 // Note: While this code performs some basic verifications, it generally
266 // assumes consistent op defs and attributes. If more complete
267 // verifications are needed, they should be done by separately, and in a
268 // way that can be reused for type inference.
269 switch (t.type_id()) {
270 case TFT_VAR:
271 return SubstituteVar(attrs, t);
272
273 case TFT_FOR_EACH:
274 return SubstituteForEach(attrs, t);
275
276 default:
277 return SubstituteGeneric(attrs, t);
278 }
279 return OkStatus();
280}
281
282} // namespace
283
284Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def,
285 FullTypeDef& target) {
286 target.Clear();
287 target.set_type_id(TFT_PRODUCT);
288
289 AttrMap map;
290 for (const auto& attr : attrs) {
291 map.emplace(attr.first, &attr.second);
292 }
293
294 int nargs = op_def.output_arg_size();
295 for (int i = 0; i < nargs; i++) {
296 auto& t = *(target.add_args());
297 t = op_def.output_arg(i).experimental_full_type();
298 TF_RETURN_WITH_CONTEXT_IF_ERROR(
299 SubstituteFromAttrs(map, t), "while expanding vars of\n",
300 t.DebugString(), "\nfrom\n", attrs.SummarizeNode());
301 }
302
303 return OkStatus();
304}
305
306const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) {
307 static FullTypeDef* unset_type = []() {
308 FullTypeDef* t = new FullTypeDef();
309 return t;
310 }();
311
312 if (i < t.args_size()) {
313 return t.args(i);
314 }
315 return *unset_type;
316}
317
318const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i) {
319 static FullTypeDef* any_type = []() {
320 FullTypeDef* t = new FullTypeDef();
321 t->set_type_id(TFT_ANY);
322 return t;
323 }();
324
325 if (i < t.args_size()) {
326 const FullTypeDef& f_val = t.args(i);
327 if (f_val.type_id() == TFT_UNSET) {
328 return *any_type;
329 }
330 return f_val;
331 }
332 return *any_type;
333}
334
335bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs) {
336 if (lhs.type_id() != rhs.type_id()) {
337 return false;
338 }
339 const auto& lhs_s = lhs.s();
340 const auto& rhs_s = rhs.s();
341 if (lhs_s.empty()) {
342 if (!rhs_s.empty()) {
343 return false;
344 }
345 } else if (rhs_s != lhs_s) {
346 return false;
347 }
348 for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
349 const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
350 const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
351
352 if (!IsEqual(lhs_arg, rhs_arg)) {
353 return false;
354 }
355 }
356 return true;
357}
358
359uint64_t Hash(const FullTypeDef& arg) {
360 // Following style of IsEqual above and walking across FullTypeDef.
361 uint64_t val = Hash64Combine(arg.type_id(), 0);
362
363 const auto& arg_s = arg.s();
364 val = Hash64Combine(val, Hash64(arg_s));
365 for (int i = 0, e = arg.args_size(); i < e; ++i) {
366 const FullTypeDef& arg_arg = GetArgDefaultAny(arg, i);
367 val = Hash64Combine(val, Hash(arg_arg));
368 }
369
370 return val;
371}
372
373bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs, bool covariant) {
374 // Rule: ANY is a supertype of all types.
375 if (rhs.type_id() == TFT_ANY) {
376 return true;
377 }
378 // Compatibility rule: UNSET is treated as ANY for the purpose of subtyping.
379 if (rhs.type_id() == TFT_UNSET) {
380 return true;
381 }
382 // Compatibility rule: TENSOR[LEGACY_VARIANT] is treated as ANY for the
383 // purpose of subtyping.
384 if ((rhs.type_id() == TFT_TENSOR) &&
385 (GetArgDefaultUnset(rhs, 0).type_id() == TFT_LEGACY_VARIANT)) {
386 return true;
387 }
388 // Rule: encodings are subtypes of the encoding type.
389 if (lhs.type_id() == TFT_ENCODED) {
390 return IsSubtype(GetArgDefaultAny(lhs, 1), rhs, true);
391 }
392
393 // Default rule: type IDs must match.
394 if (lhs.type_id() != rhs.type_id()) {
395 return false;
396 }
397
398 // Arguments must be subtypes of one another.
399 for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
400 const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
401 const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
402
403 if (covariant) {
404 if (!IsSubtype(lhs_arg, rhs_arg)) {
405 return false;
406 }
407 } else {
408 if (!IsSubtype(rhs_arg, lhs_arg)) {
409 return false;
410 }
411 }
412 }
413
414 // Invariant: type IDs are equal, and all args are subtype of one another.
415 return true;
416}
417
418} // namespace full_type
419
420} // namespace tensorflow
421