1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/full_type.pb.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20#include "tensorflow/core/framework/types.pb.h"
21
22namespace tensorflow {
23namespace {
24
25// Verifies that `shapes_and_types` is a valid list handle and has the right
26// dtype.
27Status VerifyHandleData(
28 shape_inference::InferenceContext* c,
29 const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
30 DataType element_dtype) {
31 if (shapes_and_types.size() != 1) {
32 return errors::InvalidArgument(
33 "Invalid handle_data for input list. Expected length of "
34 "shape_and_types: ",
35 1, " Saw: ", shapes_and_types.size());
36 }
37 const shape_inference::ShapeAndType& list_shape_type = shapes_and_types[0];
38 if (list_shape_type.dtype != element_dtype) {
39 return errors::InvalidArgument("Expected list with element dtype ",
40 DataTypeString(element_dtype),
41 " but got list with element dtype ",
42 DataTypeString(list_shape_type.dtype));
43 }
44 return OkStatus();
45}
46
47bool IsValidTensorListHandleData(
48 const std::vector<shape_inference::ShapeAndType>* handle_data) {
49 return handle_data != nullptr && handle_data->size() == 1;
50}
51
52// Assumes that the handle_data is valid.
53shape_inference::ShapeHandle GetElementShapeFromHandleData(
54 const std::vector<shape_inference::ShapeAndType>& shapes_and_types) {
55 return shapes_and_types[0].shape;
56}
57
58REGISTER_OP("EmptyTensorList")
59 .Input("element_shape: shape_type")
60 .Input("max_num_elements: int32")
61 .Output("handle: variant")
62 .Attr("element_dtype: type")
63 .Attr("shape_type: {int32, int64}")
64 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
65 "element_dtype"))
66 .SetShapeFn([](shape_inference::InferenceContext* c) {
67 c->set_output(0, c->Scalar());
68 DataType element_dtype;
69 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
70 shape_inference::ShapeHandle element_shape;
71 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
72 0, &element_shape));
73 const FullTypeDef& ret_types = c->ret_types();
74 c->set_output_handle_shapes_and_types(
75 0, std::vector<shape_inference::ShapeAndType>{
76 {element_shape, element_dtype, ret_types.args(0)}});
77 return OkStatus();
78 });
79
80REGISTER_OP("TensorListPushBack")
81 .Input("input_handle: variant")
82 .Input("tensor: element_dtype")
83 .Output("output_handle: variant")
84 .Attr("element_dtype: type")
85 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
86 "element_dtype"))
87 .SetShapeFn([](shape_inference::InferenceContext* c) {
88 c->set_output(0, c->Scalar());
89 DataType element_dtype;
90 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
91 shape_inference::ShapeHandle element_shape = c->UnknownShape();
92
93 auto* handle_data = c->input_handle_shapes_and_types(0);
94 if (handle_data != nullptr && handle_data->size() > 1) {
95 return errors::InvalidArgument(
96 "Trying to push to list with wrong variant data.");
97 }
98 if (IsValidTensorListHandleData(handle_data)) {
99 const shape_inference::ShapeAndType& list_shape_type =
100 (*handle_data)[0];
101 if (list_shape_type.dtype != element_dtype) {
102 return errors::InvalidArgument(
103 "Trying to push to list with wrong element dtype. List has type ",
104 DataTypeString(list_shape_type.dtype),
105 " but trying to push element with type ",
106 DataTypeString(element_dtype));
107 }
108 shape_inference::ShapeHandle ignored;
109 TF_RETURN_IF_ERROR(
110 c->Merge(element_shape, list_shape_type.shape, &ignored));
111 element_shape = list_shape_type.shape;
112 }
113 const FullTypeDef& ret_types = c->ret_types();
114 c->set_output_handle_shapes_and_types(
115 0, std::vector<shape_inference::ShapeAndType>{
116 {element_shape, element_dtype, ret_types.args(0)}});
117 return OkStatus();
118 });
119
120REGISTER_OP("TensorListPushBackBatch")
121 .Input("input_handles: variant")
122 .Input("tensor: element_dtype")
123 .Output("output_handles: variant")
124 .Attr("element_dtype: type")
125 // TODO(mdan): Also support for inferring from an input type as well.
126 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
127 "element_dtype"))
128 .SetShapeFn([](shape_inference::InferenceContext* c) {
129 shape_inference::ShapeHandle input_handles;
130 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles));
131
132 shape_inference::ShapeHandle tensor;
133 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor));
134
135 TF_RETURN_IF_ERROR(
136 c->MergePrefix(tensor, input_handles, &tensor, &input_handles));
137
138 c->set_output(0, input_handles);
139
140 DataType element_dtype;
141 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
142 shape_inference::ShapeHandle element_shape = c->UnknownShape();
143
144 auto* handle_data = c->input_handle_shapes_and_types(0);
145 if (handle_data != nullptr && handle_data->size() > 1) {
146 return errors::InvalidArgument(
147 "Trying to push to list with wrong variant data.");
148 }
149 if (IsValidTensorListHandleData(handle_data)) {
150 const shape_inference::ShapeAndType& list_shape_type =
151 (*handle_data)[0];
152 if (list_shape_type.dtype != element_dtype) {
153 return errors::InvalidArgument(
154 "Trying to push to list with wrong element dtype. List has type ",
155 DataTypeString(list_shape_type.dtype),
156 " but trying to push element with type ",
157 DataTypeString(element_dtype));
158 }
159 shape_inference::ShapeHandle ignored;
160 TF_RETURN_IF_ERROR(
161 c->Merge(element_shape, list_shape_type.shape, &ignored));
162 element_shape = list_shape_type.shape;
163 }
164 const FullTypeDef& ret_types = c->ret_types();
165 c->set_output_handle_shapes_and_types(
166 0, std::vector<shape_inference::ShapeAndType>{
167 {element_shape, element_dtype, ret_types.args(0)}});
168 return OkStatus();
169 });
170
171REGISTER_OP("TensorListLength")
172 .Input("input_handle: variant")
173 .Output("length: int32")
174 .SetShapeFn(shape_inference::ScalarShape);
175
176REGISTER_OP("TensorListPopBack")
177 .Input("input_handle: variant")
178 .Input("element_shape: int32")
179 .Output("output_handle: variant")
180 .Output("tensor: element_dtype")
181 .Attr("element_dtype: type")
182 .SetShapeFn([](shape_inference::InferenceContext* c) {
183 DataType element_dtype;
184 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
185 shape_inference::ShapeHandle tensor_shape = c->UnknownShape();
186 auto* handle_data = c->input_handle_shapes_and_types(0);
187 if (handle_data != nullptr && handle_data->size() > 1) {
188 return errors::InvalidArgument(
189 "Trying to read from list with invalid variant data.");
190 }
191 if (IsValidTensorListHandleData(handle_data)) {
192 const shape_inference::ShapeAndType& list_shape_type =
193 (*handle_data)[0];
194 if (list_shape_type.type.type_id() != TFT_ARRAY) {
195 return errors::InvalidArgument("Input argument must be a list.");
196 }
197 if (list_shape_type.dtype != element_dtype) {
198 return errors::InvalidArgument(
199 "Trying to read from list with wrong element dtype. List has "
200 "type ",
201 DataTypeString(list_shape_type.dtype),
202 " but trying to push element with type ",
203 DataTypeString(element_dtype));
204 }
205 shape_inference::ShapeHandle ignored;
206 TF_RETURN_IF_ERROR(
207 c->Merge(tensor_shape, list_shape_type.shape, &ignored));
208 c->set_output_handle_shapes_and_types(0, *handle_data);
209 tensor_shape = list_shape_type.shape;
210 }
211 c->set_output(1, tensor_shape);
212 c->set_output(0, c->Scalar());
213 return OkStatus();
214 });
215
216REGISTER_OP("TensorListStack")
217 .Input("input_handle: variant")
218 .Input("element_shape: int32")
219 .Output("tensor: element_dtype")
220 .Attr("element_dtype: type")
221 .Attr("num_elements: int = -1")
222 .SetShapeFn([](shape_inference::InferenceContext* c) {
223 DataType element_dtype;
224 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
225 shape_inference::ShapeHandle element_shape = c->UnknownShape();
226 auto* handle_data = c->input_handle_shapes_and_types(0);
227 if (handle_data != nullptr && handle_data->size() > 1) {
228 return errors::InvalidArgument(
229 "Trying to read from list with wrong variant data.");
230 }
231 if (IsValidTensorListHandleData(handle_data)) {
232 const shape_inference::ShapeAndType& list_shape_type =
233 (*handle_data)[0];
234 if (list_shape_type.dtype != element_dtype) {
235 return errors::InvalidArgument(
236 "Trying to read from list with wrong element dtype. List has "
237 "type ",
238 DataTypeString(list_shape_type.dtype), " but expected type ",
239 DataTypeString(element_dtype));
240 }
241 shape_inference::ShapeHandle ignored;
242 TF_RETURN_IF_ERROR(
243 c->Merge(element_shape, list_shape_type.shape, &ignored));
244 element_shape = list_shape_type.shape;
245 }
246 shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
247 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
248 1, &element_shape_input));
249 TF_RETURN_IF_ERROR(
250 c->Merge(element_shape, element_shape_input, &element_shape));
251 int expected_num_elements = -1;
252 TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
253 shape_inference::ShapeHandle num_elements;
254 if (expected_num_elements == -1) {
255 num_elements = c->MakeShape({c->UnknownDim()});
256 } else {
257 num_elements = c->MakeShape({expected_num_elements});
258 }
259 shape_inference::ShapeHandle result;
260 TF_RETURN_IF_ERROR(c->Concatenate(num_elements, element_shape, &result));
261 c->set_output(0, result);
262 return OkStatus();
263 });
264
265Status TensorListConcatShapeInference(
266 shape_inference::InferenceContext* c,
267 shape_inference::ShapeHandle element_shape) {
268 DataType element_dtype;
269 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
270 auto* handle_data = c->input_handle_shapes_and_types(0);
271 if (handle_data != nullptr && handle_data->size() > 1) {
272 return errors::InvalidArgument(
273 "Trying to read from list with wrong variant data.");
274 }
275 if (IsValidTensorListHandleData(handle_data)) {
276 const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
277 if (list_shape_type.dtype != element_dtype) {
278 return errors::InvalidArgument(
279 "Trying to read from list with wrong element dtype. List has "
280 "type ",
281 DataTypeString(list_shape_type.dtype), " but expected type ",
282 DataTypeString(element_dtype));
283 }
284 shape_inference::ShapeHandle merged;
285 TF_RETURN_IF_ERROR(c->Merge(element_shape, list_shape_type.shape, &merged));
286 element_shape = merged;
287 }
288 if (c->RankKnown(element_shape)) {
289 shape_inference::ShapeHandle result;
290 TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result));
291 TF_RETURN_IF_ERROR(
292 c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result));
293 c->set_output(0, result);
294 } else {
295 c->set_output(0, c->UnknownShape());
296 }
297 c->set_output(1, c->MakeShape({c->UnknownDim()}));
298 return OkStatus();
299}
300
301REGISTER_OP("TensorListConcat")
302 .Input("input_handle: variant")
303 .Output("tensor: element_dtype")
304 .Output("lengths: int64")
305 .Attr("element_dtype: type")
306 .Attr("element_shape: shape = { unknown_rank: true }")
307 .SetShapeFn([](shape_inference::InferenceContext* c) {
308 PartialTensorShape raw_element_shape;
309 TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &raw_element_shape));
310 shape_inference::ShapeHandle element_shape;
311 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape,
312 &element_shape));
313 return TensorListConcatShapeInference(c, element_shape);
314 });
315
316REGISTER_OP("TensorListConcatV2")
317 .Input("input_handle: variant")
318 .Input("element_shape: shape_type")
319 .Input("leading_dims: int64")
320 .Output("tensor: element_dtype")
321 .Output("lengths: int64")
322 .Attr("element_dtype: type")
323 .Attr("shape_type: {int32, int64}")
324 .SetShapeFn([](shape_inference::InferenceContext* c) {
325 shape_inference::ShapeHandle element_shape;
326 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
327 1, &element_shape));
328 return TensorListConcatShapeInference(c, element_shape);
329 });
330
331REGISTER_OP("TensorListSplit")
332 .Input("tensor: element_dtype")
333 .Input("element_shape: shape_type")
334 .Input("lengths: int64")
335 .Output("output_handle: variant")
336 .Attr("element_dtype: type")
337 .Attr("shape_type: {int32, int64}")
338 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
339 "element_dtype"))
340 .SetShapeFn([](shape_inference::InferenceContext* c) {
341 c->set_output(0, c->Scalar());
342 DataType element_dtype;
343 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
344 shape_inference::ShapeHandle tensor_shape = c->input(0);
345 shape_inference::ShapeHandle ignored;
346 // Check that tensor is at least a vector.
347 TF_RETURN_IF_ERROR(c->WithRankAtLeast(tensor_shape, 1, &ignored));
348 // Check that lengths is a vector.
349 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
350 shape_inference::ShapeHandle element_shape_from_tensor_shape;
351 TF_RETURN_IF_ERROR(
352 c->Subshape(tensor_shape, 1, &element_shape_from_tensor_shape));
353 TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape({c->UnknownDim()}),
354 element_shape_from_tensor_shape,
355 &element_shape_from_tensor_shape));
356 shape_inference::ShapeHandle element_shape;
357 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
358 1, &element_shape));
359 TF_RETURN_IF_ERROR(c->Merge(element_shape_from_tensor_shape,
360 element_shape,
361 &element_shape_from_tensor_shape));
362 const FullTypeDef& ret_types = c->ret_types();
363 c->set_output_handle_shapes_and_types(
364 0, std::vector<shape_inference::ShapeAndType>{
365 {element_shape, element_dtype, ret_types.args(0)}});
366 return OkStatus();
367 });
368
369REGISTER_OP("TensorListFromTensor")
370 .Input("tensor: element_dtype")
371 .Input("element_shape: shape_type")
372 .Output("output_handle: variant")
373 .Attr("element_dtype: type")
374 .Attr("shape_type: {int32, int64}")
375 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
376 "element_dtype"))
377 .SetForwardTypeFn(full_type::UnaryContainerCreate(TFT_ARRAY, 0))
378 .SetShapeFn([](shape_inference::InferenceContext* c) {
379 c->set_output(0, c->Scalar());
380 DataType element_dtype;
381 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
382 shape_inference::ShapeHandle tensor_shape = c->input(0);
383 shape_inference::ShapeHandle tensor_shape_except_first_dim;
384 TF_RETURN_IF_ERROR(
385 c->Subshape(tensor_shape, 1, &tensor_shape_except_first_dim));
386 shape_inference::ShapeHandle element_shape;
387 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
388 1, &element_shape));
389 TF_RETURN_IF_ERROR(c->Merge(tensor_shape_except_first_dim, element_shape,
390 &tensor_shape_except_first_dim));
391 const FullTypeDef& ret_types = c->ret_types();
392 c->set_output_handle_shapes_and_types(
393 0, std::vector<shape_inference::ShapeAndType>{
394 {element_shape, element_dtype, ret_types.args(0)}});
395 return OkStatus();
396 });
397
398REGISTER_OP("TensorListElementShape")
399 .Input("input_handle: variant")
400 .Output("element_shape: shape_type")
401 .Attr("shape_type: {int32, int64}")
402 .SetShapeFn([](shape_inference::InferenceContext* c) {
403 auto* handle_data = c->input_handle_shapes_and_types(0);
404 // `TensorListElementShape` returns the scalar -1 if the rank of
405 // element_shape is unknown else returns the shape vector (with possibly
406 // unknown dims).
407 if (!IsValidTensorListHandleData(handle_data)) {
408 c->set_output(0, c->UnknownShape());
409 return OkStatus();
410 }
411 if (c->RankKnown((*handle_data)[0].shape)) {
412 c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape)));
413 } else {
414 c->set_output(0, c->UnknownShape());
415 }
416 return OkStatus();
417 });
418
419REGISTER_OP("TensorListReserve")
420 .Input("element_shape: shape_type")
421 .Input("num_elements: int32")
422 .Output("handle: variant")
423 .Attr("element_dtype: type")
424 .Attr("shape_type: {int32, int64}")
425 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
426 "element_dtype"))
427 .SetShapeFn([](shape_inference::InferenceContext* c) {
428 c->set_output(0, c->Scalar());
429 shape_inference::ShapeHandle element_shape;
430 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
431 0, &element_shape));
432 DataType element_dtype;
433 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
434 const FullTypeDef& ret_types = c->ret_types();
435 c->set_output_handle_shapes_and_types(
436 0, std::vector<shape_inference::ShapeAndType>{
437 {element_shape, element_dtype, ret_types.args(0)}});
438 return OkStatus();
439 });
440
441REGISTER_OP("TensorListGetItem")
442 .Input("input_handle: variant")
443 .Input("index: int32")
444 .Input("element_shape: int32")
445 .Output("item: element_dtype")
446 .Attr("element_dtype: type")
447 .SetShapeFn([](shape_inference::InferenceContext* c) {
448 DataType element_dtype;
449 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
450 auto* handle_data = c->input_handle_shapes_and_types(0);
451 shape_inference::ShapeHandle element_shape = c->UnknownShape();
452 if (IsValidTensorListHandleData(handle_data)) {
453 const shape_inference::ShapeAndType& list_shape_type =
454 (*handle_data)[0];
455 element_shape = list_shape_type.shape;
456 if (list_shape_type.dtype != element_dtype) {
457 return errors::InvalidArgument("Expected list with element dtype ",
458 DataTypeString(element_dtype),
459 " but got list with element dtype ",
460 DataTypeString(list_shape_type.dtype));
461 }
462 }
463 shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
464 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
465 2, &element_shape_input));
466 TF_RETURN_IF_ERROR(
467 c->Merge(element_shape, element_shape_input, &element_shape));
468 c->set_output(0, element_shape);
469 return OkStatus();
470 });
471
472REGISTER_OP("TensorListResize")
473 .Input("input_handle: variant")
474 .Input("size: int32")
475 .Output("output_handle: variant")
476 .SetShapeFn([](shape_inference::InferenceContext* c) {
477 // Check that `size` has scalar shape.
478 shape_inference::ShapeHandle size_shape = c->input(1);
479 shape_inference::ShapeHandle unused;
480 TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused));
481 c->set_output(0, c->Scalar());
482 auto* handle_data = c->input_handle_shapes_and_types(0);
483 if (IsValidTensorListHandleData(handle_data)) {
484 c->set_output_handle_shapes_and_types(0, *handle_data);
485 }
486 return OkStatus();
487 });
488
489REGISTER_OP("TensorListSetItem")
490 .Input("input_handle: variant")
491 .Input("index: int32")
492 .Input("item: element_dtype")
493 .Output("output_handle: variant")
494 .Attr("element_dtype: type")
495 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
496 "element_dtype"))
497 .SetForwardTypeFn(full_type::UnaryContainerAdd(TFT_ARRAY,
498 /*container_idx=*/0,
499 /*element_idx=*/2,
500 /*homogeneous=*/true))
501 .SetShapeFn([](shape_inference::InferenceContext* c) {
502 DataType element_dtype;
503 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
504 auto* handle_data = c->input_handle_shapes_and_types(0);
505 c->set_output(0, c->Scalar());
506 if (IsValidTensorListHandleData(handle_data)) {
507 const shape_inference::ShapeAndType& list_shape_type =
508 (*handle_data)[0];
509 shape_inference::ShapeHandle item_shape = c->input(2);
510 TF_RETURN_IF_ERROR(
511 c->Merge(item_shape, list_shape_type.shape, &item_shape));
512 c->set_output_handle_shapes_and_types(0, *handle_data);
513 } else {
514 const FullTypeDef& ret_types = c->ret_types();
515 c->set_output_handle_shapes_and_types(
516 0, std::vector<shape_inference::ShapeAndType>{
517 {c->UnknownShape(), element_dtype, ret_types.args(0)}});
518 }
519 return OkStatus();
520 });
521
522REGISTER_OP("TensorListGather")
523 .Input("input_handle: variant")
524 .Input("indices: int32")
525 .Input("element_shape: int32")
526 .Output("values: element_dtype")
527 .Attr("element_dtype: type")
528 .SetShapeFn([](shape_inference::InferenceContext* c) {
529 DataType element_dtype;
530 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
531 auto* handle_data = c->input_handle_shapes_and_types(0);
532 shape_inference::ShapeHandle element_shape = c->UnknownShape();
533 if (IsValidTensorListHandleData(handle_data)) {
534 const shape_inference::ShapeAndType& list_shape_type =
535 (*handle_data)[0];
536 element_shape = list_shape_type.shape;
537 if (list_shape_type.dtype != element_dtype) {
538 return errors::InvalidArgument("Expected list with element dtype ",
539 DataTypeString(element_dtype),
540 " but got list with element dtype ",
541 DataTypeString(list_shape_type.dtype));
542 }
543 }
544 shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
545 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
546 2, &element_shape_input));
547 TF_RETURN_IF_ERROR(
548 c->Merge(element_shape, element_shape_input, &element_shape));
549 shape_inference::ShapeHandle out;
550 TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
551 c->set_output(0, out);
552 return OkStatus();
553 });
554
555REGISTER_OP("TensorListScatter")
556 .Input("tensor: element_dtype")
557 .Input("indices: int32")
558 .Input("element_shape: shape_type")
559 .Output("output_handle: variant")
560 .Attr("element_dtype: type")
561 .Attr("shape_type: {int32, int64}")
562 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
563 "element_dtype"))
564 .SetShapeFn([](shape_inference::InferenceContext* c) {
565 DataType element_dtype;
566 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
567 shape_inference::ShapeHandle element_shape;
568 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
569 2, &element_shape));
570 const FullTypeDef& ret_types = c->ret_types();
571 c->set_output_handle_shapes_and_types(
572 0, std::vector<shape_inference::ShapeAndType>{
573 {element_shape, element_dtype, ret_types.args(0)}});
574 c->set_output(0, c->Scalar());
575 return OkStatus();
576 });
577
578REGISTER_OP("TensorListScatterV2")
579 .Input("tensor: element_dtype")
580 .Input("indices: int32")
581 .Input("element_shape: shape_type")
582 .Input("num_elements: int32")
583 .Output("output_handle: variant")
584 .Attr("element_dtype: type")
585 .Attr("shape_type: {int32, int64}")
586 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
587 "element_dtype"))
588 .SetShapeFn([](shape_inference::InferenceContext* c) {
589 DataType element_dtype;
590 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
591 shape_inference::ShapeHandle element_shape;
592 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
593 2, &element_shape));
594 const FullTypeDef& ret_types = c->ret_types();
595 c->set_output_handle_shapes_and_types(
596 0, std::vector<shape_inference::ShapeAndType>{
597 {element_shape, element_dtype, ret_types.args(0)}});
598 c->set_output(0, c->Scalar());
599 return OkStatus();
600 });
601
602REGISTER_OP("TensorListScatterIntoExistingList")
603 .Input("input_handle: variant")
604 .Input("tensor: element_dtype")
605 .Input("indices: int32")
606 .Output("output_handle: variant")
607 .Attr("element_dtype: type")
608 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
609 "element_dtype"))
610 .SetShapeFn([](shape_inference::InferenceContext* c) {
611 shape_inference::ShapeHandle ignored;
612 // Check that tensor is at least a vector.
613 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &ignored));
614 // Check that indices is a vector.
615 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
616
617 DataType element_dtype;
618 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
619 shape_inference::ShapeHandle element_shape = c->UnknownShape();
620
621 auto* handle_data = c->input_handle_shapes_and_types(0);
622 if (IsValidTensorListHandleData(handle_data)) {
623 TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
624 element_shape = GetElementShapeFromHandleData(*handle_data);
625 }
626 const FullTypeDef& ret_types = c->ret_types();
627 c->set_output_handle_shapes_and_types(
628 0, std::vector<shape_inference::ShapeAndType>{
629 {element_shape, element_dtype, ret_types.args(0)}});
630 c->set_output(0, c->Scalar());
631 return OkStatus();
632 });
633
634REGISTER_OP("TensorListConcatLists")
635 .Input("input_a: variant")
636 .Input("input_b: variant")
637 .Attr("element_dtype: type")
638 .Output("output: variant")
639 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_ARRAY,
640 "element_dtype"))
641 .SetShapeFn([](shape_inference::InferenceContext* c) {
642 auto input_a = c->input(0);
643 auto input_b = c->input(1);
644 TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a));
645 c->set_output(0, input_a);
646
647 DataType element_dtype;
648 TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
649
650 auto* handle_data_a = c->input_handle_shapes_and_types(0);
651 auto* handle_data_b = c->input_handle_shapes_and_types(1);
652 bool handle_data_a_nonempty = handle_data_a && !handle_data_a->empty();
653 bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
654 if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
655 const FullTypeDef& ret_types = c->ret_types();
656 c->set_output_handle_shapes_and_types(
657 0, {{c->UnknownShape(), element_dtype, ret_types.args(0)}});
658 return OkStatus();
659 }
660 shape_inference::ShapeAndType list_shape_type_a =
661 handle_data_a_nonempty ? handle_data_a->at(0) : handle_data_b->at(0);
662 const shape_inference::ShapeAndType& list_shape_type_b =
663 handle_data_b_nonempty ? handle_data_b->at(0) : handle_data_a->at(0);
664 if (list_shape_type_a.dtype != element_dtype) {
665 return errors::InvalidArgument("input_a.type != element_dtype: ",
666 DataTypeString(list_shape_type_a.dtype),
667 " vs. ", DataTypeString(element_dtype));
668 }
669 if (list_shape_type_b.dtype != element_dtype) {
670 return errors::InvalidArgument("input_b.type != element_dtype: ",
671 DataTypeString(list_shape_type_b.dtype),
672 " vs. ", DataTypeString(element_dtype));
673 }
674 TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape,
675 list_shape_type_b.shape,
676 &list_shape_type_a.shape));
677 c->set_output_handle_shapes_and_types(0, {list_shape_type_a});
678 return OkStatus();
679 });
680
681} // namespace
682} // namespace tensorflow
683