1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #include "tensorflow/core/kernels/sequence_ops.h" |
19 | |
20 | #include <cmath> |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/op_requires.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | using CPUDevice = Eigen::ThreadPoolDevice; |
32 | using GPUDevice = Eigen::GpuDevice; |
33 | |
34 | namespace functor { |
35 | |
36 | template <typename T> |
37 | struct RangeFunctor<CPUDevice, T> { |
38 | void operator()(OpKernelContext* context, int64_t size, T start, T delta, |
39 | typename TTypes<T>::Flat output) const { |
40 | (void)context; |
41 | T val = start; |
42 | for (int64_t i = 0; i < size; ++i) { |
43 | output(i) = T(val); |
44 | val += delta; |
45 | } |
46 | } |
47 | }; |
48 | |
49 | } // namespace functor |
50 | |
51 | template <typename Device, typename T> |
52 | class RangeOp : public OpKernel { |
53 | public: |
54 | explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {} |
55 | |
56 | void Compute(OpKernelContext* context) override { |
57 | const Tensor& start_in = context->input(0); |
58 | const Tensor& limit_in = context->input(1); |
59 | const Tensor& delta_in = context->input(2); |
60 | // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. |
61 | OP_REQUIRES(context, |
62 | TensorShapeUtils::IsScalar(start_in.shape()) || |
63 | (TensorShapeUtils::IsVector(start_in.shape()) && |
64 | start_in.shape().dim_size(0) == 1), |
65 | errors::InvalidArgument("start must be a scalar, not shape " , |
66 | start_in.shape().DebugString())); |
67 | OP_REQUIRES(context, |
68 | TensorShapeUtils::IsScalar(limit_in.shape()) || |
69 | (TensorShapeUtils::IsVector(limit_in.shape()) && |
70 | limit_in.shape().dim_size(0) == 1), |
71 | errors::InvalidArgument("limit must be a scalar, not shape " , |
72 | limit_in.shape().DebugString())); |
73 | OP_REQUIRES(context, |
74 | TensorShapeUtils::IsScalar(delta_in.shape()) || |
75 | (TensorShapeUtils::IsVector(delta_in.shape()) && |
76 | delta_in.shape().dim_size(0) == 1), |
77 | errors::InvalidArgument("delta must be a scalar, not shape " , |
78 | delta_in.shape().DebugString())); |
79 | const T start = start_in.scalar<T>()(); |
80 | const T limit = limit_in.scalar<T>()(); |
81 | const T delta = delta_in.scalar<T>()(); |
82 | OP_REQUIRES(context, delta != 0, |
83 | errors::InvalidArgument("Requires delta != 0: " , delta)); |
84 | if (delta > 0) { |
85 | OP_REQUIRES( |
86 | context, start <= limit, |
87 | errors::InvalidArgument( |
88 | "Requires start <= limit when delta > 0: " , start, "/" , limit)); |
89 | } else { |
90 | OP_REQUIRES( |
91 | context, start >= limit, |
92 | errors::InvalidArgument( |
93 | "Requires start >= limit when delta < 0: " , start, "/" , limit)); |
94 | } |
95 | int64_t size; |
96 | if (std::is_integral<T>::value) { |
97 | size = Eigen::divup(Eigen::numext::abs(limit - start), |
98 | Eigen::numext::abs(delta)); |
99 | } else { |
100 | auto size_auto = |
101 | Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); |
102 | OP_REQUIRES( |
103 | context, size_auto <= std::numeric_limits<int64_t>::max(), |
104 | errors::InvalidArgument("Requires ((limit - start) / delta) <= " , |
105 | std::numeric_limits<int64_t>::max())); |
106 | size = static_cast<int64_t>(size_auto); |
107 | } |
108 | |
109 | TensorShape shape; |
110 | OP_REQUIRES_OK(context, shape.AddDimWithStatus(size)); |
111 | Tensor* out = nullptr; |
112 | OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out)); |
113 | if (size == 0) return; |
114 | auto flat = out->flat<T>(); |
115 | functor::RangeFunctor<Device, T>()(context, size, start, delta, flat); |
116 | } |
117 | }; |
118 | |
119 | #define REGISTER_KERNEL(DEV, DEV_TYPE, TYPE) \ |
120 | REGISTER_KERNEL_BUILDER(Name("Range") \ |
121 | .Device(DEV) \ |
122 | .HostMemory("start") \ |
123 | .HostMemory("limit") \ |
124 | .HostMemory("delta") \ |
125 | .TypeConstraint<TYPE>("Tidx"), \ |
126 | RangeOp<DEV_TYPE, TYPE>); |
127 | |
128 | #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, CPUDevice, T) |
129 | #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, GPUDevice, T) |
130 | |
131 | TF_CALL_float(REGISTER_CPU_KERNEL); |
132 | TF_CALL_double(REGISTER_CPU_KERNEL); |
133 | TF_CALL_int32(REGISTER_CPU_KERNEL); |
134 | TF_CALL_int64(REGISTER_CPU_KERNEL); |
135 | |
136 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
137 | |
138 | TF_CALL_float(REGISTER_GPU_KERNEL); |
139 | TF_CALL_double(REGISTER_GPU_KERNEL); |
140 | TF_CALL_int64(REGISTER_GPU_KERNEL); |
141 | |
142 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
143 | |
144 | // Special case to execute int32 on the host with host output. |
145 | REGISTER_KERNEL_BUILDER(Name("Range" ) |
146 | .Device(DEVICE_DEFAULT) |
147 | .HostMemory("start" ) |
148 | .HostMemory("limit" ) |
149 | .HostMemory("delta" ) |
150 | .HostMemory("output" ) |
151 | .TypeConstraint<int32_t>("Tidx" ), |
152 | RangeOp<CPUDevice, int32_t>); |
153 | |
154 | #undef REGISTER_KERNEL |
155 | #undef REGISTER_CPU_KERNEL |
156 | #undef REGISTER_GPU_KERNEL |
157 | |
158 | template <typename T, typename Tnum> |
159 | class LinSpaceOp : public OpKernel { |
160 | public: |
161 | explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {} |
162 | |
163 | void Compute(OpKernelContext* context) override { |
164 | const Tensor& start_in = context->input(0); |
165 | const Tensor& stop_in = context->input(1); |
166 | const Tensor& num_in = context->input(2); |
167 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()), |
168 | errors::InvalidArgument("start must be a scalar, not shape " , |
169 | start_in.shape().DebugString())); |
170 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()), |
171 | errors::InvalidArgument("stop must be a scalar, not shape " , |
172 | stop_in.shape().DebugString())); |
173 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()), |
174 | errors::InvalidArgument("num must be a scalar, not shape " , |
175 | num_in.shape().DebugString())); |
176 | const T start = start_in.scalar<T>()(); |
177 | const T stop = stop_in.scalar<T>()(); |
178 | const Tnum num = num_in.scalar<Tnum>()(); |
179 | OP_REQUIRES(context, num > 0, |
180 | errors::InvalidArgument("Requires num > 0: " , num)); |
181 | Tensor* out = nullptr; |
182 | OP_REQUIRES_OK(context, |
183 | context->allocate_output(0, TensorShape({num}), &out)); |
184 | auto flat = out->flat<T>(); |
185 | flat(0) = start; |
186 | if (num > 1) { |
187 | const T step = (stop - start) / (num - 1); |
188 | for (Tnum i = 1; i < num - 1; ++i) flat(i) = start + step * i; |
189 | // Ensure final value == stop; float arithmetic won't guarantee this. |
190 | flat(num - 1) = stop; |
191 | } |
192 | } |
193 | }; |
194 | |
195 | #define REGISTER_KERNEL(DEV, T, Tidx) \ |
196 | REGISTER_KERNEL_BUILDER(Name("LinSpace") \ |
197 | .Device(DEV) \ |
198 | .TypeConstraint<T>("T") \ |
199 | .TypeConstraint<Tidx>("Tidx") \ |
200 | .HostMemory("start") \ |
201 | .HostMemory("stop") \ |
202 | .HostMemory("num") \ |
203 | .HostMemory("output"), \ |
204 | LinSpaceOp<T, Tidx>); |
205 | |
206 | #define REGISTER_KERNEL_ALL_NUMS(dev, T) \ |
207 | REGISTER_KERNEL(dev, T, int32); \ |
208 | REGISTER_KERNEL(dev, T, int64_t) |
209 | |
210 | #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T) |
211 | TF_CALL_float(REGISTER_CPU_KERNEL); |
212 | TF_CALL_double(REGISTER_CPU_KERNEL); |
213 | |
214 | #define REGISTER_DEFAULT_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_DEFAULT, T) |
215 | TF_CALL_float(REGISTER_DEFAULT_KERNEL); |
216 | TF_CALL_double(REGISTER_DEFAULT_KERNEL); |
217 | #undef REGISTER_DEFAULT_KERNEL |
218 | |
219 | #undef REGISTER_CPU_KERNEL |
220 | #undef REGISTER_KERNEL_ALL_NUMS |
221 | #undef REGISTER_KERNEL |
222 | |
223 | } // namespace tensorflow |
224 | |