1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ |
18 | |
19 | #include "tensorflow/core/framework/tensor_types.h" |
20 | #include "tensorflow/core/lib/random/random_distributions.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | class OpKernelContext; |
25 | |
26 | namespace functor { |
27 | |
28 | // Sample a binomial random variable, with probs and counts for each batch. |
29 | // Uses binomial inversion and a transformed rejection sampling method as |
30 | // described in |
31 | // https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf. |
32 | // Two different algorithms are employed, depending on the size of |
33 | // counts * probs (or counts * (1 - probs) if probs > 0.5. |
34 | // If counts * probs < 10, we simply sum up Geometric random variables until |
35 | // they exceed count, and the number we used is binomially distributed. |
36 | // In expectation, this will take O(counts * probs) time, and requiring in |
37 | // expectation the same number of random variates. |
38 | // This can be much cheaper than summing bernoulli random variates, as we |
39 | // will always need O(counts) bernoulli random variates (so this requires fewer |
40 | // uniform r.v.s as well as can be faster). |
41 | // |
42 | // If counts * probs > 10, we use a transformed-rejection algorithm based on |
43 | // pairs of uniform random variates due to Hormann. |
44 | // https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf |
45 | // This algorithm has higher acceptance rates for counts * probs large, as the |
46 | // proposal distribution becomes quite tight, requiring approximately two |
47 | // uniform random variates as counts * probs becomes large. |
48 | template <typename Device, typename T, typename U> |
49 | struct RandomBinomialFunctor { |
50 | void operator()(OpKernelContext* ctx, const Device& d, int64_t num_batches, |
51 | int64_t samples_per_batch, int64_t num_elements, |
52 | typename TTypes<T>::ConstFlat counts, |
53 | typename TTypes<T>::ConstFlat probs, |
54 | const random::PhiloxRandom& gen, |
55 | typename TTypes<U>::Flat output); |
56 | }; |
57 | |
58 | } // namespace functor |
59 | } // namespace tensorflow |
60 | |
61 | #endif // TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ |
62 | |