1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/stateless_random_gamma_op.h"
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor_util.h"
25#include "tensorflow/core/kernels/stateless_random_ops.h"
26#include "tensorflow/core/lib/random/philox_random.h"
27#include "tensorflow/core/lib/random/random_distributions.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/util/work_sharder.h"
30
31#if EIGEN_COMP_GNUC && __cplusplus > 199711L
32#define DISABLE_FLOAT_EQUALITY_WARNING \
33 _Pragma("GCC diagnostic push") \
34 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
35#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
36#else
37#define DISABLE_FLOAT_EQUALITY_WARNING
38#define ENABLE_FLOAT_EQUALITY_WARNING
39#endif
40
41namespace tensorflow {
42
43namespace {
44
45// Each attempt to generate a new draw from the Gamma distribution is 95+%
46// successful, and requires 1-2 normal + 1 uniform sample.
47static constexpr int kReservedSamplesPerOutput = 256;
48
49typedef Eigen::ThreadPoolDevice CPUDevice;
50typedef Eigen::GpuDevice GPUDevice;
51
52}; // namespace
53
54namespace functor {
55
56template <typename T>
57struct StatelessRandomGammaFunctor<CPUDevice, T> {
58 static Status Fill(OpKernelContext* ctx, const T* alpha_flat,
59 int64_t num_samples, int64_t num_alphas,
60 int64_t samples_per_alpha,
61 const random::PhiloxRandom& random, T* samples_flat) {
62 typedef random::NormalDistribution<random::PhiloxRandom, double> Normal;
63 typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
64
65 // We partition work first across alphas then across samples-per-alpha to
66 // avoid a couple flops which can be done on a per-alpha basis.
67
68 auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
69 alpha_flat](int64_t start_output, int64_t limit_output) {
70 // Capturing "random" by-value would only make a copy for the _shared_
71 // lambda. Since we want to let each worker have its own copy, we pass
72 // "random" by reference and explicitly do a copy assignment.
73
74 using Eigen::numext::exp;
75 using Eigen::numext::log;
76 using Eigen::numext::log1p;
77 using Eigen::numext::pow;
78
79 Normal normal;
80 Uniform uniform;
81
82 RandomSampleBuffer<Normal> normal_buffer(&normal);
83 RandomSampleBuffer<Uniform> uniform_buffer(&uniform);
84
85 for (int64_t output_idx = start_output; output_idx < limit_output;
86 /* output_idx incremented within inner loop below */) {
87 int64_t alpha_idx = output_idx / samples_per_alpha;
88
89 // Instead of +alpha_idx for each sample, we offset the pointer once.
90 T* const samples_alpha_offset = samples_flat + alpha_idx;
91
92 // Several calculations can be done on a per-alpha basis.
93 const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
94
95 DISABLE_FLOAT_EQUALITY_WARNING
96 if (alpha == 1.0) {
97 ENABLE_FLOAT_EQUALITY_WARNING
98 // Sample from an exponential distribution.
99 for (int64_t sample_idx = output_idx % samples_per_alpha;
100 sample_idx < samples_per_alpha && output_idx < limit_output;
101 sample_idx++, output_idx++) {
102 // As we want data stable regardless of sharding, we skip on a
103 // per-sample basis.
104 random::PhiloxRandom gen = random;
105 gen.Skip(kReservedSamplesPerOutput * output_idx);
106 double u = uniform(&gen)[Uniform::kResultElementCount - 1];
107 const double res = -log1p(-u);
108 samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
109 } // for (sample_idx)
110 } else { // if alpha != 1.0
111 // Transformation-rejection from pairs of uniform and normal random
112 // variables. http://dl.acm.org/citation.cfm?id=358414
113 //
114 // The algorithm has an acceptance rate of ~95% for small alpha (~1),
115 // and higher accept rates for higher alpha, so runtime is
116 // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
117 //
118 // For alpha<1, we add one to d=alpha-1/3, and multiply the final
119 // result by uniform()^(1/alpha)
120 const bool alpha_less_than_one = alpha < 1.0;
121 const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
122 const double c = 1.0 / 3 / sqrt(d);
123
124 // Compute the rest of the samples for the current alpha value.
125 for (int64_t sample_idx = output_idx % samples_per_alpha;
126 sample_idx < samples_per_alpha && output_idx < limit_output;
127 sample_idx++, output_idx++) {
128 // Since each sample may use a variable number of normal/uniform
129 // samples, and we want data stable regardless of sharding, we skip
130 // on a per-sample basis.
131 random::PhiloxRandom gen = random;
132 gen.Skip(kReservedSamplesPerOutput * output_idx);
133
134 // To prevent overwriting SampleBuffer's underlying array with
135 // zeros (in tensorflow::random::Array constructor), we just mark
136 // the buffer as empty instead of initializing a new SampleBuffer
137 // object here. The next call to operator() will fill the buffer
138 // with new numbers.
139 normal_buffer.Clear();
140 uniform_buffer.Clear();
141
142 // Keep trying until we don't reject a sample. In practice, we will
143 // only reject ~5% at worst, for low alpha near 1.
144 while (true) {
145 const double x = normal_buffer(&gen);
146 double v = 1 + c * x;
147 if (v <= 0) {
148 continue;
149 }
150 v = v * v * v;
151 double u = uniform_buffer(&gen);
152 // The first option in the if is a "squeeze" short-circuit to
153 // dodge the two logs. Magic constant sourced from the paper
154 // linked above. Upward of .91 of the area covered by the log
155 // inequality is covered by the squeeze as well (larger coverage
156 // for smaller values of alpha).
157 if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
158 (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
159 double res = d * v;
160 if (alpha_less_than_one) {
161 double b = uniform_buffer(&gen);
162 res *= pow(b, 1 / alpha);
163 }
164 samples_alpha_offset[sample_idx * num_alphas] =
165 static_cast<T>(res);
166 break;
167 }
168 } // while: true
169 } // for: sample_idx
170 } // if (alpha == 1.0)
171 } // for: output_idx
172 }; // DoWork
173
174 // Two calls to log only occur for ~10% of samples reaching the log line.
175 // 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
176 // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
177 // each = ~60.
178 // All of this /0.95 (expected value of geometric distribution is 1/p) due
179 // to the rejection possibility = ~85.
180 static const int kElementCost = 85 + 2 * Normal::kElementCost +
181 Uniform::kElementCost +
182 3 * random::PhiloxRandom::kElementCost;
183 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
184 Shard(worker_threads.num_threads, worker_threads.workers, num_samples,
185 kElementCost, DoWork);
186 return OkStatus();
187 }
188};
189
190} // namespace functor
191
192namespace {
193
194template <typename Device, typename T>
195class StatelessRandomGammaOp : public OpKernel {
196 public:
197 explicit StatelessRandomGammaOp(OpKernelConstruction* context)
198 : OpKernel(context) {}
199
200 void Compute(OpKernelContext* context) override {
201 // Sanitize input
202 const Tensor& shape_t = context->input(0);
203 const Tensor& seed_t = context->input(1);
204 TensorShape shape;
205 OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape));
206 OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
207 errors::InvalidArgument("seed must have shape [2], not ",
208 seed_t.shape().DebugString()));
209
210 // Allocate output
211 Tensor* output;
212 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
213 if (shape.num_elements() == 0) return;
214
215 random::PhiloxRandom::Key key;
216 random::PhiloxRandom::ResultType counter;
217 OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter));
218
219 // Fill in the random numbers
220 Fill(context, random::PhiloxRandom(counter, key), output);
221 }
222
223 private:
224 void Fill(OpKernelContext* ctx, random::PhiloxRandom random, Tensor* output) {
225 const Tensor& alpha_t = ctx->input(2);
226
227 TensorShape samples_shape = output->shape();
228 OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()),
229 errors::InvalidArgument(
230 "Shape passed in must end with broadcasted shape."));
231
232 const int64_t num_alphas = alpha_t.NumElements();
233 OP_REQUIRES(ctx, num_alphas > 0,
234 errors::InvalidArgument(
235 "Input alpha should have non-zero element count, got: ",
236 num_alphas));
237
238 const int64_t num_samples = samples_shape.num_elements();
239 const int64_t samples_per_alpha = num_samples / num_alphas;
240 const auto alpha_flat = alpha_t.flat<T>().data();
241 auto samples_flat = output->flat<T>().data();
242
243 OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill(
244 ctx, alpha_flat, num_samples, num_alphas,
245 samples_per_alpha, random, samples_flat));
246 }
247
248 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp);
249};
250
251// Register CPU kernels for stateless gamma op.
252#define REGISTER_GAMMA_CPU(TYPE) \
253 REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
254 .Device(DEVICE_CPU) \
255 .HostMemory("shape") \
256 .HostMemory("seed") \
257 .HostMemory("alpha") \
258 .TypeConstraint<TYPE>("dtype"), \
259 StatelessRandomGammaOp<CPUDevice, TYPE>)
260
261TF_CALL_half(REGISTER_GAMMA_CPU);
262TF_CALL_bfloat16(REGISTER_GAMMA_CPU);
263TF_CALL_float(REGISTER_GAMMA_CPU);
264TF_CALL_double(REGISTER_GAMMA_CPU);
265
266#undef REGISTER_GAMMA_CPU
267
268// Register GPU kernels for stateless gamma op.
269#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
270
271#define REGISTER_GAMMA_GPU(TYPE) \
272 REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \
273 .Device(DEVICE_GPU) \
274 .HostMemory("shape") \
275 .HostMemory("seed") \
276 .TypeConstraint<TYPE>("dtype"), \
277 StatelessRandomGammaOp<GPUDevice, TYPE>)
278
279TF_CALL_half(REGISTER_GAMMA_GPU);
280TF_CALL_bfloat16(REGISTER_GAMMA_GPU);
281TF_CALL_float(REGISTER_GAMMA_GPU);
282TF_CALL_double(REGISTER_GAMMA_GPU);
283
284#undef REGISTER_GAMMA_GPU
285
286#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
287
288} // namespace
289} // namespace tensorflow
290