1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
19 | #define EIGEN_USE_GPU |
20 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | |
22 | #include "tensorflow/core/kernels/list_kernels.h" |
23 | |
24 | #include <algorithm> |
25 | #include <iterator> |
26 | #include <limits> |
27 | #include <memory> |
28 | #include <utility> |
29 | |
30 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
31 | #include "tensorflow/core/framework/allocator.h" |
32 | #include "tensorflow/core/framework/op_kernel.h" |
33 | #include "tensorflow/core/framework/register_types.h" |
34 | #include "tensorflow/core/framework/tensor_shape.h" |
35 | #include "tensorflow/core/framework/tensor_types.h" |
36 | #include "tensorflow/core/framework/variant.h" |
37 | #include "tensorflow/core/framework/variant_op_registry.h" |
38 | #include "tensorflow/core/platform/errors.h" |
39 | |
40 | namespace tensorflow { |
41 | |
42 | typedef Eigen::ThreadPoolDevice CPUDevice; |
43 | |
44 | Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { |
45 | if (t.shape() == TensorShape({})) { |
46 | if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) || |
47 | (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) { |
48 | *out = PartialTensorShape(); |
49 | return OkStatus(); |
50 | } |
51 | return errors::InvalidArgument( |
52 | "The only valid scalar shape tensor is the fully unknown shape " |
53 | "specified as -1." ); |
54 | } else if (t.shape().dims() != 1) { |
55 | return errors::InvalidArgument("Shape must be at most rank 1 but is rank " , |
56 | t.shape().dims()); |
57 | } |
58 | if (t.dtype() == DT_INT32) { |
59 | return PartialTensorShape::MakePartialShape(t.vec<int32>().data(), |
60 | t.NumElements(), out); |
61 | } else if (t.dtype() == DT_INT64) { |
62 | return PartialTensorShape::MakePartialShape(t.vec<int64_t>().data(), |
63 | t.NumElements(), out); |
64 | } |
65 | return errors::InvalidArgument( |
66 | "Expected an int32 or int64 shape tensor; found " , |
67 | DataTypeString(t.dtype())); |
68 | } |
69 | |
70 | Status GetElementShapeFromInput(OpKernelContext* c, |
71 | const TensorList& tensor_list, int index, |
72 | PartialTensorShape* element_shape) { |
73 | TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape)); |
74 | // Check that `element_shape` and `tensor_list.element_shape` are |
75 | // compatible and store the merged shape in `element_shape`. |
76 | PartialTensorShape tmp = *element_shape; |
77 | TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape)); |
78 | return OkStatus(); |
79 | } |
80 | |
81 | Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { |
82 | if (!TensorShapeUtils::IsScalar(c->input(index).shape())) { |
83 | return errors::InvalidArgument("Input list must be a scalar saw: " , |
84 | c->input(index).shape().DebugString()); |
85 | } |
86 | const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>(); |
87 | if (l == nullptr) { |
88 | return errors::InvalidArgument( |
89 | "Input handle is not a list. Saw: '" , |
90 | c->input(index).scalar<Variant>()().DebugString(), "'" ); |
91 | } |
92 | *list = l; |
93 | return OkStatus(); |
94 | } |
95 | |
96 | Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, |
97 | int32_t output_index, |
98 | const TensorList& input_list, |
99 | TensorList** output_list) { |
100 | // Attempt to forward the input tensor to the output if possible. |
101 | std::unique_ptr<Tensor> maybe_output = c->forward_input( |
102 | input_index, output_index, DT_VARIANT, TensorShape{}, |
103 | c->input_memory_type(input_index), AllocatorAttributes()); |
104 | Tensor* output_tensor; |
105 | if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && |
106 | maybe_output->NumElements() == 1) { |
107 | output_tensor = maybe_output.get(); |
108 | TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>(); |
109 | if (tmp_out == nullptr) { |
110 | return errors::InvalidArgument( |
111 | "Expected input " , input_index, " to be a TensorList but saw " , |
112 | output_tensor->scalar<Variant>()().TypeName()); |
113 | } |
114 | if (tmp_out->RefCountIsOne()) { |
115 | // Woohoo, forwarding succeeded! |
116 | c->set_output(output_index, *output_tensor); |
117 | *output_list = tmp_out; |
118 | return OkStatus(); |
119 | } |
120 | } |
121 | |
122 | // If forwarding is not possible allocate a new output tensor and copy |
123 | // the `input_list` to it. |
124 | AllocatorAttributes attr; |
125 | attr.set_on_host(true); |
126 | TF_RETURN_IF_ERROR( |
127 | c->allocate_output(output_index, {}, &output_tensor, attr)); |
128 | output_tensor->scalar<Variant>()() = input_list.Copy(); |
129 | |
130 | *output_list = output_tensor->scalar<Variant>()().get<TensorList>(); |
131 | return OkStatus(); |
132 | } |
133 | |
134 | class EmptyTensorList : public OpKernel { |
135 | public: |
136 | explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) { |
137 | OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype" , &element_dtype_)); |
138 | } |
139 | |
140 | void Compute(OpKernelContext* ctx) override { |
141 | const Tensor& max_num_elements_t = ctx->input(1); |
142 | OP_REQUIRES( |
143 | ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()), |
144 | errors::InvalidArgument( |
145 | "max_num_elements expected to be a scalar " , |
146 | "but got shape: " , max_num_elements_t.shape().DebugString())); |
147 | Tensor* result; |
148 | AllocatorAttributes attr; |
149 | attr.set_on_host(true); |
150 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); |
151 | TensorList empty; |
152 | empty.element_dtype = element_dtype_; |
153 | empty.max_num_elements = max_num_elements_t.scalar<int32>()(); |
154 | PartialTensorShape element_shape; |
155 | OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape)); |
156 | empty.element_shape = element_shape; |
157 | result->scalar<Variant>()() = std::move(empty); |
158 | } |
159 | |
160 | private: |
161 | DataType element_dtype_; |
162 | }; |
163 | |
164 | REGISTER_KERNEL_BUILDER(Name("EmptyTensorList" ).Device(DEVICE_CPU), |
165 | EmptyTensorList); |
166 | |
167 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
168 | |
169 | REGISTER_KERNEL_BUILDER(Name("EmptyTensorList" ) |
170 | .Device(DEVICE_GPU) |
171 | .HostMemory("element_shape" ) |
172 | .HostMemory("max_num_elements" ), |
173 | EmptyTensorList); |
174 | |
175 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
176 | |
177 | REGISTER_KERNEL_BUILDER(Name("EmptyTensorList" ) |
178 | .Device(DEVICE_DEFAULT) |
179 | .HostMemory("element_shape" ) |
180 | .HostMemory("max_num_elements" ), |
181 | EmptyTensorList); |
182 | |
183 | class TensorListPushBack : public OpKernel { |
184 | public: |
185 | explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) { |
186 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
187 | } |
188 | |
189 | ~TensorListPushBack() override {} |
190 | |
191 | void Compute(OpKernelContext* c) override { |
192 | const Tensor& input = c->input(1); |
193 | OP_REQUIRES(c, element_dtype_ == input.dtype(), |
194 | errors::InvalidArgument("Invalid data types; list elements " , |
195 | DataTypeString(element_dtype_), |
196 | " but tried to append " , |
197 | DataTypeString(input.dtype()))); |
198 | |
199 | const TensorList* l = nullptr; |
200 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
201 | OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()), |
202 | errors::InvalidArgument( |
203 | "Tried to append a tensor with incompatible shape to a " |
204 | "list. Op element shape: " , |
205 | input.shape().DebugString(), |
206 | " list shape: " , l->element_shape.DebugString())); |
207 | OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
208 | errors::InvalidArgument("Invalid data types; op elements " , |
209 | DataTypeString(element_dtype_), |
210 | " but list elements " , |
211 | DataTypeString(l->element_dtype))); |
212 | |
213 | if (l->max_num_elements != -1) { |
214 | OP_REQUIRES( |
215 | c, l->tensors().size() < l->max_num_elements, |
216 | errors::InvalidArgument("Tried to push item into a full list" , |
217 | " list size: " , l->tensors().size(), |
218 | " max_num_elements: " , l->max_num_elements)); |
219 | } |
220 | |
221 | TensorList* output_list = nullptr; |
222 | OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
223 | output_list->tensors().push_back(input); |
224 | } |
225 | |
226 | private: |
227 | DataType element_dtype_; |
228 | }; |
229 | |
230 | REGISTER_KERNEL_BUILDER(Name("TensorListPushBack" ).Device(DEVICE_CPU), |
231 | TensorListPushBack); |
232 | |
233 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
234 | |
235 | REGISTER_KERNEL_BUILDER(Name("TensorListPushBack" ).Device(DEVICE_GPU), |
236 | TensorListPushBack); |
237 | |
238 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
239 | |
240 | REGISTER_KERNEL_BUILDER(Name("TensorListPushBack" ).Device(DEVICE_DEFAULT), |
241 | TensorListPushBack); |
242 | |
243 | class TensorListLength : public OpKernel { |
244 | public: |
245 | explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {} |
246 | ~TensorListLength() override {} |
247 | |
248 | void Compute(OpKernelContext* c) override { |
249 | const TensorList* l = nullptr; |
250 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
251 | Tensor* result; |
252 | OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); |
253 | result->scalar<int32>()() = l->tensors().size(); |
254 | } |
255 | }; |
256 | |
257 | REGISTER_KERNEL_BUILDER(Name("TensorListLength" ).Device(DEVICE_CPU), |
258 | TensorListLength); |
259 | |
260 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
261 | |
262 | REGISTER_KERNEL_BUILDER( |
263 | Name("TensorListLength" ).Device(DEVICE_GPU).HostMemory("length" ), |
264 | TensorListLength); |
265 | |
266 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
267 | |
268 | REGISTER_KERNEL_BUILDER( |
269 | Name("TensorListLength" ).Device(DEVICE_DEFAULT).HostMemory("length" ), |
270 | TensorListLength); |
271 | |
272 | class TensorListElementShape : public OpKernel { |
273 | public: |
274 | explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {} |
275 | |
276 | void Compute(OpKernelContext* c) override { |
277 | const TensorList* l = nullptr; |
278 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
279 | Tensor* result; |
280 | if (l->element_shape.unknown_rank()) { |
281 | OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result)); |
282 | if (result->dtype() == DT_INT32) { |
283 | result->scalar<int32>()() = -1; |
284 | } else { |
285 | result->scalar<int64_t>()() = -1; |
286 | } |
287 | } else { |
288 | OP_REQUIRES_OK(c, c->allocate_output( |
289 | 0, TensorShape{l->element_shape.dims()}, &result)); |
290 | for (int i = 0; i < l->element_shape.dims(); ++i) { |
291 | if (result->dtype() == DT_INT32) { |
292 | result->flat<int32>()(i) = l->element_shape.dim_size(i); |
293 | } else { |
294 | result->flat<int64_t>()(i) = l->element_shape.dim_size(i); |
295 | } |
296 | } |
297 | } |
298 | } |
299 | }; |
300 | |
301 | REGISTER_KERNEL_BUILDER(Name("TensorListElementShape" ).Device(DEVICE_CPU), |
302 | TensorListElementShape); |
303 | |
304 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
305 | |
306 | REGISTER_KERNEL_BUILDER(Name("TensorListElementShape" ) |
307 | .Device(DEVICE_GPU) |
308 | .HostMemory("element_shape" ), |
309 | TensorListElementShape); |
310 | |
311 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
312 | |
313 | REGISTER_KERNEL_BUILDER(Name("TensorListElementShape" ) |
314 | .Device(DEVICE_DEFAULT) |
315 | .HostMemory("element_shape" ), |
316 | TensorListElementShape); |
317 | |
318 | class TensorListReserve : public OpKernel { |
319 | public: |
320 | explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) { |
321 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
322 | } |
323 | |
324 | void Compute(OpKernelContext* c) override { |
325 | PartialTensorShape element_shape; |
326 | OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape)); |
327 | OP_REQUIRES( |
328 | c, TensorShapeUtils::IsScalar(c->input(1).shape()), |
329 | errors::InvalidArgument( |
330 | "The num_elements to reserve must be a tensor size 1, but got " , |
331 | c->input(1).shape())); |
332 | int32_t num_elements = c->input(1).scalar<int32>()(); |
333 | OP_REQUIRES(c, num_elements >= 0, |
334 | errors::InvalidArgument("The num_elements to reserve must be a " |
335 | "non negative number, but got " , |
336 | num_elements)); |
337 | TensorList output; |
338 | output.element_shape = element_shape; |
339 | output.element_dtype = element_dtype_; |
340 | output.tensors().resize(num_elements, Tensor(DT_INVALID)); |
341 | Tensor* result; |
342 | AllocatorAttributes attr; |
343 | attr.set_on_host(true); |
344 | OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); |
345 | result->scalar<Variant>()() = std::move(output); |
346 | } |
347 | |
348 | private: |
349 | DataType element_dtype_; |
350 | }; |
351 | |
352 | REGISTER_KERNEL_BUILDER(Name("TensorListReserve" ).Device(DEVICE_CPU), |
353 | TensorListReserve); |
354 | |
355 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
356 | |
357 | REGISTER_KERNEL_BUILDER(Name("TensorListReserve" ) |
358 | .Device(DEVICE_GPU) |
359 | .HostMemory("element_shape" ) |
360 | .HostMemory("num_elements" ), |
361 | TensorListReserve); |
362 | |
363 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
364 | |
365 | REGISTER_KERNEL_BUILDER(Name("TensorListReserve" ) |
366 | .Device(DEVICE_DEFAULT) |
367 | .HostMemory("element_shape" ) |
368 | .HostMemory("num_elements" ), |
369 | TensorListReserve); |
370 | |
371 | class TensorListResize : public OpKernel { |
372 | public: |
373 | explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {} |
374 | |
375 | void Compute(OpKernelContext* c) override { |
376 | const TensorList* input_list = nullptr; |
377 | OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list)); |
378 | OP_REQUIRES(c, TensorShapeUtils::IsScalar(c->input(1).shape()), |
379 | errors::InvalidArgument("size must be a scalar" )); |
380 | int32_t size = c->input(1).scalar<int32>()(); |
381 | OP_REQUIRES( |
382 | c, size >= 0, |
383 | errors::InvalidArgument( |
384 | "TensorListSlice expects size to be non-negative. Got: " , size)); |
385 | |
386 | std::unique_ptr<Tensor> maybe_result = |
387 | c->forward_input(0, 0, DT_VARIANT, TensorShape{}, |
388 | c->input_memory_type(0), AllocatorAttributes()); |
389 | if (maybe_result != nullptr) { |
390 | TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>(); |
391 | if (out->RefCountIsOne()) { |
392 | // We are able to forward the input. |
393 | out->tensors().resize(size, Tensor(DT_INVALID)); |
394 | c->set_output(0, *maybe_result); |
395 | return; |
396 | } |
397 | } |
398 | |
399 | // We were not able to forward the input. Will have to resize from scratch. |
400 | Tensor* result; |
401 | AllocatorAttributes attr; |
402 | attr.set_on_host(true); |
403 | OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); |
404 | TensorList output_list; |
405 | output_list.element_shape = input_list->element_shape; |
406 | output_list.element_dtype = input_list->element_dtype; |
407 | output_list.max_num_elements = input_list->max_num_elements; |
408 | if (size > input_list->tensors().size()) { |
409 | output_list.tensors().insert(output_list.tensors().begin(), |
410 | input_list->tensors().begin(), |
411 | input_list->tensors().end()); |
412 | // Add DT_INVALID tensors to the end of the list if the requested size |
413 | // is larger than the list length. |
414 | output_list.tensors().resize(size, Tensor(DT_INVALID)); |
415 | } else { |
416 | output_list.tensors().insert(output_list.tensors().begin(), |
417 | input_list->tensors().begin(), |
418 | input_list->tensors().begin() + size); |
419 | } |
420 | result->scalar<Variant>()() = std::move(output_list); |
421 | } |
422 | }; |
423 | |
424 | REGISTER_KERNEL_BUILDER(Name("TensorListResize" ).Device(DEVICE_CPU), |
425 | TensorListResize); |
426 | |
427 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
428 | |
429 | REGISTER_KERNEL_BUILDER( |
430 | Name("TensorListResize" ).Device(DEVICE_GPU).HostMemory("size" ), |
431 | TensorListResize); |
432 | |
433 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
434 | |
435 | REGISTER_KERNEL_BUILDER( |
436 | Name("TensorListResize" ).Device(DEVICE_DEFAULT).HostMemory("size" ), |
437 | TensorListResize); |
438 | |
439 | class TensorListSetItem : public OpKernel { |
440 | public: |
441 | explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) { |
442 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
443 | } |
444 | |
445 | void Compute(OpKernelContext* c) override { |
446 | const TensorList* l = nullptr; |
447 | OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
448 | OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
449 | errors::InvalidArgument("Invalid data types; op elements " , |
450 | DataTypeString(element_dtype_), |
451 | " but list elements " , |
452 | DataTypeString(l->element_dtype))); |
453 | int32_t index = c->input(1).scalar<int32>()(); |
454 | OP_REQUIRES(c, index < l->tensors().size(), |
455 | errors::InvalidArgument("Trying to modify element " , index, |
456 | " in a list with " , l->tensors().size(), |
457 | " elements." )); |
458 | const Tensor& value = c->input(2); |
459 | OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()), |
460 | errors::InvalidArgument( |
461 | "Tried to set a tensor with incompatible shape at a " |
462 | "list index. Item element shape: " , |
463 | value.shape().DebugString(), |
464 | " list shape: " , l->element_shape.DebugString())); |
465 | TensorList* output_list = nullptr; |
466 | OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
467 | output_list->tensors()[index] = value; |
468 | } |
469 | |
470 | private: |
471 | DataType element_dtype_; |
472 | }; |
473 | |
474 | REGISTER_KERNEL_BUILDER(Name("TensorListSetItem" ).Device(DEVICE_CPU), |
475 | TensorListSetItem); |
476 | |
477 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
478 | |
479 | #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \ |
480 | REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \ |
481 | .TypeConstraint<T>("element_dtype") \ |
482 | .Device(DEVICE_GPU) \ |
483 | .HostMemory("index"), \ |
484 | TensorListSetItem); |
485 | |
486 | TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
487 | TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
488 | TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
489 | REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16) |
490 | #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU |
491 | |
492 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
493 | |
494 | #define REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(T) \ |
495 | REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \ |
496 | .TypeConstraint<T>("element_dtype") \ |
497 | .Device(DEVICE_DEFAULT) \ |
498 | .HostMemory("index"), \ |
499 | TensorListSetItem); |
500 | |
501 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT); |
502 | TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT); |
503 | TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT); |
504 | REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(bfloat16) |
505 | #undef REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT |
506 | |
507 | class TensorListConcatLists : public OpKernel { |
508 | public: |
509 | explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) { |
510 | OP_REQUIRES_OK(c, c->GetAttr("element_dtype" , &element_dtype_)); |
511 | } |
512 | |
513 | void Compute(OpKernelContext* c) override { |
514 | const TensorShape& tl_a_shape = c->input(0).shape(); |
515 | const TensorShape& tl_b_shape = c->input(1).shape(); |
516 | OP_REQUIRES( |
517 | c, tl_a_shape == tl_b_shape, |
518 | errors::InvalidArgument("Incompatible input TensorList tensor shapes: " , |
519 | tl_a_shape.DebugString(), " vs. " , |
520 | tl_b_shape.DebugString())); |
521 | AllocatorAttributes attr; |
522 | std::unique_ptr<Tensor> tl_alias = c->forward_input( |
523 | 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape, |
524 | DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); |
525 | |
526 | // tl_a may be aliased by tl_alias. |
527 | const Tensor& tl_a = c->input(0); |
528 | const Tensor& tl_b = c->input(1); |
529 | |
530 | Tensor* output = nullptr; |
531 | bool ok_to_alias = tl_alias != nullptr; |
532 | if (tl_alias && tl_alias->dtype() == DT_VARIANT && |
533 | tl_alias->NumElements() > 0) { |
534 | auto tl_a_t = tl_alias->flat<Variant>(); |
535 | for (int64_t i = 0; i < tl_alias->NumElements(); ++i) { |
536 | TensorList* aliased = tl_a_t(i).get<TensorList>(); |
537 | if (aliased == nullptr || !aliased->RefCountIsOne()) { |
538 | ok_to_alias = false; |
539 | break; |
540 | } |
541 | } |
542 | if (ok_to_alias) { |
543 | c->set_output(0, *tl_alias); |
544 | output = tl_alias.get(); |
545 | } |
546 | } |
547 | if (!ok_to_alias) { |
548 | // Couldn't alias the entire Tensor. We'll be conservative and not try |
549 | // to alias individual batch entries. |
550 | attr.set_on_host(true); |
551 | OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr)); |
552 | } |
553 | |
554 | auto output_t = output->flat<Variant>(); |
555 | auto tl_a_t = tl_a.flat<Variant>(); |
556 | auto tl_b_t = tl_b.flat<Variant>(); |
557 | |
558 | for (int64_t i = 0; i < tl_a.NumElements(); ++i) { |
559 | const TensorList* l_a = tl_a_t(i).get<TensorList>(); |
560 | const TensorList* l_b = tl_b_t(i).get<TensorList>(); |
561 | OP_REQUIRES( |
562 | c, l_a != nullptr, |
563 | errors::InvalidArgument("input_a is not a TensorList at index " , i, |
564 | ". Saw: '" , tl_a_t(i).DebugString(), "'" )); |
565 | OP_REQUIRES( |
566 | c, l_b != nullptr, |
567 | errors::InvalidArgument("input_b is not a TensorList at index " , i, |
568 | ". Saw: '" , tl_b_t(i).DebugString(), "'" )); |
569 | OP_REQUIRES(c, l_a->element_dtype == element_dtype_, |
570 | errors::InvalidArgument( |
571 | "input_a[" , i, "].dtype != element_dtype. Saw: " , |
572 | DataTypeString(l_a->element_dtype), " vs. " , |
573 | DataTypeString(element_dtype_))); |
574 | OP_REQUIRES(c, l_b->element_dtype == element_dtype_, |
575 | errors::InvalidArgument( |
576 | "input_b[" , i, "].dtype != element_dtype. Saw: " , |
577 | DataTypeString(l_b->element_dtype), " vs. " , |
578 | DataTypeString(element_dtype_))); |
579 | OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape), |
580 | errors::InvalidArgument( |
581 | "input_a and input_b TensorList element shapes are not " |
582 | "identical at index " , |
583 | i, ". Saw " , l_a->element_shape.DebugString(), " vs. " , |
584 | l_b->element_shape.DebugString())); |
585 | if (ok_to_alias) { |
586 | TensorList* out = output_t(i).get<TensorList>(); |
587 | std::copy(l_b->tensors().begin(), l_b->tensors().end(), |
588 | std::back_inserter(out->tensors())); |
589 | } else { |
590 | TensorList out = l_a->Copy(); |
591 | std::copy(l_b->tensors().begin(), l_b->tensors().end(), |
592 | std::back_inserter(out.tensors())); |
593 | output_t(i) = std::move(out); |
594 | } |
595 | } |
596 | } |
597 | |
598 | private: |
599 | DataType element_dtype_; |
600 | }; |
601 | |
602 | REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists" ).Device(DEVICE_CPU), |
603 | TensorListConcatLists); |
604 | |
605 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
606 | |
607 | REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists" ).Device(DEVICE_GPU), |
608 | TensorListConcatLists); |
609 | |
610 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
611 | |
612 | REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists" ).Device(DEVICE_DEFAULT), |
613 | TensorListConcatLists); |
614 | |
615 | #define REGISTER_TENSOR_LIST_OPS_CPU(T) \ |
616 | REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ |
617 | .TypeConstraint<T>("element_dtype") \ |
618 | .Device(DEVICE_CPU), \ |
619 | TensorListStack<CPUDevice, T>) \ |
620 | REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ |
621 | .TypeConstraint<T>("element_dtype") \ |
622 | .Device(DEVICE_CPU), \ |
623 | TensorListGather<CPUDevice, T>) \ |
624 | REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \ |
625 | .TypeConstraint<T>("element_dtype") \ |
626 | .Device(DEVICE_CPU), \ |
627 | TensorListConcat<CPUDevice, T>) \ |
628 | REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \ |
629 | .TypeConstraint<T>("element_dtype") \ |
630 | .Device(DEVICE_CPU), \ |
631 | TensorListConcat<CPUDevice, T>) \ |
632 | REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \ |
633 | .TypeConstraint<T>("element_dtype") \ |
634 | .Device(DEVICE_CPU), \ |
635 | TensorListGetItem<CPUDevice, T>) \ |
636 | REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \ |
637 | .TypeConstraint<T>("element_dtype") \ |
638 | .Device(DEVICE_CPU), \ |
639 | TensorListPopBack<CPUDevice, T>) \ |
640 | REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ |
641 | .TypeConstraint<T>("element_dtype") \ |
642 | .Device(DEVICE_CPU), \ |
643 | TensorListFromTensor<CPUDevice, T>) \ |
644 | REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ |
645 | .TypeConstraint<T>("element_dtype") \ |
646 | .Device(DEVICE_CPU), \ |
647 | TensorListScatter<CPUDevice, T>) \ |
648 | REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \ |
649 | .TypeConstraint<T>("element_dtype") \ |
650 | .Device(DEVICE_CPU), \ |
651 | TensorListScatter<CPUDevice, T>) \ |
652 | REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \ |
653 | .TypeConstraint<T>("element_dtype") \ |
654 | .Device(DEVICE_CPU), \ |
655 | TensorListScatterIntoExistingList<CPUDevice, T>) \ |
656 | REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \ |
657 | .TypeConstraint<T>("element_dtype") \ |
658 | .Device(DEVICE_CPU), \ |
659 | TensorListSplit<CPUDevice, T>) \ |
660 | REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \ |
661 | .TypeConstraint<T>("element_dtype") \ |
662 | .Device(DEVICE_CPU), \ |
663 | TensorListPushBackBatch<CPUDevice, T>) |
664 | |
665 | TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU); |
666 | REGISTER_TENSOR_LIST_OPS_CPU(quint8); |
667 | REGISTER_TENSOR_LIST_OPS_CPU(qint8); |
668 | REGISTER_TENSOR_LIST_OPS_CPU(quint16); |
669 | REGISTER_TENSOR_LIST_OPS_CPU(qint16); |
670 | REGISTER_TENSOR_LIST_OPS_CPU(qint32); |
671 | REGISTER_TENSOR_LIST_OPS_CPU(Variant); |
672 | |
673 | #undef REGISTER_TENSOR_LIST_OPS_CPU |
674 | |
675 | #define REGISTER_TENSOR_LIST_OPS_CPU(T) |
676 | |
677 | REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, |
678 | TensorList, |
679 | TensorListBinaryAdd<CPUDevice>); |
680 | |
681 | REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, |
682 | DEVICE_CPU, TensorList, |
683 | TensorListZerosLike<CPUDevice>); |
684 | |
685 | static Status TensorListDeviceCopy( |
686 | const TensorList& from, TensorList* to, |
687 | const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { |
688 | to->element_shape = from.element_shape; |
689 | to->element_dtype = from.element_dtype; |
690 | to->max_num_elements = from.max_num_elements; |
691 | to->tensors().reserve(from.tensors().size()); |
692 | for (const Tensor& t : from.tensors()) { |
693 | to->tensors().emplace_back(t.dtype()); |
694 | if (t.dtype() != DT_INVALID) { |
695 | TF_RETURN_IF_ERROR(copy(t, &to->tensors().back())); |
696 | } |
697 | } |
698 | return OkStatus(); |
699 | } |
700 | |
701 | #define REGISTER_LIST_COPY(DIRECTION) \ |
702 | INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \ |
703 | TensorListDeviceCopy) |
704 | |
705 | REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); |
706 | REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); |
707 | REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); |
708 | |
709 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName); |
710 | |
711 | #define REGISTER_TENSOR_LIST_OPS_DEFAULT(T) \ |
712 | REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ |
713 | .TypeConstraint<T>("element_dtype") \ |
714 | .HostMemory("element_shape") \ |
715 | .Device(DEVICE_DEFAULT), \ |
716 | TensorListStack<CPUDevice, T>) \ |
717 | REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ |
718 | .TypeConstraint<T>("element_dtype") \ |
719 | .HostMemory("indices") \ |
720 | .HostMemory("element_shape") \ |
721 | .Device(DEVICE_DEFAULT), \ |
722 | TensorListGather<CPUDevice, T>) \ |
723 | REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \ |
724 | .TypeConstraint<T>("element_dtype") \ |
725 | .HostMemory("lengths") \ |
726 | .Device(DEVICE_DEFAULT), \ |
727 | TensorListConcat<CPUDevice, T>) \ |
728 | REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \ |
729 | .TypeConstraint<T>("element_dtype") \ |
730 | .HostMemory("leading_dims") \ |
731 | .HostMemory("element_shape") \ |
732 | .HostMemory("lengths") \ |
733 | .Device(DEVICE_DEFAULT), \ |
734 | TensorListConcat<CPUDevice, T>) \ |
735 | REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \ |
736 | .TypeConstraint<T>("element_dtype") \ |
737 | .Device(DEVICE_DEFAULT) \ |
738 | .HostMemory("index") \ |
739 | .HostMemory("element_shape"), \ |
740 | TensorListGetItem<CPUDevice, T>) \ |
741 | REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \ |
742 | .TypeConstraint<T>("element_dtype") \ |
743 | .Device(DEVICE_DEFAULT) \ |
744 | .HostMemory("element_shape"), \ |
745 | TensorListPopBack<CPUDevice, T>) \ |
746 | REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \ |
747 | .TypeConstraint<T>("element_dtype") \ |
748 | .Device(DEVICE_DEFAULT), \ |
749 | TensorListPushBackBatch<CPUDevice, T>) \ |
750 | REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ |
751 | .TypeConstraint<T>("element_dtype") \ |
752 | .Device(DEVICE_DEFAULT) \ |
753 | .HostMemory("element_shape"), \ |
754 | TensorListFromTensor<CPUDevice, T>) \ |
755 | REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ |
756 | .TypeConstraint<T>("element_dtype") \ |
757 | .Device(DEVICE_DEFAULT) \ |
758 | .HostMemory("element_shape") \ |
759 | .HostMemory("indices"), \ |
760 | TensorListScatter<CPUDevice, T>) \ |
761 | REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \ |
762 | .TypeConstraint<T>("element_dtype") \ |
763 | .Device(DEVICE_DEFAULT) \ |
764 | .HostMemory("element_shape") \ |
765 | .HostMemory("num_elements") \ |
766 | .HostMemory("indices"), \ |
767 | TensorListScatter<CPUDevice, T>) \ |
768 | REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \ |
769 | .TypeConstraint<T>("element_dtype") \ |
770 | .Device(DEVICE_DEFAULT) \ |
771 | .HostMemory("indices"), \ |
772 | TensorListScatterIntoExistingList<CPUDevice, T>) \ |
773 | REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \ |
774 | .TypeConstraint<T>("element_dtype") \ |
775 | .Device(DEVICE_DEFAULT) \ |
776 | .HostMemory("element_shape") \ |
777 | .HostMemory("lengths"), \ |
778 | TensorListSplit<CPUDevice, T>) |
779 | |
780 | TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_DEFAULT); |
781 | TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_DEFAULT); |
782 | TF_CALL_bfloat16(REGISTER_TENSOR_LIST_OPS_DEFAULT); |
783 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_DEFAULT); |
784 | |
785 | #undef REGISTER_TENSOR_LIST_OPS_DEFAULT |
786 | } // namespace tensorflow |
787 | |