1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/lib/core/stringpiece.h"
19#include "tensorflow/core/lib/gtl/map_util.h"
20#include "tensorflow/core/lib/random/distribution_sampler.h"
21#include "tensorflow/core/lib/random/philox_random.h"
22#include "tensorflow/core/lib/random/simple_philox.h"
23#include "tensorflow/core/lib/strings/str_util.h"
24#include "tensorflow/core/platform/thread_annotations.h"
25#include "tensorflow/core/util/guarded_philox_random.h"
26
27namespace tensorflow {
28
29// Number of examples to precalculate.
30const int kPrecalc = 3000;
31// Number of words to read into a sentence before processing.
32const int kSentenceSize = 1000;
33
34namespace {
35
36bool ScanWord(StringPiece* input, string* word) {
37 str_util::RemoveLeadingWhitespace(input);
38 StringPiece tmp;
39 if (str_util::ConsumeNonWhitespace(input, &tmp)) {
40 word->assign(tmp.data(), tmp.size());
41 return true;
42 } else {
43 return false;
44 }
45}
46
47} // end namespace
48
49class SkipgramOp : public OpKernel {
50 public:
51 explicit SkipgramOp(OpKernelConstruction* ctx)
52 : OpKernel(ctx), rng_(&philox_) {
53 string filename;
54 OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
55 OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
57 OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
58 OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
59 OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
60
61 mutex_lock l(mu_);
62 example_pos_ = corpus_size_;
63 label_pos_ = corpus_size_;
64 label_limit_ = corpus_size_;
65 sentence_index_ = kSentenceSize;
66 for (int i = 0; i < kPrecalc; ++i) {
67 NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label);
68 }
69 }
70
71 void Compute(OpKernelContext* ctx) override {
72 Tensor words_per_epoch(DT_INT64, TensorShape({}));
73 Tensor current_epoch(DT_INT32, TensorShape({}));
74 Tensor total_words_processed(DT_INT64, TensorShape({}));
75 Tensor examples(DT_INT32, TensorShape({batch_size_}));
76 auto Texamples = examples.flat<int32>();
77 Tensor labels(DT_INT32, TensorShape({batch_size_}));
78 auto Tlabels = labels.flat<int32>();
79 {
80 mutex_lock l(mu_);
81 for (int i = 0; i < batch_size_; ++i) {
82 Texamples(i) = precalc_examples_[precalc_index_].input;
83 Tlabels(i) = precalc_examples_[precalc_index_].label;
84 precalc_index_++;
85 if (precalc_index_ >= kPrecalc) {
86 precalc_index_ = 0;
87 for (int j = 0; j < kPrecalc; ++j) {
88 NextExample(&precalc_examples_[j].input,
89 &precalc_examples_[j].label);
90 }
91 }
92 }
93 words_per_epoch.scalar<int64_t>()() = corpus_size_;
94 current_epoch.scalar<int32>()() = current_epoch_;
95 total_words_processed.scalar<int64_t>()() = total_words_processed_;
96 }
97 ctx->set_output(0, word_);
98 ctx->set_output(1, freq_);
99 ctx->set_output(2, words_per_epoch);
100 ctx->set_output(3, current_epoch);
101 ctx->set_output(4, total_words_processed);
102 ctx->set_output(5, examples);
103 ctx->set_output(6, labels);
104 }
105
106 private:
107 struct Example {
108 int32 input;
109 int32 label;
110 };
111
112 int32 batch_size_ = 0;
113 int32 window_size_ = 5;
114 float subsample_ = 1e-3;
115 int min_count_ = 5;
116 int32 vocab_size_ = 0;
117 Tensor word_;
118 Tensor freq_;
119 int64_t corpus_size_ = 0;
120 std::vector<int32> corpus_;
121 std::vector<Example> precalc_examples_;
122 int precalc_index_ = 0;
123 std::vector<int32> sentence_;
124 int sentence_index_ = 0;
125
126 mutex mu_;
127 random::PhiloxRandom philox_ TF_GUARDED_BY(mu_);
128 random::SimplePhilox rng_ TF_GUARDED_BY(mu_);
129 int32 current_epoch_ TF_GUARDED_BY(mu_) = -1;
130 int64_t total_words_processed_ TF_GUARDED_BY(mu_) = 0;
131 int32 example_pos_ TF_GUARDED_BY(mu_);
132 int32 label_pos_ TF_GUARDED_BY(mu_);
133 int32 label_limit_ TF_GUARDED_BY(mu_);
134
135 // {example_pos_, label_pos_} is the cursor for the next example.
136 // example_pos_ wraps around at the end of corpus_. For each
137 // example, we randomly generate [label_pos_, label_limit) for
138 // labels.
139 void NextExample(int32* example, int32* label)
140 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
141 while (true) {
142 if (label_pos_ >= label_limit_) {
143 ++total_words_processed_;
144 ++sentence_index_;
145 if (sentence_index_ >= kSentenceSize) {
146 sentence_index_ = 0;
147 for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) {
148 if (example_pos_ >= corpus_size_) {
149 ++current_epoch_;
150 example_pos_ = 0;
151 }
152 if (subsample_ > 0) {
153 int32_t word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
154 // See Eq. 5 in http://arxiv.org/abs/1310.4546
155 float keep_prob =
156 (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
157 (subsample_ * corpus_size_) / word_freq;
158 if (rng_.RandFloat() > keep_prob) {
159 i--;
160 continue;
161 }
162 }
163 sentence_[i] = corpus_[example_pos_];
164 }
165 }
166 const int32_t skip = 1 + rng_.Uniform(window_size_);
167 label_pos_ = std::max<int32>(0, sentence_index_ - skip);
168 label_limit_ =
169 std::min<int32>(kSentenceSize, sentence_index_ + skip + 1);
170 }
171 if (sentence_index_ != label_pos_) {
172 break;
173 }
174 ++label_pos_;
175 }
176 *example = sentence_[sentence_index_];
177 *label = sentence_[label_pos_++];
178 }
179
180 Status Init(Env* env, const string& filename) {
181 string data;
182 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
183 StringPiece input = data;
184 string w;
185 corpus_size_ = 0;
186 std::unordered_map<string, int32> word_freq;
187 while (ScanWord(&input, &w)) {
188 ++(word_freq[w]);
189 ++corpus_size_;
190 }
191 if (corpus_size_ < window_size_ * 10) {
192 return errors::InvalidArgument(
193 "The text file ", filename,
194 " contains too little data: ", corpus_size_, " words");
195 }
196 typedef std::pair<string, int32> WordFreq;
197 std::vector<WordFreq> ordered;
198 for (const auto& p : word_freq) {
199 if (p.second >= min_count_) ordered.push_back(p);
200 }
201 LOG(INFO) << "Data file: " << filename << " contains " << data.size()
202 << " bytes, " << corpus_size_ << " words, " << word_freq.size()
203 << " unique words, " << ordered.size()
204 << " unique frequent words.";
205 word_freq.clear();
206 std::sort(ordered.begin(), ordered.end(),
207 [](const WordFreq& x, const WordFreq& y) {
208 return x.second > y.second;
209 });
210 vocab_size_ = static_cast<int32>(1 + ordered.size());
211 Tensor word(DT_STRING, TensorShape({vocab_size_}));
212 Tensor freq(DT_INT32, TensorShape({vocab_size_}));
213 word.flat<tstring>()(0) = "UNK";
214 static const int32_t kUnkId = 0;
215 std::unordered_map<string, int32> word_id;
216 int64_t total_counted = 0;
217 for (std::size_t i = 0; i < ordered.size(); ++i) {
218 const auto& w = ordered[i].first;
219 auto id = i + 1;
220 word.flat<tstring>()(id) = w;
221 auto word_count = ordered[i].second;
222 freq.flat<int32>()(id) = word_count;
223 total_counted += word_count;
224 word_id[w] = id;
225 }
226 freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
227 word_ = word;
228 freq_ = freq;
229 corpus_.reserve(corpus_size_);
230 input = data;
231 while (ScanWord(&input, &w)) {
232 corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
233 }
234 precalc_examples_.resize(kPrecalc);
235 sentence_.resize(kSentenceSize);
236 return OkStatus();
237 }
238};
239
240REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp);
241
242class NegTrainOp : public OpKernel {
243 public:
244 explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
245 base_.Init(0, 0);
246
247 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
248
249 std::vector<int32> vocab_count;
250 OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
251
252 std::vector<float> vocab_weights;
253 vocab_weights.reserve(vocab_count.size());
254 for (const auto& f : vocab_count) {
255 float r = std::pow(static_cast<float>(f), 0.75f);
256 vocab_weights.push_back(r);
257 }
258 sampler_ = new random::DistributionSampler(vocab_weights);
259 }
260
261 ~NegTrainOp() override { delete sampler_; }
262
263 void Compute(OpKernelContext* ctx) override {
264 Tensor w_in = ctx->mutable_input(0, false);
265 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
266 errors::InvalidArgument("Must be a matrix"));
267 Tensor w_out = ctx->mutable_input(1, false);
268 OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
269 errors::InvalidArgument("w_in.shape == w_out.shape"));
270 const Tensor& examples = ctx->input(2);
271 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
272 errors::InvalidArgument("Must be a vector"));
273 const Tensor& labels = ctx->input(3);
274 OP_REQUIRES(ctx, examples.shape() == labels.shape(),
275 errors::InvalidArgument("examples.shape == labels.shape"));
276 const Tensor& learning_rate = ctx->input(4);
277 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
278 errors::InvalidArgument("Must be a scalar"));
279
280 auto Tw_in = w_in.matrix<float>();
281 auto Tw_out = w_out.matrix<float>();
282 auto Texamples = examples.flat<int32>();
283 auto Tlabels = labels.flat<int32>();
284 auto lr = learning_rate.scalar<float>()();
285 const int64_t vocab_size = w_in.dim_size(0);
286 const int64_t dims = w_in.dim_size(1);
287 const int64_t batch_size = examples.dim_size(0);
288 OP_REQUIRES(ctx, vocab_size == sampler_->num(),
289 errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
290 " vs. ", sampler_->num()));
291
292 // Gradient accumulator for v_in.
293 Tensor buf(DT_FLOAT, TensorShape({dims}));
294 auto Tbuf = buf.flat<float>();
295
296 // Scalar buffer to hold sigmoid(+/- dot).
297 Tensor g_buf(DT_FLOAT, TensorShape({}));
298 auto g = g_buf.scalar<float>();
299
300 // The following loop needs 2 random 32-bit values per negative
301 // sample. We reserve 8 values per sample just in case the
302 // underlying implementation changes.
303 auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
304 random::SimplePhilox srnd(&rnd);
305
306 for (int64_t i = 0; i < batch_size; ++i) {
307 const int32_t example = Texamples(i);
308 DCHECK(0 <= example && example < vocab_size) << example;
309 const int32_t label = Tlabels(i);
310 DCHECK(0 <= label && label < vocab_size) << label;
311 auto v_in = Tw_in.chip<0>(example);
312
313 // Positive: example predicts label.
314 // forward: x = v_in' * v_out
315 // l = log(sigmoid(x))
316 // backward: dl/dx = g = sigmoid(-x)
317 // dl/d(v_in) = g * v_out'
318 // dl/d(v_out) = v_in' * g
319 {
320 auto v_out = Tw_out.chip<0>(label);
321 auto dot = (v_in * v_out).sum();
322 g = (dot.exp() + 1.f).inverse();
323 Tbuf = v_out * (g() * lr);
324 v_out += v_in * (g() * lr);
325 }
326
327 // Negative samples:
328 // forward: x = v_in' * v_sample
329 // l = log(sigmoid(-x))
330 // backward: dl/dx = g = -sigmoid(x)
331 // dl/d(v_in) = g * v_out'
332 // dl/d(v_out) = v_in' * g
333 for (int j = 0; j < num_samples_; ++j) {
334 const int sample = sampler_->Sample(&srnd);
335 if (sample == label) continue; // Skip.
336 auto v_sample = Tw_out.chip<0>(sample);
337 auto dot = (v_in * v_sample).sum();
338 g = -((-dot).exp() + 1.f).inverse();
339 Tbuf += v_sample * (g() * lr);
340 v_sample += v_in * (g() * lr);
341 }
342
343 // Applies the gradient on v_in.
344 v_in += Tbuf;
345 }
346 }
347
348 private:
349 int32 num_samples_ = 0;
350 random::DistributionSampler* sampler_ = nullptr;
351 GuardedPhiloxRandom base_;
352};
353
354REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp);
355
356} // end namespace tensorflow
357