1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17
18#include "tensorflow/core/framework/types.pb.h"
19#include "tensorflow/core/platform/types.h"
20#define EIGEN_USE_THREADS
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/types.h"
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27
28namespace tensorflow {
29
30namespace {
31template <typename T>
32struct mod_op {
33 const T operator()(const T& a, const T& b) const { return a % b; }
34};
35} // namespace
36
37typedef Eigen::ThreadPoolDevice CPUDevice;
38
39template <typename Tidx>
40class UnravelIndexOp : public OpKernel {
41 public:
42 explicit UnravelIndexOp(OpKernelConstruction* ctx)
43 : OpKernel(ctx), dtidx_(DataTypeToEnum<Tidx>::v()) {}
44
45 void Compute(OpKernelContext* ctx) override {
46 const Tensor& indices_tensor = ctx->input(0);
47 OP_REQUIRES(ctx,
48 TensorShapeUtils::IsVector(indices_tensor.shape()) ||
49 TensorShapeUtils::IsScalar(indices_tensor.shape()),
50 errors::InvalidArgument(
51 "The indices can only be scalar or vector, got \"",
52 indices_tensor.shape().DebugString(), "\""));
53
54 const Tensor& dims_tensor = ctx->input(1);
55 OP_REQUIRES(
56 ctx, TensorShapeUtils::IsVector(dims_tensor.shape()),
57 errors::InvalidArgument("The indices can only be 1-D, got \"",
58 dims_tensor.shape().DebugString(), "\""));
59
60 auto dims = dims_tensor.vec<Tidx>();
61 // Make sure dims does not contain a zero
62 double prod = 1;
63 uint64_t limit;
64 if (dtidx_ == DataType::DT_INT64) {
65 limit = kint64max;
66 } else {
67 limit = kint32max;
68 }
69
70 for (int i = 0; i < dims.size(); i++) {
71 OP_REQUIRES(
72 ctx, dims(i) != 0,
73 errors::InvalidArgument("Input dims cannot contain a dim of zero, "
74 "but dims contains zero at index ",
75 i));
76 OP_REQUIRES(ctx, dims(i) > 0,
77 errors::InvalidArgument(
78 "Input dims cannot be negative. Got dim = ", dims(i),
79 " at index ", i));
80 // Check interger overflow
81 OP_REQUIRES(
82 ctx, prod <= limit / dims(i),
83 errors::InvalidArgument("Input dims product is causing integer "
84 "overflow: (",
85 dims, ")"));
86 prod = (prod * dims(i));
87 }
88
89 // Check to make sure indices is not out of boundary
90 Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
91 Tidx dims_prod = dims_prod_eigen();
92 const Tidx* indices = indices_tensor.flat<Tidx>().data();
93 int64_t size = indices_tensor.NumElements();
94 bool check = std::all_of(indices, indices + size,
95 [&](Tidx index) { return index < dims_prod; });
96 OP_REQUIRES(ctx, check,
97 errors::InvalidArgument("index is out of bound as with dims"));
98
99 Eigen::array<bool, 1> reverse({true});
100
101 Tensor strides_tensor;
102 OP_REQUIRES_OK(ctx,
103 ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
104 TensorShape({dims_tensor.NumElements()}),
105 &strides_tensor));
106
107 auto strides = strides_tensor.vec<Tidx>();
108 strides = dims.reverse(reverse)
109 .scan(0, Eigen::internal::ProdReducer<Tidx>(), false)
110 .reverse(reverse);
111
112 Tensor strides_shifted_tensor;
113 OP_REQUIRES_OK(ctx,
114 ctx->allocate_temp(DataTypeToEnum<Tidx>::value,
115 TensorShape({dims_tensor.NumElements()}),
116 &strides_shifted_tensor));
117
118 auto strides_shifted = strides_shifted_tensor.vec<Tidx>();
119 strides_shifted = dims.reverse(reverse)
120 .scan(0, Eigen::internal::ProdReducer<Tidx>(), true)
121 .reverse(reverse);
122
123 Tensor* output_tensor = nullptr;
124 if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
125 OP_REQUIRES_OK(
126 ctx, ctx->allocate_output(0, TensorShape({dims_tensor.NumElements()}),
127 &output_tensor));
128
129 auto output = output_tensor->vec<Tidx>();
130
131 output = output.constant(indices_tensor.scalar<Tidx>()());
132 output = output.binaryExpr(strides, mod_op<Tidx>()) / strides_shifted;
133 } else {
134 OP_REQUIRES_OK(
135 ctx, ctx->allocate_output(0,
136 TensorShape({dims_tensor.NumElements(),
137 indices_tensor.NumElements()}),
138 &output_tensor));
139
140 auto output = output_tensor->matrix<Tidx>();
141
142 Eigen::array<Eigen::Index, 2> reshape{
143 {static_cast<Eigen::Index>(dims_tensor.NumElements()), 1}};
144 Eigen::array<Eigen::Index, 2> bcast(
145 {1, static_cast<Eigen::Index>(indices_tensor.NumElements())});
146 Eigen::array<Eigen::Index, 2> indices_reshape{
147 {1, static_cast<Eigen::Index>(indices_tensor.NumElements())}};
148 Eigen::array<Eigen::Index, 2> indices_bcast(
149 {static_cast<Eigen::Index>(dims_tensor.NumElements()), 1});
150
151 output = indices_tensor.vec<Tidx>()
152 .reshape(indices_reshape)
153 .broadcast(indices_bcast);
154 output = output.binaryExpr(strides.reshape(reshape).broadcast(bcast),
155 mod_op<Tidx>()) /
156 strides_shifted.reshape(reshape).broadcast(bcast);
157 }
158 }
159 const DataType dtidx_;
160};
161
162#define REGISTER_KERNEL(type) \
163 REGISTER_KERNEL_BUILDER( \
164 Name("UnravelIndex").Device(DEVICE_CPU).TypeConstraint<type>("Tidx"), \
165 UnravelIndexOp<type>);
166TF_CALL_int32(REGISTER_KERNEL) TF_CALL_int64(REGISTER_KERNEL)
167#undef REGISTER_KERNEL
168
169} // namespace tensorflow
170