1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
17#define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
18
19#include <vector>
20
21#include "tensorflow/core/lib/core/status.h"
22#include "tensorflow/core/lib/gtl/array_slice.h"
23#include "tensorflow/core/lib/random/distribution_sampler.h"
24#include "tensorflow/core/lib/random/random_distributions.h"
25#include "tensorflow/core/lib/random/weighted_picker.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/mutex.h"
28#include "tensorflow/core/platform/thread_annotations.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tsl {
32class Env;
33} // namespace tsl
34namespace tensorflow {
35using Env = tsl::Env;
36
37// Abstract subclass for sampling from the set of non-negative integers
38// [0, range)
39class RangeSampler {
40 public:
41 explicit RangeSampler(int64_t range) : range_(range) { CHECK_GT(range_, 0); }
42 virtual ~RangeSampler();
43
44 // Sample a single value
45 virtual int64_t Sample(random::SimplePhilox* rnd) const = 0;
46
47 // The probability that a single call to Sample() returns the given value.
48 // Assumes that value is in [0, range). No range checking is done.
49 virtual float Probability(int64_t value) const = 0;
50
51 // Fill "batch" with samples from the distribution.
52 // If unique=true, then we re-pick each element until we get a
53 // value distinct from all previously picked values in the batch.
54 void SampleBatch(random::SimplePhilox* rnd, bool unique,
55 gtl::MutableArraySlice<int64_t> batch) const;
56
57 // Fill "batch" with samples from the distribution, and report
58 // "expected counts".
59 //
60 // The "expected count" of a value is an estimate of the expected
61 // number of occurrences of the value in the batch returned by a
62 // call to this function with the given parameters. If unique=true,
63 // the expected count is an inclusion probability. For details on
64 // this estimation, see the comment to "ExpectedCountHelper" in the
65 // .cc file.
66 //
67 // Expected counts for the elements of the returned "batch" are reported
68 // in the aligned array "batch_expected_count".
69 //
70 // The user can optionally provide "extras", containing values in the range.
71 // The expected counts for the extras are reported in the aligned array
72 // "extras_expected_count".
73 //
74 // "batch_expected_count" must have size equal to 0 or to the size of "batch".
75 // "extras" and "extras_expected_count" must have equal size.
76 void SampleBatchGetExpectedCount(
77 random::SimplePhilox* rnd, bool unique,
78 gtl::MutableArraySlice<int64_t> batch,
79 gtl::MutableArraySlice<float> batch_expected_count,
80 gtl::ArraySlice<int64_t> extras,
81 gtl::MutableArraySlice<float> extras_expected_count) const;
82
83 // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
84 // We repick to avoid all of the values in "avoided_values".
85 // "avoided_values" is only supported with unique=true. If
86 // unique=false, then avoided_values must be empty.
87 virtual void SampleBatchGetExpectedCountAvoid(
88 random::SimplePhilox* rnd, bool unique,
89 gtl::MutableArraySlice<int64_t> batch,
90 gtl::MutableArraySlice<float> batch_expected_count,
91 gtl::ArraySlice<int64_t> extras,
92 gtl::MutableArraySlice<float> extras_expected_count,
93 gtl::ArraySlice<int64_t> avoided_values) const;
94
95 // Does this sampler need to be updated with values, e.g. UnigramSampler
96 virtual bool NeedsUpdates() const { return false; }
97
98 // Updates the underlying distribution
99 virtual void Update(gtl::ArraySlice<int64_t> values) {
100 LOG(FATAL) << "Update not supported for this sampler type.";
101 }
102
103 int64_t range() { return range_; }
104
105 protected:
106 const int64_t range_;
107};
108
109// An AllSampler only samples batches of size equal to range.
110// It returns the entire range.
111// It cannot sample single values.
112class AllSampler : public RangeSampler {
113 public:
114 explicit AllSampler(int64_t range);
115
116 ~AllSampler() override {}
117
118 int64_t Sample(random::SimplePhilox* rnd) const override {
119 LOG(FATAL) << "Should not be called";
120 return 0;
121 }
122
123 float Probability(int64_t value) const override {
124 LOG(FATAL) << "Should not be called";
125 return 0;
126 }
127
128 void SampleBatchGetExpectedCountAvoid(
129 random::SimplePhilox* rnd, bool unique,
130 gtl::MutableArraySlice<int64_t> batch,
131 gtl::MutableArraySlice<float> batch_expected_count,
132 gtl::ArraySlice<int64_t> extras,
133 gtl::MutableArraySlice<float> extras_expected_count,
134 gtl::ArraySlice<int64_t> avoided_values) const override;
135};
136
137class UniformSampler : public RangeSampler {
138 public:
139 explicit UniformSampler(int64_t range);
140
141 ~UniformSampler() override {}
142
143 int64_t Sample(random::SimplePhilox* rnd) const override;
144
145 float Probability(int64_t value) const override;
146
147 private:
148 const float inv_range_;
149};
150
151class LogUniformSampler : public RangeSampler {
152 public:
153 explicit LogUniformSampler(int64_t range);
154
155 ~LogUniformSampler() override {}
156
157 int64_t Sample(random::SimplePhilox* rnd) const override;
158
159 float Probability(int64_t value) const override;
160
161 private:
162 const double log_range_;
163};
164
165// Thread-unsafe unigram sampler
166class ThreadUnsafeUnigramSampler : public RangeSampler {
167 public:
168 explicit ThreadUnsafeUnigramSampler(int64_t range);
169 ~ThreadUnsafeUnigramSampler() override {}
170
171 int64_t Sample(random::SimplePhilox* rnd) const override;
172
173 float Probability(int64_t value) const override;
174
175 bool NeedsUpdates() const override { return true; }
176 void Update(gtl::ArraySlice<int64_t> values) override;
177
178 private:
179 random::WeightedPicker picker_;
180};
181
182// Thread-safe unigram sampler
183class UnigramSampler : public RangeSampler {
184 public:
185 explicit UnigramSampler(int64_t range);
186 ~UnigramSampler() override {}
187
188 int64_t Sample(random::SimplePhilox* rnd) const override;
189
190 float Probability(int64_t value) const override;
191
192 // Overriding at a high level results in far fewer lock acquisitions.
193 void SampleBatchGetExpectedCountAvoid(
194 random::SimplePhilox* rnd, bool unique,
195 gtl::MutableArraySlice<int64_t> batch,
196 gtl::MutableArraySlice<float> batch_expected_count,
197 gtl::ArraySlice<int64_t> extras,
198 gtl::MutableArraySlice<float> extras_expected_count,
199 gtl::ArraySlice<int64_t> avoided_values) const override;
200
201 bool NeedsUpdates() const override { return true; }
202 void Update(gtl::ArraySlice<int64_t> values) override;
203
204 private:
205 ThreadUnsafeUnigramSampler unsafe_sampler_ TF_GUARDED_BY(mu_);
206 mutable mutex mu_;
207};
208
209// A unigram sampler that uses a fixed unigram distribution read from a
210// file or passed in as an in-memory array instead of building up the
211// distribution from data on the fly. There is also an option to skew the
212// distribution by applying a distortion power to the weights.
213class FixedUnigramSampler : public RangeSampler {
214 public:
215 // The vocab_file is assumed to be a CSV, with the last entry of each row a
216 // value representing the counts or probabilities for the corresponding ID.
217 FixedUnigramSampler(Env* env, int64_t range, const string& vocab_file,
218 float distortion, int32_t num_reserved_ids,
219 int32_t num_shards, int32_t shard);
220
221 FixedUnigramSampler(int64_t range, const std::vector<float>& unigrams,
222 float distortion, int32_t num_reserved_ids,
223 int32_t num_shards, int32_t shard);
224
225 float Probability(int64_t value) const override;
226
227 int64_t Sample(random::SimplePhilox* rnd) const override;
228
229 private:
230 // Underlying distribution sampler.
231 std::unique_ptr<random::DistributionSampler> dist_sampler_;
232 // Weights for individual samples. The probability of a sample i is defined
233 // as weights_.at(i) / total_weight_.
234 std::vector<float> weights_;
235 // The total weights of all samples.
236 float total_weight_;
237 // Sharding information of the sampler. The whole vocabulary is sharded
238 // into num_shards_ smaller ranges and each sampler is responsible for one
239 // such smaller range, identified by the shard number.
240 int32 num_shards_;
241 int32 shard_;
242
243 // Fill the sampler with the appropriate number of reserved IDs.
244 void FillReservedIds(int32_t num_reserved_ids);
245 // Load IDs to sample from a CSV file. It is assumed that the last item of
246 // each row contains a count or probability for the corresponding ID.
247 Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
248 // Load from an in-memory array.
249 void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
250};
251
252} // namespace tensorflow
253
254#endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
255