1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
18#define EIGEN_USE_GPU
19#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21#include "tensorflow/core/kernels/scan_ops.h"
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/core/framework/bounds_check.h"
25#include "tensorflow/core/framework/numeric_op.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/types.h"
30
31namespace tensorflow {
32
33typedef Eigen::ThreadPoolDevice CPUDevice;
34typedef Eigen::GpuDevice GPUDevice;
35
36template <typename Device, class T, typename Reducer, typename Tidx>
37class ScanOp : public OpKernel {
38 public:
39 explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
40 OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
41 OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
42 }
43
44 void Compute(OpKernelContext* ctx) override {
45 const Tensor& input = ctx->input(0);
46 const Tensor& tensor_axis = ctx->input(1);
47
48 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis.shape()),
49 errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
50 tensor_axis.shape().DebugString()));
51
52 const Tidx axis_arg =
53 internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()());
54 const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
55 OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()),
56 errors::InvalidArgument(
57 "ScanOp: Expected scan axis in the range [", -input.dims(),
58 ", ", input.dims(), "), but got ", axis));
59
60 const TensorShape& output_shape = input.shape();
61 Tensor* output = nullptr;
62 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
63
64 // Exit early if there's nothing to compute
65 if (output_shape.num_elements() == 0) return;
66
67 const Device& d = ctx->eigen_device<Device>();
68 Reducer reducer;
69
70 // Dim reduction.
71 int64_t reduced_shape[3] = {1, 1, 1};
72 for (Tidx i = 0; i < axis; ++i) {
73 reduced_shape[0] *= input.dim_size(i);
74 }
75 reduced_shape[1] = input.dim_size(axis);
76 for (Tidx i = axis + 1; i < input.dims(); ++i) {
77 reduced_shape[2] *= input.dim_size(i);
78 }
79
80 functor::Scan<Device, Reducer, T>()(d, input.shaped<T, 3>(reduced_shape),
81 output->shaped<T, 3>(reduced_shape),
82 reducer, reverse_, exclusive_);
83 }
84
85 private:
86 bool reverse_;
87 bool exclusive_;
88};
89
90#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
91namespace functor {
92
93// Forward declarations of GPU functors
94#define DECLARE(REDUCER, T) \
95 template <> \
96 void Scan<GPUDevice, REDUCER, T>::operator()( \
97 const GPUDevice& d, TTypes<T, 3>::ConstTensor in, \
98 TTypes<T, 3>::Tensor out, const REDUCER& reducer, const bool reverse, \
99 const bool exclusive); \
100 extern template struct Scan<GPUDevice, REDUCER, T>;
101
102#define DECLARE_FOR_ALL_REDUCERS(T) \
103 DECLARE(Eigen::internal::SumReducer<T>, T); \
104 DECLARE(Eigen::internal::ProdReducer<T>, T);
105
106TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
107DECLARE_FOR_ALL_REDUCERS(int32);
108DECLARE_FOR_ALL_REDUCERS(int64_t);
109#undef DECLARE_FOR_ALL_REDUCERS
110
111#define DECLARE_FOR_LOGSUMEXP_REDUCER(T) DECLARE(LogSumExpReducer<T>, T);
112TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_LOGSUMEXP_REDUCER)
113#undef DECLARE_FOR_LOGSUMEXP_REDUCER
114
115#undef DECLARE
116
117} // namespace functor
118#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
119
120// Register Cumsum kernels
121#define REGISTER_CPU_KERNELS(type) \
122 REGISTER_KERNEL_BUILDER( \
123 Name("Cumsum") \
124 .Device(DEVICE_CPU) \
125 .TypeConstraint<type>("T") \
126 .TypeConstraint<int32>("Tidx"), \
127 ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
128 REGISTER_KERNEL_BUILDER( \
129 Name("Cumsum") \
130 .Device(DEVICE_CPU) \
131 .TypeConstraint<type>("T") \
132 .TypeConstraint<int64_t>("Tidx"), \
133 ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
134TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
135#undef REGISTER_CPU_KERNELS
136
137#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
138#define REGISTER_GPU_KERNELS(type) \
139 REGISTER_KERNEL_BUILDER( \
140 Name("Cumsum") \
141 .Device(DEVICE_GPU) \
142 .TypeConstraint<type>("T") \
143 .TypeConstraint<int32>("Tidx") \
144 .HostMemory("axis"), \
145 ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
146 REGISTER_KERNEL_BUILDER( \
147 Name("Cumsum") \
148 .Device(DEVICE_GPU) \
149 .TypeConstraint<type>("T") \
150 .TypeConstraint<int64_t>("Tidx") \
151 .HostMemory("axis"), \
152 ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
153TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
154REGISTER_GPU_KERNELS(int32);
155REGISTER_GPU_KERNELS(int64_t);
156#undef REGISTER_GPU_KERNELS
157#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
158
159// Register Cumprod kernels
160#define REGISTER_CPU_KERNELS(type) \
161 REGISTER_KERNEL_BUILDER( \
162 Name("Cumprod") \
163 .Device(DEVICE_CPU) \
164 .TypeConstraint<type>("T") \
165 .TypeConstraint<int32>("Tidx"), \
166 ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
167 REGISTER_KERNEL_BUILDER( \
168 Name("Cumprod") \
169 .Device(DEVICE_CPU) \
170 .TypeConstraint<type>("T") \
171 .TypeConstraint<int64_t>("Tidx"), \
172 ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
173TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
174#undef REGISTER_CPU_KERNELS
175
176#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
177#define REGISTER_GPU_KERNELS(type) \
178 REGISTER_KERNEL_BUILDER( \
179 Name("Cumprod") \
180 .Device(DEVICE_GPU) \
181 .TypeConstraint<type>("T") \
182 .TypeConstraint<int32>("Tidx") \
183 .HostMemory("axis"), \
184 ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
185 REGISTER_KERNEL_BUILDER( \
186 Name("Cumprod") \
187 .Device(DEVICE_GPU) \
188 .TypeConstraint<type>("T") \
189 .TypeConstraint<int64_t>("Tidx") \
190 .HostMemory("axis"), \
191 ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
192TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
193REGISTER_GPU_KERNELS(int32);
194REGISTER_GPU_KERNELS(int64_t);
195#undef REGISTER_GPU_KERNELS
196#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
197
198#define REGISTER_CUMLOGSUMEXP_KERNEL(device, device_type, type, type_idx) \
199 REGISTER_KERNEL_BUILDER( \
200 Name("CumulativeLogsumexp") \
201 .Device(device) \
202 .TypeConstraint<type>("T") \
203 .TypeConstraint<type_idx>("Tidx") \
204 .HostMemory("axis"), \
205 ScanOp<device_type, type, functor::LogSumExpReducer<type>, type_idx>)
206
207#define REGISTER_CPU_KERNELS(type) \
208 REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int32) \
209 REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_CPU, CPUDevice, type, int64_t)
210
211TF_CALL_GPU_NUMBER_TYPES(REGISTER_CPU_KERNELS);
212#undef REGISTER_CPU_KERNELS
213
214#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
215#define REGISTER_GPU_KERNELS(type) \
216 REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int32) \
217 REGISTER_CUMLOGSUMEXP_KERNEL(DEVICE_GPU, GPUDevice, type, int64_t)
218
219TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
220#undef REGISTER_GPU_KERNELS
221#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
222
223#undef REGISTER_CUMLOGSUMEXP_KERNEL
224
225} // namespace tensorflow
226