1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstdint> |
17 | |
18 | #include "tensorflow/core/framework/types.pb.h" |
19 | #include "tensorflow/core/platform/types.h" |
20 | #define EIGEN_USE_THREADS |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace { |
31 | template <typename T> |
32 | struct mod_op { |
33 | const T operator()(const T& a, const T& b) const { return a % b; } |
34 | }; |
35 | } // namespace |
36 | |
37 | typedef Eigen::ThreadPoolDevice CPUDevice; |
38 | |
39 | template <typename Tidx> |
40 | class UnravelIndexOp : public OpKernel { |
41 | public: |
42 | explicit UnravelIndexOp(OpKernelConstruction* ctx) |
43 | : OpKernel(ctx), dtidx_(DataTypeToEnum<Tidx>::v()) {} |
44 | |
45 | void Compute(OpKernelContext* ctx) override { |
46 | const Tensor& indices_tensor = ctx->input(0); |
47 | OP_REQUIRES(ctx, |
48 | TensorShapeUtils::IsVector(indices_tensor.shape()) || |
49 | TensorShapeUtils::IsScalar(indices_tensor.shape()), |
50 | errors::InvalidArgument( |
51 | "The indices can only be scalar or vector, got \"" , |
52 | indices_tensor.shape().DebugString(), "\"" )); |
53 | |
54 | const Tensor& dims_tensor = ctx->input(1); |
55 | OP_REQUIRES( |
56 | ctx, TensorShapeUtils::IsVector(dims_tensor.shape()), |
57 | errors::InvalidArgument("The indices can only be 1-D, got \"" , |
58 | dims_tensor.shape().DebugString(), "\"" )); |
59 | |
60 | auto dims = dims_tensor.vec<Tidx>(); |
61 | // Make sure dims does not contain a zero |
62 | double prod = 1; |
63 | uint64_t limit; |
64 | if (dtidx_ == DataType::DT_INT64) { |
65 | limit = kint64max; |
66 | } else { |
67 | limit = kint32max; |
68 | } |
69 | |
70 | for (int i = 0; i < dims.size(); i++) { |
71 | OP_REQUIRES( |
72 | ctx, dims(i) != 0, |
73 | errors::InvalidArgument("Input dims cannot contain a dim of zero, " |
74 | "but dims contains zero at index " , |
75 | i)); |
76 | OP_REQUIRES(ctx, dims(i) > 0, |
77 | errors::InvalidArgument( |
78 | "Input dims cannot be negative. Got dim = " , dims(i), |
79 | " at index " , i)); |
80 | // Check interger overflow |
81 | OP_REQUIRES( |
82 | ctx, prod <= limit / dims(i), |
83 | errors::InvalidArgument("Input dims product is causing integer " |
84 | "overflow: (" , |
85 | dims, ")" )); |
86 | prod = (prod * dims(i)); |
87 | } |
88 | |
89 | // Check to make sure indices is not out of boundary |
90 | Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod(); |
91 | Tidx dims_prod = dims_prod_eigen(); |
92 | const Tidx* indices = indices_tensor.flat<Tidx>().data(); |
93 | int64_t size = indices_tensor.NumElements(); |
94 | bool check = std::all_of(indices, indices + size, |
95 | [&](Tidx index) { return index < dims_prod; }); |
96 | OP_REQUIRES(ctx, check, |
97 | errors::InvalidArgument("index is out of bound as with dims" )); |
98 | |
99 | Eigen::array<bool, 1> reverse({true}); |
100 | |
101 | Tensor strides_tensor; |
102 | OP_REQUIRES_OK(ctx, |
103 | ctx->allocate_temp(DataTypeToEnum<Tidx>::value, |
104 | TensorShape({dims_tensor.NumElements()}), |
105 | &strides_tensor)); |
106 | |
107 | auto strides = strides_tensor.vec<Tidx>(); |
108 | strides = dims.reverse(reverse) |
109 | .scan(0, Eigen::internal::ProdReducer<Tidx>(), false) |
110 | .reverse(reverse); |
111 | |
112 | Tensor strides_shifted_tensor; |
113 | OP_REQUIRES_OK(ctx, |
114 | ctx->allocate_temp(DataTypeToEnum<Tidx>::value, |
115 | TensorShape({dims_tensor.NumElements()}), |
116 | &strides_shifted_tensor)); |
117 | |
118 | auto strides_shifted = strides_shifted_tensor.vec<Tidx>(); |
119 | strides_shifted = dims.reverse(reverse) |
120 | .scan(0, Eigen::internal::ProdReducer<Tidx>(), true) |
121 | .reverse(reverse); |
122 | |
123 | Tensor* output_tensor = nullptr; |
124 | if (TensorShapeUtils::IsScalar(indices_tensor.shape())) { |
125 | OP_REQUIRES_OK( |
126 | ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}), |
127 | &output_tensor)); |
128 | |
129 | auto output = output_tensor->vec<Tidx>(); |
130 | |
131 | output = output.constant(indices_tensor.scalar<Tidx>()()); |
132 | output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted; |
133 | } else { |
134 | OP_REQUIRES_OK( |
135 | ctx, ctx->allocate_output(0, |
136 | TensorShape({dims_tensor.NumElements(), |
137 | indices_tensor.NumElements()}), |
138 | &output_tensor)); |
139 | |
140 | auto output = output_tensor->matrix<Tidx>(); |
141 | |
142 | Eigen::array<Eigen::Index, 2> reshape{ |
143 | {static_cast<Eigen::Index>(dims_tensor.NumElements()), 1}}; |
144 | Eigen::array<Eigen::Index, 2> bcast( |
145 | {1, static_cast<Eigen::Index>(indices_tensor.NumElements())}); |
146 | Eigen::array<Eigen::Index, 2> indices_reshape{ |
147 | {1, static_cast<Eigen::Index>(indices_tensor.NumElements())}}; |
148 | Eigen::array<Eigen::Index, 2> indices_bcast( |
149 | {static_cast<Eigen::Index>(dims_tensor.NumElements()), 1}); |
150 | |
151 | output = indices_tensor.vec<Tidx>() |
152 | .reshape(indices_reshape) |
153 | .broadcast(indices_bcast); |
154 | output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast), |
155 | mod_op<Tidx>()) / |
156 | strides_shifted.reshape(reshape).broadcast(bcast); |
157 | } |
158 | } |
159 | const DataType dtidx_; |
160 | }; |
161 | |
162 | #define REGISTER_KERNEL(type) \ |
163 | REGISTER_KERNEL_BUILDER( \ |
164 | Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint<type>("Tidx"), \ |
165 | UnravelIndexOp<type>); |
166 | TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL) |
167 | #undef REGISTER_KERNEL |
168 | |
169 | } // namespace tensorflow |
170 | |