1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Based on "Notes on generating Sobol sequences. August 2008" by Joe and Kuo.
17// [1] https://web.maths.unsw.edu.au/~fkuo/sobol/joe-kuo-notes.pdf
18#include <algorithm>
19#include <cmath>
20#include <cstdint>
21#include <limits>
22
23#include "third_party/eigen3/Eigen/Core"
24#include "sobol_data.h" // from @sobol_data
25#include "tensorflow/core/framework/device_base.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/lib/core/threadpool.h"
29#include "tensorflow/core/platform/platform_strings.h"
30
31namespace tensorflow {
32
33// Embed the platform strings in this binary.
34TF_PLATFORM_STRINGS()
35
36typedef Eigen::ThreadPoolDevice CPUDevice;
37
38namespace {
39
40// Each thread will calculate at least kMinBlockSize points in the sequence.
41constexpr int kMinBlockSize = 512;
42
43// Returns number of digits in binary representation of n.
44// Example: n=13. Binary representation is 1101. NumBinaryDigits(13) -> 4.
45int NumBinaryDigits(int n) {
46 return static_cast<int>(std::log2(n) + 1);
47}
48
49// Returns position of rightmost zero digit in binary representation of n.
50// Example: n=13. Binary representation is 1101. RightmostZeroBit(13) -> 1.
51int RightmostZeroBit(int n) {
52 int k = 0;
53 while (n & 1) {
54 n >>= 1;
55 ++k;
56 }
57 return k;
58}
59
60// Returns an integer representation of point `i` in the Sobol sequence of
61// dimension `dim` using the given direction numbers.
62Eigen::VectorXi GetFirstPoint(int i, int dim,
63 const Eigen::MatrixXi& direction_numbers) {
64 // Index variables used in this function, consistent with notation in [1].
65 // i - point in the Sobol sequence
66 // j - dimension
67 // k - binary digit
68 Eigen::VectorXi integer_sequence = Eigen::VectorXi::Zero(dim);
69 // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
70 int gray_code = i ^ (i >> 1);
71 int num_digits = NumBinaryDigits(i);
72 for (int j = 0; j < dim; ++j) {
73 for (int k = 0; k < num_digits; ++k) {
74 if ((gray_code >> k) & 1) integer_sequence(j) ^= direction_numbers(j, k);
75 }
76 }
77 return integer_sequence;
78}
79
80// Calculates `num_results` Sobol points of dimension `dim` starting at the
81// point `start_point + skip` and writes them into `output` starting at point
82// `start_point`.
83template <typename T>
84void CalculateSobolSample(int32_t dim, int32_t num_results, int32_t skip,
85 int32_t start_point,
86 typename TTypes<T>::Flat output) {
87 // Index variables used in this function, consistent with notation in [1].
88 // i - point in the Sobol sequence
89 // j - dimension
90 // k - binary digit
91 const int num_digits =
92 NumBinaryDigits(skip + start_point + num_results + 1);
93 Eigen::MatrixXi direction_numbers(dim, num_digits);
94
95 // Shift things so we can use integers everywhere. Before we write to output,
96 // divide by constant to convert back to floats.
97 const T normalizing_constant = 1./(1 << num_digits);
98 for (int j = 0; j < dim; ++j) {
99 for (int k = 0; k < num_digits; ++k) {
100 direction_numbers(j, k) = sobol_data::kDirectionNumbers[j][k]
101 << (num_digits - k - 1);
102 }
103 }
104
105 // If needed, skip ahead to the appropriate point in the sequence. Otherwise
106 // we start with the first column of direction numbers.
107 Eigen::VectorXi integer_sequence =
108 (skip + start_point > 0)
109 ? GetFirstPoint(skip + start_point + 1, dim, direction_numbers)
110 : direction_numbers.col(0);
111
112 for (int j = 0; j < dim; ++j) {
113 output(start_point * dim + j) = integer_sequence(j) * normalizing_constant;
114 }
115 // go/wiki/Sobol_sequence#A_fast_algorithm_for_the_construction_of_Sobol_sequences
116 for (int i = start_point + 1; i < num_results + start_point; ++i) {
117 // The Gray code for the current point differs from the preceding one by
118 // just a single bit -- the rightmost bit.
119 int k = RightmostZeroBit(i + skip);
120 // Update the current point from the preceding one with a single XOR
121 // operation per dimension.
122 for (int j = 0; j < dim; ++j) {
123 integer_sequence(j) ^= direction_numbers(j, k);
124 output(i * dim + j) = integer_sequence(j) * normalizing_constant;
125 }
126 }
127}
128
129} // namespace
130
131template <typename Device, typename T>
132class SobolSampleOp : public OpKernel {
133 public:
134 explicit SobolSampleOp(OpKernelConstruction* context)
135 : OpKernel(context) {}
136
137 void Compute(OpKernelContext* context) override {
138 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(0).shape()),
139 errors::InvalidArgument("dim must be a scalar"));
140 int32_t dim = context->input(0).scalar<int32_t>()();
141 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(1).shape()),
142 errors::InvalidArgument("num_results must be a scalar"));
143 int32_t num_results = context->input(1).scalar<int32_t>()();
144 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()),
145 errors::InvalidArgument("skip must be a scalar"));
146 int32_t skip = context->input(2).scalar<int32_t>()();
147
148 OP_REQUIRES(context, dim >= 1,
149 errors::InvalidArgument("dim must be at least one"));
150 OP_REQUIRES(context, dim <= sobol_data::kMaxSobolDim,
151 errors::InvalidArgument("dim must be at most ",
152 sobol_data::kMaxSobolDim));
153 OP_REQUIRES(context, num_results >= 1,
154 errors::InvalidArgument("num_results must be at least one"));
155 OP_REQUIRES(context, skip >= 0,
156 errors::InvalidArgument("skip must be non-negative"));
157 OP_REQUIRES(context,
158 num_results < std::numeric_limits<int32_t>::max() - skip,
159 errors::InvalidArgument("num_results+skip must be less than ",
160 std::numeric_limits<int32_t>::max()));
161
162 Tensor* output = nullptr;
163 OP_REQUIRES_OK(context,
164 context->allocate_output(
165 0, TensorShape({num_results, dim}), &output));
166 auto output_flat = output->flat<T>();
167 const DeviceBase::CpuWorkerThreads& worker_threads =
168 *(context->device()->tensorflow_cpu_worker_threads());
169 int num_threads = worker_threads.num_threads;
170 int block_size = std::max(
171 kMinBlockSize, static_cast<int>(std::ceil(
172 static_cast<float>(num_results) / num_threads)));
173 worker_threads.workers->TransformRangeConcurrently(
174 block_size, num_results /* total */,
175 [&dim, &skip, &output_flat](const int start, const int end) {
176 CalculateSobolSample<T>(dim, end - start /* num_results */, skip,
177 start, output_flat);
178 });
179 }
180};
181
182REGISTER_KERNEL_BUILDER(
183 Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<double>("dtype"),
184 SobolSampleOp<CPUDevice, double>);
185REGISTER_KERNEL_BUILDER(
186 Name("SobolSample").Device(DEVICE_CPU).TypeConstraint<float>("dtype"),
187 SobolSampleOp<CPUDevice, float>);
188
189} // namespace tensorflow
190