1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op.h"
17#include "tensorflow/core/framework/shape_inference.h"
18
19namespace tensorflow {
20
21using shape_inference::DimensionHandle;
22using shape_inference::InferenceContext;
23using shape_inference::ShapeHandle;
24
25namespace {
26
27Status CandidateSamplerShapeFn(InferenceContext* c) {
28 int64_t num_sampled;
29 TF_RETURN_IF_ERROR(c->GetAttr("num_sampled", &num_sampled));
30 int64_t num_true;
31 TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
32
33 ShapeHandle true_classes_shape;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes_shape));
35 DimensionHandle batch_size = c->Dim(true_classes_shape, 0);
36
37 ShapeHandle num_sampled_v = c->Vector(num_sampled);
38 c->set_output(0, num_sampled_v);
39 c->set_output(1, c->Matrix(batch_size, num_true));
40 c->set_output(2, num_sampled_v);
41 return OkStatus();
42}
43
44} // namespace
45
46REGISTER_OP("UniformCandidateSampler")
47 .Input("true_classes: int64")
48 .Output("sampled_candidates: int64")
49 .Output("true_expected_count: float")
50 .Output("sampled_expected_count: float")
51 .Attr("num_true: int >= 1")
52 .Attr("num_sampled: int >= 1")
53 .Attr("unique: bool")
54 .Attr("range_max: int >= 1")
55 .Attr("seed: int = 0")
56 .Attr("seed2: int = 0")
57 .SetShapeFn(CandidateSamplerShapeFn)
58 .SetIsStateful();
59
60REGISTER_OP("LogUniformCandidateSampler")
61 .Input("true_classes: int64")
62 .Output("sampled_candidates: int64")
63 .Output("true_expected_count: float")
64 .Output("sampled_expected_count: float")
65 .Attr("num_true: int >= 1")
66 .Attr("num_sampled: int >= 1")
67 .Attr("unique: bool")
68 .Attr("range_max: int >= 1")
69 .Attr("seed: int = 0")
70 .Attr("seed2: int = 0")
71 .SetShapeFn(CandidateSamplerShapeFn)
72 .SetIsStateful();
73
74REGISTER_OP("LearnedUnigramCandidateSampler")
75 .Input("true_classes: int64")
76 .Output("sampled_candidates: int64")
77 .Output("true_expected_count: float")
78 .Output("sampled_expected_count: float")
79 .Attr("num_true: int >= 1")
80 .Attr("num_sampled: int >= 1")
81 .Attr("unique: bool")
82 .Attr("range_max: int >= 1")
83 .Attr("seed: int = 0")
84 .Attr("seed2: int = 0")
85 .SetShapeFn(CandidateSamplerShapeFn)
86 .SetIsStateful();
87
88REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
89 .Input("true_classes: int64")
90 .Output("sampled_candidates: int64")
91 .Output("true_expected_count: float")
92 .Output("sampled_expected_count: float")
93 .Attr("num_true: int >= 1")
94 .Attr("num_sampled: int >= 1")
95 .Attr("unique: bool")
96 .Attr("range_max: int >= 1")
97 .Attr("seed: int = 0")
98 .Attr("seed2: int = 0")
99 .SetShapeFn(CandidateSamplerShapeFn)
100 .SetIsStateful();
101
102REGISTER_OP("FixedUnigramCandidateSampler")
103 .Input("true_classes: int64")
104 .Output("sampled_candidates: int64")
105 .Output("true_expected_count: float")
106 .Output("sampled_expected_count: float")
107 .Attr("num_true: int >= 1")
108 .Attr("num_sampled: int >= 1")
109 .Attr("unique: bool")
110 .Attr("range_max: int >= 1")
111 .Attr("vocab_file: string = ''")
112 .Attr("distortion: float = 1.0")
113 .Attr("num_reserved_ids: int = 0")
114 .Attr("num_shards: int >= 1 = 1")
115 .Attr("shard: int >= 0 = 0")
116 .Attr("unigrams: list(float) = []")
117 .Attr("seed: int = 0")
118 .Attr("seed2: int = 0")
119 .SetShapeFn(CandidateSamplerShapeFn)
120 .SetIsStateful();
121
122REGISTER_OP("AllCandidateSampler")
123 .Input("true_classes: int64")
124 .Output("sampled_candidates: int64")
125 .Output("true_expected_count: float")
126 .Output("sampled_expected_count: float")
127 .Attr("num_true: int >= 1")
128 .Attr("num_sampled: int >= 1")
129 .Attr("unique: bool")
130 .Attr("seed: int = 0")
131 .Attr("seed2: int = 0")
132 .SetShapeFn(CandidateSamplerShapeFn)
133 .SetIsStateful();
134
135REGISTER_OP("ComputeAccidentalHits")
136 .Input("true_classes: int64")
137 .Input("sampled_candidates: int64")
138 .Output("indices: int32")
139 .Output("ids: int64")
140 .Output("weights: float")
141 .Attr("num_true: int")
142 .Attr("seed: int = 0")
143 .Attr("seed2: int = 0")
144 .SetShapeFn([](InferenceContext* c) {
145 int64_t num_true;
146 TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
147
148 // Validate true_classes, must be a matrix.
149 ShapeHandle true_classes;
150 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
151 DimensionHandle unused;
152 TF_RETURN_IF_ERROR(
153 c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
154 // Validate sampled_candidates, must be a vector.
155 ShapeHandle sampled_candidates;
156 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates));
157
158 // All three outputs are the same shape.
159 ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);
160 c->set_output(0, v);
161 c->set_output(1, v);
162 c->set_output(2, v);
163 return OkStatus();
164 });
165
166} // namespace tensorflow
167