1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/lib/core/status.h" |
22 | #include "tensorflow/core/lib/gtl/array_slice.h" |
23 | #include "tensorflow/core/lib/random/distribution_sampler.h" |
24 | #include "tensorflow/core/lib/random/random_distributions.h" |
25 | #include "tensorflow/core/lib/random/weighted_picker.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/platform/mutex.h" |
28 | #include "tensorflow/core/platform/thread_annotations.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tsl { |
32 | class Env; |
33 | } // namespace tsl |
34 | namespace tensorflow { |
35 | using Env = tsl::Env; |
36 | |
37 | // Abstract subclass for sampling from the set of non-negative integers |
38 | // [0, range) |
39 | class RangeSampler { |
40 | public: |
41 | explicit RangeSampler(int64_t range) : range_(range) { CHECK_GT(range_, 0); } |
42 | virtual ~RangeSampler(); |
43 | |
44 | // Sample a single value |
45 | virtual int64_t Sample(random::SimplePhilox* rnd) const = 0; |
46 | |
47 | // The probability that a single call to Sample() returns the given value. |
48 | // Assumes that value is in [0, range). No range checking is done. |
49 | virtual float Probability(int64_t value) const = 0; |
50 | |
51 | // Fill "batch" with samples from the distribution. |
52 | // If unique=true, then we re-pick each element until we get a |
53 | // value distinct from all previously picked values in the batch. |
54 | void SampleBatch(random::SimplePhilox* rnd, bool unique, |
55 | gtl::MutableArraySlice<int64_t> batch) const; |
56 | |
57 | // Fill "batch" with samples from the distribution, and report |
58 | // "expected counts". |
59 | // |
60 | // The "expected count" of a value is an estimate of the expected |
61 | // number of occurrences of the value in the batch returned by a |
62 | // call to this function with the given parameters. If unique=true, |
63 | // the expected count is an inclusion probability. For details on |
64 | // this estimation, see the comment to "ExpectedCountHelper" in the |
65 | // .cc file. |
66 | // |
67 | // Expected counts for the elements of the returned "batch" are reported |
68 | // in the aligned array "batch_expected_count". |
69 | // |
70 | // The user can optionally provide "extras", containing values in the range. |
71 | // The expected counts for the extras are reported in the aligned array |
72 | // "extras_expected_count". |
73 | // |
74 | // "batch_expected_count" must have size equal to 0 or to the size of "batch". |
75 | // "extras" and "extras_expected_count" must have equal size. |
76 | void SampleBatchGetExpectedCount( |
77 | random::SimplePhilox* rnd, bool unique, |
78 | gtl::MutableArraySlice<int64_t> batch, |
79 | gtl::MutableArraySlice<float> batch_expected_count, |
80 | gtl::ArraySlice<int64_t> , |
81 | gtl::MutableArraySlice<float> ) const; |
82 | |
83 | // Same as SampleBatchGetExpectedCount (see above), but with avoided values. |
84 | // We repick to avoid all of the values in "avoided_values". |
85 | // "avoided_values" is only supported with unique=true. If |
86 | // unique=false, then avoided_values must be empty. |
87 | virtual void SampleBatchGetExpectedCountAvoid( |
88 | random::SimplePhilox* rnd, bool unique, |
89 | gtl::MutableArraySlice<int64_t> batch, |
90 | gtl::MutableArraySlice<float> batch_expected_count, |
91 | gtl::ArraySlice<int64_t> , |
92 | gtl::MutableArraySlice<float> , |
93 | gtl::ArraySlice<int64_t> avoided_values) const; |
94 | |
95 | // Does this sampler need to be updated with values, e.g. UnigramSampler |
96 | virtual bool NeedsUpdates() const { return false; } |
97 | |
98 | // Updates the underlying distribution |
99 | virtual void Update(gtl::ArraySlice<int64_t> values) { |
100 | LOG(FATAL) << "Update not supported for this sampler type." ; |
101 | } |
102 | |
103 | int64_t range() { return range_; } |
104 | |
105 | protected: |
106 | const int64_t range_; |
107 | }; |
108 | |
109 | // An AllSampler only samples batches of size equal to range. |
110 | // It returns the entire range. |
111 | // It cannot sample single values. |
112 | class AllSampler : public RangeSampler { |
113 | public: |
114 | explicit AllSampler(int64_t range); |
115 | |
116 | ~AllSampler() override {} |
117 | |
118 | int64_t Sample(random::SimplePhilox* rnd) const override { |
119 | LOG(FATAL) << "Should not be called" ; |
120 | return 0; |
121 | } |
122 | |
123 | float Probability(int64_t value) const override { |
124 | LOG(FATAL) << "Should not be called" ; |
125 | return 0; |
126 | } |
127 | |
128 | void SampleBatchGetExpectedCountAvoid( |
129 | random::SimplePhilox* rnd, bool unique, |
130 | gtl::MutableArraySlice<int64_t> batch, |
131 | gtl::MutableArraySlice<float> batch_expected_count, |
132 | gtl::ArraySlice<int64_t> , |
133 | gtl::MutableArraySlice<float> , |
134 | gtl::ArraySlice<int64_t> avoided_values) const override; |
135 | }; |
136 | |
137 | class UniformSampler : public RangeSampler { |
138 | public: |
139 | explicit UniformSampler(int64_t range); |
140 | |
141 | ~UniformSampler() override {} |
142 | |
143 | int64_t Sample(random::SimplePhilox* rnd) const override; |
144 | |
145 | float Probability(int64_t value) const override; |
146 | |
147 | private: |
148 | const float inv_range_; |
149 | }; |
150 | |
151 | class LogUniformSampler : public RangeSampler { |
152 | public: |
153 | explicit LogUniformSampler(int64_t range); |
154 | |
155 | ~LogUniformSampler() override {} |
156 | |
157 | int64_t Sample(random::SimplePhilox* rnd) const override; |
158 | |
159 | float Probability(int64_t value) const override; |
160 | |
161 | private: |
162 | const double log_range_; |
163 | }; |
164 | |
165 | // Thread-unsafe unigram sampler |
166 | class ThreadUnsafeUnigramSampler : public RangeSampler { |
167 | public: |
168 | explicit ThreadUnsafeUnigramSampler(int64_t range); |
169 | ~ThreadUnsafeUnigramSampler() override {} |
170 | |
171 | int64_t Sample(random::SimplePhilox* rnd) const override; |
172 | |
173 | float Probability(int64_t value) const override; |
174 | |
175 | bool NeedsUpdates() const override { return true; } |
176 | void Update(gtl::ArraySlice<int64_t> values) override; |
177 | |
178 | private: |
179 | random::WeightedPicker picker_; |
180 | }; |
181 | |
182 | // Thread-safe unigram sampler |
183 | class UnigramSampler : public RangeSampler { |
184 | public: |
185 | explicit UnigramSampler(int64_t range); |
186 | ~UnigramSampler() override {} |
187 | |
188 | int64_t Sample(random::SimplePhilox* rnd) const override; |
189 | |
190 | float Probability(int64_t value) const override; |
191 | |
192 | // Overriding at a high level results in far fewer lock acquisitions. |
193 | void SampleBatchGetExpectedCountAvoid( |
194 | random::SimplePhilox* rnd, bool unique, |
195 | gtl::MutableArraySlice<int64_t> batch, |
196 | gtl::MutableArraySlice<float> batch_expected_count, |
197 | gtl::ArraySlice<int64_t> , |
198 | gtl::MutableArraySlice<float> , |
199 | gtl::ArraySlice<int64_t> avoided_values) const override; |
200 | |
201 | bool NeedsUpdates() const override { return true; } |
202 | void Update(gtl::ArraySlice<int64_t> values) override; |
203 | |
204 | private: |
205 | ThreadUnsafeUnigramSampler unsafe_sampler_ TF_GUARDED_BY(mu_); |
206 | mutable mutex mu_; |
207 | }; |
208 | |
209 | // A unigram sampler that uses a fixed unigram distribution read from a |
210 | // file or passed in as an in-memory array instead of building up the |
211 | // distribution from data on the fly. There is also an option to skew the |
212 | // distribution by applying a distortion power to the weights. |
213 | class FixedUnigramSampler : public RangeSampler { |
214 | public: |
215 | // The vocab_file is assumed to be a CSV, with the last entry of each row a |
216 | // value representing the counts or probabilities for the corresponding ID. |
217 | FixedUnigramSampler(Env* env, int64_t range, const string& vocab_file, |
218 | float distortion, int32_t num_reserved_ids, |
219 | int32_t num_shards, int32_t shard); |
220 | |
221 | FixedUnigramSampler(int64_t range, const std::vector<float>& unigrams, |
222 | float distortion, int32_t num_reserved_ids, |
223 | int32_t num_shards, int32_t shard); |
224 | |
225 | float Probability(int64_t value) const override; |
226 | |
227 | int64_t Sample(random::SimplePhilox* rnd) const override; |
228 | |
229 | private: |
230 | // Underlying distribution sampler. |
231 | std::unique_ptr<random::DistributionSampler> dist_sampler_; |
232 | // Weights for individual samples. The probability of a sample i is defined |
233 | // as weights_.at(i) / total_weight_. |
234 | std::vector<float> weights_; |
235 | // The total weights of all samples. |
236 | float total_weight_; |
237 | // Sharding information of the sampler. The whole vocabulary is sharded |
238 | // into num_shards_ smaller ranges and each sampler is responsible for one |
239 | // such smaller range, identified by the shard number. |
240 | int32 num_shards_; |
241 | int32 shard_; |
242 | |
243 | // Fill the sampler with the appropriate number of reserved IDs. |
244 | void FillReservedIds(int32_t num_reserved_ids); |
245 | // Load IDs to sample from a CSV file. It is assumed that the last item of |
246 | // each row contains a count or probability for the corresponding ID. |
247 | Status LoadFromFile(Env* env, const string& vocab_file, float distortion); |
248 | // Load from an in-memory array. |
249 | void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion); |
250 | }; |
251 | |
252 | } // namespace tensorflow |
253 | |
254 | #endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ |
255 | |