1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/full_type_inference_util.h" |
17 | |
18 | #include <functional> |
19 | #include <string> |
20 | |
21 | #include "absl/strings/str_cat.h" |
22 | #include "tensorflow/core/framework/full_type.pb.h" |
23 | #include "tensorflow/core/framework/full_type_util.h" |
24 | #include "tensorflow/core/framework/op_def_builder.h" |
25 | #include "tensorflow/core/platform/status.h" |
26 | #include "tensorflow/core/platform/statusor.h" |
27 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | namespace full_type { |
32 | |
33 | // Note about error handling: |
34 | // For inputs which depend on the correctness of the op definition |
35 | // (i.e. if the op has three inputs, don't set an `i` that exceeds that), |
36 | // use DCHECK - an incorrect op def is considered a bug. |
37 | // Whereas for inputs that depend on the correctness of the graph (i.e. user |
38 | // used the correct ops), use Status - an incorrect graph is considered a user |
39 | // error. |
40 | |
41 | ForwardTypeInferenceFn KeepExisting() { return nullptr; } |
42 | |
43 | ForwardTypeInferenceFn ReplicateInput(int i, int n) { |
44 | return [i, n](const TypeRefVector& input_types, const TypeRefMap& type_vars) { |
45 | const FullTypeDef& in_type = input_types.at(i).get(); |
46 | FullTypeDef ret_type; |
47 | if (in_type.type_id() != TFT_UNSET) { |
48 | ret_type.set_type_id(TFT_PRODUCT); |
49 | for (int k = 0; k < n; k++) { |
50 | *(ret_type.add_args()) = in_type; |
51 | } |
52 | } |
53 | return ret_type; |
54 | }; |
55 | } |
56 | |
57 | ForwardTypeInferenceFn Merge() { |
58 | return [](const TypeRefVector& input_types, |
59 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
60 | DCHECK(!input_types.empty()); |
61 | |
62 | FullTypeDef merged; |
63 | for (int i = 0; i < input_types.size(); i++) { |
64 | const auto& t = input_types[i].get(); |
65 | |
66 | if (t.type_id() == TFT_UNSET) { |
67 | continue; |
68 | } |
69 | |
70 | if (IsSubtype(t, merged)) { |
71 | merged = t; |
72 | continue; |
73 | } |
74 | if (IsSubtype(merged, t)) { |
75 | continue; |
76 | } |
77 | |
78 | return Status(error::INVALID_ARGUMENT, |
79 | absl::StrCat("expected compatible input types, but input " , |
80 | i, ":\n" , t.DebugString(), |
81 | " is neither a subtype nor a supertype of the " |
82 | "combined inputs preceding it:\n" , |
83 | merged.DebugString())); |
84 | } |
85 | |
86 | FullTypeDef ret_type; |
87 | if (merged.type_id() != TFT_UNSET) { |
88 | ret_type.set_type_id(TFT_PRODUCT); |
89 | *(ret_type.add_args()) = merged; |
90 | } |
91 | return ret_type; |
92 | }; |
93 | } |
94 | |
95 | ForwardTypeInferenceFn Encode(FullTypeId t, int i) { |
96 | return [t, i](const TypeRefVector& input_types, |
97 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
98 | DCHECK(input_types.size() >= i); |
99 | |
100 | FullTypeDef ret_type; |
101 | const FullTypeDef& in_t = input_types[i].get(); |
102 | if (in_t.type_id() == TFT_UNSET) { |
103 | return ret_type; |
104 | } |
105 | |
106 | ret_type.set_type_id(TFT_PRODUCT); |
107 | |
108 | auto* enc_type = ret_type.add_args(); |
109 | enc_type->set_type_id(TFT_ENCODED); |
110 | *enc_type->add_args() = in_t; |
111 | enc_type->add_args()->set_type_id(t); |
112 | return ret_type; |
113 | }; |
114 | } |
115 | |
116 | ForwardTypeInferenceFn Decode(FullTypeId t, int i) { |
117 | return [t, i](const TypeRefVector& input_types, |
118 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
119 | DCHECK(input_types.size() >= i); |
120 | |
121 | const FullTypeDef& in_t = input_types[i].get(); |
122 | |
123 | const FullTypeId enc_tid = GetArgDefaultUnset(in_t, 1).type_id(); |
124 | if ((enc_tid != TFT_UNSET) && (enc_tid != t)) { |
125 | return Status(error::INVALID_ARGUMENT, |
126 | absl::StrCat("expected encoded type " , t, " for input " , i, |
127 | ", got " , in_t.DebugString())); |
128 | } |
129 | |
130 | FullTypeDef ret_type; |
131 | |
132 | const FullTypeDef& out_t = GetArgDefaultUnset(in_t, 0); |
133 | if (in_t.type_id() == TFT_UNSET) { |
134 | return ret_type; |
135 | } |
136 | |
137 | ret_type.set_type_id(TFT_PRODUCT); |
138 | *ret_type.add_args() = out_t; |
139 | return ret_type; |
140 | }; |
141 | } |
142 | |
143 | ForwardTypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx) { |
144 | return |
145 | [t, element_idx](const TypeRefVector& input_types, |
146 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
147 | DCHECK(input_types.size() >= element_idx); |
148 | |
149 | FullTypeDef ret_type; |
150 | ret_type.set_type_id(TFT_PRODUCT); |
151 | FullTypeDef* arg_t = ret_type.add_args(); |
152 | arg_t->set_type_id(t); |
153 | *(arg_t->add_args()) = input_types[element_idx].get(); |
154 | |
155 | return ret_type; |
156 | }; |
157 | } |
158 | |
159 | ForwardTypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx, |
160 | int element_idx, bool homogeneous) { |
161 | return [t, container_idx, element_idx, homogeneous]( |
162 | const TypeRefVector& input_types, |
163 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
164 | DCHECK(input_types.size() >= container_idx); |
165 | DCHECK(input_types.size() >= element_idx); |
166 | |
167 | FullTypeDef ret_type; |
168 | ret_type.set_type_id(TFT_PRODUCT); |
169 | FullTypeDef* cont_t = ret_type.add_args(); |
170 | cont_t->set_type_id(t); |
171 | |
172 | const FullTypeDef& in_cont_t = input_types[container_idx].get(); |
173 | const FullTypeDef& in_el_t = input_types[element_idx].get(); |
174 | |
175 | if (in_cont_t.type_id() != TFT_UNSET) { |
176 | if (in_cont_t.type_id() != t) { |
177 | return Status( |
178 | error::INVALID_ARGUMENT, |
179 | absl::StrCat("expected container type " , t, " for input " , |
180 | container_idx, ", got " , in_cont_t.DebugString())); |
181 | } |
182 | *cont_t = in_cont_t; |
183 | } |
184 | |
185 | VLOG(1) << "ContainerAddUnary: " << cont_t->DebugString() << ", " |
186 | << in_el_t.DebugString() << ", " << container_idx << "; " |
187 | << element_idx; |
188 | for (const auto& tmp : input_types) { |
189 | VLOG(1) << " input: " << tmp.get().DebugString(); |
190 | } |
191 | |
192 | if (in_el_t.type_id() == TFT_UNSET) { |
193 | return ret_type; |
194 | } |
195 | |
196 | const FullTypeDef& el_t = GetArgDefaultUnset(*cont_t, 0); |
197 | |
198 | if (el_t.type_id() == TFT_UNSET) { |
199 | cont_t->clear_args(); |
200 | *(cont_t->add_args()) = in_el_t; |
201 | return ret_type; |
202 | } |
203 | |
204 | if (IsSubtype(in_el_t, el_t)) { |
205 | // Nothing to do, will not refine the container type based on a single |
206 | // addition. |
207 | return ret_type; |
208 | } |
209 | |
210 | if (homogeneous) { |
211 | return Status(error::INVALID_ARGUMENT, |
212 | absl::StrCat("expected a subtype of " , el_t.DebugString(), |
213 | " for input " , element_idx, |
214 | " of a homogeneous container " , t, ", got " , |
215 | in_el_t.DebugString())); |
216 | } else { |
217 | // TODO(mdan): Implement if needed. |
218 | return Status( |
219 | error::UNIMPLEMENTED, |
220 | absl::StrCat("need union types for heterogeneous containers.\n" |
221 | "A homogeneous container would expect a subtype of " , |
222 | el_t.DebugString(), " for input " , element_idx, |
223 | ", but got " , in_el_t.DebugString())); |
224 | } |
225 | }; |
226 | } |
227 | |
228 | ForwardTypeInferenceFn MultiaryUnstack( |
229 | FullTypeId t, std::function<FullTypeDef(const FullTypeDef&)> unstack) { |
230 | return [t, unstack](const TypeRefVector& input_types, |
231 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
232 | FullTypeDef ret_type; |
233 | ret_type.set_type_id(TFT_PRODUCT); |
234 | FullTypeDef* cont_t = ret_type.add_args(); |
235 | cont_t->set_type_id(t); |
236 | FullTypeDef* el_t = cont_t->add_args(); |
237 | el_t->set_type_id(TFT_PRODUCT); |
238 | for (int element_idx = 0; element_idx < input_types.size(); ++element_idx) { |
239 | *(el_t->add_args()) = unstack(input_types[element_idx].get()); |
240 | } |
241 | return ret_type; |
242 | }; |
243 | } |
244 | |
245 | FullTypeDef UnstackTensor(const FullTypeDef& t) { |
246 | // For now, only TFT_TENSOR and TFT_RAGGED are supported and |
247 | // only if they have a single argument (i.e. they don't specify a shape). |
248 | // If these have a shape in the future, this function needs to changed |
249 | // so that the output shape is computed based on the input shape and the |
250 | // effect of the unstack operation (e.g. a dimension is removed). |
251 | // TFT_UNSET is also allowed to support weak type inference where |
252 | // not having a fulltype is allowed. |
253 | DCHECK((t.type_id() == TFT_TENSOR) || (t.type_id() == TFT_RAGGED) || |
254 | (t.type_id() == TFT_UNSET)); |
255 | DCHECK_LE(t.args_size(), 1); |
256 | return t; |
257 | } |
258 | |
259 | ForwardTypeInferenceFn ContainerMap( |
260 | FullTypeId t, int input_idx, |
261 | std::function<FullTypeDef(const FullTypeDef&)> map) { |
262 | return [t, input_idx, map]( |
263 | const TypeRefVector& input_types, |
264 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
265 | DCHECK_GE(input_types.size(), input_idx); |
266 | const FullTypeDef& in_cont_t = input_types.at(input_idx).get(); |
267 | FullTypeDef ret_type; |
268 | if (in_cont_t.type_id() == TFT_UNSET) { |
269 | return ret_type; |
270 | } |
271 | if (in_cont_t.type_id() != t) { |
272 | return Status(error::INVALID_ARGUMENT, |
273 | absl::StrCat("expected type " , t, " for input " , input_idx, |
274 | ", got " , in_cont_t.DebugString())); |
275 | } |
276 | ret_type.set_type_id(TFT_PRODUCT); |
277 | FullTypeDef* out_cont_t = ret_type.add_args(); |
278 | out_cont_t->set_type_id(t); |
279 | const FullTypeDef& in_el_t = GetArgDefaultUnset(in_cont_t, 0); |
280 | if (in_el_t.type_id() == TFT_UNSET) { |
281 | return ret_type; |
282 | } |
283 | if (in_el_t.type_id() != TFT_PRODUCT) { |
284 | return Status(error::INVALID_ARGUMENT, |
285 | absl::StrCat("expected PRODUCT element type for input " , |
286 | input_idx, ", got " , in_el_t.DebugString())); |
287 | } |
288 | FullTypeDef* out_el_t = out_cont_t->add_args(); |
289 | out_el_t->set_type_id(TFT_PRODUCT); |
290 | for (int k = 0; k < in_el_t.args_size(); k++) { |
291 | *(out_el_t->add_args()) = map(in_el_t.args(k)); |
292 | } |
293 | return ret_type; |
294 | }; |
295 | } |
296 | |
297 | ForwardTypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx) { |
298 | return |
299 | [t, u, input_idx](const TypeRefVector& input_types, |
300 | const TypeRefMap& type_vars) -> StatusOr<FullTypeDef> { |
301 | DCHECK_GE(input_types.size(), input_idx); |
302 | const FullTypeDef& in_t = input_types.at(input_idx).get(); |
303 | FullTypeDef ret_type; |
304 | if (in_t.type_id() == TFT_UNSET) { |
305 | return ret_type; |
306 | } |
307 | if (in_t.type_id() != t) { |
308 | return Status(error::INVALID_ARGUMENT, |
309 | absl::StrCat("expected type " , t, " for input " , |
310 | input_idx, ", got " , in_t.DebugString())); |
311 | } |
312 | ret_type.set_type_id(TFT_PRODUCT); |
313 | FullTypeDef* t = ret_type.add_args(); |
314 | t->set_type_id(u); |
315 | *t->mutable_args() = in_t.args(); |
316 | return ret_type; |
317 | }; |
318 | } |
319 | |
320 | FullTypeDef BatchTensor(const FullTypeDef& t) { |
321 | // For now, just return the input type. |
322 | // If the input type has a shape in the future, this function needs to be |
323 | // changed so that the output shape is computed based on the input shape and |
324 | // the effect of the op that changes the batch size (and this function would |
325 | // require more information to do this computation). |
326 | return t; |
327 | } |
328 | |
329 | FullTypeDef ShardTensor(const FullTypeDef& t) { |
330 | // For now, just return the input type. |
331 | // If the input type has a shape in the future, this function needs to be |
332 | // changed so that the output shape is computed based on the input shape and |
333 | // the effect of the op that shards the input into multiple tensors (and this |
334 | // function would require more information to do this computation). |
335 | return t; |
336 | } |
337 | |
338 | } // namespace full_type |
339 | |
340 | } // namespace tensorflow |
341 | |