1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/full_type_inference_util.h"
17
18#include <functional>
19#include <string>
20
21#include "absl/strings/str_cat.h"
22#include "tensorflow/core/framework/full_type.pb.h"
23#include "tensorflow/core/framework/full_type_util.h"
24#include "tensorflow/core/framework/op_def_builder.h"
25#include "tensorflow/core/platform/status.h"
26#include "tensorflow/core/platform/statusor.h"
27#include "tensorflow/core/protobuf/error_codes.pb.h"
28
29namespace tensorflow {
30
31namespace full_type {
32
33// Note about error handling:
34// For inputs which depend on the correctness of the op definition
35// (i.e. if the op has three inputs, don't set an `i` that exceeds that),
36// use DCHECK - an incorrect op def is considered a bug.
37// Whereas for inputs that depend on the correctness of the graph (i.e. user
38// used the correct ops), use Status - an incorrect graph is considered a user
39// error.
40
41ForwardTypeInferenceFn KeepExisting() { return nullptr; }
42
43ForwardTypeInferenceFn ReplicateInput(int i, int n) {
44 return [i, n](const TypeRefVector& input_types, const TypeRefMap& type_vars) {
45 const FullTypeDef& in_type = input_types.at(i).get();
46 FullTypeDef ret_type;
47 if (in_type.type_id() != TFT_UNSET) {
48 ret_type.set_type_id(TFT_PRODUCT);
49 for (int k = 0; k < n; k++) {
50 *(ret_type.add_args()) = in_type;
51 }
52 }
53 return ret_type;
54 };
55}
56
57ForwardTypeInferenceFn Merge() {
58 return [](const TypeRefVector& input_types,
59 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
60 DCHECK(!input_types.empty());
61
62 FullTypeDef merged;
63 for (int i = 0; i < input_types.size(); i++) {
64 const auto& t = input_types[i].get();
65
66 if (t.type_id() == TFT_UNSET) {
67 continue;
68 }
69
70 if (IsSubtype(t, merged)) {
71 merged = t;
72 continue;
73 }
74 if (IsSubtype(merged, t)) {
75 continue;
76 }
77
78 return Status(error::INVALID_ARGUMENT,
79 absl::StrCat("expected compatible input types, but input ",
80 i, ":\n", t.DebugString(),
81 " is neither a subtype nor a supertype of the "
82 "combined inputs preceding it:\n",
83 merged.DebugString()));
84 }
85
86 FullTypeDef ret_type;
87 if (merged.type_id() != TFT_UNSET) {
88 ret_type.set_type_id(TFT_PRODUCT);
89 *(ret_type.add_args()) = merged;
90 }
91 return ret_type;
92 };
93}
94
95ForwardTypeInferenceFn Encode(FullTypeId t, int i) {
96 return [t, i](const TypeRefVector& input_types,
97 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
98 DCHECK(input_types.size() >= i);
99
100 FullTypeDef ret_type;
101 const FullTypeDef& in_t = input_types[i].get();
102 if (in_t.type_id() == TFT_UNSET) {
103 return ret_type;
104 }
105
106 ret_type.set_type_id(TFT_PRODUCT);
107
108 auto* enc_type = ret_type.add_args();
109 enc_type->set_type_id(TFT_ENCODED);
110 *enc_type->add_args() = in_t;
111 enc_type->add_args()->set_type_id(t);
112 return ret_type;
113 };
114}
115
116ForwardTypeInferenceFn Decode(FullTypeId t, int i) {
117 return [t, i](const TypeRefVector& input_types,
118 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
119 DCHECK(input_types.size() >= i);
120
121 const FullTypeDef& in_t = input_types[i].get();
122
123 const FullTypeId enc_tid = GetArgDefaultUnset(in_t, 1).type_id();
124 if ((enc_tid != TFT_UNSET) && (enc_tid != t)) {
125 return Status(error::INVALID_ARGUMENT,
126 absl::StrCat("expected encoded type ", t, " for input ", i,
127 ", got ", in_t.DebugString()));
128 }
129
130 FullTypeDef ret_type;
131
132 const FullTypeDef& out_t = GetArgDefaultUnset(in_t, 0);
133 if (in_t.type_id() == TFT_UNSET) {
134 return ret_type;
135 }
136
137 ret_type.set_type_id(TFT_PRODUCT);
138 *ret_type.add_args() = out_t;
139 return ret_type;
140 };
141}
142
143ForwardTypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx) {
144 return
145 [t, element_idx](const TypeRefVector& input_types,
146 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
147 DCHECK(input_types.size() >= element_idx);
148
149 FullTypeDef ret_type;
150 ret_type.set_type_id(TFT_PRODUCT);
151 FullTypeDef* arg_t = ret_type.add_args();
152 arg_t->set_type_id(t);
153 *(arg_t->add_args()) = input_types[element_idx].get();
154
155 return ret_type;
156 };
157}
158
159ForwardTypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx,
160 int element_idx, bool homogeneous) {
161 return [t, container_idx, element_idx, homogeneous](
162 const TypeRefVector& input_types,
163 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
164 DCHECK(input_types.size() >= container_idx);
165 DCHECK(input_types.size() >= element_idx);
166
167 FullTypeDef ret_type;
168 ret_type.set_type_id(TFT_PRODUCT);
169 FullTypeDef* cont_t = ret_type.add_args();
170 cont_t->set_type_id(t);
171
172 const FullTypeDef& in_cont_t = input_types[container_idx].get();
173 const FullTypeDef& in_el_t = input_types[element_idx].get();
174
175 if (in_cont_t.type_id() != TFT_UNSET) {
176 if (in_cont_t.type_id() != t) {
177 return Status(
178 error::INVALID_ARGUMENT,
179 absl::StrCat("expected container type ", t, " for input ",
180 container_idx, ", got ", in_cont_t.DebugString()));
181 }
182 *cont_t = in_cont_t;
183 }
184
185 VLOG(1) << "ContainerAddUnary: " << cont_t->DebugString() << ", "
186 << in_el_t.DebugString() << ", " << container_idx << "; "
187 << element_idx;
188 for (const auto& tmp : input_types) {
189 VLOG(1) << " input: " << tmp.get().DebugString();
190 }
191
192 if (in_el_t.type_id() == TFT_UNSET) {
193 return ret_type;
194 }
195
196 const FullTypeDef& el_t = GetArgDefaultUnset(*cont_t, 0);
197
198 if (el_t.type_id() == TFT_UNSET) {
199 cont_t->clear_args();
200 *(cont_t->add_args()) = in_el_t;
201 return ret_type;
202 }
203
204 if (IsSubtype(in_el_t, el_t)) {
205 // Nothing to do, will not refine the container type based on a single
206 // addition.
207 return ret_type;
208 }
209
210 if (homogeneous) {
211 return Status(error::INVALID_ARGUMENT,
212 absl::StrCat("expected a subtype of ", el_t.DebugString(),
213 " for input ", element_idx,
214 " of a homogeneous container ", t, ", got ",
215 in_el_t.DebugString()));
216 } else {
217 // TODO(mdan): Implement if needed.
218 return Status(
219 error::UNIMPLEMENTED,
220 absl::StrCat("need union types for heterogeneous containers.\n"
221 "A homogeneous container would expect a subtype of ",
222 el_t.DebugString(), " for input ", element_idx,
223 ", but got ", in_el_t.DebugString()));
224 }
225 };
226}
227
228ForwardTypeInferenceFn MultiaryUnstack(
229 FullTypeId t, std::function<FullTypeDef(const FullTypeDef&)> unstack) {
230 return [t, unstack](const TypeRefVector& input_types,
231 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
232 FullTypeDef ret_type;
233 ret_type.set_type_id(TFT_PRODUCT);
234 FullTypeDef* cont_t = ret_type.add_args();
235 cont_t->set_type_id(t);
236 FullTypeDef* el_t = cont_t->add_args();
237 el_t->set_type_id(TFT_PRODUCT);
238 for (int element_idx = 0; element_idx < input_types.size(); ++element_idx) {
239 *(el_t->add_args()) = unstack(input_types[element_idx].get());
240 }
241 return ret_type;
242 };
243}
244
245FullTypeDef UnstackTensor(const FullTypeDef& t) {
246 // For now, only TFT_TENSOR and TFT_RAGGED are supported and
247 // only if they have a single argument (i.e. they don't specify a shape).
248 // If these have a shape in the future, this function needs to changed
249 // so that the output shape is computed based on the input shape and the
250 // effect of the unstack operation (e.g. a dimension is removed).
251 // TFT_UNSET is also allowed to support weak type inference where
252 // not having a fulltype is allowed.
253 DCHECK((t.type_id() == TFT_TENSOR) || (t.type_id() == TFT_RAGGED) ||
254 (t.type_id() == TFT_UNSET));
255 DCHECK_LE(t.args_size(), 1);
256 return t;
257}
258
259ForwardTypeInferenceFn ContainerMap(
260 FullTypeId t, int input_idx,
261 std::function<FullTypeDef(const FullTypeDef&)> map) {
262 return [t, input_idx, map](
263 const TypeRefVector& input_types,
264 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
265 DCHECK_GE(input_types.size(), input_idx);
266 const FullTypeDef& in_cont_t = input_types.at(input_idx).get();
267 FullTypeDef ret_type;
268 if (in_cont_t.type_id() == TFT_UNSET) {
269 return ret_type;
270 }
271 if (in_cont_t.type_id() != t) {
272 return Status(error::INVALID_ARGUMENT,
273 absl::StrCat("expected type ", t, " for input ", input_idx,
274 ", got ", in_cont_t.DebugString()));
275 }
276 ret_type.set_type_id(TFT_PRODUCT);
277 FullTypeDef* out_cont_t = ret_type.add_args();
278 out_cont_t->set_type_id(t);
279 const FullTypeDef& in_el_t = GetArgDefaultUnset(in_cont_t, 0);
280 if (in_el_t.type_id() == TFT_UNSET) {
281 return ret_type;
282 }
283 if (in_el_t.type_id() != TFT_PRODUCT) {
284 return Status(error::INVALID_ARGUMENT,
285 absl::StrCat("expected PRODUCT element type for input ",
286 input_idx, ", got ", in_el_t.DebugString()));
287 }
288 FullTypeDef* out_el_t = out_cont_t->add_args();
289 out_el_t->set_type_id(TFT_PRODUCT);
290 for (int k = 0; k < in_el_t.args_size(); k++) {
291 *(out_el_t->add_args()) = map(in_el_t.args(k));
292 }
293 return ret_type;
294 };
295}
296
297ForwardTypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx) {
298 return
299 [t, u, input_idx](const TypeRefVector& input_types,
300 const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> {
301 DCHECK_GE(input_types.size(), input_idx);
302 const FullTypeDef& in_t = input_types.at(input_idx).get();
303 FullTypeDef ret_type;
304 if (in_t.type_id() == TFT_UNSET) {
305 return ret_type;
306 }
307 if (in_t.type_id() != t) {
308 return Status(error::INVALID_ARGUMENT,
309 absl::StrCat("expected type ", t, " for input ",
310 input_idx, ", got ", in_t.DebugString()));
311 }
312 ret_type.set_type_id(TFT_PRODUCT);
313 FullTypeDef* t = ret_type.add_args();
314 t->set_type_id(u);
315 *t->mutable_args() = in_t.args();
316 return ret_type;
317 };
318}
319
320FullTypeDef BatchTensor(const FullTypeDef& t) {
321 // For now, just return the input type.
322 // If the input type has a shape in the future, this function needs to be
323 // changed so that the output shape is computed based on the input shape and
324 // the effect of the op that changes the batch size (and this function would
325 // require more information to do this computation).
326 return t;
327}
328
329FullTypeDef ShardTensor(const FullTypeDef& t) {
330 // For now, just return the input type.
331 // If the input type has a shape in the future, this function needs to be
332 // changed so that the output shape is computed based on the input shape and
333 // the effect of the op that shards the input into multiple tensors (and this
334 // function would require more information to do this computation).
335 return t;
336}
337
338} // namespace full_type
339
340} // namespace tensorflow
341