1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
18
19#include <limits>
20#include <unordered_set>
21#include <vector>
22
23#include "absl/container/inlined_vector.h"
24#include "tensorflow/core/common_runtime/dma_helper.h"
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_shape.h"
29#include "tensorflow/core/framework/variant_op_registry.h"
30
31namespace tensorflow {
32
33namespace shape_op_helpers {
34inline Status GetShape(OpKernelContext* ctx, int input_index,
35 TensorShape* shape) {
36 *shape = ctx->input(input_index).shape();
37 return OkStatus();
38}
39} // namespace shape_op_helpers
40
41template <typename OutType>
42class ShapeOp : public OpKernel {
43 public:
44 explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
45
46 void Compute(OpKernelContext* ctx) override {
47 TensorShape shape;
48 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
49 const int rank = shape.dims();
50 Tensor* out = nullptr;
51 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
52 auto vec = out->vec<OutType>();
53 for (int i = 0; i < rank; ++i) {
54 int64_t dim_size = shape.dim_size(i);
55 if (out->dtype() == DT_INT32) {
56 OP_REQUIRES(
57 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
58 errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
59 i, " is ", dim_size));
60 }
61 vec(i) = static_cast<OutType>(dim_size);
62 }
63 }
64
65 bool IsExpensive() override { return false; }
66};
67
68template <typename OutType>
69class ShapeNOp : public OpKernel {
70 public:
71 explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
72
73 void Compute(OpKernelContext* ctx) override {
74 for (int i = 0; i < ctx->num_inputs(); ++i) {
75 TensorShape shape;
76 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape));
77 const int dims = shape.dims();
78 Tensor* out = nullptr;
79 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
80 auto vec = out->vec<OutType>();
81
82 for (int j = 0; j < dims; ++j) {
83 int64_t dim_size = shape.dim_size(j);
84 if (out->dtype() == DT_INT32) {
85 OP_REQUIRES(
86 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
87 errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
88 i, " dim ", j, " is ", dim_size));
89 }
90 vec(j) = static_cast<OutType>(dim_size);
91 }
92 }
93 }
94
95 bool IsExpensive() override { return false; }
96};
97
98class RankOp : public OpKernel {
99 public:
100 explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
101
102 void Compute(OpKernelContext* ctx) override {
103 TensorShape shape;
104 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
105 const int rank = shape.dims();
106 Tensor* out = nullptr;
107 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
108 out->scalar<int32>()() = rank;
109 }
110
111 bool IsExpensive() override { return false; }
112};
113
114template <typename OutType>
115class SizeOp : public OpKernel {
116 public:
117 explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
118
119 void Compute(OpKernelContext* ctx) override {
120 TensorShape shape;
121 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
122 const int64_t size = shape.num_elements();
123 Tensor* out = nullptr;
124 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
125 if (out->dtype() == DT_INT32) {
126 OP_REQUIRES(
127 ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
128 errors::InvalidArgument("Number of elements was larger than "
129 "representable by 32-bit output type"));
130 }
131 out->scalar<OutType>()() = static_cast<OutType>(size);
132 }
133
134 bool IsExpensive() override { return false; }
135};
136
137template <typename Tdim>
138class ExpandDimsOp : public OpKernel {
139 public:
140 explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
141
142 void Compute(OpKernelContext* ctx) override {
143 const Tensor& input_t = ctx->input(0);
144 OP_REQUIRES(ctx, input_t.dtype() != DT_VARIANT,
145 errors::InvalidArgument("ExpandDims on Variant not supported"));
146
147 const Tensor& dim_t = ctx->input(1);
148 OP_REQUIRES(
149 ctx, (dim_t.NumElements() == 1),
150 errors::InvalidArgument("'dim' must be a tensor with a single value"));
151 DCHECK_EQ(dim_t.dtype(), DataTypeToEnum<Tdim>::v());
152 Tdim dim = *static_cast<const Tdim*>(DMAHelper::base(&dim_t));
153 const TensorShape& input_shape = input_t.shape();
154 int input_dims = input_shape.dims();
155 OP_REQUIRES(ctx, dim >= -1 - input_dims && dim <= input_dims,
156 errors::InvalidArgument("Tried to expand dim index ", dim,
157 " for tensor with ", input_dims,
158 " dimensions."));
159
160 // We emulate numpy's interpretation of the dim axis when
161 // -input.dims() >= dim <= input.dims().
162 if (dim < 0) {
163 // Clamp to the end if needed.
164 dim = std::min<Tdim>(dim + input_dims + 1, input_dims);
165 }
166
167 // Compute new shape with an additional dimension.
168 absl::InlinedVector<int64_t, 8> output_shape_vec(input_dims + 1);
169 for (int64_t i = 0; i < dim; ++i) {
170 output_shape_vec[i] = input_shape.dim_size(i);
171 }
172 output_shape_vec[dim] = 1;
173 for (int64_t i = dim + 1; i < input_dims + 1; ++i) {
174 output_shape_vec[i] = input_shape.dim_size(i - 1);
175 }
176 TensorShape output_shape(output_shape_vec);
177
178 Tensor output_t;
179 if (!output_t.CopyFrom(input_t, output_shape)) {
180 // This should never happen, since the sizes of the input and output
181 // should always be the same (we only expand the dimension with 1).
182 ctx->SetStatus(
183 errors::Internal("Could not expand dimension with input shape ",
184 ctx->input(0).shape().DebugString(),
185 " and output shape ", output_shape.DebugString()));
186 }
187 ctx->set_output(0, std::move(output_t));
188 }
189
190 bool IsExpensive() override { return false; }
191};
192
193class SqueezeOp : public OpKernel {
194 public:
195 explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
196 std::vector<int32> squeeze_dims;
197 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
198 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
199 }
200
201 void Compute(OpKernelContext* ctx) override {
202 OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
203 errors::InvalidArgument("Squeeze on Variant not supported"));
204
205 auto existing_dims = ctx->input(0).shape().dim_sizes();
206 const int existing_dims_size = static_cast<int>(existing_dims.size());
207 std::vector<int64_t> new_shape;
208
209 std::unordered_set<int32> wrapped_squeeze_dims;
210 wrapped_squeeze_dims.reserve(squeeze_dims_.size());
211 // Validate squeeze dims against the input.
212 for (int32_t dim : squeeze_dims_) {
213 OP_REQUIRES(
214 ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
215 errors::InvalidArgument("Tried to squeeze dim index ", dim,
216 " for tensor with ", ctx->input(0).dims(),
217 " dimensions."));
218 // If dim is < 0, we wrap around (-1 means the last element).
219 if (dim < 0) {
220 dim = existing_dims_size + dim;
221 }
222
223 wrapped_squeeze_dims.insert(dim);
224 }
225
226 for (int i = 0; i < existing_dims_size; ++i) {
227 auto existing_dim = existing_dims[i];
228
229 // If squeeze_set is non-empty, only squeeze those dimensions.
230 if (!wrapped_squeeze_dims.empty()) {
231 if (wrapped_squeeze_dims.count(i) > 0) {
232 OP_REQUIRES(ctx, existing_dim == 1,
233 errors::InvalidArgument(
234 "Can not squeeze dim[", i,
235 "], expected a dimension of 1, got ", existing_dim));
236 } else {
237 // This dimension is not being squeezed.
238 new_shape.push_back(existing_dim);
239 }
240 } else {
241 // Copy over all non-1-length dimensions.
242 if (existing_dim != 1) {
243 new_shape.push_back(existing_dim);
244 }
245 }
246 }
247
248 const TensorShape output_shape(new_shape);
249 Tensor* output = nullptr;
250 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
251 if (!output->CopyFrom(ctx->input(0), output_shape)) {
252 // This should never happen, since the sizes of the input and
253 // output should always be the same.
254 ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
255 ctx->input(0).shape().DebugString(),
256 " and output shape ",
257 output_shape.DebugString()));
258 }
259 }
260
261 bool IsExpensive() override { return false; }
262
263 private:
264 std::unordered_set<int32> squeeze_dims_;
265};
266
267} // namespace tensorflow
268
269#endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
270