1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::DimensionHandle; |
22 | using shape_inference::InferenceContext; |
23 | using shape_inference::ShapeHandle; |
24 | |
25 | namespace { |
26 | |
27 | Status CandidateSamplerShapeFn(InferenceContext* c) { |
28 | int64_t num_sampled; |
29 | TF_RETURN_IF_ERROR(c->GetAttr("num_sampled" , &num_sampled)); |
30 | int64_t num_true; |
31 | TF_RETURN_IF_ERROR(c->GetAttr("num_true" , &num_true)); |
32 | |
33 | ShapeHandle true_classes_shape; |
34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes_shape)); |
35 | DimensionHandle batch_size = c->Dim(true_classes_shape, 0); |
36 | |
37 | ShapeHandle num_sampled_v = c->Vector(num_sampled); |
38 | c->set_output(0, num_sampled_v); |
39 | c->set_output(1, c->Matrix(batch_size, num_true)); |
40 | c->set_output(2, num_sampled_v); |
41 | return OkStatus(); |
42 | } |
43 | |
44 | } // namespace |
45 | |
46 | REGISTER_OP("UniformCandidateSampler" ) |
47 | .Input("true_classes: int64" ) |
48 | .Output("sampled_candidates: int64" ) |
49 | .Output("true_expected_count: float" ) |
50 | .Output("sampled_expected_count: float" ) |
51 | .Attr("num_true: int >= 1" ) |
52 | .Attr("num_sampled: int >= 1" ) |
53 | .Attr("unique: bool" ) |
54 | .Attr("range_max: int >= 1" ) |
55 | .Attr("seed: int = 0" ) |
56 | .Attr("seed2: int = 0" ) |
57 | .SetShapeFn(CandidateSamplerShapeFn) |
58 | .SetIsStateful(); |
59 | |
60 | REGISTER_OP("LogUniformCandidateSampler" ) |
61 | .Input("true_classes: int64" ) |
62 | .Output("sampled_candidates: int64" ) |
63 | .Output("true_expected_count: float" ) |
64 | .Output("sampled_expected_count: float" ) |
65 | .Attr("num_true: int >= 1" ) |
66 | .Attr("num_sampled: int >= 1" ) |
67 | .Attr("unique: bool" ) |
68 | .Attr("range_max: int >= 1" ) |
69 | .Attr("seed: int = 0" ) |
70 | .Attr("seed2: int = 0" ) |
71 | .SetShapeFn(CandidateSamplerShapeFn) |
72 | .SetIsStateful(); |
73 | |
74 | REGISTER_OP("LearnedUnigramCandidateSampler" ) |
75 | .Input("true_classes: int64" ) |
76 | .Output("sampled_candidates: int64" ) |
77 | .Output("true_expected_count: float" ) |
78 | .Output("sampled_expected_count: float" ) |
79 | .Attr("num_true: int >= 1" ) |
80 | .Attr("num_sampled: int >= 1" ) |
81 | .Attr("unique: bool" ) |
82 | .Attr("range_max: int >= 1" ) |
83 | .Attr("seed: int = 0" ) |
84 | .Attr("seed2: int = 0" ) |
85 | .SetShapeFn(CandidateSamplerShapeFn) |
86 | .SetIsStateful(); |
87 | |
88 | REGISTER_OP("ThreadUnsafeUnigramCandidateSampler" ) |
89 | .Input("true_classes: int64" ) |
90 | .Output("sampled_candidates: int64" ) |
91 | .Output("true_expected_count: float" ) |
92 | .Output("sampled_expected_count: float" ) |
93 | .Attr("num_true: int >= 1" ) |
94 | .Attr("num_sampled: int >= 1" ) |
95 | .Attr("unique: bool" ) |
96 | .Attr("range_max: int >= 1" ) |
97 | .Attr("seed: int = 0" ) |
98 | .Attr("seed2: int = 0" ) |
99 | .SetShapeFn(CandidateSamplerShapeFn) |
100 | .SetIsStateful(); |
101 | |
102 | REGISTER_OP("FixedUnigramCandidateSampler" ) |
103 | .Input("true_classes: int64" ) |
104 | .Output("sampled_candidates: int64" ) |
105 | .Output("true_expected_count: float" ) |
106 | .Output("sampled_expected_count: float" ) |
107 | .Attr("num_true: int >= 1" ) |
108 | .Attr("num_sampled: int >= 1" ) |
109 | .Attr("unique: bool" ) |
110 | .Attr("range_max: int >= 1" ) |
111 | .Attr("vocab_file: string = ''" ) |
112 | .Attr("distortion: float = 1.0" ) |
113 | .Attr("num_reserved_ids: int = 0" ) |
114 | .Attr("num_shards: int >= 1 = 1" ) |
115 | .Attr("shard: int >= 0 = 0" ) |
116 | .Attr("unigrams: list(float) = []" ) |
117 | .Attr("seed: int = 0" ) |
118 | .Attr("seed2: int = 0" ) |
119 | .SetShapeFn(CandidateSamplerShapeFn) |
120 | .SetIsStateful(); |
121 | |
122 | REGISTER_OP("AllCandidateSampler" ) |
123 | .Input("true_classes: int64" ) |
124 | .Output("sampled_candidates: int64" ) |
125 | .Output("true_expected_count: float" ) |
126 | .Output("sampled_expected_count: float" ) |
127 | .Attr("num_true: int >= 1" ) |
128 | .Attr("num_sampled: int >= 1" ) |
129 | .Attr("unique: bool" ) |
130 | .Attr("seed: int = 0" ) |
131 | .Attr("seed2: int = 0" ) |
132 | .SetShapeFn(CandidateSamplerShapeFn) |
133 | .SetIsStateful(); |
134 | |
135 | REGISTER_OP("ComputeAccidentalHits" ) |
136 | .Input("true_classes: int64" ) |
137 | .Input("sampled_candidates: int64" ) |
138 | .Output("indices: int32" ) |
139 | .Output("ids: int64" ) |
140 | .Output("weights: float" ) |
141 | .Attr("num_true: int" ) |
142 | .Attr("seed: int = 0" ) |
143 | .Attr("seed2: int = 0" ) |
144 | .SetShapeFn([](InferenceContext* c) { |
145 | int64_t num_true; |
146 | TF_RETURN_IF_ERROR(c->GetAttr("num_true" , &num_true)); |
147 | |
148 | // Validate true_classes, must be a matrix. |
149 | ShapeHandle true_classes; |
150 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes)); |
151 | DimensionHandle unused; |
152 | TF_RETURN_IF_ERROR( |
153 | c->WithValue(c->Dim(true_classes, 1), num_true, &unused)); |
154 | // Validate sampled_candidates, must be a vector. |
155 | ShapeHandle sampled_candidates; |
156 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates)); |
157 | |
158 | // All three outputs are the same shape. |
159 | ShapeHandle v = c->Vector(InferenceContext::kUnknownDim); |
160 | c->set_output(0, v); |
161 | c->set_output(1, v); |
162 | c->set_output(2, v); |
163 | return OkStatus(); |
164 | }); |
165 | |
166 | } // namespace tensorflow |
167 | |