1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ |
18 | |
19 | #include <type_traits> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_types.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Maximum number of non-collapsible blocked dimensions supported by the |
30 | // {SpaceToBatch,BatchToSpace}ND operation. To change the limit, modify this |
31 | // constant and the TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS macro definition |
32 | // below. |
33 | constexpr int kMaxSpaceToBatchBlockDims = 4; |
34 | |
35 | // Expands to: |
36 | // MACRO(1, ## __VA_ARGS__) |
37 | // ... |
38 | // MACRO(kMaxSpaceToBatchBlockDims, ## __VA_ARGS__) |
39 | // |
40 | // Note: The space between the number and the comma is necessary for proper GCC |
41 | // comma handling: https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html |
42 | #define TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(MACRO, ...) \ |
43 | MACRO(1 /**/, ##__VA_ARGS__) \ |
44 | MACRO(2 /**/, ##__VA_ARGS__) \ |
45 | MACRO(3 /**/, ##__VA_ARGS__) \ |
46 | MACRO(4 /**/, ##__VA_ARGS__) \ |
47 | /**/ |
48 | |
49 | namespace internal { |
50 | namespace spacetobatch { |
51 | |
52 | template <typename InputType, typename OutputType> |
53 | void SubtleMustCopyFlatHelper(const Tensor& t, OutputType* output) { |
54 | const int64_t num_elements = t.shape().num_elements(); |
55 | output->resize(num_elements); |
56 | auto eigen_vec = t.flat<InputType>(); |
57 | for (int64_t i = 0; i < num_elements; ++i) { |
58 | (*output)[i] = SubtleMustCopy(eigen_vec(i)); |
59 | } |
60 | } |
61 | |
62 | // Copies flat contents of `t` to std::vector-like `*output`, which is resized |
63 | // as needed. `OutputType` may be either `std::vector<int64_t>` or |
64 | // `gtl::InlinedVector<int64_t>`. |
65 | // |
66 | // Precondition: t.dtype() must be either DT_INT32 or DT_INT64. |
67 | template <typename OutputType> |
68 | void SubtleMustCopyFlat(const Tensor& t, OutputType* output) { |
69 | if (t.dtype() == DT_INT32) { |
70 | SubtleMustCopyFlatHelper<int32, OutputType>(t, output); |
71 | } else { |
72 | SubtleMustCopyFlatHelper<int64_t, OutputType>(t, output); |
73 | } |
74 | } |
75 | |
76 | } // namespace spacetobatch |
77 | } // namespace internal |
78 | |
79 | namespace functor { |
80 | |
81 | // Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion. |
82 | // |
83 | // If B2S is false, then this performs the space-to-batch conversion. If B2S is |
84 | // true, then this performs the inverse batch-to-space conversion. |
85 | template <typename Device, typename T, int NUM_BLOCK_DIMS, bool B2S = false> |
86 | struct SpaceToBatchFunctor { |
87 | using InputT = typename std::conditional<B2S, T, const T>::type; |
88 | using OutputT = typename std::conditional<B2S, const T, T>::type; |
89 | // Implements the space to batch conversion. |
90 | // |
91 | // space_tensor: input tensor of space-to-batch operation. If B2S = false, |
92 | // then this is the input to the conversion. If B2S = true, then this |
93 | // is the output of the conversion. |
94 | // block_size: array of shape [NUM_BLOCK_DIMS] specifying the block sizes for |
95 | // dimensions 1 through NUM_BLOCK_DIMS. |
96 | // paddings: row-major array of shape [NUM_BLOCK_DIMS, 2] specifying the |
97 | // start and end padding for dimensions 1 through NUM_BLOCK_DIMS. |
98 | // batch_tensor: output tensor of the space-to-batch operation. If |
99 | // B2S = false, then this is the output of the conversion. If B2S = true, |
100 | // then this is the input to the conversion. |
101 | // |
102 | // The caller must ensure that the dimensions of the tensors are correct. |
103 | Status operator()( |
104 | const Device& d, |
105 | typename TTypes<InputT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor, |
106 | const int64_t block_shape[NUM_BLOCK_DIMS], |
107 | const int64_t paddings[NUM_BLOCK_DIMS * 2], |
108 | typename TTypes<OutputT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor); |
109 | }; |
110 | |
111 | } // namespace functor |
112 | } // namespace tensorflow |
113 | |
114 | #endif // TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ |
115 | |