1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/random_ops.cc.
17// NOTE: If the algorithm is changed, please run the test
18// .../python/kernel_tests:parameterized_truncated_normal_op_test
19// commenting out the "tf.set_random_seed(seed)" lines, and using the
20// "--runs-per-test=1000" flag. This tests the statistical correctness of the
21// op results.
22
23#define EIGEN_USE_THREADS
24
25#include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
26
27#include <algorithm>
28#include <cmath>
29#include <memory>
30
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/register_types.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/framework/tensor_shape.h"
35#include "tensorflow/core/framework/tensor_util.h"
36#include "tensorflow/core/kernels/stateless_random_ops.h"
37#include "tensorflow/core/lib/random/random_distributions.h"
38#include "tensorflow/core/platform/logging.h"
39#include "tensorflow/core/util/guarded_philox_random.h"
40#include "tensorflow/core/util/work_sharder.h"
41
42namespace tensorflow {
43
44typedef Eigen::ThreadPoolDevice CPUDevice;
45typedef Eigen::GpuDevice GPUDevice;
46
47namespace functor {
48using random::PhiloxRandom;
49
50static constexpr int kMaxIterations = 1000;
51
52template <typename T>
53struct TruncatedNormalFunctor<CPUDevice, T> {
54 void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches,
55 int64_t samples_per_batch, int64_t num_elements,
56 typename TTypes<T>::ConstFlat means,
57 typename TTypes<T>::ConstFlat stddevs,
58 typename TTypes<T>::ConstFlat minvals,
59 typename TTypes<T>::ConstFlat maxvals,
60 const random::PhiloxRandom& gen,
61 typename TTypes<T>::Flat output) {
62 // The randn rejection sampling is used when the mean and at least this many
63 // standard deviations are inside the bounds.
64 // The uniform proposal samplers become less efficient as the bounds are
65 // further from the mean, the reverse is true for the randn sampler.
66 // This number was chosen by empirical benchmarking. If modified, the
67 // benchmarks in parameterized_truncated_normal_op_test should also be
68 // changed.
69 const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
70 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
71
72 auto do_work = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
73 &minvals, &maxvals, &gen, &output,
74 kStdDevsInsideBoundsToUseRandnSampler](
75 int64_t start_batch, int64_t limit_batch) {
76 // Capturing "gen" by-value would only make a copy for the _shared_
77 // lambda. Since we want to let each worker have its own copy, we pass
78 // "gen" by reference and explicitly do a copy assignment here.
79 random::PhiloxRandom gen_copy = gen;
80 // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
81 // us using the same state in different batches.
82 // The sample from each iteration uses 2 random numbers.
83 gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
84 4);
85 using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
86 Uniform dist;
87 using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
88 Normal normal_dist;
89
90 // Vectorized intermediate calculations for uniform rejection sampling.
91 // We always generate at most 4 samples.
92 Eigen::array<T, 4> z;
93 Eigen::array<T, 4> g;
94
95 for (int64_t b = start_batch; b < limit_batch; ++b) {
96 // We are passed a flat array for each of the parameter tensors.
97 // The input is either a scalar broadcasted to all batches or a vector
98 // with length num_batches, but the scalar becomes an array of length 1.
99 T mean = means((means.dimension(0) == 1) ? 0 : b);
100 T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b);
101 T minval = minvals((minvals.dimension(0) == 1) ? 0 : b);
102 T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b);
103
104 // The last batch can be short, if we adjusted num_batches and
105 // samples_per_batch.
106 const int64_t limit_sample =
107 std::min((b + 1) * samples_per_batch, num_elements);
108 int64_t sample = b * samples_per_batch;
109
110 // On GPU, this check will just fill samples with NAN if it fails.
111 OP_REQUIRES(ctx,
112 stddev > T(0) && minval < maxval &&
113 (Eigen::numext::isfinite(minval) ||
114 Eigen::numext::isfinite(maxval)),
115 errors::InvalidArgument("Invalid parameters"));
116
117 int num_iterations = 0;
118
119 // If possible, make one-sided bound be the lower bound, or make both
120 // bounds positive. Otherwise, the bounds are on either side of the
121 // mean.
122 if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
123 // Reverse all calculations. normMin and normMax will be flipped.
124 std::swap(minval, maxval);
125 stddev = -stddev;
126 }
127
128 // Calculate normalized samples, then convert them.
129 const T normMin = (minval - mean) / stddev;
130 const T normMax = (maxval - mean) / stddev;
131
132 // Determine the method to use.
133 const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
134 const T cutoff =
135 T(2) *
136 Eigen::numext::exp(T(0.5) +
137 (normMin * (normMin - sqrtFactor)) / T(4)) /
138 (normMin + sqrtFactor);
139 const T diff = normMax - normMin;
140
141 if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
142 (normMax >= T(0.))) ||
143 ((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
144 (normMin <= T(0.)))) {
145 // If the bounds are a least 3 standard deviations from the mean
146 // on at least one side then we rejection sample by sampling
147 // from the normal distribution and rejecting samples outside
148 // the bounds.
149 // Under this condition the acceptance rate per iteration should
150 // always be ~ 50%. This sampler is more efficient (and more
151 // numerically stable when one or both bounds is far from the mean).
152
153 while (sample < limit_sample) {
154 const auto randn_sample = normal_dist(&gen_copy);
155 const int size = randn_sample.size();
156
157 for (int i = 0; i < size; i++) {
158 if ((randn_sample[i] >= normMin) &&
159 (randn_sample[i] <= normMax)) {
160 output(sample) = randn_sample[i] * stddev + mean;
161 sample++;
162 if (sample >= limit_sample) {
163 break;
164 }
165 num_iterations = 0;
166 } else {
167 num_iterations++;
168 if (num_iterations > kMaxIterations) {
169 // This should never occur because this sampler should
170 // (by the selection criteria above) be used if at least 3
171 // standard deviations of one side of the distribution
172 // is within the limits (so acceptance probability per
173 // iterations >~ 1/2 per iteration).
174 LOG(ERROR) << "TruncatedNormal randn rejection sampler "
175 << "exceeded maximum iterations for "
176 << "normMin=" << normMin << " normMax=" << normMax
177 << " kMaxIterations=" << kMaxIterations;
178 ctx->SetStatus(errors::Internal(
179 "TruncatedNormal randn rejection sampler failed to accept"
180 " a sample."));
181 return;
182 }
183 }
184 }
185 }
186 } else if (diff < cutoff) {
187 // Sample from a uniform distribution on [normMin, normMax].
188
189 const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
190
191 while (sample < limit_sample) {
192 const auto rand = dist(&gen_copy);
193 const int size = rand.size();
194 // NOTE(ringwalt): These loops seem to only generate packed AVX
195 // instructions for float32.
196 for (int i = 0; i < size; i++) {
197 z[i] = rand[i] * diff + normMin;
198 }
199 for (int i = 0; i < size; i++) {
200 g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
201 }
202
203 const auto u = dist(&gen_copy);
204 for (int i = 0; i < size; i++) {
205 auto accept = u[i] <= Eigen::numext::exp(g[i]);
206 if (accept || num_iterations + 1 >= kMaxIterations) {
207 // Accept the sample z.
208 // If we run out of iterations, just use the current uniform
209 // sample, but emit a warning.
210 // TODO(jjhunt) For small entropies (relative to the bounds),
211 // this sampler is poor and may take many iterations since
212 // the proposal distribution is the uniform distribution
213 // U(lower_bound, upper_bound).
214 if (!accept) {
215 LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
216 << "exceeded max iterations. Sample may contain "
217 << "outliers.";
218 ctx->SetStatus(errors::Internal(
219 "TruncatedNormal uniform rejection sampler failed to "
220 " accept a sample."));
221 return;
222 }
223 output(sample) = z[i] * stddev + mean;
224 sample++;
225 if (sample >= limit_sample) {
226 break;
227 }
228 num_iterations = 0;
229 } else {
230 num_iterations++;
231 }
232 }
233 }
234 } else {
235 // Sample from an exponential distribution with alpha maximizing
236 // acceptance probability, offset by normMin from the origin.
237 // Accept only if less than normMax.
238 const T alpha =
239 (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
240 T(2);
241 while (sample < limit_sample) {
242 auto rand = dist(&gen_copy);
243 const int size = rand.size();
244 int i = 0;
245 while (i < size) {
246 const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
247 i++;
248 const T x = normMin < alpha ? alpha - z : normMin - alpha;
249 const T g = Eigen::numext::exp(-x * x / T(2.0));
250 const T u = rand[i];
251 i++;
252 auto accept = (u <= g && z < normMax);
253 if (accept || num_iterations + 1 >= kMaxIterations) {
254 if (!accept) {
255 LOG(ERROR) << "TruncatedNormal exponential distribution "
256 << "rejection sampler exceeds max iterations. "
257 << "Sample may contain outliers.";
258 ctx->SetStatus(errors::Internal(
259 "TruncatedNormal exponential distribution rejection"
260 " sampler failed to accept a sample."));
261 return;
262 }
263 output(sample) = z * stddev + mean;
264 sample++;
265 if (sample >= limit_sample) {
266 break;
267 }
268 num_iterations = 0;
269 } else {
270 num_iterations++;
271 }
272 }
273 }
274 }
275 }
276 };
277 // The cost of the initial calculations for the batch.
278 const int64_t batchInitCost =
279 // normMin, normMax
280 (Eigen::TensorOpCost::AddCost<T>() +
281 Eigen::TensorOpCost::MulCost<T>()) *
282 2
283 // sqrtFactor
284 + Eigen::TensorOpCost::AddCost<T>() +
285 Eigen::TensorOpCost::MulCost<T>() +
286 Eigen::internal::functor_traits<
287 Eigen::internal::scalar_sqrt_op<T>>::Cost
288 // cutoff
289 + Eigen::TensorOpCost::MulCost<T>() * 4 +
290 Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
291 // diff
292 + Eigen::TensorOpCost::AddCost<T>();
293 const int64_t uniformSampleCost =
294 random::PhiloxRandom::kElementCost +
295 random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
296 // The cost of a single uniform sampling round.
297 const int64_t uniformRejectionSamplingCost =
298 uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
299 Eigen::TensorOpCost::AddCost<T>() +
300 Eigen::TensorOpCost::MulCost<T>() * 2 +
301 Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
302 Eigen::internal::functor_traits<
303 Eigen::internal::scalar_exp_op<T>>::Cost +
304 Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
305 // Estimate the cost for an entire batch.
306 // Assume we use uniform sampling, and accept the 2nd sample on average.
307 const int64_t batchCost =
308 batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
309 Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
310 batchCost, do_work);
311 }
312};
313
314template <typename T>
315struct TruncatedNormalFunctorV2<CPUDevice, T> {
316 void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches,
317 int64_t samples_per_batch, int64_t num_elements,
318 const BCastList<4>& bcast,
319 typename TTypes<T>::ConstFlat means,
320 typename TTypes<T>::ConstFlat stddevs,
321 typename TTypes<T>::ConstFlat minvals,
322 typename TTypes<T>::ConstFlat maxvals,
323 const random::PhiloxRandom& gen,
324 typename TTypes<T>::Flat output) {
325 // The randn rejection sampling is used when the mean and at least this many
326 // standard deviations are inside the bounds.
327 // The uniform proposal samplers become less efficient as the bounds are
328 // further from the mean, the reverse is true for the randn sampler.
329 // This number was chosen by empirical benchmarking. If modified, the
330 // benchmarks in parameterized_truncated_normal_op_test should also be
331 // changed.
332 const T kStdDevsInsideBoundsToUseRandnSampler = T(1.3);
333 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
334
335 auto do_work = [num_batches, samples_per_batch, &ctx, &bcast, &means,
336 &stddevs, &minvals, &maxvals, &gen, &output,
337 kStdDevsInsideBoundsToUseRandnSampler](
338 int64_t start_output, int64_t limit_output) {
339 // Capturing "gen" by-value would only make a copy for the _shared_
340 // lambda. Since we want to let each worker have its own copy, we pass
341 // "gen" by reference and explicitly do a copy assignment here.
342 random::PhiloxRandom gen_copy = gen;
343 using Uniform = random::UniformDistribution<random::PhiloxRandom, T>;
344 Uniform dist;
345 using Normal = random::NormalDistribution<random::PhiloxRandom, T>;
346 Normal normal_dist;
347 // Skip takes units of 128 bits. The Uniform::kResultElementCount - 1
348 // is so rounding doesn't lead to
349 // us using the same state in different workloads.
350 // The sample from each iteration uses 2 random numbers.
351 gen_copy.Skip((start_output * 2 * kMaxIterations +
352 Uniform::kResultElementCount - 1) /
353 Uniform::kResultElementCount);
354
355 // Vectorized intermediate calculations for uniform rejection sampling.
356 // We always generate at most 4 samples.
357 Eigen::array<T, Uniform::kResultElementCount> z;
358 Eigen::array<T, Uniform::kResultElementCount> g;
359
360 const bool should_bcast = bcast.IsBroadcastingRequired();
361 const auto& means_batch_indices = bcast.batch_indices(0);
362 const auto& stddevs_batch_indices = bcast.batch_indices(1);
363 const auto& minvals_batch_indices = bcast.batch_indices(2);
364 const auto& maxvals_batch_indices = bcast.batch_indices(3);
365 auto output_flat = output.data();
366
367 // We partition work across batches and then across samples
368 // per batch member, to avoid extra work.
369 for (int64_t output_idx = start_output; output_idx < limit_output;
370 // output_idx is incremented with the inner loops below.
371 ) {
372 int64_t batch_idx = output_idx / samples_per_batch;
373 // The output layout is [samples_per_batch, num_batches]. Thus
374 // the output address is sample_idx * num_batches + batch_idx.
375 // Below, code will index at output_batch_offset[sample_idx *
376 // num_batches] matching this.
377 T* const output_batch_offset = output_flat + batch_idx;
378 // Generate batch counts from BCast, as it has the right indices to loop
379 // over.
380 T mean, stddev, minval, maxval;
381 if (should_bcast) {
382 mean = means(means_batch_indices[batch_idx]);
383 stddev = stddevs(stddevs_batch_indices[batch_idx]);
384 minval = minvals(minvals_batch_indices[batch_idx]);
385 maxval = maxvals(maxvals_batch_indices[batch_idx]);
386 } else {
387 mean = means(batch_idx);
388 stddev = stddevs(batch_idx);
389 minval = minvals(batch_idx);
390 maxval = maxvals(batch_idx);
391 }
392
393 // On GPU, this check will just fill samples with NAN if it fails.
394 OP_REQUIRES(ctx,
395 stddev > T(0) && minval < maxval &&
396 (Eigen::numext::isfinite(minval) ||
397 Eigen::numext::isfinite(maxval)),
398 errors::InvalidArgument("Invalid parameters"));
399
400 int num_iterations = 0;
401
402 // If possible, make one-sided bound be the lower bound, or make both
403 // bounds positive. Otherwise, the bounds are on either side of the
404 // mean.
405 if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
406 // Reverse all calculations. normMin and normMax will be flipped.
407 std::swap(minval, maxval);
408 stddev = -stddev;
409 }
410
411 // Calculate normalized samples, then convert them.
412 const T normMin = (minval - mean) / stddev;
413 const T normMax = (maxval - mean) / stddev;
414
415 // Determine the method to use.
416 const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
417 const T cutoff =
418 T(2) *
419 Eigen::numext::exp(T(0.5) +
420 (normMin * (normMin - sqrtFactor)) / T(4)) /
421 (normMin + sqrtFactor);
422 const T diff = normMax - normMin;
423
424 if (((normMin < -kStdDevsInsideBoundsToUseRandnSampler) &&
425 (normMax >= T(0.))) ||
426 ((normMax > kStdDevsInsideBoundsToUseRandnSampler) &&
427 (normMin <= T(0.)))) {
428 // If the bounds are a least 3 standard deviations from the mean
429 // on at least one side then we rejection sample by sampling
430 // from the normal distribution and rejecting samples outside
431 // the bounds.
432 // Under this condition the acceptance rate per iteration should
433 // always be ~ 50%. This sampler is more efficient (and more
434 // numerically stable when one or both bounds is far from the mean).
435 for (int64_t sample_idx = output_idx % samples_per_batch;
436 sample_idx < samples_per_batch && output_idx < limit_output;) {
437 const auto randn_sample = normal_dist(&gen_copy);
438 const int size = randn_sample.size();
439 for (int i = 0; i < size; ++i) {
440 if ((randn_sample[i] >= normMin) &&
441 (randn_sample[i] <= normMax)) {
442 output_batch_offset[sample_idx * num_batches] =
443 randn_sample[i] * stddev + mean;
444 ++sample_idx;
445 ++output_idx;
446 if (sample_idx >= samples_per_batch ||
447 output_idx >= limit_output) {
448 break;
449 }
450 num_iterations = 0;
451 } else {
452 ++num_iterations;
453 if (num_iterations > kMaxIterations) {
454 // This should never occur because this sampler should
455 // (by the selection criteria above) be used if at least 3
456 // standard deviations of one side of the distribution
457 // is within the limits (so acceptance probability per
458 // iterations >~ 1/2 per iteration).
459 LOG(ERROR) << "TruncatedNormal randn rejection sampler "
460 << "exceeded maximum iterations for "
461 << "normMin=" << normMin << " normMax=" << normMax
462 << " kMaxIterations=" << kMaxIterations;
463 ctx->SetStatus(errors::Internal(
464 "TruncatedNormal randn rejection sampler failed to accept"
465 " a sample."));
466 return;
467 }
468 }
469 }
470 }
471 } else if (diff < cutoff) {
472 // Sample from a uniform distribution on [normMin, normMax].
473
474 const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
475
476 for (int64_t sample_idx = output_idx % samples_per_batch;
477 sample_idx < samples_per_batch && output_idx < limit_output;) {
478 const auto rand = dist(&gen_copy);
479 const int size = rand.size();
480 // NOTE(ringwalt): These loops seem to only generate packed AVX
481 // instructions for float32.
482 for (int i = 0; i < size; i++) {
483 z[i] = rand[i] * diff + normMin;
484 g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
485 }
486
487 const auto u = dist(&gen_copy);
488 for (int i = 0; i < size; i++) {
489 auto accept = u[i] <= Eigen::numext::exp(g[i]);
490 if (accept || num_iterations + 1 >= kMaxIterations) {
491 // Accept the sample z.
492 // If we run out of iterations, just use the current uniform
493 // sample, but emit a warning.
494 // TODO(jjhunt) For small entropies (relative to the bounds),
495 // this sampler is poor and may take many iterations since
496 // the proposal distribution is the uniform distribution
497 // U(lower_bound, upper_bound).
498 if (!accept) {
499 LOG(ERROR) << "TruncatedNormal uniform rejection sampler "
500 << "exceeded max iterations. Sample may contain "
501 << "outliers.";
502 ctx->SetStatus(errors::Internal(
503 "TruncatedNormal uniform rejection sampler failed to "
504 " accept a sample."));
505 return;
506 }
507 output_batch_offset[sample_idx * num_batches] =
508 z[i] * stddev + mean;
509 ++sample_idx;
510 ++output_idx;
511 if (sample_idx >= samples_per_batch ||
512 output_idx >= limit_output) {
513 break;
514 }
515 num_iterations = 0;
516 } else {
517 num_iterations++;
518 }
519 }
520 }
521 } else {
522 // Sample from an exponential distribution with alpha maximizing
523 // acceptance probability, offset by normMin from the origin.
524 // Accept only if less than normMax.
525 const T alpha =
526 (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
527 T(2);
528 for (int64_t sample_idx = output_idx % samples_per_batch;
529 sample_idx < samples_per_batch && output_idx < limit_output;) {
530 auto rand = dist(&gen_copy);
531 const int size = rand.size();
532 int i = 0;
533 while (i < size) {
534 const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
535 i++;
536 const T x = normMin < alpha ? alpha - z : normMin - alpha;
537 const T g = Eigen::numext::exp(-x * x / T(2.0));
538 const T u = rand[i];
539 i++;
540 auto accept = (u <= g && z < normMax);
541 if (accept || num_iterations + 1 >= kMaxIterations) {
542 if (!accept) {
543 LOG(ERROR) << "TruncatedNormal exponential distribution "
544 << "rejection sampler exceeds max iterations. "
545 << "Sample may contain outliers.";
546 ctx->SetStatus(errors::Internal(
547 "TruncatedNormal exponential distribution rejection"
548 " sampler failed to accept a sample."));
549 return;
550 }
551 output_batch_offset[sample_idx * num_batches] =
552 z * stddev + mean;
553 ++sample_idx;
554 ++output_idx;
555 if (sample_idx >= samples_per_batch ||
556 output_idx >= limit_output) {
557 break;
558 }
559 num_iterations = 0;
560 } else {
561 num_iterations++;
562 }
563 }
564 }
565 }
566 }
567 };
568 // The cost of the initial calculations for the batch.
569 const int64_t batchInitCost =
570 // normMin, normMax
571 (Eigen::TensorOpCost::AddCost<T>() +
572 Eigen::TensorOpCost::MulCost<T>()) *
573 2
574 // sqrtFactor
575 + Eigen::TensorOpCost::AddCost<T>() +
576 Eigen::TensorOpCost::MulCost<T>() +
577 Eigen::internal::functor_traits<
578 Eigen::internal::scalar_sqrt_op<T>>::Cost
579 // cutoff
580 + Eigen::TensorOpCost::MulCost<T>() * 4 +
581 Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
582 // diff
583 + Eigen::TensorOpCost::AddCost<T>();
584 const int64_t uniformSampleCost =
585 random::PhiloxRandom::kElementCost +
586 random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
587 // The cost of a single uniform sampling round.
588 const int64_t uniformRejectionSamplingCost =
589 uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
590 Eigen::TensorOpCost::AddCost<T>() +
591 Eigen::TensorOpCost::MulCost<T>() * 2 +
592 Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
593 Eigen::internal::functor_traits<
594 Eigen::internal::scalar_exp_op<T>>::Cost +
595 Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
596 // Estimate the cost for an entire batch.
597 // Assume we use uniform sampling, and accept the 2nd sample on average.
598 const int64_t batchCost = batchInitCost + uniformRejectionSamplingCost * 2;
599 Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
600 batchCost, do_work);
601 }
602};
603
604} // namespace functor
605
606namespace {
607
608// Samples from a truncated normal distribution, using the given parameters.
609template <typename Device, typename T>
610class ParameterizedTruncatedNormalOp : public OpKernel {
611 // Reshape batches so each batch is this size if possible.
612 static constexpr int32_t kDesiredBatchSize = 100;
613
614 public:
615 explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context)
616 : OpKernel(context) {
617 OP_REQUIRES_OK(context, generator_.Init(context));
618 }
619
620 void Compute(OpKernelContext* ctx) override {
621 const Tensor& shape_tensor = ctx->input(0);
622 const Tensor& means_tensor = ctx->input(1);
623 const Tensor& stddevs_tensor = ctx->input(2);
624 const Tensor& minvals_tensor = ctx->input(3);
625 const Tensor& maxvals_tensor = ctx->input(4);
626
627 OP_REQUIRES(
628 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
629 errors::InvalidArgument("Input shape should be a vector, got shape: ",
630 shape_tensor.shape().DebugString()));
631 OP_REQUIRES(ctx, shape_tensor.NumElements() > 0,
632 errors::InvalidArgument("Shape tensor must not be empty, got ",
633 shape_tensor.DebugString()));
634 TensorShape tensor_shape;
635 OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &tensor_shape));
636
637 int32_t num_batches = tensor_shape.dim_size(0);
638 int32_t samples_per_batch = 1;
639 const int32_t num_dims = tensor_shape.dims();
640 for (int32_t i = 1; i < num_dims; i++) {
641 samples_per_batch *= tensor_shape.dim_size(i);
642 }
643 const int32_t num_elements = num_batches * samples_per_batch;
644
645 // Allocate the output before fudging num_batches and samples_per_batch.
646 Tensor* samples_tensor;
647 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
648
649 // Parameters must be 0-d or 1-d.
650 OP_REQUIRES(ctx, means_tensor.dims() <= 1,
651 errors::InvalidArgument(
652 "Input means should be a scalar or vector, got shape: ",
653 means_tensor.shape().DebugString()));
654 OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1,
655 errors::InvalidArgument(
656 "Input stddevs should be a scalar or vector, got shape: ",
657 stddevs_tensor.shape().DebugString()));
658 OP_REQUIRES(ctx, minvals_tensor.dims() <= 1,
659 errors::InvalidArgument(
660 "Input minvals should be a scalar or vector, got shape: ",
661 minvals_tensor.shape().DebugString()));
662 OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1,
663 errors::InvalidArgument(
664 "Input maxvals should be a scalar or vector, got shape: ",
665 maxvals_tensor.shape().DebugString()));
666
667 if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) &&
668 (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) &&
669 minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) {
670 // All batches have the same parameters, so we can update the batch size
671 // to a reasonable value to improve parallelism (ensure enough batches,
672 // and no very small batches which have high overhead).
673 int32_t size = num_batches * samples_per_batch;
674 int32_t adjusted_samples = kDesiredBatchSize;
675 // Ensure adjusted_batches * adjusted_samples >= size.
676 int32_t adjusted_batches = Eigen::divup(size, adjusted_samples);
677 num_batches = adjusted_batches;
678 samples_per_batch = adjusted_samples;
679 } else {
680 // Parameters must be broadcastable to the shape [num_batches].
681 OP_REQUIRES(
682 ctx,
683 TensorShapeUtils::IsScalar(means_tensor.shape()) ||
684 means_tensor.dim_size(0) == 1 ||
685 means_tensor.dim_size(0) == num_batches,
686 errors::InvalidArgument(
687 "Input means should have length 1 or shape[0], got shape: ",
688 means_tensor.shape().DebugString()));
689 OP_REQUIRES(
690 ctx,
691 TensorShapeUtils::IsScalar(stddevs_tensor.shape()) ||
692 stddevs_tensor.dim_size(0) == 1 ||
693 stddevs_tensor.dim_size(0) == num_batches,
694 errors::InvalidArgument(
695 "Input stddevs should have length 1 or shape[0], got shape: ",
696 stddevs_tensor.shape().DebugString()));
697 OP_REQUIRES(
698 ctx,
699 TensorShapeUtils::IsScalar(minvals_tensor.shape()) ||
700 minvals_tensor.dim_size(0) == 1 ||
701 minvals_tensor.dim_size(0) == num_batches,
702 errors::InvalidArgument(
703 "Input minvals should have length 1 or shape[0], got shape: ",
704 minvals_tensor.shape().DebugString()));
705 OP_REQUIRES(
706 ctx,
707 TensorShapeUtils::IsScalar(maxvals_tensor.shape()) ||
708 maxvals_tensor.dim_size(0) == 1 ||
709 maxvals_tensor.dim_size(0) == num_batches,
710 errors::InvalidArgument(
711 "Input maxvals should have length 1 or shape[0], got shape: ",
712 maxvals_tensor.shape().DebugString()));
713 }
714
715 auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>();
716 // Each worker has the fudge factor for samples_per_batch, so use it here.
717 random::PhiloxRandom rng =
718 generator_.ReserveSamples128(num_batches * 2 * functor::kMaxIterations *
719 (samples_per_batch + 3) / 4);
720 truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
721 samples_per_batch, num_elements, means_tensor.flat<T>(),
722 stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
723 maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>());
724 }
725
726 private:
727 GuardedPhiloxRandom generator_;
728
729 TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
730};
731
732// Samples from a truncated normal distribution, using the given parameters.
733template <typename Device, typename T>
734class StatelessParameterizedTruncatedNormal : public OpKernel {
735 // Reshape batches so each batch is this size if possible.
736 static const int32_t kDesiredBatchSize = 100;
737
738 public:
739 explicit StatelessParameterizedTruncatedNormal(OpKernelConstruction* context)
740 : OpKernel(context) {}
741
742 void Compute(OpKernelContext* ctx) override {
743 const Tensor& shape_tensor = ctx->input(0);
744 const Tensor& seed_tensor = ctx->input(1);
745 const Tensor& means_tensor = ctx->input(2);
746 const Tensor& stddevs_tensor = ctx->input(3);
747 const Tensor& minvals_tensor = ctx->input(4);
748 const Tensor& maxvals_tensor = ctx->input(5);
749
750 OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
751 errors::InvalidArgument("seed must have shape [2], not ",
752 seed_tensor.shape().DebugString()));
753
754 tensorflow::BCastList<4> bcast(
755 {means_tensor.shape().dim_sizes(), stddevs_tensor.shape().dim_sizes(),
756 minvals_tensor.shape().dim_sizes(),
757 maxvals_tensor.shape().dim_sizes()},
758 /*fewer_dims_optimization=*/false,
759 /*return_flattened_batch_indices=*/true);
760
761 OP_REQUIRES(ctx, bcast.IsValid(),
762 errors::InvalidArgument(
763 "means, stddevs, minvals, maxvals must have compatible "
764 "batch dimensions: ",
765 means_tensor.shape().DebugString(), " vs. ",
766 stddevs_tensor.shape().DebugString(), " vs. ",
767 minvals_tensor.shape().DebugString(), " vs. ",
768 maxvals_tensor.shape().DebugString()));
769
770 // Let's check that the shape tensor dominates the broadcasted tensor.
771 TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
772 OP_REQUIRES(
773 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
774 errors::InvalidArgument("Input shape should be a vector, got shape: ",
775 shape_tensor.shape().DebugString()));
776 TensorShape output_shape;
777 if (shape_tensor.dtype() == DataType::DT_INT32) {
778 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
779 &output_shape));
780 } else {
781 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
782 shape_tensor.vec<int64_t>(), &output_shape));
783 }
784 OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
785 errors::InvalidArgument(
786 "Shape passed in must end with broadcasted shape."));
787
788 int64_t samples_per_batch = 1;
789 const int64_t num_sample_dims =
790 (shape_tensor.dim_size(0) - bcast.output_shape().size());
791 for (int64_t i = 0; i < num_sample_dims; ++i) {
792 samples_per_batch *= output_shape.dim_size(i);
793 }
794 int64_t num_batches = 1;
795 for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
796 num_batches *= output_shape.dim_size(i);
797 }
798 const int64_t num_elements = num_batches * samples_per_batch;
799
800 Tensor* samples_tensor;
801 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
802
803 auto truncFunctor = functor::TruncatedNormalFunctorV2<Device, T>();
804 // Each worker has the same fudge factor, so use it here.
805 random::PhiloxRandom::Key key;
806 random::PhiloxRandom::ResultType counter;
807 OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
808
809 auto philox = random::PhiloxRandom(counter, key);
810
811 truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
812 samples_per_batch, num_elements, bcast, means_tensor.flat<T>(),
813 stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
814 maxvals_tensor.flat<T>(), philox, samples_tensor->flat<T>());
815 }
816
817 private:
818 TF_DISALLOW_COPY_AND_ASSIGN(StatelessParameterizedTruncatedNormal);
819};
820
821} // namespace
822
823#define REGISTER(TYPE) \
824 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
825 .Device(DEVICE_CPU) \
826 .TypeConstraint<TYPE>("dtype"), \
827 ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) \
828 REGISTER_KERNEL_BUILDER( \
829 Name("StatelessParameterizedTruncatedNormal") \
830 .HostMemory("shape") \
831 .HostMemory("seed") \
832 .HostMemory("means") \
833 .HostMemory("stddevs") \
834 .HostMemory("minvals") \
835 .HostMemory("maxvals") \
836 .Device(DEVICE_CPU) \
837 .TypeConstraint<TYPE>("dtype"), \
838 StatelessParameterizedTruncatedNormal<CPUDevice, TYPE>)
839
840TF_CALL_half(REGISTER);
841TF_CALL_float(REGISTER);
842TF_CALL_double(REGISTER);
843
844#undef REGISTER
845
846#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
847
848#define REGISTER(TYPE) \
849 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
850 .Device(DEVICE_GPU) \
851 .HostMemory("shape") \
852 .TypeConstraint<TYPE>("dtype"), \
853 ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
854
855TF_CALL_half(REGISTER);
856TF_CALL_float(REGISTER);
857TF_CALL_double(REGISTER);
858
859#undef REGISTER
860
861#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
862
863} // end namespace tensorflow
864