1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/random_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include <algorithm>
21#include <cmath>
22#include <memory>
23
24#include "tensorflow/core/framework/op_kernel.h"
25#include "tensorflow/core/framework/register_types.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/kernels/random_op_cpu.h"
30#include "tensorflow/core/lib/hash/crc32c.h"
31#include "tensorflow/core/lib/random/random_distributions.h"
32#include "tensorflow/core/lib/random/simple_philox.h"
33#include "tensorflow/core/platform/logging.h"
34#include "tensorflow/core/util/guarded_philox_random.h"
35#include "tensorflow/core/util/work_sharder.h"
36
37#if EIGEN_COMP_GNUC && __cplusplus > 199711L
38#define DISABLE_FLOAT_EQUALITY_WARNING \
39 _Pragma("GCC diagnostic push") \
40 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
41#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
42#else
43#define DISABLE_FLOAT_EQUALITY_WARNING
44#define ENABLE_FLOAT_EQUALITY_WARNING
45#endif
46
47namespace tensorflow {
48
49typedef Eigen::ThreadPoolDevice CPUDevice;
50typedef Eigen::GpuDevice GPUDevice;
51
52namespace {
53
54static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
55 int index, Tensor** output) {
56 TensorShape tensor_shape;
57 TF_RETURN_IF_ERROR(tensor::MakeShape(shape, &tensor_shape));
58 return ctx->allocate_output(index, tensor_shape, output);
59}
60
61// For now, use the same interface as RandomOp, so we can choose either one
62// at the run-time.
63template <typename Device, class Distribution>
64class PhiloxRandomOp : public OpKernel {
65 public:
66 typedef typename Distribution::ResultElementType T;
67 explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
68 OP_REQUIRES_OK(ctx, generator_.Init(ctx));
69 }
70
71 void Compute(OpKernelContext* ctx) override {
72 const Tensor& shape = ctx->input(0);
73 Tensor* output;
74 OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
75 auto output_flat = output->flat<T>();
76 functor::FillPhiloxRandom<Device, Distribution>()(
77 ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
78 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
79 // it just here.
80 generator_.ReserveRandomOutputs(output_flat.size(), 256),
81 output_flat.data(), output_flat.size(), Distribution());
82 }
83
84 private:
85 GuardedPhiloxRandom generator_;
86};
87
88template <typename Device, class IntType>
89class RandomUniformIntOp : public OpKernel {
90 public:
91 explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
92 OP_REQUIRES_OK(ctx, generator_.Init(ctx));
93 }
94
95 void Compute(OpKernelContext* ctx) override {
96 const Tensor& shape = ctx->input(0);
97 const Tensor& minval = ctx->input(1);
98 const Tensor& maxval = ctx->input(2);
99 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
100 errors::InvalidArgument("minval must be 0-D, got shape ",
101 minval.shape().DebugString()));
102 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
103 errors::InvalidArgument("maxval must be 0-D, got shape ",
104 maxval.shape().DebugString()));
105
106 // Allocate output, and exit early if possible
107 Tensor* output;
108 OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
109 if (output->NumElements() == 0) return;
110
111 // Verify that minval < maxval. This check intentionally happens after the
112 // early exit for empty output. Zero impossible things are fine.
113 IntType lo = minval.scalar<IntType>()();
114 IntType hi = maxval.scalar<IntType>()();
115 OP_REQUIRES(
116 ctx, lo < hi,
117 errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
118
119 // Build distribution
120 typedef random::UniformDistribution<random::PhiloxRandom, IntType>
121 Distribution;
122 Distribution dist(lo, hi);
123
124 auto output_flat = output->flat<IntType>();
125 functor::FillPhiloxRandom<Device, Distribution>()(
126 ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr,
127 // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
128 // it just here.
129 generator_.ReserveRandomOutputs(output_flat.size(), 256),
130 output_flat.data(), output_flat.size(), dist);
131 }
132
133 private:
134 GuardedPhiloxRandom generator_;
135};
136
137// Samples from one or more gamma distributions. All internal computations are
138// done with double precision for numerical stability.
139template <typename T>
140class RandomGammaOp : public OpKernel {
141 public:
142 explicit RandomGammaOp(OpKernelConstruction* context) : OpKernel(context) {
143 OP_REQUIRES_OK(context, generator_.Init(context));
144 }
145
146 void Compute(OpKernelContext* ctx) override {
147 const Tensor& shape_t = ctx->input(0);
148 const Tensor& alpha_t = ctx->input(1);
149
150 OP_REQUIRES(ctx,
151 TensorShapeUtils::IsVector(shape_t.shape()) &&
152 (shape_t.dtype() == DataType::DT_INT32 ||
153 shape_t.dtype() == DataType::DT_INT64),
154 errors::InvalidArgument(
155 "shape must be a vector of {int32,int64}, got shape: ",
156 shape_t.DebugString()));
157 TensorShape samples_shape;
158 if (shape_t.dtype() == DataType::DT_INT32) {
159 auto vec = shape_t.flat<int32>();
160 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
161 &samples_shape));
162 } else if (shape_t.dtype() == DataType::DT_INT64) {
163 auto vec = shape_t.flat<int64_t>();
164 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
165 &samples_shape));
166 }
167 const int64_t samples_per_alpha = samples_shape.num_elements();
168
169 OP_REQUIRES_OK(ctx, samples_shape.AppendShapeWithStatus(alpha_t.shape()));
170 // Allocate output samples.
171 Tensor* samples_t = nullptr;
172 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
173
174 if (samples_shape.num_elements() == 0) return;
175
176 using random::PhiloxRandom;
177
178 typedef random::NormalDistribution<PhiloxRandom, double> Normal;
179 typedef random::UniformDistribution<PhiloxRandom, double> Uniform;
180#define UNIFORM(X) \
181 if (uniform_remaining == 0) { \
182 uniform_remaining = Uniform::kResultElementCount; \
183 uniform_result = uniform(&gen); \
184 } \
185 uniform_remaining--; \
186 double X = uniform_result[uniform_remaining]
187
188 // Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
189 static constexpr int kReservedSamplesPerOutput = 256;
190
191 const auto alpha_flat = alpha_t.flat<T>().data();
192 const int64_t num_alphas = alpha_t.NumElements();
193 OP_REQUIRES(ctx, num_alphas > 0,
194 errors::InvalidArgument(
195 "Input alpha should have non-zero element count, got: ",
196 num_alphas));
197 auto samples_flat = samples_t->flat<T>().data();
198 PhiloxRandom rng = generator_.ReserveRandomOutputs(
199 samples_per_alpha * num_alphas, kReservedSamplesPerOutput);
200
201 // We partition work first across alphas then across samples-per-alpha to
202 // avoid a couple flops which can be done on a per-alpha basis.
203
204 auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
205 alpha_flat](int64_t start_output, int64_t limit_output) {
206 using Eigen::numext::exp;
207 using Eigen::numext::log;
208 using Eigen::numext::log1p;
209 using Eigen::numext::pow;
210
211 // Capturing "rng" by-value would only make a copy for the _shared_
212 // lambda. Since we want to let each worker have its own copy, we pass
213 // "rng" by reference and explicitly do a copy assignment.
214
215 Normal normal;
216 Uniform uniform;
217 typename Normal::ResultType norm_result;
218 typename Uniform::ResultType uniform_result;
219 for (int64_t output_idx = start_output; output_idx < limit_output;
220 /* output_idx incremented within inner loop below */) {
221 int64_t alpha_idx = output_idx / samples_per_alpha;
222
223 // Instead of +alpha_idx for each sample, we offset the pointer once.
224 T* const samples_alpha_offset = samples_flat + alpha_idx;
225
226 // Several calculations can be done on a per-alpha basis.
227 const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
228
229 DISABLE_FLOAT_EQUALITY_WARNING
230 if (alpha == static_cast<double>(1.0)) {
231 ENABLE_FLOAT_EQUALITY_WARNING
232 // Sample from an exponential distribution.
233 for (int64_t sample_idx = output_idx % samples_per_alpha;
234 sample_idx < samples_per_alpha && output_idx < limit_output;
235 sample_idx++, output_idx++) {
236 // As we want data stable regardless of sharding
237 // (including eventually on GPU), we skip on a per-sample basis.
238 PhiloxRandom gen = rng;
239 gen.Skip(kReservedSamplesPerOutput * output_idx);
240 int16_t uniform_remaining = 0;
241 UNIFORM(u);
242 const double res = -log1p(-u);
243 samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
244 } // for (sample_idx)
245 } else { // if alpha != 1.0
246 // Transformation-rejection from pairs of uniform and normal random
247 // variables. http://dl.acm.org/citation.cfm?id=358414
248 //
249 // The algorithm has an acceptance rate of ~95% for small alpha (~1),
250 // and higher accept rates for higher alpha, so runtime is
251 // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
252 //
253 // For alpha<1, we add one to d=alpha-1/3, and multiply the final
254 // result by uniform()^(1/alpha)
255 const bool alpha_less_than_one = alpha < 1;
256 const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
257 const double c = 1.0 / 3 / sqrt(d);
258
259 // Compute the rest of the samples for the current alpha value.
260 for (int64_t sample_idx = output_idx % samples_per_alpha;
261 sample_idx < samples_per_alpha && output_idx < limit_output;
262 sample_idx++, output_idx++) {
263 // Since each sample may use a variable number of normal/uniform
264 // samples, and we want data stable regardless of sharding
265 // (including eventually on GPU), we skip on a per-sample basis.
266 PhiloxRandom gen = rng;
267 gen.Skip(kReservedSamplesPerOutput * output_idx);
268 int16_t norm_remaining = 0;
269 int16_t uniform_remaining = 0;
270
271 // Keep trying until we don't reject a sample. In practice, we will
272 // only reject ~5% at worst, for low alpha near 1.
273 while (true) {
274 if (norm_remaining == 0) {
275 norm_remaining = Normal::kResultElementCount;
276 norm_result = normal(&gen);
277 }
278 norm_remaining--;
279 const double x = norm_result[norm_remaining];
280 double v = 1 + c * x;
281 if (v <= 0) {
282 continue;
283 }
284 v = v * v * v;
285 UNIFORM(u);
286 // The first option in the if is a "squeeze" short-circuit to
287 // dodge the two logs. Magic constant sourced from the paper
288 // linked above. Upward of .91 of the area covered by the log
289 // inequality is covered by the squeeze as well (larger coverage
290 // for smaller values of alpha).
291 if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
292 (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
293 double res = d * v;
294 if (alpha_less_than_one) {
295 UNIFORM(b);
296 res *= pow(b, 1 / alpha);
297 }
298 samples_alpha_offset[sample_idx * num_alphas] =
299 static_cast<T>(res);
300 break;
301 }
302 } // while: true
303 } // for: sample_idx
304 } // if (alpha == 1.0)
305 } // for: output_idx
306 }; // DoWork
307#undef UNIFORM
308 // Two calls to log only occur for ~10% of samples reaching the log line.
309 // 2 x 100 (64-bit cycles per log) x 0.10 = ~20.
310 // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
311 // each = ~60.
312 // All of this /0.95 due to the rejection possibility = ~85.
313 static const int kElementCost = 85 + 2 * Normal::kElementCost +
314 Uniform::kElementCost +
315 3 * PhiloxRandom::kElementCost;
316 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
317 Shard(worker_threads.num_threads, worker_threads.workers,
318 num_alphas * samples_per_alpha, kElementCost, DoWork);
319 }
320
321 private:
322 GuardedPhiloxRandom generator_;
323
324 TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp);
325};
326
327} // namespace
328
329#define REGISTER(TYPE) \
330 template struct functor::FillPhiloxRandom< \
331 CPUDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \
332 template struct functor::FillPhiloxRandom< \
333 CPUDevice, random::NormalDistribution<random::PhiloxRandom, TYPE>>; \
334 template struct functor::FillPhiloxRandom< \
335 CPUDevice, \
336 random::TruncatedNormalDistribution< \
337 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>; \
338 REGISTER_KERNEL_BUILDER( \
339 Name("RandomUniform") \
340 .Device(DEVICE_CPU) \
341 .HostMemory("shape") \
342 .TypeConstraint<TYPE>("dtype"), \
343 PhiloxRandomOp<CPUDevice, random::UniformDistribution< \
344 random::PhiloxRandom, TYPE>>); \
345 REGISTER_KERNEL_BUILDER( \
346 Name("RandomStandardNormal") \
347 .Device(DEVICE_CPU) \
348 .HostMemory("shape") \
349 .TypeConstraint<TYPE>("dtype"), \
350 PhiloxRandomOp<CPUDevice, \
351 random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
352 REGISTER_KERNEL_BUILDER( \
353 Name("TruncatedNormal") \
354 .Device(DEVICE_CPU) \
355 .HostMemory("shape") \
356 .TypeConstraint<TYPE>("dtype"), \
357 PhiloxRandomOp< \
358 CPUDevice, \
359 random::TruncatedNormalDistribution< \
360 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); \
361 REGISTER_KERNEL_BUILDER( \
362 Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
363 RandomGammaOp<TYPE>)
364
365#define REGISTER_FULL_INT(IntType) \
366 template struct functor::FillPhiloxRandom< \
367 CPUDevice, \
368 random::UniformFullIntDistribution<random::PhiloxRandom, IntType>>
369
370#define REGISTER_INT(IntType) \
371 REGISTER_FULL_INT(IntType); \
372 template struct functor::FillPhiloxRandom< \
373 CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
374 REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
375 .Device(DEVICE_CPU) \
376 .HostMemory("shape") \
377 .HostMemory("minval") \
378 .HostMemory("maxval") \
379 .TypeConstraint<IntType>("Tout"), \
380 RandomUniformIntOp<CPUDevice, IntType>);
381
382TF_CALL_half(REGISTER);
383TF_CALL_bfloat16(REGISTER);
384TF_CALL_float(REGISTER);
385TF_CALL_double(REGISTER);
386TF_CALL_int32(REGISTER_INT);
387TF_CALL_int64(REGISTER_INT);
388TF_CALL_uint32(REGISTER_FULL_INT);
389TF_CALL_uint64(REGISTER_FULL_INT);
390
391#undef REGISTER
392#undef REGISTER_INT
393#undef REGISTER_FULL_INT
394
395#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
396
397#define REGISTER(TYPE) \
398 REGISTER_KERNEL_BUILDER( \
399 Name("RandomUniform") \
400 .Device(DEVICE_GPU) \
401 .HostMemory("shape") \
402 .TypeConstraint<int32>("T") \
403 .TypeConstraint<TYPE>("dtype"), \
404 PhiloxRandomOp<GPUDevice, random::UniformDistribution< \
405 random::PhiloxRandom, TYPE>>); \
406 REGISTER_KERNEL_BUILDER( \
407 Name("RandomStandardNormal") \
408 .Device(DEVICE_GPU) \
409 .HostMemory("shape") \
410 .TypeConstraint<int32>("T") \
411 .TypeConstraint<TYPE>("dtype"), \
412 PhiloxRandomOp<GPUDevice, \
413 random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
414 REGISTER_KERNEL_BUILDER( \
415 Name("TruncatedNormal") \
416 .Device(DEVICE_GPU) \
417 .HostMemory("shape") \
418 .TypeConstraint<int32>("T") \
419 .TypeConstraint<TYPE>("dtype"), \
420 PhiloxRandomOp< \
421 GPUDevice, \
422 random::TruncatedNormalDistribution< \
423 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
424
425#define REGISTER_FULL_INT(IntType) \
426 template struct functor::FillPhiloxRandom< \
427 GPUDevice, \
428 random::UniformFullIntDistribution<random::PhiloxRandom, IntType>>
429
430#define REGISTER_INT(IntType) \
431 REGISTER_FULL_INT(IntType); \
432 template struct functor::FillPhiloxRandom< \
433 GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
434 REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
435 .Device(DEVICE_GPU) \
436 .HostMemory("shape") \
437 .HostMemory("minval") \
438 .HostMemory("maxval") \
439 .TypeConstraint<int32>("T") \
440 .TypeConstraint<IntType>("Tout"), \
441 RandomUniformIntOp<GPUDevice, IntType>);
442
443TF_CALL_half(REGISTER);
444TF_CALL_float(REGISTER);
445TF_CALL_double(REGISTER);
446TF_CALL_int32(REGISTER_INT);
447TF_CALL_int64(REGISTER_INT);
448TF_CALL_uint32(REGISTER_FULL_INT);
449TF_CALL_uint64(REGISTER_FULL_INT);
450
451#undef REGISTER
452#undef REGISTER_INT
453#undef REGISTER_FULL_INT
454
455#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
456
457
458} // end namespace tensorflow
459