1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
16#define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
17
18#define EIGEN_USE_THREADS
19#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20#define EIGEN_USE_GPU
21#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_types.h"
28#include "tensorflow/core/framework/variant.h"
29#include "tensorflow/core/framework/variant_op_registry.h"
30#include "tensorflow/core/kernels/concat_lib.h"
31#include "tensorflow/core/kernels/fill_functor.h"
32#include "tensorflow/core/kernels/tensor_list.h"
33#include "tensorflow/core/kernels/tensor_list_util.h"
34#include "tensorflow/core/lib/core/coding.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/core/refcount.h"
37#include "tensorflow/core/lib/gtl/array_slice.h"
38#include "tensorflow/core/platform/platform.h"
39#include "tensorflow/core/util/tensor_ops_util.h"
40#include "tensorflow/core/util/util.h"
41
42// stream.h isn't available in some platforms such as Android, iOS, and
43// ChromiumOS. Only include it for platforms that PluggableDevice is tested on.
44#if !defined(PLUGGABLE_DEVICE_SUPPORTED) && \
45 (__x86_64__ || __i386__ || defined(__APPLE__) || defined(_WIN32)) && \
46 !defined(ANDROID) && !defined(__ANDROID__) && !TARGET_OS_IOS && \
47 !defined(PLATFORM_CHROMIUMOS)
48#define PLUGGABLE_DEVICE_SUPPORTED
49#endif
50
51#ifdef PLUGGABLE_DEVICE_SUPPORTED
52#include "tensorflow/compiler/xla/stream_executor/stream.h"
53#endif
54
55namespace tensorflow {
56
57typedef Eigen::ThreadPoolDevice CPUDevice;
58
59Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
60
61Status GetElementShapeFromInput(OpKernelContext* c,
62 const TensorList& tensor_list, int index,
63 PartialTensorShape* element_shape);
64
65Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
66
67Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index,
68 int32_t output_index,
69 const TensorList& input_list,
70 TensorList** output_list);
71
72// TODO(penporn): Move this to a proper place.
73inline bool IsPluggableDevice(OpKernelContext* c) {
74 return c->op_device_context() && c->op_device_context()->IsPluggableDevice();
75}
76
77template <typename Device, typename T>
78inline void SetZero(OpKernelContext* ctx, Tensor& tensor) {
79#ifdef PLUGGABLE_DEVICE_SUPPORTED
80 if (IsPluggableDevice(ctx)) {
81 auto ptr =
82 se::DeviceMemoryBase(tensor.flat<T>().data(), tensor.TotalBytes());
83 auto stream = ctx->op_device_context()->stream();
84 auto result = stream->ThenMemZero(&ptr, tensor.TotalBytes()).ok();
85 DCHECK_EQ(true, result);
86 } else {
87#endif // PLUGGABLE_DEVICE_SUPPORTED
88 functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
89 tensor.flat<T>());
90#ifdef PLUGGABLE_DEVICE_SUPPORTED
91 }
92#endif // PLUGGABLE_DEVICE_SUPPORTED
93}
94
95template <typename T>
96inline void CopyTensorPluggableDevice(OpKernelContext* ctx, Tensor& src,
97 Tensor& dst) {
98#ifdef PLUGGABLE_DEVICE_SUPPORTED
99 auto src_t = src.unaligned_flat<T>();
100 auto dst_t = dst.flat<T>();
101 DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
102 auto src_ptr = se::DeviceMemoryBase(src_t.data(), src.TotalBytes());
103 auto dst_ptr = se::DeviceMemoryBase(dst_t.data(), dst.TotalBytes());
104 auto stream = ctx->op_device_context()->stream();
105 auto result = stream->ThenMemcpy(&dst_ptr, src_ptr, src.TotalBytes()).ok();
106 DCHECK_EQ(true, result);
107#else
108 LOG(FATAL) // Crash OK.
109 << "PluggableDevice is not supported on this platform.";
110#endif // PLUGGABLE_DEVICE_SUPPORTED
111}
112
113template <typename Device, typename T>
114inline void CopyTensor(OpKernelContext* ctx, Tensor& src, Tensor& dst) {
115 auto src_t = src.unaligned_flat<T>();
116 auto dst_t = dst.flat<T>();
117 dst_t.device(ctx->eigen_device<Device>()) = src_t;
118}
119
120template <typename T>
121void ConcatPluggableDevice(
122 OpKernelContext* context,
123 const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
124 inputs,
125 typename TTypes<T, 2>::Matrix* output) {
126#ifdef PLUGGABLE_DEVICE_SUPPORTED
127 DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
128
129 se::Stream* stream = context->op_device_context()->stream();
130
131 size_t num_inputs = inputs.size();
132 std::vector<ptrdiff_t> sizes;
133 sizes.reserve(num_inputs);
134 int64 row_size = 0;
135 for (const auto& input : inputs) {
136 sizes.push_back(input->dimension(1));
137 row_size += sizes.back();
138 }
139
140 T* out = &(*output)(0, 0);
141 std::vector<const T*> inp;
142 inp.reserve(num_inputs);
143 for (const auto& input : inputs) {
144 inp.push_back(&(*input)(0, 0));
145 }
146 const int64 dim0 = output->dimension(0);
147 for (int64 i = 0; i < dim0; ++i) {
148 for (int64 j = 0; j < num_inputs; ++j) {
149 auto size = sizes[j];
150 se::DeviceMemoryBase out_base{out, size * sizeof(T)};
151 se::DeviceMemoryBase inp_base{const_cast<T*>(inp[j]), size * sizeof(T)};
152 stream->ThenMemcpy(&out_base, inp_base, size * sizeof(T));
153 out += size;
154 inp[j] += size;
155 }
156 }
157#else
158 LOG(FATAL) // Crash OK.
159 << "PluggableDevice is not supported on this platform.";
160#endif // PLUGGABLE_DEVICE_SUPPORTED
161}
162
163template <typename Device, typename T>
164class TensorListStack : public OpKernel {
165 public:
166 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
167 ConstMatrixVector;
168 explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
169 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
170 OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
171 }
172
173 void Compute(OpKernelContext* c) override {
174 const TensorList* tensor_list = nullptr;
175 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
176 OP_REQUIRES(
177 c, element_dtype_ == tensor_list->element_dtype,
178 errors::InvalidArgument(
179 "Invalid data types; op elements ", DataTypeString(element_dtype_),
180 " but list elements ", DataTypeString(tensor_list->element_dtype)));
181 if (num_elements_ != -1) {
182 OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
183 errors::InvalidArgument(
184 "Operation expected a list with ", num_elements_,
185 " elements but got a list with ",
186 tensor_list->tensors().size(), " elements."));
187 }
188 PartialTensorShape partial_element_shape;
189 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
190 &partial_element_shape));
191 OP_REQUIRES(
192 c,
193 partial_element_shape.IsFullyDefined() ||
194 !tensor_list->tensors().empty(),
195 errors::InvalidArgument("Tried to stack elements of an empty ",
196 "list with non-fully-defined element_shape: ",
197 partial_element_shape.DebugString()));
198
199 // Check that `element_shape` input tensor is compatible with the shapes of
200 // element tensors.
201 if (!tensor_list->element_shape.IsFullyDefined()) {
202 for (int i = 0; i < tensor_list->tensors().size(); ++i) {
203 const Tensor& t = tensor_list->tensors()[i];
204 if (t.dtype() != DT_INVALID) {
205 PartialTensorShape tmp = partial_element_shape;
206 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
207 }
208 }
209 }
210
211 // Compute the shape of the output tensor by pre-pending the leading dim to
212 // the element_shape.
213 TensorShape element_shape;
214 OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
215 errors::InvalidArgument(
216 "Tried to stack list which only contains uninitialized ",
217 "tensors and has a non-fully-defined element_shape: ",
218 partial_element_shape.DebugString()));
219 TensorShape output_shape = element_shape;
220 output_shape.InsertDim(0, tensor_list->tensors().size());
221 Tensor* output;
222 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
223 if (output->NumElements() == 0) {
224 return;
225 }
226
227 ConstMatrixVector inputs_flat;
228 inputs_flat.reserve(tensor_list->tensors().size());
229 Tensor zeros;
230 for (const auto& t : tensor_list->tensors()) {
231 if (t.dtype() != DT_INVALID) {
232 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
233 t.shaped<T, 2>({1, t.NumElements()})));
234 } else {
235 if (!zeros.NumElements()) {
236 AllocatorAttributes attr;
237 if (element_dtype_ == DT_VARIANT) {
238 attr.set_on_host(true);
239 }
240 OP_REQUIRES_OK(
241 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
242 SetZero<Device, T>(c, zeros);
243 }
244 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
245 const_cast<const Tensor&>(zeros).shaped<T, 2>(
246 {1, zeros.NumElements()})));
247 }
248 }
249 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
250
251#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
252 if (std::is_same<Device, Eigen::GpuDevice>::value) {
253 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
254 return;
255 }
256#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257 if (IsPluggableDevice(c)) {
258 ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
259 } else {
260 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
261 }
262 }
263
264 private:
265 int num_elements_;
266 DataType element_dtype_;
267};
268
269template <typename Device, typename T>
270class TensorListGetItem : public OpKernel {
271 public:
272 explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
273 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
274 }
275
276 void Compute(OpKernelContext* c) override {
277 const TensorList* l = nullptr;
278 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
279 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
280 errors::InvalidArgument("Invalid data types; op elements ",
281 DataTypeString(element_dtype_),
282 " but list elements ",
283 DataTypeString(l->element_dtype)));
284 int32_t index = c->input(1).scalar<int32>()();
285 OP_REQUIRES(c, index < l->tensors().size(),
286 errors::InvalidArgument("Trying to access element ", index,
287 " in a list with ", l->tensors().size(),
288 " elements."));
289 if (l->tensors()[index].dtype() != DT_INVALID) {
290 c->set_output(0, l->tensors()[index]);
291 } else {
292 PartialTensorShape partial_element_shape;
293 OP_REQUIRES_OK(
294 c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
295 TensorShape element_shape;
296 // If l->element_shape and the element_shape input are both not fully
297 // defined, try to infer the shape from other list elements. This requires
298 // that all initialized list elements have the same shape.
299 // NOTE(srbs): This might be a performance bottleneck since we are
300 // iterating over the entire list here. This is necessary for feature
301 // parity with TensorArray.read. TensorArray has a mode in which all
302 // elements are required to be of the same shape, TensorList does not.
303 // In that mode TensorArray sets the array's element_shape on the first
304 // write call. We could do something similar here if needed.
305 if (!partial_element_shape.IsFullyDefined()) {
306 for (const Tensor& t : l->tensors()) {
307 if (t.dtype() != DT_INVALID) {
308 PartialTensorShape tmp = partial_element_shape;
309 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
310 }
311 }
312 }
313 OP_REQUIRES(
314 c, partial_element_shape.AsTensorShape(&element_shape),
315 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
316 "element_shape is not fully defined: ",
317 partial_element_shape.DebugString(),
318 " and no list element is set."));
319 Tensor* result;
320 AllocatorAttributes attr;
321 if (element_dtype_ == DT_VARIANT) {
322 attr.set_on_host(true);
323 }
324 OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
325 SetZero<Device, T>(c, *result);
326 }
327 }
328
329 private:
330 DataType element_dtype_;
331};
332
333template <typename Device, typename T>
334class TensorListPopBack : public OpKernel {
335 public:
336 explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
337 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
338 }
339
340 void Compute(OpKernelContext* c) override {
341 const TensorList* l = nullptr;
342 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
343 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
344 errors::InvalidArgument("Invalid data types; op elements ",
345 DataTypeString(element_dtype_),
346 " but list elements ",
347 DataTypeString(l->element_dtype)));
348
349 OP_REQUIRES(c, !l->tensors().empty(),
350 errors::InvalidArgument("Trying to pop from an empty list."));
351
352 const Tensor& t = l->tensors().back();
353 if (t.dtype() != DT_INVALID) {
354 c->set_output(1, t);
355 } else {
356 PartialTensorShape partial_element_shape;
357 OP_REQUIRES_OK(
358 c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
359 TensorShape element_shape;
360 OP_REQUIRES(
361 c, partial_element_shape.AsTensorShape(&element_shape),
362 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
363 "element_shape is not fully defined.",
364 partial_element_shape.DebugString()));
365 Tensor* result;
366 AllocatorAttributes attr;
367 if (element_dtype_ == DT_VARIANT) {
368 attr.set_on_host(true);
369 }
370 OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
371 SetZero<Device, T>(c, *result);
372 }
373
374 TensorList* output_list = nullptr;
375 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
376 output_list->tensors().pop_back();
377 }
378
379 private:
380 DataType element_dtype_;
381};
382
383template <typename Device, typename T>
384class TensorListConcat : public OpKernel {
385 public:
386 using ConstMatrixVector =
387 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
388 explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
389 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
390 if (c->HasAttr("element_shape")) {
391 OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape_));
392 }
393 }
394
395 void Compute(OpKernelContext* c) override {
396 PartialTensorShape element_shape_except_first_dim;
397 if (!element_shape_.unknown_rank()) {
398 element_shape_except_first_dim = PartialTensorShape(
399 gtl::ArraySlice<int64_t>(element_shape_.dim_sizes()).subspan(1));
400 }
401 // Check that the input Variant tensor is indeed a TensorList and has the
402 // correct element type.
403 const TensorList* tensor_list = nullptr;
404 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
405 OP_REQUIRES(
406 c, element_dtype_ == tensor_list->element_dtype,
407 errors::InvalidArgument(
408 "Invalid data types; op elements ", DataTypeString(element_dtype_),
409 " but list elements ", DataTypeString(tensor_list->element_dtype)));
410 // The leading dimension of all list elements if they are all the same.
411 // This is used as the leading dim of uninitialized tensors in the list
412 // if leading_dims is not provided.
413 int64_t first_dim = -1;
414 if (c->num_inputs() > 1) {
415 // TensorListConcatV2
416 PartialTensorShape element_shape;
417 OP_REQUIRES_OK(
418 c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
419 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
420 errors::InvalidArgument(
421 "Concat requires elements to be at least vectors, ",
422 "found scalars instead."));
423 // Split `element_shape` into `first_dim` and
424 // `element_shape_except_first_dim`.
425 first_dim = element_shape.dim_size(0);
426 element_shape_except_first_dim = element_shape;
427 element_shape_except_first_dim.RemoveDim(0);
428 }
429 // If the TensorList is empty, element_shape_except_first_dim must be fully
430 // defined.
431 OP_REQUIRES(c,
432 !tensor_list->tensors().empty() ||
433 element_shape_except_first_dim.IsFullyDefined(),
434 errors::InvalidArgument(
435 "All except the first dimension must be fully defined ",
436 "when concating an empty tensor list. element_shape: ",
437 element_shape_except_first_dim.DebugString()));
438 // 1. Check that `element_shape_except_first_dim` input tensor is
439 // compatible with the shapes of element tensors.
440 // 2. Check that the elements have the same shape except the first dim.
441 // 3. If `first_dim` is known, check that it is compatible with the leading
442 // dims of all elements.
443 // 4. If `first_dim` is unknown (-1), check whether all initialized
444 // elements have the same leading dim and if so set `first_dim` to that
445 // value.
446 if (!tensor_list->element_shape.IsFullyDefined()) {
447 bool check_dim = (first_dim == -1);
448 int64_t inferred_first_dim = first_dim;
449 for (int i = 0; i < tensor_list->tensors().size(); ++i) {
450 const Tensor& t = tensor_list->tensors()[i];
451 if (t.dtype() != DT_INVALID) {
452 PartialTensorShape tmp = element_shape_except_first_dim;
453 OP_REQUIRES(
454 c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
455 errors::InvalidArgument("Concat saw a scalar shape at index ", i,
456 " but requires at least vectors."));
457 TensorShape shape_except_first_dim = TensorShape(
458 gtl::ArraySlice<int64_t>(t.shape().dim_sizes()).subspan(1));
459 OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
460 &element_shape_except_first_dim));
461 OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
462 errors::InvalidArgument(
463 "First entry of element_shape input does not match ",
464 "the first dim of list element at index: ", i,
465 " Expected: ", first_dim,
466 " Actual: ", t.shape().dim_size(0)));
467 if (check_dim) {
468 if (inferred_first_dim == -1) {
469 inferred_first_dim = t.shape().dim_size(0);
470 } else if (inferred_first_dim != t.shape().dim_size(0)) {
471 inferred_first_dim = -1;
472 check_dim = false;
473 }
474 }
475 }
476 }
477 first_dim = inferred_first_dim;
478 }
479 TensorShape output_shape;
480 OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape),
481 errors::InvalidArgument(
482 "Trying to concat list with only uninitialized tensors ",
483 "but element_shape_except_first_dim is not fully defined: ",
484 element_shape_except_first_dim.DebugString()));
485 // Build the lengths_tensor and leading dim of the output tensor by
486 // iterating over all element tensors.
487 Tensor* lengths_tensor = nullptr;
488 OP_REQUIRES_OK(c, c->allocate_output(1,
489 TensorShape({static_cast<int64_t>(
490 tensor_list->tensors().size())}),
491 &lengths_tensor));
492 auto lengths_tensor_vec = lengths_tensor->vec<int64_t>();
493 int64_t leading_dim = 0;
494 for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
495 int64_t dim;
496 if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
497 dim = tensor_list->tensors()[i].shape().dim_size(0);
498 } else {
499 // If leading_dims is not provided or does not contain an entry for
500 // index i use the inferred `first_dim` if set.
501 if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
502 first_dim != -1) {
503 dim = first_dim;
504 } else {
505 OP_REQUIRES(c, c->num_inputs() > 2,
506 errors::InvalidArgument(
507 "Concating lists with uninitialized tensors is not ",
508 "supported in this version of TensorListConcat. ",
509 "Consider updating your GraphDef to run the newer ",
510 "version."));
511 OP_REQUIRES(c, i < c->input(2).NumElements(),
512 errors::InvalidArgument(
513 "List contains uninitialized tensor at index ", i,
514 " but leading_dims has only ",
515 c->input(2).NumElements(), " elements."));
516 dim = c->input(2).vec<int64_t>()(i);
517 }
518 }
519 leading_dim += dim;
520 lengths_tensor_vec(i) = dim;
521 }
522 output_shape.InsertDim(0, leading_dim);
523 Tensor* output;
524 // Allocate the output tensor and fill it up with the concated element
525 // tensors.
526 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
527 if (output->NumElements() == 0) {
528 return;
529 }
530
531 ConstMatrixVector inputs_flat;
532 inputs_flat.reserve(tensor_list->tensors().size());
533 // Store the zeros tensors in a vector to prevent them from being GC'ed till
534 // concat is complete.
535 std::vector<Tensor> zeros_vec;
536 for (int i = 0; i < tensor_list->tensors().size(); i++) {
537 const Tensor& element_tensor = tensor_list->tensors()[i];
538 if (element_tensor.dtype() != DT_INVALID) {
539 if (element_tensor.NumElements() > 0) {
540 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
541 element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
542 }
543 } else {
544 AllocatorAttributes attr;
545 if (element_dtype_ == DT_VARIANT) {
546 attr.set_on_host(true);
547 }
548 TensorShape element_shape = output_shape;
549 element_shape.set_dim(0, lengths_tensor_vec(i));
550 zeros_vec.emplace_back();
551 Tensor& zeros = zeros_vec.back();
552 OP_REQUIRES_OK(
553 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
554 SetZero<Device, T>(c, zeros);
555 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
556 const_cast<const Tensor&>(zeros).shaped<T, 2>(
557 {1, zeros.NumElements()})));
558 }
559 }
560 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
561
562#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
563 if (std::is_same<Device, Eigen::GpuDevice>::value) {
564 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
565 return;
566 }
567#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
568 if (IsPluggableDevice(c)) {
569 ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
570 } else {
571 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
572 }
573 }
574
575 private:
576 DataType element_dtype_;
577 PartialTensorShape element_shape_;
578};
579
580template <typename Device, typename T>
581class TensorListSplit : public OpKernel {
582 public:
583 TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
584
585 void Compute(OpKernelContext* c) override {
586 Tensor* output_tensor;
587 AllocatorAttributes attr;
588 attr.set_on_host(true);
589 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
590 PartialTensorShape element_shape;
591 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
592 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
593 errors::InvalidArgument(
594 "TensorListSplit requires element_shape to be at least of ",
595 "rank 1, but saw: ", element_shape.DebugString()));
596 TensorList output_list;
597 const Tensor& input_tensor = c->input(0);
598 output_list.element_dtype = input_tensor.dtype();
599 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
600 errors::InvalidArgument(
601 "Tensor must be at least a vector, but saw shape: ",
602 input_tensor.shape().DebugString()));
603 TensorShape tensor_shape_without_first_dim(input_tensor.shape());
604 tensor_shape_without_first_dim.RemoveDim(0);
605 PartialTensorShape element_shape_without_first_dim;
606 if (!element_shape.unknown_rank()) {
607 element_shape_without_first_dim =
608 PartialTensorShape(element_shape.dim_sizes());
609 element_shape_without_first_dim.RemoveDim(0);
610 }
611 OP_REQUIRES(c,
612 element_shape_without_first_dim.IsCompatibleWith(
613 tensor_shape_without_first_dim),
614 errors::InvalidArgument(
615 "tensor shape ", input_tensor.shape().DebugString(),
616 " is not compatible with element_shape ",
617 element_shape.DebugString()));
618 output_list.element_shape = element_shape;
619 const Tensor& lengths = c->input(2);
620 OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
621 errors::InvalidArgument(
622 "Expected lengths to be a vector, received shape: ",
623 lengths.shape().DebugString()));
624 output_list.tensors().reserve(lengths.shape().dim_size(0));
625
626 const auto copy_tensor = IsPluggableDevice(c)
627 ? &CopyTensorPluggableDevice<T>
628 : &CopyTensor<Device, T>;
629
630 int64_t start = 0;
631 int64_t end = 0;
632 for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
633 int64_t length = lengths.vec<int64_t>()(i);
634 OP_REQUIRES(
635 c, length >= 0,
636 errors::InvalidArgument("Invalid value in lengths: ", length));
637 end = start + length;
638 OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
639 errors::InvalidArgument("Attempting to slice [", start, ", ",
640 end, "] from tensor with length ",
641 input_tensor.shape().dim_size(0)));
642 Tensor tmp = input_tensor.Slice(start, end);
643 start = end;
644 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
645 // prevent this.
646 Tensor aligned;
647 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
648 copy_tensor(c, tmp, aligned);
649 output_list.tensors().emplace_back(aligned);
650 }
651 OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
652 errors::InvalidArgument(
653 "Unused values in tensor. Length of tensor: ",
654 input_tensor.shape().dim_size(0), " Values used: ", end));
655 output_tensor->scalar<Variant>()() = std::move(output_list);
656 }
657};
658
659template <typename Device, typename T>
660class TensorListGather : public OpKernel {
661 public:
662 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
663 ConstMatrixVector;
664 explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
665 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
666 }
667
668 void Compute(OpKernelContext* c) override {
669 const TensorList* tensor_list = nullptr;
670 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
671 OP_REQUIRES(
672 c, element_dtype_ == tensor_list->element_dtype,
673 errors::InvalidArgument(
674 "Invalid data types; op elements ", DataTypeString(element_dtype_),
675 " but list elements ", DataTypeString(tensor_list->element_dtype)));
676 const Tensor& indices = c->input(1);
677 PartialTensorShape partial_element_shape;
678 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
679 &partial_element_shape));
680 OP_REQUIRES(
681 c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
682 errors::InvalidArgument("Tried to gather 0-elements from "
683 "a list with non-fully-defined shape: ",
684 partial_element_shape.DebugString()));
685
686 // Check that `element_shape` input tensor is compatible with the shapes of
687 // element tensors.
688 if (!tensor_list->element_shape.IsFullyDefined()) {
689 for (int index = 0; index < indices.NumElements(); ++index) {
690 const int i = indices.flat<int32>()(index);
691 const Tensor& t = tensor_list->tensors()[i];
692 if (t.dtype() != DT_INVALID) {
693 PartialTensorShape tmp = partial_element_shape;
694 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
695 }
696 }
697 }
698
699 // Compute the shape of the output tensor by pre-pending the leading dim to
700 // the element_shape.
701 TensorShape element_shape;
702 OP_REQUIRES(
703 c, partial_element_shape.AsTensorShape(&element_shape),
704 errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
705 "list with non-fully-defined element_shape: ",
706 partial_element_shape.DebugString()));
707 TensorShape output_shape = element_shape;
708 output_shape.InsertDim(0, indices.NumElements());
709 Tensor* output;
710 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
711 if (output->NumElements() == 0) {
712 return;
713 }
714
715 ConstMatrixVector inputs_flat;
716 inputs_flat.reserve(indices.NumElements());
717 Tensor zeros;
718 for (int index = 0; index < indices.NumElements(); ++index) {
719 const int i = indices.flat<int32>()(index);
720 OP_REQUIRES(
721 c, i < tensor_list->tensors().size(),
722 errors::InvalidArgument("Index ", i, " out o range; list only has ",
723 tensor_list->tensors().size(), " elements."));
724 const Tensor& t = tensor_list->tensors()[i];
725 if (t.dtype() != DT_INVALID) {
726 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
727 t.shaped<T, 2>({1, t.NumElements()})));
728 } else {
729 if (!zeros.NumElements()) {
730 AllocatorAttributes attr;
731 if (element_dtype_ == DT_VARIANT) {
732 attr.set_on_host(true);
733 }
734 OP_REQUIRES_OK(
735 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
736 SetZero<Device, T>(c, zeros);
737 }
738 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
739 const_cast<const Tensor&>(zeros).shaped<T, 2>(
740 {1, zeros.NumElements()})));
741 }
742 }
743 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
744
745#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
746 if (std::is_same<Device, Eigen::GpuDevice>::value) {
747 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
748 return;
749 }
750#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
751 if (IsPluggableDevice(c)) {
752 ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
753 } else {
754 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
755 }
756 }
757
758 private:
759 DataType element_dtype_;
760};
761
762template <typename Device, typename T>
763class TensorListFromTensor : public OpKernel {
764 public:
765 TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
766
767 void Compute(OpKernelContext* c) override {
768 Tensor* output_tensor;
769 AllocatorAttributes attr;
770 attr.set_on_host(true);
771 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
772 PartialTensorShape element_shape;
773 OP_REQUIRES(
774 c, !TensorShapeUtils::IsMatrixOrHigher(c->input(1).shape()),
775 errors::InvalidArgument(
776 "TensorListFromTensor: element_shape must be at most rank 1 but ",
777 "has the shape of ", c->input(1).shape().DebugString()));
778 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
779 TensorList output_list;
780 const Tensor& t = c->input(0);
781 output_list.element_dtype = t.dtype();
782 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
783 errors::InvalidArgument(
784 "Tensor must be at least a vector, but saw shape: ",
785 t.shape().DebugString()));
786 TensorShape output_shape(t.shape());
787 output_shape.RemoveDim(0);
788 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
789 errors::InvalidArgument(
790 "Specified a list with shape ", element_shape.DebugString(),
791 " from a tensor with shape ", output_shape.DebugString()));
792 output_list.element_shape = element_shape;
793 output_list.tensors().reserve(t.shape().dim_size(0));
794
795 const auto copy_tensor = IsPluggableDevice(c)
796 ? &CopyTensorPluggableDevice<T>
797 : &CopyTensor<Device, T>;
798
799 for (int i = 0; i < t.shape().dim_size(0); ++i) {
800 Tensor tmp = t.Slice(i, i + 1);
801 TensorShape tmp_shape = tmp.shape();
802 tmp_shape.RemoveDim(0);
803 OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
804 errors::Unknown("Unexpected shape error."));
805 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
806 // prevent this.
807 Tensor aligned;
808 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
809 copy_tensor(c, tmp, aligned);
810 output_list.tensors().push_back(aligned);
811 }
812 output_tensor->scalar<Variant>()() = std::move(output_list);
813 }
814};
815
816// Scatters values in `value` into `list`. Assumes that `indices` are valid.
817template <typename Device, typename T>
818Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
819 TensorList* list) {
820 const auto copy_tensor = IsPluggableDevice(c) ? &CopyTensorPluggableDevice<T>
821 : &CopyTensor<Device, T>;
822 for (int index = 0; index < indices.NumElements(); ++index) {
823 const int i = indices.flat<int32>()(index);
824 Tensor tmp = value.Slice(index, index + 1);
825 TensorShape tmp_shape = tmp.shape();
826 tmp_shape.RemoveDim(0);
827 if (!tmp.CopyFrom(tmp, tmp_shape)) {
828 return errors::Unknown("Unexpected shape error.");
829 }
830 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
831 // prevent this.
832 Tensor aligned;
833 TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
834 // TODO(apassos) do all slices in a single kernel invocation instead of
835 // many small ones.
836 copy_tensor(c, tmp, aligned);
837 std::swap(list->tensors()[i], aligned);
838 }
839 return OkStatus();
840}
841
842template <typename Device, typename T>
843class TensorListScatterIntoExistingList : public OpKernel {
844 public:
845 TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
846
847 void Compute(OpKernelContext* c) override {
848 const TensorList* l = nullptr;
849 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
850 const Tensor& input_tensor = c->input(1);
851 const Tensor& indices = c->input(2);
852
853 // Check that inputs are valid.
854 OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
855 errors::InvalidArgument(
856 "Invalid data types; input tensor type: ",
857 DataTypeString(input_tensor.dtype()),
858 " list element_type: ", DataTypeString(l->element_dtype)));
859 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
860 errors::InvalidArgument(
861 "Tensor must be at least a vector, but saw shape: ",
862 input_tensor.shape().DebugString()));
863 OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
864 errors::InvalidArgument(
865 "Expected indices to be a vector, but received shape: ",
866 indices.shape().DebugString()));
867 OP_REQUIRES(
868 c, indices.NumElements() == input_tensor.shape().dim_size(0),
869 errors::InvalidArgument(
870 "Expected len(indices) == tensor.shape[0], but saw: ",
871 indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
872
873 // Resize the list if needed to accommodate all indices.
874 TensorList* output_list = nullptr;
875 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
876 const auto indices_vec = indices.vec<int32>();
877 int32_t max_index =
878 (indices.NumElements() == 0)
879 ? -1
880 : *std::max_element(indices_vec.data(),
881 indices_vec.data() + indices.NumElements());
882 if (max_index + 1 > output_list->tensors().size()) {
883 output_list->tensors().resize(max_index + 1);
884 }
885
886 // Scatter the values.
887 OP_REQUIRES_OK(c,
888 Scatter<Device, T>(c, input_tensor, indices, output_list));
889 }
890};
891
892template <typename Device, typename T>
893class TensorListScatter : public OpKernel {
894 public:
895 TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
896
897 void Compute(OpKernelContext* c) override {
898 Tensor* output_tensor;
899 AllocatorAttributes attr;
900 attr.set_on_host(true);
901 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
902 Tensor indices = c->input(1);
903 PartialTensorShape element_shape;
904 OP_REQUIRES(
905 c, !TensorShapeUtils::IsMatrixOrHigher(c->input(2).shape()),
906 errors::InvalidArgument(
907 "TensorListScatter: element_shape must be at most rank 1 but has ",
908 "the shape of ", c->input(2).shape().DebugString()));
909 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
910 // TensorListScatterV2 passes the num_elements input, TensorListScatter does
911 // not.
912 int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
913 OP_REQUIRES(c, num_elements >= -1,
914 errors::InvalidArgument(
915 "TensorListScatter expects num_elements >= -1, found: ",
916 num_elements));
917 TensorList output_list;
918 const Tensor& input_tensor = c->input(0);
919 output_list.element_dtype = input_tensor.dtype();
920 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
921 errors::InvalidArgument(
922 "Tensor must be at least a vector, but saw shape: ",
923 input_tensor.shape().DebugString()));
924 TensorShape output_shape(input_tensor.shape());
925 output_shape.RemoveDim(0);
926 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
927 errors::InvalidArgument(
928 "Specified a list with shape ", element_shape.DebugString(),
929 " from a tensor with shape ", output_shape.DebugString()));
930 output_list.element_shape = element_shape;
931
932 OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
933 errors::InvalidArgument(
934 "Invalid number of rows in input tensor. Expected: ",
935 indices.NumElements(),
936 " Actual: ", input_tensor.shape().dim_size(0)));
937
938 // Validate indices and resize output_list.tensors to fit the highest index.
939 {
940 int highest_index = -1;
941 for (int index = 0; index < indices.NumElements(); ++index) {
942 const int i = indices.flat<int32>()(index);
943 OP_REQUIRES(
944 c, i >= 0,
945 errors::InvalidArgument(
946 "Indices in TensorListScatter must all be non-negative."));
947 OP_REQUIRES(c, num_elements == -1 || i < num_elements,
948 errors::InvalidArgument(
949 "TensorListScatter: Trying to scatter at index ", i,
950 " in list with size ", num_elements));
951 if (i > highest_index) {
952 highest_index = i;
953 }
954 }
955 output_list.tensors().resize(std::max(highest_index + 1, num_elements),
956 Tensor(DT_INVALID));
957 }
958
959 OP_REQUIRES_OK(c,
960 Scatter<Device, T>(c, input_tensor, indices, &output_list));
961 output_tensor->scalar<Variant>()() = std::move(output_list);
962 }
963};
964
965template <typename Device>
966Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
967 const TensorList& b, TensorList* out) {
968 return TensorListBinaryAdd(c, a, b, out, BinaryAddTensors<Device>);
969}
970
971template <typename Device>
972Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
973 TensorList* y) {
974 return TensorListZerosLike(c, x, y, ZerosLikeTensor<Device>);
975}
976
977template <typename Device, typename T>
978class TensorListPushBackBatch : public OpKernel {
979 public:
980 explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
981 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
982 }
983
984 void Compute(OpKernelContext* c) override {
985 const Tensor& input = c->input(1);
986 OP_REQUIRES(c, element_dtype_ == input.dtype(),
987 errors::InvalidArgument("Invalid data types; list elements ",
988 DataTypeString(element_dtype_),
989 " but tried to append ",
990 DataTypeString(input.dtype())));
991 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
992 errors::InvalidArgument(
993 "Expected tensor to be at least a vector, but saw shape: ",
994 input.shape().DebugString()));
995
996 const TensorShape& tls_shape = c->input(0).shape();
997
998 // For purposes of input forwarding, we want the least restrictive
999 // AllocatorAttributes possible. If we need to allocate later,
1000 // we'll request the DT_VARIANT be allocated on host.
1001 AllocatorAttributes attr;
1002
1003 std::unique_ptr<Tensor> tls_alias = c->forward_input(
1004 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
1005 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
1006
1007 bool ok_to_alias = tls_alias != nullptr;
1008 if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
1009 tls_alias->NumElements() > 0) {
1010 auto alias_t = tls_alias->flat<Variant>();
1011 for (int i = 0; i < tls_alias->NumElements(); ++i) {
1012 TensorList* tl_i = alias_t(i).get<TensorList>();
1013 if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
1014 ok_to_alias = false;
1015 break;
1016 }
1017 }
1018 }
1019 const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
1020
1021 OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
1022 errors::InvalidArgument(
1023 "Expected input_handles dtype to be Variant, but saw: ",
1024 DataTypeString(tls.dtype())));
1025 OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
1026 errors::InvalidArgument(
1027 "Expected input_handles to be a vector, but saw shape: ",
1028 tls_shape.DebugString()));
1029 const int64_t batch_size = tls.NumElements();
1030 OP_REQUIRES(c, input.dim_size(0) == batch_size,
1031 errors::InvalidArgument(
1032 "Expected tensor.shape[0] == input_handles.size, but saw ",
1033 input.dim_size(0), " vs. ", batch_size));
1034 auto tls_t = tls.vec<Variant>();
1035
1036 TensorShape input_element_shape = input.shape();
1037 input_element_shape.RemoveDim(0);
1038 std::vector<const TensorList*> tl_batch;
1039 for (int64_t b = 0; b < batch_size; ++b) {
1040 const TensorList* l = tls_t(b).get<TensorList>();
1041 OP_REQUIRES(c, l != nullptr,
1042 errors::InvalidArgument("Input handle at index ", b,
1043 " is not a list. Saw: '",
1044 tls_t(b).DebugString(), "'"));
1045 OP_REQUIRES(
1046 c, l->element_shape.IsCompatibleWith(input_element_shape),
1047 errors::InvalidArgument(
1048 "Tried to append a tensor with incompatible shape to a "
1049 "list at index ",
1050 b, ". Op element shape: ", input_element_shape.DebugString(),
1051 " list shape: ", l->element_shape.DebugString()));
1052 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
1053 errors::InvalidArgument(
1054 "Invalid data type at index ", b, "; op elements ",
1055 DataTypeString(element_dtype_), " but list elements ",
1056 DataTypeString(l->element_dtype)));
1057 tl_batch.push_back(l);
1058 }
1059
1060 Tensor* result;
1061
1062 if (ok_to_alias) {
1063 result = tls_alias.get();
1064 c->set_output(0, *result);
1065 } else {
1066 // DT_VARIANT tensors always allocated on host.
1067 AllocatorAttributes attr;
1068 attr.set_on_host(true);
1069 OP_REQUIRES_OK(
1070 c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
1071 }
1072
1073 if (batch_size == 0) {
1074 return;
1075 }
1076
1077 auto input_t = input.flat_outer_dims<T, 2>();
1078 auto result_t = result->vec<Variant>();
1079
1080 for (int64_t b = 0; b < batch_size; ++b) {
1081 if (!ok_to_alias) {
1082 result_t(b) = tl_batch[b]->Copy();
1083 }
1084 TensorList* output = result_t(b).get<TensorList>();
1085 DCHECK(output != nullptr);
1086 Tensor frame;
1087 OP_REQUIRES_OK(
1088 c, c->allocate_temp(element_dtype_, input_element_shape, &frame));
1089 if (input_element_shape.num_elements() > 0) {
1090 auto frame_t = frame.flat<T>();
1091 // TODO(penporn): Get this if out of the batch loop.
1092 if (IsPluggableDevice(c)) {
1093 // The chip method need Eigen Device, so need to use Tensor.Slice
1094 // instead of chip for pluggable device. The input should be reshaped
1095 // to 2-D and so can be sliced by batch dim.
1096 auto input_t_shape =
1097 TensorShape({input_t.dimension(0), input_t.dimension(1)});
1098 auto input_reshaped = Tensor();
1099 OP_REQUIRES(c, input_reshaped.CopyFrom(input, input_t_shape),
1100 errors::Unknown("Unexpected shape error."));
1101
1102 auto input_batch = input_reshaped.Slice(b, b + 1);
1103 CopyTensorPluggableDevice<T>(c, input_batch, frame);
1104 } else {
1105 frame_t.device(c->eigen_device<Device>()) =
1106 input_t.template chip<0>(b);
1107 }
1108 }
1109 output->tensors().push_back(std::move(frame));
1110 }
1111 }
1112
1113 private:
1114 DataType element_dtype_;
1115};
1116
1117} // namespace tensorflow
1118
1119#undef PLUGGABLE_DEVICE_SUPPORTED
1120#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
1121