1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <memory>
21#include <string>
22#include <utility>
23
24#include "tensorflow/core/kernels/spacetobatch_functor.h"
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/framework/tensor_types.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/types.h"
36
37namespace tensorflow {
38
39typedef Eigen::ThreadPoolDevice CPUDevice;
40typedef Eigen::GpuDevice GPUDevice;
41
42template <typename Device, typename T>
43static void BatchToSpaceOpCompute(OpKernelContext* context,
44 const Tensor& orig_input_tensor,
45 const Tensor& orig_block_shape,
46 const Tensor& orig_crops) {
47 const int input_dims = orig_input_tensor.dims();
48 OP_REQUIRES(
49 context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
50 errors::InvalidArgument("block_shape rank should be 1 instead of ",
51 orig_block_shape.dims()));
52
53 const int block_dims = orig_block_shape.dim_size(0);
54 OP_REQUIRES(
55 context, orig_input_tensor.dims() >= 1 + block_dims,
56 errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
57 " instead of ", orig_input_tensor.dims()));
58
59 OP_REQUIRES(context,
60 TensorShapeUtils::IsMatrix(orig_crops.shape()) &&
61 block_dims == orig_crops.dim_size(0) &&
62 2 == orig_crops.dim_size(1),
63 errors::InvalidArgument("crops should have shape [", block_dims,
64 ", 2] instead of ",
65 orig_crops.shape().DebugString()));
66 // To avoid out-of-bounds access in the case that the block_shape and/or
67 // crops tensors are concurrently modified, we must copy the values.
68 gtl::InlinedVector<int64_t, 4> block_shape;
69 gtl::InlinedVector<int64_t, 8> crops;
70 internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape);
71 internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops);
72
73 // Determine the length of the prefix of block dims that can be combined
74 // into the batch dimension due to having no padding and block_shape=1.
75 int removed_prefix_block_dims = 0;
76 for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
77 const int dim = removed_prefix_block_dims;
78 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
79 block_shape[dim] != 1) {
80 break;
81 }
82 }
83
84 // Determine the length of the suffix of block dims that can be combined
85 // into the depth dimension due to having no padding and block_shape=1.
86 int removed_suffix_block_dims = 0;
87 for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims;
88 ++removed_suffix_block_dims) {
89 const int dim = block_dims - 1 - removed_suffix_block_dims;
90 if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 ||
91 block_shape[dim] != 1) {
92 break;
93 }
94 }
95
96 // Compute the product of the block_shape values.
97 int64_t block_shape_product = 1;
98 for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
99 block_shape_product *= block_shape[block_dim];
100 }
101 OP_REQUIRES(
102 context, block_shape_product > 0,
103 errors::InvalidArgument("Product of block sizes must be positive, got ",
104 block_shape_product));
105
106 const int64_t orig_input_batch_size = orig_input_tensor.dim_size(0);
107 OP_REQUIRES(
108 context, orig_input_batch_size % block_shape_product == 0,
109 errors::InvalidArgument("Input batch dimension (", orig_input_batch_size,
110 ") is not divisible by product of block sizes (",
111 block_shape_product, ")"));
112
113 const int internal_block_dims =
114 block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
115 OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
116 errors::InvalidArgument(
117 "Maximum number of non-combined block dimensions is ",
118 internal_block_dims, " but must not exceed ",
119 kMaxSpaceToBatchBlockDims));
120
121 if (internal_block_dims == 0) {
122 context->set_output(0, orig_input_tensor);
123 return;
124 }
125
126 // For the purpose of computing the result, the input will be treated as
127 // having this shape, of rank 2 + internal_block_dims.
128 TensorShape internal_input_shape;
129
130 // For the purpose of computing the result, the output will be treated as
131 // having this shape, of rank 2 + internal_block_dims.
132 TensorShape internal_output_shape;
133
134 // The actual output shape exposed to callers.
135 TensorShape external_output_shape;
136
137 external_output_shape.AddDim(orig_input_batch_size / block_shape_product);
138
139 int64_t input_batch_size = orig_input_batch_size;
140 for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
141 const int64_t size = orig_input_tensor.dim_size(block_dim + 1);
142 input_batch_size *= size;
143 external_output_shape.AddDim(size);
144 }
145 internal_input_shape.AddDim(input_batch_size);
146 internal_output_shape.AddDim(input_batch_size / block_shape_product);
147
148 for (int block_dim = removed_prefix_block_dims;
149 block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
150 const int64_t crop_start = crops[2 * block_dim],
151 crop_end = crops[2 * block_dim + 1];
152 OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0,
153 errors::InvalidArgument("Crops must be non-negative"));
154 const int64_t input_size = orig_input_tensor.dim_size(block_dim + 1);
155 const int64_t block_shape_value = block_shape[block_dim];
156 const int64_t cropped_size =
157 input_size * block_shape_value - crop_start - crop_end;
158 OP_REQUIRES(context, cropped_size >= 0,
159 errors::InvalidArgument("cropped_shape[", block_dim, "]=",
160 cropped_size, " must be non-negative"));
161 internal_input_shape.AddDim(input_size);
162 internal_output_shape.AddDim(cropped_size);
163 external_output_shape.AddDim(cropped_size);
164 }
165
166 int64_t depth = 1;
167 for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims;
168 ++dim) {
169 const int64_t size = orig_input_tensor.dim_size(dim);
170 external_output_shape.AddDim(size);
171 depth *= size;
172 }
173 internal_input_shape.AddDim(depth);
174 internal_output_shape.AddDim(depth);
175
176 // Allocate output tensor.
177 Tensor* output_tensor = nullptr;
178 OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
179 &output_tensor));
180
181 const int64_t* internal_crops = &crops[2 * removed_prefix_block_dims];
182 const int64_t* internal_block_shape = &block_shape[removed_prefix_block_dims];
183
184 switch (internal_block_dims) {
185#define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
186 case NUM_BLOCK_DIMS: { \
187 OP_REQUIRES_OK( \
188 context, \
189 (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \
190 context->eigen_device<Device>(), \
191 output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
192 internal_output_shape.dim_sizes()), \
193 internal_block_shape, internal_crops, \
194 orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
195 internal_input_shape.dim_sizes())))); \
196 } break; \
197 /**/
198 TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE)
199#undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE
200 }
201}
202
203template <typename Device, typename T>
204class BatchToSpaceNDOp : public OpKernel {
205 public:
206 explicit BatchToSpaceNDOp(OpKernelConstruction* context)
207 : OpKernel(context) {}
208
209 void Compute(OpKernelContext* context) override {
210 const Tensor& orig_input_tensor = context->input(0);
211 const Tensor& orig_block_shape = context->input(1);
212 const Tensor& orig_crops = context->input(2);
213 BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor,
214 orig_block_shape, orig_crops);
215 }
216};
217
218template <typename Device, typename T>
219class BatchToSpaceOp : public OpKernel {
220 public:
221 explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) {
222 OP_REQUIRES_OK(context, context->GetAttr("block_size", &block_size_));
223 OP_REQUIRES(
224 context, block_size_ > 1,
225 errors::InvalidArgument("Block size should be > 1: ", block_size_));
226 block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2}));
227 auto block_shape_vec = block_shape_.vec<int64_t>();
228 block_shape_vec(0) = block_size_;
229 block_shape_vec(1) = block_size_;
230 }
231
232 void Compute(OpKernelContext* context) override {
233 const Tensor& in0 = context->input(0);
234 const Tensor& in1 = context->input(1);
235 const int dims = in0.dims();
236
237 // Check on the input dimensions first.
238 // The input is presumed to be [batch, height, width, depth]
239 static const int kRequiredDims = 4;
240 OP_REQUIRES(context, kRequiredDims == dims,
241 errors::InvalidArgument("Input rank should be: ", kRequiredDims,
242 "instead of: ", dims));
243 BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1);
244 }
245
246 private:
247 int block_size_;
248 Tensor block_shape_;
249};
250
251#define REGISTER(T) \
252 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
253 .Device(DEVICE_CPU) \
254 .TypeConstraint<T>("T") \
255 .HostMemory("block_shape") \
256 .HostMemory("crops"), \
257 BatchToSpaceNDOp<CPUDevice, T>); \
258 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
259 .Device(DEVICE_CPU) \
260 .TypeConstraint<T>("T") \
261 .HostMemory("crops"), \
262 BatchToSpaceOp<CPUDevice, T>);
263
264TF_CALL_REAL_NUMBER_TYPES(REGISTER);
265#undef REGISTER
266
267#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
268#define REGISTER(T) \
269 REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \
270 .Device(DEVICE_GPU) \
271 .TypeConstraint<T>("T") \
272 .HostMemory("block_shape") \
273 .HostMemory("crops"), \
274 BatchToSpaceNDOp<GPUDevice, T>); \
275 REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \
276 .Device(DEVICE_GPU) \
277 .TypeConstraint<T>("T") \
278 .HostMemory("crops"), \
279 BatchToSpaceOp<GPUDevice, T>);
280
281TF_CALL_GPU_NUMBER_TYPES(REGISTER);
282#undef REGISTER
283#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
284
285} // end namespace tensorflow
286