1 | // This file is MACHINE GENERATED! Do not edit. |
2 | |
3 | #ifndef TENSORFLOW_CC_OPS_CANDIDATE_SAMPLING_OPS_H_ |
4 | #define TENSORFLOW_CC_OPS_CANDIDATE_SAMPLING_OPS_H_ |
5 | |
6 | // This file is MACHINE GENERATED! Do not edit. |
7 | |
8 | #include "tensorflow/cc/framework/ops.h" |
9 | #include "tensorflow/cc/framework/scope.h" |
10 | #include "tensorflow/core/framework/tensor.h" |
11 | #include "tensorflow/core/framework/tensor_shape.h" |
12 | #include "tensorflow/core/framework/types.h" |
13 | #include "tensorflow/core/lib/gtl/array_slice.h" |
14 | |
15 | namespace tensorflow { |
16 | namespace ops { |
17 | |
18 | /// @defgroup candidate_sampling_ops Candidate Sampling Ops |
19 | /// @{ |
20 | |
21 | /// Generates labels for candidate sampling with a learned unigram distribution. |
22 | /// |
23 | /// See explanations of candidate sampling and the data formats at |
24 | /// go/candidate-sampling. |
25 | /// |
26 | /// For each batch, this op picks a single set of sampled candidate labels. |
27 | /// |
28 | /// The advantages of sampling candidates per-batch are simplicity and the |
29 | /// possibility of efficient dense matrix multiplication. The disadvantage is that |
30 | /// the sampled candidates must be chosen independently of the context and of the |
31 | /// true labels. |
32 | /// |
33 | /// Args: |
34 | /// * scope: A Scope object |
35 | /// * true_classes: A batch_size * num_true matrix, in which each row contains the |
36 | /// IDs of the num_true target_classes in the corresponding original label. |
37 | /// * num_true: Number of true labels per context. |
38 | /// * num_sampled: Number of candidates to produce. |
39 | /// * unique: If unique is true, we sample with rejection, so that all sampled |
40 | /// candidates in a batch are unique. This requires some approximation to |
41 | /// estimate the post-rejection sampling probabilities. |
42 | /// |
43 | /// Optional attributes (see `Attrs`): |
44 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
45 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
46 | /// random seed. |
47 | /// * seed2: An second seed to avoid seed collision. |
48 | /// |
49 | /// Returns: |
50 | /// * `Output` sampled_candidates: A vector of length num_sampled, in which each element is |
51 | /// the ID of a sampled candidate. |
52 | /// * `Output` true_expected_count: A batch_size * num_true matrix, representing |
53 | /// the number of times each candidate is expected to occur in a batch |
54 | /// of sampled candidates. If unique=true, then this is a probability. |
55 | /// * `Output` sampled_expected_count: A vector of length num_sampled, for each sampled |
56 | /// candidate representing the number of times the candidate is expected |
57 | /// to occur in a batch of sampled candidates. If unique=true, then this is a |
58 | /// probability. |
59 | class AllCandidateSampler { |
60 | public: |
61 | /// Optional attribute setters for AllCandidateSampler |
62 | struct Attrs { |
63 | /// If either seed or seed2 are set to be non-zero, the random number |
64 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
65 | /// random seed. |
66 | /// |
67 | /// Defaults to 0 |
68 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
69 | Attrs ret = *this; |
70 | ret.seed_ = x; |
71 | return ret; |
72 | } |
73 | |
74 | /// An second seed to avoid seed collision. |
75 | /// |
76 | /// Defaults to 0 |
77 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
78 | Attrs ret = *this; |
79 | ret.seed2_ = x; |
80 | return ret; |
81 | } |
82 | |
83 | int64 seed_ = 0; |
84 | int64 seed2_ = 0; |
85 | }; |
86 | AllCandidateSampler(const ::tensorflow::Scope& scope, ::tensorflow::Input |
87 | true_classes, int64 num_true, int64 num_sampled, bool |
88 | unique); |
89 | AllCandidateSampler(const ::tensorflow::Scope& scope, ::tensorflow::Input |
90 | true_classes, int64 num_true, int64 num_sampled, bool |
91 | unique, const AllCandidateSampler::Attrs& attrs); |
92 | |
93 | static Attrs Seed(int64 x) { |
94 | return Attrs().Seed(x); |
95 | } |
96 | static Attrs Seed2(int64 x) { |
97 | return Attrs().Seed2(x); |
98 | } |
99 | |
100 | Operation operation; |
101 | ::tensorflow::Output sampled_candidates; |
102 | ::tensorflow::Output true_expected_count; |
103 | ::tensorflow::Output sampled_expected_count; |
104 | }; |
105 | |
106 | /// Computes the ids of the positions in sampled_candidates that match true_labels. |
107 | /// |
108 | /// When doing log-odds NCE, the result of this op should be passed through a |
109 | /// SparseToDense op, then added to the logits of the sampled candidates. This has |
110 | /// the effect of 'removing' the sampled labels that match the true labels by |
111 | /// making the classifier sure that they are sampled labels. |
112 | /// |
113 | /// Args: |
114 | /// * scope: A Scope object |
115 | /// * true_classes: The true_classes output of UnpackSparseLabels. |
116 | /// * sampled_candidates: The sampled_candidates output of CandidateSampler. |
117 | /// * num_true: Number of true labels per context. |
118 | /// |
119 | /// Optional attributes (see `Attrs`): |
120 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
121 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
122 | /// random seed. |
123 | /// * seed2: An second seed to avoid seed collision. |
124 | /// |
125 | /// Returns: |
126 | /// * `Output` indices: A vector of indices corresponding to rows of true_candidates. |
127 | /// * `Output` ids: A vector of IDs of positions in sampled_candidates that match a true_label |
128 | /// for the row with the corresponding index in indices. |
129 | /// * `Output` weights: A vector of the same length as indices and ids, in which each element |
130 | /// is -FLOAT_MAX. |
131 | class ComputeAccidentalHits { |
132 | public: |
133 | /// Optional attribute setters for ComputeAccidentalHits |
134 | struct Attrs { |
135 | /// If either seed or seed2 are set to be non-zero, the random number |
136 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
137 | /// random seed. |
138 | /// |
139 | /// Defaults to 0 |
140 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
141 | Attrs ret = *this; |
142 | ret.seed_ = x; |
143 | return ret; |
144 | } |
145 | |
146 | /// An second seed to avoid seed collision. |
147 | /// |
148 | /// Defaults to 0 |
149 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
150 | Attrs ret = *this; |
151 | ret.seed2_ = x; |
152 | return ret; |
153 | } |
154 | |
155 | int64 seed_ = 0; |
156 | int64 seed2_ = 0; |
157 | }; |
158 | ComputeAccidentalHits(const ::tensorflow::Scope& scope, ::tensorflow::Input |
159 | true_classes, ::tensorflow::Input sampled_candidates, |
160 | int64 num_true); |
161 | ComputeAccidentalHits(const ::tensorflow::Scope& scope, ::tensorflow::Input |
162 | true_classes, ::tensorflow::Input sampled_candidates, |
163 | int64 num_true, const ComputeAccidentalHits::Attrs& |
164 | attrs); |
165 | |
166 | static Attrs Seed(int64 x) { |
167 | return Attrs().Seed(x); |
168 | } |
169 | static Attrs Seed2(int64 x) { |
170 | return Attrs().Seed2(x); |
171 | } |
172 | |
173 | Operation operation; |
174 | ::tensorflow::Output indices; |
175 | ::tensorflow::Output ids; |
176 | ::tensorflow::Output weights; |
177 | }; |
178 | |
179 | /// Generates labels for candidate sampling with a learned unigram distribution. |
180 | /// |
181 | /// A unigram sampler could use a fixed unigram distribution read from a |
182 | /// file or passed in as an in-memory array instead of building up the distribution |
183 | /// from data on the fly. There is also an option to skew the distribution by |
184 | /// applying a distortion power to the weights. |
185 | /// |
186 | /// The vocabulary file should be in CSV-like format, with the last field |
187 | /// being the weight associated with the word. |
188 | /// |
189 | /// For each batch, this op picks a single set of sampled candidate labels. |
190 | /// |
191 | /// The advantages of sampling candidates per-batch are simplicity and the |
192 | /// possibility of efficient dense matrix multiplication. The disadvantage is that |
193 | /// the sampled candidates must be chosen independently of the context and of the |
194 | /// true labels. |
195 | /// |
196 | /// Args: |
197 | /// * scope: A Scope object |
198 | /// * true_classes: A batch_size * num_true matrix, in which each row contains the |
199 | /// IDs of the num_true target_classes in the corresponding original label. |
200 | /// * num_true: Number of true labels per context. |
201 | /// * num_sampled: Number of candidates to randomly sample. |
202 | /// * unique: If unique is true, we sample with rejection, so that all sampled |
203 | /// candidates in a batch are unique. This requires some approximation to |
204 | /// estimate the post-rejection sampling probabilities. |
205 | /// * range_max: The sampler will sample integers from the interval [0, range_max). |
206 | /// |
207 | /// Optional attributes (see `Attrs`): |
208 | /// * vocab_file: Each valid line in this file (which should have a CSV-like format) |
209 | /// corresponds to a valid word ID. IDs are in sequential order, starting from |
210 | /// num_reserved_ids. The last entry in each line is expected to be a value |
211 | /// corresponding to the count or relative probability. Exactly one of vocab_file |
212 | /// and unigrams needs to be passed to this op. |
213 | /// * distortion: The distortion is used to skew the unigram probability distribution. |
214 | /// Each weight is first raised to the distortion's power before adding to the |
215 | /// internal unigram distribution. As a result, distortion = 1.0 gives regular |
216 | /// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives |
217 | /// a uniform distribution. |
218 | /// * num_reserved_ids: Optionally some reserved IDs can be added in the range [0, |
219 | /// ..., num_reserved_ids) by the users. One use case is that a special unknown |
220 | /// word token is used as ID 0. These IDs will have a sampling probability of 0. |
221 | /// * num_shards: A sampler can be used to sample from a subset of the original range |
222 | /// in order to speed up the whole computation through parallelism. This parameter |
223 | /// (together with 'shard') indicates the number of partitions that are being |
224 | /// used in the overall computation. |
225 | /// * shard: A sampler can be used to sample from a subset of the original range |
226 | /// in order to speed up the whole computation through parallelism. This parameter |
227 | /// (together with 'num_shards') indicates the particular partition number of a |
228 | /// sampler op, when partitioning is being used. |
229 | /// * unigrams: A list of unigram counts or probabilities, one per ID in sequential |
230 | /// order. Exactly one of vocab_file and unigrams should be passed to this op. |
231 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
232 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
233 | /// random seed. |
234 | /// * seed2: An second seed to avoid seed collision. |
235 | /// |
236 | /// Returns: |
237 | /// * `Output` sampled_candidates: A vector of length num_sampled, in which each element is |
238 | /// the ID of a sampled candidate. |
239 | /// * `Output` true_expected_count: A batch_size * num_true matrix, representing |
240 | /// the number of times each candidate is expected to occur in a batch |
241 | /// of sampled candidates. If unique=true, then this is a probability. |
242 | /// * `Output` sampled_expected_count: A vector of length num_sampled, for each sampled |
243 | /// candidate representing the number of times the candidate is expected |
244 | /// to occur in a batch of sampled candidates. If unique=true, then this is a |
245 | /// probability. |
246 | class FixedUnigramCandidateSampler { |
247 | public: |
248 | /// Optional attribute setters for FixedUnigramCandidateSampler |
249 | struct Attrs { |
250 | /// Each valid line in this file (which should have a CSV-like format) |
251 | /// corresponds to a valid word ID. IDs are in sequential order, starting from |
252 | /// num_reserved_ids. The last entry in each line is expected to be a value |
253 | /// corresponding to the count or relative probability. Exactly one of vocab_file |
254 | /// and unigrams needs to be passed to this op. |
255 | /// |
256 | /// Defaults to "" |
257 | TF_MUST_USE_RESULT Attrs VocabFile(StringPiece x) { |
258 | Attrs ret = *this; |
259 | ret.vocab_file_ = x; |
260 | return ret; |
261 | } |
262 | |
263 | /// The distortion is used to skew the unigram probability distribution. |
264 | /// Each weight is first raised to the distortion's power before adding to the |
265 | /// internal unigram distribution. As a result, distortion = 1.0 gives regular |
266 | /// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives |
267 | /// a uniform distribution. |
268 | /// |
269 | /// Defaults to 1 |
270 | TF_MUST_USE_RESULT Attrs Distortion(float x) { |
271 | Attrs ret = *this; |
272 | ret.distortion_ = x; |
273 | return ret; |
274 | } |
275 | |
276 | /// Optionally some reserved IDs can be added in the range [0, |
277 | /// ..., num_reserved_ids) by the users. One use case is that a special unknown |
278 | /// word token is used as ID 0. These IDs will have a sampling probability of 0. |
279 | /// |
280 | /// Defaults to 0 |
281 | TF_MUST_USE_RESULT Attrs NumReservedIds(int64 x) { |
282 | Attrs ret = *this; |
283 | ret.num_reserved_ids_ = x; |
284 | return ret; |
285 | } |
286 | |
287 | /// A sampler can be used to sample from a subset of the original range |
288 | /// in order to speed up the whole computation through parallelism. This parameter |
289 | /// (together with 'shard') indicates the number of partitions that are being |
290 | /// used in the overall computation. |
291 | /// |
292 | /// Defaults to 1 |
293 | TF_MUST_USE_RESULT Attrs NumShards(int64 x) { |
294 | Attrs ret = *this; |
295 | ret.num_shards_ = x; |
296 | return ret; |
297 | } |
298 | |
299 | /// A sampler can be used to sample from a subset of the original range |
300 | /// in order to speed up the whole computation through parallelism. This parameter |
301 | /// (together with 'num_shards') indicates the particular partition number of a |
302 | /// sampler op, when partitioning is being used. |
303 | /// |
304 | /// Defaults to 0 |
305 | TF_MUST_USE_RESULT Attrs Shard(int64 x) { |
306 | Attrs ret = *this; |
307 | ret.shard_ = x; |
308 | return ret; |
309 | } |
310 | |
311 | /// A list of unigram counts or probabilities, one per ID in sequential |
312 | /// order. Exactly one of vocab_file and unigrams should be passed to this op. |
313 | /// |
314 | /// Defaults to [] |
315 | TF_MUST_USE_RESULT Attrs Unigrams(const gtl::ArraySlice<float>& x) { |
316 | Attrs ret = *this; |
317 | ret.unigrams_ = x; |
318 | return ret; |
319 | } |
320 | |
321 | /// If either seed or seed2 are set to be non-zero, the random number |
322 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
323 | /// random seed. |
324 | /// |
325 | /// Defaults to 0 |
326 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
327 | Attrs ret = *this; |
328 | ret.seed_ = x; |
329 | return ret; |
330 | } |
331 | |
332 | /// An second seed to avoid seed collision. |
333 | /// |
334 | /// Defaults to 0 |
335 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
336 | Attrs ret = *this; |
337 | ret.seed2_ = x; |
338 | return ret; |
339 | } |
340 | |
341 | StringPiece vocab_file_ = "" ; |
342 | float distortion_ = 1.0f; |
343 | int64 num_reserved_ids_ = 0; |
344 | int64 num_shards_ = 1; |
345 | int64 shard_ = 0; |
346 | gtl::ArraySlice<float> unigrams_ = {}; |
347 | int64 seed_ = 0; |
348 | int64 seed2_ = 0; |
349 | }; |
350 | FixedUnigramCandidateSampler(const ::tensorflow::Scope& scope, |
351 | ::tensorflow::Input true_classes, int64 num_true, |
352 | int64 num_sampled, bool unique, int64 range_max); |
353 | FixedUnigramCandidateSampler(const ::tensorflow::Scope& scope, |
354 | ::tensorflow::Input true_classes, int64 num_true, |
355 | int64 num_sampled, bool unique, int64 range_max, |
356 | const FixedUnigramCandidateSampler::Attrs& attrs); |
357 | |
358 | static Attrs VocabFile(StringPiece x) { |
359 | return Attrs().VocabFile(x); |
360 | } |
361 | static Attrs Distortion(float x) { |
362 | return Attrs().Distortion(x); |
363 | } |
364 | static Attrs NumReservedIds(int64 x) { |
365 | return Attrs().NumReservedIds(x); |
366 | } |
367 | static Attrs NumShards(int64 x) { |
368 | return Attrs().NumShards(x); |
369 | } |
370 | static Attrs Shard(int64 x) { |
371 | return Attrs().Shard(x); |
372 | } |
373 | static Attrs Unigrams(const gtl::ArraySlice<float>& x) { |
374 | return Attrs().Unigrams(x); |
375 | } |
376 | static Attrs Seed(int64 x) { |
377 | return Attrs().Seed(x); |
378 | } |
379 | static Attrs Seed2(int64 x) { |
380 | return Attrs().Seed2(x); |
381 | } |
382 | |
383 | Operation operation; |
384 | ::tensorflow::Output sampled_candidates; |
385 | ::tensorflow::Output true_expected_count; |
386 | ::tensorflow::Output sampled_expected_count; |
387 | }; |
388 | |
389 | /// Generates labels for candidate sampling with a learned unigram distribution. |
390 | /// |
391 | /// See explanations of candidate sampling and the data formats at |
392 | /// go/candidate-sampling. |
393 | /// |
394 | /// For each batch, this op picks a single set of sampled candidate labels. |
395 | /// |
396 | /// The advantages of sampling candidates per-batch are simplicity and the |
397 | /// possibility of efficient dense matrix multiplication. The disadvantage is that |
398 | /// the sampled candidates must be chosen independently of the context and of the |
399 | /// true labels. |
400 | /// |
401 | /// Args: |
402 | /// * scope: A Scope object |
403 | /// * true_classes: A batch_size * num_true matrix, in which each row contains the |
404 | /// IDs of the num_true target_classes in the corresponding original label. |
405 | /// * num_true: Number of true labels per context. |
406 | /// * num_sampled: Number of candidates to randomly sample. |
407 | /// * unique: If unique is true, we sample with rejection, so that all sampled |
408 | /// candidates in a batch are unique. This requires some approximation to |
409 | /// estimate the post-rejection sampling probabilities. |
410 | /// * range_max: The sampler will sample integers from the interval [0, range_max). |
411 | /// |
412 | /// Optional attributes (see `Attrs`): |
413 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
414 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
415 | /// random seed. |
416 | /// * seed2: An second seed to avoid seed collision. |
417 | /// |
418 | /// Returns: |
419 | /// * `Output` sampled_candidates: A vector of length num_sampled, in which each element is |
420 | /// the ID of a sampled candidate. |
421 | /// * `Output` true_expected_count: A batch_size * num_true matrix, representing |
422 | /// the number of times each candidate is expected to occur in a batch |
423 | /// of sampled candidates. If unique=true, then this is a probability. |
424 | /// * `Output` sampled_expected_count: A vector of length num_sampled, for each sampled |
425 | /// candidate representing the number of times the candidate is expected |
426 | /// to occur in a batch of sampled candidates. If unique=true, then this is a |
427 | /// probability. |
428 | class LearnedUnigramCandidateSampler { |
429 | public: |
430 | /// Optional attribute setters for LearnedUnigramCandidateSampler |
431 | struct Attrs { |
432 | /// If either seed or seed2 are set to be non-zero, the random number |
433 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
434 | /// random seed. |
435 | /// |
436 | /// Defaults to 0 |
437 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
438 | Attrs ret = *this; |
439 | ret.seed_ = x; |
440 | return ret; |
441 | } |
442 | |
443 | /// An second seed to avoid seed collision. |
444 | /// |
445 | /// Defaults to 0 |
446 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
447 | Attrs ret = *this; |
448 | ret.seed2_ = x; |
449 | return ret; |
450 | } |
451 | |
452 | int64 seed_ = 0; |
453 | int64 seed2_ = 0; |
454 | }; |
455 | LearnedUnigramCandidateSampler(const ::tensorflow::Scope& scope, |
456 | ::tensorflow::Input true_classes, int64 |
457 | num_true, int64 num_sampled, bool unique, int64 |
458 | range_max); |
459 | LearnedUnigramCandidateSampler(const ::tensorflow::Scope& scope, |
460 | ::tensorflow::Input true_classes, int64 |
461 | num_true, int64 num_sampled, bool unique, int64 |
462 | range_max, const |
463 | LearnedUnigramCandidateSampler::Attrs& attrs); |
464 | |
465 | static Attrs Seed(int64 x) { |
466 | return Attrs().Seed(x); |
467 | } |
468 | static Attrs Seed2(int64 x) { |
469 | return Attrs().Seed2(x); |
470 | } |
471 | |
472 | Operation operation; |
473 | ::tensorflow::Output sampled_candidates; |
474 | ::tensorflow::Output true_expected_count; |
475 | ::tensorflow::Output sampled_expected_count; |
476 | }; |
477 | |
478 | /// Generates labels for candidate sampling with a log-uniform distribution. |
479 | /// |
480 | /// See explanations of candidate sampling and the data formats at |
481 | /// go/candidate-sampling. |
482 | /// |
483 | /// For each batch, this op picks a single set of sampled candidate labels. |
484 | /// |
485 | /// The advantages of sampling candidates per-batch are simplicity and the |
486 | /// possibility of efficient dense matrix multiplication. The disadvantage is that |
487 | /// the sampled candidates must be chosen independently of the context and of the |
488 | /// true labels. |
489 | /// |
490 | /// Args: |
491 | /// * scope: A Scope object |
492 | /// * true_classes: A batch_size * num_true matrix, in which each row contains the |
493 | /// IDs of the num_true target_classes in the corresponding original label. |
494 | /// * num_true: Number of true labels per context. |
495 | /// * num_sampled: Number of candidates to randomly sample. |
496 | /// * unique: If unique is true, we sample with rejection, so that all sampled |
497 | /// candidates in a batch are unique. This requires some approximation to |
498 | /// estimate the post-rejection sampling probabilities. |
499 | /// * range_max: The sampler will sample integers from the interval [0, range_max). |
500 | /// |
501 | /// Optional attributes (see `Attrs`): |
502 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
503 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
504 | /// random seed. |
505 | /// * seed2: An second seed to avoid seed collision. |
506 | /// |
507 | /// Returns: |
508 | /// * `Output` sampled_candidates: A vector of length num_sampled, in which each element is |
509 | /// the ID of a sampled candidate. |
510 | /// * `Output` true_expected_count: A batch_size * num_true matrix, representing |
511 | /// the number of times each candidate is expected to occur in a batch |
512 | /// of sampled candidates. If unique=true, then this is a probability. |
513 | /// * `Output` sampled_expected_count: A vector of length num_sampled, for each sampled |
514 | /// candidate representing the number of times the candidate is expected |
515 | /// to occur in a batch of sampled candidates. If unique=true, then this is a |
516 | /// probability. |
517 | class LogUniformCandidateSampler { |
518 | public: |
519 | /// Optional attribute setters for LogUniformCandidateSampler |
520 | struct Attrs { |
521 | /// If either seed or seed2 are set to be non-zero, the random number |
522 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
523 | /// random seed. |
524 | /// |
525 | /// Defaults to 0 |
526 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
527 | Attrs ret = *this; |
528 | ret.seed_ = x; |
529 | return ret; |
530 | } |
531 | |
532 | /// An second seed to avoid seed collision. |
533 | /// |
534 | /// Defaults to 0 |
535 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
536 | Attrs ret = *this; |
537 | ret.seed2_ = x; |
538 | return ret; |
539 | } |
540 | |
541 | int64 seed_ = 0; |
542 | int64 seed2_ = 0; |
543 | }; |
544 | LogUniformCandidateSampler(const ::tensorflow::Scope& scope, |
545 | ::tensorflow::Input true_classes, int64 num_true, |
546 | int64 num_sampled, bool unique, int64 range_max); |
547 | LogUniformCandidateSampler(const ::tensorflow::Scope& scope, |
548 | ::tensorflow::Input true_classes, int64 num_true, |
549 | int64 num_sampled, bool unique, int64 range_max, |
550 | const LogUniformCandidateSampler::Attrs& attrs); |
551 | |
552 | static Attrs Seed(int64 x) { |
553 | return Attrs().Seed(x); |
554 | } |
555 | static Attrs Seed2(int64 x) { |
556 | return Attrs().Seed2(x); |
557 | } |
558 | |
559 | Operation operation; |
560 | ::tensorflow::Output sampled_candidates; |
561 | ::tensorflow::Output true_expected_count; |
562 | ::tensorflow::Output sampled_expected_count; |
563 | }; |
564 | |
565 | /// Generates labels for candidate sampling with a uniform distribution. |
566 | /// |
567 | /// See explanations of candidate sampling and the data formats at |
568 | /// go/candidate-sampling. |
569 | /// |
570 | /// For each batch, this op picks a single set of sampled candidate labels. |
571 | /// |
572 | /// The advantages of sampling candidates per-batch are simplicity and the |
573 | /// possibility of efficient dense matrix multiplication. The disadvantage is that |
574 | /// the sampled candidates must be chosen independently of the context and of the |
575 | /// true labels. |
576 | /// |
577 | /// Args: |
578 | /// * scope: A Scope object |
579 | /// * true_classes: A batch_size * num_true matrix, in which each row contains the |
580 | /// IDs of the num_true target_classes in the corresponding original label. |
581 | /// * num_true: Number of true labels per context. |
582 | /// * num_sampled: Number of candidates to randomly sample. |
583 | /// * unique: If unique is true, we sample with rejection, so that all sampled |
584 | /// candidates in a batch are unique. This requires some approximation to |
585 | /// estimate the post-rejection sampling probabilities. |
586 | /// * range_max: The sampler will sample integers from the interval [0, range_max). |
587 | /// |
588 | /// Optional attributes (see `Attrs`): |
589 | /// * seed: If either seed or seed2 are set to be non-zero, the random number |
590 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
591 | /// random seed. |
592 | /// * seed2: An second seed to avoid seed collision. |
593 | /// |
594 | /// Returns: |
595 | /// * `Output` sampled_candidates: A vector of length num_sampled, in which each element is |
596 | /// the ID of a sampled candidate. |
597 | /// * `Output` true_expected_count: A batch_size * num_true matrix, representing |
598 | /// the number of times each candidate is expected to occur in a batch |
599 | /// of sampled candidates. If unique=true, then this is a probability. |
600 | /// * `Output` sampled_expected_count: A vector of length num_sampled, for each sampled |
601 | /// candidate representing the number of times the candidate is expected |
602 | /// to occur in a batch of sampled candidates. If unique=true, then this is a |
603 | /// probability. |
604 | class UniformCandidateSampler { |
605 | public: |
606 | /// Optional attribute setters for UniformCandidateSampler |
607 | struct Attrs { |
608 | /// If either seed or seed2 are set to be non-zero, the random number |
609 | /// generator is seeded by the given seed. Otherwise, it is seeded by a |
610 | /// random seed. |
611 | /// |
612 | /// Defaults to 0 |
613 | TF_MUST_USE_RESULT Attrs Seed(int64 x) { |
614 | Attrs ret = *this; |
615 | ret.seed_ = x; |
616 | return ret; |
617 | } |
618 | |
619 | /// An second seed to avoid seed collision. |
620 | /// |
621 | /// Defaults to 0 |
622 | TF_MUST_USE_RESULT Attrs Seed2(int64 x) { |
623 | Attrs ret = *this; |
624 | ret.seed2_ = x; |
625 | return ret; |
626 | } |
627 | |
628 | int64 seed_ = 0; |
629 | int64 seed2_ = 0; |
630 | }; |
631 | UniformCandidateSampler(const ::tensorflow::Scope& scope, ::tensorflow::Input |
632 | true_classes, int64 num_true, int64 num_sampled, bool |
633 | unique, int64 range_max); |
634 | UniformCandidateSampler(const ::tensorflow::Scope& scope, ::tensorflow::Input |
635 | true_classes, int64 num_true, int64 num_sampled, bool |
636 | unique, int64 range_max, const |
637 | UniformCandidateSampler::Attrs& attrs); |
638 | |
639 | static Attrs Seed(int64 x) { |
640 | return Attrs().Seed(x); |
641 | } |
642 | static Attrs Seed2(int64 x) { |
643 | return Attrs().Seed2(x); |
644 | } |
645 | |
646 | Operation operation; |
647 | ::tensorflow::Output sampled_candidates; |
648 | ::tensorflow::Output true_expected_count; |
649 | ::tensorflow::Output sampled_expected_count; |
650 | }; |
651 | |
652 | /// @} |
653 | |
654 | } // namespace ops |
655 | } // namespace tensorflow |
656 | |
657 | #endif // TENSORFLOW_CC_OPS_CANDIDATE_SAMPLING_OPS_H_ |
658 | |