1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
17#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
18
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/lib/core/status.h"
21#include "tensorflow/core/lib/random/philox_random.h"
22
23namespace tensorflow {
24
25namespace functor {
26
27template <typename Device, typename T>
28struct StatelessRandomGammaFunctor {
29 static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
30 int64_t num_samples, int64_t num_alphas,
31 int64_t samples_per_alpha,
32 const random::PhiloxRandom& random, T* samples_flat);
33};
34
35} // namespace functor
36
37// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns
38// a single sample from this buffer. If the buffer is empty, it first generates
39// new samples using the provided distribution.
40//
41// If the call to Distribution::operator() returns samples[0...N-1], then this
42// class returns samples in the following order:
43//
44// samples[N-1], samples[N-2],..., samples[1], samples[0]
45//
46// For comparison, random::SingleSampleAdapter returns samples in
47// the following order:
48//
49// samples[0], samples[1],...,samples[N-2], samples[N-1].
50//
51template <class Distribution>
52class RandomSampleBuffer {
53 public:
54 typedef typename Distribution::ResultElementType ResultElementType;
55
56 PHILOX_DEVICE_INLINE
57 explicit RandomSampleBuffer(Distribution* distribution)
58 : distribution_(distribution), remaining_numbers_(0) {}
59
60 PHILOX_DEVICE_INLINE
61 ResultElementType operator()(random::PhiloxRandom* random) {
62 if (remaining_numbers_ == 0) {
63 results_ = (*distribution_)(random);
64 remaining_numbers_ = Distribution::kResultElementCount;
65 }
66
67 remaining_numbers_--;
68 return results_[remaining_numbers_];
69 }
70
71 // Mark this buffer as empty. The next call to operator() will fill it
72 // with new random numbers.
73 PHILOX_DEVICE_INLINE
74 void Clear() { remaining_numbers_ = 0; }
75
76 private:
77 typedef typename Distribution::ResultType ResultType;
78
79 Distribution* distribution_;
80 ResultType results_;
81 int remaining_numbers_;
82};
83
84} // namespace tensorflow
85
86#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_
87