1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstdint>
16#include <limits>
17#include <memory>
18#include <string>
19#include <vector>
20
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25
26namespace tensorflow {
27
28using errors::InvalidArgument;
29
30template <typename T, typename SPLITS_TYPE>
31class RaggedRangeOp : public OpKernel {
32 public:
33 using OpKernel::OpKernel;
34
35 void Compute(OpKernelContext* context) override {
36 const Tensor& starts_in = context->input(0);
37 const Tensor& limits_in = context->input(1);
38 const Tensor& deltas_in = context->input(2);
39
40 // Check input tensor shapes.
41 OP_REQUIRES(context, starts_in.shape().dims() <= 1,
42 InvalidArgument("starts must be a scalar or vector"));
43 OP_REQUIRES(context, limits_in.shape().dims() <= 1,
44 InvalidArgument("limits must be a scalar or vector"));
45 OP_REQUIRES(context, deltas_in.shape().dims() <= 1,
46 InvalidArgument("deltas must be a scalar or vector"));
47
48 // Determine which tensors we need to broadcast.
49 bool broadcast_starts = starts_in.shape().dims() == 0;
50 bool broadcast_limits = limits_in.shape().dims() == 0;
51 bool broadcast_deltas = deltas_in.shape().dims() == 0;
52
53 // nrows (number of output rows) is the size of the non-broadcast inputs,
54 // or 1 if all inputs are scalars.
55 std::vector<int> in_sizes;
56 if (!broadcast_starts) in_sizes.push_back(starts_in.shape().dim_size(0));
57 if (!broadcast_limits) in_sizes.push_back(limits_in.shape().dim_size(0));
58 if (!broadcast_deltas) in_sizes.push_back(deltas_in.shape().dim_size(0));
59 for (int i = 1; i < in_sizes.size(); ++i) {
60 OP_REQUIRES(context, in_sizes[i] == in_sizes[i - 1],
61 InvalidArgument("starts, limits, and deltas must have the "
62 "same shape"));
63 }
64 SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes[0];
65
66 const auto& starts = starts_in.flat<T>();
67 const auto& limits = limits_in.flat<T>();
68 const auto& deltas = deltas_in.flat<T>();
69
70 // Construct the rt_nested_splits tensor.
71 Tensor* rt_nested_splits_out = nullptr;
72 OP_REQUIRES_OK(context,
73 context->allocate_output(0, TensorShape({nrows + 1}),
74 &rt_nested_splits_out));
75 auto rt_nested_splits = rt_nested_splits_out->flat<SPLITS_TYPE>();
76 rt_nested_splits(0) = 0;
77 for (int row = 0; row < nrows; ++row) {
78 T start = broadcast_starts ? starts(0) : starts(row);
79 T limit = broadcast_limits ? limits(0) : limits(row);
80 T delta = broadcast_deltas ? deltas(0) : deltas(row);
81 OP_REQUIRES(context, delta != 0, InvalidArgument("Requires delta != 0"));
82 int64_t size; // The number of elements in the specified range.
83 if (((delta > 0) && (limit < start)) ||
84 ((delta < 0) && (limit > start))) {
85 size = 0;
86 } else if (std::is_integral<T>::value) {
87 // The following is copied from tensorflow::RangeOp::Compute().
88 size = Eigen::divup(Eigen::numext::abs(limit - start),
89 Eigen::numext::abs(delta));
90 } else {
91 // The following is copied from tensorflow::RangeOp::Compute().
92 auto size_auto =
93 Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta));
94 OP_REQUIRES(
95 context, size_auto <= std::numeric_limits<int64_t>::max(),
96 errors::InvalidArgument("Requires ((limit - start) / delta) <= ",
97 std::numeric_limits<int64_t>::max()));
98 size = static_cast<int64_t>(size_auto);
99 }
100 rt_nested_splits(row + 1) = rt_nested_splits(row) + size;
101 }
102 SPLITS_TYPE nvals = rt_nested_splits(nrows);
103
104 // Construct the rt_dense_values tensor.
105 Tensor* rt_dense_values_out = nullptr;
106 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({nvals}),
107 &rt_dense_values_out));
108 auto rt_dense_values = rt_dense_values_out->flat<T>();
109 int value_index = 0;
110 for (int row = 0; row < nrows; ++row) {
111 SPLITS_TYPE row_size = rt_nested_splits(row + 1) - rt_nested_splits(row);
112 T value = broadcast_starts ? starts(0) : starts(row);
113 T delta = broadcast_deltas ? deltas(0) : deltas(row);
114 for (SPLITS_TYPE i = 0; i < row_size; ++i) {
115 rt_dense_values(value_index++) = T(value);
116 value += delta;
117 }
118 }
119 }
120};
121
122#define REGISTER_CPU_KERNEL(TYPE) \
123 REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
124 .Device(DEVICE_CPU) \
125 .TypeConstraint<TYPE>("T") \
126 .TypeConstraint<int32>("Tsplits"), \
127 RaggedRangeOp<TYPE, int32>); \
128 REGISTER_KERNEL_BUILDER(Name("RaggedRange") \
129 .Device(DEVICE_CPU) \
130 .TypeConstraint<TYPE>("T") \
131 .TypeConstraint<int64_t>("Tsplits"), \
132 RaggedRangeOp<TYPE, int64>);
133TF_CALL_float(REGISTER_CPU_KERNEL);
134TF_CALL_double(REGISTER_CPU_KERNEL);
135TF_CALL_int32(REGISTER_CPU_KERNEL);
136TF_CALL_int64(REGISTER_CPU_KERNEL);
137#undef REGISTER_CPU_KERNEL
138
139} // namespace tensorflow
140