1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19// TODO(b/31098934): Figure out why this is necessary here but not in
20// any other place, e.g., the cwise lgamma ops.
21#define EIGEN_HAS_C99_MATH 1
22
23#include "tensorflow/core/kernels/betainc_op.h"
24#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25#include "tensorflow/core/framework/numeric_op.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/util/bcast.h"
31
32namespace tensorflow {
33
34typedef Eigen::ThreadPoolDevice CPUDevice;
35typedef Eigen::GpuDevice GPUDevice;
36
37template <typename Device, typename T>
38class BetaincOp : public OpKernel {
39 public:
40 explicit BetaincOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
41
42 void Compute(OpKernelContext* ctx) override {
43 const Tensor& a = ctx->input(0);
44 const Tensor& b = ctx->input(1);
45 const Tensor& x = ctx->input(2);
46
47 const TensorShape& a_shape = a.shape();
48 const TensorShape& b_shape = b.shape();
49 const TensorShape& x_shape = x.shape();
50 if (a_shape.dims() > 0 && b_shape.dims() > 0) {
51 OP_REQUIRES(ctx, a_shape == b_shape,
52 errors::InvalidArgument(
53 "Shapes of a and b are inconsistent: ",
54 a_shape.DebugString(), " vs. ", b_shape.DebugString()));
55 }
56 if (a_shape.dims() > 0 && x_shape.dims() > 0) {
57 OP_REQUIRES(ctx, a_shape == x_shape,
58 errors::InvalidArgument(
59 "Shapes of a and x are inconsistent: ",
60 a_shape.DebugString(), " vs. ", x_shape.DebugString()));
61 }
62 if (b_shape.dims() > 0 && x_shape.dims() > 0) {
63 OP_REQUIRES(ctx, b_shape == x_shape,
64 errors::InvalidArgument(
65 "Shapes of b and x are inconsistent: ",
66 b_shape.DebugString(), " vs. ", x_shape.DebugString()));
67 }
68
69 TensorShape merged_shape(a_shape);
70 if (b_shape.dims() > 0) merged_shape = b_shape;
71 if (x_shape.dims() > 0) merged_shape = x_shape;
72
73 Tensor* output = nullptr;
74 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, merged_shape, &output));
75
76 if (a_shape == b_shape && a_shape == x_shape) {
77 functor::Betainc<Device, T, 1> functor;
78 functor(ctx->eigen_device<Device>(), a.flat<T>(), b.flat<T>(),
79 x.flat<T>(), output->flat<T>());
80 return;
81 }
82
83 auto merged_shape_vec = BCast::FromShape(merged_shape);
84 BCast a_shaper(BCast::FromShape(a_shape), merged_shape_vec);
85 BCast b_shaper(BCast::FromShape(b_shape), merged_shape_vec);
86 BCast x_shaper(BCast::FromShape(x_shape), merged_shape_vec);
87
88 int ndims = static_cast<int>(a_shaper.x_reshape().size());
89
90 switch (ndims) {
91#define CASE(NDIM) \
92 case NDIM: { \
93 functor::Betainc<Device, T, NDIM> functor; \
94 auto a_value = a.shaped<T, NDIM>(a_shaper.x_reshape()); \
95 auto b_value = b.shaped<T, NDIM>(b_shaper.x_reshape()); \
96 auto x_value = x.shaped<T, NDIM>(x_shaper.x_reshape()); \
97 functor.BCast(ctx->eigen_device<Device>(), a_value, \
98 BCast::ToIndexArray<NDIM>(a_shaper.x_bcast()), b_value, \
99 BCast::ToIndexArray<NDIM>(b_shaper.x_bcast()), x_value, \
100 BCast::ToIndexArray<NDIM>(x_shaper.x_bcast()), \
101 output->shaped<T, NDIM>(a_shaper.y_reshape())); \
102 return; \
103 }
104
105 CASE(1);
106 CASE(2);
107 default: {
108 ctx->SetStatus(errors::InvalidArgument(
109 "Broadcasting rank not supported: ", ndims));
110 return;
111 }
112 }
113 }
114};
115
116#define REGISTER_KERNELS(type) \
117 REGISTER_KERNEL_BUILDER( \
118 Name("Betainc").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
119 BetaincOp<CPUDevice, type>);
120
121REGISTER_KERNELS(float);
122REGISTER_KERNELS(double);
123#undef REGISTER_KERNELS
124
125#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
126 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
127// Forward declarations of the functor specializations for GPU.
128namespace functor {
129#define DECLARE_GPU_SPEC_NDIM(T, NDIM) \
130 template <> \
131 void Betainc<GPUDevice, T, NDIM>::operator()( \
132 const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \
133 typename TTypes<T, NDIM>::ConstTensor b, \
134 typename TTypes<T, NDIM>::ConstTensor x, \
135 typename TTypes<T, NDIM>::Tensor output); \
136 template <> \
137 void Betainc<GPUDevice, T, NDIM>::BCast( \
138 const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor a, \
139 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_a, \
140 typename TTypes<T, NDIM>::ConstTensor b, \
141 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_b, \
142 typename TTypes<T, NDIM>::ConstTensor x, \
143 const typename Eigen::array<Eigen::DenseIndex, NDIM>& bcast_x, \
144 typename TTypes<T, NDIM>::Tensor output); \
145 extern template struct Betainc<GPUDevice, T, NDIM>;
146
147#define DECLARE_GPU_SPEC(T) \
148 DECLARE_GPU_SPEC_NDIM(T, 1) \
149 DECLARE_GPU_SPEC_NDIM(T, 2)
150
151DECLARE_GPU_SPEC(float);
152DECLARE_GPU_SPEC(double);
153
154#undef DECLARE_GPU_SPEC
155#undef DECLARE_GPU_SPEC_NDIM
156} // namespace functor
157
158// Registration of the GPU implementations.
159#define REGISTER_GPU_KERNELS(type) \
160 REGISTER_KERNEL_BUILDER( \
161 Name("Betainc").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
162 BetaincOp<GPUDevice, type>);
163
164REGISTER_GPU_KERNELS(float);
165REGISTER_GPU_KERNELS(double);
166#undef REGISTER_GPU_KERNELS
167
168#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
169
170} // namespace tensorflow
171