1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/spacetodepth_op.h" |
21 | |
22 | #include <memory> |
23 | #include <string> |
24 | #include <utility> |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/framework/op.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_types.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | #include "tensorflow/core/util/tensor_format.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | namespace { |
41 | template <typename T> |
42 | struct RawType { |
43 | using type = T; |
44 | }; |
45 | |
46 | template <> |
47 | struct RawType<qint8> { |
48 | // spacetodepth_op_gpu.cu.cc does not instantiate SpaceToDepthOpFunctor for |
49 | // int8, so we map qint8 to uint8. Instantiating int8 could slow down |
50 | // compilation and the code generated is almost the same as for uint8. |
51 | using type = uint8; |
52 | }; |
53 | } // namespace |
54 | |
55 | typedef Eigen::ThreadPoolDevice CPUDevice; |
56 | typedef Eigen::GpuDevice GPUDevice; |
57 | |
58 | template <typename Device, typename T> |
59 | class SpaceToDepthOp : public OpKernel { |
60 | public: |
61 | explicit SpaceToDepthOp(OpKernelConstruction* context) : OpKernel(context) { |
62 | string data_format_str; |
63 | OP_REQUIRES_OK(context, context->GetAttr("data_format" , &data_format_str)); |
64 | OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), |
65 | errors::InvalidArgument("Invalid data format" )); |
66 | |
67 | OP_REQUIRES_OK(context, context->GetAttr("block_size" , &block_size_)); |
68 | OP_REQUIRES(context, block_size_ > 1, |
69 | errors::InvalidArgument("Block size should be > 1, but was: " , |
70 | block_size_)); |
71 | |
72 | if (std::is_same<Device, CPUDevice>::value) { |
73 | OP_REQUIRES( |
74 | context, data_format_ == FORMAT_NHWC, |
75 | errors::InvalidArgument( |
76 | "Only NHWC data_format supported on CPU. Got " , data_format_str)); |
77 | } |
78 | } |
79 | |
80 | void Compute(OpKernelContext* context) override { |
81 | const Tensor& input = context->input(0); |
82 | const int dims = input.dims(); |
83 | |
84 | const bool is_int8x4 = (data_format_ == FORMAT_NCHW_VECT_C); |
85 | const int vect = is_int8x4 ? 4 : 1; |
86 | if (is_int8x4) { |
87 | OP_REQUIRES( |
88 | context, dims == 5, |
89 | errors::InvalidArgument("Input rank should be 5 instead of " , dims)); |
90 | } else { |
91 | OP_REQUIRES( |
92 | context, dims == 4, |
93 | errors::InvalidArgument("Input rank should be 4 instead of " , dims)); |
94 | } |
95 | |
96 | constexpr int kNumSpatialDims = 2; |
97 | const int batch_size = |
98 | input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'N')); |
99 | const int height = |
100 | input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'H')); |
101 | const int width = |
102 | input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'W')); |
103 | const int input_depth = |
104 | input.dim_size(GetTensorDimIndex<kNumSpatialDims>(data_format_, 'C')) * |
105 | vect; |
106 | |
107 | // Both width and height must be divisible by block_size. |
108 | OP_REQUIRES(context, |
109 | (width % block_size_) == 0 && (height % block_size_) == 0, |
110 | errors::InvalidArgument( |
111 | "Image width " , width, " and height " , height, |
112 | " should be divisible by block_size: " , block_size_)); |
113 | |
114 | // The 'spatial' block of size block_size_ X block_size_ will be moved |
115 | // to depth. |
116 | const int output_depth = input_depth * block_size_ * block_size_; |
117 | const int output_width = width / block_size_; |
118 | const int output_height = height / block_size_; |
119 | |
120 | // Allocate output tensor. |
121 | Tensor* outputs_tensor = nullptr; |
122 | OP_REQUIRES_OK(context, |
123 | context->allocate_output( |
124 | 0, |
125 | ShapeFromFormat(data_format_, batch_size, output_height, |
126 | output_width, output_depth), |
127 | &outputs_tensor)); |
128 | |
129 | if (std::is_same<Device, GPUDevice>::value) { |
130 | using RT = typename RawType<T>::type; |
131 | if (data_format_ == FORMAT_NCHW_VECT_C) { |
132 | // NCHW_VECT_C with 4 x qint8 can be treated as NCHW int32. |
133 | auto Tinput_v = input.template reinterpret_last_dimension<int32, 4>(); |
134 | auto Toutput_v = outputs_tensor->reinterpret_last_dimension<int32, 4>(); |
135 | functor::SpaceToDepthOpFunctor<Device, int32, FORMAT_NCHW> functor; |
136 | functor(context->eigen_device<Device>(), Tinput_v, block_size_, |
137 | Toutput_v); |
138 | } else if (data_format_ == FORMAT_NCHW) { |
139 | CHECK((std::is_same<T, RT>::value)); |
140 | functor::SpaceToDepthOpFunctor<Device, RT, FORMAT_NCHW> functor; |
141 | functor(context->eigen_device<Device>(), input.tensor<RT, 4>(), |
142 | block_size_, outputs_tensor->tensor<RT, 4>()); |
143 | } else { |
144 | CHECK((std::is_same<T, RT>::value)); |
145 | functor::SpaceToDepthOpFunctor<Device, RT, FORMAT_NHWC> functor; |
146 | functor(context->eigen_device<Device>(), input.tensor<RT, 4>(), |
147 | block_size_, outputs_tensor->tensor<RT, 4>()); |
148 | } |
149 | } else { |
150 | // NOTE: Assumes data_format_ == FORMAT_NHWC here, since we have rejected |
151 | // (CPU && data_format_ != FORMAT_NHWC) in the constructor. |
152 | functor::SpaceToDepthOpFunctor<Device, T, FORMAT_NHWC> functor; |
153 | functor(context->eigen_device<Device>(), input.tensor<T, 4>(), |
154 | block_size_, outputs_tensor->tensor<T, 4>()); |
155 | } |
156 | }; |
157 | |
158 | private: |
159 | int block_size_; |
160 | TensorFormat data_format_; |
161 | }; |
162 | |
163 | // Partial specialization of SpaceToDepthOpFunctor for a CPUDevice. |
164 | namespace functor { |
165 | template <typename T> |
166 | struct SpaceToDepthOpFunctor<CPUDevice, T, FORMAT_NHWC> { |
167 | void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input, |
168 | int block_size, typename TTypes<T, 4>::Tensor output) { |
169 | const int batch_size = output.dimension(0); |
170 | const int input_height = input.dimension(1); |
171 | const int input_width = input.dimension(2); |
172 | const int input_depth = input.dimension(3); |
173 | |
174 | for (int b = 0; b < batch_size; ++b) { |
175 | for (int h = 0; h < input_height; ++h) { |
176 | const int out_h = h / block_size; |
177 | const int offset_h = (h % block_size); |
178 | for (int w = 0; w < input_width; ++w) { |
179 | const int out_w = w / block_size; |
180 | const int offset_w = (w % block_size); |
181 | const int offset_d = (offset_h * block_size + offset_w) * input_depth; |
182 | for (int d = 0; d < input_depth; ++d) { |
183 | const int out_d = d + offset_d; |
184 | output(b, out_h, out_w, out_d) = input(b, h, w, d); |
185 | } |
186 | } |
187 | } |
188 | } |
189 | } |
190 | }; |
191 | } // namespace functor |
192 | |
193 | #define REGISTER(type) \ |
194 | REGISTER_KERNEL_BUILDER(Name("SpaceToDepth") \ |
195 | .Device(DEVICE_CPU) \ |
196 | .TypeConstraint<type>("T") \ |
197 | .AttrConstraint("data_format", "NHWC"), \ |
198 | SpaceToDepthOp<CPUDevice, type>); |
199 | |
200 | TF_CALL_ALL_TYPES(REGISTER); |
201 | TF_CALL_qint8(REGISTER); |
202 | #undef REGISTER |
203 | |
204 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
205 | REGISTER_KERNEL_BUILDER( |
206 | Name("SpaceToDepth" ).Device(DEVICE_GPU).TypeConstraint<float>("T" ), |
207 | SpaceToDepthOp<GPUDevice, float>); |
208 | REGISTER_KERNEL_BUILDER( |
209 | Name("SpaceToDepth" ).Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T" ), |
210 | SpaceToDepthOp<GPUDevice, Eigen::half>); |
211 | REGISTER_KERNEL_BUILDER( |
212 | Name("SpaceToDepth" ).Device(DEVICE_GPU).TypeConstraint<qint8>("T" ), |
213 | SpaceToDepthOp<GPUDevice, qint8>); |
214 | REGISTER_KERNEL_BUILDER( |
215 | Name("SpaceToDepth" ).Device(DEVICE_GPU).TypeConstraint<uint8>("T" ), |
216 | SpaceToDepthOp<GPUDevice, uint8>); |
217 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
218 | |
219 | } // end namespace tensorflow |
220 | |