1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <memory>
21#include <string>
22#include <utility>
23
24#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25#include "tensorflow/core/framework/op.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/framework/tensor_types.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/kernels/spacetobatch_functor.h"
33#include "tensorflow/core/platform/logging.h"
34#include "tensorflow/core/platform/types.h"
35#include "tensorflow/core/util/overflow.h"
36
37namespace tensorflow {
38
39typedef Eigen::ThreadPoolDevice CPUDevice;
40typedef Eigen::GpuDevice GPUDevice;
41
42namespace {
43
44template <typename Device, typename T>
45Status SpaceToBatchOpCompute(OpKernelContext* context,
46 const Tensor& orig_input_tensor,
47 const Tensor& orig_block_shape,
48 const Tensor& orig_paddings) {
49 const int input_dims = orig_input_tensor.dims();
50 if (!TensorShapeUtils::IsVector(orig_block_shape.shape())) {
51 return errors::InvalidArgument("block_shape rank should be 1 instead of ",
52 orig_block_shape.dims());
53 }
54
55 const int block_dims = orig_block_shape.dim_size(0);
56 if (orig_input_tensor.dims() < 1 + block_dims) {
57 return errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
58 " instead of ", orig_input_tensor.dims());
59 }
60
61 if (!(TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
62 block_dims == orig_paddings.dim_size(0) &&
63 2 == orig_paddings.dim_size(1))) {
64 return errors::InvalidArgument("paddings should have shape [", block_dims,
65 ", 2] instead of ",
66 orig_paddings.shape().DebugString());
67 }
68
69 // To avoid out-of-bounds access in the case that the block_shape and/or
70 // paddings tensors are concurrently modified, we must copy the values.
71 gtl::InlinedVector<int64_t, 4> block_shape;
72 gtl::InlinedVector<int64_t, 8> paddings;
73 internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
74 internal::spacetobatch::SubtleMustCopyFlat(orig_paddings, &paddings);
75
76 // Determine the length of the prefix of block dims that can be combined
77 // into the batch dimension due to having no padding and block_shape=1.
78 int removed_prefix_block_dims = 0;
79 for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
80 const int dim = removed_prefix_block_dims;
81 if (paddings[2 * dim] != 0 || paddings[2 * dim + 1] != 0 ||
82 block_shape[dim] != 1) {
83 break;
84 }
85 }
86
87 // Determine the length of the suffix of block dims that can be combined
88 // into the depth dimension due to having no padding and block_shape=1.
89 int removed_suffix_block_dims = 0;
90 for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
91 ++removed_suffix_block_dims) {
92 const int dim = block_dims - 1 - removed_suffix_block_dims;
93 if (paddings[dim * 2] != 0 || paddings[dim * 2 + 1] != 0 ||
94 block_shape[dim] != 1) {
95 break;
96 }
97 }
98
99 // Compute the product of the block_shape values.
100 int64_t block_shape_product = 1;
101 for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
102 if (block_shape[block_dim] < 1) {
103 return errors::InvalidArgument(
104 "All values in block_shape must be positive, got value, ",
105 block_shape[block_dim], " at index ", block_dim, ".");
106 }
107 block_shape_product =
108 MultiplyWithoutOverflow(block_shape_product, block_shape[block_dim]);
109 }
110 if (block_shape_product <= 0) {
111 return errors::InvalidArgument(
112 "Product of block sizes must be positive, got ", block_shape_product);
113 }
114
115 const int internal_block_dims =
116 block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
117 if (internal_block_dims > kMaxSpaceToBatchBlockDims) {
118 return errors::InvalidArgument(
119 "Maximum number of non-combined block dimensions is ",
120 internal_block_dims, " but must not exceed ",
121 kMaxSpaceToBatchBlockDims);
122 }
123
124 if (internal_block_dims == 0) {
125 context->set_output(0, orig_input_tensor);
126 return OkStatus();
127 }
128
129 // For the purpose of computing the result, the input will be treated as
130 // having this shape, of rank 2 + internal_block_dims.
131 TensorShape internal_input_shape;
132
133 // For the purpose of computing the result, the output will be treated as
134 // having this shape, of rank 2 + internal_block_dims.
135 TensorShape internal_output_shape;
136
137 // The actual output shape exposed to callers.
138 TensorShape external_output_shape;
139
140 const int64_t output_shape = MultiplyWithoutOverflow(
141 orig_input_tensor.dim_size(0), block_shape_product);
142 if (output_shape < 0) {
143 return errors::InvalidArgument(
144 "Negative output dimension size caused by overflow when multiplying ",
145 orig_input_tensor.dim_size(0), " and ", block_shape_product);
146 }
147 external_output_shape.AddDim(output_shape);
148
149 int64_t input_batch_size = orig_input_tensor.dim_size(0);
150 for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
151 const int64_t size = orig_input_tensor.dim_size(block_dim + 1);
152 input_batch_size *= size;
153 external_output_shape.AddDim(size);
154 }
155 internal_input_shape.AddDim(input_batch_size);
156 internal_output_shape.AddDim(input_batch_size * block_shape_product);
157
158 for (int block_dim = removed_prefix_block_dims;
159 block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
160 const int64_t pad_start = paddings[2 * block_dim],
161 pad_end = paddings[2 * block_dim + 1];
162 if (pad_start < 0 || pad_end < 0) {
163 return errors::InvalidArgument("Paddings must be non-negative");
164 }
165 const int64_t input_size = orig_input_tensor.dim_size(block_dim + 1);
166 const int64_t block_shape_value = block_shape[block_dim];
167 const int64_t padded_size = input_size + pad_start + pad_end;
168 if (padded_size % block_shape_value != 0) {
169 return errors::InvalidArgument("padded_shape[", block_dim,
170 "]=", padded_size,
171 " is not divisible by block_shape[",
172 block_dim, "]=", block_shape_value);
173 }
174 internal_input_shape.AddDim(input_size);
175 const int64_t output_size = padded_size / block_shape_value;
176 internal_output_shape.AddDim(output_size);
177 external_output_shape.AddDim(output_size);
178 }
179
180 int64_t depth = 1;
181 for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
182 ++dim) {
183 const int64_t size = orig_input_tensor.dim_size(dim);
184 external_output_shape.AddDim(size);
185 depth *= size;
186 }
187 internal_input_shape.AddDim(depth);
188 internal_output_shape.AddDim(depth);
189
190 // Allocate output tensor.
191 Tensor* output_tensor = nullptr;
192 TF_RETURN_IF_ERROR(
193 context->allocate_output(0, external_output_shape, &output_tensor));
194
195 const int64_t* internal_paddings = &paddings[2 * removed_prefix_block_dims];
196 const int64_t* internal_block_shape = &block_shape[removed_prefix_block_dims];
197
198 switch (internal_block_dims) {
199#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
200 case NUM_BLOCK_DIMS: { \
201 TF_RETURN_IF_ERROR( \
202 functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
203 context->eigen_device<Device>(), \
204 orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
205 internal_input_shape.dim_sizes()), \
206 internal_block_shape, internal_paddings, \
207 output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
208 internal_output_shape.dim_sizes()))); \
209 } break; \
210 /**/
211 TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE)
212#undef TF_SPACETOBATCH_BLOCK_DIMS_CASE
213 }
214 return OkStatus();
215}
216
217} // namespace
218
219template <typename Device, typename T>
220class SpaceToBatchNDOp : public OpKernel {
221 public:
222 explicit SpaceToBatchNDOp(OpKernelConstruction* context)
223 : OpKernel(context) {}
224
225 void Compute(OpKernelContext* context) override {
226 const Tensor& orig_input_tensor = context->input(0);
227 const Tensor& orig_block_shape = context->input(1);
228 const Tensor& orig_paddings = context->input(2);
229 OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
230 context, orig_input_tensor, orig_block_shape,
231 orig_paddings));
232 }
233};
234
235template <typename Device, typename T>
236class SpaceToBatchOp : public OpKernel {
237 public:
238 explicit SpaceToBatchOp(OpKernelConstruction* context) : OpKernel(context) {
239 OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
240 OP_REQUIRES(
241 context, block_size_ > 1,
242 errors::InvalidArgument("Block size should be > 1: ", block_size_));
243 block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
244 auto block_shape_vec = block_shape_.vec<int64_t>();
245 block_shape_vec(0) = block_size_;
246 block_shape_vec(1) = block_size_;
247 }
248
249 void Compute(OpKernelContext* context) override {
250 const Tensor& in0 = context->input(0);
251 const Tensor& in1 = context->input(1);
252 const int dims = in0.dims();
253
254 static const int kRequiredDims = 4;
255 OP_REQUIRES(context, kRequiredDims == dims,
256 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
257 "instead of: ", dims));
258 OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
259 context, in0, block_shape_, in1));
260 }
261
262 private:
263 int block_size_;
264 Tensor block_shape_;
265};
266
267#define REGISTER(T) \
268 REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
269 .Device(DEVICE_CPU) \
270 .TypeConstraint<T>("T") \
271 .HostMemory("block_shape") \
272 .HostMemory("paddings"), \
273 SpaceToBatchNDOp<CPUDevice, T>); \
274 REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
275 .Device(DEVICE_CPU) \
276 .TypeConstraint<T>("T") \
277 .HostMemory("paddings"), \
278 SpaceToBatchOp<CPUDevice, T>);
279
280TF_CALL_REAL_NUMBER_TYPES(REGISTER);
281#undef REGISTER
282
283#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
284#define REGISTER(T) \
285 REGISTER_KERNEL_BUILDER(Name("SpaceToBatchND") \
286 .Device(DEVICE_GPU) \
287 .TypeConstraint<T>("T") \
288 .HostMemory("block_shape") \
289 .HostMemory("paddings"), \
290 SpaceToBatchNDOp<GPUDevice, T>); \
291 REGISTER_KERNEL_BUILDER(Name("SpaceToBatch") \
292 .Device(DEVICE_GPU) \
293 .TypeConstraint<T>("T") \
294 .HostMemory("paddings"), \
295 SpaceToBatchOp<GPUDevice, T>);
296
297TF_CALL_GPU_NUMBER_TYPES(REGISTER);
298#undef REGISTER
299#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
300
301} // end namespace tensorflow
302