1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <memory> |
21 | #include <string> |
22 | #include <utility> |
23 | |
24 | #include "tensorflow/core/kernels/spacetobatch_functor.h" |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/framework/op.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_types.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | typedef Eigen::ThreadPoolDevice CPUDevice; |
40 | typedef Eigen::GpuDevice GPUDevice; |
41 | |
42 | template <typename Device, typename T> |
43 | static void BatchToSpaceOpCompute(OpKernelContext* context, |
44 | const Tensor& orig_input_tensor, |
45 | const Tensor& orig_block_shape, |
46 | const Tensor& orig_crops) { |
47 | const int input_dims = orig_input_tensor.dims(); |
48 | OP_REQUIRES( |
49 | context, TensorShapeUtils::IsVector(orig_block_shape.shape()), |
50 | errors::InvalidArgument("block_shape rank should be 1 instead of " , |
51 | orig_block_shape.dims())); |
52 | |
53 | const int block_dims = orig_block_shape.dim_size(0); |
54 | OP_REQUIRES( |
55 | context, orig_input_tensor.dims() >= 1 + block_dims, |
56 | errors::InvalidArgument("input rank should be >= " , 1 + block_dims, |
57 | " instead of " , orig_input_tensor.dims())); |
58 | |
59 | OP_REQUIRES(context, |
60 | TensorShapeUtils::IsMatrix(orig_crops.shape()) && |
61 | block_dims == orig_crops.dim_size(0) && |
62 | 2 == orig_crops.dim_size(1), |
63 | errors::InvalidArgument("crops should have shape [" , block_dims, |
64 | ", 2] instead of " , |
65 | orig_crops.shape().DebugString())); |
66 | // To avoid out-of-bounds access in the case that the block_shape and/or |
67 | // crops tensors are concurrently modified, we must copy the values. |
68 | gtl::InlinedVector<int64_t, 4> block_shape; |
69 | gtl::InlinedVector<int64_t, 8> crops; |
70 | internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape); |
71 | internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops); |
72 | |
73 | // Determine the length of the prefix of block dims that can be combined |
74 | // into the batch dimension due to having no padding and block_shape=1. |
75 | int removed_prefix_block_dims = 0; |
76 | for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) { |
77 | const int dim = removed_prefix_block_dims; |
78 | if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 || |
79 | block_shape[dim] != 1) { |
80 | break; |
81 | } |
82 | } |
83 | |
84 | // Determine the length of the suffix of block dims that can be combined |
85 | // into the depth dimension due to having no padding and block_shape=1. |
86 | int removed_suffix_block_dims = 0; |
87 | for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; |
88 | ++removed_suffix_block_dims) { |
89 | const int dim = block_dims - 1 - removed_suffix_block_dims; |
90 | if (crops[2 * dim] != 0 || crops[2 * dim + 1] != 0 || |
91 | block_shape[dim] != 1) { |
92 | break; |
93 | } |
94 | } |
95 | |
96 | // Compute the product of the block_shape values. |
97 | int64_t block_shape_product = 1; |
98 | for (int block_dim = 0; block_dim < block_dims; ++block_dim) { |
99 | block_shape_product *= block_shape[block_dim]; |
100 | } |
101 | OP_REQUIRES( |
102 | context, block_shape_product > 0, |
103 | errors::InvalidArgument("Product of block sizes must be positive, got " , |
104 | block_shape_product)); |
105 | |
106 | const int64_t orig_input_batch_size = orig_input_tensor.dim_size(0); |
107 | OP_REQUIRES( |
108 | context, orig_input_batch_size % block_shape_product == 0, |
109 | errors::InvalidArgument("Input batch dimension (" , orig_input_batch_size, |
110 | ") is not divisible by product of block sizes (" , |
111 | block_shape_product, ")" )); |
112 | |
113 | const int internal_block_dims = |
114 | block_dims - removed_prefix_block_dims - removed_suffix_block_dims; |
115 | OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims, |
116 | errors::InvalidArgument( |
117 | "Maximum number of non-combined block dimensions is " , |
118 | internal_block_dims, " but must not exceed " , |
119 | kMaxSpaceToBatchBlockDims)); |
120 | |
121 | if (internal_block_dims == 0) { |
122 | context->set_output(0, orig_input_tensor); |
123 | return; |
124 | } |
125 | |
126 | // For the purpose of computing the result, the input will be treated as |
127 | // having this shape, of rank 2 + internal_block_dims. |
128 | TensorShape internal_input_shape; |
129 | |
130 | // For the purpose of computing the result, the output will be treated as |
131 | // having this shape, of rank 2 + internal_block_dims. |
132 | TensorShape internal_output_shape; |
133 | |
134 | // The actual output shape exposed to callers. |
135 | TensorShape external_output_shape; |
136 | |
137 | external_output_shape.AddDim(orig_input_batch_size / block_shape_product); |
138 | |
139 | int64_t input_batch_size = orig_input_batch_size; |
140 | for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) { |
141 | const int64_t size = orig_input_tensor.dim_size(block_dim + 1); |
142 | input_batch_size *= size; |
143 | external_output_shape.AddDim(size); |
144 | } |
145 | internal_input_shape.AddDim(input_batch_size); |
146 | internal_output_shape.AddDim(input_batch_size / block_shape_product); |
147 | |
148 | for (int block_dim = removed_prefix_block_dims; |
149 | block_dim < block_dims - removed_suffix_block_dims; ++block_dim) { |
150 | const int64_t crop_start = crops[2 * block_dim], |
151 | crop_end = crops[2 * block_dim + 1]; |
152 | OP_REQUIRES(context, crop_start >= 0 && crop_end >= 0, |
153 | errors::InvalidArgument("Crops must be non-negative" )); |
154 | const int64_t input_size = orig_input_tensor.dim_size(block_dim + 1); |
155 | const int64_t block_shape_value = block_shape[block_dim]; |
156 | const int64_t cropped_size = |
157 | input_size * block_shape_value - crop_start - crop_end; |
158 | OP_REQUIRES(context, cropped_size >= 0, |
159 | errors::InvalidArgument("cropped_shape[" , block_dim, "]=" , |
160 | cropped_size, " must be non-negative" )); |
161 | internal_input_shape.AddDim(input_size); |
162 | internal_output_shape.AddDim(cropped_size); |
163 | external_output_shape.AddDim(cropped_size); |
164 | } |
165 | |
166 | int64_t depth = 1; |
167 | for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims; |
168 | ++dim) { |
169 | const int64_t size = orig_input_tensor.dim_size(dim); |
170 | external_output_shape.AddDim(size); |
171 | depth *= size; |
172 | } |
173 | internal_input_shape.AddDim(depth); |
174 | internal_output_shape.AddDim(depth); |
175 | |
176 | // Allocate output tensor. |
177 | Tensor* output_tensor = nullptr; |
178 | OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape, |
179 | &output_tensor)); |
180 | |
181 | const int64_t* internal_crops = &crops[2 * removed_prefix_block_dims]; |
182 | const int64_t* internal_block_shape = &block_shape[removed_prefix_block_dims]; |
183 | |
184 | switch (internal_block_dims) { |
185 | #define TF_BATCHTOSPACE_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \ |
186 | case NUM_BLOCK_DIMS: { \ |
187 | OP_REQUIRES_OK( \ |
188 | context, \ |
189 | (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, true>()( \ |
190 | context->eigen_device<Device>(), \ |
191 | output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \ |
192 | internal_output_shape.dim_sizes()), \ |
193 | internal_block_shape, internal_crops, \ |
194 | orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \ |
195 | internal_input_shape.dim_sizes())))); \ |
196 | } break; \ |
197 | /**/ |
198 | TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_BATCHTOSPACE_BLOCK_DIMS_CASE) |
199 | #undef TF_BATCHTOSPACE_BLOCK_DIMS_CASE |
200 | } |
201 | } |
202 | |
203 | template <typename Device, typename T> |
204 | class BatchToSpaceNDOp : public OpKernel { |
205 | public: |
206 | explicit BatchToSpaceNDOp(OpKernelConstruction* context) |
207 | : OpKernel(context) {} |
208 | |
209 | void Compute(OpKernelContext* context) override { |
210 | const Tensor& orig_input_tensor = context->input(0); |
211 | const Tensor& orig_block_shape = context->input(1); |
212 | const Tensor& orig_crops = context->input(2); |
213 | BatchToSpaceOpCompute<Device, T>(context, orig_input_tensor, |
214 | orig_block_shape, orig_crops); |
215 | } |
216 | }; |
217 | |
218 | template <typename Device, typename T> |
219 | class BatchToSpaceOp : public OpKernel { |
220 | public: |
221 | explicit BatchToSpaceOp(OpKernelConstruction* context) : OpKernel(context) { |
222 | OP_REQUIRES_OK(context, context->GetAttr("block_size" , &block_size_)); |
223 | OP_REQUIRES( |
224 | context, block_size_ > 1, |
225 | errors::InvalidArgument("Block size should be > 1: " , block_size_)); |
226 | block_shape_ = Tensor(tensorflow::DT_INT64, TensorShape({2})); |
227 | auto block_shape_vec = block_shape_.vec<int64_t>(); |
228 | block_shape_vec(0) = block_size_; |
229 | block_shape_vec(1) = block_size_; |
230 | } |
231 | |
232 | void Compute(OpKernelContext* context) override { |
233 | const Tensor& in0 = context->input(0); |
234 | const Tensor& in1 = context->input(1); |
235 | const int dims = in0.dims(); |
236 | |
237 | // Check on the input dimensions first. |
238 | // The input is presumed to be [batch, height, width, depth] |
239 | static const int kRequiredDims = 4; |
240 | OP_REQUIRES(context, kRequiredDims == dims, |
241 | errors::InvalidArgument("Input rank should be: " , kRequiredDims, |
242 | "instead of: " , dims)); |
243 | BatchToSpaceOpCompute<Device, T>(context, in0, block_shape_, in1); |
244 | } |
245 | |
246 | private: |
247 | int block_size_; |
248 | Tensor block_shape_; |
249 | }; |
250 | |
251 | #define REGISTER(T) \ |
252 | REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \ |
253 | .Device(DEVICE_CPU) \ |
254 | .TypeConstraint<T>("T") \ |
255 | .HostMemory("block_shape") \ |
256 | .HostMemory("crops"), \ |
257 | BatchToSpaceNDOp<CPUDevice, T>); \ |
258 | REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \ |
259 | .Device(DEVICE_CPU) \ |
260 | .TypeConstraint<T>("T") \ |
261 | .HostMemory("crops"), \ |
262 | BatchToSpaceOp<CPUDevice, T>); |
263 | |
264 | TF_CALL_REAL_NUMBER_TYPES(REGISTER); |
265 | #undef REGISTER |
266 | |
267 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
268 | #define REGISTER(T) \ |
269 | REGISTER_KERNEL_BUILDER(Name("BatchToSpaceND") \ |
270 | .Device(DEVICE_GPU) \ |
271 | .TypeConstraint<T>("T") \ |
272 | .HostMemory("block_shape") \ |
273 | .HostMemory("crops"), \ |
274 | BatchToSpaceNDOp<GPUDevice, T>); \ |
275 | REGISTER_KERNEL_BUILDER(Name("BatchToSpace") \ |
276 | .Device(DEVICE_GPU) \ |
277 | .TypeConstraint<T>("T") \ |
278 | .HostMemory("crops"), \ |
279 | BatchToSpaceOp<GPUDevice, T>); |
280 | |
281 | TF_CALL_GPU_NUMBER_TYPES(REGISTER); |
282 | #undef REGISTER |
283 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
284 | |
285 | } // end namespace tensorflow |
286 | |