1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/range_sampler.h" |
17 | |
18 | #include <cmath> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | #include "tensorflow/core/lib/gtl/map_util.h" |
24 | #include "tensorflow/core/lib/io/inputbuffer.h" |
25 | #include "tensorflow/core/lib/strings/numbers.h" |
26 | #include "tensorflow/core/lib/strings/str_util.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/platform/mutex.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | using gtl::ArraySlice; |
34 | using gtl::MutableArraySlice; |
35 | |
36 | RangeSampler::~RangeSampler() {} |
37 | |
38 | void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique, |
39 | gtl::MutableArraySlice<int64_t> batch) const { |
40 | SampleBatchGetExpectedCount( |
41 | rnd, unique, batch, gtl::MutableArraySlice<float>(), |
42 | gtl::ArraySlice<int64_t>(), gtl::MutableArraySlice<float>()); |
43 | } |
44 | |
45 | void RangeSampler::SampleBatchGetExpectedCount( |
46 | random::SimplePhilox* rnd, bool unique, |
47 | gtl::MutableArraySlice<int64_t> batch, |
48 | gtl::MutableArraySlice<float> batch_expected_count, |
49 | gtl::ArraySlice<int64_t> , |
50 | gtl::MutableArraySlice<float> ) const { |
51 | SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count, |
52 | extras, extras_expected_count, |
53 | gtl::ArraySlice<int64_t>()); |
54 | } |
55 | |
56 | namespace { |
57 | |
58 | // Approximates the expected count of a value in the output of SampleBatch. |
59 | // |
60 | // If unique=false, then this is (Probability(value) * batch_size) |
61 | // |
62 | // We use batch_size and num_tries, where num_tries is the observed number of |
63 | // tries it took to get batch_size unique values. |
64 | // |
65 | // Assuming (falsely) that the number of tries to get a batch of batch_size |
66 | // distinct values is _always_ num_tries, the probability that the value |
67 | // is in a batch is (1 - (1-p)^num_tries) |
68 | static float ExpectedCountHelper(float p, int batch_size, int num_tries) { |
69 | if (num_tries == batch_size) { |
70 | // This shortcut will always be taken if unique=false |
71 | return p * batch_size; |
72 | } |
73 | // numerically stable version of (1 - (1-p)^num_tries) |
74 | return -std::expm1(num_tries * std::log1p(-p)); |
75 | } |
76 | |
77 | } // namespace |
78 | |
79 | void RangeSampler::SampleBatchGetExpectedCountAvoid( |
80 | random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch, |
81 | MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> , |
82 | MutableArraySlice<float> , |
83 | ArraySlice<int64_t> avoided_values) const { |
84 | const int batch_size = batch.size(); |
85 | int num_tries; |
86 | |
87 | if (unique) { |
88 | CHECK_LE(static_cast<int64_t>(batch_size + avoided_values.size()), range_); |
89 | std::unordered_set<int64_t> used(batch_size); |
90 | used.insert(avoided_values.begin(), avoided_values.end()); |
91 | int num_picked = 0; |
92 | num_tries = 0; |
93 | while (num_picked < batch_size) { |
94 | num_tries++; |
95 | CHECK_LT(num_tries, kint32max); |
96 | int64_t value = Sample(rnd); |
97 | if (gtl::InsertIfNotPresent(&used, value)) { |
98 | batch[num_picked++] = value; |
99 | } |
100 | } |
101 | } else { |
102 | CHECK_EQ(avoided_values.size(), size_t{0}) |
103 | << "avoided_values only supported with unique=true" ; |
104 | for (int i = 0; i < batch_size; i++) { |
105 | batch[i] = Sample(rnd); |
106 | } |
107 | num_tries = batch_size; |
108 | } |
109 | // Compute the expected counts of the batch and the extra values |
110 | if (!batch_expected_count.empty()) { |
111 | CHECK_EQ(batch_size, batch_expected_count.size()); |
112 | for (int i = 0; i < batch_size; i++) { |
113 | batch_expected_count[i] = |
114 | ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries); |
115 | } |
116 | } |
117 | CHECK_EQ(extras.size(), extras_expected_count.size()); |
118 | for (size_t i = 0; i < extras.size(); i++) { |
119 | extras_expected_count[i] = |
120 | ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries); |
121 | } |
122 | } |
123 | |
124 | AllSampler::AllSampler(int64_t range) : RangeSampler(range) {} |
125 | |
126 | void AllSampler::SampleBatchGetExpectedCountAvoid( |
127 | random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch, |
128 | MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> , |
129 | MutableArraySlice<float> , |
130 | ArraySlice<int64_t> avoided_values) const { |
131 | const int batch_size = batch.size(); |
132 | CHECK_EQ(range_, batch_size); |
133 | for (int i = 0; i < batch_size; i++) { |
134 | batch[i] = i; |
135 | } |
136 | if (!batch_expected_count.empty()) { |
137 | CHECK_EQ(batch_size, batch_expected_count.size()); |
138 | for (int i = 0; i < batch_size; i++) { |
139 | batch_expected_count[i] = 1; |
140 | } |
141 | } |
142 | CHECK_EQ(size_t{0}, avoided_values.size()); |
143 | CHECK_EQ(extras.size(), extras_expected_count.size()); |
144 | for (size_t i = 0; i < extras.size(); i++) { |
145 | extras_expected_count[i] = 1; |
146 | } |
147 | } |
148 | |
149 | UniformSampler::UniformSampler(int64_t range) |
150 | : RangeSampler(range), inv_range_(1.0 / range) {} |
151 | |
152 | int64_t UniformSampler::Sample(random::SimplePhilox* rnd) const { |
153 | return rnd->Uniform64(range_); |
154 | } |
155 | |
156 | float UniformSampler::Probability(int64_t value) const { return inv_range_; } |
157 | |
158 | LogUniformSampler::LogUniformSampler(int64_t range) |
159 | : RangeSampler(range), log_range_(log1p(range)) {} |
160 | |
161 | int64_t LogUniformSampler::Sample(random::SimplePhilox* rnd) const { |
162 | const int64_t value = |
163 | static_cast<int64_t>(exp(rnd->RandDouble() * log_range_)) - 1; |
164 | DCHECK_GE(value, 0); |
165 | // Mathematically, value should be <= range_, but might not be due to some |
166 | // floating point roundoff, so we mod by range_. In practice this case |
167 | // happens never regardless of the value of range_, including and up to |
168 | // DBL_MAX. But we include it as a guarantee of the function's output. |
169 | return value % range_; |
170 | } |
171 | |
172 | float LogUniformSampler::Probability(int64_t value) const { |
173 | // value is returned iff the call to UniformDouble(log_range_) in the |
174 | // Sample() function returns a value between log(value + 1) |
175 | // and log(value + 2). The probability of this is: |
176 | // (log(value + 2) - log(value + 1)) / log_range |
177 | // To avoid two calls to log(), we compute this as follows: |
178 | return (log((value + 2.0) / (value + 1.0))) / log_range_; |
179 | } |
180 | |
181 | ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64_t range) |
182 | : RangeSampler(range), picker_(range) { |
183 | CHECK_LT(range, kint32max); |
184 | } |
185 | |
186 | int64_t ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const { |
187 | return picker_.Pick(rnd); |
188 | } |
189 | |
190 | float ThreadUnsafeUnigramSampler::Probability(int64_t value) const { |
191 | return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight(); |
192 | } |
193 | |
194 | void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64_t> values) { |
195 | int num_updates = std::min(static_cast<int>(values.size()), |
196 | kint32max - picker_.total_weight()); |
197 | for (int i = 0; i < num_updates; i++) { |
198 | const int64_t value = values[i]; |
199 | picker_.set_weight(value, picker_.get_weight(value) + 1); |
200 | } |
201 | } |
202 | |
203 | // Thread-safe unigram sampler |
204 | UnigramSampler::UnigramSampler(int64_t range) |
205 | : RangeSampler(range), unsafe_sampler_(range) { |
206 | CHECK_LT(range, kint32max); |
207 | } |
208 | |
209 | int64_t UnigramSampler::Sample(random::SimplePhilox* rnd) const { |
210 | tf_shared_lock lock(mu_); |
211 | return unsafe_sampler_.Sample(rnd); |
212 | } |
213 | |
214 | float UnigramSampler::Probability(int64_t value) const { |
215 | tf_shared_lock lock(mu_); |
216 | return unsafe_sampler_.Probability(value); |
217 | } |
218 | |
219 | // Overriding at a high level results in far fewer lock acquisitions. |
220 | void UnigramSampler::SampleBatchGetExpectedCountAvoid( |
221 | random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64_t> batch, |
222 | MutableArraySlice<float> batch_expected_count, ArraySlice<int64_t> , |
223 | MutableArraySlice<float> , |
224 | ArraySlice<int64_t> avoided_values) const { |
225 | tf_shared_lock lock(mu_); |
226 | unsafe_sampler_.SampleBatchGetExpectedCountAvoid( |
227 | rnd, unique, batch, batch_expected_count, extras, extras_expected_count, |
228 | avoided_values); |
229 | } |
230 | |
231 | void UnigramSampler::Update(ArraySlice<int64_t> values) { |
232 | mutex_lock lock(mu_); |
233 | unsafe_sampler_.Update(values); |
234 | } |
235 | |
236 | FixedUnigramSampler::FixedUnigramSampler(Env* env, int64_t range, |
237 | const string& vocab_file, |
238 | float distortion, |
239 | int32_t num_reserved_ids, |
240 | int32_t num_shards, int32_t shard) |
241 | : RangeSampler(range), |
242 | total_weight_(0.0), |
243 | num_shards_(num_shards), |
244 | shard_(shard) { |
245 | FillReservedIds(num_reserved_ids); |
246 | // TODO(vanhoucke): make this non-crashing. |
247 | TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion)); |
248 | CHECK_EQ(range, weights_.size()); |
249 | dist_sampler_.reset(new random::DistributionSampler(weights_)); |
250 | } |
251 | |
252 | FixedUnigramSampler::FixedUnigramSampler(int64_t range, |
253 | const std::vector<float>& unigrams, |
254 | float distortion, |
255 | int32_t num_reserved_ids, |
256 | int32_t num_shards, int32_t shard) |
257 | : RangeSampler(range), |
258 | total_weight_(0.0), |
259 | num_shards_(num_shards), |
260 | shard_(shard) { |
261 | FillReservedIds(num_reserved_ids); |
262 | LoadFromUnigrams(unigrams, distortion); |
263 | // TODO(vanhoucke): make this non-crashing. |
264 | CHECK_EQ(range, weights_.size()); |
265 | dist_sampler_.reset(new random::DistributionSampler(weights_)); |
266 | } |
267 | |
268 | float FixedUnigramSampler::Probability(int64_t value) const { |
269 | if (value < 0 || static_cast<size_t>(value) >= weights_.size()) { |
270 | return 0.0; |
271 | } |
272 | return weights_.at(value) / total_weight_; |
273 | } |
274 | |
275 | int64_t FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const { |
276 | return dist_sampler_->Sample(rnd); |
277 | } |
278 | |
279 | void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) { |
280 | for (int32_t word_id = 0; word_id < num_reserved_ids; ++word_id) { |
281 | if (word_id % num_shards_ == shard_) weights_.push_back(0.0); |
282 | } |
283 | } |
284 | |
285 | Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, |
286 | float distortion) { |
287 | std::unique_ptr<RandomAccessFile> file; |
288 | TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); |
289 | |
290 | io::InputBuffer in(file.get(), 262144 /*bytes*/); |
291 | string line; |
292 | int32_t word_id = weights_.size(); |
293 | while (in.ReadLine(&line).ok()) { |
294 | // The vocabulary file should be in csv like format, with the last |
295 | // field the weight associated with the word. |
296 | std::vector<string> cols = str_util::Split(line, ','); |
297 | if (cols.empty()) continue; |
298 | // Skip entries that do not belong to this shard. |
299 | if (word_id % num_shards_ == shard_) { |
300 | float w = 0.0; |
301 | if (!strings::safe_strtof(cols.at(cols.size() - 1), &w)) { |
302 | return errors::InvalidArgument("Wrong vocabulary format at line: " , |
303 | line); |
304 | } |
305 | w = std::pow(w, distortion); |
306 | total_weight_ += w; |
307 | weights_.push_back(w); |
308 | } |
309 | ++word_id; |
310 | } |
311 | return OkStatus(); |
312 | } |
313 | |
314 | void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams, |
315 | float distortion) { |
316 | int32_t word_id = weights_.size(); |
317 | for (float w : unigrams) { |
318 | // Skip entries that do not belong to this shard. |
319 | if (word_id % num_shards_ == shard_) { |
320 | w = std::pow(w, distortion); |
321 | total_weight_ += w; |
322 | weights_.push_back(w); |
323 | } |
324 | ++word_id; |
325 | } |
326 | } |
327 | |
328 | } // namespace tensorflow |
329 | |