1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #include "tensorflow/core/kernels/stack.h" |
19 | |
20 | #include <limits.h> |
21 | #include <atomic> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/common_runtime/device.h" |
25 | #include "tensorflow/core/framework/device_base.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/resource_mgr.h" |
29 | #include "tensorflow/core/framework/tensor.h" |
30 | #include "tensorflow/core/framework/tensor_shape.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/lib/core/errors.h" |
33 | #include "tensorflow/core/lib/core/refcount.h" |
34 | #include "tensorflow/core/lib/gtl/map_util.h" |
35 | #include "tensorflow/core/platform/logging.h" |
36 | #include "tensorflow/core/platform/macros.h" |
37 | #include "tensorflow/core/platform/mutex.h" |
38 | #include "tensorflow/core/platform/thread_annotations.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | REGISTER_KERNEL_BUILDER(Name("Stack" ).Device(DEVICE_CPU), StackOp); |
44 | REGISTER_KERNEL_BUILDER( |
45 | Name("Stack" ).Device(DEVICE_DEFAULT).HostMemory("handle" ), StackOp); |
46 | |
47 | REGISTER_KERNEL_BUILDER(Name("StackV2" ).Device(DEVICE_CPU), StackOp); |
48 | REGISTER_KERNEL_BUILDER(Name("StackV2" ) |
49 | .Device(DEVICE_DEFAULT) |
50 | .HostMemory("max_size" ) |
51 | .HostMemory("handle" ), |
52 | StackOp); |
53 | |
54 | REGISTER_KERNEL_BUILDER(Name("StackPush" ).Device(DEVICE_CPU), |
55 | TemplatedStackPushOp</*allow_swapping=*/false>); |
56 | REGISTER_KERNEL_BUILDER(Name("StackPushV2" ).Device(DEVICE_CPU), |
57 | TemplatedStackPushOp</*allow_swapping=*/false>); |
58 | |
59 | REGISTER_KERNEL_BUILDER(Name("StackPop" ).Device(DEVICE_CPU), StackPopOp); |
60 | REGISTER_KERNEL_BUILDER(Name("StackPopV2" ).Device(DEVICE_CPU), StackPopOp); |
61 | |
62 | #define REGISTER_DEFAULT_KERNEL(type) \ |
63 | REGISTER_KERNEL_BUILDER(Name("StackPush") \ |
64 | .Device(DEVICE_DEFAULT) \ |
65 | .HostMemory("handle") \ |
66 | .TypeConstraint<type>("T"), \ |
67 | TemplatedStackPushOp</*allow_swapping=*/true>); \ |
68 | REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ |
69 | .Device(DEVICE_DEFAULT) \ |
70 | .HostMemory("handle") \ |
71 | .TypeConstraint<type>("T"), \ |
72 | TemplatedStackPushOp</*allow_swapping=*/true>); \ |
73 | REGISTER_KERNEL_BUILDER(Name("StackPop") \ |
74 | .Device(DEVICE_DEFAULT) \ |
75 | .HostMemory("handle") \ |
76 | .TypeConstraint<type>("elem_type"), \ |
77 | StackPopOp); \ |
78 | REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ |
79 | .Device(DEVICE_DEFAULT) \ |
80 | .HostMemory("handle") \ |
81 | .TypeConstraint<type>("elem_type"), \ |
82 | StackPopOp); |
83 | |
84 | TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_DEFAULT_KERNEL); |
85 | #undef REGISTER_DEFAULT_KERNEL |
86 | |
87 | // Special GPU kernels for int32 and bool. |
88 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
89 | // registration requires all int32 inputs and outputs to be in host memory. |
90 | #define REGISTER_DEFAULT_HOST_KERNEL(type) \ |
91 | REGISTER_KERNEL_BUILDER(Name("StackPush") \ |
92 | .Device(DEVICE_DEFAULT) \ |
93 | .HostMemory("handle") \ |
94 | .HostMemory("elem") \ |
95 | .HostMemory("output") \ |
96 | .TypeConstraint<type>("T"), \ |
97 | TemplatedStackPushOp</*allow_swapping=*/true>); \ |
98 | REGISTER_KERNEL_BUILDER(Name("StackPushV2") \ |
99 | .Device(DEVICE_DEFAULT) \ |
100 | .HostMemory("handle") \ |
101 | .HostMemory("elem") \ |
102 | .HostMemory("output") \ |
103 | .TypeConstraint<type>("T"), \ |
104 | TemplatedStackPushOp</*allow_swapping=*/true>); \ |
105 | REGISTER_KERNEL_BUILDER(Name("StackPop") \ |
106 | .Device(DEVICE_DEFAULT) \ |
107 | .HostMemory("handle") \ |
108 | .HostMemory("elem") \ |
109 | .TypeConstraint<type>("elem_type"), \ |
110 | StackPopOp); \ |
111 | REGISTER_KERNEL_BUILDER(Name("StackPopV2") \ |
112 | .Device(DEVICE_DEFAULT) \ |
113 | .HostMemory("handle") \ |
114 | .HostMemory("elem") \ |
115 | .TypeConstraint<type>("elem_type"), \ |
116 | StackPopOp); |
117 | |
118 | REGISTER_DEFAULT_HOST_KERNEL(int32); |
119 | REGISTER_DEFAULT_HOST_KERNEL(bool); |
120 | |
121 | #undef REGISTER_DEFAULT_HOST_KERNEL |
122 | |
123 | REGISTER_KERNEL_BUILDER(Name("StackClose" ).Device(DEVICE_CPU), StackCloseOp); |
124 | REGISTER_KERNEL_BUILDER( |
125 | Name("StackClose" ).Device(DEVICE_DEFAULT).HostMemory("handle" ), |
126 | StackCloseOp); |
127 | REGISTER_KERNEL_BUILDER(Name("StackCloseV2" ).Device(DEVICE_CPU), StackCloseOp); |
128 | REGISTER_KERNEL_BUILDER( |
129 | Name("StackCloseV2" ).Device(DEVICE_DEFAULT).HostMemory("handle" ), |
130 | StackCloseOp); |
131 | |
132 | } // namespace tensorflow |
133 | |