1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/random_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/random_poisson_op.h"
21
22#include <algorithm>
23#include <cmath>
24#include <limits>
25#include <memory>
26
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/framework/tensor_shape.h"
31#include "tensorflow/core/framework/tensor_util.h"
32#include "tensorflow/core/lib/random/random_distributions.h"
33#include "tensorflow/core/lib/random/simple_philox.h"
34#include "tensorflow/core/util/guarded_philox_random.h"
35#include "tensorflow/core/util/work_sharder.h"
36
37#if EIGEN_COMP_GNUC && __cplusplus > 199711L
38#define DISABLE_FLOAT_EQUALITY_WARNING \
39 _Pragma("GCC diagnostic push") \
40 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
41#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
42#else
43#define DISABLE_FLOAT_EQUALITY_WARNING
44#define ENABLE_FLOAT_EQUALITY_WARNING
45#endif
46
47#define UNIFORM(X) \
48 if (uniform_remaining == 0) { \
49 uniform_remaining = Uniform::kResultElementCount; \
50 uniform_result = uniform(&gen); \
51 } \
52 uniform_remaining--; \
53 CT X = uniform_result[uniform_remaining]
54
55namespace tensorflow {
56namespace {
57
58static constexpr int kReservedSamplesPerOutput = 256;
59
60typedef Eigen::ThreadPoolDevice CPUDevice;
61
62template <typename T>
63struct PoissonComputeType {
64 typedef double ComputeType;
65};
66
67} // namespace
68
69namespace functor {
70
71template <typename T, typename U>
72struct PoissonFunctor<CPUDevice, T, U> {
73 void operator()(OpKernelContext* ctx, const CPUDevice& d, const T* rate_flat,
74 int num_rate, int num_samples,
75 const random::PhiloxRandom& rng, U* samples_flat) {
76 // Two different algorithms are employed, depending on the size of
77 // rate.
78 // If rate < 10, we use an algorithm attributed to Knuth:
79 // Seminumerical Algorithms. Art of Computer Programming, Volume 2.
80 //
81 // This algorithm runs in O(rate) time, and will require O(rate)
82 // uniform variates.
83 //
84 // If rate >= 10 we use a transformation-rejection algorithm from
85 // pairs of uniform random variables due to Hormann.
86 // http://www.sciencedirect.com/science/article/pii/0167668793909974
87 //
88 // The algorithm has an acceptance rate of ~89% for the smallest rate
89 // (~10),
90 // and higher accept rates for higher rate, so runtime is
91 // O(NumRate * NumSamples * k) with k ~ 1 / 0.89.
92 //
93 // We partition work first across rates then across
94 // samples-per-rate to
95 // avoid a couple flops which can be done on a per-rate basis.
96
97 typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
98
99 auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
100 int64_t start_output, int64_t limit_output) {
101 // Capturing "rng" by value would only make a copy for the _shared_
102 // lambda. Since we want to let each worker have its own copy, we pass
103 // "rng" by reference and explicitly do a copy assignment.
104
105 Uniform uniform;
106 typename Uniform::ResultType uniform_result;
107 for (int64_t output_idx = start_output; output_idx < limit_output;
108 /* output_idx incremented within inner loop below */) {
109 const int64_t rate_idx = output_idx / num_samples;
110
111 // Several calculations can be done on a per-rate basis.
112 const CT rate = CT(rate_flat[rate_idx]);
113
114 auto samples_rate_output = samples_flat + rate_idx;
115
116 if (rate < CT(10)) {
117 // Knuth's algorithm for generating Poisson random variates.
118 // Given a Poisson process, the time between events is exponentially
119 // distributed. If we have a Poisson process with rate lambda, then,
120 // the time between events is distributed Exp(lambda). If X ~
121 // Uniform(0, 1), then Y ~ Exp(lambda), where Y = -log(X) / lambda.
122 // Thus to simulate a Poisson draw, we can draw X_i ~ Exp(lambda),
123 // and N ~ Poisson(lambda), where N is the least number such that
124 // \sum_i^N X_i > 1.
125 const CT exp_neg_rate = Eigen::numext::exp(-rate);
126
127 // Compute the rest of the samples for the current rate value.
128 for (int64_t sample_idx = output_idx % num_samples;
129 sample_idx < num_samples && output_idx < limit_output;
130 sample_idx++, output_idx++) {
131 random::PhiloxRandom gen = rng;
132 gen.Skip(kReservedSamplesPerOutput * output_idx);
133 int16_t uniform_remaining = 0;
134
135 CT prod = 1;
136 CT x = 0;
137
138 // Keep trying until we surpass e^(-rate). This will take
139 // expected time proportional to rate.
140 while (true) {
141 UNIFORM(u);
142 prod = prod * u;
143 if (prod <= exp_neg_rate &&
144 x <= CT(Eigen::NumTraits<U>::highest())) {
145 samples_rate_output[sample_idx * num_rate] = U(x);
146 break;
147 }
148 x += 1;
149 }
150 }
151 continue;
152 }
153 if (Eigen::numext::isinf(rate) && rate > CT(0)) {
154 // Fill the rest of the samples for the current rate value.
155 for (int64_t sample_idx = output_idx % num_samples;
156 sample_idx < num_samples && output_idx < limit_output;
157 sample_idx++, output_idx++) {
158 U k = Eigen::NumTraits<U>::infinity();
159 samples_rate_output[sample_idx * num_rate] = k;
160 }
161 continue;
162 }
163 // Transformed rejection due to Hormann.
164 //
165 // Given a CDF F(x), and G(x), a dominating distribution chosen such
166 // that it is close to the inverse CDF F^-1(x), compute the following
167 // steps:
168 //
169 // 1) Generate U and V, two independent random variates. Set U = U - 0.5
170 // (this step isn't strictly necessary, but is done to make some
171 // calculations symmetric and convenient. Henceforth, G is defined on
172 // [-0.5, 0.5]).
173 //
174 // 2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return
175 // to step 1. alpha is the acceptance probability of the rejection
176 // algorithm.
177 //
178 // For more details on transformed rejection, see:
179 // http://citeseer.ist.psu.edu/viewdoc/citations;jsessionid=1BEB35946CC807879F55D42512E5490C?doi=10.1.1.48.3054.
180 //
181 // The dominating distribution in this case:
182 //
183 // G(u) = (2 * a / (2 - |u|) + b) * u + c
184
185 using Eigen::numext::log;
186 const CT log_rate = log(rate);
187
188 // Constants used to define the dominating distribution. Names taken
189 // from Hormann's paper. Constants were chosen to define the tightest
190 // G(u) for the inverse Poisson CDF.
191 const CT b = CT(0.931) + CT(2.53) * Eigen::numext::sqrt(rate);
192 const CT a = CT(-0.059) + CT(0.02483) * b;
193
194 // This is the inverse acceptance rate. At a minimum (when rate = 10),
195 // this corresponds to ~75% acceptance. As the rate becomes larger, this
196 // approaches ~89%.
197 const CT inv_alpha = CT(1.1239) + CT(1.1328) / (b - CT(3.4));
198
199 // Compute the rest of the samples for the current rate value.
200 for (int64_t sample_idx = output_idx % num_samples;
201 sample_idx < num_samples && output_idx < limit_output;
202 sample_idx++, output_idx++) {
203 random::PhiloxRandom gen = rng;
204 gen.Skip(kReservedSamplesPerOutput * output_idx);
205 int16_t uniform_remaining = 0;
206
207 while (true) {
208 UNIFORM(u);
209 u -= CT(0.5);
210 UNIFORM(v);
211
212 CT u_shifted = CT(0.5) - Eigen::numext::abs(u);
213 CT k = Eigen::numext::floor((CT(2) * a / u_shifted + b) * u + rate +
214 CT(0.43));
215
216 if (k > CT(Eigen::NumTraits<U>::highest())) {
217 // retry in case of overflow.
218 continue;
219 }
220
221 // When alpha * f(G(U)) * G'(U) is close to 1, it is possible to
222 // find a rectangle (-u_r, u_r) x (0, v_r) under the curve, such
223 // that if v <= v_r and |u| <= u_r, then we can accept.
224 // Here v_r = 0.9227 - 3.6224 / (b - 2) and u_r = 0.43.
225 if (u_shifted >= CT(0.07) &&
226 v <= CT(0.9277) - CT(3.6224) / (b - CT(2))) {
227 samples_rate_output[sample_idx * num_rate] = U(k);
228 break;
229 }
230
231 if (k < 0 || (u_shifted < CT(0.013) && v > u_shifted)) {
232 continue;
233 }
234
235 // The expression below is equivalent to the computation of step 2)
236 // in transformed rejection (v <= alpha * F'(G(u)) * G'(u)).
237 CT s = log(v * inv_alpha / (a / (u_shifted * u_shifted) + b));
238 CT t = -rate + k * log_rate - Eigen::numext::lgamma(k + 1);
239 if (s <= t) {
240 samples_rate_output[sample_idx * num_rate] = U(k);
241 break;
242 }
243 }
244 }
245 }
246 };
247
248 // This will depend on rate.
249 // For rate < 10, on average, O(rate) calls to uniform are
250 // needed, with that
251 // many multiplies. ~10 uniform calls on average with ~25 cost op calls.
252 //
253 // Very roughly, for rate >= 10, the single call to log + call to
254 // lgamma
255 // occur for ~60 percent of samples.
256 // 2 x 100 (64-bit cycles per log) * 0.62 = ~124
257 // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
258 // 40 * .62 = ~25.
259 //
260 // Finally, there are several other ops that are done every loop along with
261 // 2 uniform generations along with 5 other ops at 3-6 cycles each.
262 // ~15 / .89 = ~16
263 //
264 // In total this should be ~165 + 2 * Uniform::kElementCost.
265 // We assume that half the tensor has rate < 10, so on average 6
266 // uniform's
267 // will be needed. We will upper bound the other op cost by the one for
268 // rate > 10.
269 static const int kElementCost = 165 + 6 * Uniform::kElementCost +
270 6 * random::PhiloxRandom::kElementCost;
271 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
272 Shard(worker_threads.num_threads, worker_threads.workers,
273 num_rate * num_samples, kElementCost, DoWork);
274 }
275
276 private:
277 typedef typename PoissonComputeType<T>::ComputeType CT;
278};
279
280} // namespace functor
281
282namespace {
283
284// Samples from one or more Poisson distributions.
285template <typename T, typename U>
286class RandomPoissonOp : public OpKernel {
287 public:
288 explicit RandomPoissonOp(OpKernelConstruction* context) : OpKernel(context) {
289 OP_REQUIRES_OK(context, generator_.Init(context));
290 }
291
292 void Compute(OpKernelContext* ctx) override {
293 const Tensor& shape_t = ctx->input(0);
294 const Tensor& rate_t = ctx->input(1);
295
296 TensorShape samples_shape;
297 OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_t, &samples_shape));
298 const int64_t num_samples = samples_shape.num_elements();
299 OP_REQUIRES_OK(ctx, samples_shape.AppendShapeWithStatus(rate_t.shape()));
300
301 // Allocate output samples.
302 Tensor* samples_t = nullptr;
303 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
304 if (num_samples == 0) return;
305
306 const auto rate_flat = rate_t.flat<T>().data();
307 const int64_t num_rate = rate_t.NumElements();
308 auto samples_flat = samples_t->flat<U>().data();
309 random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
310 num_samples * num_rate, kReservedSamplesPerOutput);
311
312 functor::PoissonFunctor<CPUDevice, T, U>()(
313 ctx, ctx->eigen_device<CPUDevice>(), rate_flat, num_rate, num_samples,
314 rng, samples_flat);
315 }
316
317 private:
318 GuardedPhiloxRandom generator_;
319
320 TF_DISALLOW_COPY_AND_ASSIGN(RandomPoissonOp);
321};
322} // namespace
323
324#undef UNIFORM
325
326#define REGISTER(TYPE) \
327 REGISTER_KERNEL_BUILDER( \
328 Name("RandomPoisson").Device(DEVICE_CPU).TypeConstraint<TYPE>("dtype"), \
329 RandomPoissonOp<TYPE, TYPE>);
330
331TF_CALL_half(REGISTER);
332TF_CALL_float(REGISTER);
333TF_CALL_double(REGISTER);
334
335#define REGISTER_V2(RTYPE, OTYPE) \
336 template struct functor::PoissonFunctor<CPUDevice, RTYPE, OTYPE>; \
337 REGISTER_KERNEL_BUILDER(Name("RandomPoissonV2") \
338 .Device(DEVICE_CPU) \
339 .TypeConstraint<RTYPE>("R") \
340 .TypeConstraint<OTYPE>("dtype"), \
341 RandomPoissonOp<RTYPE, OTYPE>);
342
343#define REGISTER_ALL(RTYPE) \
344 REGISTER_V2(RTYPE, Eigen::half); \
345 REGISTER_V2(RTYPE, float); \
346 REGISTER_V2(RTYPE, double); \
347 REGISTER_V2(RTYPE, int32); \
348 REGISTER_V2(RTYPE, int64_t);
349
350REGISTER_ALL(Eigen::half);
351REGISTER_ALL(float);
352REGISTER_ALL(double);
353REGISTER_ALL(int32);
354REGISTER_ALL(int64_t);
355
356#undef REGISTER_ALL
357#undef REGISTER_V2
358#undef REGISTER
359
360} // end namespace tensorflow
361