1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <limits> |
21 | #include <vector> |
22 | // TODO(b/31496047): Fix non-standard include order. |
23 | #include <numeric> // clang-format off |
24 | |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | #include "tensorflow/core/framework/bounds_check.h" |
27 | #include "tensorflow/core/framework/op_kernel.h" |
28 | #include "tensorflow/core/framework/register_types.h" |
29 | #include "tensorflow/core/framework/resource_mgr.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_util.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/kernels/concat_lib.h" |
35 | #include "tensorflow/core/kernels/split_lib.h" |
36 | #include "tensorflow/core/kernels/tensor_array.h" |
37 | #include "tensorflow/core/lib/core/errors.h" |
38 | #include "tensorflow/core/lib/core/refcount.h" |
39 | #include "tensorflow/core/lib/strings/strcat.h" |
40 | #include "tensorflow/core/platform/dynamic_annotations.h" |
41 | #include "tensorflow/core/platform/logging.h" |
42 | #include "tensorflow/core/platform/thread_annotations.h" |
43 | #include "tensorflow/core/platform/types.h" |
44 | #include "tensorflow/core/util/ptr_util.h" |
45 | |
46 | typedef Eigen::ThreadPoolDevice CPUDevice; |
47 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
48 | typedef Eigen::GpuDevice GPUDevice; |
49 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
50 | |
51 | // clang-format on |
52 | |
53 | namespace tensorflow { |
54 | |
55 | Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { |
56 | { |
57 | Tensor tensor; |
58 | // Assuming that handle is the input at index 0. |
59 | if (IsRefType(ctx->input_dtype(0))) { |
60 | tensor = ctx->mutable_input(0, false); |
61 | } else { |
62 | tensor = ctx->input(0); |
63 | } |
64 | if (tensor.NumElements() != 2) { |
65 | return errors::InvalidArgument( |
66 | "Tensor array handle must be 2-element vector, but had shape: " , |
67 | tensor.shape().DebugString()); |
68 | } |
69 | auto h = tensor.flat<tstring>(); |
70 | *container = h(0); |
71 | *ta_handle = h(1); |
72 | } |
73 | return OkStatus(); |
74 | } |
75 | |
76 | Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { |
77 | string container; |
78 | string ta_handle; |
79 | if (ctx->input_dtype(0) != DT_RESOURCE) { |
80 | TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); |
81 | ResourceMgr* rm = ctx->resource_manager(); |
82 | if (rm == nullptr) return errors::Internal("No resource manager." ); |
83 | TF_RETURN_IF_ERROR( |
84 | ctx->step_container()->Lookup(rm, container + ta_handle, tensor_array)); |
85 | return OkStatus(); |
86 | } else { |
87 | return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array); |
88 | } |
89 | } |
90 | |
91 | Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { |
92 | const Tensor* flow_control; |
93 | TF_RETURN_IF_ERROR(ctx->input("flow_in" , &flow_control)); |
94 | if (set_output) { |
95 | TF_RETURN_IF_ERROR(ctx->set_output("flow_out" , *flow_control)); |
96 | } |
97 | return OkStatus(); |
98 | } |
99 | |
100 | // CREATION ******************************************************************* |
101 | |
102 | // Virtual class for shared behavior between TensorArrayOp and |
103 | // TensorArrayGradOp. |
104 | class TensorArrayCreationOp : public OpKernel { |
105 | public: |
106 | explicit TensorArrayCreationOp(OpKernelConstruction* context) |
107 | : OpKernel(context), device_type_(context->device_type()) {} |
108 | |
109 | void Compute(OpKernelContext* ctx) override { |
110 | Tensor tensor_array_output_handle; |
111 | |
112 | AllocatorAttributes alloc_attr; |
113 | alloc_attr.set_on_host(true); |
114 | OP_REQUIRES_OK(ctx, ctx->allocate_temp( |
115 | tensorflow::DT_STRING, tensorflow::TensorShape({2}), |
116 | &tensor_array_output_handle, alloc_attr)); |
117 | // Store the handle in a per-step container of the RM. |
118 | ResourceMgr* rm = ctx->resource_manager(); |
119 | OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager." )); |
120 | |
121 | TensorArray* output_tensor_array; |
122 | OP_REQUIRES_OK(ctx, CreateTensorArray(ctx, rm, &tensor_array_output_handle, |
123 | &output_tensor_array)); |
124 | if (IsRefType(ctx->expected_output_dtype(0))) { |
125 | ctx->set_output_ref(0, output_tensor_array->mu(), |
126 | output_tensor_array->handle()); |
127 | } else if (ctx->expected_output_dtype(0) == DT_STRING) { |
128 | ctx->set_output(0, *output_tensor_array->handle()); |
129 | } else { |
130 | Tensor* handle; |
131 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); |
132 | handle->flat<ResourceHandle>()(0) = |
133 | output_tensor_array->resource_handle(ctx); |
134 | } |
135 | if (ctx->num_outputs() == 2) { |
136 | // Create the flow output. |
137 | Tensor* flow; |
138 | OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow)); |
139 | if (device_type_ == DEVICE_CPU) { |
140 | // Value doesn't matter, but this makes msan not complaint about |
141 | // copying an uninitialized value. To do this on GPU would require |
142 | // a kernel launch or a host->device memcpy, so we avoid that. |
143 | flow->flat<float>()(0) = 0; |
144 | } |
145 | } |
146 | } |
147 | |
148 | protected: |
149 | virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, |
150 | Tensor* tensor_array_output_handle, |
151 | TensorArray** output_tensor_array) = 0; |
152 | |
153 | private: |
154 | const DeviceType device_type_; |
155 | }; |
156 | |
157 | // A per-run local tensor array. The tensor array uses a "per-step" resource |
158 | // manager which ensures that correct garbage collection on error or |
159 | // successful completion. |
160 | class TensorArrayOp : public TensorArrayCreationOp { |
161 | public: |
162 | explicit TensorArrayOp(OpKernelConstruction* context) |
163 | : TensorArrayCreationOp(context) { |
164 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
165 | OP_REQUIRES_OK(context, context->GetAttr("element_shape" , &element_shape_)); |
166 | OP_REQUIRES_OK(context, context->GetAttr("dynamic_size" , &dynamic_size_)); |
167 | // The HasAttr check is for backwards compatibility with older op |
168 | // versions which do not have this attribute. |
169 | if (context->HasAttr("identical_element_shapes" )) { |
170 | OP_REQUIRES_OK(context, context->GetAttr("identical_element_shapes" , |
171 | &identical_element_shapes_)); |
172 | } else { |
173 | identical_element_shapes_ = false; |
174 | } |
175 | OP_REQUIRES_OK(context, |
176 | context->GetAttr("clear_after_read" , &clear_after_read_)); |
177 | OP_REQUIRES_OK(context, |
178 | context->GetAttr("tensor_array_name" , &tensor_array_name_)); |
179 | if (tensor_array_name_.empty()) tensor_array_name_ = name(); |
180 | } |
181 | |
182 | Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, |
183 | Tensor* tensor_array_output_handle, |
184 | TensorArray** output_tensor_array) override { |
185 | const Tensor* tensor_size; |
186 | TF_RETURN_IF_ERROR(ctx->input("size" , &tensor_size)); |
187 | |
188 | if (!TensorShapeUtils::IsScalar(tensor_size->shape())) { |
189 | return errors::InvalidArgument( |
190 | "TensorArray size must be scalar, but had shape: " , |
191 | tensor_size->shape().DebugString()); |
192 | } |
193 | const int32_t size = tensor_size->scalar<int32>()(); |
194 | if (size < 0) { |
195 | return errors::InvalidArgument("Size should be >= 0." ); |
196 | } |
197 | |
198 | auto handle = tensor_array_output_handle->flat<tstring>(); |
199 | string unique_tensor_array_name = |
200 | strings::StrCat(tensor_array_name_, "_" , |
201 | TensorArray::tensor_array_counter.fetch_add(1)); |
202 | handle(0) = "_tensor_arrays" ; |
203 | handle(1) = unique_tensor_array_name; |
204 | |
205 | auto key = strings::StrCat(handle(0), unique_tensor_array_name); |
206 | |
207 | TensorArray* tensor_array = new TensorArray( |
208 | key, dtype_, *tensor_array_output_handle, size, element_shape_, |
209 | identical_element_shapes_, dynamic_size_, |
210 | false /* multiple_writes_aggregate */, false /* is_grad */, |
211 | -1 /* marked_size */, clear_after_read_); |
212 | |
213 | TF_RETURN_IF_ERROR(ctx->step_container()->Create(rm, key, tensor_array)); |
214 | |
215 | *output_tensor_array = tensor_array; |
216 | |
217 | return OkStatus(); |
218 | } |
219 | |
220 | private: |
221 | DataType dtype_; |
222 | PartialTensorShape element_shape_; |
223 | bool identical_element_shapes_; |
224 | bool dynamic_size_; |
225 | bool clear_after_read_; |
226 | string tensor_array_name_; // The name used to create the TensorArray. |
227 | |
228 | TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); |
229 | }; |
230 | |
231 | REGISTER_KERNEL_BUILDER(Name("TensorArray" ).Device(DEVICE_CPU), TensorArrayOp); |
232 | REGISTER_KERNEL_BUILDER(Name("TensorArrayV2" ).Device(DEVICE_CPU), |
233 | TensorArrayOp); |
234 | REGISTER_KERNEL_BUILDER(Name("TensorArrayV3" ).Device(DEVICE_CPU), |
235 | TensorArrayOp); |
236 | |
237 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
238 | |
239 | #define REGISTER_GPU(type) \ |
240 | REGISTER_KERNEL_BUILDER(Name("TensorArray") \ |
241 | .Device(DEVICE_GPU) \ |
242 | .TypeConstraint<type>("dtype") \ |
243 | .HostMemory("size") \ |
244 | .HostMemory("handle"), \ |
245 | TensorArrayOp); \ |
246 | REGISTER_KERNEL_BUILDER(Name("TensorArrayV2") \ |
247 | .Device(DEVICE_GPU) \ |
248 | .TypeConstraint<type>("dtype") \ |
249 | .HostMemory("size") \ |
250 | .HostMemory("handle"), \ |
251 | TensorArrayOp); \ |
252 | REGISTER_KERNEL_BUILDER(Name("TensorArrayV3") \ |
253 | .Device(DEVICE_GPU) \ |
254 | .TypeConstraint<type>("dtype") \ |
255 | .HostMemory("size") \ |
256 | .HostMemory("handle"), \ |
257 | TensorArrayOp); |
258 | |
259 | TF_CALL_int64(REGISTER_GPU); |
260 | TF_CALL_bfloat16(REGISTER_GPU); |
261 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
262 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
263 | #undef REGISTER_GPU |
264 | |
265 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
266 | |
267 | // GRADIENT ******************************************************************* |
268 | // Note that this op may have an optional third input. If present, it represents |
269 | // a shape value. It indicates that element shape of this gradient array is that |
270 | // shape value concatenated with the element shape of the original tensor array. |
271 | // See TensorArrayGradWithShape. |
272 | class TensorArrayGradOp : public TensorArrayCreationOp { |
273 | public: |
274 | explicit TensorArrayGradOp(OpKernelConstruction* context) |
275 | : TensorArrayCreationOp(context) { |
276 | OP_REQUIRES_OK(context, context->GetAttr("source" , &source_)); |
277 | } |
278 | |
279 | Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, |
280 | Tensor* tensor_array_output_handle, |
281 | TensorArray** output_tensor_array) override { |
282 | string container; |
283 | string tensor_array_name; |
284 | if (ctx->input_dtype(0) != DT_RESOURCE) { |
285 | TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name)); |
286 | if (container != "_tensor_arrays" ) { |
287 | return errors::InvalidArgument( |
288 | "Input container should be '_tensor_arrays', but received '" , |
289 | container, "'" ); |
290 | } |
291 | } else { |
292 | container = "_tensor_arrays" ; |
293 | const auto& resource = ctx->input(0).flat<ResourceHandle>()(0); |
294 | if (StringPiece(resource.name()).substr(0, container.size()) != |
295 | container) { |
296 | return errors::InvalidArgument("Wrong input container. " , |
297 | resource.name()); |
298 | } |
299 | tensor_array_name = |
300 | string(StringPiece(resource.name()).substr(container.size())); |
301 | } |
302 | |
303 | auto output_handle = tensor_array_output_handle->flat<tstring>(); |
304 | output_handle(0) = "_tensor_array_grads" ; |
305 | output_handle(1) = strings::StrCat(tensor_array_name, "@" , source_); |
306 | |
307 | TensorArray* tensor_array; |
308 | TF_RETURN_IF_ERROR(ctx->step_container()->Lookup( |
309 | rm, strings::StrCat(container, tensor_array_name), &tensor_array)); |
310 | core::ScopedUnref unref(tensor_array); |
311 | |
312 | // Once gradients are being calculated, the forward TensorArray |
313 | // may no longer be resized by new Writes. |
314 | tensor_array->DisableDynamicSize(); |
315 | |
316 | int32_t array_size = 0; |
317 | int32_t marked_size = 0; |
318 | TF_RETURN_IF_ERROR(tensor_array->Size(&array_size)); |
319 | TF_RETURN_IF_ERROR(tensor_array->MarkedSize(&marked_size)); |
320 | |
321 | if (array_size < 0) { |
322 | return errors::InvalidArgument("ArraySize should be >= 0." ); |
323 | } |
324 | if (!tensor_array->GradientsAllowed()) { |
325 | return errors::InvalidArgument( |
326 | "Unable to create a gradients TensorArray for " , tensor_array_name, |
327 | ". Perhaps you used the multiple_writes_aggregate flag on a " |
328 | "previous write? Gradient calculation is impossible when multiple " |
329 | "writes are performed to the same index." ); |
330 | } |
331 | TensorShape shape_to_prepend; |
332 | auto element_shape = PartialTensorShape(); |
333 | if (ctx->num_inputs() > 2) { |
334 | TF_RETURN_IF_ERROR(tensor::MakeShape(ctx->input(2), &shape_to_prepend)); |
335 | auto ta_element_shape = tensor_array->ElemShape(); |
336 | if (!ta_element_shape.unknown_rank()) { |
337 | std::vector<int64_t> dims; |
338 | for (auto dim : shape_to_prepend) { |
339 | dims.push_back(dim.size); |
340 | } |
341 | for (auto dim : ta_element_shape) { |
342 | dims.push_back(dim.size); |
343 | } |
344 | TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( |
345 | gtl::ArraySlice<int64_t>(dims), &element_shape)); |
346 | } |
347 | } else { |
348 | element_shape = tensor_array->ElemShape(); |
349 | } |
350 | |
351 | const auto key = strings::StrCat(output_handle(0), output_handle(1)); |
352 | auto creator = [key, tensor_array, array_size, marked_size, element_shape, |
353 | shape_to_prepend, |
354 | tensor_array_output_handle](TensorArray** ret) -> Status { |
355 | *ret = new TensorArray( |
356 | key, tensor_array->ElemType(), *tensor_array_output_handle, |
357 | array_size, element_shape, tensor_array->HasIdenticalElementShapes(), |
358 | false /* dynamic_size */, true /* multiple_writes_aggregate */, |
359 | true /* is_grad */, marked_size /* marked_size */, |
360 | true /* close_after_read */); |
361 | return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend); |
362 | }; |
363 | |
364 | Status s = ctx->step_container()->LookupOrCreate<TensorArray>( |
365 | rm, key, output_tensor_array, creator); |
366 | (*output_tensor_array)->Unref(); |
367 | |
368 | return s; |
369 | } |
370 | |
371 | private: |
372 | // The gradient source for creating the given |
373 | // gradient TensorArray. This should be unique to each gradients |
374 | // call. Typical values look like "gradients", "gradients_1", ... |
375 | string source_; |
376 | |
377 | TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayGradOp); |
378 | }; |
379 | |
380 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad" ).Device(DEVICE_CPU), |
381 | TensorArrayGradOp); |
382 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2" ).Device(DEVICE_CPU), |
383 | TensorArrayGradOp); |
384 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3" ).Device(DEVICE_CPU), |
385 | TensorArrayGradOp); |
386 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape" ).Device(DEVICE_CPU), |
387 | TensorArrayGradOp); |
388 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad" ) |
389 | .Device(DEVICE_GPU) |
390 | .HostMemory("handle" ) |
391 | .HostMemory("grad_handle" ), |
392 | TensorArrayGradOp); |
393 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2" ) |
394 | .Device(DEVICE_GPU) |
395 | .HostMemory("handle" ) |
396 | .HostMemory("grad_handle" ), |
397 | TensorArrayGradOp); |
398 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3" ) |
399 | .Device(DEVICE_GPU) |
400 | .HostMemory("handle" ) |
401 | .HostMemory("grad_handle" ), |
402 | TensorArrayGradOp); |
403 | REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape" ) |
404 | .Device(DEVICE_GPU) |
405 | .HostMemory("handle" ) |
406 | .HostMemory("shape_to_prepend" ) |
407 | .HostMemory("grad_handle" ), |
408 | TensorArrayGradOp); |
409 | |
410 | // WRITE ********************************************************************** |
411 | |
412 | template <typename Device, typename T> |
413 | class TensorArrayWriteOp : public OpKernel { |
414 | public: |
415 | explicit TensorArrayWriteOp(OpKernelConstruction* context) |
416 | : OpKernel(context) {} |
417 | |
418 | void Compute(OpKernelContext* ctx) override { |
419 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true)); |
420 | |
421 | const Tensor* tensor_index; |
422 | const Tensor* tensor_value; |
423 | OP_REQUIRES_OK(ctx, ctx->input("index" , &tensor_index)); |
424 | OP_REQUIRES_OK(ctx, ctx->input("value" , &tensor_value)); |
425 | |
426 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_index->shape()), |
427 | errors::InvalidArgument( |
428 | "TensorArray index must be scalar, but had shape: " , |
429 | tensor_index->shape().DebugString())); |
430 | |
431 | TensorArray* tensor_array = nullptr; |
432 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
433 | core::ScopedUnref unref(tensor_array); |
434 | const int32_t index = tensor_index->scalar<int32>()(); |
435 | OP_REQUIRES( |
436 | ctx, tensor_value->dtype() == tensor_array->ElemType(), |
437 | errors::InvalidArgument("TensorArray dtype is " , |
438 | DataTypeString(tensor_array->ElemType()), |
439 | " but Op is trying to write dtype " , |
440 | DataTypeString(tensor_value->dtype()), "." )); |
441 | Status s = |
442 | tensor_array->WriteOrAggregate<Device, T>(ctx, index, tensor_value); |
443 | OP_REQUIRES_OK(ctx, s); |
444 | } |
445 | }; |
446 | |
447 | #define REGISTER_WRITE(type) \ |
448 | REGISTER_KERNEL_BUILDER( \ |
449 | Name("TensorArrayWrite").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
450 | TensorArrayWriteOp<CPUDevice, type>); \ |
451 | REGISTER_KERNEL_BUILDER( \ |
452 | Name("TensorArrayWriteV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
453 | TensorArrayWriteOp<CPUDevice, type>); \ |
454 | REGISTER_KERNEL_BUILDER( \ |
455 | Name("TensorArrayWriteV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
456 | TensorArrayWriteOp<CPUDevice, type>); |
457 | |
458 | TF_CALL_ALL_TYPES(REGISTER_WRITE); |
459 | |
460 | #undef REGISTER_WRITE |
461 | |
462 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
463 | |
464 | #define REGISTER_GPU(type) \ |
465 | REGISTER_KERNEL_BUILDER(Name("TensorArrayWrite") \ |
466 | .Device(DEVICE_GPU) \ |
467 | .TypeConstraint<type>("T") \ |
468 | .HostMemory("handle") \ |
469 | .HostMemory("index"), \ |
470 | TensorArrayWriteOp<GPUDevice, type>); \ |
471 | REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV2") \ |
472 | .Device(DEVICE_GPU) \ |
473 | .TypeConstraint<type>("T") \ |
474 | .HostMemory("handle") \ |
475 | .HostMemory("index"), \ |
476 | TensorArrayWriteOp<GPUDevice, type>); \ |
477 | REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV3") \ |
478 | .Device(DEVICE_GPU) \ |
479 | .TypeConstraint<type>("T") \ |
480 | .HostMemory("handle") \ |
481 | .HostMemory("index"), \ |
482 | TensorArrayWriteOp<GPUDevice, type>); |
483 | |
484 | TF_CALL_bfloat16(REGISTER_GPU); |
485 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
486 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
487 | #undef REGISTER_GPU |
488 | |
489 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
490 | |
491 | // READ *********************************************************************** |
492 | |
493 | template <typename Device, typename T> |
494 | class TensorArrayReadOp : public OpKernel { |
495 | public: |
496 | explicit TensorArrayReadOp(OpKernelConstruction* context) |
497 | : OpKernel(context) { |
498 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
499 | } |
500 | |
501 | void Compute(OpKernelContext* ctx) override { |
502 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false)); |
503 | |
504 | const Tensor* tensor_index; |
505 | OP_REQUIRES_OK(ctx, ctx->input("index" , &tensor_index)); |
506 | |
507 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_index->shape()), |
508 | errors::InvalidArgument( |
509 | "TensorArray index must be scalar, but had shape: " , |
510 | tensor_index->shape().DebugString())); |
511 | |
512 | TensorArray* tensor_array = nullptr; |
513 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
514 | core::ScopedUnref unref(tensor_array); |
515 | |
516 | const int32_t index = tensor_index->scalar<int32>()(); |
517 | OP_REQUIRES( |
518 | ctx, dtype_ == tensor_array->ElemType(), |
519 | errors::InvalidArgument( |
520 | "TensorArray dtype is " , DataTypeString(tensor_array->ElemType()), |
521 | " but Op requested dtype " , DataTypeString(dtype_), "." )); |
522 | Tensor value; |
523 | Status s = tensor_array->Read<Device, T>(ctx, index, &value); |
524 | OP_REQUIRES_OK(ctx, s); |
525 | ctx->set_output(0, value); |
526 | } |
527 | |
528 | private: |
529 | DataType dtype_; |
530 | }; |
531 | |
532 | #define REGISTER_READ(type) \ |
533 | REGISTER_KERNEL_BUILDER(Name("TensorArrayRead") \ |
534 | .Device(DEVICE_CPU) \ |
535 | .TypeConstraint<type>("dtype"), \ |
536 | TensorArrayReadOp<CPUDevice, type>); \ |
537 | REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2") \ |
538 | .Device(DEVICE_CPU) \ |
539 | .TypeConstraint<type>("dtype"), \ |
540 | TensorArrayReadOp<CPUDevice, type>); \ |
541 | REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ |
542 | .Device(DEVICE_CPU) \ |
543 | .TypeConstraint<type>("dtype"), \ |
544 | TensorArrayReadOp<CPUDevice, type>); |
545 | |
546 | TF_CALL_ALL_TYPES(REGISTER_READ) |
547 | |
548 | #undef REGISTER_READ |
549 | |
550 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
551 | |
552 | #define REGISTER_GPU(type) \ |
553 | REGISTER_KERNEL_BUILDER(Name("TensorArrayRead") \ |
554 | .Device(DEVICE_GPU) \ |
555 | .TypeConstraint<type>("dtype") \ |
556 | .HostMemory("handle") \ |
557 | .HostMemory("index"), \ |
558 | TensorArrayReadOp<GPUDevice, type>); \ |
559 | REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2") \ |
560 | .Device(DEVICE_GPU) \ |
561 | .TypeConstraint<type>("dtype") \ |
562 | .HostMemory("handle") \ |
563 | .HostMemory("index"), \ |
564 | TensorArrayReadOp<GPUDevice, type>); \ |
565 | REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ |
566 | .Device(DEVICE_GPU) \ |
567 | .TypeConstraint<type>("dtype") \ |
568 | .HostMemory("handle") \ |
569 | .HostMemory("index"), \ |
570 | TensorArrayReadOp<GPUDevice, type>); |
571 | |
572 | TF_CALL_int64(REGISTER_GPU); |
573 | TF_CALL_bfloat16(REGISTER_GPU); |
574 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
575 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
576 | #undef REGISTER_GPU |
577 | |
578 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
579 | |
580 | // PACK and GATHER ************************************************************ |
581 | |
582 | // Concatenate the elements in a TensorArray. All elements must be |
583 | // defined and have the same shape. |
584 | template <typename Device, typename T, bool LEGACY_PACK> |
585 | class TensorArrayPackOrGatherOp : public OpKernel { |
586 | public: |
587 | typedef typename TTypes<T, 2>::ConstMatrix ConstMatrix; |
588 | typedef std::vector<std::unique_ptr<ConstMatrix> > ConstMatrixVector; |
589 | |
590 | explicit TensorArrayPackOrGatherOp(OpKernelConstruction* context) |
591 | : OpKernel(context) { |
592 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
593 | OP_REQUIRES_OK(context, context->GetAttr("element_shape" , &element_shape_)); |
594 | } |
595 | |
596 | void Compute(OpKernelContext* ctx) override { |
597 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false)); |
598 | |
599 | TensorArray* tensor_array = nullptr; |
600 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
601 | |
602 | core::ScopedUnref unref(tensor_array); |
603 | OP_REQUIRES( |
604 | ctx, dtype_ == tensor_array->ElemType(), |
605 | errors::InvalidArgument( |
606 | "TensorArray dtype is " , DataTypeString(tensor_array->ElemType()), |
607 | " but Op requested dtype " , DataTypeString(dtype_), "." )); |
608 | |
609 | // Ensure new element shape is compatible with the one stored in the |
610 | // TensorArray. |
611 | OP_REQUIRES_OK(ctx, tensor_array->SetElemShape(element_shape_)); |
612 | |
613 | int32_t num_indices; |
614 | std::vector<Tensor> values; |
615 | std::vector<int32> indices; |
616 | if (LEGACY_PACK) { |
617 | OP_REQUIRES_OK(ctx, tensor_array->PackOrConcatSize(&num_indices)); |
618 | indices.resize(num_indices); |
619 | std::iota(indices.begin(), indices.end(), 0); |
620 | } else { |
621 | const Tensor* tensor_indices; |
622 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &tensor_indices)); |
623 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_indices->shape()), |
624 | errors::InvalidArgument( |
625 | "Expected indices to be a vector, but received shape: " , |
626 | tensor_indices->shape().DebugString())); |
627 | const auto indices_t = tensor_indices->vec<int32>(); |
628 | num_indices = tensor_indices->NumElements(); |
629 | indices.resize(num_indices); |
630 | std::copy(indices_t.data(), indices_t.data() + num_indices, |
631 | indices.begin()); |
632 | } |
633 | |
634 | // If there are no elements to return, return a zero-element Tensor with |
635 | // shape [0] + element_shape_ |
636 | if (num_indices == 0) { |
637 | OP_REQUIRES(ctx, element_shape_.IsFullyDefined(), |
638 | errors::Unimplemented( |
639 | "TensorArray has size zero, but element shape " , |
640 | element_shape_.DebugString(), |
641 | " is not fully defined. " |
642 | "Currently only static shapes are supported when packing " |
643 | "zero-size TensorArrays." )); |
644 | TensorShape empty_shape; |
645 | element_shape_.AsTensorShape(&empty_shape); |
646 | empty_shape.InsertDim(0, 0); |
647 | Tensor* empty_unused; |
648 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &empty_unused)); |
649 | return; |
650 | } |
651 | |
652 | // Read all the Tensors into a vector to keep track of their memory. |
653 | Status s = tensor_array->ReadMany<Device, T>(ctx, indices, &values); |
654 | OP_REQUIRES_OK(ctx, s); |
655 | |
656 | const Tensor* value_0_t = &values[0]; |
657 | |
658 | OP_REQUIRES( |
659 | ctx, element_shape_.IsCompatibleWith(value_0_t->shape()), |
660 | errors::InvalidArgument("TensorArray was passed element_shape " , |
661 | element_shape_.DebugString(), |
662 | " which does not match the Tensor at index 0: " , |
663 | value_0_t->shape().DebugString())); |
664 | |
665 | TensorShape output_shape(value_0_t->shape()); |
666 | output_shape.InsertDim(0, num_indices); |
667 | |
668 | Tensor* output_tensor = nullptr; |
669 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); |
670 | |
671 | // If output_tensor is empty, there is nothing to concatenate so return it. |
672 | if (output_shape.num_elements() == 0) { |
673 | return; |
674 | } |
675 | |
676 | ConstMatrixVector input_tensors_flat; |
677 | input_tensors_flat.reserve(num_indices); |
678 | auto output_flat = |
679 | output_tensor->shaped<T, 2>({1, output_shape.num_elements()}); |
680 | |
681 | // Insert the first value |
682 | input_tensors_flat.push_back(MakeUnique<ConstMatrix>( |
683 | value_0_t->shaped<T, 2>({1, value_0_t->NumElements()}))); |
684 | |
685 | for (int i = 1; i < num_indices; ++i) { |
686 | const Tensor* value_t = &values[i]; |
687 | OP_REQUIRES( |
688 | ctx, value_0_t->shape() == value_t->shape(), |
689 | errors::InvalidArgument( |
690 | "TensorArray has inconsistent shapes. Index 0 has shape: " , |
691 | value_0_t->shape().DebugString(), " but index " , i, |
692 | " has shape: " , value_t->shape().DebugString())); |
693 | input_tensors_flat.push_back(MakeUnique<ConstMatrix>( |
694 | value_t->shaped<T, 2>({1, value_t->NumElements()}))); |
695 | } |
696 | |
697 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
698 | if (std::is_same<Device, GPUDevice>::value) { |
699 | ConcatGPU<T>(ctx, input_tensors_flat, output_tensor, &output_flat); |
700 | return; |
701 | } |
702 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
703 | ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat); |
704 | } |
705 | |
706 | private: |
707 | DataType dtype_; |
708 | PartialTensorShape element_shape_; |
709 | }; |
710 | |
711 | #define REGISTER_GATHER_AND_PACK(type) \ |
712 | REGISTER_KERNEL_BUILDER( \ |
713 | Name("TensorArrayPack") \ |
714 | .Device(DEVICE_CPU) \ |
715 | .TypeConstraint<type>("dtype"), \ |
716 | TensorArrayPackOrGatherOp<CPUDevice, type, true /* LEGACY_PACK */>); \ |
717 | REGISTER_KERNEL_BUILDER( \ |
718 | Name("TensorArrayGather") \ |
719 | .Device(DEVICE_CPU) \ |
720 | .TypeConstraint<type>("dtype"), \ |
721 | TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \ |
722 | REGISTER_KERNEL_BUILDER( \ |
723 | Name("TensorArrayGatherV2") \ |
724 | .Device(DEVICE_CPU) \ |
725 | .TypeConstraint<type>("dtype"), \ |
726 | TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \ |
727 | REGISTER_KERNEL_BUILDER( \ |
728 | Name("TensorArrayGatherV3") \ |
729 | .Device(DEVICE_CPU) \ |
730 | .TypeConstraint<type>("dtype"), \ |
731 | TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); |
732 | |
733 | TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK); |
734 | TF_CALL_variant(REGISTER_GATHER_AND_PACK); |
735 | REGISTER_GATHER_AND_PACK(quint8); |
736 | REGISTER_GATHER_AND_PACK(qint8); |
737 | REGISTER_GATHER_AND_PACK(qint32); |
738 | |
739 | #undef REGISTER_GATHER_AND_PACK |
740 | |
741 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
742 | |
743 | #define REGISTER_GPU(type) \ |
744 | REGISTER_KERNEL_BUILDER( \ |
745 | Name("TensorArrayPack") \ |
746 | .Device(DEVICE_GPU) \ |
747 | .TypeConstraint<type>("dtype") \ |
748 | .HostMemory("handle"), \ |
749 | TensorArrayPackOrGatherOp<GPUDevice, type, true /* LEGACY_PACK */>); \ |
750 | REGISTER_KERNEL_BUILDER( \ |
751 | Name("TensorArrayGather") \ |
752 | .Device(DEVICE_GPU) \ |
753 | .TypeConstraint<type>("dtype") \ |
754 | .HostMemory("indices") \ |
755 | .HostMemory("handle"), \ |
756 | TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \ |
757 | REGISTER_KERNEL_BUILDER( \ |
758 | Name("TensorArrayGatherV2") \ |
759 | .Device(DEVICE_GPU) \ |
760 | .TypeConstraint<type>("dtype") \ |
761 | .HostMemory("indices") \ |
762 | .HostMemory("handle"), \ |
763 | TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \ |
764 | REGISTER_KERNEL_BUILDER( \ |
765 | Name("TensorArrayGatherV3") \ |
766 | .Device(DEVICE_GPU) \ |
767 | .TypeConstraint<type>("dtype") \ |
768 | .HostMemory("indices") \ |
769 | .HostMemory("handle"), \ |
770 | TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); |
771 | |
772 | TF_CALL_bfloat16(REGISTER_GPU); |
773 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
774 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
775 | #undef REGISTER_GPU |
776 | |
777 | // A special GPU kernel for int32. |
778 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
779 | // registration requires all int32 inputs and outputs to be in host memory. |
780 | REGISTER_KERNEL_BUILDER( |
781 | Name("TensorArrayGather" ) |
782 | .Device(DEVICE_GPU) |
783 | .TypeConstraint<int32>("dtype" ) |
784 | .HostMemory("indices" ) |
785 | .HostMemory("handle" ), |
786 | TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); |
787 | REGISTER_KERNEL_BUILDER( |
788 | Name("TensorArrayGatherV2" ) |
789 | .Device(DEVICE_GPU) |
790 | .TypeConstraint<int32>("dtype" ) |
791 | .HostMemory("indices" ) |
792 | .HostMemory("handle" ), |
793 | TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); |
794 | REGISTER_KERNEL_BUILDER( |
795 | Name("TensorArrayGatherV3" ) |
796 | .Device(DEVICE_GPU) |
797 | .TypeConstraint<int32>("dtype" ) |
798 | .HostMemory("indices" ) |
799 | .HostMemory("handle" ), |
800 | TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); |
801 | |
802 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
803 | |
804 | // CONCAT ********************************************************************* |
805 | |
806 | // Concatenate the elements in a TensorArray. All elements must be |
807 | // defined and (excepting the first dimension) have the same shape. |
808 | template <typename Device, typename T> |
809 | class TensorArrayConcatOp : public OpKernel { |
810 | public: |
811 | typedef typename TTypes<T, 2>::ConstMatrix ConstMatrix; |
812 | typedef std::vector<std::unique_ptr<ConstMatrix> > ConstMatrixVector; |
813 | |
814 | explicit TensorArrayConcatOp(OpKernelConstruction* context) |
815 | : OpKernel(context) { |
816 | OP_REQUIRES_OK(context, context->GetAttr("dtype" , &dtype_)); |
817 | OP_REQUIRES_OK(context, context->GetAttr("element_shape_except0" , |
818 | &element_shape_except0_)); |
819 | } |
820 | |
821 | void Compute(OpKernelContext* ctx) override { |
822 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, false)); |
823 | |
824 | TensorArray* tensor_array = nullptr; |
825 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
826 | core::ScopedUnref unref(tensor_array); |
827 | OP_REQUIRES( |
828 | ctx, dtype_ == tensor_array->ElemType(), |
829 | errors::InvalidArgument( |
830 | "TensorArray dtype is " , DataTypeString(tensor_array->ElemType()), |
831 | " but Op requested dtype " , DataTypeString(dtype_), "." )); |
832 | |
833 | int32_t array_size; |
834 | OP_REQUIRES_OK(ctx, tensor_array->PackOrConcatSize(&array_size)); |
835 | |
836 | // If there are no elements, return a zero-element Tensor with |
837 | // shape [0] + element_shape_except0_ |
838 | if (array_size == 0) { |
839 | OP_REQUIRES( |
840 | ctx, element_shape_except0_.IsFullyDefined(), |
841 | errors::Unimplemented( |
842 | "TensorArray has size zero, but element_shape_except0 " , |
843 | element_shape_except0_.DebugString(), |
844 | " is not fully defined. " |
845 | "Currently only static shapes are supported when concatenating " |
846 | "zero-size TensorArrays." )); |
847 | TensorShape empty_shape; |
848 | element_shape_except0_.AsTensorShape(&empty_shape); |
849 | empty_shape.InsertDim(0, 0); |
850 | Tensor* empty_unused; |
851 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &empty_unused)); |
852 | OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {0}, &empty_unused)); |
853 | return; |
854 | } |
855 | |
856 | // Read all the Tensors into a vector to keep track of their memory. |
857 | std::vector<Tensor> values; |
858 | std::vector<int32> indices(array_size); |
859 | std::iota(indices.begin(), indices.end(), 0); |
860 | Status s = tensor_array->ReadMany<Device, T>(ctx, indices, &values); |
861 | OP_REQUIRES_OK(ctx, s); |
862 | |
863 | Tensor* lengths_tensor = nullptr; |
864 | OP_REQUIRES_OK(ctx, |
865 | ctx->allocate_output( |
866 | 1, TensorShape({static_cast<int64_t>(values.size())}), |
867 | &lengths_tensor)); |
868 | auto lengths_tensor_t = lengths_tensor->vec<int64_t>(); |
869 | |
870 | TensorShape output_shape; |
871 | TensorShape output_shape_except0; |
872 | for (std::size_t i = 0; i < values.size(); ++i) { |
873 | TensorShape value_shape_t = values[i].shape(); |
874 | |
875 | OP_REQUIRES( |
876 | ctx, TensorShapeUtils::IsVectorOrHigher(value_shape_t), |
877 | errors::InvalidArgument( |
878 | "Concat saw a scalar shape at index " , i, |
879 | " but requires at least vectors. Did you mean to call pack?" )); |
880 | |
881 | lengths_tensor_t(i) = value_shape_t.dim_size(0); |
882 | |
883 | TensorShape value_shape_t_except0 = value_shape_t; |
884 | value_shape_t_except0.RemoveDim(0); |
885 | if (i == 0) { |
886 | output_shape = value_shape_t; |
887 | output_shape_except0 = value_shape_t_except0; |
888 | OP_REQUIRES( |
889 | ctx, element_shape_except0_.IsCompatibleWith(output_shape_except0), |
890 | errors::InvalidArgument( |
891 | "TensorArray was passed element_shape_except0 " , |
892 | element_shape_except0_.DebugString(), |
893 | " but index 0 has (excepting dimension 0) shape: " , |
894 | value_shape_t_except0.DebugString(), " which does not match." )); |
895 | } else { |
896 | OP_REQUIRES(ctx, output_shape_except0 == value_shape_t_except0, |
897 | errors::InvalidArgument( |
898 | "TensorArray has inconsistent shapes. Index 0 has " |
899 | "(excepting dimension 0) shape: " , |
900 | output_shape_except0.DebugString(), " but index " , i, |
901 | " has (excepting dimension 0) shape: " , |
902 | value_shape_t_except0.DebugString())); |
903 | // Store the previous maximum length as the offset for this tensor. |
904 | output_shape.set_dim( |
905 | 0, output_shape.dim_size(0) + value_shape_t.dim_size(0)); |
906 | } |
907 | } |
908 | |
909 | Tensor* output_tensor = nullptr; |
910 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); |
911 | ConstMatrixVector input_tensors_flat; |
912 | input_tensors_flat.reserve(values.size()); |
913 | for (size_t i = 0; i < values.size(); ++i) { |
914 | const Tensor* value_t = &values[i]; |
915 | if (value_t->NumElements() > 0) { |
916 | input_tensors_flat.push_back(MakeUnique<ConstMatrix>( |
917 | value_t->shaped<T, 2>({1, value_t->NumElements()}))); |
918 | } |
919 | } |
920 | |
921 | if (output_shape.num_elements() > 0) { |
922 | auto output_flat = |
923 | output_tensor->shaped<T, 2>({1, output_shape.num_elements()}); |
924 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
925 | if (std::is_same<Device, GPUDevice>::value) { |
926 | ConcatGPU<T>(ctx, input_tensors_flat, output_tensor, &output_flat); |
927 | return; |
928 | } |
929 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
930 | ConcatCPU<T>(ctx->device(), input_tensors_flat, &output_flat); |
931 | } |
932 | } |
933 | |
934 | private: |
935 | DataType dtype_; |
936 | PartialTensorShape element_shape_except0_; |
937 | }; |
938 | |
939 | #define REGISTER_CONCAT(type) \ |
940 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat") \ |
941 | .Device(DEVICE_CPU) \ |
942 | .TypeConstraint<type>("dtype") \ |
943 | .HostMemory("lengths") \ |
944 | .HostMemory("handle"), \ |
945 | TensorArrayConcatOp<CPUDevice, type>); \ |
946 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2") \ |
947 | .Device(DEVICE_CPU) \ |
948 | .TypeConstraint<type>("dtype") \ |
949 | .HostMemory("lengths") \ |
950 | .HostMemory("handle"), \ |
951 | TensorArrayConcatOp<CPUDevice, type>) \ |
952 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ |
953 | .Device(DEVICE_CPU) \ |
954 | .TypeConstraint<type>("dtype") \ |
955 | .HostMemory("lengths") \ |
956 | .HostMemory("handle"), \ |
957 | TensorArrayConcatOp<CPUDevice, type>) |
958 | |
959 | TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); |
960 | REGISTER_CONCAT(quint8); |
961 | REGISTER_CONCAT(qint8); |
962 | REGISTER_CONCAT(qint32); |
963 | |
964 | #undef REGISTER_CONCAT |
965 | |
966 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
967 | |
968 | #define REGISTER_GPU(type) \ |
969 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat") \ |
970 | .Device(DEVICE_GPU) \ |
971 | .TypeConstraint<type>("dtype") \ |
972 | .HostMemory("lengths") \ |
973 | .HostMemory("handle"), \ |
974 | TensorArrayConcatOp<GPUDevice, type>); \ |
975 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2") \ |
976 | .Device(DEVICE_GPU) \ |
977 | .TypeConstraint<type>("dtype") \ |
978 | .HostMemory("lengths") \ |
979 | .HostMemory("handle"), \ |
980 | TensorArrayConcatOp<GPUDevice, type>) \ |
981 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ |
982 | .Device(DEVICE_GPU) \ |
983 | .TypeConstraint<type>("dtype") \ |
984 | .HostMemory("lengths") \ |
985 | .HostMemory("handle"), \ |
986 | TensorArrayConcatOp<GPUDevice, type>) |
987 | |
988 | TF_CALL_bfloat16(REGISTER_GPU); |
989 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
990 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
991 | #undef REGISTER_GPU |
992 | |
993 | // A special GPU kernel for int32. |
994 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
995 | // registration requires all int32 inputs and outputs to be in host memory. |
996 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcat" ) |
997 | .Device(DEVICE_GPU) |
998 | .TypeConstraint<int32>("dtype" ) |
999 | .HostMemory("lengths" ) |
1000 | .HostMemory("handle" ), |
1001 | TensorArrayConcatOp<CPUDevice, int32>); |
1002 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2" ) |
1003 | .Device(DEVICE_GPU) |
1004 | .TypeConstraint<int32>("dtype" ) |
1005 | .HostMemory("lengths" ) |
1006 | .HostMemory("handle" ), |
1007 | TensorArrayConcatOp<CPUDevice, int32>); |
1008 | REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3" ) |
1009 | .Device(DEVICE_GPU) |
1010 | .TypeConstraint<int32>("dtype" ) |
1011 | .HostMemory("lengths" ) |
1012 | .HostMemory("handle" ), |
1013 | TensorArrayConcatOp<CPUDevice, int32>); |
1014 | |
1015 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1016 | |
1017 | // UNPACK and SCATTER ********************************************************* |
1018 | |
1019 | template <typename Device, typename T, bool LEGACY_UNPACK> |
1020 | class TensorArrayUnpackOrScatterOp : public OpKernel { |
1021 | public: |
1022 | explicit TensorArrayUnpackOrScatterOp(OpKernelConstruction* context) |
1023 | : OpKernel(context) {} |
1024 | |
1025 | void Compute(OpKernelContext* ctx) override { |
1026 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true)); |
1027 | |
1028 | TensorArray* tensor_array = nullptr; |
1029 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
1030 | core::ScopedUnref unref(tensor_array); |
1031 | const Tensor* tensor_value; |
1032 | OP_REQUIRES_OK(ctx, ctx->input("value" , &tensor_value)); |
1033 | TensorShape element_shape(tensor_value->shape()); |
1034 | |
1035 | OP_REQUIRES(ctx, |
1036 | FastBoundsCheck(element_shape.dim_size(0), |
1037 | std::numeric_limits<int32>::max()), |
1038 | errors::InvalidArgument("tensor dim0 too large to unpack" )); |
1039 | |
1040 | OP_REQUIRES( |
1041 | ctx, tensor_value->dtype() == tensor_array->ElemType(), |
1042 | errors::InvalidArgument("TensorArray dtype is " , |
1043 | DataTypeString(tensor_array->ElemType()), |
1044 | " but Op is trying to write dtype " , |
1045 | DataTypeString(tensor_value->dtype()), "." )); |
1046 | OP_REQUIRES(ctx, element_shape.dims() > 0, |
1047 | errors::InvalidArgument("Input value for unpack must be at " |
1048 | "least a vector but received shape: " , |
1049 | element_shape.DebugString())); |
1050 | int32_t array_size; |
1051 | OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size)); |
1052 | |
1053 | int32_t max_index; |
1054 | int32_t num_values; |
1055 | std::vector<int32> write_indices; |
1056 | if (LEGACY_UNPACK) { |
1057 | num_values = element_shape.dim_size(0); |
1058 | max_index = num_values - 1; |
1059 | write_indices.resize(num_values); |
1060 | std::iota(write_indices.begin(), write_indices.end(), 0); |
1061 | } else { |
1062 | const Tensor* tensor_indices; |
1063 | OP_REQUIRES_OK(ctx, ctx->input("indices" , &tensor_indices)); |
1064 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_indices->shape()), |
1065 | errors::InvalidArgument( |
1066 | "Expected indices to be a vector, but received shape: " , |
1067 | tensor_indices->shape().DebugString())); |
1068 | OP_REQUIRES(ctx, |
1069 | tensor_indices->NumElements() == element_shape.dim_size(0), |
1070 | errors::InvalidArgument( |
1071 | "Expected len(indices) == values.shape[0], but saw: " , |
1072 | tensor_indices->NumElements(), " vs. " , |
1073 | element_shape.dim_size(0))); |
1074 | const auto indices_t = tensor_indices->vec<int32>(); |
1075 | num_values = tensor_indices->NumElements(); |
1076 | max_index = (num_values == 0) |
1077 | ? -1 |
1078 | : *std::max_element(indices_t.data(), |
1079 | indices_t.data() + num_values); |
1080 | write_indices.resize(num_values); |
1081 | // Copy into write_indices. |
1082 | std::copy(indices_t.data(), indices_t.data() + num_values, |
1083 | write_indices.begin()); |
1084 | } |
1085 | |
1086 | bool dynamic_size = tensor_array->HasDynamicSize(); |
1087 | |
1088 | // If dynamic size, we may have to resize the TensorArray to fit. |
1089 | if (dynamic_size && array_size < max_index + 1) { |
1090 | array_size = static_cast<int32>(max_index + 1); |
1091 | } |
1092 | |
1093 | if (LEGACY_UNPACK) { |
1094 | OP_REQUIRES( |
1095 | ctx, element_shape.dim_size(0) == array_size, |
1096 | errors::InvalidArgument( |
1097 | "Input value must have first dimension equal to the array size (" , |
1098 | element_shape.dim_size(0), " vs. " , array_size, ")" )); |
1099 | } else { |
1100 | OP_REQUIRES( |
1101 | ctx, max_index < array_size, |
1102 | errors::InvalidArgument("Max scatter index must be < array size (" , |
1103 | max_index, " vs. " , array_size, ")" )); |
1104 | } |
1105 | element_shape.RemoveDim(0); |
1106 | |
1107 | auto tensor_value_t = tensor_value->shaped<T, 3>( |
1108 | {1, num_values, element_shape.num_elements()}); |
1109 | |
1110 | Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0}; |
1111 | Eigen::DSizes<Eigen::DenseIndex, 3> sizes{ |
1112 | 1, 1, static_cast<Eigen::DenseIndex>(element_shape.num_elements())}; |
1113 | |
1114 | std::vector<Tensor> write_values; |
1115 | write_values.reserve(num_values); |
1116 | |
1117 | for (int i = 0; i < num_values; ++i) { |
1118 | Tensor tensor_value_i; |
1119 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensor_array->ElemType(), |
1120 | element_shape, &tensor_value_i)); |
1121 | auto tensor_value_i_t = |
1122 | tensor_value_i.shaped<T, 3>({1, 1, element_shape.num_elements()}); |
1123 | indices[1] = i; |
1124 | |
1125 | if (element_shape.num_elements() > 0) { |
1126 | functor::Split<Device, T, 3>()(ctx->eigen_device<Device>(), |
1127 | tensor_value_i_t, tensor_value_t, |
1128 | indices, sizes); |
1129 | } |
1130 | |
1131 | write_values.push_back(tensor_value_i); |
1132 | } |
1133 | |
1134 | // Record the pack size of the TensorArray. |
1135 | if (LEGACY_UNPACK) { |
1136 | OP_REQUIRES_OK(ctx, tensor_array->SetMarkedSize(array_size)); |
1137 | } |
1138 | |
1139 | Status s = tensor_array->WriteOrAggregateMany<Device, T>(ctx, write_indices, |
1140 | &write_values); |
1141 | OP_REQUIRES_OK(ctx, s); |
1142 | } |
1143 | }; |
1144 | |
1145 | #define REGISTER_SCATTER_AND_UNPACK(type) \ |
1146 | REGISTER_KERNEL_BUILDER( \ |
1147 | Name("TensorArrayUnpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
1148 | TensorArrayUnpackOrScatterOp<CPUDevice, type, \ |
1149 | true /* LEGACY_UNPACK */>); \ |
1150 | REGISTER_KERNEL_BUILDER( \ |
1151 | Name("TensorArrayScatter").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
1152 | TensorArrayUnpackOrScatterOp<CPUDevice, type, \ |
1153 | false /* LEGACY_UNPACK */>); \ |
1154 | REGISTER_KERNEL_BUILDER( \ |
1155 | Name("TensorArrayScatterV2") \ |
1156 | .Device(DEVICE_CPU) \ |
1157 | .TypeConstraint<type>("T"), \ |
1158 | TensorArrayUnpackOrScatterOp<CPUDevice, type, \ |
1159 | false /* LEGACY_UNPACK */>); \ |
1160 | REGISTER_KERNEL_BUILDER( \ |
1161 | Name("TensorArrayScatterV3") \ |
1162 | .Device(DEVICE_CPU) \ |
1163 | .TypeConstraint<type>("T"), \ |
1164 | TensorArrayUnpackOrScatterOp<CPUDevice, type, \ |
1165 | false /* LEGACY_UNPACK */>); |
1166 | |
1167 | TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); |
1168 | #undef REGISTER_SCATTER_AND_UNPACK |
1169 | |
1170 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1171 | |
1172 | #define REGISTER_GPU(type) \ |
1173 | REGISTER_KERNEL_BUILDER( \ |
1174 | Name("TensorArrayUnpack") \ |
1175 | .Device(DEVICE_GPU) \ |
1176 | .TypeConstraint<type>("T") \ |
1177 | .HostMemory("handle"), \ |
1178 | TensorArrayUnpackOrScatterOp<GPUDevice, type, \ |
1179 | true /* LEGACY_UNPACK */>); \ |
1180 | REGISTER_KERNEL_BUILDER( \ |
1181 | Name("TensorArrayScatter") \ |
1182 | .Device(DEVICE_GPU) \ |
1183 | .TypeConstraint<type>("T") \ |
1184 | .HostMemory("indices") \ |
1185 | .HostMemory("handle"), \ |
1186 | TensorArrayUnpackOrScatterOp<GPUDevice, type, \ |
1187 | false /* LEGACY_UNPACK */>); \ |
1188 | REGISTER_KERNEL_BUILDER( \ |
1189 | Name("TensorArrayScatterV2") \ |
1190 | .Device(DEVICE_GPU) \ |
1191 | .TypeConstraint<type>("T") \ |
1192 | .HostMemory("indices") \ |
1193 | .HostMemory("handle"), \ |
1194 | TensorArrayUnpackOrScatterOp<GPUDevice, type, \ |
1195 | false /* LEGACY_UNPACK */>); \ |
1196 | REGISTER_KERNEL_BUILDER( \ |
1197 | Name("TensorArrayScatterV3") \ |
1198 | .Device(DEVICE_GPU) \ |
1199 | .TypeConstraint<type>("T") \ |
1200 | .HostMemory("indices") \ |
1201 | .HostMemory("handle"), \ |
1202 | TensorArrayUnpackOrScatterOp<GPUDevice, type, \ |
1203 | false /* LEGACY_UNPACK */>); |
1204 | |
1205 | TF_CALL_int64(REGISTER_GPU); |
1206 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
1207 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
1208 | #undef REGISTER_GPU |
1209 | |
1210 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1211 | |
1212 | // SPLIT ********************************************************************* |
1213 | |
1214 | template <typename Device, typename T> |
1215 | class TensorArraySplitOp : public OpKernel { |
1216 | public: |
1217 | explicit TensorArraySplitOp(OpKernelConstruction* context) |
1218 | : OpKernel(context) {} |
1219 | |
1220 | void Compute(OpKernelContext* ctx) override { |
1221 | OP_REQUIRES_OK(ctx, SetupFlowControlInputs(ctx, true)); |
1222 | |
1223 | TensorArray* tensor_array = nullptr; |
1224 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
1225 | core::ScopedUnref unref(tensor_array); |
1226 | const Tensor* tensor_value; |
1227 | OP_REQUIRES_OK(ctx, ctx->input("value" , &tensor_value)); |
1228 | const Tensor* tensor_lengths; |
1229 | OP_REQUIRES_OK(ctx, ctx->input("lengths" , &tensor_lengths)); |
1230 | |
1231 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(tensor_lengths->shape()), |
1232 | errors::InvalidArgument( |
1233 | "Expected lengths to be a vector, received shape: " , |
1234 | tensor_lengths->shape().DebugString())); |
1235 | OP_REQUIRES(ctx, |
1236 | FastBoundsCheck(tensor_lengths->NumElements(), |
1237 | std::numeric_limits<int32>::max()), |
1238 | errors::InvalidArgument( |
1239 | "Expected lengths to have < max int32 entries" )); |
1240 | |
1241 | int32_t num_tensors = static_cast<int32>(tensor_lengths->NumElements()); |
1242 | auto tensor_lengths_t = tensor_lengths->vec<int64_t>(); |
1243 | std::vector<int64_t> cumulative_lengths; |
1244 | cumulative_lengths.reserve(num_tensors); |
1245 | int64_t total_length = 0; |
1246 | for (int i = 0; i < num_tensors; ++i) { |
1247 | total_length += tensor_lengths_t(i); |
1248 | cumulative_lengths.push_back(total_length); |
1249 | } |
1250 | |
1251 | OP_REQUIRES( |
1252 | ctx, TensorShapeUtils::IsVectorOrHigher(tensor_value->shape()), |
1253 | errors::InvalidArgument( |
1254 | "Expected value to be at least a vector, but received shape: " , |
1255 | tensor_value->shape().DebugString())); |
1256 | |
1257 | OP_REQUIRES( |
1258 | ctx, total_length == tensor_value->shape().dim_size(0), |
1259 | errors::InvalidArgument("Expected sum of lengths to be equal to " |
1260 | "values.shape[0], but sum of lengths is " , |
1261 | total_length, " and value's shape is: " , |
1262 | tensor_value->shape().DebugString())); |
1263 | int64_t elements_per_row = |
1264 | (total_length == 0) ? 0 : (tensor_value->NumElements() / total_length); |
1265 | |
1266 | int32_t array_size; |
1267 | OP_REQUIRES_OK(ctx, tensor_array->Size(&array_size)); |
1268 | bool dynamic_size = tensor_array->HasDynamicSize(); |
1269 | |
1270 | std::vector<TensorShape> element_shapes(num_tensors, tensor_value->shape()); |
1271 | for (int32_t i = 0; i < num_tensors; ++i) { |
1272 | element_shapes[i].set_dim(0, tensor_lengths_t(i)); |
1273 | } |
1274 | |
1275 | // If dynamic size, we may have to resize the TensorArray to fit. |
1276 | if (dynamic_size && array_size < num_tensors) { |
1277 | array_size = num_tensors; |
1278 | } |
1279 | |
1280 | OP_REQUIRES( |
1281 | ctx, array_size == num_tensors, |
1282 | errors::InvalidArgument( |
1283 | "TensorArray's size is not equal to the size of lengths (" , |
1284 | array_size, " vs. " , num_tensors, "), and the TensorArray is not " , |
1285 | "marked as dynamically resizeable" )); |
1286 | |
1287 | OP_REQUIRES( |
1288 | ctx, tensor_value->dtype() == tensor_array->ElemType(), |
1289 | errors::InvalidArgument("TensorArray dtype is " , |
1290 | DataTypeString(tensor_array->ElemType()), |
1291 | " but Op is trying to write dtype " , |
1292 | DataTypeString(tensor_value->dtype()), "." )); |
1293 | |
1294 | auto tensor_value_t = |
1295 | tensor_value->shaped<T, 3>({1, total_length, elements_per_row}); |
1296 | |
1297 | std::vector<Tensor> write_values; |
1298 | write_values.reserve(array_size); |
1299 | |
1300 | for (int i = 0; i < array_size; ++i) { |
1301 | Tensor tensor_value_i; |
1302 | |
1303 | int64_t previous_length = (i == 0) ? 0 : cumulative_lengths[i - 1]; |
1304 | Eigen::DSizes<Eigen::DenseIndex, 3> indices{ |
1305 | 0, static_cast<Eigen::DenseIndex>(previous_length), 0}; |
1306 | Eigen::DSizes<Eigen::DenseIndex, 3> sizes{ |
1307 | 1, static_cast<Eigen::DenseIndex>(tensor_lengths_t(i)), |
1308 | static_cast<Eigen::DenseIndex>(elements_per_row)}; |
1309 | |
1310 | OP_REQUIRES_OK( |
1311 | ctx, ctx->allocate_temp(tensor_array->ElemType(), element_shapes[i], |
1312 | &tensor_value_i)); |
1313 | |
1314 | if (tensor_lengths_t(i) > 0) { |
1315 | auto tensor_value_i_t = tensor_value_i.shaped<T, 3>( |
1316 | {1, tensor_lengths_t(i), elements_per_row}); |
1317 | |
1318 | functor::Split<Device, T, 3>()(ctx->eigen_device<Device>(), |
1319 | tensor_value_i_t, tensor_value_t, |
1320 | indices, sizes); |
1321 | } |
1322 | |
1323 | write_values.push_back(tensor_value_i); |
1324 | } |
1325 | |
1326 | // Record the concat size of the TensorArray. |
1327 | OP_REQUIRES_OK(ctx, tensor_array->SetMarkedSize(array_size)); |
1328 | |
1329 | std::vector<int32> indices(array_size); |
1330 | std::iota(indices.begin(), indices.end(), 0); |
1331 | |
1332 | Status s = tensor_array->WriteOrAggregateMany<Device, T>(ctx, indices, |
1333 | &write_values); |
1334 | OP_REQUIRES_OK(ctx, s); |
1335 | } |
1336 | }; |
1337 | |
1338 | #define REGISTER_SPLIT(type) \ |
1339 | REGISTER_KERNEL_BUILDER( \ |
1340 | Name("TensorArraySplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
1341 | TensorArraySplitOp<CPUDevice, type>); \ |
1342 | REGISTER_KERNEL_BUILDER( \ |
1343 | Name("TensorArraySplitV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
1344 | TensorArraySplitOp<CPUDevice, type>); \ |
1345 | REGISTER_KERNEL_BUILDER( \ |
1346 | Name("TensorArraySplitV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
1347 | TensorArraySplitOp<CPUDevice, type>); |
1348 | |
1349 | TF_CALL_ALL_TYPES(REGISTER_SPLIT); |
1350 | #undef REGISTER_SPLIT |
1351 | |
1352 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1353 | |
1354 | #define REGISTER_GPU(type) \ |
1355 | REGISTER_KERNEL_BUILDER(Name("TensorArraySplit") \ |
1356 | .Device(DEVICE_GPU) \ |
1357 | .TypeConstraint<type>("T") \ |
1358 | .HostMemory("lengths") \ |
1359 | .HostMemory("handle"), \ |
1360 | TensorArraySplitOp<GPUDevice, type>); \ |
1361 | REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV2") \ |
1362 | .Device(DEVICE_GPU) \ |
1363 | .TypeConstraint<type>("T") \ |
1364 | .HostMemory("lengths") \ |
1365 | .HostMemory("handle"), \ |
1366 | TensorArraySplitOp<GPUDevice, type>); \ |
1367 | REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV3") \ |
1368 | .Device(DEVICE_GPU) \ |
1369 | .TypeConstraint<type>("T") \ |
1370 | .HostMemory("lengths") \ |
1371 | .HostMemory("handle"), \ |
1372 | TensorArraySplitOp<GPUDevice, type>); |
1373 | |
1374 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
1375 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
1376 | #undef REGISTER_GPU |
1377 | |
1378 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
1379 | |
1380 | // SIZE *********************************************************************** |
1381 | |
1382 | // Get the size of the TensorArray |
1383 | class TensorArraySizeOp : public OpKernel { |
1384 | public: |
1385 | explicit TensorArraySizeOp(OpKernelConstruction* context) |
1386 | : OpKernel(context) {} |
1387 | |
1388 | void Compute(OpKernelContext* ctx) override { |
1389 | TensorArray* tensor_array; |
1390 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
1391 | core::ScopedUnref unref(tensor_array); |
1392 | Tensor* output = nullptr; |
1393 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); |
1394 | OP_REQUIRES_OK(ctx, tensor_array->Size(&(output->scalar<int32>()()))); |
1395 | } |
1396 | }; |
1397 | |
1398 | REGISTER_KERNEL_BUILDER(Name("TensorArraySize" ).Device(DEVICE_CPU), |
1399 | TensorArraySizeOp); |
1400 | REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2" ).Device(DEVICE_CPU), |
1401 | TensorArraySizeOp); |
1402 | REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3" ).Device(DEVICE_CPU), |
1403 | TensorArraySizeOp); |
1404 | |
1405 | REGISTER_KERNEL_BUILDER(Name("TensorArraySize" ) |
1406 | .Device(DEVICE_GPU) |
1407 | .HostMemory("handle" ) |
1408 | .HostMemory("size" ), |
1409 | TensorArraySizeOp); |
1410 | REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2" ) |
1411 | .Device(DEVICE_GPU) |
1412 | .HostMemory("handle" ) |
1413 | .HostMemory("size" ), |
1414 | TensorArraySizeOp); |
1415 | REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3" ) |
1416 | .Device(DEVICE_GPU) |
1417 | .HostMemory("handle" ) |
1418 | .HostMemory("size" ), |
1419 | TensorArraySizeOp); |
1420 | |
1421 | // CLOSE |
1422 | // ********************************************************************** |
1423 | |
1424 | // Delete the TensorArray from its resource container. This enables |
1425 | // the user to close and release the resource in the middle of a step/run. |
1426 | // TODO(ebrevdo): decide whether closing the grad op should happen |
1427 | // here or on the python side. |
1428 | class TensorArrayCloseOp : public OpKernel { |
1429 | public: |
1430 | explicit TensorArrayCloseOp(OpKernelConstruction* context) |
1431 | : OpKernel(context) {} |
1432 | |
1433 | void Compute(OpKernelContext* ctx) override { |
1434 | TensorArray* tensor_array; |
1435 | OP_REQUIRES_OK(ctx, GetTensorArray(ctx, &tensor_array)); |
1436 | core::ScopedUnref unref(tensor_array); |
1437 | // Instead of deleting this TA from the ResourceManager, we just |
1438 | // clear it away and mark it as closed. The remaining memory |
1439 | // consumed store its mutex and handle Tensor. This will be |
1440 | // cleared out at the end of the step anyway, so it's fine to keep |
1441 | // it around until the end of the step. Further calls to the |
1442 | // TensorArray will fail because TensorArray checks internally to |
1443 | // see if it is closed or not. |
1444 | tensor_array->ClearAndMarkClosed(); |
1445 | } |
1446 | }; |
1447 | |
1448 | REGISTER_KERNEL_BUILDER(Name("TensorArrayClose" ).Device(DEVICE_CPU), |
1449 | TensorArrayCloseOp); |
1450 | REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV2" ).Device(DEVICE_CPU), |
1451 | TensorArrayCloseOp); |
1452 | REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV3" ).Device(DEVICE_CPU), |
1453 | TensorArrayCloseOp); |
1454 | |
1455 | REGISTER_KERNEL_BUILDER( |
1456 | Name("TensorArrayClose" ).Device(DEVICE_GPU).HostMemory("handle" ), |
1457 | TensorArrayCloseOp); |
1458 | REGISTER_KERNEL_BUILDER( |
1459 | Name("TensorArrayCloseV2" ).Device(DEVICE_GPU).HostMemory("handle" ), |
1460 | TensorArrayCloseOp); |
1461 | REGISTER_KERNEL_BUILDER( |
1462 | Name("TensorArrayCloseV3" ).Device(DEVICE_GPU).HostMemory("handle" ), |
1463 | TensorArrayCloseOp); |
1464 | |
1465 | } // namespace tensorflow |
1466 | |