1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/random_ops.cc.
17// NOTE: If the algorithm is changed, please run the test
18// .../python/kernel_tests/random:random_binomial_test
19// commenting out the "tf.set_random_seed(seed)" lines, and using the
20// "--runs-per-test=1000" flag. This tests the statistical correctness of the
21// op results.
22
23#define EIGEN_USE_THREADS
24
25#include "tensorflow/core/kernels/random_binomial_op.h"
26
27#include <algorithm>
28#include <cmath>
29#include <memory>
30
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/register_types.h"
33#include "tensorflow/core/framework/rng_alg.h"
34#include "tensorflow/core/framework/tensor.h"
35#include "tensorflow/core/framework/tensor_shape.h"
36#include "tensorflow/core/kernels/random_ops_util.h"
37#include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
38#include "tensorflow/core/kernels/stateless_random_ops.h"
39#include "tensorflow/core/kernels/training_op_helpers.h"
40#include "tensorflow/core/lib/core/refcount.h"
41#include "tensorflow/core/lib/random/random_distributions.h"
42#include "tensorflow/core/platform/logging.h"
43#include "tensorflow/core/util/bcast.h"
44#include "tensorflow/core/util/guarded_philox_random.h"
45#include "tensorflow/core/util/work_sharder.h"
46
47#define UNIFORM(X) \
48 if (uniform_remaining == 0) { \
49 uniform_remaining = Uniform::kResultElementCount; \
50 uniform_result = uniform(gen); \
51 } \
52 uniform_remaining--; \
53 double X = uniform_result[uniform_remaining]
54
55namespace tensorflow {
56
57typedef Eigen::ThreadPoolDevice CPUDevice;
58typedef Eigen::GpuDevice GPUDevice;
59
60namespace {
61
62typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
63
64// Binomial inversion. Given prob, sum geometric random variables until they
65// exceed count. The number of random variables used is binomially distributed.
66// This is also known as binomial inversion, as this is equivalent to inverting
67// the Binomial CDF.
68double binomial_inversion(double count, double prob,
69 random::PhiloxRandom* gen) {
70 using Eigen::numext::ceil;
71 using Eigen::numext::log;
72 using Eigen::numext::log1p;
73
74 double geom_sum = 0;
75 int num_geom = 0;
76
77 Uniform uniform;
78 typename Uniform::ResultType uniform_result;
79 int16_t uniform_remaining = 0;
80
81 while (true) {
82 UNIFORM(u);
83 double geom = ceil(log(u) / log1p(-prob));
84 geom_sum += geom;
85 if (geom_sum > count) {
86 break;
87 }
88 ++num_geom;
89 }
90 return num_geom;
91}
92
93inline double stirling_approx_tail(double k) {
94 static double kTailValues[] = {0.0810614667953272, 0.0413406959554092,
95 0.0276779256849983, 0.02079067210376509,
96 0.0166446911898211, 0.0138761288230707,
97 0.0118967099458917, 0.0104112652619720,
98 0.00925546218271273, 0.00833056343336287};
99 if (k <= 9) {
100 return kTailValues[static_cast<int>(k)];
101 }
102 double kp1sq = (k + 1) * (k + 1);
103 return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
104}
105
106// We use a transformation-rejection algorithm from
107// pairs of uniform random variables due to Hormann.
108// https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
109inline double btrs(double count, double prob, random::PhiloxRandom* gen) {
110 using Eigen::numext::abs;
111 using Eigen::numext::floor;
112 using Eigen::numext::log;
113 using Eigen::numext::log1p;
114 using Eigen::numext::sqrt;
115
116 // This is spq in the paper.
117 const double stddev = sqrt(count * prob * (1 - prob));
118
119 // Other coefficients for Transformed Rejection sampling.
120 const double b = 1.15 + 2.53 * stddev;
121 const double a = -0.0873 + 0.0248 * b + 0.01 * prob;
122 const double c = count * prob + 0.5;
123 const double v_r = 0.92 - 4.2 / b;
124 const double r = prob / (1 - prob);
125
126 const double alpha = (2.83 + 5.1 / b) * stddev;
127 const double m = floor((count + 1) * prob);
128
129 Uniform uniform;
130 typename Uniform::ResultType uniform_result;
131 int16_t uniform_remaining = 0;
132
133 while (true) {
134 UNIFORM(u);
135 UNIFORM(v);
136 u = u - 0.5;
137 double us = 0.5 - abs(u);
138 double k = floor((2 * a / us + b) * u + c);
139
140 // Region for which the box is tight, and we
141 // can return our calculated value This should happen
142 // 0.86 * v_r times. In the limit as n * p is large,
143 // the acceptance rate converges to ~79% (and in the lower
144 // regime it is ~24%).
145 if (us >= 0.07 && v <= v_r) {
146 return k;
147 }
148 // Reject non-sensical answers.
149 if (k < 0 || k > count) {
150 continue;
151 }
152
153 // This deviates from Hormann's BRTS algorithm, as there is a log missing.
154 // For all (u, v) pairs outside of the bounding box, this calculates the
155 // transformed-reject ratio.
156 v = log(v * alpha / (a / (us * us) + b));
157 double upperbound =
158 ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) +
159 (count + 1) * log((count - m + 1) / (count - k + 1)) +
160 (k + 0.5) * log(r * (count - k + 1) / (k + 1)) +
161 stirling_approx_tail(m) + stirling_approx_tail(count - m) -
162 stirling_approx_tail(k) - stirling_approx_tail(count - k));
163 if (v <= upperbound) {
164 return k;
165 }
166 }
167}
168
169} // namespace
170
171namespace functor {
172
173template <typename T, typename U>
174struct RandomBinomialFunctor<CPUDevice, T, U> {
175 void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches,
176 int64_t samples_per_batch, int64_t num_elements,
177 const BCast& bcast, typename TTypes<T>::ConstFlat counts,
178 typename TTypes<T>::ConstFlat probs,
179 const random::PhiloxRandom& gen,
180 typename TTypes<U>::Flat output) {
181 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
182
183 // The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for
184 // the sample shape and [H1, ... Hm] for the batch shape of the samples.
185 // We have B1 * ... * Bk samples per batch member we need.
186 auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
187 &gen, &output](int64_t start_output, int64_t limit_output) {
188 // Vectorized intermediate calculations for uniform rejection sampling.
189 // We always generate at most 4 samples.
190 Eigen::array<T, 4> z;
191 Eigen::array<T, 4> g;
192 const bool should_bcast = bcast.IsBroadcastingRequired();
193 const auto& counts_batch_indices = bcast.x_batch_indices();
194 const auto& probs_batch_indices = bcast.y_batch_indices();
195 auto output_flat = output.data();
196
197 // We partition work across batches (count, prob) and then across samples
198 // per batch member, to avoid extra work.
199 for (int64_t output_idx = start_output; output_idx < limit_output;
200 // output_idx is incremented with the inner loops below.
201 ) {
202 int64_t batch_idx = output_idx / samples_per_batch;
203 U* const output_batch_offset = output_flat + batch_idx;
204 // Generate batch counts from BCast, as it has the right indices to loop
205 // over.
206 T count, prob;
207 if (should_bcast) {
208 count = counts(counts_batch_indices[batch_idx]);
209 prob = probs(probs_batch_indices[batch_idx]);
210 } else {
211 count = counts(batch_idx);
212 prob = probs(batch_idx);
213 }
214
215 // Calculate normalized samples, then convert them.
216 // Determine the method to use.
217 double dcount = static_cast<double>(count);
218 if (dcount <= 0.0 || prob <= T(0.0)) {
219 for (int64_t sample_idx = output_idx % samples_per_batch;
220 sample_idx < samples_per_batch && output_idx < limit_output;
221 ++sample_idx, ++output_idx) {
222 output_batch_offset[sample_idx * num_batches] = static_cast<U>(0.0);
223 }
224 } else if (prob >= T(1.0)) {
225 for (int64_t sample_idx = output_idx % samples_per_batch;
226 sample_idx < samples_per_batch && output_idx < limit_output;
227 ++sample_idx, ++output_idx) {
228 output_batch_offset[sample_idx * num_batches] =
229 static_cast<U>(dcount);
230 }
231 } else if (prob <= T(0.5)) {
232 double dp = static_cast<double>(prob);
233 if (count * prob >= T(10)) {
234 for (int64_t sample_idx = output_idx % samples_per_batch;
235 sample_idx < samples_per_batch && output_idx < limit_output;
236 ++sample_idx, ++output_idx) {
237 random::PhiloxRandom gen_copy = gen;
238 gen_copy.Skip(256 * output_idx);
239 output_batch_offset[sample_idx * num_batches] =
240 static_cast<U>(btrs(dcount, dp, &gen_copy));
241 }
242 } else {
243 for (int64_t sample_idx = output_idx % samples_per_batch;
244 sample_idx < samples_per_batch && output_idx < limit_output;
245 ++sample_idx, ++output_idx) {
246 random::PhiloxRandom gen_copy = gen;
247 // For binomial inversion, we have mean <= 10, variance <= 10.
248 // This means on average we need at most 10 number of samples,
249 // and for 10 standard deviations, we need 42 samples. We reserve
250 // that much.
251 gen_copy.Skip(42 * output_idx);
252 output_batch_offset[sample_idx * num_batches] =
253 static_cast<U>(binomial_inversion(dcount, dp, &gen_copy));
254 }
255 }
256 } else if (prob > T(0.5)) {
257 T q = T(1) - prob;
258 double dq = static_cast<double>(q);
259 if (count * q >= T(10)) {
260 for (int64_t sample_idx = output_idx % samples_per_batch;
261 sample_idx < samples_per_batch && output_idx < limit_output;
262 ++sample_idx, ++output_idx) {
263 random::PhiloxRandom gen_copy = gen;
264 gen_copy.Skip(256 * output_idx);
265 output_batch_offset[sample_idx * num_batches] =
266 static_cast<U>(dcount - btrs(dcount, dq, &gen_copy));
267 }
268 } else {
269 for (int64_t sample_idx = output_idx % samples_per_batch;
270 sample_idx < samples_per_batch && output_idx < limit_output;
271 ++sample_idx, ++output_idx) {
272 random::PhiloxRandom gen_copy = gen;
273 // For binomial inversion, we have mean <= 10, variance <= 10.
274 // This means on average we need at most 10 number of samples,
275 // and for 10 standard deviations, we need 42 samples. We reserve
276 // that much.
277 gen_copy.Skip(42 * output_idx);
278 output_batch_offset[sample_idx * num_batches] = static_cast<U>(
279 dcount - binomial_inversion(dcount, dq, &gen_copy));
280 }
281 }
282 } else { // prob is NaN
283 // TODO(srvasude): What should happen if prob is NaN but the output
284 // type is an integer (which doesn't have a sentinel for NaN)? Fail
285 // the whole batch sample? Return a specialized sentinel like -1?
286 for (int64_t sample_idx = output_idx % samples_per_batch;
287 sample_idx < samples_per_batch && output_idx < limit_output;
288 ++sample_idx, ++output_idx) {
289 output_batch_offset[sample_idx * num_batches] = static_cast<U>(NAN);
290 }
291 }
292 }
293 };
294
295 // This will depend on count * p (or count * q).
296 // For n * p < 10, on average, O(n * p) calls to uniform are
297 // needed, with that
298 // many multiplies. ~10 uniform calls on average with ~200 cost op calls.
299 //
300 // Very roughly, for rate >= 10, the four calls to log
301 // occur for ~72 percent of samples.
302 // 4 x 100 (64-bit cycles per log) * 0.72 = ~288
303 // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
304 // 40 * .72 = ~25.
305 //
306 // Finally, there are several other ops that are done every loop along with
307 // 2 uniform generations along with 5 other ops at 3-6 cycles each.
308 // ~15 / .89 = ~16
309 //
310 // In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost.
311 // We assume that half the tensor has rate < 10, so on average 6
312 // uniform's
313 // will be needed. We will upper bound the other op cost by the one for
314 // rate > 10.
315 static const int kElementCost = 329 + 6 * Uniform::kElementCost +
316 6 * random::PhiloxRandom::kElementCost;
317 Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
318 kElementCost, DoWork);
319 }
320};
321
322} // namespace functor
323
324namespace {
325
326// Samples from a binomial distribution, using the given parameters.
327template <typename Device, typename T, typename U>
328class RandomBinomialOp : public OpKernel {
329 // Reshape batches so each batch is this size if possible.
330 static constexpr int32_t kDesiredBatchSize = 100;
331
332 public:
333 explicit RandomBinomialOp(OpKernelConstruction* context)
334 : OpKernel(context) {}
335
336 void Compute(OpKernelContext* ctx) override {
337 const Tensor& alg_tensor = ctx->input(1);
338 const Tensor& shape_tensor = ctx->input(2);
339 const Tensor& counts_tensor = ctx->input(3);
340 const Tensor& probs_tensor = ctx->input(4);
341
342 tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
343 probs_tensor.shape().dim_sizes(),
344 /*fewer_dims_optimization=*/false,
345 /*return_flattened_batch_indices=*/true);
346 OP_REQUIRES(ctx, bcast.IsValid(),
347 errors::InvalidArgument(
348 "counts and probs must have compatible batch dimensions: ",
349 counts_tensor.shape().DebugString(), " vs. ",
350 probs_tensor.shape().DebugString()));
351 OP_REQUIRES(
352 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
353 errors::InvalidArgument("Input shape should be a vector, got shape: ",
354 shape_tensor.shape().DebugString()));
355 OP_REQUIRES(ctx,
356 (shape_tensor.dtype() == DataType::DT_INT32 ||
357 shape_tensor.dtype() == DataType::DT_INT64),
358 errors::InvalidArgument(
359 "Input shape should have dtype {int32, int64}."));
360
361 // Let's check that the shape tensor dominates the broadcasted tensor.
362 TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
363 TensorShape output_shape;
364 if (shape_tensor.dtype() == DataType::DT_INT32) {
365 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
366 &output_shape));
367 } else {
368 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
369 shape_tensor.vec<int64_t>(), &output_shape));
370 }
371 OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
372 errors::InvalidArgument(
373 "Shape passed in must end with broadcasted shape."));
374 // Now that we have a guarantee, we can get the additional dimensions added
375 // by sampling.
376 OP_REQUIRES(ctx, alg_tensor.dims() == 0,
377 errors::InvalidArgument("algorithm must be of shape [], not ",
378 alg_tensor.shape().DebugString()));
379 Algorithm alg = Algorithm(alg_tensor.flat<int64_t>()(0));
380
381 int64_t samples_per_batch = 1;
382 const int64_t num_sample_dims =
383 (shape_tensor.dim_size(0) - bcast.output_shape().size());
384 for (int64_t i = 0; i < num_sample_dims; ++i) {
385 samples_per_batch *= shape_tensor.flat<int32>()(i);
386 }
387 int64_t num_batches = 1;
388 for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
389 num_batches *= shape_tensor.flat<int32>()(i);
390 }
391 const int64_t num_elements = num_batches * samples_per_batch;
392
393 Tensor* samples_tensor;
394 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
395
396 core::RefCountPtr<Var> var;
397 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
398
399 Tensor* var_tensor = var->tensor();
400 OP_REQUIRES(
401 ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
402 errors::InvalidArgument("dtype of RNG state variable must be ",
403 DataTypeString(STATE_ELEMENT_DTYPE), ", not ",
404 DataTypeString(var_tensor->dtype())));
405 OP_REQUIRES(ctx, var_tensor->dims() == 1,
406 errors::InvalidArgument(
407 "RNG state must have one and only one dimension, not ",
408 var_tensor->dims()));
409 auto var_tensor_flat = var_tensor->flat<StateElementType>();
410 OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX,
411 errors::InvalidArgument("Unsupported algorithm id: ", alg));
412 static_assert(std::is_same<StateElementType, int64_t>::value,
413 "StateElementType must be int64");
414 static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
415 "PhiloxRandom::ResultElementType must be uint32");
416 OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE,
417 errors::InvalidArgument(
418 "For Philox algorithm, the size of state must be at least ",
419 PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size()));
420
421 OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
422 ctx, var_tensor, var->copy_on_read_mode.load()));
423 auto var_data = var_tensor_flat.data();
424 auto philox = GetPhiloxRandomFromMem(var_data);
425 UpdateMemWithPhiloxRandom(
426 philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
427
428 auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
429 binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
430 samples_per_batch, num_elements, bcast,
431 counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
432 samples_tensor->flat<U>());
433 }
434
435 private:
436 TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp);
437};
438
439// Samples from a binomial distribution, using the given parameters.
440template <typename Device, typename T, typename U>
441class StatelessRandomBinomialOp : public OpKernel {
442 // Reshape batches so each batch is this size if possible.
443 static constexpr int32_t kDesiredBatchSize = 100;
444
445 public:
446 explicit StatelessRandomBinomialOp(OpKernelConstruction* context)
447 : OpKernel(context) {}
448
449 void Compute(OpKernelContext* ctx) override {
450 const Tensor& shape_tensor = ctx->input(0);
451 const Tensor& seed_tensor = ctx->input(1);
452 const Tensor& counts_tensor = ctx->input(2);
453 const Tensor& probs_tensor = ctx->input(3);
454
455 OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
456 errors::InvalidArgument("seed must have shape [2], not ",
457 seed_tensor.shape().DebugString()));
458
459 tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
460 probs_tensor.shape().dim_sizes(),
461 /*fewer_dims_optimization=*/false,
462 /*return_flattened_batch_indices=*/true);
463 OP_REQUIRES(ctx, bcast.IsValid(),
464 errors::InvalidArgument(
465 "counts and probs must have compatible batch dimensions: ",
466 counts_tensor.shape().DebugString(), " vs. ",
467 probs_tensor.shape().DebugString()));
468 OP_REQUIRES(
469 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
470 errors::InvalidArgument("Input shape should be a vector, got shape: ",
471 shape_tensor.shape().DebugString()));
472 OP_REQUIRES(ctx,
473 (shape_tensor.dtype() == DataType::DT_INT32 ||
474 shape_tensor.dtype() == DataType::DT_INT64),
475 errors::InvalidArgument(
476 "Input shape should have dtype {int32, int64}."));
477
478 // Let's check that the shape tensor dominates the broadcasted tensor.
479 TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
480 TensorShape output_shape;
481 if (shape_tensor.dtype() == DataType::DT_INT32) {
482 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
483 &output_shape));
484 } else {
485 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
486 shape_tensor.vec<int64_t>(), &output_shape));
487 }
488 OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
489 errors::InvalidArgument(
490 "Shape passed in must end with broadcasted shape."));
491 // Now that we have a guarantee, we can get the additional dimensions added
492 // by sampling.
493 int64_t samples_per_batch = 1;
494 const int64_t num_sample_dims =
495 (shape_tensor.dim_size(0) - bcast.output_shape().size());
496 for (int64_t i = 0; i < num_sample_dims; ++i) {
497 samples_per_batch *= shape_tensor.flat<int32>()(i);
498 }
499 int64_t num_batches = 1;
500 for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
501 num_batches *= shape_tensor.flat<int32>()(i);
502 }
503 const int64_t num_elements = num_batches * samples_per_batch;
504
505 Tensor* samples_tensor;
506 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
507 if (output_shape.num_elements() == 0) return;
508
509 random::PhiloxRandom::Key key;
510 random::PhiloxRandom::ResultType counter;
511 OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
512
513 auto philox = random::PhiloxRandom(counter, key);
514 auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
515 binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
516 samples_per_batch, num_elements, bcast,
517 counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
518 samples_tensor->flat<U>());
519 }
520
521 private:
522 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomBinomialOp);
523};
524
525} // namespace
526
527#define REGISTER(RTYPE, TYPE) \
528 REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \
529 .Device(DEVICE_CPU) \
530 .HostMemory("resource") \
531 .HostMemory("algorithm") \
532 .HostMemory("shape") \
533 .HostMemory("counts") \
534 .HostMemory("probs") \
535 .TypeConstraint<RTYPE>("dtype") \
536 .TypeConstraint<TYPE>("T"), \
537 RandomBinomialOp<CPUDevice, TYPE, RTYPE>); \
538 REGISTER_KERNEL_BUILDER(Name("StatelessRandomBinomial") \
539 .Device(DEVICE_CPU) \
540 .HostMemory("shape") \
541 .HostMemory("seed") \
542 .HostMemory("counts") \
543 .HostMemory("probs") \
544 .TypeConstraint<RTYPE>("dtype") \
545 .TypeConstraint<TYPE>("T"), \
546 StatelessRandomBinomialOp<CPUDevice, TYPE, RTYPE>)
547
548#define REGISTER_ALL(RTYPE) \
549 REGISTER(RTYPE, Eigen::half); \
550 REGISTER(RTYPE, float); \
551 REGISTER(RTYPE, double);
552
553REGISTER_ALL(Eigen::half);
554REGISTER_ALL(float);
555REGISTER_ALL(double);
556REGISTER_ALL(int32);
557REGISTER_ALL(int64_t);
558
559#undef REGISTER
560#undef REGISTER_ALL
561
562} // end namespace tensorflow
563