1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19#define EIGEN_USE_GPU
20#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21
22#include "tensorflow/core/kernels/list_kernels.h"
23
24#include <algorithm>
25#include <iterator>
26#include <limits>
27#include <memory>
28#include <utility>
29
30#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31#include "tensorflow/core/framework/allocator.h"
32#include "tensorflow/core/framework/op_kernel.h"
33#include "tensorflow/core/framework/register_types.h"
34#include "tensorflow/core/framework/tensor_shape.h"
35#include "tensorflow/core/framework/tensor_types.h"
36#include "tensorflow/core/framework/variant.h"
37#include "tensorflow/core/framework/variant_op_registry.h"
38#include "tensorflow/core/platform/errors.h"
39
40namespace tensorflow {
41
42typedef Eigen::ThreadPoolDevice CPUDevice;
43
44Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
45 if (t.shape() == TensorShape({})) {
46 if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
47 (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) {
48 *out = PartialTensorShape();
49 return OkStatus();
50 }
51 return errors::InvalidArgument(
52 "The only valid scalar shape tensor is the fully unknown shape "
53 "specified as -1.");
54 } else if (t.shape().dims() != 1) {
55 return errors::InvalidArgument("Shape must be at most rank 1 but is rank ",
56 t.shape().dims());
57 }
58 if (t.dtype() == DT_INT32) {
59 return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
60 t.NumElements(), out);
61 } else if (t.dtype() == DT_INT64) {
62 return PartialTensorShape::MakePartialShape(t.vec<int64_t>().data(),
63 t.NumElements(), out);
64 }
65 return errors::InvalidArgument(
66 "Expected an int32 or int64 shape tensor; found ",
67 DataTypeString(t.dtype()));
68}
69
70Status GetElementShapeFromInput(OpKernelContext* c,
71 const TensorList& tensor_list, int index,
72 PartialTensorShape* element_shape) {
73 TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
74 // Check that `element_shape` and `tensor_list.element_shape` are
75 // compatible and store the merged shape in `element_shape`.
76 PartialTensorShape tmp = *element_shape;
77 TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
78 return OkStatus();
79}
80
81Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
82 if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
83 return errors::InvalidArgument("Input list must be a scalar saw: ",
84 c->input(index).shape().DebugString());
85 }
86 const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
87 if (l == nullptr) {
88 return errors::InvalidArgument(
89 "Input handle is not a list. Saw: '",
90 c->input(index).scalar<Variant>()().DebugString(), "'");
91 }
92 *list = l;
93 return OkStatus();
94}
95
96Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index,
97 int32_t output_index,
98 const TensorList& input_list,
99 TensorList** output_list) {
100 // Attempt to forward the input tensor to the output if possible.
101 std::unique_ptr<Tensor> maybe_output = c->forward_input(
102 input_index, output_index, DT_VARIANT, TensorShape{},
103 c->input_memory_type(input_index), AllocatorAttributes());
104 Tensor* output_tensor;
105 if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
106 maybe_output->NumElements() == 1) {
107 output_tensor = maybe_output.get();
108 TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
109 if (tmp_out == nullptr) {
110 return errors::InvalidArgument(
111 "Expected input ", input_index, " to be a TensorList but saw ",
112 output_tensor->scalar<Variant>()().TypeName());
113 }
114 if (tmp_out->RefCountIsOne()) {
115 // Woohoo, forwarding succeeded!
116 c->set_output(output_index, *output_tensor);
117 *output_list = tmp_out;
118 return OkStatus();
119 }
120 }
121
122 // If forwarding is not possible allocate a new output tensor and copy
123 // the `input_list` to it.
124 AllocatorAttributes attr;
125 attr.set_on_host(true);
126 TF_RETURN_IF_ERROR(
127 c->allocate_output(output_index, {}, &output_tensor, attr));
128 output_tensor->scalar<Variant>()() = input_list.Copy();
129
130 *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
131 return OkStatus();
132}
133
134class EmptyTensorList : public OpKernel {
135 public:
136 explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
137 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
138 }
139
140 void Compute(OpKernelContext* ctx) override {
141 const Tensor& max_num_elements_t = ctx->input(1);
142 OP_REQUIRES(
143 ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
144 errors::InvalidArgument(
145 "max_num_elements expected to be a scalar ",
146 "but got shape: ", max_num_elements_t.shape().DebugString()));
147 Tensor* result;
148 AllocatorAttributes attr;
149 attr.set_on_host(true);
150 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
151 TensorList empty;
152 empty.element_dtype = element_dtype_;
153 empty.max_num_elements = max_num_elements_t.scalar<int32>()();
154 PartialTensorShape element_shape;
155 OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
156 empty.element_shape = element_shape;
157 result->scalar<Variant>()() = std::move(empty);
158 }
159
160 private:
161 DataType element_dtype_;
162};
163
164REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
165 EmptyTensorList);
166
167#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
168
169REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
170 .Device(DEVICE_GPU)
171 .HostMemory("element_shape")
172 .HostMemory("max_num_elements"),
173 EmptyTensorList);
174
175#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176
177REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
178 .Device(DEVICE_DEFAULT)
179 .HostMemory("element_shape")
180 .HostMemory("max_num_elements"),
181 EmptyTensorList);
182
183class TensorListPushBack : public OpKernel {
184 public:
185 explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
186 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
187 }
188
189 ~TensorListPushBack() override {}
190
191 void Compute(OpKernelContext* c) override {
192 const Tensor& input = c->input(1);
193 OP_REQUIRES(c, element_dtype_ == input.dtype(),
194 errors::InvalidArgument("Invalid data types; list elements ",
195 DataTypeString(element_dtype_),
196 " but tried to append ",
197 DataTypeString(input.dtype())));
198
199 const TensorList* l = nullptr;
200 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
201 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
202 errors::InvalidArgument(
203 "Tried to append a tensor with incompatible shape to a "
204 "list. Op element shape: ",
205 input.shape().DebugString(),
206 " list shape: ", l->element_shape.DebugString()));
207 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
208 errors::InvalidArgument("Invalid data types; op elements ",
209 DataTypeString(element_dtype_),
210 " but list elements ",
211 DataTypeString(l->element_dtype)));
212
213 if (l->max_num_elements != -1) {
214 OP_REQUIRES(
215 c, l->tensors().size() < l->max_num_elements,
216 errors::InvalidArgument("Tried to push item into a full list",
217 " list size: ", l->tensors().size(),
218 " max_num_elements: ", l->max_num_elements));
219 }
220
221 TensorList* output_list = nullptr;
222 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
223 output_list->tensors().push_back(input);
224 }
225
226 private:
227 DataType element_dtype_;
228};
229
230REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
231 TensorListPushBack);
232
233#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
234
235REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
236 TensorListPushBack);
237
238#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
239
240REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_DEFAULT),
241 TensorListPushBack);
242
243class TensorListLength : public OpKernel {
244 public:
245 explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
246 ~TensorListLength() override {}
247
248 void Compute(OpKernelContext* c) override {
249 const TensorList* l = nullptr;
250 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
251 Tensor* result;
252 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
253 result->scalar<int32>()() = l->tensors().size();
254 }
255};
256
257REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
258 TensorListLength);
259
260#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
261
262REGISTER_KERNEL_BUILDER(
263 Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
264 TensorListLength);
265
266#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
267
268REGISTER_KERNEL_BUILDER(
269 Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"),
270 TensorListLength);
271
272class TensorListElementShape : public OpKernel {
273 public:
274 explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
275
276 void Compute(OpKernelContext* c) override {
277 const TensorList* l = nullptr;
278 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
279 Tensor* result;
280 if (l->element_shape.unknown_rank()) {
281 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
282 if (result->dtype() == DT_INT32) {
283 result->scalar<int32>()() = -1;
284 } else {
285 result->scalar<int64_t>()() = -1;
286 }
287 } else {
288 OP_REQUIRES_OK(c, c->allocate_output(
289 0, TensorShape{l->element_shape.dims()}, &result));
290 for (int i = 0; i < l->element_shape.dims(); ++i) {
291 if (result->dtype() == DT_INT32) {
292 result->flat<int32>()(i) = l->element_shape.dim_size(i);
293 } else {
294 result->flat<int64_t>()(i) = l->element_shape.dim_size(i);
295 }
296 }
297 }
298 }
299};
300
301REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
302 TensorListElementShape);
303
304#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
305
306REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
307 .Device(DEVICE_GPU)
308 .HostMemory("element_shape"),
309 TensorListElementShape);
310
311#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
312
313REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
314 .Device(DEVICE_DEFAULT)
315 .HostMemory("element_shape"),
316 TensorListElementShape);
317
318class TensorListReserve : public OpKernel {
319 public:
320 explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
321 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
322 }
323
324 void Compute(OpKernelContext* c) override {
325 PartialTensorShape element_shape;
326 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
327 OP_REQUIRES(
328 c, TensorShapeUtils::IsScalar(c->input(1).shape()),
329 errors::InvalidArgument(
330 "The num_elements to reserve must be a tensor size 1, but got ",
331 c->input(1).shape()));
332 int32_t num_elements = c->input(1).scalar<int32>()();
333 OP_REQUIRES(c, num_elements >= 0,
334 errors::InvalidArgument("The num_elements to reserve must be a "
335 "non negative number, but got ",
336 num_elements));
337 TensorList output;
338 output.element_shape = element_shape;
339 output.element_dtype = element_dtype_;
340 output.tensors().resize(num_elements, Tensor(DT_INVALID));
341 Tensor* result;
342 AllocatorAttributes attr;
343 attr.set_on_host(true);
344 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
345 result->scalar<Variant>()() = std::move(output);
346 }
347
348 private:
349 DataType element_dtype_;
350};
351
352REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
353 TensorListReserve);
354
355#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
356
357REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
358 .Device(DEVICE_GPU)
359 .HostMemory("element_shape")
360 .HostMemory("num_elements"),
361 TensorListReserve);
362
363#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
364
365REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
366 .Device(DEVICE_DEFAULT)
367 .HostMemory("element_shape")
368 .HostMemory("num_elements"),
369 TensorListReserve);
370
371class TensorListResize : public OpKernel {
372 public:
373 explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
374
375 void Compute(OpKernelContext* c) override {
376 const TensorList* input_list = nullptr;
377 OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
378 OP_REQUIRES(c, TensorShapeUtils::IsScalar(c->input(1).shape()),
379 errors::InvalidArgument("size must be a scalar"));
380 int32_t size = c->input(1).scalar<int32>()();
381 OP_REQUIRES(
382 c, size >= 0,
383 errors::InvalidArgument(
384 "TensorListSlice expects size to be non-negative. Got: ", size));
385
386 std::unique_ptr<Tensor> maybe_result =
387 c->forward_input(0, 0, DT_VARIANT, TensorShape{},
388 c->input_memory_type(0), AllocatorAttributes());
389 if (maybe_result != nullptr) {
390 TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
391 if (out->RefCountIsOne()) {
392 // We are able to forward the input.
393 out->tensors().resize(size, Tensor(DT_INVALID));
394 c->set_output(0, *maybe_result);
395 return;
396 }
397 }
398
399 // We were not able to forward the input. Will have to resize from scratch.
400 Tensor* result;
401 AllocatorAttributes attr;
402 attr.set_on_host(true);
403 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
404 TensorList output_list;
405 output_list.element_shape = input_list->element_shape;
406 output_list.element_dtype = input_list->element_dtype;
407 output_list.max_num_elements = input_list->max_num_elements;
408 if (size > input_list->tensors().size()) {
409 output_list.tensors().insert(output_list.tensors().begin(),
410 input_list->tensors().begin(),
411 input_list->tensors().end());
412 // Add DT_INVALID tensors to the end of the list if the requested size
413 // is larger than the list length.
414 output_list.tensors().resize(size, Tensor(DT_INVALID));
415 } else {
416 output_list.tensors().insert(output_list.tensors().begin(),
417 input_list->tensors().begin(),
418 input_list->tensors().begin() + size);
419 }
420 result->scalar<Variant>()() = std::move(output_list);
421 }
422};
423
424REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
425 TensorListResize);
426
427#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
428
429REGISTER_KERNEL_BUILDER(
430 Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
431 TensorListResize);
432
433#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
434
435REGISTER_KERNEL_BUILDER(
436 Name("TensorListResize").Device(DEVICE_DEFAULT).HostMemory("size"),
437 TensorListResize);
438
439class TensorListSetItem : public OpKernel {
440 public:
441 explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
442 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
443 }
444
445 void Compute(OpKernelContext* c) override {
446 const TensorList* l = nullptr;
447 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
448 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
449 errors::InvalidArgument("Invalid data types; op elements ",
450 DataTypeString(element_dtype_),
451 " but list elements ",
452 DataTypeString(l->element_dtype)));
453 int32_t index = c->input(1).scalar<int32>()();
454 OP_REQUIRES(c, index < l->tensors().size(),
455 errors::InvalidArgument("Trying to modify element ", index,
456 " in a list with ", l->tensors().size(),
457 " elements."));
458 const Tensor& value = c->input(2);
459 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
460 errors::InvalidArgument(
461 "Tried to set a tensor with incompatible shape at a "
462 "list index. Item element shape: ",
463 value.shape().DebugString(),
464 " list shape: ", l->element_shape.DebugString()));
465 TensorList* output_list = nullptr;
466 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
467 output_list->tensors()[index] = value;
468 }
469
470 private:
471 DataType element_dtype_;
472};
473
474REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
475 TensorListSetItem);
476
477#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
478
479#define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \
480 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
481 .TypeConstraint<T>("element_dtype") \
482 .Device(DEVICE_GPU) \
483 .HostMemory("index"), \
484 TensorListSetItem);
485
486TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
487TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
488TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
489REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
490#undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
491
492#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
493
494#define REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(T) \
495 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
496 .TypeConstraint<T>("element_dtype") \
497 .Device(DEVICE_DEFAULT) \
498 .HostMemory("index"), \
499 TensorListSetItem);
500
501TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
502TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
503TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
504REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(bfloat16)
505#undef REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT
506
507class TensorListConcatLists : public OpKernel {
508 public:
509 explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
510 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
511 }
512
513 void Compute(OpKernelContext* c) override {
514 const TensorShape& tl_a_shape = c->input(0).shape();
515 const TensorShape& tl_b_shape = c->input(1).shape();
516 OP_REQUIRES(
517 c, tl_a_shape == tl_b_shape,
518 errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
519 tl_a_shape.DebugString(), " vs. ",
520 tl_b_shape.DebugString()));
521 AllocatorAttributes attr;
522 std::unique_ptr<Tensor> tl_alias = c->forward_input(
523 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
524 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
525
526 // tl_a may be aliased by tl_alias.
527 const Tensor& tl_a = c->input(0);
528 const Tensor& tl_b = c->input(1);
529
530 Tensor* output = nullptr;
531 bool ok_to_alias = tl_alias != nullptr;
532 if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
533 tl_alias->NumElements() > 0) {
534 auto tl_a_t = tl_alias->flat<Variant>();
535 for (int64_t i = 0; i < tl_alias->NumElements(); ++i) {
536 TensorList* aliased = tl_a_t(i).get<TensorList>();
537 if (aliased == nullptr || !aliased->RefCountIsOne()) {
538 ok_to_alias = false;
539 break;
540 }
541 }
542 if (ok_to_alias) {
543 c->set_output(0, *tl_alias);
544 output = tl_alias.get();
545 }
546 }
547 if (!ok_to_alias) {
548 // Couldn't alias the entire Tensor. We'll be conservative and not try
549 // to alias individual batch entries.
550 attr.set_on_host(true);
551 OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
552 }
553
554 auto output_t = output->flat<Variant>();
555 auto tl_a_t = tl_a.flat<Variant>();
556 auto tl_b_t = tl_b.flat<Variant>();
557
558 for (int64_t i = 0; i < tl_a.NumElements(); ++i) {
559 const TensorList* l_a = tl_a_t(i).get<TensorList>();
560 const TensorList* l_b = tl_b_t(i).get<TensorList>();
561 OP_REQUIRES(
562 c, l_a != nullptr,
563 errors::InvalidArgument("input_a is not a TensorList at index ", i,
564 ". Saw: '", tl_a_t(i).DebugString(), "'"));
565 OP_REQUIRES(
566 c, l_b != nullptr,
567 errors::InvalidArgument("input_b is not a TensorList at index ", i,
568 ". Saw: '", tl_b_t(i).DebugString(), "'"));
569 OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
570 errors::InvalidArgument(
571 "input_a[", i, "].dtype != element_dtype. Saw: ",
572 DataTypeString(l_a->element_dtype), " vs. ",
573 DataTypeString(element_dtype_)));
574 OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
575 errors::InvalidArgument(
576 "input_b[", i, "].dtype != element_dtype. Saw: ",
577 DataTypeString(l_b->element_dtype), " vs. ",
578 DataTypeString(element_dtype_)));
579 OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
580 errors::InvalidArgument(
581 "input_a and input_b TensorList element shapes are not "
582 "identical at index ",
583 i, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
584 l_b->element_shape.DebugString()));
585 if (ok_to_alias) {
586 TensorList* out = output_t(i).get<TensorList>();
587 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
588 std::back_inserter(out->tensors()));
589 } else {
590 TensorList out = l_a->Copy();
591 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
592 std::back_inserter(out.tensors()));
593 output_t(i) = std::move(out);
594 }
595 }
596 }
597
598 private:
599 DataType element_dtype_;
600};
601
602REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
603 TensorListConcatLists);
604
605#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
606
607REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
608 TensorListConcatLists);
609
610#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
611
612REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_DEFAULT),
613 TensorListConcatLists);
614
615#define REGISTER_TENSOR_LIST_OPS_CPU(T) \
616 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
617 .TypeConstraint<T>("element_dtype") \
618 .Device(DEVICE_CPU), \
619 TensorListStack<CPUDevice, T>) \
620 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
621 .TypeConstraint<T>("element_dtype") \
622 .Device(DEVICE_CPU), \
623 TensorListGather<CPUDevice, T>) \
624 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
625 .TypeConstraint<T>("element_dtype") \
626 .Device(DEVICE_CPU), \
627 TensorListConcat<CPUDevice, T>) \
628 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
629 .TypeConstraint<T>("element_dtype") \
630 .Device(DEVICE_CPU), \
631 TensorListConcat<CPUDevice, T>) \
632 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
633 .TypeConstraint<T>("element_dtype") \
634 .Device(DEVICE_CPU), \
635 TensorListGetItem<CPUDevice, T>) \
636 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
637 .TypeConstraint<T>("element_dtype") \
638 .Device(DEVICE_CPU), \
639 TensorListPopBack<CPUDevice, T>) \
640 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
641 .TypeConstraint<T>("element_dtype") \
642 .Device(DEVICE_CPU), \
643 TensorListFromTensor<CPUDevice, T>) \
644 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
645 .TypeConstraint<T>("element_dtype") \
646 .Device(DEVICE_CPU), \
647 TensorListScatter<CPUDevice, T>) \
648 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
649 .TypeConstraint<T>("element_dtype") \
650 .Device(DEVICE_CPU), \
651 TensorListScatter<CPUDevice, T>) \
652 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
653 .TypeConstraint<T>("element_dtype") \
654 .Device(DEVICE_CPU), \
655 TensorListScatterIntoExistingList<CPUDevice, T>) \
656 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
657 .TypeConstraint<T>("element_dtype") \
658 .Device(DEVICE_CPU), \
659 TensorListSplit<CPUDevice, T>) \
660 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
661 .TypeConstraint<T>("element_dtype") \
662 .Device(DEVICE_CPU), \
663 TensorListPushBackBatch<CPUDevice, T>)
664
665TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
666REGISTER_TENSOR_LIST_OPS_CPU(quint8);
667REGISTER_TENSOR_LIST_OPS_CPU(qint8);
668REGISTER_TENSOR_LIST_OPS_CPU(quint16);
669REGISTER_TENSOR_LIST_OPS_CPU(qint16);
670REGISTER_TENSOR_LIST_OPS_CPU(qint32);
671REGISTER_TENSOR_LIST_OPS_CPU(Variant);
672
673#undef REGISTER_TENSOR_LIST_OPS_CPU
674
675#define REGISTER_TENSOR_LIST_OPS_CPU(T)
676
677REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
678 TensorList,
679 TensorListBinaryAdd<CPUDevice>);
680
681REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
682 DEVICE_CPU, TensorList,
683 TensorListZerosLike<CPUDevice>);
684
685static Status TensorListDeviceCopy(
686 const TensorList& from, TensorList* to,
687 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
688 to->element_shape = from.element_shape;
689 to->element_dtype = from.element_dtype;
690 to->max_num_elements = from.max_num_elements;
691 to->tensors().reserve(from.tensors().size());
692 for (const Tensor& t : from.tensors()) {
693 to->tensors().emplace_back(t.dtype());
694 if (t.dtype() != DT_INVALID) {
695 TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
696 }
697 }
698 return OkStatus();
699}
700
701#define REGISTER_LIST_COPY(DIRECTION) \
702 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
703 TensorListDeviceCopy)
704
705REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
706REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
707REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
708
709REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
710
711#define REGISTER_TENSOR_LIST_OPS_DEFAULT(T) \
712 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
713 .TypeConstraint<T>("element_dtype") \
714 .HostMemory("element_shape") \
715 .Device(DEVICE_DEFAULT), \
716 TensorListStack<CPUDevice, T>) \
717 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
718 .TypeConstraint<T>("element_dtype") \
719 .HostMemory("indices") \
720 .HostMemory("element_shape") \
721 .Device(DEVICE_DEFAULT), \
722 TensorListGather<CPUDevice, T>) \
723 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
724 .TypeConstraint<T>("element_dtype") \
725 .HostMemory("lengths") \
726 .Device(DEVICE_DEFAULT), \
727 TensorListConcat<CPUDevice, T>) \
728 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
729 .TypeConstraint<T>("element_dtype") \
730 .HostMemory("leading_dims") \
731 .HostMemory("element_shape") \
732 .HostMemory("lengths") \
733 .Device(DEVICE_DEFAULT), \
734 TensorListConcat<CPUDevice, T>) \
735 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
736 .TypeConstraint<T>("element_dtype") \
737 .Device(DEVICE_DEFAULT) \
738 .HostMemory("index") \
739 .HostMemory("element_shape"), \
740 TensorListGetItem<CPUDevice, T>) \
741 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
742 .TypeConstraint<T>("element_dtype") \
743 .Device(DEVICE_DEFAULT) \
744 .HostMemory("element_shape"), \
745 TensorListPopBack<CPUDevice, T>) \
746 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
747 .TypeConstraint<T>("element_dtype") \
748 .Device(DEVICE_DEFAULT), \
749 TensorListPushBackBatch<CPUDevice, T>) \
750 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
751 .TypeConstraint<T>("element_dtype") \
752 .Device(DEVICE_DEFAULT) \
753 .HostMemory("element_shape"), \
754 TensorListFromTensor<CPUDevice, T>) \
755 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
756 .TypeConstraint<T>("element_dtype") \
757 .Device(DEVICE_DEFAULT) \
758 .HostMemory("element_shape") \
759 .HostMemory("indices"), \
760 TensorListScatter<CPUDevice, T>) \
761 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
762 .TypeConstraint<T>("element_dtype") \
763 .Device(DEVICE_DEFAULT) \
764 .HostMemory("element_shape") \
765 .HostMemory("num_elements") \
766 .HostMemory("indices"), \
767 TensorListScatter<CPUDevice, T>) \
768 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
769 .TypeConstraint<T>("element_dtype") \
770 .Device(DEVICE_DEFAULT) \
771 .HostMemory("indices"), \
772 TensorListScatterIntoExistingList<CPUDevice, T>) \
773 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
774 .TypeConstraint<T>("element_dtype") \
775 .Device(DEVICE_DEFAULT) \
776 .HostMemory("element_shape") \
777 .HostMemory("lengths"), \
778 TensorListSplit<CPUDevice, T>)
779
780TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_DEFAULT);
781TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_DEFAULT);
782TF_CALL_bfloat16(REGISTER_TENSOR_LIST_OPS_DEFAULT);
783TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_DEFAULT);
784
785#undef REGISTER_TENSOR_LIST_OPS_DEFAULT
786} // namespace tensorflow
787