1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstdint> |
16 | #include <limits> |
17 | #include <memory> |
18 | #include <string> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | using errors::InvalidArgument; |
29 | |
30 | template <typename T, typename SPLITS_TYPE> |
31 | class RaggedRangeOp : public OpKernel { |
32 | public: |
33 | using OpKernel::OpKernel; |
34 | |
35 | void Compute(OpKernelContext* context) override { |
36 | const Tensor& starts_in = context->input(0); |
37 | const Tensor& limits_in = context->input(1); |
38 | const Tensor& deltas_in = context->input(2); |
39 | |
40 | // Check input tensor shapes. |
41 | OP_REQUIRES(context, starts_in.shape().dims() <= 1, |
42 | InvalidArgument("starts must be a scalar or vector" )); |
43 | OP_REQUIRES(context, limits_in.shape().dims() <= 1, |
44 | InvalidArgument("limits must be a scalar or vector" )); |
45 | OP_REQUIRES(context, deltas_in.shape().dims() <= 1, |
46 | InvalidArgument("deltas must be a scalar or vector" )); |
47 | |
48 | // Determine which tensors we need to broadcast. |
49 | bool broadcast_starts = starts_in.shape().dims() == 0; |
50 | bool broadcast_limits = limits_in.shape().dims() == 0; |
51 | bool broadcast_deltas = deltas_in.shape().dims() == 0; |
52 | |
53 | // nrows (number of output rows) is the size of the non-broadcast inputs, |
54 | // or 1 if all inputs are scalars. |
55 | std::vector<int> in_sizes; |
56 | if (!broadcast_starts) in_sizes.push_back(starts_in.shape().dim_size(0)); |
57 | if (!broadcast_limits) in_sizes.push_back(limits_in.shape().dim_size(0)); |
58 | if (!broadcast_deltas) in_sizes.push_back(deltas_in.shape().dim_size(0)); |
59 | for (int i = 1; i < in_sizes.size(); ++i) { |
60 | OP_REQUIRES(context, in_sizes[i] == in_sizes[i - 1], |
61 | InvalidArgument("starts, limits, and deltas must have the " |
62 | "same shape" )); |
63 | } |
64 | SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes[0]; |
65 | |
66 | const auto& starts = starts_in.flat<T>(); |
67 | const auto& limits = limits_in.flat<T>(); |
68 | const auto& deltas = deltas_in.flat<T>(); |
69 | |
70 | // Construct the rt_nested_splits tensor. |
71 | Tensor* rt_nested_splits_out = nullptr; |
72 | OP_REQUIRES_OK(context, |
73 | context->allocate_output(0, TensorShape({nrows + 1}), |
74 | &rt_nested_splits_out)); |
75 | auto rt_nested_splits = rt_nested_splits_out->flat<SPLITS_TYPE>(); |
76 | rt_nested_splits(0) = 0; |
77 | for (int row = 0; row < nrows; ++row) { |
78 | T start = broadcast_starts ? starts(0) : starts(row); |
79 | T limit = broadcast_limits ? limits(0) : limits(row); |
80 | T delta = broadcast_deltas ? deltas(0) : deltas(row); |
81 | OP_REQUIRES(context, delta != 0, InvalidArgument("Requires delta != 0" )); |
82 | int64_t size; // The number of elements in the specified range. |
83 | if (((delta > 0) && (limit < start)) || |
84 | ((delta < 0) && (limit > start))) { |
85 | size = 0; |
86 | } else if (std::is_integral<T>::value) { |
87 | // The following is copied from tensorflow::RangeOp::Compute(). |
88 | size = Eigen::divup(Eigen::numext::abs(limit - start), |
89 | Eigen::numext::abs(delta)); |
90 | } else { |
91 | // The following is copied from tensorflow::RangeOp::Compute(). |
92 | auto size_auto = |
93 | Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); |
94 | OP_REQUIRES( |
95 | context, size_auto <= std::numeric_limits<int64_t>::max(), |
96 | errors::InvalidArgument("Requires ((limit - start) / delta) <= " , |
97 | std::numeric_limits<int64_t>::max())); |
98 | size = static_cast<int64_t>(size_auto); |
99 | } |
100 | rt_nested_splits(row + 1) = rt_nested_splits(row) + size; |
101 | } |
102 | SPLITS_TYPE nvals = rt_nested_splits(nrows); |
103 | |
104 | // Construct the rt_dense_values tensor. |
105 | Tensor* rt_dense_values_out = nullptr; |
106 | OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({nvals}), |
107 | &rt_dense_values_out)); |
108 | auto rt_dense_values = rt_dense_values_out->flat<T>(); |
109 | int value_index = 0; |
110 | for (int row = 0; row < nrows; ++row) { |
111 | SPLITS_TYPE row_size = rt_nested_splits(row + 1) - rt_nested_splits(row); |
112 | T value = broadcast_starts ? starts(0) : starts(row); |
113 | T delta = broadcast_deltas ? deltas(0) : deltas(row); |
114 | for (SPLITS_TYPE i = 0; i < row_size; ++i) { |
115 | rt_dense_values(value_index++) = T(value); |
116 | value += delta; |
117 | } |
118 | } |
119 | } |
120 | }; |
121 | |
122 | #define REGISTER_CPU_KERNEL(TYPE) \ |
123 | REGISTER_KERNEL_BUILDER(Name("RaggedRange") \ |
124 | .Device(DEVICE_CPU) \ |
125 | .TypeConstraint<TYPE>("T") \ |
126 | .TypeConstraint<int32>("Tsplits"), \ |
127 | RaggedRangeOp<TYPE, int32>); \ |
128 | REGISTER_KERNEL_BUILDER(Name("RaggedRange") \ |
129 | .Device(DEVICE_CPU) \ |
130 | .TypeConstraint<TYPE>("T") \ |
131 | .TypeConstraint<int64_t>("Tsplits"), \ |
132 | RaggedRangeOp<TYPE, int64>); |
133 | TF_CALL_float(REGISTER_CPU_KERNEL); |
134 | TF_CALL_double(REGISTER_CPU_KERNEL); |
135 | TF_CALL_int32(REGISTER_CPU_KERNEL); |
136 | TF_CALL_int64(REGISTER_CPU_KERNEL); |
137 | #undef REGISTER_CPU_KERNEL |
138 | |
139 | } // namespace tensorflow |
140 | |