1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ |
16 | #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
20 | #define EIGEN_USE_GPU |
21 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
22 | |
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/register_types.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_types.h" |
28 | #include "tensorflow/core/framework/variant.h" |
29 | #include "tensorflow/core/framework/variant_op_registry.h" |
30 | #include "tensorflow/core/kernels/concat_lib.h" |
31 | #include "tensorflow/core/kernels/fill_functor.h" |
32 | #include "tensorflow/core/kernels/tensor_list.h" |
33 | #include "tensorflow/core/kernels/tensor_list_util.h" |
34 | #include "tensorflow/core/lib/core/coding.h" |
35 | #include "tensorflow/core/lib/core/errors.h" |
36 | #include "tensorflow/core/lib/core/refcount.h" |
37 | #include "tensorflow/core/lib/gtl/array_slice.h" |
38 | #include "tensorflow/core/platform/platform.h" |
39 | #include "tensorflow/core/util/tensor_ops_util.h" |
40 | #include "tensorflow/core/util/util.h" |
41 | |
42 | // stream.h isn't available in some platforms such as Android, iOS, and |
43 | // ChromiumOS. Only include it for platforms that PluggableDevice is tested on. |
44 | #if !defined(PLUGGABLE_DEVICE_SUPPORTED) && \ |
45 | (__x86_64__ || __i386__ || defined(__APPLE__) || defined(_WIN32)) && \ |
46 | !defined(ANDROID) && !defined(__ANDROID__) && !TARGET_OS_IOS && \ |
47 | !defined(PLATFORM_CHROMIUMOS) |
48 | #define PLUGGABLE_DEVICE_SUPPORTED |
49 | #endif |
50 | |
51 | #ifdef PLUGGABLE_DEVICE_SUPPORTED |
52 | #include "tensorflow/compiler/xla/stream_executor/stream.h" |
53 | #endif |
54 | |
55 | namespace tensorflow { |
56 | |
57 | typedef Eigen::ThreadPoolDevice CPUDevice; |
58 | |
59 | Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); |
60 | |
61 | Status GetElementShapeFromInput(OpKernelContext* c, |
62 | const TensorList& tensor_list, int index, |
63 | PartialTensorShape* element_shape); |
64 | |
65 | Status GetInputList(OpKernelContext* c, int index, const TensorList** list); |
66 | |
67 | Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, |
68 | int32_t output_index, |
69 | const TensorList& input_list, |
70 | TensorList** output_list); |
71 | |
72 | // TODO(penporn): Move this to a proper place. |
73 | inline bool IsPluggableDevice(OpKernelContext* c) { |
74 | return c->op_device_context() && c->op_device_context()->IsPluggableDevice(); |
75 | } |
76 | |
77 | template <typename Device, typename T> |
78 | inline void SetZero(OpKernelContext* ctx, Tensor& tensor) { |
79 | #ifdef PLUGGABLE_DEVICE_SUPPORTED |
80 | if (IsPluggableDevice(ctx)) { |
81 | auto ptr = |
82 | se::DeviceMemoryBase(tensor.flat<T>().data(), tensor.TotalBytes()); |
83 | auto stream = ctx->op_device_context()->stream(); |
84 | auto result = stream->ThenMemZero(&ptr, tensor.TotalBytes()).ok(); |
85 | DCHECK_EQ(true, result); |
86 | } else { |
87 | #endif // PLUGGABLE_DEVICE_SUPPORTED |
88 | functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(), |
89 | tensor.flat<T>()); |
90 | #ifdef PLUGGABLE_DEVICE_SUPPORTED |
91 | } |
92 | #endif // PLUGGABLE_DEVICE_SUPPORTED |
93 | } |
94 | |
95 | template <typename T> |
96 | inline void CopyTensorPluggableDevice(OpKernelContext* ctx, Tensor& src, |
97 | Tensor& dst) { |
98 | #ifdef PLUGGABLE_DEVICE_SUPPORTED |
99 | auto src_t = src.unaligned_flat<T>(); |
100 | auto dst_t = dst.flat<T>(); |
101 | DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())); |
102 | auto src_ptr = se::DeviceMemoryBase(src_t.data(), src.TotalBytes()); |
103 | auto dst_ptr = se::DeviceMemoryBase(dst_t.data(), dst.TotalBytes()); |
104 | auto stream = ctx->op_device_context()->stream(); |
105 | auto result = stream->ThenMemcpy(&dst_ptr, src_ptr, src.TotalBytes()).ok(); |
106 | DCHECK_EQ(true, result); |
107 | #else |
108 | LOG(FATAL) // Crash OK. |
109 | << "PluggableDevice is not supported on this platform." ; |
110 | #endif // PLUGGABLE_DEVICE_SUPPORTED |
111 | } |
112 | |
113 | template <typename Device, typename T> |
114 | inline void CopyTensor(OpKernelContext* ctx, Tensor& src, Tensor& dst) { |
115 | auto src_t = src.unaligned_flat<T>(); |
116 | auto dst_t = dst.flat<T>(); |
117 | dst_t.device(ctx->eigen_device<Device>()) = src_t; |
118 | } |
119 | |
120 | template <typename T> |
121 | void ConcatPluggableDevice( |
122 | OpKernelContext* context, |
123 | const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& |
124 | inputs, |
125 | typename TTypes<T, 2>::Matrix* output) { |
126 | #ifdef PLUGGABLE_DEVICE_SUPPORTED |
127 | DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())); |
128 | |
129 | se::Stream* stream = context->op_device_context()->stream(); |
130 | |
131 | size_t num_inputs = inputs.size(); |
132 | std::vector<ptrdiff_t> sizes; |
133 | sizes.reserve(num_inputs); |
134 | int64 row_size = 0; |
135 | for (const auto& input : inputs) { |
136 | sizes.push_back(input->dimension(1)); |
137 | row_size += sizes.back(); |
138 | } |
139 | |
140 | T* out = &(*output)(0, 0); |
141 | std::vector<const T*> inp; |
142 | inp.reserve(num_inputs); |
143 | for (const auto& input : inputs) { |
144 | inp.push_back(&(*input)(0, 0)); |
145 | } |
146 | const int64 dim0 = output->dimension(0); |
147 | for (int64 i = 0; i < dim0; ++i) { |
148 | for (int64 j = 0; j < num_inputs; ++j) { |
149 | auto size = sizes[j]; |
150 | se::DeviceMemoryBase out_base{out, size * sizeof(T)}; |
151 | se::DeviceMemoryBase inp_base{const_cast<T*>(inp[j]), size * sizeof(T)}; |
152 | stream->ThenMemcpy(&out_base, inp_base, size * sizeof(T)); |
153 | out += size; |
154 | inp[j] += size; |
155 | } |
156 | } |
157 | #else |
158 | LOG(FATAL) // Crash OK. |
159 | << "PluggableDevice is not supported on this platform." ; |
160 | #endif // PLUGGABLE_DEVICE_SUPPORTED |
161 | } |
162 | |
163 | template <typename Device, typename T> |
164 | class TensorListStack : public OpKernel { |
165 | public: |
166 | typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> |
167 | ConstMatrixVector; |
168 | explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) { |
169 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
170 | OP_REQUIRES_OK(c, c->GetAttr("num_elements" , &num_elements_)); |
171 | } |
172 | |
173 | void Compute(OpKernelContext* c) override { |
174 | const TensorList* tensor_list = nullptr; |
175 | OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); |
176 | OP_REQUIRES( |
177 | c, element_dtype_ == tensor_list->element_dtype, |
178 | errors::InvalidArgument( |
179 | "Invalid data types; op elements " , DataTypeString(element_dtype_), |
180 | " but list elements " , DataTypeString(tensor_list->element_dtype))); |
181 | if (num_elements_ != -1) { |
182 | OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_, |
183 | errors::InvalidArgument( |
184 | "Operation expected a list with " , num_elements_, |
185 | " elements but got a list with " , |
186 | tensor_list->tensors().size(), " elements." )); |
187 | } |
188 | PartialTensorShape partial_element_shape; |
189 | OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1, |
190 | &partial_element_shape)); |
191 | OP_REQUIRES( |
192 | c, |
193 | partial_element_shape.IsFullyDefined() || |
194 | !tensor_list->tensors().empty(), |
195 | errors::InvalidArgument("Tried to stack elements of an empty " , |
196 | "list with non-fully-defined element_shape: " , |
197 | partial_element_shape.DebugString())); |
198 | |
199 | // Check that `element_shape` input tensor is compatible with the shapes of |
200 | // element tensors. |
201 | if (!tensor_list->element_shape.IsFullyDefined()) { |
202 | for (int i = 0; i < tensor_list->tensors().size(); ++i) { |
203 | const Tensor& t = tensor_list->tensors()[i]; |
204 | if (t.dtype() != DT_INVALID) { |
205 | PartialTensorShape tmp = partial_element_shape; |
206 | OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); |
207 | } |
208 | } |
209 | } |
210 | |
211 | // Compute the shape of the output tensor by pre-pending the leading dim to |
212 | // the element_shape. |
213 | TensorShape element_shape; |
214 | OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape), |
215 | errors::InvalidArgument( |
216 | "Tried to stack list which only contains uninitialized " , |
217 | "tensors and has a non-fully-defined element_shape: " , |
218 | partial_element_shape.DebugString())); |
219 | TensorShape output_shape = element_shape; |
220 | output_shape.InsertDim(0, tensor_list->tensors().size()); |
221 | Tensor* output; |
222 | OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); |
223 | if (output->NumElements() == 0) { |
224 | return; |
225 | } |
226 | |
227 | ConstMatrixVector inputs_flat; |
228 | inputs_flat.reserve(tensor_list->tensors().size()); |
229 | Tensor zeros; |
230 | for (const auto& t : tensor_list->tensors()) { |
231 | if (t.dtype() != DT_INVALID) { |
232 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
233 | t.shaped<T, 2>({1, t.NumElements()}))); |
234 | } else { |
235 | if (!zeros.NumElements()) { |
236 | AllocatorAttributes attr; |
237 | if (element_dtype_ == DT_VARIANT) { |
238 | attr.set_on_host(true); |
239 | } |
240 | OP_REQUIRES_OK( |
241 | c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); |
242 | SetZero<Device, T>(c, zeros); |
243 | } |
244 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
245 | const_cast<const Tensor&>(zeros).shaped<T, 2>( |
246 | {1, zeros.NumElements()}))); |
247 | } |
248 | } |
249 | auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); |
250 | |
251 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
252 | if (std::is_same<Device, Eigen::GpuDevice>::value) { |
253 | ConcatGPU<T>(c, inputs_flat, output, &output_flat); |
254 | return; |
255 | } |
256 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
257 | if (IsPluggableDevice(c)) { |
258 | ConcatPluggableDevice<T>(c, inputs_flat, &output_flat); |
259 | } else { |
260 | ConcatCPU<T>(c->device(), inputs_flat, &output_flat); |
261 | } |
262 | } |
263 | |
264 | private: |
265 | int num_elements_; |
266 | DataType element_dtype_; |
267 | }; |
268 | |
269 | template <typename Device, typename T> |
270 | class TensorListGetItem : public OpKernel { |
271 | public: |
272 | explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) { |
273 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
274 | } |
275 | |
276 | void Compute(OpKernelContext* c) override { |
277 | const TensorList* l = nullptr; |
278 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
279 | OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
280 | errors::InvalidArgument("Invalid data types; op elements " , |
281 | DataTypeString(element_dtype_), |
282 | " but list elements " , |
283 | DataTypeString(l->element_dtype))); |
284 | int32_t index = c->input(1).scalar<int32>()(); |
285 | OP_REQUIRES(c, index < l->tensors().size(), |
286 | errors::InvalidArgument("Trying to access element " , index, |
287 | " in a list with " , l->tensors().size(), |
288 | " elements." )); |
289 | if (l->tensors()[index].dtype() != DT_INVALID) { |
290 | c->set_output(0, l->tensors()[index]); |
291 | } else { |
292 | PartialTensorShape partial_element_shape; |
293 | OP_REQUIRES_OK( |
294 | c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape)); |
295 | TensorShape element_shape; |
296 | // If l->element_shape and the element_shape input are both not fully |
297 | // defined, try to infer the shape from other list elements. This requires |
298 | // that all initialized list elements have the same shape. |
299 | // NOTE(srbs): This might be a performance bottleneck since we are |
300 | // iterating over the entire list here. This is necessary for feature |
301 | // parity with TensorArray.read. TensorArray has a mode in which all |
302 | // elements are required to be of the same shape, TensorList does not. |
303 | // In that mode TensorArray sets the array's element_shape on the first |
304 | // write call. We could do something similar here if needed. |
305 | if (!partial_element_shape.IsFullyDefined()) { |
306 | for (const Tensor& t : l->tensors()) { |
307 | if (t.dtype() != DT_INVALID) { |
308 | PartialTensorShape tmp = partial_element_shape; |
309 | OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); |
310 | } |
311 | } |
312 | } |
313 | OP_REQUIRES( |
314 | c, partial_element_shape.AsTensorShape(&element_shape), |
315 | errors::InvalidArgument("Trying to read an uninitialized tensor but " , |
316 | "element_shape is not fully defined: " , |
317 | partial_element_shape.DebugString(), |
318 | " and no list element is set." )); |
319 | Tensor* result; |
320 | AllocatorAttributes attr; |
321 | if (element_dtype_ == DT_VARIANT) { |
322 | attr.set_on_host(true); |
323 | } |
324 | OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr)); |
325 | SetZero<Device, T>(c, *result); |
326 | } |
327 | } |
328 | |
329 | private: |
330 | DataType element_dtype_; |
331 | }; |
332 | |
333 | template <typename Device, typename T> |
334 | class TensorListPopBack : public OpKernel { |
335 | public: |
336 | explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) { |
337 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
338 | } |
339 | |
340 | void Compute(OpKernelContext* c) override { |
341 | const TensorList* l = nullptr; |
342 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
343 | OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
344 | errors::InvalidArgument("Invalid data types; op elements " , |
345 | DataTypeString(element_dtype_), |
346 | " but list elements " , |
347 | DataTypeString(l->element_dtype))); |
348 | |
349 | OP_REQUIRES(c, !l->tensors().empty(), |
350 | errors::InvalidArgument("Trying to pop from an empty list." )); |
351 | |
352 | const Tensor& t = l->tensors().back(); |
353 | if (t.dtype() != DT_INVALID) { |
354 | c->set_output(1, t); |
355 | } else { |
356 | PartialTensorShape partial_element_shape; |
357 | OP_REQUIRES_OK( |
358 | c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape)); |
359 | TensorShape element_shape; |
360 | OP_REQUIRES( |
361 | c, partial_element_shape.AsTensorShape(&element_shape), |
362 | errors::InvalidArgument("Trying to read an uninitialized tensor but " , |
363 | "element_shape is not fully defined." , |
364 | partial_element_shape.DebugString())); |
365 | Tensor* result; |
366 | AllocatorAttributes attr; |
367 | if (element_dtype_ == DT_VARIANT) { |
368 | attr.set_on_host(true); |
369 | } |
370 | OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr)); |
371 | SetZero<Device, T>(c, *result); |
372 | } |
373 | |
374 | TensorList* output_list = nullptr; |
375 | OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
376 | output_list->tensors().pop_back(); |
377 | } |
378 | |
379 | private: |
380 | DataType element_dtype_; |
381 | }; |
382 | |
383 | template <typename Device, typename T> |
384 | class TensorListConcat : public OpKernel { |
385 | public: |
386 | using ConstMatrixVector = |
387 | std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>; |
388 | explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) { |
389 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
390 | if (c->HasAttr("element_shape" )) { |
391 | OP_REQUIRES_OK(c, c->GetAttr("element_shape" , &element_shape_)); |
392 | } |
393 | } |
394 | |
395 | void Compute(OpKernelContext* c) override { |
396 | PartialTensorShape element_shape_except_first_dim; |
397 | if (!element_shape_.unknown_rank()) { |
398 | element_shape_except_first_dim = PartialTensorShape( |
399 | gtl::ArraySlice<int64_t>(element_shape_.dim_sizes()).subspan(1)); |
400 | } |
401 | // Check that the input Variant tensor is indeed a TensorList and has the |
402 | // correct element type. |
403 | const TensorList* tensor_list = nullptr; |
404 | OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); |
405 | OP_REQUIRES( |
406 | c, element_dtype_ == tensor_list->element_dtype, |
407 | errors::InvalidArgument( |
408 | "Invalid data types; op elements " , DataTypeString(element_dtype_), |
409 | " but list elements " , DataTypeString(tensor_list->element_dtype))); |
410 | // The leading dimension of all list elements if they are all the same. |
411 | // This is used as the leading dim of uninitialized tensors in the list |
412 | // if leading_dims is not provided. |
413 | int64_t first_dim = -1; |
414 | if (c->num_inputs() > 1) { |
415 | // TensorListConcatV2 |
416 | PartialTensorShape element_shape; |
417 | OP_REQUIRES_OK( |
418 | c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape)); |
419 | OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1, |
420 | errors::InvalidArgument( |
421 | "Concat requires elements to be at least vectors, " , |
422 | "found scalars instead." )); |
423 | // Split `element_shape` into `first_dim` and |
424 | // `element_shape_except_first_dim`. |
425 | first_dim = element_shape.dim_size(0); |
426 | element_shape_except_first_dim = element_shape; |
427 | element_shape_except_first_dim.RemoveDim(0); |
428 | } |
429 | // If the TensorList is empty, element_shape_except_first_dim must be fully |
430 | // defined. |
431 | OP_REQUIRES(c, |
432 | !tensor_list->tensors().empty() || |
433 | element_shape_except_first_dim.IsFullyDefined(), |
434 | errors::InvalidArgument( |
435 | "All except the first dimension must be fully defined " , |
436 | "when concating an empty tensor list. element_shape: " , |
437 | element_shape_except_first_dim.DebugString())); |
438 | // 1. Check that `element_shape_except_first_dim` input tensor is |
439 | // compatible with the shapes of element tensors. |
440 | // 2. Check that the elements have the same shape except the first dim. |
441 | // 3. If `first_dim` is known, check that it is compatible with the leading |
442 | // dims of all elements. |
443 | // 4. If `first_dim` is unknown (-1), check whether all initialized |
444 | // elements have the same leading dim and if so set `first_dim` to that |
445 | // value. |
446 | if (!tensor_list->element_shape.IsFullyDefined()) { |
447 | bool check_dim = (first_dim == -1); |
448 | int64_t inferred_first_dim = first_dim; |
449 | for (int i = 0; i < tensor_list->tensors().size(); ++i) { |
450 | const Tensor& t = tensor_list->tensors()[i]; |
451 | if (t.dtype() != DT_INVALID) { |
452 | PartialTensorShape tmp = element_shape_except_first_dim; |
453 | OP_REQUIRES( |
454 | c, TensorShapeUtils::IsVectorOrHigher(t.shape()), |
455 | errors::InvalidArgument("Concat saw a scalar shape at index " , i, |
456 | " but requires at least vectors." )); |
457 | TensorShape shape_except_first_dim = TensorShape( |
458 | gtl::ArraySlice<int64_t>(t.shape().dim_sizes()).subspan(1)); |
459 | OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim, |
460 | &element_shape_except_first_dim)); |
461 | OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0), |
462 | errors::InvalidArgument( |
463 | "First entry of element_shape input does not match " , |
464 | "the first dim of list element at index: " , i, |
465 | " Expected: " , first_dim, |
466 | " Actual: " , t.shape().dim_size(0))); |
467 | if (check_dim) { |
468 | if (inferred_first_dim == -1) { |
469 | inferred_first_dim = t.shape().dim_size(0); |
470 | } else if (inferred_first_dim != t.shape().dim_size(0)) { |
471 | inferred_first_dim = -1; |
472 | check_dim = false; |
473 | } |
474 | } |
475 | } |
476 | } |
477 | first_dim = inferred_first_dim; |
478 | } |
479 | TensorShape output_shape; |
480 | OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape), |
481 | errors::InvalidArgument( |
482 | "Trying to concat list with only uninitialized tensors " , |
483 | "but element_shape_except_first_dim is not fully defined: " , |
484 | element_shape_except_first_dim.DebugString())); |
485 | // Build the lengths_tensor and leading dim of the output tensor by |
486 | // iterating over all element tensors. |
487 | Tensor* lengths_tensor = nullptr; |
488 | OP_REQUIRES_OK(c, c->allocate_output(1, |
489 | TensorShape({static_cast<int64_t>( |
490 | tensor_list->tensors().size())}), |
491 | &lengths_tensor)); |
492 | auto lengths_tensor_vec = lengths_tensor->vec<int64_t>(); |
493 | int64_t leading_dim = 0; |
494 | for (size_t i = 0; i < tensor_list->tensors().size(); i++) { |
495 | int64_t dim; |
496 | if (tensor_list->tensors()[i].dtype() != DT_INVALID) { |
497 | dim = tensor_list->tensors()[i].shape().dim_size(0); |
498 | } else { |
499 | // If leading_dims is not provided or does not contain an entry for |
500 | // index i use the inferred `first_dim` if set. |
501 | if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) && |
502 | first_dim != -1) { |
503 | dim = first_dim; |
504 | } else { |
505 | OP_REQUIRES(c, c->num_inputs() > 2, |
506 | errors::InvalidArgument( |
507 | "Concating lists with uninitialized tensors is not " , |
508 | "supported in this version of TensorListConcat. " , |
509 | "Consider updating your GraphDef to run the newer " , |
510 | "version." )); |
511 | OP_REQUIRES(c, i < c->input(2).NumElements(), |
512 | errors::InvalidArgument( |
513 | "List contains uninitialized tensor at index " , i, |
514 | " but leading_dims has only " , |
515 | c->input(2).NumElements(), " elements." )); |
516 | dim = c->input(2).vec<int64_t>()(i); |
517 | } |
518 | } |
519 | leading_dim += dim; |
520 | lengths_tensor_vec(i) = dim; |
521 | } |
522 | output_shape.InsertDim(0, leading_dim); |
523 | Tensor* output; |
524 | // Allocate the output tensor and fill it up with the concated element |
525 | // tensors. |
526 | OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); |
527 | if (output->NumElements() == 0) { |
528 | return; |
529 | } |
530 | |
531 | ConstMatrixVector inputs_flat; |
532 | inputs_flat.reserve(tensor_list->tensors().size()); |
533 | // Store the zeros tensors in a vector to prevent them from being GC'ed till |
534 | // concat is complete. |
535 | std::vector<Tensor> zeros_vec; |
536 | for (int i = 0; i < tensor_list->tensors().size(); i++) { |
537 | const Tensor& element_tensor = tensor_list->tensors()[i]; |
538 | if (element_tensor.dtype() != DT_INVALID) { |
539 | if (element_tensor.NumElements() > 0) { |
540 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
541 | element_tensor.shaped<T, 2>({1, element_tensor.NumElements()}))); |
542 | } |
543 | } else { |
544 | AllocatorAttributes attr; |
545 | if (element_dtype_ == DT_VARIANT) { |
546 | attr.set_on_host(true); |
547 | } |
548 | TensorShape element_shape = output_shape; |
549 | element_shape.set_dim(0, lengths_tensor_vec(i)); |
550 | zeros_vec.emplace_back(); |
551 | Tensor& zeros = zeros_vec.back(); |
552 | OP_REQUIRES_OK( |
553 | c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); |
554 | SetZero<Device, T>(c, zeros); |
555 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
556 | const_cast<const Tensor&>(zeros).shaped<T, 2>( |
557 | {1, zeros.NumElements()}))); |
558 | } |
559 | } |
560 | auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); |
561 | |
562 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
563 | if (std::is_same<Device, Eigen::GpuDevice>::value) { |
564 | ConcatGPU<T>(c, inputs_flat, output, &output_flat); |
565 | return; |
566 | } |
567 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
568 | if (IsPluggableDevice(c)) { |
569 | ConcatPluggableDevice<T>(c, inputs_flat, &output_flat); |
570 | } else { |
571 | ConcatCPU<T>(c->device(), inputs_flat, &output_flat); |
572 | } |
573 | } |
574 | |
575 | private: |
576 | DataType element_dtype_; |
577 | PartialTensorShape element_shape_; |
578 | }; |
579 | |
580 | template <typename Device, typename T> |
581 | class TensorListSplit : public OpKernel { |
582 | public: |
583 | TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {} |
584 | |
585 | void Compute(OpKernelContext* c) override { |
586 | Tensor* output_tensor; |
587 | AllocatorAttributes attr; |
588 | attr.set_on_host(true); |
589 | OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); |
590 | PartialTensorShape element_shape; |
591 | OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape)); |
592 | OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1, |
593 | errors::InvalidArgument( |
594 | "TensorListSplit requires element_shape to be at least of " , |
595 | "rank 1, but saw: " , element_shape.DebugString())); |
596 | TensorList output_list; |
597 | const Tensor& input_tensor = c->input(0); |
598 | output_list.element_dtype = input_tensor.dtype(); |
599 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), |
600 | errors::InvalidArgument( |
601 | "Tensor must be at least a vector, but saw shape: " , |
602 | input_tensor.shape().DebugString())); |
603 | TensorShape tensor_shape_without_first_dim(input_tensor.shape()); |
604 | tensor_shape_without_first_dim.RemoveDim(0); |
605 | PartialTensorShape element_shape_without_first_dim; |
606 | if (!element_shape.unknown_rank()) { |
607 | element_shape_without_first_dim = |
608 | PartialTensorShape(element_shape.dim_sizes()); |
609 | element_shape_without_first_dim.RemoveDim(0); |
610 | } |
611 | OP_REQUIRES(c, |
612 | element_shape_without_first_dim.IsCompatibleWith( |
613 | tensor_shape_without_first_dim), |
614 | errors::InvalidArgument( |
615 | "tensor shape " , input_tensor.shape().DebugString(), |
616 | " is not compatible with element_shape " , |
617 | element_shape.DebugString())); |
618 | output_list.element_shape = element_shape; |
619 | const Tensor& lengths = c->input(2); |
620 | OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()), |
621 | errors::InvalidArgument( |
622 | "Expected lengths to be a vector, received shape: " , |
623 | lengths.shape().DebugString())); |
624 | output_list.tensors().reserve(lengths.shape().dim_size(0)); |
625 | |
626 | const auto copy_tensor = IsPluggableDevice(c) |
627 | ? &CopyTensorPluggableDevice<T> |
628 | : &CopyTensor<Device, T>; |
629 | |
630 | int64_t start = 0; |
631 | int64_t end = 0; |
632 | for (int i = 0; i < lengths.shape().dim_size(0); ++i) { |
633 | int64_t length = lengths.vec<int64_t>()(i); |
634 | OP_REQUIRES( |
635 | c, length >= 0, |
636 | errors::InvalidArgument("Invalid value in lengths: " , length)); |
637 | end = start + length; |
638 | OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0), |
639 | errors::InvalidArgument("Attempting to slice [" , start, ", " , |
640 | end, "] from tensor with length " , |
641 | input_tensor.shape().dim_size(0))); |
642 | Tensor tmp = input_tensor.Slice(start, end); |
643 | start = end; |
644 | // TODO(apassos) maybe not always align; but weird compiler bugs seem to |
645 | // prevent this. |
646 | Tensor aligned; |
647 | OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); |
648 | copy_tensor(c, tmp, aligned); |
649 | output_list.tensors().emplace_back(aligned); |
650 | } |
651 | OP_REQUIRES(c, end == input_tensor.shape().dim_size(0), |
652 | errors::InvalidArgument( |
653 | "Unused values in tensor. Length of tensor: " , |
654 | input_tensor.shape().dim_size(0), " Values used: " , end)); |
655 | output_tensor->scalar<Variant>()() = std::move(output_list); |
656 | } |
657 | }; |
658 | |
659 | template <typename Device, typename T> |
660 | class TensorListGather : public OpKernel { |
661 | public: |
662 | typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> |
663 | ConstMatrixVector; |
664 | explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) { |
665 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
666 | } |
667 | |
668 | void Compute(OpKernelContext* c) override { |
669 | const TensorList* tensor_list = nullptr; |
670 | OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); |
671 | OP_REQUIRES( |
672 | c, element_dtype_ == tensor_list->element_dtype, |
673 | errors::InvalidArgument( |
674 | "Invalid data types; op elements " , DataTypeString(element_dtype_), |
675 | " but list elements " , DataTypeString(tensor_list->element_dtype))); |
676 | const Tensor& indices = c->input(1); |
677 | PartialTensorShape partial_element_shape; |
678 | OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2, |
679 | &partial_element_shape)); |
680 | OP_REQUIRES( |
681 | c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0, |
682 | errors::InvalidArgument("Tried to gather 0-elements from " |
683 | "a list with non-fully-defined shape: " , |
684 | partial_element_shape.DebugString())); |
685 | |
686 | // Check that `element_shape` input tensor is compatible with the shapes of |
687 | // element tensors. |
688 | if (!tensor_list->element_shape.IsFullyDefined()) { |
689 | for (int index = 0; index < indices.NumElements(); ++index) { |
690 | const int i = indices.flat<int32>()(index); |
691 | const Tensor& t = tensor_list->tensors()[i]; |
692 | if (t.dtype() != DT_INVALID) { |
693 | PartialTensorShape tmp = partial_element_shape; |
694 | OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); |
695 | } |
696 | } |
697 | } |
698 | |
699 | // Compute the shape of the output tensor by pre-pending the leading dim to |
700 | // the element_shape. |
701 | TensorShape element_shape; |
702 | OP_REQUIRES( |
703 | c, partial_element_shape.AsTensorShape(&element_shape), |
704 | errors::InvalidArgument("Tried to gather uninitialized tensors from a " , |
705 | "list with non-fully-defined element_shape: " , |
706 | partial_element_shape.DebugString())); |
707 | TensorShape output_shape = element_shape; |
708 | output_shape.InsertDim(0, indices.NumElements()); |
709 | Tensor* output; |
710 | OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); |
711 | if (output->NumElements() == 0) { |
712 | return; |
713 | } |
714 | |
715 | ConstMatrixVector inputs_flat; |
716 | inputs_flat.reserve(indices.NumElements()); |
717 | Tensor zeros; |
718 | for (int index = 0; index < indices.NumElements(); ++index) { |
719 | const int i = indices.flat<int32>()(index); |
720 | OP_REQUIRES( |
721 | c, i < tensor_list->tensors().size(), |
722 | errors::InvalidArgument("Index " , i, " out o range; list only has " , |
723 | tensor_list->tensors().size(), " elements." )); |
724 | const Tensor& t = tensor_list->tensors()[i]; |
725 | if (t.dtype() != DT_INVALID) { |
726 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
727 | t.shaped<T, 2>({1, t.NumElements()}))); |
728 | } else { |
729 | if (!zeros.NumElements()) { |
730 | AllocatorAttributes attr; |
731 | if (element_dtype_ == DT_VARIANT) { |
732 | attr.set_on_host(true); |
733 | } |
734 | OP_REQUIRES_OK( |
735 | c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); |
736 | SetZero<Device, T>(c, zeros); |
737 | } |
738 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
739 | const_cast<const Tensor&>(zeros).shaped<T, 2>( |
740 | {1, zeros.NumElements()}))); |
741 | } |
742 | } |
743 | auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); |
744 | |
745 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
746 | if (std::is_same<Device, Eigen::GpuDevice>::value) { |
747 | ConcatGPU<T>(c, inputs_flat, output, &output_flat); |
748 | return; |
749 | } |
750 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
751 | if (IsPluggableDevice(c)) { |
752 | ConcatPluggableDevice<T>(c, inputs_flat, &output_flat); |
753 | } else { |
754 | ConcatCPU<T>(c->device(), inputs_flat, &output_flat); |
755 | } |
756 | } |
757 | |
758 | private: |
759 | DataType element_dtype_; |
760 | }; |
761 | |
762 | template <typename Device, typename T> |
763 | class TensorListFromTensor : public OpKernel { |
764 | public: |
765 | TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {} |
766 | |
767 | void Compute(OpKernelContext* c) override { |
768 | Tensor* output_tensor; |
769 | AllocatorAttributes attr; |
770 | attr.set_on_host(true); |
771 | OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); |
772 | PartialTensorShape element_shape; |
773 | OP_REQUIRES( |
774 | c, !TensorShapeUtils::IsMatrixOrHigher(c->input(1).shape()), |
775 | errors::InvalidArgument( |
776 | "TensorListFromTensor: element_shape must be at most rank 1 but " , |
777 | "has the shape of " , c->input(1).shape().DebugString())); |
778 | OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape)); |
779 | TensorList output_list; |
780 | const Tensor& t = c->input(0); |
781 | output_list.element_dtype = t.dtype(); |
782 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()), |
783 | errors::InvalidArgument( |
784 | "Tensor must be at least a vector, but saw shape: " , |
785 | t.shape().DebugString())); |
786 | TensorShape output_shape(t.shape()); |
787 | output_shape.RemoveDim(0); |
788 | OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), |
789 | errors::InvalidArgument( |
790 | "Specified a list with shape " , element_shape.DebugString(), |
791 | " from a tensor with shape " , output_shape.DebugString())); |
792 | output_list.element_shape = element_shape; |
793 | output_list.tensors().reserve(t.shape().dim_size(0)); |
794 | |
795 | const auto copy_tensor = IsPluggableDevice(c) |
796 | ? &CopyTensorPluggableDevice<T> |
797 | : &CopyTensor<Device, T>; |
798 | |
799 | for (int i = 0; i < t.shape().dim_size(0); ++i) { |
800 | Tensor tmp = t.Slice(i, i + 1); |
801 | TensorShape tmp_shape = tmp.shape(); |
802 | tmp_shape.RemoveDim(0); |
803 | OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape), |
804 | errors::Unknown("Unexpected shape error." )); |
805 | // TODO(apassos) maybe not always align; but weird compiler bugs seem to |
806 | // prevent this. |
807 | Tensor aligned; |
808 | OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); |
809 | copy_tensor(c, tmp, aligned); |
810 | output_list.tensors().push_back(aligned); |
811 | } |
812 | output_tensor->scalar<Variant>()() = std::move(output_list); |
813 | } |
814 | }; |
815 | |
816 | // Scatters values in `value` into `list`. Assumes that `indices` are valid. |
817 | template <typename Device, typename T> |
818 | Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices, |
819 | TensorList* list) { |
820 | const auto copy_tensor = IsPluggableDevice(c) ? &CopyTensorPluggableDevice<T> |
821 | : &CopyTensor<Device, T>; |
822 | for (int index = 0; index < indices.NumElements(); ++index) { |
823 | const int i = indices.flat<int32>()(index); |
824 | Tensor tmp = value.Slice(index, index + 1); |
825 | TensorShape tmp_shape = tmp.shape(); |
826 | tmp_shape.RemoveDim(0); |
827 | if (!tmp.CopyFrom(tmp, tmp_shape)) { |
828 | return errors::Unknown("Unexpected shape error." ); |
829 | } |
830 | // TODO(apassos) maybe not always align; but weird compiler bugs seem to |
831 | // prevent this. |
832 | Tensor aligned; |
833 | TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); |
834 | // TODO(apassos) do all slices in a single kernel invocation instead of |
835 | // many small ones. |
836 | copy_tensor(c, tmp, aligned); |
837 | std::swap(list->tensors()[i], aligned); |
838 | } |
839 | return OkStatus(); |
840 | } |
841 | |
842 | template <typename Device, typename T> |
843 | class TensorListScatterIntoExistingList : public OpKernel { |
844 | public: |
845 | TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {} |
846 | |
847 | void Compute(OpKernelContext* c) override { |
848 | const TensorList* l = nullptr; |
849 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
850 | const Tensor& input_tensor = c->input(1); |
851 | const Tensor& indices = c->input(2); |
852 | |
853 | // Check that inputs are valid. |
854 | OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype, |
855 | errors::InvalidArgument( |
856 | "Invalid data types; input tensor type: " , |
857 | DataTypeString(input_tensor.dtype()), |
858 | " list element_type: " , DataTypeString(l->element_dtype))); |
859 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), |
860 | errors::InvalidArgument( |
861 | "Tensor must be at least a vector, but saw shape: " , |
862 | input_tensor.shape().DebugString())); |
863 | OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()), |
864 | errors::InvalidArgument( |
865 | "Expected indices to be a vector, but received shape: " , |
866 | indices.shape().DebugString())); |
867 | OP_REQUIRES( |
868 | c, indices.NumElements() == input_tensor.shape().dim_size(0), |
869 | errors::InvalidArgument( |
870 | "Expected len(indices) == tensor.shape[0], but saw: " , |
871 | indices.NumElements(), " vs. " , input_tensor.shape().dim_size(0))); |
872 | |
873 | // Resize the list if needed to accommodate all indices. |
874 | TensorList* output_list = nullptr; |
875 | OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
876 | const auto indices_vec = indices.vec<int32>(); |
877 | int32_t max_index = |
878 | (indices.NumElements() == 0) |
879 | ? -1 |
880 | : *std::max_element(indices_vec.data(), |
881 | indices_vec.data() + indices.NumElements()); |
882 | if (max_index + 1 > output_list->tensors().size()) { |
883 | output_list->tensors().resize(max_index + 1); |
884 | } |
885 | |
886 | // Scatter the values. |
887 | OP_REQUIRES_OK(c, |
888 | Scatter<Device, T>(c, input_tensor, indices, output_list)); |
889 | } |
890 | }; |
891 | |
892 | template <typename Device, typename T> |
893 | class TensorListScatter : public OpKernel { |
894 | public: |
895 | TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {} |
896 | |
897 | void Compute(OpKernelContext* c) override { |
898 | Tensor* output_tensor; |
899 | AllocatorAttributes attr; |
900 | attr.set_on_host(true); |
901 | OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); |
902 | Tensor indices = c->input(1); |
903 | PartialTensorShape element_shape; |
904 | OP_REQUIRES( |
905 | c, !TensorShapeUtils::IsMatrixOrHigher(c->input(2).shape()), |
906 | errors::InvalidArgument( |
907 | "TensorListScatter: element_shape must be at most rank 1 but has " , |
908 | "the shape of " , c->input(2).shape().DebugString())); |
909 | OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape)); |
910 | // TensorListScatterV2 passes the num_elements input, TensorListScatter does |
911 | // not. |
912 | int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1; |
913 | OP_REQUIRES(c, num_elements >= -1, |
914 | errors::InvalidArgument( |
915 | "TensorListScatter expects num_elements >= -1, found: " , |
916 | num_elements)); |
917 | TensorList output_list; |
918 | const Tensor& input_tensor = c->input(0); |
919 | output_list.element_dtype = input_tensor.dtype(); |
920 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), |
921 | errors::InvalidArgument( |
922 | "Tensor must be at least a vector, but saw shape: " , |
923 | input_tensor.shape().DebugString())); |
924 | TensorShape output_shape(input_tensor.shape()); |
925 | output_shape.RemoveDim(0); |
926 | OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), |
927 | errors::InvalidArgument( |
928 | "Specified a list with shape " , element_shape.DebugString(), |
929 | " from a tensor with shape " , output_shape.DebugString())); |
930 | output_list.element_shape = element_shape; |
931 | |
932 | OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0), |
933 | errors::InvalidArgument( |
934 | "Invalid number of rows in input tensor. Expected: " , |
935 | indices.NumElements(), |
936 | " Actual: " , input_tensor.shape().dim_size(0))); |
937 | |
938 | // Validate indices and resize output_list.tensors to fit the highest index. |
939 | { |
940 | int highest_index = -1; |
941 | for (int index = 0; index < indices.NumElements(); ++index) { |
942 | const int i = indices.flat<int32>()(index); |
943 | OP_REQUIRES( |
944 | c, i >= 0, |
945 | errors::InvalidArgument( |
946 | "Indices in TensorListScatter must all be non-negative." )); |
947 | OP_REQUIRES(c, num_elements == -1 || i < num_elements, |
948 | errors::InvalidArgument( |
949 | "TensorListScatter: Trying to scatter at index " , i, |
950 | " in list with size " , num_elements)); |
951 | if (i > highest_index) { |
952 | highest_index = i; |
953 | } |
954 | } |
955 | output_list.tensors().resize(std::max(highest_index + 1, num_elements), |
956 | Tensor(DT_INVALID)); |
957 | } |
958 | |
959 | OP_REQUIRES_OK(c, |
960 | Scatter<Device, T>(c, input_tensor, indices, &output_list)); |
961 | output_tensor->scalar<Variant>()() = std::move(output_list); |
962 | } |
963 | }; |
964 | |
965 | template <typename Device> |
966 | Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, |
967 | const TensorList& b, TensorList* out) { |
968 | return TensorListBinaryAdd(c, a, b, out, BinaryAddTensors<Device>); |
969 | } |
970 | |
971 | template <typename Device> |
972 | Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, |
973 | TensorList* y) { |
974 | return TensorListZerosLike(c, x, y, ZerosLikeTensor<Device>); |
975 | } |
976 | |
977 | template <typename Device, typename T> |
978 | class TensorListPushBackBatch : public OpKernel { |
979 | public: |
980 | explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) { |
981 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
982 | } |
983 | |
984 | void Compute(OpKernelContext* c) override { |
985 | const Tensor& input = c->input(1); |
986 | OP_REQUIRES(c, element_dtype_ == input.dtype(), |
987 | errors::InvalidArgument("Invalid data types; list elements " , |
988 | DataTypeString(element_dtype_), |
989 | " but tried to append " , |
990 | DataTypeString(input.dtype()))); |
991 | OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()), |
992 | errors::InvalidArgument( |
993 | "Expected tensor to be at least a vector, but saw shape: " , |
994 | input.shape().DebugString())); |
995 | |
996 | const TensorShape& tls_shape = c->input(0).shape(); |
997 | |
998 | // For purposes of input forwarding, we want the least restrictive |
999 | // AllocatorAttributes possible. If we need to allocate later, |
1000 | // we'll request the DT_VARIANT be allocated on host. |
1001 | AllocatorAttributes attr; |
1002 | |
1003 | std::unique_ptr<Tensor> tls_alias = c->forward_input( |
1004 | 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape, |
1005 | DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); |
1006 | |
1007 | bool ok_to_alias = tls_alias != nullptr; |
1008 | if (tls_alias && tls_alias->dtype() == DT_VARIANT && |
1009 | tls_alias->NumElements() > 0) { |
1010 | auto alias_t = tls_alias->flat<Variant>(); |
1011 | for (int i = 0; i < tls_alias->NumElements(); ++i) { |
1012 | TensorList* tl_i = alias_t(i).get<TensorList>(); |
1013 | if (tl_i == nullptr || !tl_i->RefCountIsOne()) { |
1014 | ok_to_alias = false; |
1015 | break; |
1016 | } |
1017 | } |
1018 | } |
1019 | const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0); |
1020 | |
1021 | OP_REQUIRES(c, tls.dtype() == DT_VARIANT, |
1022 | errors::InvalidArgument( |
1023 | "Expected input_handles dtype to be Variant, but saw: " , |
1024 | DataTypeString(tls.dtype()))); |
1025 | OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape), |
1026 | errors::InvalidArgument( |
1027 | "Expected input_handles to be a vector, but saw shape: " , |
1028 | tls_shape.DebugString())); |
1029 | const int64_t batch_size = tls.NumElements(); |
1030 | OP_REQUIRES(c, input.dim_size(0) == batch_size, |
1031 | errors::InvalidArgument( |
1032 | "Expected tensor.shape[0] == input_handles.size, but saw " , |
1033 | input.dim_size(0), " vs. " , batch_size)); |
1034 | auto tls_t = tls.vec<Variant>(); |
1035 | |
1036 | TensorShape input_element_shape = input.shape(); |
1037 | input_element_shape.RemoveDim(0); |
1038 | std::vector<const TensorList*> tl_batch; |
1039 | for (int64_t b = 0; b < batch_size; ++b) { |
1040 | const TensorList* l = tls_t(b).get<TensorList>(); |
1041 | OP_REQUIRES(c, l != nullptr, |
1042 | errors::InvalidArgument("Input handle at index " , b, |
1043 | " is not a list. Saw: '" , |
1044 | tls_t(b).DebugString(), "'" )); |
1045 | OP_REQUIRES( |
1046 | c, l->element_shape.IsCompatibleWith(input_element_shape), |
1047 | errors::InvalidArgument( |
1048 | "Tried to append a tensor with incompatible shape to a " |
1049 | "list at index " , |
1050 | b, ". Op element shape: " , input_element_shape.DebugString(), |
1051 | " list shape: " , l->element_shape.DebugString())); |
1052 | OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
1053 | errors::InvalidArgument( |
1054 | "Invalid data type at index " , b, "; op elements " , |
1055 | DataTypeString(element_dtype_), " but list elements " , |
1056 | DataTypeString(l->element_dtype))); |
1057 | tl_batch.push_back(l); |
1058 | } |
1059 | |
1060 | Tensor* result; |
1061 | |
1062 | if (ok_to_alias) { |
1063 | result = tls_alias.get(); |
1064 | c->set_output(0, *result); |
1065 | } else { |
1066 | // DT_VARIANT tensors always allocated on host. |
1067 | AllocatorAttributes attr; |
1068 | attr.set_on_host(true); |
1069 | OP_REQUIRES_OK( |
1070 | c, c->allocate_output(0, TensorShape{batch_size}, &result, attr)); |
1071 | } |
1072 | |
1073 | if (batch_size == 0) { |
1074 | return; |
1075 | } |
1076 | |
1077 | auto input_t = input.flat_outer_dims<T, 2>(); |
1078 | auto result_t = result->vec<Variant>(); |
1079 | |
1080 | for (int64_t b = 0; b < batch_size; ++b) { |
1081 | if (!ok_to_alias) { |
1082 | result_t(b) = tl_batch[b]->Copy(); |
1083 | } |
1084 | TensorList* output = result_t(b).get<TensorList>(); |
1085 | DCHECK(output != nullptr); |
1086 | Tensor frame; |
1087 | OP_REQUIRES_OK( |
1088 | c, c->allocate_temp(element_dtype_, input_element_shape, &frame)); |
1089 | if (input_element_shape.num_elements() > 0) { |
1090 | auto frame_t = frame.flat<T>(); |
1091 | // TODO(penporn): Get this if out of the batch loop. |
1092 | if (IsPluggableDevice(c)) { |
1093 | // The chip method need Eigen Device, so need to use Tensor.Slice |
1094 | // instead of chip for pluggable device. The input should be reshaped |
1095 | // to 2-D and so can be sliced by batch dim. |
1096 | auto input_t_shape = |
1097 | TensorShape({input_t.dimension(0), input_t.dimension(1)}); |
1098 | auto input_reshaped = Tensor(); |
1099 | OP_REQUIRES(c, input_reshaped.CopyFrom(input, input_t_shape), |
1100 | errors::Unknown("Unexpected shape error." )); |
1101 | |
1102 | auto input_batch = input_reshaped.Slice(b, b + 1); |
1103 | CopyTensorPluggableDevice<T>(c, input_batch, frame); |
1104 | } else { |
1105 | frame_t.device(c->eigen_device<Device>()) = |
1106 | input_t.template chip<0>(b); |
1107 | } |
1108 | } |
1109 | output->tensors().push_back(std::move(frame)); |
1110 | } |
1111 | } |
1112 | |
1113 | private: |
1114 | DataType element_dtype_; |
1115 | }; |
1116 | |
1117 | } // namespace tensorflow |
1118 | |
1119 | #undef PLUGGABLE_DEVICE_SUPPORTED |
1120 | #endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ |
1121 | |