1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/range_sampler.h"
17
18#include <cmath>
19#include <unordered_set>
20#include <vector>
21
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/gtl/map_util.h"
24#include "tensorflow/core/lib/io/inputbuffer.h"
25#include "tensorflow/core/lib/strings/numbers.h"
26#include "tensorflow/core/lib/strings/str_util.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/mutex.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32
33using gtl::ArraySlice;
34using gtl::MutableArraySlice;
35
36RangeSampler::~RangeSampler() {}
37
38void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
39 gtl::MutableArraySlice<int64_t> batch) const {
40 SampleBatchGetExpectedCount(
41 rnd, unique, batch, gtl::MutableArraySlice<float>(),
42 gtl::ArraySlice<int64_t>(), gtl::MutableArraySlice<float>());
43}
44
45void RangeSampler::SampleBatchGetExpectedCount(
46 random::SimplePhilox* rnd, bool unique,
47 gtl::MutableArraySlice<int64_t> batch,
48 gtl::MutableArraySlice<float> batch_expected_count,
49 gtl::ArraySlice<int64_t> extras,
50 gtl::MutableArraySlice<float> extras_expected_count) const {
51 SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
52 extras, extras_expected_count,
53 gtl::ArraySlice<int64_t>());
54}
55
56namespace {
57
58// Approximates the expected count of a value in the output of SampleBatch.
59//
60// If unique=false, then this is (Probability(value) * batch_size)
61//
62// We use batch_size and num_tries, where num_tries is the observed number of
63// tries it took to get batch_size unique values.
64//
65// Assuming (falsely) that the number of tries to get a batch of batch_size
66// distinct values is _always_ num_tries, the probability that the value
67// is in a batch is (1 - (1-p)^num_tries)
68static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
69 if (num_tries == batch_size) {
70 // This shortcut will always be taken if unique=false
71 return p * batch_size;
72 }
73 // numerically stable version of (1 - (1-p)^num_tries)
74 return -std::expm1(num_tries * std::log1p(-p));
75}
76
77} // namespace
78
79void RangeSampler::SampleBatchGetExpectedCountAvoid(
80 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
81 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
82 MutableArraySlice<float> extras_expected_count,
83 ArraySlice<int64_t> avoided_values) const {
84 const int batch_size = batch.size();
85 int num_tries;
86
87 if (unique) {
88 CHECK_LE(static_cast<int64_t>(batch_size + avoided_values.size()), range_);
89 std::unordered_set<int64_t> used(batch_size);
90 used.insert(avoided_values.begin(), avoided_values.end());
91 int num_picked = 0;
92 num_tries = 0;
93 while (num_picked < batch_size) {
94 num_tries++;
95 CHECK_LT(num_tries, kint32max);
96 int64_t value = Sample(rnd);
97 if (gtl::InsertIfNotPresent(&used, value)) {
98 batch[num_picked++] = value;
99 }
100 }
101 } else {
102 CHECK_EQ(avoided_values.size(), size_t{0})
103 << "avoided_values only supported with unique=true";
104 for (int i = 0; i < batch_size; i++) {
105 batch[i] = Sample(rnd);
106 }
107 num_tries = batch_size;
108 }
109 // Compute the expected counts of the batch and the extra values
110 if (!batch_expected_count.empty()) {
111 CHECK_EQ(batch_size, batch_expected_count.size());
112 for (int i = 0; i < batch_size; i++) {
113 batch_expected_count[i] =
114 ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
115 }
116 }
117 CHECK_EQ(extras.size(), extras_expected_count.size());
118 for (size_t i = 0; i < extras.size(); i++) {
119 extras_expected_count[i] =
120 ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
121 }
122}
123
124AllSampler::AllSampler(int64_t range) : RangeSampler(range) {}
125
126void AllSampler::SampleBatchGetExpectedCountAvoid(
127 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
128 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
129 MutableArraySlice<float> extras_expected_count,
130 ArraySlice<int64_t> avoided_values) const {
131 const int batch_size = batch.size();
132 CHECK_EQ(range_, batch_size);
133 for (int i = 0; i < batch_size; i++) {
134 batch[i] = i;
135 }
136 if (!batch_expected_count.empty()) {
137 CHECK_EQ(batch_size, batch_expected_count.size());
138 for (int i = 0; i < batch_size; i++) {
139 batch_expected_count[i] = 1;
140 }
141 }
142 CHECK_EQ(size_t{0}, avoided_values.size());
143 CHECK_EQ(extras.size(), extras_expected_count.size());
144 for (size_t i = 0; i < extras.size(); i++) {
145 extras_expected_count[i] = 1;
146 }
147}
148
149UniformSampler::UniformSampler(int64_t range)
150 : RangeSampler(range), inv_range_(1.0 / range) {}
151
152int64_t UniformSampler::Sample(random::SimplePhilox* rnd) const {
153 return rnd->Uniform64(range_);
154}
155
156float UniformSampler::Probability(int64_t value) const { return inv_range_; }
157
158LogUniformSampler::LogUniformSampler(int64_t range)
159 : RangeSampler(range), log_range_(log1p(range)) {}
160
161int64_t LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
162 const int64_t value =
163 static_cast<int64_t>(exp(rnd->RandDouble() * log_range_)) - 1;
164 DCHECK_GE(value, 0);
165 // Mathematically, value should be <= range_, but might not be due to some
166 // floating point roundoff, so we mod by range_. In practice this case
167 // happens never regardless of the value of range_, including and up to
168 // DBL_MAX. But we include it as a guarantee of the function's output.
169 return value % range_;
170}
171
172float LogUniformSampler::Probability(int64_t value) const {
173 // value is returned iff the call to UniformDouble(log_range_) in the
174 // Sample() function returns a value between log(value + 1)
175 // and log(value + 2). The probability of this is:
176 // (log(value + 2) - log(value + 1)) / log_range
177 // To avoid two calls to log(), we compute this as follows:
178 return (log((value + 2.0) / (value + 1.0))) / log_range_;
179}
180
181ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64_t range)
182 : RangeSampler(range), picker_(range) {
183 CHECK_LT(range, kint32max);
184}
185
186int64_t ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
187 return picker_.Pick(rnd);
188}
189
190float ThreadUnsafeUnigramSampler::Probability(int64_t value) const {
191 return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
192}
193
194void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64_t> values) {
195 int num_updates = std::min(static_cast<int>(values.size()),
196 kint32max - picker_.total_weight());
197 for (int i = 0; i < num_updates; i++) {
198 const int64_t value = values[i];
199 picker_.set_weight(value, picker_.get_weight(value) + 1);
200 }
201}
202
203// Thread-safe unigram sampler
204UnigramSampler::UnigramSampler(int64_t range)
205 : RangeSampler(range), unsafe_sampler_(range) {
206 CHECK_LT(range, kint32max);
207}
208
209int64_t UnigramSampler::Sample(random::SimplePhilox* rnd) const {
210 tf_shared_lock lock(mu_);
211 return unsafe_sampler_.Sample(rnd);
212}
213
214float UnigramSampler::Probability(int64_t value) const {
215 tf_shared_lock lock(mu_);
216 return unsafe_sampler_.Probability(value);
217}
218
219// Overriding at a high level results in far fewer lock acquisitions.
220void UnigramSampler::SampleBatchGetExpectedCountAvoid(
221 random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch,
222 MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> extras,
223 MutableArraySlice<float> extras_expected_count,
224 ArraySlice<int64_t> avoided_values) const {
225 tf_shared_lock lock(mu_);
226 unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
227 rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
228 avoided_values);
229}
230
231void UnigramSampler::Update(ArraySlice<int64_t> values) {
232 mutex_lock lock(mu_);
233 unsafe_sampler_.Update(values);
234}
235
236FixedUnigramSampler::FixedUnigramSampler(Env* env, int64_t range,
237 const string& vocab_file,
238 float distortion,
239 int32_t num_reserved_ids,
240 int32_t num_shards, int32_t shard)
241 : RangeSampler(range),
242 total_weight_(0.0),
243 num_shards_(num_shards),
244 shard_(shard) {
245 FillReservedIds(num_reserved_ids);
246 // TODO(vanhoucke): make this non-crashing.
247 TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
248 CHECK_EQ(range, weights_.size());
249 dist_sampler_.reset(new random::DistributionSampler(weights_));
250}
251
252FixedUnigramSampler::FixedUnigramSampler(int64_t range,
253 const std::vector<float>& unigrams,
254 float distortion,
255 int32_t num_reserved_ids,
256 int32_t num_shards, int32_t shard)
257 : RangeSampler(range),
258 total_weight_(0.0),
259 num_shards_(num_shards),
260 shard_(shard) {
261 FillReservedIds(num_reserved_ids);
262 LoadFromUnigrams(unigrams, distortion);
263 // TODO(vanhoucke): make this non-crashing.
264 CHECK_EQ(range, weights_.size());
265 dist_sampler_.reset(new random::DistributionSampler(weights_));
266}
267
268float FixedUnigramSampler::Probability(int64_t value) const {
269 if (value < 0 || static_cast<size_t>(value) >= weights_.size()) {
270 return 0.0;
271 }
272 return weights_.at(value) / total_weight_;
273}
274
275int64_t FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
276 return dist_sampler_->Sample(rnd);
277}
278
279void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) {
280 for (int32_t word_id = 0; word_id < num_reserved_ids; ++word_id) {
281 if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
282 }
283}
284
285Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
286 float distortion) {
287 std::unique_ptr<RandomAccessFile> file;
288 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
289
290 io::InputBuffer in(file.get(), 262144 /*bytes*/);
291 string line;
292 int32_t word_id = weights_.size();
293 while (in.ReadLine(&line).ok()) {
294 // The vocabulary file should be in csv like format, with the last
295 // field the weight associated with the word.
296 std::vector<string> cols = str_util::Split(line, ',');
297 if (cols.empty()) continue;
298 // Skip entries that do not belong to this shard.
299 if (word_id % num_shards_ == shard_) {
300 float w = 0.0;
301 if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) {
302 return errors::InvalidArgument("Wrong vocabulary format at line: ",
303 line);
304 }
305 w = std::pow(w, distortion);
306 total_weight_ += w;
307 weights_.push_back(w);
308 }
309 ++word_id;
310 }
311 return OkStatus();
312}
313
314void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
315 float distortion) {
316 int32_t word_id = weights_.size();
317 for (float w : unigrams) {
318 // Skip entries that do not belong to this shard.
319 if (word_id % num_shards_ == shard_) {
320 w = std::pow(w, distortion);
321 total_weight_ += w;
322 weights_.push_back(w);
323 }
324 ++word_id;
325 }
326}
327
328} // namespace tensorflow
329