1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#include "tensorflow/core/kernels/sequence_ops.h"
19
20#include <cmath>
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/types.h"
28
29namespace tensorflow {
30
31using CPUDevice = Eigen::ThreadPoolDevice;
32using GPUDevice = Eigen::GpuDevice;
33
34namespace functor {
35
36template <typename T>
37struct RangeFunctor<CPUDevice, T> {
38 void operator()(OpKernelContext* context, int64_t size, T start, T delta,
39 typename TTypes<T>::Flat output) const {
40 (void)context;
41 T val = start;
42 for (int64_t i = 0; i < size; ++i) {
43 output(i) = T(val);
44 val += delta;
45 }
46 }
47};
48
49} // namespace functor
50
51template <typename Device, typename T>
52class RangeOp : public OpKernel {
53 public:
54 explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {}
55
56 void Compute(OpKernelContext* context) override {
57 const Tensor& start_in = context->input(0);
58 const Tensor& limit_in = context->input(1);
59 const Tensor& delta_in = context->input(2);
60 // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
61 OP_REQUIRES(context,
62 TensorShapeUtils::IsScalar(start_in.shape()) ||
63 (TensorShapeUtils::IsVector(start_in.shape()) &&
64 start_in.shape().dim_size(0) == 1),
65 errors::InvalidArgument("start must be a scalar, not shape ",
66 start_in.shape().DebugString()));
67 OP_REQUIRES(context,
68 TensorShapeUtils::IsScalar(limit_in.shape()) ||
69 (TensorShapeUtils::IsVector(limit_in.shape()) &&
70 limit_in.shape().dim_size(0) == 1),
71 errors::InvalidArgument("limit must be a scalar, not shape ",
72 limit_in.shape().DebugString()));
73 OP_REQUIRES(context,
74 TensorShapeUtils::IsScalar(delta_in.shape()) ||
75 (TensorShapeUtils::IsVector(delta_in.shape()) &&
76 delta_in.shape().dim_size(0) == 1),
77 errors::InvalidArgument("delta must be a scalar, not shape ",
78 delta_in.shape().DebugString()));
79 const T start = start_in.scalar<T>()();
80 const T limit = limit_in.scalar<T>()();
81 const T delta = delta_in.scalar<T>()();
82 OP_REQUIRES(context, delta != 0,
83 errors::InvalidArgument("Requires delta != 0: ", delta));
84 if (delta > 0) {
85 OP_REQUIRES(
86 context, start <= limit,
87 errors::InvalidArgument(
88 "Requires start <= limit when delta > 0: ", start, "/", limit));
89 } else {
90 OP_REQUIRES(
91 context, start >= limit,
92 errors::InvalidArgument(
93 "Requires start >= limit when delta < 0: ", start, "/", limit));
94 }
95 int64_t size;
96 if (std::is_integral<T>::value) {
97 size = Eigen::divup(Eigen::numext::abs(limit - start),
98 Eigen::numext::abs(delta));
99 } else {
100 auto size_auto =
101 Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta));
102 OP_REQUIRES(
103 context, size_auto <= std::numeric_limits<int64_t>::max(),
104 errors::InvalidArgument("Requires ((limit - start) / delta) <= ",
105 std::numeric_limits<int64_t>::max()));
106 size = static_cast<int64_t>(size_auto);
107 }
108
109 TensorShape shape;
110 OP_REQUIRES_OK(context, shape.AddDimWithStatus(size));
111 Tensor* out = nullptr;
112 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &out));
113 if (size == 0) return;
114 auto flat = out->flat<T>();
115 functor::RangeFunctor<Device, T>()(context, size, start, delta, flat);
116 }
117};
118
119#define REGISTER_KERNEL(DEV, DEV_TYPE, TYPE) \
120 REGISTER_KERNEL_BUILDER(Name("Range") \
121 .Device(DEV) \
122 .HostMemory("start") \
123 .HostMemory("limit") \
124 .HostMemory("delta") \
125 .TypeConstraint<TYPE>("Tidx"), \
126 RangeOp<DEV_TYPE, TYPE>);
127
128#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, CPUDevice, T)
129#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, GPUDevice, T)
130
131TF_CALL_float(REGISTER_CPU_KERNEL);
132TF_CALL_double(REGISTER_CPU_KERNEL);
133TF_CALL_int32(REGISTER_CPU_KERNEL);
134TF_CALL_int64(REGISTER_CPU_KERNEL);
135
136#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
137
138TF_CALL_float(REGISTER_GPU_KERNEL);
139TF_CALL_double(REGISTER_GPU_KERNEL);
140TF_CALL_int64(REGISTER_GPU_KERNEL);
141
142#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
143
144// Special case to execute int32 on the host with host output.
145REGISTER_KERNEL_BUILDER(Name("Range")
146 .Device(DEVICE_DEFAULT)
147 .HostMemory("start")
148 .HostMemory("limit")
149 .HostMemory("delta")
150 .HostMemory("output")
151 .TypeConstraint<int32_t>("Tidx"),
152 RangeOp<CPUDevice, int32_t>);
153
154#undef REGISTER_KERNEL
155#undef REGISTER_CPU_KERNEL
156#undef REGISTER_GPU_KERNEL
157
158template <typename T, typename Tnum>
159class LinSpaceOp : public OpKernel {
160 public:
161 explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {}
162
163 void Compute(OpKernelContext* context) override {
164 const Tensor& start_in = context->input(0);
165 const Tensor& stop_in = context->input(1);
166 const Tensor& num_in = context->input(2);
167 OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()),
168 errors::InvalidArgument("start must be a scalar, not shape ",
169 start_in.shape().DebugString()));
170 OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()),
171 errors::InvalidArgument("stop must be a scalar, not shape ",
172 stop_in.shape().DebugString()));
173 OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()),
174 errors::InvalidArgument("num must be a scalar, not shape ",
175 num_in.shape().DebugString()));
176 const T start = start_in.scalar<T>()();
177 const T stop = stop_in.scalar<T>()();
178 const Tnum num = num_in.scalar<Tnum>()();
179 OP_REQUIRES(context, num > 0,
180 errors::InvalidArgument("Requires num > 0: ", num));
181 Tensor* out = nullptr;
182 OP_REQUIRES_OK(context,
183 context->allocate_output(0, TensorShape({num}), &out));
184 auto flat = out->flat<T>();
185 flat(0) = start;
186 if (num > 1) {
187 const T step = (stop - start) / (num - 1);
188 for (Tnum i = 1; i < num - 1; ++i) flat(i) = start + step * i;
189 // Ensure final value == stop; float arithmetic won't guarantee this.
190 flat(num - 1) = stop;
191 }
192 }
193};
194
195#define REGISTER_KERNEL(DEV, T, Tidx) \
196 REGISTER_KERNEL_BUILDER(Name("LinSpace") \
197 .Device(DEV) \
198 .TypeConstraint<T>("T") \
199 .TypeConstraint<Tidx>("Tidx") \
200 .HostMemory("start") \
201 .HostMemory("stop") \
202 .HostMemory("num") \
203 .HostMemory("output"), \
204 LinSpaceOp<T, Tidx>);
205
206#define REGISTER_KERNEL_ALL_NUMS(dev, T) \
207 REGISTER_KERNEL(dev, T, int32); \
208 REGISTER_KERNEL(dev, T, int64_t)
209
210#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T)
211TF_CALL_float(REGISTER_CPU_KERNEL);
212TF_CALL_double(REGISTER_CPU_KERNEL);
213
214#define REGISTER_DEFAULT_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_DEFAULT, T)
215TF_CALL_float(REGISTER_DEFAULT_KERNEL);
216TF_CALL_double(REGISTER_DEFAULT_KERNEL);
217#undef REGISTER_DEFAULT_KERNEL
218
219#undef REGISTER_CPU_KERNEL
220#undef REGISTER_KERNEL_ALL_NUMS
221#undef REGISTER_KERNEL
222
223} // namespace tensorflow
224