1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #include <algorithm> |
20 | #include <cmath> |
21 | |
22 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/tensor_types.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/kernels/cross_op.h" |
30 | #include "tensorflow/core/lib/core/status.h" |
31 | #include "tensorflow/core/platform/logging.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | typedef Eigen::ThreadPoolDevice CPUDevice; |
37 | typedef Eigen::GpuDevice GPUDevice; |
38 | |
39 | template <typename Device, typename Type> |
40 | class CrossOp : public OpKernel { |
41 | public: |
42 | explicit CrossOp(OpKernelConstruction* context) : OpKernel(context) {} |
43 | |
44 | void Compute(OpKernelContext* context) override { |
45 | const Tensor& in0 = context->input(0); |
46 | const Tensor& in1 = context->input(1); |
47 | OP_REQUIRES(context, in0.shape() == in1.shape(), |
48 | errors::InvalidArgument("Both inputs must be of same shape: " , |
49 | in0.shape().DebugString(), " vs. " , |
50 | in1.shape().DebugString())); |
51 | OP_REQUIRES(context, in0.dims() >= 1, |
52 | errors::InvalidArgument("Input must be at least 1D" , |
53 | in0.shape().DebugString())); |
54 | |
55 | // Cross-products only really make sense for three and |
56 | // seven dimensions, and the latter is very obscure. If there is |
57 | // demand, we could perhaps allow 2D vectors where the last |
58 | // element is taken to be zero, but for now, we simply require |
59 | // that all are 3D. |
60 | auto inner_dim = in0.dim_size(in0.dims() - 1); |
61 | OP_REQUIRES(context, inner_dim == 3, |
62 | errors::FailedPrecondition( |
63 | "Cross-products are only defined for 3-element vectors." )); |
64 | |
65 | // Create the output Tensor with the same dimensions as the input Tensors. |
66 | Tensor* output = nullptr; |
67 | OP_REQUIRES_OK(context, context->allocate_output(0, in0.shape(), &output)); |
68 | |
69 | // Make a canonical tensor, maintaining the last (3-vector) dimension, |
70 | // while flattening all others do give the functor easy to work with data. |
71 | typename TTypes<Type, 2>::ConstTensor in0_data = |
72 | in0.flat_inner_dims<Type>(); |
73 | typename TTypes<Type, 2>::ConstTensor in1_data = |
74 | in1.flat_inner_dims<Type>(); |
75 | typename TTypes<Type, 2>::Tensor output_data = |
76 | output->flat_inner_dims<Type>(); |
77 | |
78 | functor::Cross<Device, Type>()(context->eigen_device<Device>(), in0_data, |
79 | in1_data, output_data); |
80 | } |
81 | }; |
82 | |
83 | #define REGISTER_CPU_KERNEL(type) \ |
84 | REGISTER_KERNEL_BUILDER( \ |
85 | Name("Cross").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
86 | CrossOp<CPUDevice, type>); |
87 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNEL); |
88 | #undef REGISTER_CPU_KERNEL |
89 | |
90 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
91 | // Forward declarations of the function specializations for GPU (to prevent |
92 | // building the GPU versions here, they will be built compiling _gpu.cu.cc). |
93 | namespace functor { |
94 | #define DECLARE_GPU_KERNEL(type) \ |
95 | template <> \ |
96 | void Cross<GPUDevice, type>::operator()( \ |
97 | const GPUDevice& d, TTypes<type, 2>::ConstTensor in0_data, \ |
98 | TTypes<type, 2>::ConstTensor in1_data, \ |
99 | TTypes<type, 2>::Tensor output_data); \ |
100 | extern template struct Cross<GPUDevice, type>; |
101 | TF_CALL_REAL_NUMBER_TYPES(DECLARE_GPU_KERNEL); |
102 | #undef DECLARE_GPU_KERNEL |
103 | } // namespace functor |
104 | #define REGISTER_GPU_KERNEL(type) \ |
105 | REGISTER_KERNEL_BUILDER( \ |
106 | Name("Cross").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
107 | CrossOp<GPUDevice, type>); |
108 | |
109 | TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_KERNEL); |
110 | #undef REGISTER_GPU_KERNEL |
111 | #endif |
112 | |
113 | } // namespace tensorflow |
114 | |