1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_
17#define TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_
18
19#include <functional>
20#include <string>
21
22#include "tensorflow/core/framework/full_type.pb.h"
23#include "tensorflow/core/framework/op_def_builder.h"
24#include "tensorflow/core/platform/statusor.h"
25
26namespace tensorflow {
27
28namespace full_type {
29
30// TODO(mdan): Specific helpers won't get too far. Use a parser instead.
31
32// Helpers that allow shorthand expression for the more common kinds of type
33// inference functions.
34// TODO(mdan): Break into separate header if it grows.
35// Note: The information contained in these functions is also expressed to some
36// extent by opdef attributes of the kind "input: T, output T". But in that
37// context, T has strong DType semantics (i.e. T is DT_VARIANT for most
38// interesting cases). The logic here extends to the op's FullType, so it's best
39// to keep them separate, even though it leads to some redundancy. The
40// same can be said about the shape inference function.
41
42// Note: Unlike type constructors, which describe op definitions, type inference
43// functions are meant to modify the type information of specific nodes (i.e.
44// NodeDef proto).
45
46// Helper for a no-op type inference function that indicates type inference
47// should never alter the node's existing type.
48// This is the same as not defining a type inference function at all, but
49// explicitly communicates that intent.
50ForwardTypeInferenceFn KeepExisting();
51
52// Helper for a type inference function which has the same type as the i'th
53// input.
54// The n arg allows multiple outputs, e.g. (T -> Product[T, T]).
55// TODO(mdan): Drop defaults for readability if more non-(0, 1) cases appear.
56// TODO(mdan): Rename to just Replicate.
57ForwardTypeInferenceFn ReplicateInput(int i = 0, int n = 1);
58
59// Helper for a type inference function which has the same type as a variadic
60// number of inputs, e.g. (T, T -> Product[T]), (T, T, T -> Product[T]), etc.
61// Infers the meet of the input types, in the sense of type meets (see
62// https://en.wikipedia.org/wiki/Join_and_meet). This implementation is
63// simplified to require the two inputs are a subtype of another.
64ForwardTypeInferenceFn Merge();
65
66// Helper for ops with semantics of encoding an input, that is,
67// `T -> Encoded[T, <t>]`, where <t> is the encoded type.
68ForwardTypeInferenceFn Encode(FullTypeId t, int i);
69
70// Helper for ops with semantics of encoding an input, that is,
71// `Encoded[T, <t>] -> T`, where <t> is the encoded type.
72ForwardTypeInferenceFn Decode(FullTypeId t, int i);
73
74// Helper for the type inference counterpart of Unary, that is (U ->
75// PRODUCT[<t>[U]]), where <t> is parameterized by this factory, and U is the
76// type of the input specified by element_idx.
77// Note: when we migrate to a more formal type definition of an op, these two
78// functions will naturally merge.
79ForwardTypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx);
80
81// Helper for ops with semantics of adding an element to a container (<t>[T]),
82// that is (<t>[U], V -> PRODUCT[<t>[Union[U, V]]]), where <t> is parameterized
83// by this factory, U is the type of the input specified by container_idx, and V
84// is the type of the input specified by element_idx. The homogeneous arg allows
85// for constraints which guarantee that U and V must have a subtyping
86// relationship, case in which either V or U is selected, whichever is the
87// supertype.
88ForwardTypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx,
89 int element_idx, bool homogeneous);
90
91// Helper for ops with semantics of unstacking multiple inputs into a container
92// `<t>[T1, ..., Tn]`, that is `T1, ..., Tn -> <t>[PRODUCT[U1, ..., Un]]`
93// where Ui is obtained from an "unstack" mapping T -> U. Both <t> and the
94// "unstack" mapping are parameterized by this factory.
95// Note that when the "unstack" function is the identity function, this becomes
96// equivalent to ContainerCreate.
97ForwardTypeInferenceFn MultiaryUnstack(
98 FullTypeId t, std::function<FullTypeDef(const FullTypeDef&)> unstack);
99
100// Helper for ops with semantics of applying some transformation to the
101// elements of a container:
102// `<t>[PRODUCT[T1, ..., Tn]] -> <t>[PRODUCT[U1, ..., Un]]`,
103// where Ui is obtained by applying a map T -> U. Both <t> and the "map"
104// function are parameterized by this factory. See BatchTensor and ShardTensor
105// for examples of "map".
106ForwardTypeInferenceFn ContainerMap(
107 FullTypeId t, int input_idx,
108 std::function<FullTypeDef(const FullTypeDef&)> map);
109
110// Helper for ops with semantics of repacking some element from a container to
111// another `<t> -> <u>`, in a covariant way, that is, `<t>[T] -> <u>[T]`. <t>
112// and <u> are parameterized by this factory. The input type is specified by
113// element_idx.
114ForwardTypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx);
115
116// Auxiliary constructs to help creation of type inference functions.
117// TODO(mdan): define these as type inference functions as well.
118
119// Mapping function representing the type function for unstacking of
120// Tensor (or Tensor-like) types. Note that this is a helper to use with
121// other type inference functions; it's not a function itself.
122// TODO(mdan): Replace with a trait, when available.
123FullTypeDef UnstackTensor(const FullTypeDef& t);
124
125// Mapping function representing the type function for an op that changes the
126// batch size of dataset. Note that this is a helper to use with other type
127// inference functions; it's not a function itself.
128// TODO(mdan): Replace with a trait, when available.
129FullTypeDef BatchTensor(const FullTypeDef& t);
130
131// Mapping function representing the type function for an op that creates a
132// fixed (given) number of tensors of a size calculated based on the input. Note
133// that this is a helper to use with other type inference functions; it's not a
134// function itself.
135// TODO(mdan): Replace with a trait, when available.
136FullTypeDef ShardTensor(const FullTypeDef& t);
137} // namespace full_type
138
139} // namespace tensorflow
140
141#endif // TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_
142