1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/op_kernel.h" |
18 | #include "tensorflow/core/lib/core/stringpiece.h" |
19 | #include "tensorflow/core/lib/gtl/map_util.h" |
20 | #include "tensorflow/core/lib/random/distribution_sampler.h" |
21 | #include "tensorflow/core/lib/random/philox_random.h" |
22 | #include "tensorflow/core/lib/random/simple_philox.h" |
23 | #include "tensorflow/core/lib/strings/str_util.h" |
24 | #include "tensorflow/core/platform/thread_annotations.h" |
25 | #include "tensorflow/core/util/guarded_philox_random.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | // Number of examples to precalculate. |
30 | const int kPrecalc = 3000; |
31 | // Number of words to read into a sentence before processing. |
32 | const int kSentenceSize = 1000; |
33 | |
34 | namespace { |
35 | |
36 | bool ScanWord(StringPiece* input, string* word) { |
37 | str_util::RemoveLeadingWhitespace(input); |
38 | StringPiece tmp; |
39 | if (str_util::ConsumeNonWhitespace(input, &tmp)) { |
40 | word->assign(tmp.data(), tmp.size()); |
41 | return true; |
42 | } else { |
43 | return false; |
44 | } |
45 | } |
46 | |
47 | } // end namespace |
48 | |
49 | class SkipgramOp : public OpKernel { |
50 | public: |
51 | explicit SkipgramOp(OpKernelConstruction* ctx) |
52 | : OpKernel(ctx), rng_(&philox_) { |
53 | string filename; |
54 | OP_REQUIRES_OK(ctx, ctx->GetAttr("filename" , &filename)); |
55 | OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size" , &batch_size_)); |
56 | OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size" , &window_size_)); |
57 | OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count" , &min_count_)); |
58 | OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample" , &subsample_)); |
59 | OP_REQUIRES_OK(ctx, Init(ctx->env(), filename)); |
60 | |
61 | mutex_lock l(mu_); |
62 | example_pos_ = corpus_size_; |
63 | label_pos_ = corpus_size_; |
64 | label_limit_ = corpus_size_; |
65 | sentence_index_ = kSentenceSize; |
66 | for (int i = 0; i < kPrecalc; ++i) { |
67 | NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label); |
68 | } |
69 | } |
70 | |
71 | void Compute(OpKernelContext* ctx) override { |
72 | Tensor words_per_epoch(DT_INT64, TensorShape({})); |
73 | Tensor current_epoch(DT_INT32, TensorShape({})); |
74 | Tensor total_words_processed(DT_INT64, TensorShape({})); |
75 | Tensor examples(DT_INT32, TensorShape({batch_size_})); |
76 | auto Texamples = examples.flat<int32>(); |
77 | Tensor labels(DT_INT32, TensorShape({batch_size_})); |
78 | auto Tlabels = labels.flat<int32>(); |
79 | { |
80 | mutex_lock l(mu_); |
81 | for (int i = 0; i < batch_size_; ++i) { |
82 | Texamples(i) = precalc_examples_[precalc_index_].input; |
83 | Tlabels(i) = precalc_examples_[precalc_index_].label; |
84 | precalc_index_++; |
85 | if (precalc_index_ >= kPrecalc) { |
86 | precalc_index_ = 0; |
87 | for (int j = 0; j < kPrecalc; ++j) { |
88 | NextExample(&precalc_examples_[j].input, |
89 | &precalc_examples_[j].label); |
90 | } |
91 | } |
92 | } |
93 | words_per_epoch.scalar<int64_t>()() = corpus_size_; |
94 | current_epoch.scalar<int32>()() = current_epoch_; |
95 | total_words_processed.scalar<int64_t>()() = total_words_processed_; |
96 | } |
97 | ctx->set_output(0, word_); |
98 | ctx->set_output(1, freq_); |
99 | ctx->set_output(2, words_per_epoch); |
100 | ctx->set_output(3, current_epoch); |
101 | ctx->set_output(4, total_words_processed); |
102 | ctx->set_output(5, examples); |
103 | ctx->set_output(6, labels); |
104 | } |
105 | |
106 | private: |
107 | struct Example { |
108 | int32 input; |
109 | int32 label; |
110 | }; |
111 | |
112 | int32 batch_size_ = 0; |
113 | int32 window_size_ = 5; |
114 | float subsample_ = 1e-3; |
115 | int min_count_ = 5; |
116 | int32 vocab_size_ = 0; |
117 | Tensor word_; |
118 | Tensor freq_; |
119 | int64_t corpus_size_ = 0; |
120 | std::vector<int32> corpus_; |
121 | std::vector<Example> precalc_examples_; |
122 | int precalc_index_ = 0; |
123 | std::vector<int32> sentence_; |
124 | int sentence_index_ = 0; |
125 | |
126 | mutex mu_; |
127 | random::PhiloxRandom philox_ TF_GUARDED_BY(mu_); |
128 | random::SimplePhilox rng_ TF_GUARDED_BY(mu_); |
129 | int32 current_epoch_ TF_GUARDED_BY(mu_) = -1; |
130 | int64_t total_words_processed_ TF_GUARDED_BY(mu_) = 0; |
131 | int32 example_pos_ TF_GUARDED_BY(mu_); |
132 | int32 label_pos_ TF_GUARDED_BY(mu_); |
133 | int32 label_limit_ TF_GUARDED_BY(mu_); |
134 | |
135 | // {example_pos_, label_pos_} is the cursor for the next example. |
136 | // example_pos_ wraps around at the end of corpus_. For each |
137 | // example, we randomly generate [label_pos_, label_limit) for |
138 | // labels. |
139 | void NextExample(int32* example, int32* label) |
140 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
141 | while (true) { |
142 | if (label_pos_ >= label_limit_) { |
143 | ++total_words_processed_; |
144 | ++sentence_index_; |
145 | if (sentence_index_ >= kSentenceSize) { |
146 | sentence_index_ = 0; |
147 | for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) { |
148 | if (example_pos_ >= corpus_size_) { |
149 | ++current_epoch_; |
150 | example_pos_ = 0; |
151 | } |
152 | if (subsample_ > 0) { |
153 | int32_t word_freq = freq_.flat<int32>()(corpus_[example_pos_]); |
154 | // See Eq. 5 in http://arxiv.org/abs/1310.4546 |
155 | float keep_prob = |
156 | (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) * |
157 | (subsample_ * corpus_size_) / word_freq; |
158 | if (rng_.RandFloat() > keep_prob) { |
159 | i--; |
160 | continue; |
161 | } |
162 | } |
163 | sentence_[i] = corpus_[example_pos_]; |
164 | } |
165 | } |
166 | const int32_t skip = 1 + rng_.Uniform(window_size_); |
167 | label_pos_ = std::max<int32>(0, sentence_index_ - skip); |
168 | label_limit_ = |
169 | std::min<int32>(kSentenceSize, sentence_index_ + skip + 1); |
170 | } |
171 | if (sentence_index_ != label_pos_) { |
172 | break; |
173 | } |
174 | ++label_pos_; |
175 | } |
176 | *example = sentence_[sentence_index_]; |
177 | *label = sentence_[label_pos_++]; |
178 | } |
179 | |
180 | Status Init(Env* env, const string& filename) { |
181 | string data; |
182 | TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); |
183 | StringPiece input = data; |
184 | string w; |
185 | corpus_size_ = 0; |
186 | std::unordered_map<string, int32> word_freq; |
187 | while (ScanWord(&input, &w)) { |
188 | ++(word_freq[w]); |
189 | ++corpus_size_; |
190 | } |
191 | if (corpus_size_ < window_size_ * 10) { |
192 | return errors::InvalidArgument( |
193 | "The text file " , filename, |
194 | " contains too little data: " , corpus_size_, " words" ); |
195 | } |
196 | typedef std::pair<string, int32> WordFreq; |
197 | std::vector<WordFreq> ordered; |
198 | for (const auto& p : word_freq) { |
199 | if (p.second >= min_count_) ordered.push_back(p); |
200 | } |
201 | LOG(INFO) << "Data file: " << filename << " contains " << data.size() |
202 | << " bytes, " << corpus_size_ << " words, " << word_freq.size() |
203 | << " unique words, " << ordered.size() |
204 | << " unique frequent words." ; |
205 | word_freq.clear(); |
206 | std::sort(ordered.begin(), ordered.end(), |
207 | [](const WordFreq& x, const WordFreq& y) { |
208 | return x.second > y.second; |
209 | }); |
210 | vocab_size_ = static_cast<int32>(1 + ordered.size()); |
211 | Tensor word(DT_STRING, TensorShape({vocab_size_})); |
212 | Tensor freq(DT_INT32, TensorShape({vocab_size_})); |
213 | word.flat<tstring>()(0) = "UNK" ; |
214 | static const int32_t kUnkId = 0; |
215 | std::unordered_map<string, int32> word_id; |
216 | int64_t total_counted = 0; |
217 | for (std::size_t i = 0; i < ordered.size(); ++i) { |
218 | const auto& w = ordered[i].first; |
219 | auto id = i + 1; |
220 | word.flat<tstring>()(id) = w; |
221 | auto word_count = ordered[i].second; |
222 | freq.flat<int32>()(id) = word_count; |
223 | total_counted += word_count; |
224 | word_id[w] = id; |
225 | } |
226 | freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted; |
227 | word_ = word; |
228 | freq_ = freq; |
229 | corpus_.reserve(corpus_size_); |
230 | input = data; |
231 | while (ScanWord(&input, &w)) { |
232 | corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId)); |
233 | } |
234 | precalc_examples_.resize(kPrecalc); |
235 | sentence_.resize(kSentenceSize); |
236 | return OkStatus(); |
237 | } |
238 | }; |
239 | |
240 | REGISTER_KERNEL_BUILDER(Name("Skipgram" ).Device(DEVICE_CPU), SkipgramOp); |
241 | |
242 | class NegTrainOp : public OpKernel { |
243 | public: |
244 | explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
245 | base_.Init(0, 0); |
246 | |
247 | OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples" , &num_samples_)); |
248 | |
249 | std::vector<int32> vocab_count; |
250 | OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count" , &vocab_count)); |
251 | |
252 | std::vector<float> vocab_weights; |
253 | vocab_weights.reserve(vocab_count.size()); |
254 | for (const auto& f : vocab_count) { |
255 | float r = std::pow(static_cast<float>(f), 0.75f); |
256 | vocab_weights.push_back(r); |
257 | } |
258 | sampler_ = new random::DistributionSampler(vocab_weights); |
259 | } |
260 | |
261 | ~NegTrainOp() override { delete sampler_; } |
262 | |
263 | void Compute(OpKernelContext* ctx) override { |
264 | Tensor w_in = ctx->mutable_input(0, false); |
265 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()), |
266 | errors::InvalidArgument("Must be a matrix" )); |
267 | Tensor w_out = ctx->mutable_input(1, false); |
268 | OP_REQUIRES(ctx, w_in.shape() == w_out.shape(), |
269 | errors::InvalidArgument("w_in.shape == w_out.shape" )); |
270 | const Tensor& examples = ctx->input(2); |
271 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()), |
272 | errors::InvalidArgument("Must be a vector" )); |
273 | const Tensor& labels = ctx->input(3); |
274 | OP_REQUIRES(ctx, examples.shape() == labels.shape(), |
275 | errors::InvalidArgument("examples.shape == labels.shape" )); |
276 | const Tensor& learning_rate = ctx->input(4); |
277 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()), |
278 | errors::InvalidArgument("Must be a scalar" )); |
279 | |
280 | auto Tw_in = w_in.matrix<float>(); |
281 | auto Tw_out = w_out.matrix<float>(); |
282 | auto Texamples = examples.flat<int32>(); |
283 | auto Tlabels = labels.flat<int32>(); |
284 | auto lr = learning_rate.scalar<float>()(); |
285 | const int64_t vocab_size = w_in.dim_size(0); |
286 | const int64_t dims = w_in.dim_size(1); |
287 | const int64_t batch_size = examples.dim_size(0); |
288 | OP_REQUIRES(ctx, vocab_size == sampler_->num(), |
289 | errors::InvalidArgument("vocab_size mismatches: " , vocab_size, |
290 | " vs. " , sampler_->num())); |
291 | |
292 | // Gradient accumulator for v_in. |
293 | Tensor buf(DT_FLOAT, TensorShape({dims})); |
294 | auto Tbuf = buf.flat<float>(); |
295 | |
296 | // Scalar buffer to hold sigmoid(+/- dot). |
297 | Tensor g_buf(DT_FLOAT, TensorShape({})); |
298 | auto g = g_buf.scalar<float>(); |
299 | |
300 | // The following loop needs 2 random 32-bit values per negative |
301 | // sample. We reserve 8 values per sample just in case the |
302 | // underlying implementation changes. |
303 | auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8); |
304 | random::SimplePhilox srnd(&rnd); |
305 | |
306 | for (int64_t i = 0; i < batch_size; ++i) { |
307 | const int32_t example = Texamples(i); |
308 | DCHECK(0 <= example && example < vocab_size) << example; |
309 | const int32_t label = Tlabels(i); |
310 | DCHECK(0 <= label && label < vocab_size) << label; |
311 | auto v_in = Tw_in.chip<0>(example); |
312 | |
313 | // Positive: example predicts label. |
314 | // forward: x = v_in' * v_out |
315 | // l = log(sigmoid(x)) |
316 | // backward: dl/dx = g = sigmoid(-x) |
317 | // dl/d(v_in) = g * v_out' |
318 | // dl/d(v_out) = v_in' * g |
319 | { |
320 | auto v_out = Tw_out.chip<0>(label); |
321 | auto dot = (v_in * v_out).sum(); |
322 | g = (dot.exp() + 1.f).inverse(); |
323 | Tbuf = v_out * (g() * lr); |
324 | v_out += v_in * (g() * lr); |
325 | } |
326 | |
327 | // Negative samples: |
328 | // forward: x = v_in' * v_sample |
329 | // l = log(sigmoid(-x)) |
330 | // backward: dl/dx = g = -sigmoid(x) |
331 | // dl/d(v_in) = g * v_out' |
332 | // dl/d(v_out) = v_in' * g |
333 | for (int j = 0; j < num_samples_; ++j) { |
334 | const int sample = sampler_->Sample(&srnd); |
335 | if (sample == label) continue; // Skip. |
336 | auto v_sample = Tw_out.chip<0>(sample); |
337 | auto dot = (v_in * v_sample).sum(); |
338 | g = -((-dot).exp() + 1.f).inverse(); |
339 | Tbuf += v_sample * (g() * lr); |
340 | v_sample += v_in * (g() * lr); |
341 | } |
342 | |
343 | // Applies the gradient on v_in. |
344 | v_in += Tbuf; |
345 | } |
346 | } |
347 | |
348 | private: |
349 | int32 num_samples_ = 0; |
350 | random::DistributionSampler* sampler_ = nullptr; |
351 | GuardedPhiloxRandom base_; |
352 | }; |
353 | |
354 | REGISTER_KERNEL_BUILDER(Name("NegTrain" ).Device(DEVICE_CPU), NegTrainOp); |
355 | |
356 | } // end namespace tensorflow |
357 | |