1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ |
18 | |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/util/overflow.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | // Note that this op is subclassed for QuantizedReshapeOp. |
33 | class ReshapeOp : public OpKernel { |
34 | public: |
35 | explicit ReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} |
36 | |
37 | void Compute(OpKernelContext* context) override { |
38 | const Tensor& input = context->input(0); |
39 | const Tensor& sizes = context->input(1); |
40 | // Preliminary validation of sizes. |
41 | OP_REQUIRES( |
42 | context, |
43 | (TensorShapeUtils::IsVector(sizes.shape()) || |
44 | // TODO(rmlarsen): Disallow legacy use of scalars to represent shape. |
45 | TensorShapeUtils::IsScalar(sizes.shape())), |
46 | errors::InvalidArgument("sizes input must be 1-D, not " , |
47 | sizes.shape().DebugString())); |
48 | OP_REQUIRES( |
49 | context, sizes.NumElements() < TensorShape::MaxDimensions(), |
50 | errors::InvalidArgument("too many dimensions: must be < " , |
51 | TensorShape::MaxDimensions(), ", but received " , |
52 | sizes.NumElements())); |
53 | |
54 | // Compute the output shape. Determine product of specified |
55 | // dimensions, and find the index of the unspecified one. |
56 | TensorShape shape; |
57 | int64_t product = 1; |
58 | int unknown_index = -1; |
59 | bool sizes_has_zero_dim; |
60 | switch (sizes.dtype()) { |
61 | case DT_INT32: |
62 | OP_REQUIRES_OK(context, |
63 | ValidateSizes<int32>(sizes, &product, &unknown_index, |
64 | &shape, &sizes_has_zero_dim)); |
65 | break; |
66 | case DT_INT64: |
67 | OP_REQUIRES_OK(context, |
68 | ValidateSizes<int64_t>(sizes, &product, &unknown_index, |
69 | &shape, &sizes_has_zero_dim)); |
70 | break; |
71 | default: |
72 | context->CtxFailure(errors::InvalidArgument( |
73 | "desired shape must be a DT_INT32 or DT_INT64 vector, not a " , |
74 | DataTypeString(sizes.dtype()))); |
75 | return; |
76 | } |
77 | if (unknown_index != -1) { |
78 | int64_t input_num_elements = 1; |
79 | bool input_has_zero_dim = false; |
80 | for (int dim = 0; dim < input.dims(); dim++) { |
81 | // For zero dimension, we don't count it into `input_num_elements` |
82 | // unless `sizes` has no zero dimension, so we are still able to |
83 | // infer shapes for other dimensions. |
84 | if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) { |
85 | input_num_elements *= input.dim_size(dim); |
86 | } else { |
87 | input_has_zero_dim = true; |
88 | } |
89 | } |
90 | |
91 | const int64_t missing = input_num_elements / product; |
92 | if (!input_has_zero_dim) { |
93 | OP_REQUIRES( |
94 | context, product * missing == input_num_elements, |
95 | errors::InvalidArgument( |
96 | "Input to reshape is a tensor with " , input_num_elements, |
97 | " values, but the requested shape requires a multiple of " , |
98 | product)); |
99 | } |
100 | shape.set_dim(unknown_index, missing); |
101 | } |
102 | OP_REQUIRES(context, shape.num_elements() == input.NumElements(), |
103 | errors::InvalidArgument("Input to reshape is a tensor with " , |
104 | input.NumElements(), |
105 | " values, but the requested shape has " , |
106 | shape.num_elements())); |
107 | |
108 | // Actually produce the reshaped output. |
109 | Tensor output(input.dtype()); |
110 | CHECK(output.CopyFrom(input, shape)); |
111 | context->set_output(0, output); |
112 | } |
113 | |
114 | bool IsExpensive() override { return false; } |
115 | |
116 | private: |
117 | template <typename Tshape> |
118 | Status ValidateSizes(const Tensor& sizes, int64_t* product, |
119 | int* unknown_index, TensorShape* shape, |
120 | bool* has_zero_dim) { |
121 | *product = 1; |
122 | *unknown_index = -1; |
123 | *has_zero_dim = false; |
124 | const int64_t num_dims = sizes.NumElements(); |
125 | auto Svec = sizes.flat<Tshape>(); |
126 | for (int d = 0; d < num_dims; ++d) { |
127 | const Tshape size = Svec(d); |
128 | if (size == -1) { |
129 | if (*unknown_index != -1) { |
130 | return errors::InvalidArgument( |
131 | "Only one input size may be -1, not both " , *unknown_index, |
132 | " and " , d); |
133 | } |
134 | *unknown_index = d; |
135 | shape->AddDim(1); |
136 | } else if (size < 0) { |
137 | return errors::InvalidArgument("Size " , d, |
138 | " must be non-negative, not " , size); |
139 | } else if (size == 0) { |
140 | // We don't include zero-sized dimension in product, so that we can |
141 | // still calculate number of elements for non-zero-sized dimensions and |
142 | // therefore infer their shapes. |
143 | shape->AddDim(size); |
144 | *has_zero_dim = true; |
145 | } else { |
146 | if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) { |
147 | string msg; |
148 | for (int ii = 0; ii < num_dims; ++ii) { |
149 | if (ii != 0) { |
150 | strings::StrAppend(&msg, ", " ); |
151 | } |
152 | strings::StrAppend(&msg, Svec(ii)); |
153 | } |
154 | return errors::InvalidArgument("Shape [" , msg, |
155 | "] has too many elements" ); |
156 | } |
157 | shape->AddDim(size); |
158 | (*product) *= size; |
159 | } |
160 | } |
161 | return OkStatus(); |
162 | } |
163 | }; |
164 | |
165 | } // namespace tensorflow |
166 | |
167 | #endif // TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ |
168 | |