1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ |
18 | |
19 | #include <limits> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/inlined_vector.h" |
24 | #include "tensorflow/core/common_runtime/dma_helper.h" |
25 | #include "tensorflow/core/framework/bounds_check.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_shape.h" |
29 | #include "tensorflow/core/framework/variant_op_registry.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | namespace shape_op_helpers { |
34 | inline Status GetShape(OpKernelContext* ctx, int input_index, |
35 | TensorShape* shape) { |
36 | *shape = ctx->input(input_index).shape(); |
37 | return OkStatus(); |
38 | } |
39 | } // namespace shape_op_helpers |
40 | |
41 | template <typename OutType> |
42 | class ShapeOp : public OpKernel { |
43 | public: |
44 | explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
45 | |
46 | void Compute(OpKernelContext* ctx) override { |
47 | TensorShape shape; |
48 | OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); |
49 | const int rank = shape.dims(); |
50 | Tensor* out = nullptr; |
51 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); |
52 | auto vec = out->vec<OutType>(); |
53 | for (int i = 0; i < rank; ++i) { |
54 | int64_t dim_size = shape.dim_size(i); |
55 | if (out->dtype() == DT_INT32) { |
56 | OP_REQUIRES( |
57 | ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), |
58 | errors::InvalidArgument("Shape output type is 32-bit " , " but dim " , |
59 | i, " is " , dim_size)); |
60 | } |
61 | vec(i) = static_cast<OutType>(dim_size); |
62 | } |
63 | } |
64 | |
65 | bool IsExpensive() override { return false; } |
66 | }; |
67 | |
68 | template <typename OutType> |
69 | class ShapeNOp : public OpKernel { |
70 | public: |
71 | explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
72 | |
73 | void Compute(OpKernelContext* ctx) override { |
74 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
75 | TensorShape shape; |
76 | OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape)); |
77 | const int dims = shape.dims(); |
78 | Tensor* out = nullptr; |
79 | OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); |
80 | auto vec = out->vec<OutType>(); |
81 | |
82 | for (int j = 0; j < dims; ++j) { |
83 | int64_t dim_size = shape.dim_size(j); |
84 | if (out->dtype() == DT_INT32) { |
85 | OP_REQUIRES( |
86 | ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()), |
87 | errors::InvalidArgument("ShapeN output type is 32-bit but shape " , |
88 | i, " dim " , j, " is " , dim_size)); |
89 | } |
90 | vec(j) = static_cast<OutType>(dim_size); |
91 | } |
92 | } |
93 | } |
94 | |
95 | bool IsExpensive() override { return false; } |
96 | }; |
97 | |
98 | class RankOp : public OpKernel { |
99 | public: |
100 | explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
101 | |
102 | void Compute(OpKernelContext* ctx) override { |
103 | TensorShape shape; |
104 | OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); |
105 | const int rank = shape.dims(); |
106 | Tensor* out = nullptr; |
107 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); |
108 | out->scalar<int32>()() = rank; |
109 | } |
110 | |
111 | bool IsExpensive() override { return false; } |
112 | }; |
113 | |
114 | template <typename OutType> |
115 | class SizeOp : public OpKernel { |
116 | public: |
117 | explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
118 | |
119 | void Compute(OpKernelContext* ctx) override { |
120 | TensorShape shape; |
121 | OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); |
122 | const int64_t size = shape.num_elements(); |
123 | Tensor* out = nullptr; |
124 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); |
125 | if (out->dtype() == DT_INT32) { |
126 | OP_REQUIRES( |
127 | ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()), |
128 | errors::InvalidArgument("Number of elements was larger than " |
129 | "representable by 32-bit output type" )); |
130 | } |
131 | out->scalar<OutType>()() = static_cast<OutType>(size); |
132 | } |
133 | |
134 | bool IsExpensive() override { return false; } |
135 | }; |
136 | |
137 | template <typename Tdim> |
138 | class ExpandDimsOp : public OpKernel { |
139 | public: |
140 | explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
141 | |
142 | void Compute(OpKernelContext* ctx) override { |
143 | const Tensor& input_t = ctx->input(0); |
144 | OP_REQUIRES(ctx, input_t.dtype() != DT_VARIANT, |
145 | errors::InvalidArgument("ExpandDims on Variant not supported" )); |
146 | |
147 | const Tensor& dim_t = ctx->input(1); |
148 | OP_REQUIRES( |
149 | ctx, (dim_t.NumElements() == 1), |
150 | errors::InvalidArgument("'dim' must be a tensor with a single value" )); |
151 | DCHECK_EQ(dim_t.dtype(), DataTypeToEnum<Tdim>::v()); |
152 | Tdim dim = *static_cast<const Tdim*>(DMAHelper::base(&dim_t)); |
153 | const TensorShape& input_shape = input_t.shape(); |
154 | int input_dims = input_shape.dims(); |
155 | OP_REQUIRES(ctx, dim >= -1 - input_dims && dim <= input_dims, |
156 | errors::InvalidArgument("Tried to expand dim index " , dim, |
157 | " for tensor with " , input_dims, |
158 | " dimensions." )); |
159 | |
160 | // We emulate numpy's interpretation of the dim axis when |
161 | // -input.dims() >= dim <= input.dims(). |
162 | if (dim < 0) { |
163 | // Clamp to the end if needed. |
164 | dim = std::min<Tdim>(dim + input_dims + 1, input_dims); |
165 | } |
166 | |
167 | // Compute new shape with an additional dimension. |
168 | absl::InlinedVector<int64_t, 8> output_shape_vec(input_dims + 1); |
169 | for (int64_t i = 0; i < dim; ++i) { |
170 | output_shape_vec[i] = input_shape.dim_size(i); |
171 | } |
172 | output_shape_vec[dim] = 1; |
173 | for (int64_t i = dim + 1; i < input_dims + 1; ++i) { |
174 | output_shape_vec[i] = input_shape.dim_size(i - 1); |
175 | } |
176 | TensorShape output_shape(output_shape_vec); |
177 | |
178 | Tensor output_t; |
179 | if (!output_t.CopyFrom(input_t, output_shape)) { |
180 | // This should never happen, since the sizes of the input and output |
181 | // should always be the same (we only expand the dimension with 1). |
182 | ctx->SetStatus( |
183 | errors::Internal("Could not expand dimension with input shape " , |
184 | ctx->input(0).shape().DebugString(), |
185 | " and output shape " , output_shape.DebugString())); |
186 | } |
187 | ctx->set_output(0, std::move(output_t)); |
188 | } |
189 | |
190 | bool IsExpensive() override { return false; } |
191 | }; |
192 | |
193 | class SqueezeOp : public OpKernel { |
194 | public: |
195 | explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
196 | std::vector<int32> squeeze_dims; |
197 | OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims" , &squeeze_dims)); |
198 | squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); |
199 | } |
200 | |
201 | void Compute(OpKernelContext* ctx) override { |
202 | OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, |
203 | errors::InvalidArgument("Squeeze on Variant not supported" )); |
204 | |
205 | auto existing_dims = ctx->input(0).shape().dim_sizes(); |
206 | const int existing_dims_size = static_cast<int>(existing_dims.size()); |
207 | std::vector<int64_t> new_shape; |
208 | |
209 | std::unordered_set<int32> wrapped_squeeze_dims; |
210 | wrapped_squeeze_dims.reserve(squeeze_dims_.size()); |
211 | // Validate squeeze dims against the input. |
212 | for (int32_t dim : squeeze_dims_) { |
213 | OP_REQUIRES( |
214 | ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), |
215 | errors::InvalidArgument("Tried to squeeze dim index " , dim, |
216 | " for tensor with " , ctx->input(0).dims(), |
217 | " dimensions." )); |
218 | // If dim is < 0, we wrap around (-1 means the last element). |
219 | if (dim < 0) { |
220 | dim = existing_dims_size + dim; |
221 | } |
222 | |
223 | wrapped_squeeze_dims.insert(dim); |
224 | } |
225 | |
226 | for (int i = 0; i < existing_dims_size; ++i) { |
227 | auto existing_dim = existing_dims[i]; |
228 | |
229 | // If squeeze_set is non-empty, only squeeze those dimensions. |
230 | if (!wrapped_squeeze_dims.empty()) { |
231 | if (wrapped_squeeze_dims.count(i) > 0) { |
232 | OP_REQUIRES(ctx, existing_dim == 1, |
233 | errors::InvalidArgument( |
234 | "Can not squeeze dim[" , i, |
235 | "], expected a dimension of 1, got " , existing_dim)); |
236 | } else { |
237 | // This dimension is not being squeezed. |
238 | new_shape.push_back(existing_dim); |
239 | } |
240 | } else { |
241 | // Copy over all non-1-length dimensions. |
242 | if (existing_dim != 1) { |
243 | new_shape.push_back(existing_dim); |
244 | } |
245 | } |
246 | } |
247 | |
248 | const TensorShape output_shape(new_shape); |
249 | Tensor* output = nullptr; |
250 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); |
251 | if (!output->CopyFrom(ctx->input(0), output_shape)) { |
252 | // This should never happen, since the sizes of the input and |
253 | // output should always be the same. |
254 | ctx->SetStatus(errors::Internal("Could not squeeze input with shape " , |
255 | ctx->input(0).shape().DebugString(), |
256 | " and output shape " , |
257 | output_shape.DebugString())); |
258 | } |
259 | } |
260 | |
261 | bool IsExpensive() override { return false; } |
262 | |
263 | private: |
264 | std::unordered_set<int32> squeeze_dims_; |
265 | }; |
266 | |
267 | } // namespace tensorflow |
268 | |
269 | #endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ |
270 | |