1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/full_type.pb.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | #include "tensorflow/core/framework/types.pb.h" |
21 | |
22 | namespace tensorflow { |
23 | namespace { |
24 | |
25 | // Verifies that `shapes_and_types` is a valid list handle and has the right |
26 | // dtype. |
27 | Status VerifyHandleData( |
28 | shape_inference::InferenceContext* c, |
29 | const std::vector<shape_inference::ShapeAndType>& shapes_and_types, |
30 | DataType element_dtype) { |
31 | if (shapes_and_types.size() != 1) { |
32 | return errors::InvalidArgument( |
33 | "Invalid handle_data for input list. Expected length of " |
34 | "shape_and_types: " , |
35 | 1, " Saw: " , shapes_and_types.size()); |
36 | } |
37 | const shape_inference::ShapeAndType& list_shape_type = shapes_and_types[0]; |
38 | if (list_shape_type.dtype != element_dtype) { |
39 | return errors::InvalidArgument("Expected list with element dtype " , |
40 | DataTypeString(element_dtype), |
41 | " but got list with element dtype " , |
42 | DataTypeString(list_shape_type.dtype)); |
43 | } |
44 | return OkStatus(); |
45 | } |
46 | |
47 | bool IsValidTensorListHandleData( |
48 | const std::vector<shape_inference::ShapeAndType>* handle_data) { |
49 | return handle_data != nullptr && handle_data->size() == 1; |
50 | } |
51 | |
52 | // Assumes that the handle_data is valid. |
53 | shape_inference::ShapeHandle GetElementShapeFromHandleData( |
54 | const std::vector<shape_inference::ShapeAndType>& shapes_and_types) { |
55 | return shapes_and_types[0].shape; |
56 | } |
57 | |
58 | REGISTER_OP("EmptyTensorList" ) |
59 | .Input("element_shape: shape_type" ) |
60 | .Input("max_num_elements: int32" ) |
61 | .Output("handle: variant" ) |
62 | .Attr("element_dtype: type" ) |
63 | .Attr("shape_type: {int32, int64}" ) |
64 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
65 | "element_dtype" )) |
66 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
67 | c->set_output(0, c->Scalar()); |
68 | DataType element_dtype; |
69 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
70 | shape_inference::ShapeHandle element_shape; |
71 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
72 | 0, &element_shape)); |
73 | const FullTypeDef& ret_types = c->ret_types(); |
74 | c->set_output_handle_shapes_and_types( |
75 | 0, std::vector<shape_inference::ShapeAndType>{ |
76 | {element_shape, element_dtype, ret_types.args(0)}}); |
77 | return OkStatus(); |
78 | }); |
79 | |
80 | REGISTER_OP("TensorListPushBack" ) |
81 | .Input("input_handle: variant" ) |
82 | .Input("tensor: element_dtype" ) |
83 | .Output("output_handle: variant" ) |
84 | .Attr("element_dtype: type" ) |
85 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
86 | "element_dtype" )) |
87 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
88 | c->set_output(0, c->Scalar()); |
89 | DataType element_dtype; |
90 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
91 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
92 | |
93 | auto* handle_data = c->input_handle_shapes_and_types(0); |
94 | if (handle_data != nullptr && handle_data->size() > 1) { |
95 | return errors::InvalidArgument( |
96 | "Trying to push to list with wrong variant data." ); |
97 | } |
98 | if (IsValidTensorListHandleData(handle_data)) { |
99 | const shape_inference::ShapeAndType& list_shape_type = |
100 | (*handle_data)[0]; |
101 | if (list_shape_type.dtype != element_dtype) { |
102 | return errors::InvalidArgument( |
103 | "Trying to push to list with wrong element dtype. List has type " , |
104 | DataTypeString(list_shape_type.dtype), |
105 | " but trying to push element with type " , |
106 | DataTypeString(element_dtype)); |
107 | } |
108 | shape_inference::ShapeHandle ignored; |
109 | TF_RETURN_IF_ERROR( |
110 | c->Merge(element_shape, list_shape_type.shape, &ignored)); |
111 | element_shape = list_shape_type.shape; |
112 | } |
113 | const FullTypeDef& ret_types = c->ret_types(); |
114 | c->set_output_handle_shapes_and_types( |
115 | 0, std::vector<shape_inference::ShapeAndType>{ |
116 | {element_shape, element_dtype, ret_types.args(0)}}); |
117 | return OkStatus(); |
118 | }); |
119 | |
120 | REGISTER_OP("TensorListPushBackBatch" ) |
121 | .Input("input_handles: variant" ) |
122 | .Input("tensor: element_dtype" ) |
123 | .Output("output_handles: variant" ) |
124 | .Attr("element_dtype: type" ) |
125 | // TODO(mdan): Also support for inferring from an input type as well. |
126 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
127 | "element_dtype" )) |
128 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
129 | shape_inference::ShapeHandle input_handles; |
130 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles)); |
131 | |
132 | shape_inference::ShapeHandle tensor; |
133 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor)); |
134 | |
135 | TF_RETURN_IF_ERROR( |
136 | c->MergePrefix(tensor, input_handles, &tensor, &input_handles)); |
137 | |
138 | c->set_output(0, input_handles); |
139 | |
140 | DataType element_dtype; |
141 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
142 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
143 | |
144 | auto* handle_data = c->input_handle_shapes_and_types(0); |
145 | if (handle_data != nullptr && handle_data->size() > 1) { |
146 | return errors::InvalidArgument( |
147 | "Trying to push to list with wrong variant data." ); |
148 | } |
149 | if (IsValidTensorListHandleData(handle_data)) { |
150 | const shape_inference::ShapeAndType& list_shape_type = |
151 | (*handle_data)[0]; |
152 | if (list_shape_type.dtype != element_dtype) { |
153 | return errors::InvalidArgument( |
154 | "Trying to push to list with wrong element dtype. List has type " , |
155 | DataTypeString(list_shape_type.dtype), |
156 | " but trying to push element with type " , |
157 | DataTypeString(element_dtype)); |
158 | } |
159 | shape_inference::ShapeHandle ignored; |
160 | TF_RETURN_IF_ERROR( |
161 | c->Merge(element_shape, list_shape_type.shape, &ignored)); |
162 | element_shape = list_shape_type.shape; |
163 | } |
164 | const FullTypeDef& ret_types = c->ret_types(); |
165 | c->set_output_handle_shapes_and_types( |
166 | 0, std::vector<shape_inference::ShapeAndType>{ |
167 | {element_shape, element_dtype, ret_types.args(0)}}); |
168 | return OkStatus(); |
169 | }); |
170 | |
171 | REGISTER_OP("TensorListLength" ) |
172 | .Input("input_handle: variant" ) |
173 | .Output("length: int32" ) |
174 | .SetShapeFn(shape_inference::ScalarShape); |
175 | |
176 | REGISTER_OP("TensorListPopBack" ) |
177 | .Input("input_handle: variant" ) |
178 | .Input("element_shape: int32" ) |
179 | .Output("output_handle: variant" ) |
180 | .Output("tensor: element_dtype" ) |
181 | .Attr("element_dtype: type" ) |
182 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
183 | DataType element_dtype; |
184 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
185 | shape_inference::ShapeHandle tensor_shape = c->UnknownShape(); |
186 | auto* handle_data = c->input_handle_shapes_and_types(0); |
187 | if (handle_data != nullptr && handle_data->size() > 1) { |
188 | return errors::InvalidArgument( |
189 | "Trying to read from list with invalid variant data." ); |
190 | } |
191 | if (IsValidTensorListHandleData(handle_data)) { |
192 | const shape_inference::ShapeAndType& list_shape_type = |
193 | (*handle_data)[0]; |
194 | if (list_shape_type.type.type_id() != TFT_ARRAY) { |
195 | return errors::InvalidArgument("Input argument must be a list." ); |
196 | } |
197 | if (list_shape_type.dtype != element_dtype) { |
198 | return errors::InvalidArgument( |
199 | "Trying to read from list with wrong element dtype. List has " |
200 | "type " , |
201 | DataTypeString(list_shape_type.dtype), |
202 | " but trying to push element with type " , |
203 | DataTypeString(element_dtype)); |
204 | } |
205 | shape_inference::ShapeHandle ignored; |
206 | TF_RETURN_IF_ERROR( |
207 | c->Merge(tensor_shape, list_shape_type.shape, &ignored)); |
208 | c->set_output_handle_shapes_and_types(0, *handle_data); |
209 | tensor_shape = list_shape_type.shape; |
210 | } |
211 | c->set_output(1, tensor_shape); |
212 | c->set_output(0, c->Scalar()); |
213 | return OkStatus(); |
214 | }); |
215 | |
216 | REGISTER_OP("TensorListStack" ) |
217 | .Input("input_handle: variant" ) |
218 | .Input("element_shape: int32" ) |
219 | .Output("tensor: element_dtype" ) |
220 | .Attr("element_dtype: type" ) |
221 | .Attr("num_elements: int = -1" ) |
222 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
223 | DataType element_dtype; |
224 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
225 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
226 | auto* handle_data = c->input_handle_shapes_and_types(0); |
227 | if (handle_data != nullptr && handle_data->size() > 1) { |
228 | return errors::InvalidArgument( |
229 | "Trying to read from list with wrong variant data." ); |
230 | } |
231 | if (IsValidTensorListHandleData(handle_data)) { |
232 | const shape_inference::ShapeAndType& list_shape_type = |
233 | (*handle_data)[0]; |
234 | if (list_shape_type.dtype != element_dtype) { |
235 | return errors::InvalidArgument( |
236 | "Trying to read from list with wrong element dtype. List has " |
237 | "type " , |
238 | DataTypeString(list_shape_type.dtype), " but expected type " , |
239 | DataTypeString(element_dtype)); |
240 | } |
241 | shape_inference::ShapeHandle ignored; |
242 | TF_RETURN_IF_ERROR( |
243 | c->Merge(element_shape, list_shape_type.shape, &ignored)); |
244 | element_shape = list_shape_type.shape; |
245 | } |
246 | shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); |
247 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
248 | 1, &element_shape_input)); |
249 | TF_RETURN_IF_ERROR( |
250 | c->Merge(element_shape, element_shape_input, &element_shape)); |
251 | int expected_num_elements = -1; |
252 | TF_RETURN_IF_ERROR(c->GetAttr("num_elements" , &expected_num_elements)); |
253 | shape_inference::ShapeHandle num_elements; |
254 | if (expected_num_elements == -1) { |
255 | num_elements = c->MakeShape({c->UnknownDim()}); |
256 | } else { |
257 | num_elements = c->MakeShape({expected_num_elements}); |
258 | } |
259 | shape_inference::ShapeHandle result; |
260 | TF_RETURN_IF_ERROR(c->Concatenate(num_elements, element_shape, &result)); |
261 | c->set_output(0, result); |
262 | return OkStatus(); |
263 | }); |
264 | |
265 | Status TensorListConcatShapeInference( |
266 | shape_inference::InferenceContext* c, |
267 | shape_inference::ShapeHandle element_shape) { |
268 | DataType element_dtype; |
269 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
270 | auto* handle_data = c->input_handle_shapes_and_types(0); |
271 | if (handle_data != nullptr && handle_data->size() > 1) { |
272 | return errors::InvalidArgument( |
273 | "Trying to read from list with wrong variant data." ); |
274 | } |
275 | if (IsValidTensorListHandleData(handle_data)) { |
276 | const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0]; |
277 | if (list_shape_type.dtype != element_dtype) { |
278 | return errors::InvalidArgument( |
279 | "Trying to read from list with wrong element dtype. List has " |
280 | "type " , |
281 | DataTypeString(list_shape_type.dtype), " but expected type " , |
282 | DataTypeString(element_dtype)); |
283 | } |
284 | shape_inference::ShapeHandle merged; |
285 | TF_RETURN_IF_ERROR(c->Merge(element_shape, list_shape_type.shape, &merged)); |
286 | element_shape = merged; |
287 | } |
288 | if (c->RankKnown(element_shape)) { |
289 | shape_inference::ShapeHandle result; |
290 | TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result)); |
291 | TF_RETURN_IF_ERROR( |
292 | c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result)); |
293 | c->set_output(0, result); |
294 | } else { |
295 | c->set_output(0, c->UnknownShape()); |
296 | } |
297 | c->set_output(1, c->MakeShape({c->UnknownDim()})); |
298 | return OkStatus(); |
299 | } |
300 | |
301 | REGISTER_OP("TensorListConcat" ) |
302 | .Input("input_handle: variant" ) |
303 | .Output("tensor: element_dtype" ) |
304 | .Output("lengths: int64" ) |
305 | .Attr("element_dtype: type" ) |
306 | .Attr("element_shape: shape = { unknown_rank: true }" ) |
307 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
308 | PartialTensorShape raw_element_shape; |
309 | TF_RETURN_IF_ERROR(c->GetAttr("element_shape" , &raw_element_shape)); |
310 | shape_inference::ShapeHandle element_shape; |
311 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape, |
312 | &element_shape)); |
313 | return TensorListConcatShapeInference(c, element_shape); |
314 | }); |
315 | |
316 | REGISTER_OP("TensorListConcatV2" ) |
317 | .Input("input_handle: variant" ) |
318 | .Input("element_shape: shape_type" ) |
319 | .Input("leading_dims: int64" ) |
320 | .Output("tensor: element_dtype" ) |
321 | .Output("lengths: int64" ) |
322 | .Attr("element_dtype: type" ) |
323 | .Attr("shape_type: {int32, int64}" ) |
324 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
325 | shape_inference::ShapeHandle element_shape; |
326 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
327 | 1, &element_shape)); |
328 | return TensorListConcatShapeInference(c, element_shape); |
329 | }); |
330 | |
331 | REGISTER_OP("TensorListSplit" ) |
332 | .Input("tensor: element_dtype" ) |
333 | .Input("element_shape: shape_type" ) |
334 | .Input("lengths: int64" ) |
335 | .Output("output_handle: variant" ) |
336 | .Attr("element_dtype: type" ) |
337 | .Attr("shape_type: {int32, int64}" ) |
338 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
339 | "element_dtype" )) |
340 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
341 | c->set_output(0, c->Scalar()); |
342 | DataType element_dtype; |
343 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
344 | shape_inference::ShapeHandle tensor_shape = c->input(0); |
345 | shape_inference::ShapeHandle ignored; |
346 | // Check that tensor is at least a vector. |
347 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(tensor_shape, 1, &ignored)); |
348 | // Check that lengths is a vector. |
349 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored)); |
350 | shape_inference::ShapeHandle element_shape_from_tensor_shape; |
351 | TF_RETURN_IF_ERROR( |
352 | c->Subshape(tensor_shape, 1, &element_shape_from_tensor_shape)); |
353 | TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape({c->UnknownDim()}), |
354 | element_shape_from_tensor_shape, |
355 | &element_shape_from_tensor_shape)); |
356 | shape_inference::ShapeHandle element_shape; |
357 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
358 | 1, &element_shape)); |
359 | TF_RETURN_IF_ERROR(c->Merge(element_shape_from_tensor_shape, |
360 | element_shape, |
361 | &element_shape_from_tensor_shape)); |
362 | const FullTypeDef& ret_types = c->ret_types(); |
363 | c->set_output_handle_shapes_and_types( |
364 | 0, std::vector<shape_inference::ShapeAndType>{ |
365 | {element_shape, element_dtype, ret_types.args(0)}}); |
366 | return OkStatus(); |
367 | }); |
368 | |
369 | REGISTER_OP("TensorListFromTensor" ) |
370 | .Input("tensor: element_dtype" ) |
371 | .Input("element_shape: shape_type" ) |
372 | .Output("output_handle: variant" ) |
373 | .Attr("element_dtype: type" ) |
374 | .Attr("shape_type: {int32, int64}" ) |
375 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
376 | "element_dtype" )) |
377 | .SetForwardTypeFn(full_type::UnaryContainerCreate(TFT_ARRAY, 0)) |
378 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
379 | c->set_output(0, c->Scalar()); |
380 | DataType element_dtype; |
381 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
382 | shape_inference::ShapeHandle tensor_shape = c->input(0); |
383 | shape_inference::ShapeHandle tensor_shape_except_first_dim; |
384 | TF_RETURN_IF_ERROR( |
385 | c->Subshape(tensor_shape, 1, &tensor_shape_except_first_dim)); |
386 | shape_inference::ShapeHandle element_shape; |
387 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
388 | 1, &element_shape)); |
389 | TF_RETURN_IF_ERROR(c->Merge(tensor_shape_except_first_dim, element_shape, |
390 | &tensor_shape_except_first_dim)); |
391 | const FullTypeDef& ret_types = c->ret_types(); |
392 | c->set_output_handle_shapes_and_types( |
393 | 0, std::vector<shape_inference::ShapeAndType>{ |
394 | {element_shape, element_dtype, ret_types.args(0)}}); |
395 | return OkStatus(); |
396 | }); |
397 | |
398 | REGISTER_OP("TensorListElementShape" ) |
399 | .Input("input_handle: variant" ) |
400 | .Output("element_shape: shape_type" ) |
401 | .Attr("shape_type: {int32, int64}" ) |
402 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
403 | auto* handle_data = c->input_handle_shapes_and_types(0); |
404 | // `TensorListElementShape` returns the scalar -1 if the rank of |
405 | // element_shape is unknown else returns the shape vector (with possibly |
406 | // unknown dims). |
407 | if (!IsValidTensorListHandleData(handle_data)) { |
408 | c->set_output(0, c->UnknownShape()); |
409 | return OkStatus(); |
410 | } |
411 | if (c->RankKnown((*handle_data)[0].shape)) { |
412 | c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape))); |
413 | } else { |
414 | c->set_output(0, c->UnknownShape()); |
415 | } |
416 | return OkStatus(); |
417 | }); |
418 | |
419 | REGISTER_OP("TensorListReserve" ) |
420 | .Input("element_shape: shape_type" ) |
421 | .Input("num_elements: int32" ) |
422 | .Output("handle: variant" ) |
423 | .Attr("element_dtype: type" ) |
424 | .Attr("shape_type: {int32, int64}" ) |
425 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
426 | "element_dtype" )) |
427 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
428 | c->set_output(0, c->Scalar()); |
429 | shape_inference::ShapeHandle element_shape; |
430 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
431 | 0, &element_shape)); |
432 | DataType element_dtype; |
433 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
434 | const FullTypeDef& ret_types = c->ret_types(); |
435 | c->set_output_handle_shapes_and_types( |
436 | 0, std::vector<shape_inference::ShapeAndType>{ |
437 | {element_shape, element_dtype, ret_types.args(0)}}); |
438 | return OkStatus(); |
439 | }); |
440 | |
441 | REGISTER_OP("TensorListGetItem" ) |
442 | .Input("input_handle: variant" ) |
443 | .Input("index: int32" ) |
444 | .Input("element_shape: int32" ) |
445 | .Output("item: element_dtype" ) |
446 | .Attr("element_dtype: type" ) |
447 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
448 | DataType element_dtype; |
449 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
450 | auto* handle_data = c->input_handle_shapes_and_types(0); |
451 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
452 | if (IsValidTensorListHandleData(handle_data)) { |
453 | const shape_inference::ShapeAndType& list_shape_type = |
454 | (*handle_data)[0]; |
455 | element_shape = list_shape_type.shape; |
456 | if (list_shape_type.dtype != element_dtype) { |
457 | return errors::InvalidArgument("Expected list with element dtype " , |
458 | DataTypeString(element_dtype), |
459 | " but got list with element dtype " , |
460 | DataTypeString(list_shape_type.dtype)); |
461 | } |
462 | } |
463 | shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); |
464 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
465 | 2, &element_shape_input)); |
466 | TF_RETURN_IF_ERROR( |
467 | c->Merge(element_shape, element_shape_input, &element_shape)); |
468 | c->set_output(0, element_shape); |
469 | return OkStatus(); |
470 | }); |
471 | |
472 | REGISTER_OP("TensorListResize" ) |
473 | .Input("input_handle: variant" ) |
474 | .Input("size: int32" ) |
475 | .Output("output_handle: variant" ) |
476 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
477 | // Check that `size` has scalar shape. |
478 | shape_inference::ShapeHandle size_shape = c->input(1); |
479 | shape_inference::ShapeHandle unused; |
480 | TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused)); |
481 | c->set_output(0, c->Scalar()); |
482 | auto* handle_data = c->input_handle_shapes_and_types(0); |
483 | if (IsValidTensorListHandleData(handle_data)) { |
484 | c->set_output_handle_shapes_and_types(0, *handle_data); |
485 | } |
486 | return OkStatus(); |
487 | }); |
488 | |
489 | REGISTER_OP("TensorListSetItem" ) |
490 | .Input("input_handle: variant" ) |
491 | .Input("index: int32" ) |
492 | .Input("item: element_dtype" ) |
493 | .Output("output_handle: variant" ) |
494 | .Attr("element_dtype: type" ) |
495 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
496 | "element_dtype" )) |
497 | .SetForwardTypeFn(full_type::UnaryContainerAdd(TFT_ARRAY, |
498 | /*container_idx=*/0, |
499 | /*element_idx=*/2, |
500 | /*homogeneous=*/true)) |
501 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
502 | DataType element_dtype; |
503 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
504 | auto* handle_data = c->input_handle_shapes_and_types(0); |
505 | c->set_output(0, c->Scalar()); |
506 | if (IsValidTensorListHandleData(handle_data)) { |
507 | const shape_inference::ShapeAndType& list_shape_type = |
508 | (*handle_data)[0]; |
509 | shape_inference::ShapeHandle item_shape = c->input(2); |
510 | TF_RETURN_IF_ERROR( |
511 | c->Merge(item_shape, list_shape_type.shape, &item_shape)); |
512 | c->set_output_handle_shapes_and_types(0, *handle_data); |
513 | } else { |
514 | const FullTypeDef& ret_types = c->ret_types(); |
515 | c->set_output_handle_shapes_and_types( |
516 | 0, std::vector<shape_inference::ShapeAndType>{ |
517 | {c->UnknownShape(), element_dtype, ret_types.args(0)}}); |
518 | } |
519 | return OkStatus(); |
520 | }); |
521 | |
522 | REGISTER_OP("TensorListGather" ) |
523 | .Input("input_handle: variant" ) |
524 | .Input("indices: int32" ) |
525 | .Input("element_shape: int32" ) |
526 | .Output("values: element_dtype" ) |
527 | .Attr("element_dtype: type" ) |
528 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
529 | DataType element_dtype; |
530 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
531 | auto* handle_data = c->input_handle_shapes_and_types(0); |
532 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
533 | if (IsValidTensorListHandleData(handle_data)) { |
534 | const shape_inference::ShapeAndType& list_shape_type = |
535 | (*handle_data)[0]; |
536 | element_shape = list_shape_type.shape; |
537 | if (list_shape_type.dtype != element_dtype) { |
538 | return errors::InvalidArgument("Expected list with element dtype " , |
539 | DataTypeString(element_dtype), |
540 | " but got list with element dtype " , |
541 | DataTypeString(list_shape_type.dtype)); |
542 | } |
543 | } |
544 | shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); |
545 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
546 | 2, &element_shape_input)); |
547 | TF_RETURN_IF_ERROR( |
548 | c->Merge(element_shape, element_shape_input, &element_shape)); |
549 | shape_inference::ShapeHandle out; |
550 | TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); |
551 | c->set_output(0, out); |
552 | return OkStatus(); |
553 | }); |
554 | |
555 | REGISTER_OP("TensorListScatter" ) |
556 | .Input("tensor: element_dtype" ) |
557 | .Input("indices: int32" ) |
558 | .Input("element_shape: shape_type" ) |
559 | .Output("output_handle: variant" ) |
560 | .Attr("element_dtype: type" ) |
561 | .Attr("shape_type: {int32, int64}" ) |
562 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
563 | "element_dtype" )) |
564 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
565 | DataType element_dtype; |
566 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
567 | shape_inference::ShapeHandle element_shape; |
568 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
569 | 2, &element_shape)); |
570 | const FullTypeDef& ret_types = c->ret_types(); |
571 | c->set_output_handle_shapes_and_types( |
572 | 0, std::vector<shape_inference::ShapeAndType>{ |
573 | {element_shape, element_dtype, ret_types.args(0)}}); |
574 | c->set_output(0, c->Scalar()); |
575 | return OkStatus(); |
576 | }); |
577 | |
578 | REGISTER_OP("TensorListScatterV2" ) |
579 | .Input("tensor: element_dtype" ) |
580 | .Input("indices: int32" ) |
581 | .Input("element_shape: shape_type" ) |
582 | .Input("num_elements: int32" ) |
583 | .Output("output_handle: variant" ) |
584 | .Attr("element_dtype: type" ) |
585 | .Attr("shape_type: {int32, int64}" ) |
586 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
587 | "element_dtype" )) |
588 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
589 | DataType element_dtype; |
590 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
591 | shape_inference::ShapeHandle element_shape; |
592 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( |
593 | 2, &element_shape)); |
594 | const FullTypeDef& ret_types = c->ret_types(); |
595 | c->set_output_handle_shapes_and_types( |
596 | 0, std::vector<shape_inference::ShapeAndType>{ |
597 | {element_shape, element_dtype, ret_types.args(0)}}); |
598 | c->set_output(0, c->Scalar()); |
599 | return OkStatus(); |
600 | }); |
601 | |
602 | REGISTER_OP("TensorListScatterIntoExistingList" ) |
603 | .Input("input_handle: variant" ) |
604 | .Input("tensor: element_dtype" ) |
605 | .Input("indices: int32" ) |
606 | .Output("output_handle: variant" ) |
607 | .Attr("element_dtype: type" ) |
608 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
609 | "element_dtype" )) |
610 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
611 | shape_inference::ShapeHandle ignored; |
612 | // Check that tensor is at least a vector. |
613 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &ignored)); |
614 | // Check that indices is a vector. |
615 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored)); |
616 | |
617 | DataType element_dtype; |
618 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
619 | shape_inference::ShapeHandle element_shape = c->UnknownShape(); |
620 | |
621 | auto* handle_data = c->input_handle_shapes_and_types(0); |
622 | if (IsValidTensorListHandleData(handle_data)) { |
623 | TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype)); |
624 | element_shape = GetElementShapeFromHandleData(*handle_data); |
625 | } |
626 | const FullTypeDef& ret_types = c->ret_types(); |
627 | c->set_output_handle_shapes_and_types( |
628 | 0, std::vector<shape_inference::ShapeAndType>{ |
629 | {element_shape, element_dtype, ret_types.args(0)}}); |
630 | c->set_output(0, c->Scalar()); |
631 | return OkStatus(); |
632 | }); |
633 | |
634 | REGISTER_OP("TensorListConcatLists" ) |
635 | .Input("input_a: variant" ) |
636 | .Input("input_b: variant" ) |
637 | .Attr("element_dtype: type" ) |
638 | .Output("output: variant" ) |
639 | .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY, |
640 | "element_dtype" )) |
641 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
642 | auto input_a = c->input(0); |
643 | auto input_b = c->input(1); |
644 | TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a)); |
645 | c->set_output(0, input_a); |
646 | |
647 | DataType element_dtype; |
648 | TF_RETURN_IF_ERROR(c->GetAttr("element_dtype" , &element_dtype)); |
649 | |
650 | auto* handle_data_a = c->input_handle_shapes_and_types(0); |
651 | auto* handle_data_b = c->input_handle_shapes_and_types(1); |
652 | bool handle_data_a_nonempty = handle_data_a && !handle_data_a->empty(); |
653 | bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty(); |
654 | if (!(handle_data_a_nonempty || handle_data_b_nonempty)) { |
655 | const FullTypeDef& ret_types = c->ret_types(); |
656 | c->set_output_handle_shapes_and_types( |
657 | 0, {{c->UnknownShape(), element_dtype, ret_types.args(0)}}); |
658 | return OkStatus(); |
659 | } |
660 | shape_inference::ShapeAndType list_shape_type_a = |
661 | handle_data_a_nonempty ? handle_data_a->at(0) : handle_data_b->at(0); |
662 | const shape_inference::ShapeAndType& list_shape_type_b = |
663 | handle_data_b_nonempty ? handle_data_b->at(0) : handle_data_a->at(0); |
664 | if (list_shape_type_a.dtype != element_dtype) { |
665 | return errors::InvalidArgument("input_a.type != element_dtype: " , |
666 | DataTypeString(list_shape_type_a.dtype), |
667 | " vs. " , DataTypeString(element_dtype)); |
668 | } |
669 | if (list_shape_type_b.dtype != element_dtype) { |
670 | return errors::InvalidArgument("input_b.type != element_dtype: " , |
671 | DataTypeString(list_shape_type_b.dtype), |
672 | " vs. " , DataTypeString(element_dtype)); |
673 | } |
674 | TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape, |
675 | list_shape_type_b.shape, |
676 | &list_shape_type_a.shape)); |
677 | c->set_output_handle_shapes_and_types(0, {list_shape_type_a}); |
678 | return OkStatus(); |
679 | }); |
680 | |
681 | } // namespace |
682 | } // namespace tensorflow |
683 | |