1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/random_ops.cc. |
17 | // NOTE: If the algorithm is changed, please run the test |
18 | // .../python/kernel_tests/random:random_binomial_test |
19 | // commenting out the "tf.set_random_seed(seed)" lines, and using the |
20 | // "--runs-per-test=1000" flag. This tests the statistical correctness of the |
21 | // op results. |
22 | |
23 | #define EIGEN_USE_THREADS |
24 | |
25 | #include "tensorflow/core/kernels/random_binomial_op.h" |
26 | |
27 | #include <algorithm> |
28 | #include <cmath> |
29 | #include <memory> |
30 | |
31 | #include "tensorflow/core/framework/op_kernel.h" |
32 | #include "tensorflow/core/framework/register_types.h" |
33 | #include "tensorflow/core/framework/rng_alg.h" |
34 | #include "tensorflow/core/framework/tensor.h" |
35 | #include "tensorflow/core/framework/tensor_shape.h" |
36 | #include "tensorflow/core/kernels/random_ops_util.h" |
37 | #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h" |
38 | #include "tensorflow/core/kernels/stateless_random_ops.h" |
39 | #include "tensorflow/core/kernels/training_op_helpers.h" |
40 | #include "tensorflow/core/lib/core/refcount.h" |
41 | #include "tensorflow/core/lib/random/random_distributions.h" |
42 | #include "tensorflow/core/platform/logging.h" |
43 | #include "tensorflow/core/util/bcast.h" |
44 | #include "tensorflow/core/util/guarded_philox_random.h" |
45 | #include "tensorflow/core/util/work_sharder.h" |
46 | |
47 | #define UNIFORM(X) \ |
48 | if (uniform_remaining == 0) { \ |
49 | uniform_remaining = Uniform::kResultElementCount; \ |
50 | uniform_result = uniform(gen); \ |
51 | } \ |
52 | uniform_remaining--; \ |
53 | double X = uniform_result[uniform_remaining] |
54 | |
55 | namespace tensorflow { |
56 | |
57 | typedef Eigen::ThreadPoolDevice CPUDevice; |
58 | typedef Eigen::GpuDevice GPUDevice; |
59 | |
60 | namespace { |
61 | |
62 | typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform; |
63 | |
64 | // Binomial inversion. Given prob, sum geometric random variables until they |
65 | // exceed count. The number of random variables used is binomially distributed. |
66 | // This is also known as binomial inversion, as this is equivalent to inverting |
67 | // the Binomial CDF. |
68 | double binomial_inversion(double count, double prob, |
69 | random::PhiloxRandom* gen) { |
70 | using Eigen::numext::ceil; |
71 | using Eigen::numext::log; |
72 | using Eigen::numext::log1p; |
73 | |
74 | double geom_sum = 0; |
75 | int num_geom = 0; |
76 | |
77 | Uniform uniform; |
78 | typename Uniform::ResultType uniform_result; |
79 | int16_t uniform_remaining = 0; |
80 | |
81 | while (true) { |
82 | UNIFORM(u); |
83 | double geom = ceil(log(u) / log1p(-prob)); |
84 | geom_sum += geom; |
85 | if (geom_sum > count) { |
86 | break; |
87 | } |
88 | ++num_geom; |
89 | } |
90 | return num_geom; |
91 | } |
92 | |
93 | inline double stirling_approx_tail(double k) { |
94 | static double kTailValues[] = {0.0810614667953272, 0.0413406959554092, |
95 | 0.0276779256849983, 0.02079067210376509, |
96 | 0.0166446911898211, 0.0138761288230707, |
97 | 0.0118967099458917, 0.0104112652619720, |
98 | 0.00925546218271273, 0.00833056343336287}; |
99 | if (k <= 9) { |
100 | return kTailValues[static_cast<int>(k)]; |
101 | } |
102 | double kp1sq = (k + 1) * (k + 1); |
103 | return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1); |
104 | } |
105 | |
106 | // We use a transformation-rejection algorithm from |
107 | // pairs of uniform random variables due to Hormann. |
108 | // https://www.tandfonline.com/doi/abs/10.1080/00949659308811496 |
109 | inline double btrs(double count, double prob, random::PhiloxRandom* gen) { |
110 | using Eigen::numext::abs; |
111 | using Eigen::numext::floor; |
112 | using Eigen::numext::log; |
113 | using Eigen::numext::log1p; |
114 | using Eigen::numext::sqrt; |
115 | |
116 | // This is spq in the paper. |
117 | const double stddev = sqrt(count * prob * (1 - prob)); |
118 | |
119 | // Other coefficients for Transformed Rejection sampling. |
120 | const double b = 1.15 + 2.53 * stddev; |
121 | const double a = -0.0873 + 0.0248 * b + 0.01 * prob; |
122 | const double c = count * prob + 0.5; |
123 | const double v_r = 0.92 - 4.2 / b; |
124 | const double r = prob / (1 - prob); |
125 | |
126 | const double alpha = (2.83 + 5.1 / b) * stddev; |
127 | const double m = floor((count + 1) * prob); |
128 | |
129 | Uniform uniform; |
130 | typename Uniform::ResultType uniform_result; |
131 | int16_t uniform_remaining = 0; |
132 | |
133 | while (true) { |
134 | UNIFORM(u); |
135 | UNIFORM(v); |
136 | u = u - 0.5; |
137 | double us = 0.5 - abs(u); |
138 | double k = floor((2 * a / us + b) * u + c); |
139 | |
140 | // Region for which the box is tight, and we |
141 | // can return our calculated value This should happen |
142 | // 0.86 * v_r times. In the limit as n * p is large, |
143 | // the acceptance rate converges to ~79% (and in the lower |
144 | // regime it is ~24%). |
145 | if (us >= 0.07 && v <= v_r) { |
146 | return k; |
147 | } |
148 | // Reject non-sensical answers. |
149 | if (k < 0 || k > count) { |
150 | continue; |
151 | } |
152 | |
153 | // This deviates from Hormann's BRTS algorithm, as there is a log missing. |
154 | // For all (u, v) pairs outside of the bounding box, this calculates the |
155 | // transformed-reject ratio. |
156 | v = log(v * alpha / (a / (us * us) + b)); |
157 | double upperbound = |
158 | ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) + |
159 | (count + 1) * log((count - m + 1) / (count - k + 1)) + |
160 | (k + 0.5) * log(r * (count - k + 1) / (k + 1)) + |
161 | stirling_approx_tail(m) + stirling_approx_tail(count - m) - |
162 | stirling_approx_tail(k) - stirling_approx_tail(count - k)); |
163 | if (v <= upperbound) { |
164 | return k; |
165 | } |
166 | } |
167 | } |
168 | |
169 | } // namespace |
170 | |
171 | namespace functor { |
172 | |
173 | template <typename T, typename U> |
174 | struct RandomBinomialFunctor<CPUDevice, T, U> { |
175 | void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches, |
176 | int64_t samples_per_batch, int64_t num_elements, |
177 | const BCast& bcast, typename TTypes<T>::ConstFlat counts, |
178 | typename TTypes<T>::ConstFlat probs, |
179 | const random::PhiloxRandom& gen, |
180 | typename TTypes<U>::Flat output) { |
181 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
182 | |
183 | // The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for |
184 | // the sample shape and [H1, ... Hm] for the batch shape of the samples. |
185 | // We have B1 * ... * Bk samples per batch member we need. |
186 | auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs, |
187 | &gen, &output](int64_t start_output, int64_t limit_output) { |
188 | // Vectorized intermediate calculations for uniform rejection sampling. |
189 | // We always generate at most 4 samples. |
190 | Eigen::array<T, 4> z; |
191 | Eigen::array<T, 4> g; |
192 | const bool should_bcast = bcast.IsBroadcastingRequired(); |
193 | const auto& counts_batch_indices = bcast.x_batch_indices(); |
194 | const auto& probs_batch_indices = bcast.y_batch_indices(); |
195 | auto output_flat = output.data(); |
196 | |
197 | // We partition work across batches (count, prob) and then across samples |
198 | // per batch member, to avoid extra work. |
199 | for (int64_t output_idx = start_output; output_idx < limit_output; |
200 | // output_idx is incremented with the inner loops below. |
201 | ) { |
202 | int64_t batch_idx = output_idx / samples_per_batch; |
203 | U* const output_batch_offset = output_flat + batch_idx; |
204 | // Generate batch counts from BCast, as it has the right indices to loop |
205 | // over. |
206 | T count, prob; |
207 | if (should_bcast) { |
208 | count = counts(counts_batch_indices[batch_idx]); |
209 | prob = probs(probs_batch_indices[batch_idx]); |
210 | } else { |
211 | count = counts(batch_idx); |
212 | prob = probs(batch_idx); |
213 | } |
214 | |
215 | // Calculate normalized samples, then convert them. |
216 | // Determine the method to use. |
217 | double dcount = static_cast<double>(count); |
218 | if (dcount <= 0.0 || prob <= T(0.0)) { |
219 | for (int64_t sample_idx = output_idx % samples_per_batch; |
220 | sample_idx < samples_per_batch && output_idx < limit_output; |
221 | ++sample_idx, ++output_idx) { |
222 | output_batch_offset[sample_idx * num_batches] = static_cast<U>(0.0); |
223 | } |
224 | } else if (prob >= T(1.0)) { |
225 | for (int64_t sample_idx = output_idx % samples_per_batch; |
226 | sample_idx < samples_per_batch && output_idx < limit_output; |
227 | ++sample_idx, ++output_idx) { |
228 | output_batch_offset[sample_idx * num_batches] = |
229 | static_cast<U>(dcount); |
230 | } |
231 | } else if (prob <= T(0.5)) { |
232 | double dp = static_cast<double>(prob); |
233 | if (count * prob >= T(10)) { |
234 | for (int64_t sample_idx = output_idx % samples_per_batch; |
235 | sample_idx < samples_per_batch && output_idx < limit_output; |
236 | ++sample_idx, ++output_idx) { |
237 | random::PhiloxRandom gen_copy = gen; |
238 | gen_copy.Skip(256 * output_idx); |
239 | output_batch_offset[sample_idx * num_batches] = |
240 | static_cast<U>(btrs(dcount, dp, &gen_copy)); |
241 | } |
242 | } else { |
243 | for (int64_t sample_idx = output_idx % samples_per_batch; |
244 | sample_idx < samples_per_batch && output_idx < limit_output; |
245 | ++sample_idx, ++output_idx) { |
246 | random::PhiloxRandom gen_copy = gen; |
247 | // For binomial inversion, we have mean <= 10, variance <= 10. |
248 | // This means on average we need at most 10 number of samples, |
249 | // and for 10 standard deviations, we need 42 samples. We reserve |
250 | // that much. |
251 | gen_copy.Skip(42 * output_idx); |
252 | output_batch_offset[sample_idx * num_batches] = |
253 | static_cast<U>(binomial_inversion(dcount, dp, &gen_copy)); |
254 | } |
255 | } |
256 | } else if (prob > T(0.5)) { |
257 | T q = T(1) - prob; |
258 | double dq = static_cast<double>(q); |
259 | if (count * q >= T(10)) { |
260 | for (int64_t sample_idx = output_idx % samples_per_batch; |
261 | sample_idx < samples_per_batch && output_idx < limit_output; |
262 | ++sample_idx, ++output_idx) { |
263 | random::PhiloxRandom gen_copy = gen; |
264 | gen_copy.Skip(256 * output_idx); |
265 | output_batch_offset[sample_idx * num_batches] = |
266 | static_cast<U>(dcount - btrs(dcount, dq, &gen_copy)); |
267 | } |
268 | } else { |
269 | for (int64_t sample_idx = output_idx % samples_per_batch; |
270 | sample_idx < samples_per_batch && output_idx < limit_output; |
271 | ++sample_idx, ++output_idx) { |
272 | random::PhiloxRandom gen_copy = gen; |
273 | // For binomial inversion, we have mean <= 10, variance <= 10. |
274 | // This means on average we need at most 10 number of samples, |
275 | // and for 10 standard deviations, we need 42 samples. We reserve |
276 | // that much. |
277 | gen_copy.Skip(42 * output_idx); |
278 | output_batch_offset[sample_idx * num_batches] = static_cast<U>( |
279 | dcount - binomial_inversion(dcount, dq, &gen_copy)); |
280 | } |
281 | } |
282 | } else { // prob is NaN |
283 | // TODO(srvasude): What should happen if prob is NaN but the output |
284 | // type is an integer (which doesn't have a sentinel for NaN)? Fail |
285 | // the whole batch sample? Return a specialized sentinel like -1? |
286 | for (int64_t sample_idx = output_idx % samples_per_batch; |
287 | sample_idx < samples_per_batch && output_idx < limit_output; |
288 | ++sample_idx, ++output_idx) { |
289 | output_batch_offset[sample_idx * num_batches] = static_cast<U>(NAN); |
290 | } |
291 | } |
292 | } |
293 | }; |
294 | |
295 | // This will depend on count * p (or count * q). |
296 | // For n * p < 10, on average, O(n * p) calls to uniform are |
297 | // needed, with that |
298 | // many multiplies. ~10 uniform calls on average with ~200 cost op calls. |
299 | // |
300 | // Very roughly, for rate >= 10, the four calls to log |
301 | // occur for ~72 percent of samples. |
302 | // 4 x 100 (64-bit cycles per log) * 0.72 = ~288 |
303 | // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each: |
304 | // 40 * .72 = ~25. |
305 | // |
306 | // Finally, there are several other ops that are done every loop along with |
307 | // 2 uniform generations along with 5 other ops at 3-6 cycles each. |
308 | // ~15 / .89 = ~16 |
309 | // |
310 | // In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost. |
311 | // We assume that half the tensor has rate < 10, so on average 6 |
312 | // uniform's |
313 | // will be needed. We will upper bound the other op cost by the one for |
314 | // rate > 10. |
315 | static const int kElementCost = 329 + 6 * Uniform::kElementCost + |
316 | 6 * random::PhiloxRandom::kElementCost; |
317 | Shard(worker_threads.num_threads, worker_threads.workers, num_elements, |
318 | kElementCost, DoWork); |
319 | } |
320 | }; |
321 | |
322 | } // namespace functor |
323 | |
324 | namespace { |
325 | |
326 | // Samples from a binomial distribution, using the given parameters. |
327 | template <typename Device, typename T, typename U> |
328 | class RandomBinomialOp : public OpKernel { |
329 | // Reshape batches so each batch is this size if possible. |
330 | static constexpr int32_t kDesiredBatchSize = 100; |
331 | |
332 | public: |
333 | explicit RandomBinomialOp(OpKernelConstruction* context) |
334 | : OpKernel(context) {} |
335 | |
336 | void Compute(OpKernelContext* ctx) override { |
337 | const Tensor& alg_tensor = ctx->input(1); |
338 | const Tensor& shape_tensor = ctx->input(2); |
339 | const Tensor& counts_tensor = ctx->input(3); |
340 | const Tensor& probs_tensor = ctx->input(4); |
341 | |
342 | tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(), |
343 | probs_tensor.shape().dim_sizes(), |
344 | /*fewer_dims_optimization=*/false, |
345 | /*return_flattened_batch_indices=*/true); |
346 | OP_REQUIRES(ctx, bcast.IsValid(), |
347 | errors::InvalidArgument( |
348 | "counts and probs must have compatible batch dimensions: " , |
349 | counts_tensor.shape().DebugString(), " vs. " , |
350 | probs_tensor.shape().DebugString())); |
351 | OP_REQUIRES( |
352 | ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), |
353 | errors::InvalidArgument("Input shape should be a vector, got shape: " , |
354 | shape_tensor.shape().DebugString())); |
355 | OP_REQUIRES(ctx, |
356 | (shape_tensor.dtype() == DataType::DT_INT32 || |
357 | shape_tensor.dtype() == DataType::DT_INT64), |
358 | errors::InvalidArgument( |
359 | "Input shape should have dtype {int32, int64}." )); |
360 | |
361 | // Let's check that the shape tensor dominates the broadcasted tensor. |
362 | TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); |
363 | TensorShape output_shape; |
364 | if (shape_tensor.dtype() == DataType::DT_INT32) { |
365 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(), |
366 | &output_shape)); |
367 | } else { |
368 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( |
369 | shape_tensor.vec<int64_t>(), &output_shape)); |
370 | } |
371 | OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape), |
372 | errors::InvalidArgument( |
373 | "Shape passed in must end with broadcasted shape." )); |
374 | // Now that we have a guarantee, we can get the additional dimensions added |
375 | // by sampling. |
376 | OP_REQUIRES(ctx, alg_tensor.dims() == 0, |
377 | errors::InvalidArgument("algorithm must be of shape [], not " , |
378 | alg_tensor.shape().DebugString())); |
379 | Algorithm alg = Algorithm(alg_tensor.flat<int64_t>()(0)); |
380 | |
381 | int64_t samples_per_batch = 1; |
382 | const int64_t num_sample_dims = |
383 | (shape_tensor.dim_size(0) - bcast.output_shape().size()); |
384 | for (int64_t i = 0; i < num_sample_dims; ++i) { |
385 | samples_per_batch *= shape_tensor.flat<int32>()(i); |
386 | } |
387 | int64_t num_batches = 1; |
388 | for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) { |
389 | num_batches *= shape_tensor.flat<int32>()(i); |
390 | } |
391 | const int64_t num_elements = num_batches * samples_per_batch; |
392 | |
393 | Tensor* samples_tensor; |
394 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor)); |
395 | |
396 | core::RefCountPtr<Var> var; |
397 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var)); |
398 | |
399 | Tensor* var_tensor = var->tensor(); |
400 | OP_REQUIRES( |
401 | ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE, |
402 | errors::InvalidArgument("dtype of RNG state variable must be " , |
403 | DataTypeString(STATE_ELEMENT_DTYPE), ", not " , |
404 | DataTypeString(var_tensor->dtype()))); |
405 | OP_REQUIRES(ctx, var_tensor->dims() == 1, |
406 | errors::InvalidArgument( |
407 | "RNG state must have one and only one dimension, not " , |
408 | var_tensor->dims())); |
409 | auto var_tensor_flat = var_tensor->flat<StateElementType>(); |
410 | OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX, |
411 | errors::InvalidArgument("Unsupported algorithm id: " , alg)); |
412 | static_assert(std::is_same<StateElementType, int64_t>::value, |
413 | "StateElementType must be int64" ); |
414 | static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value, |
415 | "PhiloxRandom::ResultElementType must be uint32" ); |
416 | OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE, |
417 | errors::InvalidArgument( |
418 | "For Philox algorithm, the size of state must be at least " , |
419 | PHILOX_MIN_STATE_SIZE, "; got " , var_tensor_flat.size())); |
420 | |
421 | OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>( |
422 | ctx, var_tensor, var->copy_on_read_mode.load())); |
423 | auto var_data = var_tensor_flat.data(); |
424 | auto philox = GetPhiloxRandomFromMem(var_data); |
425 | UpdateMemWithPhiloxRandom( |
426 | philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data); |
427 | |
428 | auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>(); |
429 | binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches, |
430 | samples_per_batch, num_elements, bcast, |
431 | counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox, |
432 | samples_tensor->flat<U>()); |
433 | } |
434 | |
435 | private: |
436 | TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp); |
437 | }; |
438 | |
439 | // Samples from a binomial distribution, using the given parameters. |
440 | template <typename Device, typename T, typename U> |
441 | class StatelessRandomBinomialOp : public OpKernel { |
442 | // Reshape batches so each batch is this size if possible. |
443 | static constexpr int32_t kDesiredBatchSize = 100; |
444 | |
445 | public: |
446 | explicit StatelessRandomBinomialOp(OpKernelConstruction* context) |
447 | : OpKernel(context) {} |
448 | |
449 | void Compute(OpKernelContext* ctx) override { |
450 | const Tensor& shape_tensor = ctx->input(0); |
451 | const Tensor& seed_tensor = ctx->input(1); |
452 | const Tensor& counts_tensor = ctx->input(2); |
453 | const Tensor& probs_tensor = ctx->input(3); |
454 | |
455 | OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2, |
456 | errors::InvalidArgument("seed must have shape [2], not " , |
457 | seed_tensor.shape().DebugString())); |
458 | |
459 | tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(), |
460 | probs_tensor.shape().dim_sizes(), |
461 | /*fewer_dims_optimization=*/false, |
462 | /*return_flattened_batch_indices=*/true); |
463 | OP_REQUIRES(ctx, bcast.IsValid(), |
464 | errors::InvalidArgument( |
465 | "counts and probs must have compatible batch dimensions: " , |
466 | counts_tensor.shape().DebugString(), " vs. " , |
467 | probs_tensor.shape().DebugString())); |
468 | OP_REQUIRES( |
469 | ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), |
470 | errors::InvalidArgument("Input shape should be a vector, got shape: " , |
471 | shape_tensor.shape().DebugString())); |
472 | OP_REQUIRES(ctx, |
473 | (shape_tensor.dtype() == DataType::DT_INT32 || |
474 | shape_tensor.dtype() == DataType::DT_INT64), |
475 | errors::InvalidArgument( |
476 | "Input shape should have dtype {int32, int64}." )); |
477 | |
478 | // Let's check that the shape tensor dominates the broadcasted tensor. |
479 | TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); |
480 | TensorShape output_shape; |
481 | if (shape_tensor.dtype() == DataType::DT_INT32) { |
482 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(), |
483 | &output_shape)); |
484 | } else { |
485 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( |
486 | shape_tensor.vec<int64_t>(), &output_shape)); |
487 | } |
488 | OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape), |
489 | errors::InvalidArgument( |
490 | "Shape passed in must end with broadcasted shape." )); |
491 | // Now that we have a guarantee, we can get the additional dimensions added |
492 | // by sampling. |
493 | int64_t samples_per_batch = 1; |
494 | const int64_t num_sample_dims = |
495 | (shape_tensor.dim_size(0) - bcast.output_shape().size()); |
496 | for (int64_t i = 0; i < num_sample_dims; ++i) { |
497 | samples_per_batch *= shape_tensor.flat<int32>()(i); |
498 | } |
499 | int64_t num_batches = 1; |
500 | for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) { |
501 | num_batches *= shape_tensor.flat<int32>()(i); |
502 | } |
503 | const int64_t num_elements = num_batches * samples_per_batch; |
504 | |
505 | Tensor* samples_tensor; |
506 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor)); |
507 | if (output_shape.num_elements() == 0) return; |
508 | |
509 | random::PhiloxRandom::Key key; |
510 | random::PhiloxRandom::ResultType counter; |
511 | OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter)); |
512 | |
513 | auto philox = random::PhiloxRandom(counter, key); |
514 | auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>(); |
515 | binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches, |
516 | samples_per_batch, num_elements, bcast, |
517 | counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox, |
518 | samples_tensor->flat<U>()); |
519 | } |
520 | |
521 | private: |
522 | TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomBinomialOp); |
523 | }; |
524 | |
525 | } // namespace |
526 | |
527 | #define REGISTER(RTYPE, TYPE) \ |
528 | REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial") \ |
529 | .Device(DEVICE_CPU) \ |
530 | .HostMemory("resource") \ |
531 | .HostMemory("algorithm") \ |
532 | .HostMemory("shape") \ |
533 | .HostMemory("counts") \ |
534 | .HostMemory("probs") \ |
535 | .TypeConstraint<RTYPE>("dtype") \ |
536 | .TypeConstraint<TYPE>("T"), \ |
537 | RandomBinomialOp<CPUDevice, TYPE, RTYPE>); \ |
538 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomBinomial") \ |
539 | .Device(DEVICE_CPU) \ |
540 | .HostMemory("shape") \ |
541 | .HostMemory("seed") \ |
542 | .HostMemory("counts") \ |
543 | .HostMemory("probs") \ |
544 | .TypeConstraint<RTYPE>("dtype") \ |
545 | .TypeConstraint<TYPE>("T"), \ |
546 | StatelessRandomBinomialOp<CPUDevice, TYPE, RTYPE>) |
547 | |
548 | #define REGISTER_ALL(RTYPE) \ |
549 | REGISTER(RTYPE, Eigen::half); \ |
550 | REGISTER(RTYPE, float); \ |
551 | REGISTER(RTYPE, double); |
552 | |
553 | REGISTER_ALL(Eigen::half); |
554 | REGISTER_ALL(float); |
555 | REGISTER_ALL(double); |
556 | REGISTER_ALL(int32); |
557 | REGISTER_ALL(int64_t); |
558 | |
559 | #undef REGISTER |
560 | #undef REGISTER_ALL |
561 | |
562 | } // end namespace tensorflow |
563 | |