1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
17#define TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
18
19#include <functional>
20#include <string>
21
22#include "tensorflow/core/framework/full_type.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/node_def_util.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/framework/op_def_builder.h"
27#include "tensorflow/core/platform/statusor.h"
28
29namespace tensorflow {
30
31namespace full_type {
32
33// TODO(mdan): Specific helpers won't get too far. Use a parser instead.
34// TODO(mdan): Move constructors into a separate file.
35
36// Helpers that allow shorthand expression for the more common kinds of type
37// constructors.
38// Note: The arity below refers to the number of arguments of parametric types,
39// not to the number of return values from a particular op.
40// Note: Type constructors are meant to create static type definitions in the
41// op definition (i.e. the OpDef proto).
42
43// Helper for a no-op type constructor that indicates that the node's type
44// should be set by external means (typically by the user).
45OpTypeConstructor NoOp();
46
47// Helper for a trivial type constructor that indicates a node has no
48// outputs (that is, its output type is an empty TFT_PRODUCT).
49OpTypeConstructor NoOutputs();
50
51// Helper for a type constructor of <t>[] (with no parameters).
52OpTypeConstructor Nullary(FullTypeId t);
53
54// Helper for a type constructor of <t>[FT_VAR[<var_name>]].
55OpTypeConstructor Unary(FullTypeId t, const string& var_name);
56
57// Helper for a type constructor of <t>[FT_ANY].
58OpTypeConstructor UnaryGeneric(FullTypeId t);
59
60// Helper for a type constructor of <t>[FT_TENSOR[<dtype>]].
61OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype);
62
63// Helper for a type constructor of <t>[FT_VAR[<var_name>]].
64OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name);
65
66// Helper for a type constructor of
67// <t>[FT_FOR_EACH[
68// FT_PRODUCT,
69// FT_TENSOR[FT_VAR[<var_name>]],
70// FT_VAR[<var_name>]].
71// Multi-valued type variables will expand the template (see full_type.proto).
72OpTypeConstructor VariadicTensorContainer(FullTypeId t, const string& var_name);
73
74// Type specialization and inference logic. This function narrows the type
75// specified in an op definition. Such types are usually generic and dependent
76// on input types. This function resolves the output types based on the input
77// types specified in a given node def.
78Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def,
79 FullTypeDef& target);
80
81const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i);
82const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i);
83
84bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs);
85
86bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs,
87 bool covariant = true);
88
89uint64_t Hash(const FullTypeDef& arg);
90
91// Determine if the given fulltype is a host memory type.
92// While it is prefered that Placer (placer.cc and colocation_graph.cc) make
93// all host memory type placement decisions, any decision made elsewhere
94// should use this function (e.g. instead of assuming that all variants never
95// contain host memory types).
96inline bool IsHostMemoryType(const FullTypeDef& t) {
97 switch (t.type_id()) {
98 case TFT_TENSOR:
99 return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
100 case TFT_ARRAY:
101 return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
102 case TFT_DATASET:
103 return true;
104 case TFT_MUTEX_LOCK:
105 return true;
106 case TFT_RAGGED:
107 return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
108 case TFT_STRING:
109 return true;
110 case TFT_ITERATOR:
111 return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
112 case TFT_OPTIONAL:
113 return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0));
114 case TFT_PRODUCT:
115 for (int i = 0; i < t.args_size(); i++) {
116 if (IsHostMemoryType(full_type::GetArgDefaultAny(t, i))) {
117 return true;
118 }
119 }
120 return false;
121 default:
122 return false;
123 }
124}
125
126} // namespace full_type
127
128} // namespace tensorflow
129
130#endif // TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_
131