1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | // TODO(b/31098934): Figure out why this is necessary here but not in |
20 | // any other place, e.g., the cwise lgamma ops. |
21 | #define EIGEN_HAS_C99_MATH 1 |
22 | |
23 | #include "tensorflow/core/kernels/betainc_op.h" |
24 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
25 | #include "tensorflow/core/framework/numeric_op.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/util/bcast.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | typedef Eigen::ThreadPoolDevice CPUDevice; |
35 | typedef Eigen::GpuDevice GPUDevice; |
36 | |
37 | template <typename Device, typename T> |
38 | class BetaincOp : public OpKernel { |
39 | public: |
40 | explicit BetaincOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
41 | |
42 | void Compute(OpKernelContext* ctx) override { |
43 | const Tensor& a = ctx->input(0); |
44 | const Tensor& b = ctx->input(1); |
45 | const Tensor& x = ctx->input(2); |
46 | |
47 | const TensorShape& a_shape = a.shape(); |
48 | const TensorShape& b_shape = b.shape(); |
49 | const TensorShape& x_shape = x.shape(); |
50 | if (a_shape.dims() > 0 && b_shape.dims() > 0) { |
51 | OP_REQUIRES(ctx, a_shape == b_shape, |
52 | errors::InvalidArgument( |
53 | "Shapes of a and b are inconsistent: " , |
54 | a_shape.DebugString(), " vs. " , b_shape.DebugString())); |
55 | } |
56 | if (a_shape.dims() > 0 && x_shape.dims() > 0) { |
57 | OP_REQUIRES(ctx, a_shape == x_shape, |
58 | errors::InvalidArgument( |
59 | "Shapes of a and x are inconsistent: " , |
60 | a_shape.DebugString(), " vs. " , x_shape.DebugString())); |
61 | } |
62 | if (b_shape.dims() > 0 && x_shape.dims() > 0) { |
63 | OP_REQUIRES(ctx, b_shape == x_shape, |
64 | errors::InvalidArgument( |
65 | "Shapes of b and x are inconsistent: " , |
66 | b_shape.DebugString(), " vs. " , x_shape.DebugString())); |
67 | } |
68 | |
69 | TensorShape merged_shape(a_shape); |
70 | if (b_shape.dims() > 0) merged_shape = b_shape; |
71 | if (x_shape.dims() > 0) merged_shape = x_shape; |
72 | |
73 | Tensor* output = nullptr; |
74 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, merged_shape, &output)); |
75 | |
76 | if (a_shape == b_shape && a_shape == x_shape) { |
77 | functor::Betainc<Device, T, 1> functor; |
78 | functor(ctx->eigen_device<Device>(), a.flat<T>(), b.flat<T>(), |
79 | x.flat<T>(), output->flat<T>()); |
80 | return; |
81 | } |
82 | |
83 | auto merged_shape_vec = BCast::FromShape(merged_shape); |
84 | BCast a_shaper(BCast::FromShape(a_shape), merged_shape_vec); |
85 | BCast b_shaper(BCast::FromShape(b_shape), merged_shape_vec); |
86 | BCast x_shaper(BCast::FromShape(x_shape), merged_shape_vec); |
87 | |
88 | int ndims = static_cast<int>(a_shaper.x_reshape().size()); |
89 | |
90 | switch (ndims) { |
91 | #define CASE(NDIM) \ |
92 | case NDIM: { \ |
93 | functor::Betainc<Device, T, NDIM> functor; \ |
94 | auto a_value = a.shaped<T, NDIM>(a_shaper.x_reshape()); \ |
95 | auto b_value = b.shaped<T, NDIM>(b_shaper.x_reshape()); \ |
96 | auto x_value = x.shaped<T, NDIM>(x_shaper.x_reshape()); \ |
97 | functor.BCast(ctx->eigen_device<Device>(), a_value, \ |
98 | BCast::ToIndexArray<NDIM>(a_shaper.x_bcast()), b_value, \ |
99 | BCast::ToIndexArray<NDIM>(b_shaper.x_bcast()), x_value, \ |
100 | BCast::ToIndexArray<NDIM>(x_shaper.x_bcast()), \ |
101 | output->shaped<T, NDIM>(a_shaper.y_reshape())); \ |
102 | return; \ |
103 | } |
104 | |
105 | CASE(1); |
106 | CASE(2); |
107 | default: { |
108 | ctx->SetStatus(errors::InvalidArgument( |
109 | "Broadcasting rank not supported: " , ndims)); |
110 | return; |
111 | } |
112 | } |
113 | } |
114 | }; |
115 | |
116 | #define REGISTER_KERNELS(type) \ |
117 | REGISTER_KERNEL_BUILDER( \ |
118 | Name("Betainc").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
119 | BetaincOp<CPUDevice, type>); |
120 | |
121 | REGISTER_KERNELS(float); |
122 | REGISTER_KERNELS(double); |
123 | #undef REGISTER_KERNELS |
124 | |
125 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
126 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
127 | // Forward declarations of the functor specializations for GPU. |
128 | namespace functor { |
129 | #define DECLARE_GPU_SPEC_NDIM(T, NDIM) \ |
130 | template <> \ |
131 | void Betainc<GPUDevice, T, NDIM>::operator()( \ |
132 | const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \ |
133 | typename TTypes<T, NDIM>::ConstTensor b, \ |
134 | typename TTypes<T, NDIM>::ConstTensor x, \ |
135 | typename TTypes<T, NDIM>::Tensor output); \ |
136 | template <> \ |
137 | void Betainc<GPUDevice, T, NDIM>::BCast( \ |
138 | const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \ |
139 | const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_a, \ |
140 | typename TTypes<T, NDIM>::ConstTensor b, \ |
141 | const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_b, \ |
142 | typename TTypes<T, NDIM>::ConstTensor x, \ |
143 | const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_x, \ |
144 | typename TTypes<T, NDIM>::Tensor output); \ |
145 | extern template struct Betainc<GPUDevice, T, NDIM>; |
146 | |
147 | #define DECLARE_GPU_SPEC(T) \ |
148 | DECLARE_GPU_SPEC_NDIM(T, 1) \ |
149 | DECLARE_GPU_SPEC_NDIM(T, 2) |
150 | |
151 | DECLARE_GPU_SPEC(float); |
152 | DECLARE_GPU_SPEC(double); |
153 | |
154 | #undef DECLARE_GPU_SPEC |
155 | #undef DECLARE_GPU_SPEC_NDIM |
156 | } // namespace functor |
157 | |
158 | // Registration of the GPU implementations. |
159 | #define REGISTER_GPU_KERNELS(type) \ |
160 | REGISTER_KERNEL_BUILDER( \ |
161 | Name("Betainc").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
162 | BetaincOp<GPUDevice, type>); |
163 | |
164 | REGISTER_GPU_KERNELS(float); |
165 | REGISTER_GPU_KERNELS(double); |
166 | #undef REGISTER_GPU_KERNELS |
167 | |
168 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
169 | |
170 | } // namespace tensorflow |
171 | |