1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/candidate_sampling_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <cfloat> |
21 | #include <unordered_map> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/kernels/range_sampler.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/util/guarded_philox_random.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | class BaseCandidateSamplerOp : public OpKernel { |
33 | public: |
34 | explicit BaseCandidateSamplerOp(OpKernelConstruction* context) |
35 | : OpKernel(context) { |
36 | OP_REQUIRES_OK(context, context->GetAttr("num_sampled" , &num_sampled_)); |
37 | OP_REQUIRES_OK(context, context->GetAttr("num_true" , &num_true_)); |
38 | OP_REQUIRES_OK(context, context->GetAttr("unique" , &unique_)); |
39 | OP_REQUIRES_OK(context, generator_.Init(context)); |
40 | } |
41 | |
42 | void Compute(OpKernelContext* context) override { |
43 | const Tensor& true_classes = context->input(0); |
44 | OP_REQUIRES(context, true_classes.dims() == 2, |
45 | errors::InvalidArgument("true_classes must be a matrix" )); |
46 | const int32_t batch_size = true_classes.dim_size(0); |
47 | OP_REQUIRES( |
48 | context, true_classes.dim_size(1) == num_true_, |
49 | errors::InvalidArgument("true_classes must have " |
50 | "num_true columns, expected: " , |
51 | true_classes.dim_size(1), " was: " , num_true_)); |
52 | CHECK(sampler_) << "CandidateSamplerOp did not set sampler_" ; |
53 | |
54 | if (unique_) { |
55 | OP_REQUIRES(context, num_sampled_ <= sampler_->range(), |
56 | errors::InvalidArgument("Sampler's range is too small." )); |
57 | } |
58 | |
59 | // Output candidates and expected_count. |
60 | Tensor* out_sampled_candidates = nullptr; |
61 | OP_REQUIRES_OK(context, |
62 | context->allocate_output(0, TensorShape({num_sampled_}), |
63 | &out_sampled_candidates)); |
64 | |
65 | Tensor* out_true_expected_count = nullptr; |
66 | OP_REQUIRES_OK(context, context->allocate_output( |
67 | 1, TensorShape({batch_size, num_true_}), |
68 | &out_true_expected_count)); |
69 | Tensor* out_sampled_expected_count = nullptr; |
70 | OP_REQUIRES_OK(context, |
71 | context->allocate_output(2, TensorShape({num_sampled_}), |
72 | &out_sampled_expected_count)); |
73 | |
74 | gtl::ArraySlice<int64_t> true_candidate( |
75 | true_classes.matrix<int64_t>().data(), batch_size * num_true_); |
76 | gtl::MutableArraySlice<int64_t> sampled_candidate( |
77 | out_sampled_candidates->vec<int64_t>().data(), num_sampled_); |
78 | gtl::MutableArraySlice<float> true_expected_count( |
79 | out_true_expected_count->matrix<float>().data(), |
80 | batch_size * num_true_); |
81 | gtl::MutableArraySlice<float> sampled_expected_count( |
82 | out_sampled_expected_count->vec<float>().data(), num_sampled_); |
83 | |
84 | // Approximately conservatively estimate the number of samples required. |
85 | // In cases where rejection sampling is used we may occasionally use more |
86 | // samples than expected, which will result in reused random bits. |
87 | const int64_t samples32 = 2048 * num_sampled_; |
88 | |
89 | // Pick sampled candidates. |
90 | auto local_gen = generator_.ReserveSamples32(samples32); |
91 | random::SimplePhilox random(&local_gen); |
92 | sampler_->SampleBatchGetExpectedCount(&random, unique_, sampled_candidate, |
93 | sampled_expected_count, |
94 | true_candidate, true_expected_count); |
95 | |
96 | if (sampler_->NeedsUpdates()) { |
97 | sampler_->Update(true_candidate); |
98 | } |
99 | } |
100 | |
101 | protected: |
102 | void set_sampler(RangeSampler* sampler) { sampler_.reset(sampler); } |
103 | |
104 | private: |
105 | int32 num_true_; |
106 | int32 num_sampled_; |
107 | bool unique_; |
108 | std::unique_ptr<RangeSampler> sampler_; |
109 | GuardedPhiloxRandom generator_; |
110 | }; |
111 | |
112 | template <class RangeSamplerType> |
113 | class SimpleCandidateSamplerOp : public BaseCandidateSamplerOp { |
114 | public: |
115 | explicit SimpleCandidateSamplerOp(OpKernelConstruction* context) |
116 | : BaseCandidateSamplerOp(context) { |
117 | int64_t range_max; |
118 | OP_REQUIRES_OK(context, context->GetAttr("range_max" , &range_max)); |
119 | set_sampler(new RangeSamplerType(range_max)); |
120 | } |
121 | }; |
122 | |
123 | REGISTER_KERNEL_BUILDER(Name("UniformCandidateSampler" ).Device(DEVICE_CPU), |
124 | SimpleCandidateSamplerOp<UniformSampler>); |
125 | |
126 | REGISTER_KERNEL_BUILDER(Name("LogUniformCandidateSampler" ).Device(DEVICE_CPU), |
127 | SimpleCandidateSamplerOp<LogUniformSampler>); |
128 | |
129 | REGISTER_KERNEL_BUILDER( |
130 | Name("LearnedUnigramCandidateSampler" ).Device(DEVICE_CPU), |
131 | SimpleCandidateSamplerOp<UnigramSampler>); |
132 | |
133 | REGISTER_KERNEL_BUILDER( |
134 | Name("ThreadUnsafeUnigramCandidateSampler" ).Device(DEVICE_CPU), |
135 | SimpleCandidateSamplerOp<ThreadUnsafeUnigramSampler>); |
136 | |
137 | class AllCandidateSamplerOp : public BaseCandidateSamplerOp { |
138 | public: |
139 | explicit AllCandidateSamplerOp(OpKernelConstruction* context) |
140 | : BaseCandidateSamplerOp(context) { |
141 | int64_t range_max; |
142 | OP_REQUIRES_OK(context, context->GetAttr("num_sampled" , &range_max)); |
143 | set_sampler(new AllSampler(range_max)); |
144 | } |
145 | }; |
146 | |
147 | REGISTER_KERNEL_BUILDER(Name("AllCandidateSampler" ).Device(DEVICE_CPU), |
148 | AllCandidateSamplerOp); |
149 | |
150 | class FixedUnigramCandidateSamplerOp : public BaseCandidateSamplerOp { |
151 | public: |
152 | explicit FixedUnigramCandidateSamplerOp(OpKernelConstruction* context) |
153 | : BaseCandidateSamplerOp(context) { |
154 | int64_t range_max; |
155 | OP_REQUIRES_OK(context, context->GetAttr("range_max" , &range_max)); |
156 | string vocab_file; |
157 | OP_REQUIRES_OK(context, context->GetAttr("vocab_file" , &vocab_file)); |
158 | std::vector<float> unigrams; |
159 | OP_REQUIRES_OK(context, context->GetAttr("unigrams" , &unigrams)); |
160 | OP_REQUIRES( |
161 | context, !vocab_file.empty() || !unigrams.empty(), |
162 | errors::InvalidArgument("Must provide either vocab_file or unigrams." )); |
163 | OP_REQUIRES(context, vocab_file.empty() || unigrams.empty(), |
164 | errors::InvalidArgument( |
165 | "Must only provide one of vocab_file and unigrams." )); |
166 | float distortion; |
167 | OP_REQUIRES_OK(context, context->GetAttr("distortion" , &distortion)); |
168 | int64_t num_reserved_ids; |
169 | OP_REQUIRES_OK(context, |
170 | context->GetAttr("num_reserved_ids" , &num_reserved_ids)); |
171 | int64_t num_shards; |
172 | OP_REQUIRES_OK(context, context->GetAttr("num_shards" , &num_shards)); |
173 | int64_t shard; |
174 | OP_REQUIRES_OK(context, context->GetAttr("shard" , &shard)); |
175 | |
176 | if (!vocab_file.empty()) { |
177 | set_sampler(new FixedUnigramSampler(context->env(), range_max, vocab_file, |
178 | distortion, num_reserved_ids, |
179 | num_shards, shard)); |
180 | } else { |
181 | set_sampler(new FixedUnigramSampler(range_max, unigrams, distortion, |
182 | num_reserved_ids, num_shards, shard)); |
183 | } |
184 | } |
185 | }; |
186 | |
187 | REGISTER_KERNEL_BUILDER(Name("FixedUnigramCandidateSampler" ).Device(DEVICE_CPU), |
188 | FixedUnigramCandidateSamplerOp); |
189 | |
190 | class ComputeAccidentalHitsOp : public OpKernel { |
191 | public: |
192 | explicit ComputeAccidentalHitsOp(OpKernelConstruction* context) |
193 | : OpKernel(context) { |
194 | OP_REQUIRES_OK(context, context->GetAttr("num_true" , &num_true_)); |
195 | } |
196 | |
197 | void Compute(OpKernelContext* context) override { |
198 | const Tensor& in_true_candidates = context->input(0); |
199 | const TensorShape& in_true_candidates_shape = in_true_candidates.shape(); |
200 | OP_REQUIRES(context, |
201 | TensorShapeUtils::IsMatrix(in_true_candidates_shape) && |
202 | in_true_candidates_shape.dim_size(1) == num_true_, |
203 | errors::InvalidArgument( |
204 | "true_candidates must be a batch_size * num_true matrix" )); |
205 | |
206 | const int64_t batch_size = in_true_candidates_shape.dim_size(0); |
207 | |
208 | const Tensor& in_sampled_candidates = context->input(1); |
209 | OP_REQUIRES(context, |
210 | TensorShapeUtils::IsVector(in_sampled_candidates.shape()), |
211 | errors::InvalidArgument( |
212 | "sampled_candidates must be a vector, which is typically " |
213 | "an output from CandidateSampler" )); |
214 | |
215 | std::unordered_map<int64_t, int> sampled_candidate_to_pos; |
216 | for (int64_t i = 0; i < in_sampled_candidates.dim_size(0); ++i) { |
217 | sampled_candidate_to_pos[in_sampled_candidates.vec<int64_t>()(i)] = i; |
218 | } |
219 | |
220 | // Produce output in the same format as UnpackSparseFeatures. |
221 | std::vector<int> indices; |
222 | std::vector<int64_t> ids; |
223 | std::vector<float> weights; |
224 | |
225 | for (int64_t i = 0; i < batch_size; ++i) { |
226 | for (int64_t j = 0; j < num_true_; ++j) { |
227 | const int64_t true_candidate = |
228 | in_true_candidates.matrix<int64_t>()(i, j); |
229 | const auto look = sampled_candidate_to_pos.find(true_candidate); |
230 | if (look != sampled_candidate_to_pos.end()) { |
231 | indices.push_back(i); |
232 | ids.push_back(look->second); |
233 | weights.push_back(-FLT_MAX); |
234 | } |
235 | } |
236 | } |
237 | |
238 | Tensor* out_indices = nullptr; |
239 | OP_REQUIRES_OK( |
240 | context, |
241 | context->allocate_output( |
242 | 0, TensorShape({static_cast<int>(indices.size())}), &out_indices)); |
243 | Tensor* out_ids = nullptr; |
244 | OP_REQUIRES_OK( |
245 | context, context->allocate_output( |
246 | 1, TensorShape({static_cast<int>(ids.size())}), &out_ids)); |
247 | Tensor* out_weights = nullptr; |
248 | OP_REQUIRES_OK( |
249 | context, |
250 | context->allocate_output( |
251 | 2, TensorShape({static_cast<int>(weights.size())}), &out_weights)); |
252 | |
253 | for (size_t i = 0; i < indices.size(); ++i) { |
254 | out_indices->vec<int32>()(i) = indices[i]; |
255 | out_ids->vec<int64_t>()(i) = ids[i]; |
256 | out_weights->vec<float>()(i) = weights[i]; |
257 | } |
258 | } |
259 | |
260 | private: |
261 | int64_t num_true_; |
262 | }; |
263 | |
264 | REGISTER_KERNEL_BUILDER(Name("ComputeAccidentalHits" ).Device(DEVICE_CPU), |
265 | ComputeAccidentalHitsOp); |
266 | |
267 | } // namespace tensorflow |
268 | |