1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/data_format_ops.h" |
21 | |
22 | #include <map> |
23 | |
24 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | typedef Eigen::ThreadPoolDevice CPUDevice; |
33 | typedef Eigen::GpuDevice GPUDevice; |
34 | |
35 | // Ensure that `src` and `dst` define a valid permutation. |
36 | // Ops defined in this file assume that user specifies a permutation via two |
37 | // string attributes. This check validates that these attributes properly define |
38 | // it to prevent security vulnerabilities. |
39 | static bool IsValidPermutation(const std::string& src, const std::string& dst) { |
40 | if (src.size() != dst.size()) { |
41 | return false; |
42 | } |
43 | |
44 | std::map<char, bool> characters; |
45 | |
46 | // Every character in `src` must be present only once |
47 | for (const auto c : src) { |
48 | if (characters[c]) { |
49 | return false; |
50 | } |
51 | characters[c] = true; |
52 | } |
53 | |
54 | // Every character in `dst` must show up in `src` exactly once |
55 | for (const auto c : dst) { |
56 | if (!characters[c]) { |
57 | return false; |
58 | } |
59 | characters[c] = false; |
60 | } |
61 | |
62 | // At this point, characters[] has been switched to true and false exactly |
63 | // once for all character in `src` (and `dst`) so we have a valid permutation |
64 | return true; |
65 | } |
66 | |
67 | template <typename Device, typename T> |
68 | class DataFormatDimMapOp : public OpKernel { |
69 | public: |
70 | explicit DataFormatDimMapOp(OpKernelConstruction* context) |
71 | : OpKernel(context) { |
72 | string src_format; |
73 | OP_REQUIRES_OK(context, context->GetAttr("src_format" , &src_format)); |
74 | string dst_format; |
75 | OP_REQUIRES_OK(context, context->GetAttr("dst_format" , &dst_format)); |
76 | OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, |
77 | errors::InvalidArgument( |
78 | "Source format must be of length 4 or 5, received " |
79 | "src_format = " , |
80 | src_format)); |
81 | OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5, |
82 | errors::InvalidArgument("Destination format must be of length " |
83 | "4 or 5, received dst_format = " , |
84 | dst_format)); |
85 | OP_REQUIRES( |
86 | context, IsValidPermutation(src_format, dst_format), |
87 | errors::InvalidArgument( |
88 | "Destination and source format must determine a permutation, got " , |
89 | src_format, " and " , dst_format)); |
90 | dst_idx_ = Tensor(DT_INT32, {static_cast<int64_t>(src_format.size())}); |
91 | for (int i = 0; i < src_format.size(); ++i) { |
92 | for (int j = 0; j < dst_format.size(); ++j) { |
93 | if (dst_format[j] == src_format[i]) { |
94 | dst_idx_.vec<int>()(i) = j; |
95 | break; |
96 | } |
97 | } |
98 | } |
99 | } |
100 | |
101 | void Compute(OpKernelContext* context) override { |
102 | const Tensor& input = context->input(0); |
103 | Tensor* output; |
104 | OP_REQUIRES_OK(context, |
105 | context->allocate_output(0, input.shape(), &output)); |
106 | functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), |
107 | input.flat<T>(), output->flat<T>(), |
108 | dst_idx_.vec<int>()); |
109 | } |
110 | |
111 | Tensor dst_idx_; |
112 | }; |
113 | |
114 | template <typename Device, typename T> |
115 | class DataFormatVecPermuteOp : public OpKernel { |
116 | public: |
117 | explicit DataFormatVecPermuteOp(OpKernelConstruction* context) |
118 | : OpKernel(context) { |
119 | string src_format; |
120 | OP_REQUIRES_OK(context, context->GetAttr("src_format" , &src_format)); |
121 | OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, |
122 | errors::InvalidArgument( |
123 | "Source format must be of length 4 or 5, received " |
124 | "src_format = " , |
125 | src_format)); |
126 | string dst_format; |
127 | OP_REQUIRES_OK(context, context->GetAttr("dst_format" , &dst_format)); |
128 | OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5, |
129 | errors::InvalidArgument("Destination format must be of length " |
130 | "4 or 5, received dst_format = " , |
131 | dst_format)); |
132 | OP_REQUIRES( |
133 | context, IsValidPermutation(src_format, dst_format), |
134 | errors::InvalidArgument( |
135 | "Destination and source format must determine a permutation, got " , |
136 | src_format, " and " , dst_format)); |
137 | src_format_ = src_format; |
138 | dst_format_ = dst_format; |
139 | } |
140 | |
141 | void Compute(OpKernelContext* context) override { |
142 | const Tensor& input = context->input(0); |
143 | OP_REQUIRES(context, input.dims() == 1 || input.dims() == 2, |
144 | errors::InvalidArgument( |
145 | "input must be a vector or 2D tensor, but got shape " , |
146 | input.shape().DebugString())); |
147 | |
148 | const int full_dim_count = src_format_.size(); |
149 | const int spatial_dim_count = full_dim_count - 2; |
150 | |
151 | if (input.dims() == 1) { |
152 | OP_REQUIRES(context, |
153 | input.NumElements() == spatial_dim_count || |
154 | input.NumElements() == full_dim_count, |
155 | errors::InvalidArgument("1D input must be of size " , |
156 | spatial_dim_count, " or " , |
157 | full_dim_count, ", but got shape " , |
158 | input.shape().DebugString())); |
159 | } else if (input.dims() == 2) { |
160 | OP_REQUIRES(context, |
161 | input.dim_size(0) == spatial_dim_count || |
162 | input.dim_size(0) == full_dim_count, |
163 | errors::InvalidArgument("First dimension of 2D input must be " |
164 | "of size " , |
165 | spatial_dim_count, " or " , |
166 | full_dim_count, ", but got shape " , |
167 | input.shape().DebugString())); |
168 | OP_REQUIRES( |
169 | context, input.dim_size(1) == 2, |
170 | errors::InvalidArgument( |
171 | "Second dimension of 2D input must be of size 2, but got shape " , |
172 | input.shape().DebugString())); |
173 | } |
174 | |
175 | Tensor* output = nullptr; |
176 | OP_REQUIRES_OK(context, |
177 | context->allocate_output(0, input.shape(), &output)); |
178 | // Support 1D and 2D cases. |
179 | Eigen::DSizes<Eigen::DenseIndex, 10> dst_idx; |
180 | string src_format_str = src_format_; |
181 | string dst_format_str = dst_format_; |
182 | if (input.dim_size(0) == spatial_dim_count) { |
183 | // If the input is a vector of size spatial_dim_count, treat the elements |
184 | // as spatial dimensions. |
185 | auto keep_only_spatial_dimensions = |
186 | [spatial_dim_count](string* format_str) -> void { |
187 | auto new_end = |
188 | std::remove_if(format_str->begin(), format_str->end(), |
189 | [spatial_dim_count](const char dim) { |
190 | return dim != 'H' && dim != 'W' && |
191 | (spatial_dim_count == 2 || dim != 'D'); |
192 | }); |
193 | format_str->erase(new_end, format_str->end()); |
194 | }; |
195 | keep_only_spatial_dimensions(&src_format_str); |
196 | keep_only_spatial_dimensions(&dst_format_str); |
197 | if (spatial_dim_count == 3) { |
198 | OP_REQUIRES( |
199 | context, src_format_str.size() == 3 && dst_format_str.size() == 3, |
200 | errors::InvalidArgument( |
201 | "Format specifier must contain D, H and W for 2D case" )); |
202 | } else { |
203 | DCHECK(spatial_dim_count == 2); |
204 | OP_REQUIRES(context, |
205 | src_format_str.size() == 2 && dst_format_str.size() == 2, |
206 | errors::InvalidArgument( |
207 | "Format specifier must contain H and W for 2D case" )); |
208 | } |
209 | } |
210 | ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx); |
211 | |
212 | functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(), |
213 | input.flat<T>(), |
214 | output->flat<T>(), dst_idx); |
215 | } |
216 | |
217 | private: |
218 | // Finds out the destination index. Support 1D and 2D cases. |
219 | // Example: HWNC --> NHWC |
220 | // 1D: dst = [1, 2, 0, 3], |
221 | // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7] |
222 | static void ComputeDstIndex(const string& src_format_str, |
223 | const string& dst_format_str, int num_dim, |
224 | Eigen::DSizes<Eigen::DenseIndex, 10>* dst) { |
225 | for (int i = 0; i < src_format_str.size(); ++i) { |
226 | for (int j = 0; j < dst_format_str.size(); ++j) { |
227 | if (dst_format_str[j] != src_format_str[i]) continue; |
228 | // Found the dst index. Set output based on the number of dims. |
229 | for (int k = 0; k < num_dim; ++k) { |
230 | (*dst)[i * num_dim + k] = j * num_dim + k; |
231 | } |
232 | } |
233 | } |
234 | } |
235 | |
236 | string src_format_; |
237 | string dst_format_; |
238 | }; |
239 | |
240 | #define REGISTER_KERNEL(T) \ |
241 | REGISTER_KERNEL_BUILDER( \ |
242 | Name("DataFormatDimMap").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
243 | DataFormatDimMapOp<CPUDevice, T>); |
244 | TF_CALL_int32(REGISTER_KERNEL); |
245 | TF_CALL_int64(REGISTER_KERNEL); |
246 | #undef REGISTER_KERNEL |
247 | |
248 | #define REGISTER_KERNEL(T) \ |
249 | REGISTER_KERNEL_BUILDER( \ |
250 | Name("DataFormatVecPermute").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
251 | DataFormatVecPermuteOp<CPUDevice, T>); |
252 | TF_CALL_int32(REGISTER_KERNEL); |
253 | TF_CALL_int64(REGISTER_KERNEL); |
254 | #undef REGISTER_KERNEL |
255 | |
256 | #define REGISTER_KERNEL(T) \ |
257 | REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \ |
258 | .Device(DEVICE_CPU) \ |
259 | .Label("host") \ |
260 | .TypeConstraint<T>("T"), \ |
261 | DataFormatDimMapOp<CPUDevice, T>); |
262 | TF_CALL_int32(REGISTER_KERNEL); |
263 | TF_CALL_int64(REGISTER_KERNEL); |
264 | #undef REGISTER_KERNEL |
265 | |
266 | #define REGISTER_KERNEL(T) \ |
267 | REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \ |
268 | .Device(DEVICE_CPU) \ |
269 | .Label("host") \ |
270 | .TypeConstraint<T>("T"), \ |
271 | DataFormatVecPermuteOp<CPUDevice, T>); |
272 | TF_CALL_int32(REGISTER_KERNEL); |
273 | TF_CALL_int64(REGISTER_KERNEL); |
274 | #undef REGISTER_KERNEL |
275 | |
276 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
277 | // Forward declarations of the functor specializations for GPU. |
278 | namespace functor { |
279 | #define DECLARE_GPU_SPEC(T) \ |
280 | template <> \ |
281 | void DataFormatDimMap<GPUDevice, T>::operator()( \ |
282 | const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ |
283 | typename TTypes<T>::Flat y, const TTypes<int>::Vec dst); \ |
284 | extern template struct DataFormatDimMap<GPUDevice, T>; |
285 | #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); |
286 | TF_CALL_int32(DECLARE_GPU_SPECS); |
287 | TF_CALL_int64(DECLARE_GPU_SPECS); |
288 | #undef DECLARE_GPU_SPEC |
289 | |
290 | #define DECLARE_GPU_SPEC(T) \ |
291 | template <> \ |
292 | void DataFormatVecPermute<GPUDevice, T>::operator()( \ |
293 | const GPUDevice& d, typename TTypes<T>::ConstFlat x, \ |
294 | typename TTypes<T>::Vec y, \ |
295 | const Eigen::DSizes<Eigen::DenseIndex, 10>& dst_idx); \ |
296 | extern template struct DataFormatVecPermute<GPUDevice, T>; |
297 | #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); |
298 | TF_CALL_int32(DECLARE_GPU_SPECS); |
299 | TF_CALL_int64(DECLARE_GPU_SPECS); |
300 | #undef DECLARE_GPU_SPEC |
301 | } // namespace functor |
302 | |
303 | // Registration of the GPU implementations. |
304 | #define REGISTER_GPU_KERNEL(T) \ |
305 | REGISTER_KERNEL_BUILDER( \ |
306 | Name("DataFormatDimMap").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
307 | DataFormatDimMapOp<GPUDevice, T>); |
308 | TF_CALL_int32(REGISTER_GPU_KERNEL); |
309 | TF_CALL_int64(REGISTER_GPU_KERNEL); |
310 | #undef REGISTER_GPU_KERNEL |
311 | |
312 | #define REGISTER_GPU_KERNEL(T) \ |
313 | REGISTER_KERNEL_BUILDER( \ |
314 | Name("DataFormatVecPermute").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
315 | DataFormatVecPermuteOp<GPUDevice, T>); |
316 | TF_CALL_int32(REGISTER_GPU_KERNEL); |
317 | TF_CALL_int64(REGISTER_GPU_KERNEL); |
318 | #undef REGISTER_GPU_KERNEL |
319 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
320 | |
321 | // Registration of the DEVICE_DEFAULT implementations. |
322 | #define REGISTER_DEVICE_DEFAULT_KERNEL(T) \ |
323 | REGISTER_KERNEL_BUILDER(Name("DataFormatDimMap") \ |
324 | .Device(DEVICE_DEFAULT) \ |
325 | .HostMemory("x") \ |
326 | .HostMemory("y") \ |
327 | .Label("host") \ |
328 | .TypeConstraint<T>("T"), \ |
329 | DataFormatDimMapOp<CPUDevice, T>); |
330 | TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL); |
331 | TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL); |
332 | #undef REGISTER_DEVICE_DEFAULT_KERNEL |
333 | |
334 | #define REGISTER_DEVICE_DEFAULT_KERNEL(T) \ |
335 | REGISTER_KERNEL_BUILDER(Name("DataFormatVecPermute") \ |
336 | .Device(DEVICE_DEFAULT) \ |
337 | .HostMemory("x") \ |
338 | .HostMemory("y") \ |
339 | .Label("host") \ |
340 | .TypeConstraint<T>("T"), \ |
341 | DataFormatVecPermuteOp<CPUDevice, T>); |
342 | TF_CALL_int32(REGISTER_DEVICE_DEFAULT_KERNEL); |
343 | TF_CALL_int64(REGISTER_DEVICE_DEFAULT_KERNEL); |
344 | #undef REGISTER_DEVICE_DEFAULT_KERNEL |
345 | |
346 | } // namespace tensorflow |
347 | |