1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
17#define TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
18
19#include <type_traits>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/tensor.h"
24#include "tensorflow/core/framework/tensor_types.h"
25#include "tensorflow/core/platform/types.h"
26
27namespace tensorflow {
28
29// Maximum number of non-collapsible blocked dimensions supported by the
30// {SpaceToBatch,BatchToSpace}ND operation. To change the limit, modify this
31// constant and the TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS macro definition
32// below.
33constexpr int kMaxSpaceToBatchBlockDims = 4;
34
35// Expands to:
36// MACRO(1, ## __VA_ARGS__)
37// ...
38// MACRO(kMaxSpaceToBatchBlockDims, ## __VA_ARGS__)
39//
40// Note: The space between the number and the comma is necessary for proper GCC
41// comma handling: https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html
42#define TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(MACRO, ...) \
43 MACRO(1 /**/, ##__VA_ARGS__) \
44 MACRO(2 /**/, ##__VA_ARGS__) \
45 MACRO(3 /**/, ##__VA_ARGS__) \
46 MACRO(4 /**/, ##__VA_ARGS__) \
47 /**/
48
49namespace internal {
50namespace spacetobatch {
51
52template <typename InputType, typename OutputType>
53void SubtleMustCopyFlatHelper(const Tensor& t, OutputType* output) {
54 const int64_t num_elements = t.shape().num_elements();
55 output->resize(num_elements);
56 auto eigen_vec = t.flat<InputType>();
57 for (int64_t i = 0; i < num_elements; ++i) {
58 (*output)[i] = SubtleMustCopy(eigen_vec(i));
59 }
60}
61
62// Copies flat contents of `t` to std::vector-like `*output`, which is resized
63// as needed. `OutputType` may be either `std::vector<int64_t>` or
64// `gtl::InlinedVector<int64_t>`.
65//
66// Precondition: t.dtype() must be either DT_INT32 or DT_INT64.
67template <typename OutputType>
68void SubtleMustCopyFlat(const Tensor& t, OutputType* output) {
69 if (t.dtype() == DT_INT32) {
70 SubtleMustCopyFlatHelper<int32, OutputType>(t, output);
71 } else {
72 SubtleMustCopyFlatHelper<int64_t, OutputType>(t, output);
73 }
74}
75
76} // namespace spacetobatch
77} // namespace internal
78
79namespace functor {
80
81// Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion.
82//
83// If B2S is false, then this performs the space-to-batch conversion. If B2S is
84// true, then this performs the inverse batch-to-space conversion.
85template <typename Device, typename T, int NUM_BLOCK_DIMS, bool B2S = false>
86struct SpaceToBatchFunctor {
87 using InputT = typename std::conditional<B2S, T, const T>::type;
88 using OutputT = typename std::conditional<B2S, const T, T>::type;
89 // Implements the space to batch conversion.
90 //
91 // space_tensor: input tensor of space-to-batch operation. If B2S = false,
92 // then this is the input to the conversion. If B2S = true, then this
93 // is the output of the conversion.
94 // block_size: array of shape [NUM_BLOCK_DIMS] specifying the block sizes for
95 // dimensions 1 through NUM_BLOCK_DIMS.
96 // paddings: row-major array of shape [NUM_BLOCK_DIMS, 2] specifying the
97 // start and end padding for dimensions 1 through NUM_BLOCK_DIMS.
98 // batch_tensor: output tensor of the space-to-batch operation. If
99 // B2S = false, then this is the output of the conversion. If B2S = true,
100 // then this is the input to the conversion.
101 //
102 // The caller must ensure that the dimensions of the tensors are correct.
103 Status operator()(
104 const Device& d,
105 typename TTypes<InputT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
106 const int64_t block_shape[NUM_BLOCK_DIMS],
107 const int64_t paddings[NUM_BLOCK_DIMS * 2],
108 typename TTypes<OutputT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor);
109};
110
111} // namespace functor
112} // namespace tensorflow
113
114#endif // TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_
115