1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/random_ops.cc. |
17 | // NOTE: If the algorithm is changed, please run the test |
18 | // .../python/kernel_tests:parameterized_truncated_normal_op_test |
19 | // commenting out the "tf.set_random_seed(seed)" lines, and using the |
20 | // "--runs-per-test=1000" flag. This tests the statistical correctness of the |
21 | // op results. |
22 | |
23 | #define EIGEN_USE_THREADS |
24 | |
25 | #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" |
26 | |
27 | #include <algorithm> |
28 | #include <cmath> |
29 | #include <memory> |
30 | |
31 | #include "tensorflow/core/framework/op_kernel.h" |
32 | #include "tensorflow/core/framework/register_types.h" |
33 | #include "tensorflow/core/framework/tensor.h" |
34 | #include "tensorflow/core/framework/tensor_shape.h" |
35 | #include "tensorflow/core/framework/tensor_util.h" |
36 | #include "tensorflow/core/kernels/stateless_random_ops.h" |
37 | #include "tensorflow/core/lib/random/random_distributions.h" |
38 | #include "tensorflow/core/platform/logging.h" |
39 | #include "tensorflow/core/util/guarded_philox_random.h" |
40 | #include "tensorflow/core/util/work_sharder.h" |
41 | |
42 | namespace tensorflow { |
43 | |
44 | typedef Eigen::ThreadPoolDevice CPUDevice; |
45 | typedef Eigen::GpuDevice GPUDevice; |
46 | |
47 | namespace functor { |
48 | using random::PhiloxRandom; |
49 | |
50 | static constexpr int kMaxIterations = 1000; |
51 | |
52 | template <typename T> |
53 | struct TruncatedNormalFunctor<CPUDevice, T> { |
54 | void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches, |
55 | int64_t samples_per_batch, int64_t num_elements, |
56 | typename TTypes<T>::ConstFlat means, |
57 | typename TTypes<T>::ConstFlat stddevs, |
58 | typename TTypes<T>::ConstFlat minvals, |
59 | typename TTypes<T>::ConstFlat maxvals, |
60 | const random::PhiloxRandom& gen, |
61 | typename TTypes<T>::Flat output) { |
62 | // The randn rejection sampling is used when the mean and at least this many |
63 | // standard deviations are inside the bounds. |
64 | // The uniform proposal samplers become less efficient as the bounds are |
65 | // further from the mean, the reverse is true for the randn sampler. |
66 | // This number was chosen by empirical benchmarking. If modified, the |
67 | // benchmarks in parameterized_truncated_normal_op_test should also be |
68 | // changed. |
69 | const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3); |
70 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
71 | |
72 | auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs, |
73 | &minvals, &maxvals, &gen, &output, |
74 | kStdDevsInsideBoundsToUseRandnSampler]( |
75 | int64_t start_batch, int64_t limit_batch) { |
76 | // Capturing "gen" by-value would only make a copy for the _shared_ |
77 | // lambda. Since we want to let each worker have its own copy, we pass |
78 | // "gen" by reference and explicitly do a copy assignment here. |
79 | random::PhiloxRandom gen_copy = gen; |
80 | // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to |
81 | // us using the same state in different batches. |
82 | // The sample from each iteration uses 2 random numbers. |
83 | gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) / |
84 | 4); |
85 | using Uniform = random::UniformDistribution<random::PhiloxRandom, T>; |
86 | Uniform dist; |
87 | using Normal = random::NormalDistribution<random::PhiloxRandom, T>; |
88 | Normal normal_dist; |
89 | |
90 | // Vectorized intermediate calculations for uniform rejection sampling. |
91 | // We always generate at most 4 samples. |
92 | Eigen::array<T, 4> z; |
93 | Eigen::array<T, 4> g; |
94 | |
95 | for (int64_t b = start_batch; b < limit_batch; ++b) { |
96 | // We are passed a flat array for each of the parameter tensors. |
97 | // The input is either a scalar broadcasted to all batches or a vector |
98 | // with length num_batches, but the scalar becomes an array of length 1. |
99 | T mean = means((means.dimension(0) == 1) ? 0 : b); |
100 | T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b); |
101 | T minval = minvals((minvals.dimension(0) == 1) ? 0 : b); |
102 | T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b); |
103 | |
104 | // The last batch can be short, if we adjusted num_batches and |
105 | // samples_per_batch. |
106 | const int64_t limit_sample = |
107 | std::min((b + 1) * samples_per_batch, num_elements); |
108 | int64_t sample = b * samples_per_batch; |
109 | |
110 | // On GPU, this check will just fill samples with NAN if it fails. |
111 | OP_REQUIRES(ctx, |
112 | stddev > T(0) && minval < maxval && |
113 | (Eigen::numext::isfinite(minval) || |
114 | Eigen::numext::isfinite(maxval)), |
115 | errors::InvalidArgument("Invalid parameters" )); |
116 | |
117 | int num_iterations = 0; |
118 | |
119 | // If possible, make one-sided bound be the lower bound, or make both |
120 | // bounds positive. Otherwise, the bounds are on either side of the |
121 | // mean. |
122 | if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) { |
123 | // Reverse all calculations. normMin and normMax will be flipped. |
124 | std::swap(minval, maxval); |
125 | stddev = -stddev; |
126 | } |
127 | |
128 | // Calculate normalized samples, then convert them. |
129 | const T normMin = (minval - mean) / stddev; |
130 | const T normMax = (maxval - mean) / stddev; |
131 | |
132 | // Determine the method to use. |
133 | const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); |
134 | const T cutoff = |
135 | T(2) * |
136 | Eigen::numext::exp(T(0.5) + |
137 | (normMin * (normMin - sqrtFactor)) / T(4)) / |
138 | (normMin + sqrtFactor); |
139 | const T diff = normMax - normMin; |
140 | |
141 | if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) && |
142 | (normMax >= T(0.))) || |
143 | ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && |
144 | (normMin <= T(0.)))) { |
145 | // If the bounds are a least 3 standard deviations from the mean |
146 | // on at least one side then we rejection sample by sampling |
147 | // from the normal distribution and rejecting samples outside |
148 | // the bounds. |
149 | // Under this condition the acceptance rate per iteration should |
150 | // always be ~ 50%. This sampler is more efficient (and more |
151 | // numerically stable when one or both bounds is far from the mean). |
152 | |
153 | while (sample < limit_sample) { |
154 | const auto randn_sample = normal_dist(&gen_copy); |
155 | const int size = randn_sample.size(); |
156 | |
157 | for (int i = 0; i < size; i++) { |
158 | if ((randn_sample[i] >= normMin) && |
159 | (randn_sample[i] <= normMax)) { |
160 | output(sample) = randn_sample[i] * stddev + mean; |
161 | sample++; |
162 | if (sample >= limit_sample) { |
163 | break; |
164 | } |
165 | num_iterations = 0; |
166 | } else { |
167 | num_iterations++; |
168 | if (num_iterations > kMaxIterations) { |
169 | // This should never occur because this sampler should |
170 | // (by the selection criteria above) be used if at least 3 |
171 | // standard deviations of one side of the distribution |
172 | // is within the limits (so acceptance probability per |
173 | // iterations >~ 1/2 per iteration). |
174 | LOG(ERROR) << "TruncatedNormal randn rejection sampler " |
175 | << "exceeded maximum iterations for " |
176 | << "normMin=" << normMin << " normMax=" << normMax |
177 | << " kMaxIterations=" << kMaxIterations; |
178 | ctx->SetStatus(errors::Internal( |
179 | "TruncatedNormal randn rejection sampler failed to accept" |
180 | " a sample." )); |
181 | return; |
182 | } |
183 | } |
184 | } |
185 | } |
186 | } else if (diff < cutoff) { |
187 | // Sample from a uniform distribution on [normMin, normMax]. |
188 | |
189 | const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; |
190 | |
191 | while (sample < limit_sample) { |
192 | const auto rand = dist(&gen_copy); |
193 | const int size = rand.size(); |
194 | // NOTE(ringwalt): These loops seem to only generate packed AVX |
195 | // instructions for float32. |
196 | for (int i = 0; i < size; i++) { |
197 | z[i] = rand[i] * diff + normMin; |
198 | } |
199 | for (int i = 0; i < size; i++) { |
200 | g[i] = (plusFactor - z[i] * z[i]) / T(2.0); |
201 | } |
202 | |
203 | const auto u = dist(&gen_copy); |
204 | for (int i = 0; i < size; i++) { |
205 | auto accept = u[i] <= Eigen::numext::exp(g[i]); |
206 | if (accept || num_iterations + 1 >= kMaxIterations) { |
207 | // Accept the sample z. |
208 | // If we run out of iterations, just use the current uniform |
209 | // sample, but emit a warning. |
210 | // TODO(jjhunt) For small entropies (relative to the bounds), |
211 | // this sampler is poor and may take many iterations since |
212 | // the proposal distribution is the uniform distribution |
213 | // U(lower_bound, upper_bound). |
214 | if (!accept) { |
215 | LOG(ERROR) << "TruncatedNormal uniform rejection sampler " |
216 | << "exceeded max iterations. Sample may contain " |
217 | << "outliers." ; |
218 | ctx->SetStatus(errors::Internal( |
219 | "TruncatedNormal uniform rejection sampler failed to " |
220 | " accept a sample." )); |
221 | return; |
222 | } |
223 | output(sample) = z[i] * stddev + mean; |
224 | sample++; |
225 | if (sample >= limit_sample) { |
226 | break; |
227 | } |
228 | num_iterations = 0; |
229 | } else { |
230 | num_iterations++; |
231 | } |
232 | } |
233 | } |
234 | } else { |
235 | // Sample from an exponential distribution with alpha maximizing |
236 | // acceptance probability, offset by normMin from the origin. |
237 | // Accept only if less than normMax. |
238 | const T alpha = |
239 | (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / |
240 | T(2); |
241 | while (sample < limit_sample) { |
242 | auto rand = dist(&gen_copy); |
243 | const int size = rand.size(); |
244 | int i = 0; |
245 | while (i < size) { |
246 | const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; |
247 | i++; |
248 | const T x = normMin < alpha ? alpha - z : normMin - alpha; |
249 | const T g = Eigen::numext::exp(-x * x / T(2.0)); |
250 | const T u = rand[i]; |
251 | i++; |
252 | auto accept = (u <= g && z < normMax); |
253 | if (accept || num_iterations + 1 >= kMaxIterations) { |
254 | if (!accept) { |
255 | LOG(ERROR) << "TruncatedNormal exponential distribution " |
256 | << "rejection sampler exceeds max iterations. " |
257 | << "Sample may contain outliers." ; |
258 | ctx->SetStatus(errors::Internal( |
259 | "TruncatedNormal exponential distribution rejection" |
260 | " sampler failed to accept a sample." )); |
261 | return; |
262 | } |
263 | output(sample) = z * stddev + mean; |
264 | sample++; |
265 | if (sample >= limit_sample) { |
266 | break; |
267 | } |
268 | num_iterations = 0; |
269 | } else { |
270 | num_iterations++; |
271 | } |
272 | } |
273 | } |
274 | } |
275 | } |
276 | }; |
277 | // The cost of the initial calculations for the batch. |
278 | const int64_t batchInitCost = |
279 | // normMin, normMax |
280 | (Eigen::TensorOpCost::AddCost<T>() + |
281 | Eigen::TensorOpCost::MulCost<T>()) * |
282 | 2 |
283 | // sqrtFactor |
284 | + Eigen::TensorOpCost::AddCost<T>() + |
285 | Eigen::TensorOpCost::MulCost<T>() + |
286 | Eigen::internal::functor_traits< |
287 | Eigen::internal::scalar_sqrt_op<T>>::Cost |
288 | // cutoff |
289 | + Eigen::TensorOpCost::MulCost<T>() * 4 + |
290 | Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost |
291 | // diff |
292 | + Eigen::TensorOpCost::AddCost<T>(); |
293 | const int64_t uniformSampleCost = |
294 | random::PhiloxRandom::kElementCost + |
295 | random::UniformDistribution<random::PhiloxRandom, T>::kElementCost; |
296 | // The cost of a single uniform sampling round. |
297 | const int64_t uniformRejectionSamplingCost = |
298 | uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() + |
299 | Eigen::TensorOpCost::AddCost<T>() + |
300 | Eigen::TensorOpCost::MulCost<T>() * 2 + |
301 | Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost + |
302 | Eigen::internal::functor_traits< |
303 | Eigen::internal::scalar_exp_op<T>>::Cost + |
304 | Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>(); |
305 | // Estimate the cost for an entire batch. |
306 | // Assume we use uniform sampling, and accept the 2nd sample on average. |
307 | const int64_t batchCost = |
308 | batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch; |
309 | Shard(worker_threads.num_threads, worker_threads.workers, num_batches, |
310 | batchCost, do_work); |
311 | } |
312 | }; |
313 | |
314 | template <typename T> |
315 | struct TruncatedNormalFunctorV2<CPUDevice, T> { |
316 | void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches, |
317 | int64_t samples_per_batch, int64_t num_elements, |
318 | const BCastList<4>& bcast, |
319 | typename TTypes<T>::ConstFlat means, |
320 | typename TTypes<T>::ConstFlat stddevs, |
321 | typename TTypes<T>::ConstFlat minvals, |
322 | typename TTypes<T>::ConstFlat maxvals, |
323 | const random::PhiloxRandom& gen, |
324 | typename TTypes<T>::Flat output) { |
325 | // The randn rejection sampling is used when the mean and at least this many |
326 | // standard deviations are inside the bounds. |
327 | // The uniform proposal samplers become less efficient as the bounds are |
328 | // further from the mean, the reverse is true for the randn sampler. |
329 | // This number was chosen by empirical benchmarking. If modified, the |
330 | // benchmarks in parameterized_truncated_normal_op_test should also be |
331 | // changed. |
332 | const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3); |
333 | auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); |
334 | |
335 | auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means, |
336 | &stddevs, &minvals, &maxvals, &gen, &output, |
337 | kStdDevsInsideBoundsToUseRandnSampler]( |
338 | int64_t start_output, int64_t limit_output) { |
339 | // Capturing "gen" by-value would only make a copy for the _shared_ |
340 | // lambda. Since we want to let each worker have its own copy, we pass |
341 | // "gen" by reference and explicitly do a copy assignment here. |
342 | random::PhiloxRandom gen_copy = gen; |
343 | using Uniform = random::UniformDistribution<random::PhiloxRandom, T>; |
344 | Uniform dist; |
345 | using Normal = random::NormalDistribution<random::PhiloxRandom, T>; |
346 | Normal normal_dist; |
347 | // Skip takes units of 128 bits. The Uniform::kResultElementCount - 1 |
348 | // is so rounding doesn't lead to |
349 | // us using the same state in different workloads. |
350 | // The sample from each iteration uses 2 random numbers. |
351 | gen_copy.Skip((start_output * 2 * kMaxIterations + |
352 | Uniform::kResultElementCount - 1) / |
353 | Uniform::kResultElementCount); |
354 | |
355 | // Vectorized intermediate calculations for uniform rejection sampling. |
356 | // We always generate at most 4 samples. |
357 | Eigen::array<T, Uniform::kResultElementCount> z; |
358 | Eigen::array<T, Uniform::kResultElementCount> g; |
359 | |
360 | const bool should_bcast = bcast.IsBroadcastingRequired(); |
361 | const auto& means_batch_indices = bcast.batch_indices(0); |
362 | const auto& stddevs_batch_indices = bcast.batch_indices(1); |
363 | const auto& minvals_batch_indices = bcast.batch_indices(2); |
364 | const auto& maxvals_batch_indices = bcast.batch_indices(3); |
365 | auto output_flat = output.data(); |
366 | |
367 | // We partition work across batches and then across samples |
368 | // per batch member, to avoid extra work. |
369 | for (int64_t output_idx = start_output; output_idx < limit_output; |
370 | // output_idx is incremented with the inner loops below. |
371 | ) { |
372 | int64_t batch_idx = output_idx / samples_per_batch; |
373 | // The output layout is [samples_per_batch, num_batches]. Thus |
374 | // the output address is sample_idx * num_batches + batch_idx. |
375 | // Below, code will index at output_batch_offset[sample_idx * |
376 | // num_batches] matching this. |
377 | T* const output_batch_offset = output_flat + batch_idx; |
378 | // Generate batch counts from BCast, as it has the right indices to loop |
379 | // over. |
380 | T mean, stddev, minval, maxval; |
381 | if (should_bcast) { |
382 | mean = means(means_batch_indices[batch_idx]); |
383 | stddev = stddevs(stddevs_batch_indices[batch_idx]); |
384 | minval = minvals(minvals_batch_indices[batch_idx]); |
385 | maxval = maxvals(maxvals_batch_indices[batch_idx]); |
386 | } else { |
387 | mean = means(batch_idx); |
388 | stddev = stddevs(batch_idx); |
389 | minval = minvals(batch_idx); |
390 | maxval = maxvals(batch_idx); |
391 | } |
392 | |
393 | // On GPU, this check will just fill samples with NAN if it fails. |
394 | OP_REQUIRES(ctx, |
395 | stddev > T(0) && minval < maxval && |
396 | (Eigen::numext::isfinite(minval) || |
397 | Eigen::numext::isfinite(maxval)), |
398 | errors::InvalidArgument("Invalid parameters" )); |
399 | |
400 | int num_iterations = 0; |
401 | |
402 | // If possible, make one-sided bound be the lower bound, or make both |
403 | // bounds positive. Otherwise, the bounds are on either side of the |
404 | // mean. |
405 | if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) { |
406 | // Reverse all calculations. normMin and normMax will be flipped. |
407 | std::swap(minval, maxval); |
408 | stddev = -stddev; |
409 | } |
410 | |
411 | // Calculate normalized samples, then convert them. |
412 | const T normMin = (minval - mean) / stddev; |
413 | const T normMax = (maxval - mean) / stddev; |
414 | |
415 | // Determine the method to use. |
416 | const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); |
417 | const T cutoff = |
418 | T(2) * |
419 | Eigen::numext::exp(T(0.5) + |
420 | (normMin * (normMin - sqrtFactor)) / T(4)) / |
421 | (normMin + sqrtFactor); |
422 | const T diff = normMax - normMin; |
423 | |
424 | if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) && |
425 | (normMax >= T(0.))) || |
426 | ((normMax > kStdDevsInsideBoundsToUseRandnSampler) && |
427 | (normMin <= T(0.)))) { |
428 | // If the bounds are a least 3 standard deviations from the mean |
429 | // on at least one side then we rejection sample by sampling |
430 | // from the normal distribution and rejecting samples outside |
431 | // the bounds. |
432 | // Under this condition the acceptance rate per iteration should |
433 | // always be ~ 50%. This sampler is more efficient (and more |
434 | // numerically stable when one or both bounds is far from the mean). |
435 | for (int64_t sample_idx = output_idx % samples_per_batch; |
436 | sample_idx < samples_per_batch && output_idx < limit_output;) { |
437 | const auto randn_sample = normal_dist(&gen_copy); |
438 | const int size = randn_sample.size(); |
439 | for (int i = 0; i < size; ++i) { |
440 | if ((randn_sample[i] >= normMin) && |
441 | (randn_sample[i] <= normMax)) { |
442 | output_batch_offset[sample_idx * num_batches] = |
443 | randn_sample[i] * stddev + mean; |
444 | ++sample_idx; |
445 | ++output_idx; |
446 | if (sample_idx >= samples_per_batch || |
447 | output_idx >= limit_output) { |
448 | break; |
449 | } |
450 | num_iterations = 0; |
451 | } else { |
452 | ++num_iterations; |
453 | if (num_iterations > kMaxIterations) { |
454 | // This should never occur because this sampler should |
455 | // (by the selection criteria above) be used if at least 3 |
456 | // standard deviations of one side of the distribution |
457 | // is within the limits (so acceptance probability per |
458 | // iterations >~ 1/2 per iteration). |
459 | LOG(ERROR) << "TruncatedNormal randn rejection sampler " |
460 | << "exceeded maximum iterations for " |
461 | << "normMin=" << normMin << " normMax=" << normMax |
462 | << " kMaxIterations=" << kMaxIterations; |
463 | ctx->SetStatus(errors::Internal( |
464 | "TruncatedNormal randn rejection sampler failed to accept" |
465 | " a sample." )); |
466 | return; |
467 | } |
468 | } |
469 | } |
470 | } |
471 | } else if (diff < cutoff) { |
472 | // Sample from a uniform distribution on [normMin, normMax]. |
473 | |
474 | const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; |
475 | |
476 | for (int64_t sample_idx = output_idx % samples_per_batch; |
477 | sample_idx < samples_per_batch && output_idx < limit_output;) { |
478 | const auto rand = dist(&gen_copy); |
479 | const int size = rand.size(); |
480 | // NOTE(ringwalt): These loops seem to only generate packed AVX |
481 | // instructions for float32. |
482 | for (int i = 0; i < size; i++) { |
483 | z[i] = rand[i] * diff + normMin; |
484 | g[i] = (plusFactor - z[i] * z[i]) / T(2.0); |
485 | } |
486 | |
487 | const auto u = dist(&gen_copy); |
488 | for (int i = 0; i < size; i++) { |
489 | auto accept = u[i] <= Eigen::numext::exp(g[i]); |
490 | if (accept || num_iterations + 1 >= kMaxIterations) { |
491 | // Accept the sample z. |
492 | // If we run out of iterations, just use the current uniform |
493 | // sample, but emit a warning. |
494 | // TODO(jjhunt) For small entropies (relative to the bounds), |
495 | // this sampler is poor and may take many iterations since |
496 | // the proposal distribution is the uniform distribution |
497 | // U(lower_bound, upper_bound). |
498 | if (!accept) { |
499 | LOG(ERROR) << "TruncatedNormal uniform rejection sampler " |
500 | << "exceeded max iterations. Sample may contain " |
501 | << "outliers." ; |
502 | ctx->SetStatus(errors::Internal( |
503 | "TruncatedNormal uniform rejection sampler failed to " |
504 | " accept a sample." )); |
505 | return; |
506 | } |
507 | output_batch_offset[sample_idx * num_batches] = |
508 | z[i] * stddev + mean; |
509 | ++sample_idx; |
510 | ++output_idx; |
511 | if (sample_idx >= samples_per_batch || |
512 | output_idx >= limit_output) { |
513 | break; |
514 | } |
515 | num_iterations = 0; |
516 | } else { |
517 | num_iterations++; |
518 | } |
519 | } |
520 | } |
521 | } else { |
522 | // Sample from an exponential distribution with alpha maximizing |
523 | // acceptance probability, offset by normMin from the origin. |
524 | // Accept only if less than normMax. |
525 | const T alpha = |
526 | (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / |
527 | T(2); |
528 | for (int64_t sample_idx = output_idx % samples_per_batch; |
529 | sample_idx < samples_per_batch && output_idx < limit_output;) { |
530 | auto rand = dist(&gen_copy); |
531 | const int size = rand.size(); |
532 | int i = 0; |
533 | while (i < size) { |
534 | const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; |
535 | i++; |
536 | const T x = normMin < alpha ? alpha - z : normMin - alpha; |
537 | const T g = Eigen::numext::exp(-x * x / T(2.0)); |
538 | const T u = rand[i]; |
539 | i++; |
540 | auto accept = (u <= g && z < normMax); |
541 | if (accept || num_iterations + 1 >= kMaxIterations) { |
542 | if (!accept) { |
543 | LOG(ERROR) << "TruncatedNormal exponential distribution " |
544 | << "rejection sampler exceeds max iterations. " |
545 | << "Sample may contain outliers." ; |
546 | ctx->SetStatus(errors::Internal( |
547 | "TruncatedNormal exponential distribution rejection" |
548 | " sampler failed to accept a sample." )); |
549 | return; |
550 | } |
551 | output_batch_offset[sample_idx * num_batches] = |
552 | z * stddev + mean; |
553 | ++sample_idx; |
554 | ++output_idx; |
555 | if (sample_idx >= samples_per_batch || |
556 | output_idx >= limit_output) { |
557 | break; |
558 | } |
559 | num_iterations = 0; |
560 | } else { |
561 | num_iterations++; |
562 | } |
563 | } |
564 | } |
565 | } |
566 | } |
567 | }; |
568 | // The cost of the initial calculations for the batch. |
569 | const int64_t batchInitCost = |
570 | // normMin, normMax |
571 | (Eigen::TensorOpCost::AddCost<T>() + |
572 | Eigen::TensorOpCost::MulCost<T>()) * |
573 | 2 |
574 | // sqrtFactor |
575 | + Eigen::TensorOpCost::AddCost<T>() + |
576 | Eigen::TensorOpCost::MulCost<T>() + |
577 | Eigen::internal::functor_traits< |
578 | Eigen::internal::scalar_sqrt_op<T>>::Cost |
579 | // cutoff |
580 | + Eigen::TensorOpCost::MulCost<T>() * 4 + |
581 | Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost |
582 | // diff |
583 | + Eigen::TensorOpCost::AddCost<T>(); |
584 | const int64_t uniformSampleCost = |
585 | random::PhiloxRandom::kElementCost + |
586 | random::UniformDistribution<random::PhiloxRandom, T>::kElementCost; |
587 | // The cost of a single uniform sampling round. |
588 | const int64_t uniformRejectionSamplingCost = |
589 | uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() + |
590 | Eigen::TensorOpCost::AddCost<T>() + |
591 | Eigen::TensorOpCost::MulCost<T>() * 2 + |
592 | Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost + |
593 | Eigen::internal::functor_traits< |
594 | Eigen::internal::scalar_exp_op<T>>::Cost + |
595 | Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>(); |
596 | // Estimate the cost for an entire batch. |
597 | // Assume we use uniform sampling, and accept the 2nd sample on average. |
598 | const int64_t batchCost = batchInitCost + uniformRejectionSamplingCost * 2; |
599 | Shard(worker_threads.num_threads, worker_threads.workers, num_elements, |
600 | batchCost, do_work); |
601 | } |
602 | }; |
603 | |
604 | } // namespace functor |
605 | |
606 | namespace { |
607 | |
608 | // Samples from a truncated normal distribution, using the given parameters. |
609 | template <typename Device, typename T> |
610 | class ParameterizedTruncatedNormalOp : public OpKernel { |
611 | // Reshape batches so each batch is this size if possible. |
612 | static constexpr int32_t kDesiredBatchSize = 100; |
613 | |
614 | public: |
615 | explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context) |
616 | : OpKernel(context) { |
617 | OP_REQUIRES_OK(context, generator_.Init(context)); |
618 | } |
619 | |
620 | void Compute(OpKernelContext* ctx) override { |
621 | const Tensor& shape_tensor = ctx->input(0); |
622 | const Tensor& means_tensor = ctx->input(1); |
623 | const Tensor& stddevs_tensor = ctx->input(2); |
624 | const Tensor& minvals_tensor = ctx->input(3); |
625 | const Tensor& maxvals_tensor = ctx->input(4); |
626 | |
627 | OP_REQUIRES( |
628 | ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), |
629 | errors::InvalidArgument("Input shape should be a vector, got shape: " , |
630 | shape_tensor.shape().DebugString())); |
631 | OP_REQUIRES(ctx, shape_tensor.NumElements() > 0, |
632 | errors::InvalidArgument("Shape tensor must not be empty, got " , |
633 | shape_tensor.DebugString())); |
634 | TensorShape tensor_shape; |
635 | OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &tensor_shape)); |
636 | |
637 | int32_t num_batches = tensor_shape.dim_size(0); |
638 | int32_t samples_per_batch = 1; |
639 | const int32_t num_dims = tensor_shape.dims(); |
640 | for (int32_t i = 1; i < num_dims; i++) { |
641 | samples_per_batch *= tensor_shape.dim_size(i); |
642 | } |
643 | const int32_t num_elements = num_batches * samples_per_batch; |
644 | |
645 | // Allocate the output before fudging num_batches and samples_per_batch. |
646 | Tensor* samples_tensor; |
647 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); |
648 | |
649 | // Parameters must be 0-d or 1-d. |
650 | OP_REQUIRES(ctx, means_tensor.dims() <= 1, |
651 | errors::InvalidArgument( |
652 | "Input means should be a scalar or vector, got shape: " , |
653 | means_tensor.shape().DebugString())); |
654 | OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1, |
655 | errors::InvalidArgument( |
656 | "Input stddevs should be a scalar or vector, got shape: " , |
657 | stddevs_tensor.shape().DebugString())); |
658 | OP_REQUIRES(ctx, minvals_tensor.dims() <= 1, |
659 | errors::InvalidArgument( |
660 | "Input minvals should be a scalar or vector, got shape: " , |
661 | minvals_tensor.shape().DebugString())); |
662 | OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1, |
663 | errors::InvalidArgument( |
664 | "Input maxvals should be a scalar or vector, got shape: " , |
665 | maxvals_tensor.shape().DebugString())); |
666 | |
667 | if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) && |
668 | (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) && |
669 | minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) { |
670 | // All batches have the same parameters, so we can update the batch size |
671 | // to a reasonable value to improve parallelism (ensure enough batches, |
672 | // and no very small batches which have high overhead). |
673 | int32_t size = num_batches * samples_per_batch; |
674 | int32_t adjusted_samples = kDesiredBatchSize; |
675 | // Ensure adjusted_batches * adjusted_samples >= size. |
676 | int32_t adjusted_batches = Eigen::divup(size, adjusted_samples); |
677 | num_batches = adjusted_batches; |
678 | samples_per_batch = adjusted_samples; |
679 | } else { |
680 | // Parameters must be broadcastable to the shape [num_batches]. |
681 | OP_REQUIRES( |
682 | ctx, |
683 | TensorShapeUtils::IsScalar(means_tensor.shape()) || |
684 | means_tensor.dim_size(0) == 1 || |
685 | means_tensor.dim_size(0) == num_batches, |
686 | errors::InvalidArgument( |
687 | "Input means should have length 1 or shape[0], got shape: " , |
688 | means_tensor.shape().DebugString())); |
689 | OP_REQUIRES( |
690 | ctx, |
691 | TensorShapeUtils::IsScalar(stddevs_tensor.shape()) || |
692 | stddevs_tensor.dim_size(0) == 1 || |
693 | stddevs_tensor.dim_size(0) == num_batches, |
694 | errors::InvalidArgument( |
695 | "Input stddevs should have length 1 or shape[0], got shape: " , |
696 | stddevs_tensor.shape().DebugString())); |
697 | OP_REQUIRES( |
698 | ctx, |
699 | TensorShapeUtils::IsScalar(minvals_tensor.shape()) || |
700 | minvals_tensor.dim_size(0) == 1 || |
701 | minvals_tensor.dim_size(0) == num_batches, |
702 | errors::InvalidArgument( |
703 | "Input minvals should have length 1 or shape[0], got shape: " , |
704 | minvals_tensor.shape().DebugString())); |
705 | OP_REQUIRES( |
706 | ctx, |
707 | TensorShapeUtils::IsScalar(maxvals_tensor.shape()) || |
708 | maxvals_tensor.dim_size(0) == 1 || |
709 | maxvals_tensor.dim_size(0) == num_batches, |
710 | errors::InvalidArgument( |
711 | "Input maxvals should have length 1 or shape[0], got shape: " , |
712 | maxvals_tensor.shape().DebugString())); |
713 | } |
714 | |
715 | auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>(); |
716 | // Each worker has the fudge factor for samples_per_batch, so use it here. |
717 | random::PhiloxRandom rng = |
718 | generator_.ReserveSamples128(num_batches * 2 * functor::kMaxIterations * |
719 | (samples_per_batch + 3) / 4); |
720 | truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches, |
721 | samples_per_batch, num_elements, means_tensor.flat<T>(), |
722 | stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(), |
723 | maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>()); |
724 | } |
725 | |
726 | private: |
727 | GuardedPhiloxRandom generator_; |
728 | |
729 | TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp); |
730 | }; |
731 | |
732 | // Samples from a truncated normal distribution, using the given parameters. |
733 | template <typename Device, typename T> |
734 | class StatelessParameterizedTruncatedNormal : public OpKernel { |
735 | // Reshape batches so each batch is this size if possible. |
736 | static const int32_t kDesiredBatchSize = 100; |
737 | |
738 | public: |
739 | explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context) |
740 | : OpKernel(context) {} |
741 | |
742 | void Compute(OpKernelContext* ctx) override { |
743 | const Tensor& shape_tensor = ctx->input(0); |
744 | const Tensor& seed_tensor = ctx->input(1); |
745 | const Tensor& means_tensor = ctx->input(2); |
746 | const Tensor& stddevs_tensor = ctx->input(3); |
747 | const Tensor& minvals_tensor = ctx->input(4); |
748 | const Tensor& maxvals_tensor = ctx->input(5); |
749 | |
750 | OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2, |
751 | errors::InvalidArgument("seed must have shape [2], not " , |
752 | seed_tensor.shape().DebugString())); |
753 | |
754 | tensorflow::BCastList<4> bcast( |
755 | {means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(), |
756 | minvals_tensor.shape().dim_sizes(), |
757 | maxvals_tensor.shape().dim_sizes()}, |
758 | /*fewer_dims_optimization=*/false, |
759 | /*return_flattened_batch_indices=*/true); |
760 | |
761 | OP_REQUIRES(ctx, bcast.IsValid(), |
762 | errors::InvalidArgument( |
763 | "means, stddevs, minvals, maxvals must have compatible " |
764 | "batch dimensions: " , |
765 | means_tensor.shape().DebugString(), " vs. " , |
766 | stddevs_tensor.shape().DebugString(), " vs. " , |
767 | minvals_tensor.shape().DebugString(), " vs. " , |
768 | maxvals_tensor.shape().DebugString())); |
769 | |
770 | // Let's check that the shape tensor dominates the broadcasted tensor. |
771 | TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); |
772 | OP_REQUIRES( |
773 | ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), |
774 | errors::InvalidArgument("Input shape should be a vector, got shape: " , |
775 | shape_tensor.shape().DebugString())); |
776 | TensorShape output_shape; |
777 | if (shape_tensor.dtype() == DataType::DT_INT32) { |
778 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(), |
779 | &output_shape)); |
780 | } else { |
781 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( |
782 | shape_tensor.vec<int64_t>(), &output_shape)); |
783 | } |
784 | OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape), |
785 | errors::InvalidArgument( |
786 | "Shape passed in must end with broadcasted shape." )); |
787 | |
788 | int64_t samples_per_batch = 1; |
789 | const int64_t num_sample_dims = |
790 | (shape_tensor.dim_size(0) - bcast.output_shape().size()); |
791 | for (int64_t i = 0; i < num_sample_dims; ++i) { |
792 | samples_per_batch *= output_shape.dim_size(i); |
793 | } |
794 | int64_t num_batches = 1; |
795 | for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) { |
796 | num_batches *= output_shape.dim_size(i); |
797 | } |
798 | const int64_t num_elements = num_batches * samples_per_batch; |
799 | |
800 | Tensor* samples_tensor; |
801 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor)); |
802 | |
803 | auto truncFunctor = functor::TruncatedNormalFunctorV2<Device, T>(); |
804 | // Each worker has the same fudge factor, so use it here. |
805 | random::PhiloxRandom::Key key; |
806 | random::PhiloxRandom::ResultType counter; |
807 | OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter)); |
808 | |
809 | auto philox = random::PhiloxRandom(counter, key); |
810 | |
811 | truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches, |
812 | samples_per_batch, num_elements, bcast, means_tensor.flat<T>(), |
813 | stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(), |
814 | maxvals_tensor.flat<T>(), philox, samples_tensor->flat<T>()); |
815 | } |
816 | |
817 | private: |
818 | TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal); |
819 | }; |
820 | |
821 | } // namespace |
822 | |
823 | #define REGISTER(TYPE) \ |
824 | REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ |
825 | .Device(DEVICE_CPU) \ |
826 | .TypeConstraint<TYPE>("dtype"), \ |
827 | ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) \ |
828 | REGISTER_KERNEL_BUILDER( \ |
829 | Name("StatelessParameterizedTruncatedNormal") \ |
830 | .HostMemory("shape") \ |
831 | .HostMemory("seed") \ |
832 | .HostMemory("means") \ |
833 | .HostMemory("stddevs") \ |
834 | .HostMemory("minvals") \ |
835 | .HostMemory("maxvals") \ |
836 | .Device(DEVICE_CPU) \ |
837 | .TypeConstraint<TYPE>("dtype"), \ |
838 | StatelessParameterizedTruncatedNormal<CPUDevice, TYPE>) |
839 | |
840 | TF_CALL_half(REGISTER); |
841 | TF_CALL_float(REGISTER); |
842 | TF_CALL_double(REGISTER); |
843 | |
844 | #undef REGISTER |
845 | |
846 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
847 | |
848 | #define REGISTER(TYPE) \ |
849 | REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ |
850 | .Device(DEVICE_GPU) \ |
851 | .HostMemory("shape") \ |
852 | .TypeConstraint<TYPE>("dtype"), \ |
853 | ParameterizedTruncatedNormalOp<GPUDevice, TYPE>) |
854 | |
855 | TF_CALL_half(REGISTER); |
856 | TF_CALL_float(REGISTER); |
857 | TF_CALL_double(REGISTER); |
858 | |
859 | #undef REGISTER |
860 | |
861 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
862 | |
863 | } // end namespace tensorflow |
864 | |