1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
17#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
18
19#include <memory>
20
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/framework/types.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/util/overflow.h"
29
30namespace tensorflow {
31
32// Note that this op is subclassed for QuantizedReshapeOp.
33class ReshapeOp : public OpKernel {
34 public:
35 explicit ReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
36
37 void Compute(OpKernelContext* context) override {
38 const Tensor& input = context->input(0);
39 const Tensor& sizes = context->input(1);
40 // Preliminary validation of sizes.
41 OP_REQUIRES(
42 context,
43 (TensorShapeUtils::IsVector(sizes.shape()) ||
44 // TODO(rmlarsen): Disallow legacy use of scalars to represent shape.
45 TensorShapeUtils::IsScalar(sizes.shape())),
46 errors::InvalidArgument("sizes input must be 1-D, not ",
47 sizes.shape().DebugString()));
48 OP_REQUIRES(
49 context, sizes.NumElements() < TensorShape::MaxDimensions(),
50 errors::InvalidArgument("too many dimensions: must be < ",
51 TensorShape::MaxDimensions(), ", but received ",
52 sizes.NumElements()));
53
54 // Compute the output shape. Determine product of specified
55 // dimensions, and find the index of the unspecified one.
56 TensorShape shape;
57 int64_t product = 1;
58 int unknown_index = -1;
59 bool sizes_has_zero_dim;
60 switch (sizes.dtype()) {
61 case DT_INT32:
62 OP_REQUIRES_OK(context,
63 ValidateSizes<int32>(sizes, &product, &unknown_index,
64 &shape, &sizes_has_zero_dim));
65 break;
66 case DT_INT64:
67 OP_REQUIRES_OK(context,
68 ValidateSizes<int64_t>(sizes, &product, &unknown_index,
69 &shape, &sizes_has_zero_dim));
70 break;
71 default:
72 context->CtxFailure(errors::InvalidArgument(
73 "desired shape must be a DT_INT32 or DT_INT64 vector, not a ",
74 DataTypeString(sizes.dtype())));
75 return;
76 }
77 if (unknown_index != -1) {
78 int64_t input_num_elements = 1;
79 bool input_has_zero_dim = false;
80 for (int dim = 0; dim < input.dims(); dim++) {
81 // For zero dimension, we don't count it into `input_num_elements`
82 // unless `sizes` has no zero dimension, so we are still able to
83 // infer shapes for other dimensions.
84 if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) {
85 input_num_elements *= input.dim_size(dim);
86 } else {
87 input_has_zero_dim = true;
88 }
89 }
90
91 const int64_t missing = input_num_elements / product;
92 if (!input_has_zero_dim) {
93 OP_REQUIRES(
94 context, product * missing == input_num_elements,
95 errors::InvalidArgument(
96 "Input to reshape is a tensor with ", input_num_elements,
97 " values, but the requested shape requires a multiple of ",
98 product));
99 }
100 shape.set_dim(unknown_index, missing);
101 }
102 OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
103 errors::InvalidArgument("Input to reshape is a tensor with ",
104 input.NumElements(),
105 " values, but the requested shape has ",
106 shape.num_elements()));
107
108 // Actually produce the reshaped output.
109 Tensor output(input.dtype());
110 CHECK(output.CopyFrom(input, shape));
111 context->set_output(0, output);
112 }
113
114 bool IsExpensive() override { return false; }
115
116 private:
117 template <typename Tshape>
118 Status ValidateSizes(const Tensor& sizes, int64_t* product,
119 int* unknown_index, TensorShape* shape,
120 bool* has_zero_dim) {
121 *product = 1;
122 *unknown_index = -1;
123 *has_zero_dim = false;
124 const int64_t num_dims = sizes.NumElements();
125 auto Svec = sizes.flat<Tshape>();
126 for (int d = 0; d < num_dims; ++d) {
127 const Tshape size = Svec(d);
128 if (size == -1) {
129 if (*unknown_index != -1) {
130 return errors::InvalidArgument(
131 "Only one input size may be -1, not both ", *unknown_index,
132 " and ", d);
133 }
134 *unknown_index = d;
135 shape->AddDim(1);
136 } else if (size < 0) {
137 return errors::InvalidArgument("Size ", d,
138 " must be non-negative, not ", size);
139 } else if (size == 0) {
140 // We don't include zero-sized dimension in product, so that we can
141 // still calculate number of elements for non-zero-sized dimensions and
142 // therefore infer their shapes.
143 shape->AddDim(size);
144 *has_zero_dim = true;
145 } else {
146 if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) {
147 string msg;
148 for (int ii = 0; ii < num_dims; ++ii) {
149 if (ii != 0) {
150 strings::StrAppend(&msg, ", ");
151 }
152 strings::StrAppend(&msg, Svec(ii));
153 }
154 return errors::InvalidArgument("Shape [", msg,
155 "] has too many elements");
156 }
157 shape->AddDim(size);
158 (*product) *= size;
159 }
160 }
161 return OkStatus();
162 }
163};
164
165} // namespace tensorflow
166
167#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
168