1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/random_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <algorithm> |
21 | #include <cmath> |
22 | #include <memory> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/register_types.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/tensor_util.h" |
29 | #include "tensorflow/core/kernels/random_op_cpu.h" |
30 | #include "tensorflow/core/lib/hash/crc32c.h" |
31 | #include "tensorflow/core/lib/random/random_distributions.h" |
32 | #include "tensorflow/core/lib/random/simple_philox.h" |
33 | #include "tensorflow/core/platform/logging.h" |
34 | #include "tensorflow/core/util/guarded_philox_random.h" |
35 | #include "tensorflow/core/util/work_sharder.h" |
36 | |
37 | #if EIGEN_COMP_GNUC && __cplusplus > 199711L |
38 | #define DISABLE_FLOAT_EQUALITY_WARNING \ |
39 | _Pragma("GCC diagnostic push") \ |
40 | _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") |
41 | #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") |
42 | #else |
43 | #define DISABLE_FLOAT_EQUALITY_WARNING |
44 | #define ENABLE_FLOAT_EQUALITY_WARNING |
45 | #endif |
46 | |
47 | namespace tensorflow { |
48 | |
49 | typedef Eigen::ThreadPoolDevice CPUDevice; |
50 | typedef Eigen::GpuDevice GPUDevice; |
51 | |
52 | namespace { |
53 | |
54 | static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, |
55 | int index, Tensor** output) { |
56 | TensorShape tensor_shape; |
57 | TF_RETURN_IF_ERROR(tensor::MakeShape(shape, &tensor_shape)); |
58 | return ctx->allocate_output(index, tensor_shape, output); |
59 | } |
60 | |
61 | // For now, use the same interface as RandomOp, so we can choose either one |
62 | // at the run-time. |
63 | template <typename Device, class Distribution> |
64 | class PhiloxRandomOp : public OpKernel { |
65 | public: |
66 | typedef typename Distribution::ResultElementType T; |
67 | explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
68 | OP_REQUIRES_OK(ctx, generator_.Init(ctx)); |
69 | } |
70 | |
71 | void Compute(OpKernelContext* ctx) override { |
72 | const Tensor& shape = ctx->input(0); |
73 | Tensor* output; |
74 | OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output)); |
75 | auto output_flat = output->flat<T>(); |
76 | functor::FillPhiloxRandom<Device, Distribution>()( |
77 | ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr, |
78 | // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change |
79 | // it just here. |
80 | generator_.ReserveRandomOutputs(output_flat.size(), 256), |
81 | output_flat.data(), output_flat.size(), Distribution()); |
82 | } |
83 | |
84 | private: |
85 | GuardedPhiloxRandom generator_; |
86 | }; |
87 | |
88 | template <typename Device, class IntType> |
89 | class RandomUniformIntOp : public OpKernel { |
90 | public: |
91 | explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
92 | OP_REQUIRES_OK(ctx, generator_.Init(ctx)); |
93 | } |
94 | |
95 | void Compute(OpKernelContext* ctx) override { |
96 | const Tensor& shape = ctx->input(0); |
97 | const Tensor& minval = ctx->input(1); |
98 | const Tensor& maxval = ctx->input(2); |
99 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()), |
100 | errors::InvalidArgument("minval must be 0-D, got shape " , |
101 | minval.shape().DebugString())); |
102 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()), |
103 | errors::InvalidArgument("maxval must be 0-D, got shape " , |
104 | maxval.shape().DebugString())); |
105 | |
106 | // Allocate output, and exit early if possible |
107 | Tensor* output; |
108 | OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output)); |
109 | if (output->NumElements() == 0) return; |
110 | |
111 | // Verify that minval < maxval. This check intentionally happens after the |
112 | // early exit for empty output. Zero impossible things are fine. |
113 | IntType lo = minval.scalar<IntType>()(); |
114 | IntType hi = maxval.scalar<IntType>()(); |
115 | OP_REQUIRES( |
116 | ctx, lo < hi, |
117 | errors::InvalidArgument("Need minval < maxval, got " , lo, " >= " , hi)); |
118 | |
119 | // Build distribution |
120 | typedef random::UniformDistribution<random::PhiloxRandom, IntType> |
121 | Distribution; |
122 | Distribution dist(lo, hi); |
123 | |
124 | auto output_flat = output->flat<IntType>(); |
125 | functor::FillPhiloxRandom<Device, Distribution>()( |
126 | ctx, ctx->eigen_device<Device>(), /*key=*/nullptr, /*counter=*/nullptr, |
127 | // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change |
128 | // it just here. |
129 | generator_.ReserveRandomOutputs(output_flat.size(), 256), |
130 | output_flat.data(), output_flat.size(), dist); |
131 | } |
132 | |
133 | private: |
134 | GuardedPhiloxRandom generator_; |
135 | }; |
136 | |
137 | // Samples from one or more gamma distributions. All internal computations are |
138 | // done with double precision for numerical stability. |
139 | template <typename T> |
140 | class RandomGammaOp : public OpKernel { |
141 | public: |
142 | explicit RandomGammaOp(OpKernelConstruction* context) : OpKernel(context) { |
143 | OP_REQUIRES_OK(context, generator_.Init(context)); |
144 | } |
145 | |
146 | void Compute(OpKernelContext* ctx) override { |
147 | const Tensor& shape_t = ctx->input(0); |
148 | const Tensor& alpha_t = ctx->input(1); |
149 | |
150 | OP_REQUIRES(ctx, |
151 | TensorShapeUtils::IsVector(shape_t.shape()) && |
152 | (shape_t.dtype() == DataType::DT_INT32 || |
153 | shape_t.dtype() == DataType::DT_INT64), |
154 | errors::InvalidArgument( |
155 | "shape must be a vector of {int32,int64}, got shape: " , |
156 | shape_t.DebugString())); |
157 | TensorShape samples_shape; |
158 | if (shape_t.dtype() == DataType::DT_INT32) { |
159 | auto vec = shape_t.flat<int32>(); |
160 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), |
161 | &samples_shape)); |
162 | } else if (shape_t.dtype() == DataType::DT_INT64) { |
163 | auto vec = shape_t.flat<int64_t>(); |
164 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(), |
165 | &samples_shape)); |
166 | } |
167 | const int64_t samples_per_alpha = samples_shape.num_elements(); |
168 | |
169 | OP_REQUIRES_OK(ctx, samples_shape.AppendShapeWithStatus(alpha_t.shape())); |
170 | // Allocate output samples. |
171 | Tensor* samples_t = nullptr; |
172 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t)); |
173 | |
174 | if (samples_shape.num_elements() == 0) return; |
175 | |
176 | using random::PhiloxRandom; |
177 | |
178 | typedef random::NormalDistribution<PhiloxRandom, double> Normal; |
179 | typedef random::UniformDistribution<PhiloxRandom, double> Uniform; |
180 | #define UNIFORM(X) \ |
181 | if (uniform_remaining == 0) { \ |
182 | uniform_remaining = Uniform::kResultElementCount; \ |
183 | uniform_result = uniform(&gen); \ |
184 | } \ |
185 | uniform_remaining--; \ |
186 | double X = uniform_result[uniform_remaining] |
187 | |
188 | // Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform |
189 | static constexpr int kReservedSamplesPerOutput = 256; |
190 | |
191 | const auto alpha_flat = alpha_t.flat<T>().data(); |
192 | const int64_t num_alphas = alpha_t.NumElements(); |
193 | OP_REQUIRES(ctx, num_alphas > 0, |
194 | errors::InvalidArgument( |
195 | "Input alpha should have non-zero element count, got: " , |
196 | num_alphas)); |
197 | auto samples_flat = samples_t->flat<T>().data(); |
198 | PhiloxRandom rng = generator_.ReserveRandomOutputs( |
199 | samples_per_alpha * num_alphas, kReservedSamplesPerOutput); |
200 | |
201 | // We partition work first across alphas then across samples-per-alpha to |
202 | // avoid a couple flops which can be done on a per-alpha basis. |
203 | |
204 | auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat, |
205 | alpha_flat](int64_t start_output, int64_t limit_output) { |
206 | using Eigen::numext::exp; |
207 | using Eigen::numext::log; |
208 | using Eigen::numext::log1p; |
209 | using Eigen::numext::pow; |
210 | |
211 | // Capturing "rng" by-value would only make a copy for the _shared_ |
212 | // lambda. Since we want to let each worker have its own copy, we pass |
213 | // "rng" by reference and explicitly do a copy assignment. |
214 | |
215 | Normal normal; |
216 | Uniform uniform; |
217 | typename Normal::ResultType norm_result; |
218 | typename Uniform::ResultType uniform_result; |
219 | for (int64_t output_idx = start_output; output_idx < limit_output; |
220 | /* output_idx incremented within inner loop below */) { |
221 | int64_t alpha_idx = output_idx / samples_per_alpha; |
222 | |
223 | // Instead of +alpha_idx for each sample, we offset the pointer once. |
224 | T* const samples_alpha_offset = samples_flat + alpha_idx; |
225 | |
226 | // Several calculations can be done on a per-alpha basis. |
227 | const double alpha = static_cast<double>(alpha_flat[alpha_idx]); |
228 | |
229 | DISABLE_FLOAT_EQUALITY_WARNING |
230 | if (alpha == static_cast<double>(1.0)) { |
231 | ENABLE_FLOAT_EQUALITY_WARNING |
232 | // Sample from an exponential distribution. |
233 | for (int64_t sample_idx = output_idx % samples_per_alpha; |
234 | sample_idx < samples_per_alpha && output_idx < limit_output; |
235 | sample_idx++, output_idx++) { |
236 | // As we want data stable regardless of sharding |
237 | // (including eventually on GPU), we skip on a per-sample basis. |
238 | PhiloxRandom gen = rng; |
239 | gen.Skip(kReservedSamplesPerOutput * output_idx); |
240 | int16_t uniform_remaining = 0; |
241 | UNIFORM(u); |
242 | const double res = -log1p(-u); |
243 | samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res); |
244 | } // for (sample_idx) |
245 | } else { // if alpha != 1.0 |
246 | // Transformation-rejection from pairs of uniform and normal random |
247 | // variables. http://dl.acm.org/citation.cfm?id=358414 |
248 | // |
249 | // The algorithm has an acceptance rate of ~95% for small alpha (~1), |
250 | // and higher accept rates for higher alpha, so runtime is |
251 | // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95. |
252 | // |
253 | // For alpha<1, we add one to d=alpha-1/3, and multiply the final |
254 | // result by uniform()^(1/alpha) |
255 | const bool alpha_less_than_one = alpha < 1; |
256 | const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3); |
257 | const double c = 1.0 / 3 / sqrt(d); |
258 | |
259 | // Compute the rest of the samples for the current alpha value. |
260 | for (int64_t sample_idx = output_idx % samples_per_alpha; |
261 | sample_idx < samples_per_alpha && output_idx < limit_output; |
262 | sample_idx++, output_idx++) { |
263 | // Since each sample may use a variable number of normal/uniform |
264 | // samples, and we want data stable regardless of sharding |
265 | // (including eventually on GPU), we skip on a per-sample basis. |
266 | PhiloxRandom gen = rng; |
267 | gen.Skip(kReservedSamplesPerOutput * output_idx); |
268 | int16_t norm_remaining = 0; |
269 | int16_t uniform_remaining = 0; |
270 | |
271 | // Keep trying until we don't reject a sample. In practice, we will |
272 | // only reject ~5% at worst, for low alpha near 1. |
273 | while (true) { |
274 | if (norm_remaining == 0) { |
275 | norm_remaining = Normal::kResultElementCount; |
276 | norm_result = normal(&gen); |
277 | } |
278 | norm_remaining--; |
279 | const double x = norm_result[norm_remaining]; |
280 | double v = 1 + c * x; |
281 | if (v <= 0) { |
282 | continue; |
283 | } |
284 | v = v * v * v; |
285 | UNIFORM(u); |
286 | // The first option in the if is a "squeeze" short-circuit to |
287 | // dodge the two logs. Magic constant sourced from the paper |
288 | // linked above. Upward of .91 of the area covered by the log |
289 | // inequality is covered by the squeeze as well (larger coverage |
290 | // for smaller values of alpha). |
291 | if ((u < 1 - 0.0331 * (x * x) * (x * x)) || |
292 | (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) { |
293 | double res = d * v; |
294 | if (alpha_less_than_one) { |
295 | UNIFORM(b); |
296 | res *= pow(b, 1 / alpha); |
297 | } |
298 | samples_alpha_offset[sample_idx * num_alphas] = |
299 | static_cast<T>(res); |
300 | break; |
301 | } |
302 | } // while: true |
303 | } // for: sample_idx |
304 | } // if (alpha == 1.0) |
305 | } // for: output_idx |
306 | }; // DoWork |
307 | #undef UNIFORM |
308 | // Two calls to log only occur for ~10% of samples reaching the log line. |
309 | // 2 x 100 (64-bit cycles per log) x 0.10 = ~20. |
310 | // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles |
311 | // each = ~60. |
312 | // All of this /0.95 due to the rejection possibility = ~85. |
313 | static const int kElementCost = 85 + 2 * Normal::kElementCost + |
314 | Uniform::kElementCost + |
315 | 3 * PhiloxRandom::kElementCost; |
316 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
317 | Shard(worker_threads.num_threads, worker_threads.workers, |
318 | num_alphas * samples_per_alpha, kElementCost, DoWork); |
319 | } |
320 | |
321 | private: |
322 | GuardedPhiloxRandom generator_; |
323 | |
324 | TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp); |
325 | }; |
326 | |
327 | } // namespace |
328 | |
329 | #define REGISTER(TYPE) \ |
330 | template struct functor::FillPhiloxRandom< \ |
331 | CPUDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>; \ |
332 | template struct functor::FillPhiloxRandom< \ |
333 | CPUDevice, random::NormalDistribution<random::PhiloxRandom, TYPE>>; \ |
334 | template struct functor::FillPhiloxRandom< \ |
335 | CPUDevice, \ |
336 | random::TruncatedNormalDistribution< \ |
337 | random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>; \ |
338 | REGISTER_KERNEL_BUILDER( \ |
339 | Name("RandomUniform") \ |
340 | .Device(DEVICE_CPU) \ |
341 | .HostMemory("shape") \ |
342 | .TypeConstraint<TYPE>("dtype"), \ |
343 | PhiloxRandomOp<CPUDevice, random::UniformDistribution< \ |
344 | random::PhiloxRandom, TYPE>>); \ |
345 | REGISTER_KERNEL_BUILDER( \ |
346 | Name("RandomStandardNormal") \ |
347 | .Device(DEVICE_CPU) \ |
348 | .HostMemory("shape") \ |
349 | .TypeConstraint<TYPE>("dtype"), \ |
350 | PhiloxRandomOp<CPUDevice, \ |
351 | random::NormalDistribution<random::PhiloxRandom, TYPE>>); \ |
352 | REGISTER_KERNEL_BUILDER( \ |
353 | Name("TruncatedNormal") \ |
354 | .Device(DEVICE_CPU) \ |
355 | .HostMemory("shape") \ |
356 | .TypeConstraint<TYPE>("dtype"), \ |
357 | PhiloxRandomOp< \ |
358 | CPUDevice, \ |
359 | random::TruncatedNormalDistribution< \ |
360 | random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); \ |
361 | REGISTER_KERNEL_BUILDER( \ |
362 | Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ |
363 | RandomGammaOp<TYPE>) |
364 | |
365 | #define REGISTER_FULL_INT(IntType) \ |
366 | template struct functor::FillPhiloxRandom< \ |
367 | CPUDevice, \ |
368 | random::UniformFullIntDistribution<random::PhiloxRandom, IntType>> |
369 | |
370 | #define REGISTER_INT(IntType) \ |
371 | REGISTER_FULL_INT(IntType); \ |
372 | template struct functor::FillPhiloxRandom< \ |
373 | CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ |
374 | REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ |
375 | .Device(DEVICE_CPU) \ |
376 | .HostMemory("shape") \ |
377 | .HostMemory("minval") \ |
378 | .HostMemory("maxval") \ |
379 | .TypeConstraint<IntType>("Tout"), \ |
380 | RandomUniformIntOp<CPUDevice, IntType>); |
381 | |
382 | TF_CALL_half(REGISTER); |
383 | TF_CALL_bfloat16(REGISTER); |
384 | TF_CALL_float(REGISTER); |
385 | TF_CALL_double(REGISTER); |
386 | TF_CALL_int32(REGISTER_INT); |
387 | TF_CALL_int64(REGISTER_INT); |
388 | TF_CALL_uint32(REGISTER_FULL_INT); |
389 | TF_CALL_uint64(REGISTER_FULL_INT); |
390 | |
391 | #undef REGISTER |
392 | #undef REGISTER_INT |
393 | #undef REGISTER_FULL_INT |
394 | |
395 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
396 | |
397 | #define REGISTER(TYPE) \ |
398 | REGISTER_KERNEL_BUILDER( \ |
399 | Name("RandomUniform") \ |
400 | .Device(DEVICE_GPU) \ |
401 | .HostMemory("shape") \ |
402 | .TypeConstraint<int32>("T") \ |
403 | .TypeConstraint<TYPE>("dtype"), \ |
404 | PhiloxRandomOp<GPUDevice, random::UniformDistribution< \ |
405 | random::PhiloxRandom, TYPE>>); \ |
406 | REGISTER_KERNEL_BUILDER( \ |
407 | Name("RandomStandardNormal") \ |
408 | .Device(DEVICE_GPU) \ |
409 | .HostMemory("shape") \ |
410 | .TypeConstraint<int32>("T") \ |
411 | .TypeConstraint<TYPE>("dtype"), \ |
412 | PhiloxRandomOp<GPUDevice, \ |
413 | random::NormalDistribution<random::PhiloxRandom, TYPE>>); \ |
414 | REGISTER_KERNEL_BUILDER( \ |
415 | Name("TruncatedNormal") \ |
416 | .Device(DEVICE_GPU) \ |
417 | .HostMemory("shape") \ |
418 | .TypeConstraint<int32>("T") \ |
419 | .TypeConstraint<TYPE>("dtype"), \ |
420 | PhiloxRandomOp< \ |
421 | GPUDevice, \ |
422 | random::TruncatedNormalDistribution< \ |
423 | random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); |
424 | |
425 | #define REGISTER_FULL_INT(IntType) \ |
426 | template struct functor::FillPhiloxRandom< \ |
427 | GPUDevice, \ |
428 | random::UniformFullIntDistribution<random::PhiloxRandom, IntType>> |
429 | |
430 | #define REGISTER_INT(IntType) \ |
431 | REGISTER_FULL_INT(IntType); \ |
432 | template struct functor::FillPhiloxRandom< \ |
433 | GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ |
434 | REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ |
435 | .Device(DEVICE_GPU) \ |
436 | .HostMemory("shape") \ |
437 | .HostMemory("minval") \ |
438 | .HostMemory("maxval") \ |
439 | .TypeConstraint<int32>("T") \ |
440 | .TypeConstraint<IntType>("Tout"), \ |
441 | RandomUniformIntOp<GPUDevice, IntType>); |
442 | |
443 | TF_CALL_half(REGISTER); |
444 | TF_CALL_float(REGISTER); |
445 | TF_CALL_double(REGISTER); |
446 | TF_CALL_int32(REGISTER_INT); |
447 | TF_CALL_int64(REGISTER_INT); |
448 | TF_CALL_uint32(REGISTER_FULL_INT); |
449 | TF_CALL_uint64(REGISTER_FULL_INT); |
450 | |
451 | #undef REGISTER |
452 | #undef REGISTER_INT |
453 | #undef REGISTER_FULL_INT |
454 | |
455 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
456 | |
457 | |
458 | } // end namespace tensorflow |
459 | |