1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/stateless_random_gamma_op.h" |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor_util.h" |
25 | #include "tensorflow/core/kernels/stateless_random_ops.h" |
26 | #include "tensorflow/core/lib/random/philox_random.h" |
27 | #include "tensorflow/core/lib/random/random_distributions.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/util/work_sharder.h" |
30 | |
31 | #if EIGEN_COMP_GNUC && __cplusplus > 199711L |
32 | #define DISABLE_FLOAT_EQUALITY_WARNING \ |
33 | _Pragma("GCC diagnostic push") \ |
34 | _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") |
35 | #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") |
36 | #else |
37 | #define DISABLE_FLOAT_EQUALITY_WARNING |
38 | #define ENABLE_FLOAT_EQUALITY_WARNING |
39 | #endif |
40 | |
41 | namespace tensorflow { |
42 | |
43 | namespace { |
44 | |
45 | // Each attempt to generate a new draw from the Gamma distribution is 95+% |
46 | // successful, and requires 1-2 normal + 1 uniform sample. |
47 | static constexpr int kReservedSamplesPerOutput = 256; |
48 | |
49 | typedef Eigen::ThreadPoolDevice CPUDevice; |
50 | typedef Eigen::GpuDevice GPUDevice; |
51 | |
52 | }; // namespace |
53 | |
54 | namespace functor { |
55 | |
56 | template <typename T> |
57 | struct StatelessRandomGammaFunctor<CPUDevice, T> { |
58 | static Status Fill(OpKernelContext* ctx, const T* alpha_flat, |
59 | int64_t num_samples, int64_t num_alphas, |
60 | int64_t samples_per_alpha, |
61 | const random::PhiloxRandom& random, T* samples_flat) { |
62 | typedef random::NormalDistribution<random::PhiloxRandom, double> Normal; |
63 | typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform; |
64 | |
65 | // We partition work first across alphas then across samples-per-alpha to |
66 | // avoid a couple flops which can be done on a per-alpha basis. |
67 | |
68 | auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat, |
69 | alpha_flat](int64_t start_output, int64_t limit_output) { |
70 | // Capturing "random" by-value would only make a copy for the _shared_ |
71 | // lambda. Since we want to let each worker have its own copy, we pass |
72 | // "random" by reference and explicitly do a copy assignment. |
73 | |
74 | using Eigen::numext::exp; |
75 | using Eigen::numext::log; |
76 | using Eigen::numext::log1p; |
77 | using Eigen::numext::pow; |
78 | |
79 | Normal normal; |
80 | Uniform uniform; |
81 | |
82 | RandomSampleBuffer<Normal> normal_buffer(&normal); |
83 | RandomSampleBuffer<Uniform> uniform_buffer(&uniform); |
84 | |
85 | for (int64_t output_idx = start_output; output_idx < limit_output; |
86 | /* output_idx incremented within inner loop below */) { |
87 | int64_t alpha_idx = output_idx / samples_per_alpha; |
88 | |
89 | // Instead of +alpha_idx for each sample, we offset the pointer once. |
90 | T* const samples_alpha_offset = samples_flat + alpha_idx; |
91 | |
92 | // Several calculations can be done on a per-alpha basis. |
93 | const double alpha = static_cast<double>(alpha_flat[alpha_idx]); |
94 | |
95 | DISABLE_FLOAT_EQUALITY_WARNING |
96 | if (alpha == 1.0) { |
97 | ENABLE_FLOAT_EQUALITY_WARNING |
98 | // Sample from an exponential distribution. |
99 | for (int64_t sample_idx = output_idx % samples_per_alpha; |
100 | sample_idx < samples_per_alpha && output_idx < limit_output; |
101 | sample_idx++, output_idx++) { |
102 | // As we want data stable regardless of sharding, we skip on a |
103 | // per-sample basis. |
104 | random::PhiloxRandom gen = random; |
105 | gen.Skip(kReservedSamplesPerOutput * output_idx); |
106 | double u = uniform(&gen)[Uniform::kResultElementCount - 1]; |
107 | const double res = -log1p(-u); |
108 | samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res); |
109 | } // for (sample_idx) |
110 | } else { // if alpha != 1.0 |
111 | // Transformation-rejection from pairs of uniform and normal random |
112 | // variables. http://dl.acm.org/citation.cfm?id=358414 |
113 | // |
114 | // The algorithm has an acceptance rate of ~95% for small alpha (~1), |
115 | // and higher accept rates for higher alpha, so runtime is |
116 | // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95. |
117 | // |
118 | // For alpha<1, we add one to d=alpha-1/3, and multiply the final |
119 | // result by uniform()^(1/alpha) |
120 | const bool alpha_less_than_one = alpha < 1.0; |
121 | const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3); |
122 | const double c = 1.0 / 3 / sqrt(d); |
123 | |
124 | // Compute the rest of the samples for the current alpha value. |
125 | for (int64_t sample_idx = output_idx % samples_per_alpha; |
126 | sample_idx < samples_per_alpha && output_idx < limit_output; |
127 | sample_idx++, output_idx++) { |
128 | // Since each sample may use a variable number of normal/uniform |
129 | // samples, and we want data stable regardless of sharding, we skip |
130 | // on a per-sample basis. |
131 | random::PhiloxRandom gen = random; |
132 | gen.Skip(kReservedSamplesPerOutput * output_idx); |
133 | |
134 | // To prevent overwriting SampleBuffer's underlying array with |
135 | // zeros (in tensorflow::random::Array constructor), we just mark |
136 | // the buffer as empty instead of initializing a new SampleBuffer |
137 | // object here. The next call to operator() will fill the buffer |
138 | // with new numbers. |
139 | normal_buffer.Clear(); |
140 | uniform_buffer.Clear(); |
141 | |
142 | // Keep trying until we don't reject a sample. In practice, we will |
143 | // only reject ~5% at worst, for low alpha near 1. |
144 | while (true) { |
145 | const double x = normal_buffer(&gen); |
146 | double v = 1 + c * x; |
147 | if (v <= 0) { |
148 | continue; |
149 | } |
150 | v = v * v * v; |
151 | double u = uniform_buffer(&gen); |
152 | // The first option in the if is a "squeeze" short-circuit to |
153 | // dodge the two logs. Magic constant sourced from the paper |
154 | // linked above. Upward of .91 of the area covered by the log |
155 | // inequality is covered by the squeeze as well (larger coverage |
156 | // for smaller values of alpha). |
157 | if ((u < 1 - 0.0331 * (x * x) * (x * x)) || |
158 | (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) { |
159 | double res = d * v; |
160 | if (alpha_less_than_one) { |
161 | double b = uniform_buffer(&gen); |
162 | res *= pow(b, 1 / alpha); |
163 | } |
164 | samples_alpha_offset[sample_idx * num_alphas] = |
165 | static_cast<T>(res); |
166 | break; |
167 | } |
168 | } // while: true |
169 | } // for: sample_idx |
170 | } // if (alpha == 1.0) |
171 | } // for: output_idx |
172 | }; // DoWork |
173 | |
174 | // Two calls to log only occur for ~10% of samples reaching the log line. |
175 | // 2 x 100 (64-bit cycles per log) x 0.10 = ~20. |
176 | // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles |
177 | // each = ~60. |
178 | // All of this /0.95 (expected value of geometric distribution is 1/p) due |
179 | // to the rejection possibility = ~85. |
180 | static const int kElementCost = 85 + 2 * Normal::kElementCost + |
181 | Uniform::kElementCost + |
182 | 3 * random::PhiloxRandom::kElementCost; |
183 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
184 | Shard(worker_threads.num_threads, worker_threads.workers, num_samples, |
185 | kElementCost, DoWork); |
186 | return OkStatus(); |
187 | } |
188 | }; |
189 | |
190 | } // namespace functor |
191 | |
192 | namespace { |
193 | |
194 | template <typename Device, typename T> |
195 | class StatelessRandomGammaOp : public OpKernel { |
196 | public: |
197 | explicit StatelessRandomGammaOp(OpKernelConstruction* context) |
198 | : OpKernel(context) {} |
199 | |
200 | void Compute(OpKernelContext* context) override { |
201 | // Sanitize input |
202 | const Tensor& shape_t = context->input(0); |
203 | const Tensor& seed_t = context->input(1); |
204 | TensorShape shape; |
205 | OP_REQUIRES_OK(context, tensor::MakeShape(shape_t, &shape)); |
206 | OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, |
207 | errors::InvalidArgument("seed must have shape [2], not " , |
208 | seed_t.shape().DebugString())); |
209 | |
210 | // Allocate output |
211 | Tensor* output; |
212 | OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); |
213 | if (shape.num_elements() == 0) return; |
214 | |
215 | random::PhiloxRandom::Key key; |
216 | random::PhiloxRandom::ResultType counter; |
217 | OP_REQUIRES_OK(context, GenerateKey(seed_t, &key, &counter)); |
218 | |
219 | // Fill in the random numbers |
220 | Fill(context, random::PhiloxRandom(counter, key), output); |
221 | } |
222 | |
223 | private: |
224 | void Fill(OpKernelContext* ctx, random::PhiloxRandom random, Tensor* output) { |
225 | const Tensor& alpha_t = ctx->input(2); |
226 | |
227 | TensorShape samples_shape = output->shape(); |
228 | OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(samples_shape, alpha_t.shape()), |
229 | errors::InvalidArgument( |
230 | "Shape passed in must end with broadcasted shape." )); |
231 | |
232 | const int64_t num_alphas = alpha_t.NumElements(); |
233 | OP_REQUIRES(ctx, num_alphas > 0, |
234 | errors::InvalidArgument( |
235 | "Input alpha should have non-zero element count, got: " , |
236 | num_alphas)); |
237 | |
238 | const int64_t num_samples = samples_shape.num_elements(); |
239 | const int64_t samples_per_alpha = num_samples / num_alphas; |
240 | const auto alpha_flat = alpha_t.flat<T>().data(); |
241 | auto samples_flat = output->flat<T>().data(); |
242 | |
243 | OP_REQUIRES_OK(ctx, functor::StatelessRandomGammaFunctor<Device, T>::Fill( |
244 | ctx, alpha_flat, num_samples, num_alphas, |
245 | samples_per_alpha, random, samples_flat)); |
246 | } |
247 | |
248 | TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomGammaOp); |
249 | }; |
250 | |
251 | // Register CPU kernels for stateless gamma op. |
252 | #define REGISTER_GAMMA_CPU(TYPE) \ |
253 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \ |
254 | .Device(DEVICE_CPU) \ |
255 | .HostMemory("shape") \ |
256 | .HostMemory("seed") \ |
257 | .HostMemory("alpha") \ |
258 | .TypeConstraint<TYPE>("dtype"), \ |
259 | StatelessRandomGammaOp<CPUDevice, TYPE>) |
260 | |
261 | TF_CALL_half(REGISTER_GAMMA_CPU); |
262 | TF_CALL_bfloat16(REGISTER_GAMMA_CPU); |
263 | TF_CALL_float(REGISTER_GAMMA_CPU); |
264 | TF_CALL_double(REGISTER_GAMMA_CPU); |
265 | |
266 | #undef REGISTER_GAMMA_CPU |
267 | |
268 | // Register GPU kernels for stateless gamma op. |
269 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
270 | |
271 | #define REGISTER_GAMMA_GPU(TYPE) \ |
272 | REGISTER_KERNEL_BUILDER(Name("StatelessRandomGammaV2") \ |
273 | .Device(DEVICE_GPU) \ |
274 | .HostMemory("shape") \ |
275 | .HostMemory("seed") \ |
276 | .TypeConstraint<TYPE>("dtype"), \ |
277 | StatelessRandomGammaOp<GPUDevice, TYPE>) |
278 | |
279 | TF_CALL_half(REGISTER_GAMMA_GPU); |
280 | TF_CALL_bfloat16(REGISTER_GAMMA_GPU); |
281 | TF_CALL_float(REGISTER_GAMMA_GPU); |
282 | TF_CALL_double(REGISTER_GAMMA_GPU); |
283 | |
284 | #undef REGISTER_GAMMA_GPU |
285 | |
286 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
287 | |
288 | } // namespace |
289 | } // namespace tensorflow |
290 | |