1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ |
18 | |
19 | #include <functional> |
20 | #include <string> |
21 | |
22 | #include "tensorflow/core/framework/full_type.pb.h" |
23 | #include "tensorflow/core/framework/op_def_builder.h" |
24 | #include "tensorflow/core/platform/statusor.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace full_type { |
29 | |
30 | // TODO(mdan): Specific helpers won't get too far. Use a parser instead. |
31 | |
32 | // Helpers that allow shorthand expression for the more common kinds of type |
33 | // inference functions. |
34 | // TODO(mdan): Break into separate header if it grows. |
35 | // Note: The information contained in these functions is also expressed to some |
36 | // extent by opdef attributes of the kind "input: T, output T". But in that |
37 | // context, T has strong DType semantics (i.e. T is DT_VARIANT for most |
38 | // interesting cases). The logic here extends to the op's FullType, so it's best |
39 | // to keep them separate, even though it leads to some redundancy. The |
40 | // same can be said about the shape inference function. |
41 | |
42 | // Note: Unlike type constructors, which describe op definitions, type inference |
43 | // functions are meant to modify the type information of specific nodes (i.e. |
44 | // NodeDef proto). |
45 | |
46 | // Helper for a no-op type inference function that indicates type inference |
47 | // should never alter the node's existing type. |
48 | // This is the same as not defining a type inference function at all, but |
49 | // explicitly communicates that intent. |
50 | ForwardTypeInferenceFn KeepExisting(); |
51 | |
52 | // Helper for a type inference function which has the same type as the i'th |
53 | // input. |
54 | // The n arg allows multiple outputs, e.g. (T -> Product[T, T]). |
55 | // TODO(mdan): Drop defaults for readability if more non-(0, 1) cases appear. |
56 | // TODO(mdan): Rename to just Replicate. |
57 | ForwardTypeInferenceFn ReplicateInput(int i = 0, int n = 1); |
58 | |
59 | // Helper for a type inference function which has the same type as a variadic |
60 | // number of inputs, e.g. (T, T -> Product[T]), (T, T, T -> Product[T]), etc. |
61 | // Infers the meet of the input types, in the sense of type meets (see |
62 | // https://en.wikipedia.org/wiki/Join_and_meet). This implementation is |
63 | // simplified to require the two inputs are a subtype of another. |
64 | ForwardTypeInferenceFn Merge(); |
65 | |
66 | // Helper for ops with semantics of encoding an input, that is, |
67 | // `T -> Encoded[T, <t>]`, where <t> is the encoded type. |
68 | ForwardTypeInferenceFn Encode(FullTypeId t, int i); |
69 | |
70 | // Helper for ops with semantics of encoding an input, that is, |
71 | // `Encoded[T, <t>] -> T`, where <t> is the encoded type. |
72 | ForwardTypeInferenceFn Decode(FullTypeId t, int i); |
73 | |
74 | // Helper for the type inference counterpart of Unary, that is (U -> |
75 | // PRODUCT[<t>[U]]), where <t> is parameterized by this factory, and U is the |
76 | // type of the input specified by element_idx. |
77 | // Note: when we migrate to a more formal type definition of an op, these two |
78 | // functions will naturally merge. |
79 | ForwardTypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx); |
80 | |
81 | // Helper for ops with semantics of adding an element to a container (<t>[T]), |
82 | // that is (<t>[U], V -> PRODUCT[<t>[Union[U, V]]]), where <t> is parameterized |
83 | // by this factory, U is the type of the input specified by container_idx, and V |
84 | // is the type of the input specified by element_idx. The homogeneous arg allows |
85 | // for constraints which guarantee that U and V must have a subtyping |
86 | // relationship, case in which either V or U is selected, whichever is the |
87 | // supertype. |
88 | ForwardTypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx, |
89 | int element_idx, bool homogeneous); |
90 | |
91 | // Helper for ops with semantics of unstacking multiple inputs into a container |
92 | // `<t>[T1, ..., Tn]`, that is `T1, ..., Tn -> <t>[PRODUCT[U1, ..., Un]]` |
93 | // where Ui is obtained from an "unstack" mapping T -> U. Both <t> and the |
94 | // "unstack" mapping are parameterized by this factory. |
95 | // Note that when the "unstack" function is the identity function, this becomes |
96 | // equivalent to ContainerCreate. |
97 | ForwardTypeInferenceFn MultiaryUnstack( |
98 | FullTypeId t, std::function<FullTypeDef(const FullTypeDef&)> unstack); |
99 | |
100 | // Helper for ops with semantics of applying some transformation to the |
101 | // elements of a container: |
102 | // `<t>[PRODUCT[T1, ..., Tn]] -> <t>[PRODUCT[U1, ..., Un]]`, |
103 | // where Ui is obtained by applying a map T -> U. Both <t> and the "map" |
104 | // function are parameterized by this factory. See BatchTensor and ShardTensor |
105 | // for examples of "map". |
106 | ForwardTypeInferenceFn ContainerMap( |
107 | FullTypeId t, int input_idx, |
108 | std::function<FullTypeDef(const FullTypeDef&)> map); |
109 | |
110 | // Helper for ops with semantics of repacking some element from a container to |
111 | // another `<t> -> <u>`, in a covariant way, that is, `<t>[T] -> <u>[T]`. <t> |
112 | // and <u> are parameterized by this factory. The input type is specified by |
113 | // element_idx. |
114 | ForwardTypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx); |
115 | |
116 | // Auxiliary constructs to help creation of type inference functions. |
117 | // TODO(mdan): define these as type inference functions as well. |
118 | |
119 | // Mapping function representing the type function for unstacking of |
120 | // Tensor (or Tensor-like) types. Note that this is a helper to use with |
121 | // other type inference functions; it's not a function itself. |
122 | // TODO(mdan): Replace with a trait, when available. |
123 | FullTypeDef UnstackTensor(const FullTypeDef& t); |
124 | |
125 | // Mapping function representing the type function for an op that changes the |
126 | // batch size of dataset. Note that this is a helper to use with other type |
127 | // inference functions; it's not a function itself. |
128 | // TODO(mdan): Replace with a trait, when available. |
129 | FullTypeDef BatchTensor(const FullTypeDef& t); |
130 | |
131 | // Mapping function representing the type function for an op that creates a |
132 | // fixed (given) number of tensors of a size calculated based on the input. Note |
133 | // that this is a helper to use with other type inference functions; it's not a |
134 | // function itself. |
135 | // TODO(mdan): Replace with a trait, when available. |
136 | FullTypeDef ShardTensor(const FullTypeDef& t); |
137 | } // namespace full_type |
138 | |
139 | } // namespace tensorflow |
140 | |
141 | #endif // TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ |
142 | |