1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
18 | #define EIGEN_USE_GPU |
19 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
20 | |
21 | #include "tensorflow/core/kernels/scan_ops.h" |
22 | |
23 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
24 | #include "tensorflow/core/framework/bounds_check.h" |
25 | #include "tensorflow/core/framework/numeric_op.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | typedef Eigen::ThreadPoolDevice CPUDevice; |
34 | typedef Eigen::GpuDevice GPUDevice; |
35 | |
36 | template <typename Device, class T, typename Reducer, typename Tidx> |
37 | class ScanOp : public OpKernel { |
38 | public: |
39 | explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
40 | OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse" , &reverse_)); |
41 | OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive" , &exclusive_)); |
42 | } |
43 | |
44 | void Compute(OpKernelContext* ctx) override { |
45 | const Tensor& input = ctx->input(0); |
46 | const Tensor& tensor_axis = ctx->input(1); |
47 | |
48 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis.shape()), |
49 | errors::InvalidArgument("ScanOp: axis must be a scalar, not " , |
50 | tensor_axis.shape().DebugString())); |
51 | |
52 | const Tidx axis_arg = |
53 | internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()()); |
54 | const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg; |
55 | OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()), |
56 | errors::InvalidArgument( |
57 | "ScanOp: Expected scan axis in the range [" , -input.dims(), |
58 | ", " , input.dims(), "), but got " , axis)); |
59 | |
60 | const TensorShape& output_shape = input.shape(); |
61 | Tensor* output = nullptr; |
62 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); |
63 | |
64 | // Exit early if there's nothing to compute |
65 | if (output_shape.num_elements() == 0) return; |
66 | |
67 | const Device& d = ctx->eigen_device<Device>(); |
68 | Reducer reducer; |
69 | |
70 | // Dim reduction. |
71 | int64_t reduced_shape[3] = {1, 1, 1}; |
72 | for (Tidx i = 0; i < axis; ++i) { |
73 | reduced_shape[0] *= input.dim_size(i); |
74 | } |
75 | reduced_shape[1] = input.dim_size(axis); |
76 | for (Tidx i = axis + 1; i < input.dims(); ++i) { |
77 | reduced_shape[2] *= input.dim_size(i); |
78 | } |
79 | |
80 | functor::Scan<Device, Reducer, T>()(d, input.shaped<T, 3>(reduced_shape), |
81 | output->shaped<T, 3>(reduced_shape), |
82 | reducer, reverse_, exclusive_); |
83 | } |
84 | |
85 | private: |
86 | bool reverse_; |
87 | bool exclusive_; |
88 | }; |
89 | |
90 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
91 | namespace functor { |
92 | |
93 | // Forward declarations of GPU functors |
94 | #define DECLARE(REDUCER, T) \ |
95 | template <> \ |
96 | void Scan<GPUDevice, REDUCER, T>::operator()( \ |
97 | const GPUDevice& d, TTypes<T, 3>::ConstTensor in, \ |
98 | TTypes<T, 3>::Tensor out, const REDUCER& reducer, const bool reverse, \ |
99 | const bool exclusive); \ |
100 | extern template struct Scan<GPUDevice, REDUCER, T>; |
101 | |
102 | #define DECLARE_FOR_ALL_REDUCERS(T) \ |
103 | DECLARE(Eigen::internal::SumReducer<T>, T); \ |
104 | DECLARE(Eigen::internal::ProdReducer<T>, T); |
105 | |
106 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS); |
107 | DECLARE_FOR_ALL_REDUCERS(int32); |
108 | DECLARE_FOR_ALL_REDUCERS(int64_t); |
109 | #undef DECLARE_FOR_ALL_REDUCERS |
110 | |
111 | #define DECLARE_FOR_LOGSUMEXP_REDUCER(T) DECLARE(LogSumExpReducer<T>, T); |
112 | TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_LOGSUMEXP_REDUCER) |
113 | #undef DECLARE_FOR_LOGSUMEXP_REDUCER |
114 | |
115 | #undef DECLARE |
116 | |
117 | } // namespace functor |
118 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
119 | |
120 | // Register Cumsum kernels |
121 | #define REGISTER_CPU_KERNELS(type) \ |
122 | REGISTER_KERNEL_BUILDER( \ |
123 | Name("Cumsum") \ |
124 | .Device(DEVICE_CPU) \ |
125 | .TypeConstraint<type>("T") \ |
126 | .TypeConstraint<int32>("Tidx"), \ |
127 | ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \ |
128 | REGISTER_KERNEL_BUILDER( \ |
129 | Name("Cumsum") \ |
130 | .Device(DEVICE_CPU) \ |
131 | .TypeConstraint<type>("T") \ |
132 | .TypeConstraint<int64_t>("Tidx"), \ |
133 | ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>) |
134 | TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); |
135 | #undef REGISTER_CPU_KERNELS |
136 | |
137 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
138 | #define REGISTER_GPU_KERNELS(type) \ |
139 | REGISTER_KERNEL_BUILDER( \ |
140 | Name("Cumsum") \ |
141 | .Device(DEVICE_GPU) \ |
142 | .TypeConstraint<type>("T") \ |
143 | .TypeConstraint<int32>("Tidx") \ |
144 | .HostMemory("axis"), \ |
145 | ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \ |
146 | REGISTER_KERNEL_BUILDER( \ |
147 | Name("Cumsum") \ |
148 | .Device(DEVICE_GPU) \ |
149 | .TypeConstraint<type>("T") \ |
150 | .TypeConstraint<int64_t>("Tidx") \ |
151 | .HostMemory("axis"), \ |
152 | ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>) |
153 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) |
154 | REGISTER_GPU_KERNELS(int32); |
155 | REGISTER_GPU_KERNELS(int64_t); |
156 | #undef REGISTER_GPU_KERNELS |
157 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
158 | |
159 | // Register Cumprod kernels |
160 | #define REGISTER_CPU_KERNELS(type) \ |
161 | REGISTER_KERNEL_BUILDER( \ |
162 | Name("Cumprod") \ |
163 | .Device(DEVICE_CPU) \ |
164 | .TypeConstraint<type>("T") \ |
165 | .TypeConstraint<int32>("Tidx"), \ |
166 | ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \ |
167 | REGISTER_KERNEL_BUILDER( \ |
168 | Name("Cumprod") \ |
169 | .Device(DEVICE_CPU) \ |
170 | .TypeConstraint<type>("T") \ |
171 | .TypeConstraint<int64_t>("Tidx"), \ |
172 | ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>) |
173 | TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); |
174 | #undef REGISTER_CPU_KERNELS |
175 | |
176 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
177 | #define REGISTER_GPU_KERNELS(type) \ |
178 | REGISTER_KERNEL_BUILDER( \ |
179 | Name("Cumprod") \ |
180 | .Device(DEVICE_GPU) \ |
181 | .TypeConstraint<type>("T") \ |
182 | .TypeConstraint<int32>("Tidx") \ |
183 | .HostMemory("axis"), \ |
184 | ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \ |
185 | REGISTER_KERNEL_BUILDER( \ |
186 | Name("Cumprod") \ |
187 | .Device(DEVICE_GPU) \ |
188 | .TypeConstraint<type>("T") \ |
189 | .TypeConstraint<int64_t>("Tidx") \ |
190 | .HostMemory("axis"), \ |
191 | ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>) |
192 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) |
193 | REGISTER_GPU_KERNELS(int32); |
194 | REGISTER_GPU_KERNELS(int64_t); |
195 | #undef REGISTER_GPU_KERNELS |
196 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
197 | |
198 | #define REGISTER_CUMLOGSUMEXP_KERNEL(device, device_type, type, type_idx) \ |
199 | REGISTER_KERNEL_BUILDER( \ |
200 | Name("CumulativeLogsumexp") \ |
201 | .Device(device) \ |
202 | .TypeConstraint<type>("T") \ |
203 | .TypeConstraint<type_idx>("Tidx") \ |
204 | .HostMemory("axis"), \ |
205 | ScanOp<device_type, type, functor::LogSumExpReducer<type>, type_idx>) |
206 | |
207 | #define REGISTER_CPU_KERNELS(type) \ |
208 | REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int32) \ |
209 | REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int64_t) |
210 | |
211 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_CPU_KERNELS); |
212 | #undef REGISTER_CPU_KERNELS |
213 | |
214 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
215 | #define REGISTER_GPU_KERNELS(type) \ |
216 | REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int32) \ |
217 | REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int64_t) |
218 | |
219 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); |
220 | #undef REGISTER_GPU_KERNELS |
221 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
222 | |
223 | #undef REGISTER_CUMLOGSUMEXP_KERNEL |
224 | |
225 | } // namespace tensorflow |
226 | |