1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/data_format_ops.h"
21
22#include <map>
23
24#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/platform/errors.h"
29
30namespace tensorflow {
31
32typedef Eigen::ThreadPoolDevice CPUDevice;
33typedef Eigen::GpuDevice GPUDevice;
34
35// Ensure that `src` and `dst` define a valid permutation.
36// Ops defined in this file assume that user specifies a permutation via two
37// string attributes. This check validates that these attributes properly define
38// it to prevent security vulnerabilities.
39static bool IsValidPermutation(const std::string& src, const std::string& dst) {
40 if (src.size() != dst.size()) {
41 return false;
42 }
43
44 std::map<char, bool> characters;
45
46 // Every character in `src` must be present only once
47 for (const auto c : src) {
48 if (characters[c]) {
49 return false;
50 }
51 characters[c] = true;
52 }
53
54 // Every character in `dst` must show up in `src` exactly once
55 for (const auto c : dst) {
56 if (!characters[c]) {
57 return false;
58 }
59 characters[c] = false;
60 }
61
62 // At this point, characters[] has been switched to true and false exactly
63 // once for all character in `src` (and `dst`) so we have a valid permutation
64 return true;
65}
66
67template <typename Device, typename T>
68class DataFormatDimMapOp : public OpKernel {
69 public:
70 explicit DataFormatDimMapOp(OpKernelConstruction* context)
71 : OpKernel(context) {
72 string src_format;
73 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
74 string dst_format;
75 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
76 OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
77 errors::InvalidArgument(
78 "Source format must be of length 4 or 5, received "
79 "src_format = ",
80 src_format));
81 OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
82 errors::InvalidArgument("Destination format must be of length "
83 "4 or 5, received dst_format = ",
84 dst_format));
85 OP_REQUIRES(
86 context, IsValidPermutation(src_format, dst_format),
87 errors::InvalidArgument(
88 "Destination and source format must determine a permutation, got ",
89 src_format, " and ", dst_format));
90 dst_idx_ = Tensor(DT_INT32, {static_cast<int64_t>(src_format.size())});
91 for (int i = 0; i < src_format.size(); ++i) {
92 for (int j = 0; j < dst_format.size(); ++j) {
93 if (dst_format[j] == src_format[i]) {
94 dst_idx_.vec<int>()(i) = j;
95 break;
96 }
97 }
98 }
99 }
100
101 void Compute(OpKernelContext* context) override {
102 const Tensor& input = context->input(0);
103 Tensor* output;
104 OP_REQUIRES_OK(context,
105 context->allocate_output(0, input.shape(), &output));
106 functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
107 input.flat<T>(), output->flat<T>(),
108 dst_idx_.vec<int>());
109 }
110
111 Tensor dst_idx_;
112};
113
114template <typename Device, typename T>
115class DataFormatVecPermuteOp : public OpKernel {
116 public:
117 explicit DataFormatVecPermuteOp(OpKernelConstruction* context)
118 : OpKernel(context) {
119 string src_format;
120 OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
121 OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
122 errors::InvalidArgument(
123 "Source format must be of length 4 or 5, received "
124 "src_format = ",
125 src_format));
126 string dst_format;
127 OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
128 OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
129 errors::InvalidArgument("Destination format must be of length "
130 "4 or 5, received dst_format = ",
131 dst_format));
132 OP_REQUIRES(
133 context, IsValidPermutation(src_format, dst_format),
134 errors::InvalidArgument(
135 "Destination and source format must determine a permutation, got ",
136 src_format, " and ", dst_format));
137 src_format_ = src_format;
138 dst_format_ = dst_format;
139 }
140
141 void Compute(OpKernelContext* context) override {
142 const Tensor& input = context->input(0);
143 OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2,
144 errors::InvalidArgument(
145 "input must be a vector or 2D tensor, but got shape ",
146 input.shape().DebugString()));
147
148 const int full_dim_count = src_format_.size();
149 const int spatial_dim_count = full_dim_count - 2;
150
151 if (input.dims() == 1) {
152 OP_REQUIRES(context,
153 input.NumElements() == spatial_dim_count ||
154 input.NumElements() == full_dim_count,
155 errors::InvalidArgument("1D input must be of size ",
156 spatial_dim_count, " or ",
157 full_dim_count, ", but got shape ",
158 input.shape().DebugString()));
159 } else if (input.dims() == 2) {
160 OP_REQUIRES(context,
161 input.dim_size(0) == spatial_dim_count ||
162 input.dim_size(0) == full_dim_count,
163 errors::InvalidArgument("First dimension of 2D input must be "
164 "of size ",
165 spatial_dim_count, " or ",
166 full_dim_count, ", but got shape ",
167 input.shape().DebugString()));
168 OP_REQUIRES(
169 context, input.dim_size(1) == 2,
170 errors::InvalidArgument(
171 "Second dimension of 2D input must be of size 2, but got shape ",
172 input.shape().DebugString()));
173 }
174
175 Tensor* output = nullptr;
176 OP_REQUIRES_OK(context,
177 context->allocate_output(0, input.shape(), &output));
178 // Support 1D and 2D cases.
179 Eigen::DSizes<Eigen::DenseIndex, 10> dst_idx;
180 string src_format_str = src_format_;
181 string dst_format_str = dst_format_;
182 if (input.dim_size(0) == spatial_dim_count) {
183 // If the input is a vector of size spatial_dim_count, treat the elements
184 // as spatial dimensions.
185 auto keep_only_spatial_dimensions =
186 [spatial_dim_count](string* format_str) -> void {
187 auto new_end =
188 std::remove_if(format_str->begin(), format_str->end(),
189 [spatial_dim_count](const char dim) {
190 return dim != 'H' && dim != 'W' &&
191 (spatial_dim_count == 2 || dim != 'D');
192 });
193 format_str->erase(new_end, format_str->end());
194 };
195 keep_only_spatial_dimensions(&src_format_str);
196 keep_only_spatial_dimensions(&dst_format_str);
197 if (spatial_dim_count == 3) {
198 OP_REQUIRES(
199 context, src_format_str.size() == 3 && dst_format_str.size() == 3,
200 errors::InvalidArgument(
201 "Format specifier must contain D, H and W for 2D case"));
202 } else {
203 DCHECK(spatial_dim_count == 2);
204 OP_REQUIRES(context,
205 src_format_str.size() == 2 && dst_format_str.size() == 2,
206 errors::InvalidArgument(
207 "Format specifier must contain H and W for 2D case"));
208 }
209 }
210 ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
211
212 functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(),
213 input.flat<T>(),
214 output->flat<T>(), dst_idx);
215 }
216
217 private:
218 // Finds out the destination index. Support 1D and 2D cases.
219 // Example: HWNC --> NHWC
220 // 1D: dst = [1, 2, 0, 3],
221 // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7]
222 static void ComputeDstIndex(const string& src_format_str,
223 const string& dst_format_str, int num_dim,
224 Eigen::DSizes<Eigen::DenseIndex, 10>* dst) {
225 for (int i = 0; i < src_format_str.size(); ++i) {
226 for (int j = 0; j < dst_format_str.size(); ++j) {
227 if (dst_format_str[j] != src_format_str[i]) continue;
228 // Found the dst index. Set output based on the number of dims.
229 for (int k = 0; k < num_dim; ++k) {
230 (*dst)[i * num_dim + k] = j * num_dim + k;
231 }
232 }
233 }
234 }
235
236 string src_format_;
237 string dst_format_;
238};
239
240#define REGISTER_KERNEL(T) \
241 REGISTER_KERNEL_BUILDER( \
242 Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
243 DataFormatDimMapOp<CPUDevice, T>);
244TF_CALL_int32(REGISTER_KERNEL);
245TF_CALL_int64(REGISTER_KERNEL);
246#undef REGISTER_KERNEL
247
248#define REGISTER_KERNEL(T) \
249 REGISTER_KERNEL_BUILDER( \
250 Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
251 DataFormatVecPermuteOp<CPUDevice, T>);
252TF_CALL_int32(REGISTER_KERNEL);
253TF_CALL_int64(REGISTER_KERNEL);
254#undef REGISTER_KERNEL
255
256#define REGISTER_KERNEL(T) \
257 REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \
258 .Device(DEVICE_CPU) \
259 .Label("host") \
260 .TypeConstraint<T>("T"), \
261 DataFormatDimMapOp<CPUDevice, T>);
262TF_CALL_int32(REGISTER_KERNEL);
263TF_CALL_int64(REGISTER_KERNEL);
264#undef REGISTER_KERNEL
265
266#define REGISTER_KERNEL(T) \
267 REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \
268 .Device(DEVICE_CPU) \
269 .Label("host") \
270 .TypeConstraint<T>("T"), \
271 DataFormatVecPermuteOp<CPUDevice, T>);
272TF_CALL_int32(REGISTER_KERNEL);
273TF_CALL_int64(REGISTER_KERNEL);
274#undef REGISTER_KERNEL
275
276#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
277// Forward declarations of the functor specializations for GPU.
278namespace functor {
279#define DECLARE_GPU_SPEC(T) \
280 template <> \
281 void DataFormatDimMap<GPUDevice, T>::operator()( \
282 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
283 typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \
284 extern template struct DataFormatDimMap<GPUDevice, T>;
285#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
286TF_CALL_int32(DECLARE_GPU_SPECS);
287TF_CALL_int64(DECLARE_GPU_SPECS);
288#undef DECLARE_GPU_SPEC
289
290#define DECLARE_GPU_SPEC(T) \
291 template <> \
292 void DataFormatVecPermute<GPUDevice, T>::operator()( \
293 const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
294 typename TTypes<T>::Vec y, \
295 const Eigen::DSizes<Eigen::DenseIndex, 10>& dst_idx); \
296 extern template struct DataFormatVecPermute<GPUDevice, T>;
297#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
298TF_CALL_int32(DECLARE_GPU_SPECS);
299TF_CALL_int64(DECLARE_GPU_SPECS);
300#undef DECLARE_GPU_SPEC
301} // namespace functor
302
303// Registration of the GPU implementations.
304#define REGISTER_GPU_KERNEL(T) \
305 REGISTER_KERNEL_BUILDER( \
306 Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
307 DataFormatDimMapOp<GPUDevice, T>);
308TF_CALL_int32(REGISTER_GPU_KERNEL);
309TF_CALL_int64(REGISTER_GPU_KERNEL);
310#undef REGISTER_GPU_KERNEL
311
312#define REGISTER_GPU_KERNEL(T) \
313 REGISTER_KERNEL_BUILDER( \
314 Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
315 DataFormatVecPermuteOp<GPUDevice, T>);
316TF_CALL_int32(REGISTER_GPU_KERNEL);
317TF_CALL_int64(REGISTER_GPU_KERNEL);
318#undef REGISTER_GPU_KERNEL
319#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
320
321// Registration of the DEVICE_DEFAULT implementations.
322#define REGISTER_DEVICE_DEFAULT_KERNEL(T) \
323 REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \
324 .Device(DEVICE_DEFAULT) \
325 .HostMemory("x") \
326 .HostMemory("y") \
327 .Label("host") \
328 .TypeConstraint<T>("T"), \
329 DataFormatDimMapOp<CPUDevice, T>);
330TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL);
331TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL);
332#undef REGISTER_DEVICE_DEFAULT_KERNEL
333
334#define REGISTER_DEVICE_DEFAULT_KERNEL(T) \
335 REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \
336 .Device(DEVICE_DEFAULT) \
337 .HostMemory("x") \
338 .HostMemory("y") \
339 .Label("host") \
340 .TypeConstraint<T>("T"), \
341 DataFormatVecPermuteOp<CPUDevice, T>);
342TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL);
343TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL);
344#undef REGISTER_DEVICE_DEFAULT_KERNEL
345
346} // namespace tensorflow
347