1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Specialization of SpaceToBatchFunctor for a CPUDevice.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/spacetobatch_functor.h"
21
22#include "tensorflow/core/framework/register_types.h"
23
24namespace tensorflow {
25
26typedef Eigen::ThreadPoolDevice CPUDevice;
27
28namespace functor {
29
30namespace {
31
32// Implementation of nested loops for SpaceToBatchOpFunctor.
33//
34// To simplify template implementation given lack of constexpr if, both the
35// input and output pointers are non-const.
36template <int N, bool B2S>
37struct SpaceToBatchHelper {
38 template <typename T>
39 static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
40 const int64_t* space_tensor_strides,
41 const int64_t* block_shape, const int64_t* pad_start,
42 const int64_t* block_offsets,
43 const int64_t* batch_tensor_shape,
44 const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {
45 for (int64_t batch_tensor_pos = 0; batch_tensor_pos < batch_tensor_shape[0];
46 ++batch_tensor_pos) {
47 const int64_t space_tensor_pos =
48 batch_tensor_pos * block_shape[0] + block_offsets[0] - pad_start[0];
49 if (space_tensor_pos >= 0 && space_tensor_pos < space_tensor_shape[0]) {
50 SpaceToBatchHelper<N - 1, B2S>::run(
51 space_tensor_ptr + space_tensor_pos * space_tensor_strides[0],
52 space_tensor_shape + 1, space_tensor_strides + 1, block_shape + 1,
53 pad_start + 1, block_offsets + 1, batch_tensor_shape + 1,
54 batch_tensor_strides + 1, batch_tensor_ptr);
55 } else {
56 if (B2S == false) {
57 // Copy in padding.
58 for (int64_t i = 0; i < batch_tensor_strides[0]; ++i) {
59 batch_tensor_ptr[i] = static_cast<T>(0);
60 }
61 }
62 }
63 batch_tensor_ptr += batch_tensor_strides[0];
64 }
65 }
66};
67
68template <bool B2S>
69struct SpaceToBatchHelper<0, B2S> {
70 template <typename T>
71 static void run(T* space_tensor_ptr, const int64_t* space_tensor_shape,
72 const int64_t* space_tensor_strides,
73 const int64_t* block_shape, const int64_t* pad_start,
74 const int64_t* block_offsets,
75 const int64_t* batch_tensor_shape,
76 const int64_t* batch_tensor_strides, T* batch_tensor_ptr) {
77 for (int64_t i = 0; i < batch_tensor_strides[-1]; ++i) {
78 if (B2S == false) {
79 batch_tensor_ptr[i] = space_tensor_ptr[i];
80 } else {
81 space_tensor_ptr[i] = batch_tensor_ptr[i];
82 }
83 }
84 }
85};
86
87} // namespace
88
89template <typename T, int NUM_BLOCK_DIMS, bool B2S>
90struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, B2S> {
91 using SpaceT = typename std::conditional<B2S, T, const T>::type;
92 using BatchT = typename std::conditional<B2S, const T, T>::type;
93 Status operator()(
94 const CPUDevice& d,
95 typename TTypes<SpaceT, NUM_BLOCK_DIMS + 2>::Tensor space_tensor,
96 const int64_t block_shape_tensor[NUM_BLOCK_DIMS],
97 const int64_t paddings_tensor[NUM_BLOCK_DIMS * 2],
98 typename TTypes<BatchT, NUM_BLOCK_DIMS + 2>::Tensor batch_tensor) {
99 const int64_t batch_tensor_batch = batch_tensor.dimension(0);
100
101 const int64_t space_tensor_batch = space_tensor.dimension(0);
102
103 // Copy into local array so that the compiler is free to place in a
104 // register.
105 int64_t pad_start[NUM_BLOCK_DIMS];
106 int64_t block_shape[NUM_BLOCK_DIMS];
107 int64_t space_tensor_shape[NUM_BLOCK_DIMS],
108 batch_tensor_shape[NUM_BLOCK_DIMS];
109 for (int block_dim = 0; block_dim < NUM_BLOCK_DIMS; ++block_dim) {
110 pad_start[block_dim] = paddings_tensor[block_dim * 2];
111 block_shape[block_dim] = block_shape_tensor[block_dim];
112 space_tensor_shape[block_dim] = space_tensor.dimension(block_dim + 1);
113 batch_tensor_shape[block_dim] = batch_tensor.dimension(block_dim + 1);
114 }
115
116 int64_t space_tensor_strides[NUM_BLOCK_DIMS + 2],
117 batch_tensor_strides[NUM_BLOCK_DIMS + 2];
118 space_tensor_strides[NUM_BLOCK_DIMS + 1] =
119 batch_tensor_strides[NUM_BLOCK_DIMS + 1] = 1;
120 for (int dim = NUM_BLOCK_DIMS; dim >= 0; --dim) {
121 space_tensor_strides[dim] =
122 space_tensor_strides[dim + 1] * space_tensor.dimension(dim + 1);
123 batch_tensor_strides[dim] =
124 batch_tensor_strides[dim + 1] * batch_tensor.dimension(dim + 1);
125 }
126
127 // Use non-const pointers for both input and output to simplify template
128 // implementation given lack of constexpr if.
129 T* space_tensor_ptr = const_cast<T*>(space_tensor.data());
130 T* batch_tensor_ptr = const_cast<T*>(batch_tensor.data());
131
132 for (int64_t batch_tensor_b = 0; batch_tensor_b < batch_tensor_batch;
133 ++batch_tensor_b) {
134 const int64_t space_tensor_b = batch_tensor_b % space_tensor_batch;
135 int64_t block_index = batch_tensor_b / space_tensor_batch;
136 int64_t block_offsets[NUM_BLOCK_DIMS];
137 for (int block_dim = NUM_BLOCK_DIMS - 1; block_dim >= 0; --block_dim) {
138 // Skip unnecessary remainder operation for block_dim == 0.
139 block_offsets[block_dim] =
140 block_dim > 0 ? block_index % block_shape[block_dim] : block_index;
141 block_index /= block_shape[block_dim];
142 }
143
144 // The compiler should inline the nested loops generated by this template.
145 SpaceToBatchHelper<NUM_BLOCK_DIMS, B2S>::run(
146 space_tensor_ptr + space_tensor_b * space_tensor_strides[0],
147 space_tensor_shape, &space_tensor_strides[1], block_shape, pad_start,
148 block_offsets, batch_tensor_shape, &batch_tensor_strides[1],
149 batch_tensor_ptr + batch_tensor_b * batch_tensor_strides[0]);
150 }
151 return OkStatus();
152 }
153};
154
155// Instantiate.
156#define INSTANTIATE(NUM_BLOCK_DIMS, T) \
157 template struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, false>; \
158 template struct SpaceToBatchFunctor<CPUDevice, T, NUM_BLOCK_DIMS, true>; \
159 /**/
160
161#define INSTANTIATE_FOR_T(T) \
162 TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(INSTANTIATE, T)
163
164TF_CALL_REAL_NUMBER_TYPES(INSTANTIATE_FOR_T)
165
166#undef INSTANTIATE_FOR_T
167#undef INSTANTIATE
168
169} // namespace functor
170} // end namespace tensorflow
171