1 | // Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); you may not |
4 | // use this file except in compliance with the License. You may obtain a copy |
5 | // of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
11 | // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
12 | // License for the specific language governing permissions and limitations under |
13 | // the License. |
14 | // ============================================================================== |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <algorithm> |
19 | #include <memory> |
20 | #include <numeric> |
21 | #include <tuple> |
22 | #include <unordered_set> |
23 | #include <vector> |
24 | |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/lib/core/threadpool.h" |
31 | #include "tensorflow/core/lib/gtl/top_n.h" |
32 | #include "tensorflow/core/lib/random/philox_random.h" |
33 | #include "tensorflow/core/lib/random/simple_philox.h" |
34 | #include "tensorflow/core/platform/blocking_counter.h" |
35 | #include "tensorflow/core/platform/byte_order.h" |
36 | #include "tensorflow/core/platform/cpu_info.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | |
39 | namespace tensorflow { |
40 | namespace { |
41 | using errors::InvalidArgument; |
42 | |
43 | template <typename Scalar> |
44 | using RowMajorMatrix = |
45 | Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; |
46 | |
47 | using MatrixXfRowMajor = RowMajorMatrix<float>; |
48 | using MatrixXi64RowMajor = RowMajorMatrix<int64_t>; |
49 | |
50 | // Ideally this should be computed by dividing L3 cache size by the number of |
51 | // physical CPUs. Since there isn't a portable method to do this, we are using |
52 | // a conservative estimate here. |
53 | const int64_t kDefaultL3CachePerCpu = 1 << 20; |
54 | |
55 | // These values were determined by performing a parameter sweep on the |
56 | // NearestNeighborsOp benchmark. |
57 | const int64_t kNearestNeighborsCentersMaxBlockSize = 1024; |
58 | const int64_t kNearestNeighborsPointsMinBlockSize = 16; |
59 | |
60 | // Returns the smallest multiple of a that is not smaller than b. |
61 | int64_t NextMultiple(int64_t a, int64_t b) { |
62 | const int64_t remainder = b % a; |
63 | return remainder == 0 ? b : (b + a - remainder); |
64 | } |
65 | |
66 | // Returns a / b rounded up to the next higher integer. |
67 | int64_t CeilOfRatio(int64_t a, int64_t b) { return (a + b - 1) / b; } |
68 | |
69 | } // namespace |
70 | |
71 | // Implementation of K-means++ initialization. Samples points iteratively in |
72 | // proportion to the squared distances from selected points. |
73 | // TODO(ands): Add support for other distance metrics. |
74 | class KmeansPlusPlusInitializationOp : public OpKernel { |
75 | public: |
76 | explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context) |
77 | : OpKernel(context) { |
78 | OP_REQUIRES_OK(context, |
79 | context->MatchSignature( |
80 | {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT})); |
81 | } |
82 | |
83 | void Compute(OpKernelContext* context) override { |
84 | const Tensor& points_tensor = context->input(0); |
85 | const Tensor& num_to_sample_tensor = context->input(1); |
86 | const Tensor& seed_tensor = context->input(2); |
87 | const Tensor& num_retries_per_sample_tensor = context->input(3); |
88 | |
89 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()), |
90 | InvalidArgument("Input points should be a matrix." )); |
91 | OP_REQUIRES(context, |
92 | TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()), |
93 | InvalidArgument("Input num_to_sample should be a scalar." )); |
94 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), |
95 | InvalidArgument("Input seed should be a scalar." )); |
96 | OP_REQUIRES( |
97 | context, |
98 | TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()), |
99 | InvalidArgument("Input num_retries_per_sample should be a scalar." )); |
100 | |
101 | const int64_t num_points = points_tensor.dim_size(0); |
102 | const int64_t point_dimensions = points_tensor.dim_size(1); |
103 | const int64_t num_to_sample = num_to_sample_tensor.scalar<int64_t>()(); |
104 | const int64_t seed = seed_tensor.scalar<int64_t>()(); |
105 | const int64_t num_retries_per_sample = [&]() { |
106 | const int64_t value = num_retries_per_sample_tensor.scalar<int64_t>()(); |
107 | return value >= 0 ? value |
108 | : 2 + static_cast<int64_t>(std::log(num_to_sample)); |
109 | }(); |
110 | |
111 | OP_REQUIRES(context, num_points > 0, |
112 | InvalidArgument("Expected points.rows() > 0." )); |
113 | OP_REQUIRES(context, num_to_sample > 0, |
114 | InvalidArgument("Expected num_to_sample > 0." )); |
115 | OP_REQUIRES(context, num_to_sample <= num_points, |
116 | InvalidArgument("Expected num_to_sample <= points.rows(). " , |
117 | num_to_sample, " vs " , num_points, "." )); |
118 | |
119 | Tensor* output_sampled_points_tensor; |
120 | OP_REQUIRES_OK(context, |
121 | context->allocate_output( |
122 | 0, TensorShape({num_to_sample, point_dimensions}), |
123 | &output_sampled_points_tensor)); |
124 | |
125 | const Eigen::Map<const MatrixXfRowMajor> points( |
126 | points_tensor.matrix<float>().data(), num_points, point_dimensions); |
127 | const Eigen::VectorXf points_half_squared_norm = |
128 | 0.5 * points.rowwise().squaredNorm(); |
129 | |
130 | Eigen::Map<MatrixXfRowMajor> sampled_points( |
131 | output_sampled_points_tensor->matrix<float>().data(), num_to_sample, |
132 | point_dimensions); |
133 | std::unordered_set<int64_t> sampled_indices; |
134 | |
135 | random::PhiloxRandom random(seed); |
136 | random::SimplePhilox rng(&random); |
137 | |
138 | auto add_one_point = [&](int64_t from, int64_t to) { |
139 | from = std::min(from, num_points - 1); |
140 | sampled_points.row(to) = points.row(from); |
141 | sampled_indices.insert(from); |
142 | }; |
143 | |
144 | // Distances from all points to nearest selected point. Initialize with |
145 | // distances to first selected point. |
146 | Eigen::VectorXf min_distances(num_points); |
147 | min_distances.fill(std::numeric_limits<float>::infinity()); |
148 | Eigen::VectorXf min_distances_cumsum(num_points); |
149 | |
150 | auto draw_one_sample = [&]() -> int64 { |
151 | if (sampled_indices.empty()) return rng.Uniform64(num_points); |
152 | int64_t index = 0; |
153 | do { |
154 | // If v is drawn from Uniform[0, distances.sum()), then |
155 | // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is |
156 | // proportional to distances(i). |
157 | index = std::upper_bound( |
158 | min_distances_cumsum.data(), |
159 | min_distances_cumsum.data() + num_points, |
160 | rng.RandFloat() * min_distances_cumsum(num_points - 1)) - |
161 | min_distances_cumsum.data(); |
162 | } while (sampled_indices.find(index) != sampled_indices.end()); |
163 | return index; |
164 | }; |
165 | |
166 | auto sample_one_point = [&]() { |
167 | const int64_t sampled_index = draw_one_sample(); |
168 | min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY( |
169 | points, points_half_squared_norm, points.row(sampled_index), |
170 | points_half_squared_norm(sampled_index))); |
171 | return sampled_index; |
172 | }; |
173 | |
174 | auto sample_one_point_with_retries = [&]() { |
175 | Eigen::VectorXf best_new_min_distances(num_points); |
176 | float best_potential = std::numeric_limits<float>::infinity(); |
177 | int64_t best_sampled_index = 0; |
178 | for (int i = 1 + num_retries_per_sample; i > 0; --i) { |
179 | const int64_t sampled_index = draw_one_sample(); |
180 | Eigen::VectorXf new_min_distances = |
181 | min_distances.cwiseMin(GetHalfSquaredDistancesToY( |
182 | points, points_half_squared_norm, points.row(sampled_index), |
183 | points_half_squared_norm(sampled_index))); |
184 | const float potential = new_min_distances.sum(); |
185 | if (potential < best_potential) { |
186 | best_potential = potential; |
187 | best_sampled_index = sampled_index; |
188 | best_new_min_distances.swap(new_min_distances); |
189 | } |
190 | } |
191 | min_distances.swap(best_new_min_distances); |
192 | return best_sampled_index; |
193 | }; |
194 | |
195 | for (int64_t i = 0; i < num_to_sample; ++i) { |
196 | if (i > 0) { |
197 | std::partial_sum(min_distances.data(), |
198 | min_distances.data() + num_points, |
199 | min_distances_cumsum.data()); |
200 | } |
201 | int64_t next = num_retries_per_sample == 0 |
202 | ? sample_one_point() |
203 | : sample_one_point_with_retries(); |
204 | add_one_point(next, i); |
205 | } |
206 | } |
207 | |
208 | private: |
209 | // Returns a column vector with the i-th element set to half the squared |
210 | // euclidean distance between the i-th row of xs, and y. Precomputed norms for |
211 | // each row of xs and y must be provided for efficiency. |
212 | // TODO(ands): Parallelize this for large xs. |
213 | static Eigen::VectorXf GetHalfSquaredDistancesToY( |
214 | const Eigen::Ref<const MatrixXfRowMajor>& xs, |
215 | const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm, |
216 | const Eigen::Ref<const Eigen::RowVectorXf>& y, |
217 | float y_half_squared_norm) { |
218 | // Squared distance between points xs_i and y is: |
219 | // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2 |
220 | return (xs_half_squared_norm - xs * y.transpose()).array() + |
221 | y_half_squared_norm; |
222 | } |
223 | }; |
224 | |
225 | REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization" ).Device(DEVICE_CPU), |
226 | KmeansPlusPlusInitializationOp); |
227 | |
228 | // Implementation of one single Markov Chain for the k-MC^2 algorithm |
229 | class KMC2ChainInitializationOp : public OpKernel { |
230 | public: |
231 | explicit KMC2ChainInitializationOp(OpKernelConstruction* context) |
232 | : OpKernel(context) { |
233 | OP_REQUIRES_OK(context, |
234 | context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64})); |
235 | } |
236 | |
237 | void Compute(OpKernelContext* context) override { |
238 | const Tensor& distances_tensor = context->input(0); |
239 | const Tensor& seed_tensor = context->input(1); |
240 | OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()), |
241 | InvalidArgument("Input distances should be a vector." )); |
242 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), |
243 | InvalidArgument("Input seed should be a scalar." )); |
244 | const int64_t num_points = distances_tensor.dim_size(0); |
245 | const int64_t seed = seed_tensor.scalar<int64_t>()(); |
246 | OP_REQUIRES(context, num_points > 0, |
247 | InvalidArgument("Expected distances_tensor.size() > 0." )); |
248 | |
249 | random::PhiloxRandom random(seed); |
250 | random::SimplePhilox rng(&random); |
251 | |
252 | auto distances = distances_tensor.flat<float>(); |
253 | // Set the initial state of the Markov chain to be the first candidate. |
254 | int64_t selected_index = 0; |
255 | float selected_distance = distances(selected_index); |
256 | // Build a Markov chain of length num_points. |
257 | for (int64_t i = 1; i < num_points; ++i) { |
258 | const float candidate_distance = distances(i); |
259 | // Set the next state of the Markov chain to be the candidate with |
260 | // probability min(1, candidate_distance/selected_distance). |
261 | if (candidate_distance > rng.RandFloat() * selected_distance) { |
262 | selected_index = i; |
263 | selected_distance = candidate_distance; |
264 | } |
265 | } |
266 | |
267 | Tensor* output_sampled_index_tensor; |
268 | OP_REQUIRES_OK(context, |
269 | context->allocate_output(0, TensorShape({}), |
270 | &output_sampled_index_tensor)); |
271 | auto output = output_sampled_index_tensor->scalar<int64_t>(); |
272 | // Return the last state of the Markov chain as the new center. |
273 | output() = selected_index; |
274 | } |
275 | }; |
276 | |
277 | REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization" ).Device(DEVICE_CPU), |
278 | KMC2ChainInitializationOp); |
279 | |
280 | // Operator for computing the nearest neighbors for a set of points. |
281 | class NearestNeighborsOp : public OpKernel { |
282 | public: |
283 | explicit NearestNeighborsOp(OpKernelConstruction* context) |
284 | : OpKernel(context) { |
285 | OP_REQUIRES_OK(context, |
286 | context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64}, |
287 | {DT_INT64, DT_FLOAT})); |
288 | } |
289 | |
290 | void Compute(OpKernelContext* context) override { |
291 | const Tensor& points_tensor = context->input(0); |
292 | const Tensor& centers_tensor = context->input(1); |
293 | const Tensor& k_tensor = context->input(2); |
294 | |
295 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()), |
296 | InvalidArgument("Input points should be a matrix." )); |
297 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()), |
298 | InvalidArgument("Input centers should be a matrix." )); |
299 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()), |
300 | InvalidArgument("Input k should be a scalar." )); |
301 | |
302 | const int64_t num_points = points_tensor.dim_size(0); |
303 | const int64_t point_dimensions = points_tensor.dim_size(1); |
304 | const int64_t num_centers = centers_tensor.dim_size(0); |
305 | const int64_t center_dimensions = centers_tensor.dim_size(1); |
306 | |
307 | OP_REQUIRES(context, num_points > 0, |
308 | InvalidArgument("Expected points.rows() > 0." )); |
309 | OP_REQUIRES( |
310 | context, point_dimensions == center_dimensions, |
311 | InvalidArgument("Expected point_dimensions == center_dimensions: " , |
312 | point_dimensions, " vs " , center_dimensions, "." )); |
313 | |
314 | const Eigen::Map<const MatrixXfRowMajor> points( |
315 | points_tensor.matrix<float>().data(), num_points, point_dimensions); |
316 | const Eigen::Map<const MatrixXfRowMajor> centers( |
317 | centers_tensor.matrix<float>().data(), num_centers, center_dimensions); |
318 | const int64_t k = |
319 | std::min<int64_t>(num_centers, k_tensor.scalar<int64_t>()()); |
320 | |
321 | Tensor* output_nearest_center_indices_tensor; |
322 | Tensor* output_nearest_center_distances_tensor; |
323 | OP_REQUIRES_OK(context, context->allocate_output( |
324 | 0, TensorShape({num_points, k}), |
325 | &output_nearest_center_indices_tensor)); |
326 | OP_REQUIRES_OK(context, context->allocate_output( |
327 | 1, TensorShape({num_points, k}), |
328 | &output_nearest_center_distances_tensor)); |
329 | |
330 | if (k == 0) return; |
331 | |
332 | Eigen::Map<MatrixXi64RowMajor> nearest_center_indices( |
333 | output_nearest_center_indices_tensor->matrix<int64_t>().data(), |
334 | num_points, k); |
335 | Eigen::Map<MatrixXfRowMajor> nearest_center_distances( |
336 | output_nearest_center_distances_tensor->matrix<float>().data(), |
337 | num_points, k); |
338 | |
339 | const Eigen::VectorXf centers_half_squared_norm = |
340 | 0.5 * centers.rowwise().squaredNorm(); |
341 | |
342 | // The distance computation is sharded to take advantage of multiple cores |
343 | // and to allow intermediate values to reside in L3 cache. This is done by |
344 | // sharding the points and centers as follows: |
345 | // |
346 | // 1. Centers are sharded such that each block of centers has at most |
347 | // kNearestNeighborsCentersMaxBlockSize rows. |
348 | // 2. Points are sharded, and each block of points is multiplied with each |
349 | // block of centers. The block size of points is chosen such that the |
350 | // point coordinates (point_dimensions) and the matrix of distances to |
351 | // each center in one block -- the intermediate data -- fits in L3 cache. |
352 | // 3. After performing each block-block distance computation, the results |
353 | // are reduced to a set of k nearest centers as soon as possible. This |
354 | // decreases total memory I/O. |
355 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
356 | const int64_t num_threads = worker_threads.num_threads; |
357 | // This kernel might be configured to use fewer than the total number of |
358 | // available CPUs on the host machine. To avoid destructive interference |
359 | // with other jobs running on the host machine, we must only use a fraction |
360 | // of total available L3 cache. Unfortunately, we cannot query the host |
361 | // machine to get the number of physical CPUs. So, we use a fixed per-CPU |
362 | // budget and scale it by the number of CPUs available to this operation. |
363 | const int64_t total_memory_budget = |
364 | kDefaultL3CachePerCpu * port::NumSchedulableCPUs(); |
365 | // Compute the number of blocks into which rows of points must be split so |
366 | // that the distance matrix and the block of points can fit in cache. One |
367 | // row of points will yield a vector of distances to each center in a block. |
368 | const int64_t bytes_per_row = |
369 | (std::min(kNearestNeighborsCentersMaxBlockSize, |
370 | num_centers) /* centers in a block */ |
371 | + point_dimensions /* coordinates of one point */) * |
372 | sizeof(float); |
373 | // The memory needed for storing the centers being processed. This is shared |
374 | // by all workers. Adding slack to the number of threads to avoid incorrect |
375 | // cache eviction when a new block of centers is loaded. |
376 | const int64_t bytes_for_centers = |
377 | std::min(num_centers, |
378 | (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) * |
379 | point_dimensions * sizeof(float); |
380 | // The memory budget available for workers to store their distance matrices. |
381 | const int64_t available_memory_budget = |
382 | total_memory_budget - bytes_for_centers; |
383 | // That memory budget is shared by all threads. |
384 | const int64_t rows_per_block = std::max<int64_t>( |
385 | kNearestNeighborsPointsMinBlockSize, |
386 | available_memory_budget / num_threads / bytes_per_row); |
387 | // Divide rows into almost uniformly-sized units of work that are small |
388 | // enough for the memory budget (rows_per_block). Round up to a multiple of |
389 | // the number of threads. |
390 | const int64_t num_units = |
391 | NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block)); |
392 | auto work = [&](int64_t start, int64_t limit) { |
393 | for (; start < limit; ++start) { |
394 | const int64_t start_row = num_points * start / num_units; |
395 | const int64_t limit_row = num_points * (start + 1) / num_units; |
396 | DCHECK_LE(limit_row, num_points); |
397 | const int64_t num_rows = limit_row - start_row; |
398 | auto points_shard = points.middleRows(start_row, num_rows); |
399 | const Eigen::VectorXf points_half_squared_norm = |
400 | 0.5 * points_shard.rowwise().squaredNorm(); |
401 | auto nearest_center_indices_shard = |
402 | nearest_center_indices.middleRows(start_row, num_rows); |
403 | auto nearest_center_distances_shard = |
404 | nearest_center_distances.middleRows(start_row, num_rows); |
405 | FindKNearestCenters(k, points_shard, points_half_squared_norm, centers, |
406 | centers_half_squared_norm, |
407 | nearest_center_indices_shard, |
408 | nearest_center_distances_shard); |
409 | } |
410 | }; |
411 | |
412 | const int64_t units_per_thread = num_units / num_threads; |
413 | BlockingCounter counter(num_threads - 1); |
414 | for (int64_t i = 1; i < num_threads; ++i) { |
415 | const int64_t start = i * units_per_thread; |
416 | const int64_t limit = start + units_per_thread; |
417 | worker_threads.workers->Schedule([work, &counter, start, limit]() { |
418 | work(start, limit); |
419 | counter.DecrementCount(); |
420 | }); |
421 | } |
422 | work(0, units_per_thread); |
423 | counter.Wait(); |
424 | } |
425 | |
426 | private: |
427 | static void FindKNearestCenters( |
428 | int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points, |
429 | const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, |
430 | const Eigen::Ref<const MatrixXfRowMajor>& centers, |
431 | const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, |
432 | const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices, |
433 | const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) { |
434 | DCHECK_LE(k, centers.rows()); |
435 | if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) { |
436 | FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers, |
437 | centers_half_squared_norm, |
438 | nearest_center_indices, |
439 | nearest_center_distances); |
440 | } else { |
441 | FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers, |
442 | centers_half_squared_norm, |
443 | nearest_center_indices, |
444 | nearest_center_distances); |
445 | } |
446 | } |
447 | |
448 | static void FindKNearestCentersOneBlock( |
449 | int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points, |
450 | const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, |
451 | const Eigen::Ref<const MatrixXfRowMajor>& centers, |
452 | const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, |
453 | Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, |
454 | Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { |
455 | DCHECK_LE(k, centers.rows()); |
456 | const int64_t num_points = points.rows(); |
457 | const MatrixXfRowMajor inner_product = points * centers.transpose(); |
458 | // Find nearest neighbors. |
459 | if (k == 1) { |
460 | for (int i = 0; i < num_points; ++i) { |
461 | int64_t index; |
462 | nearest_center_distances(i, 0) = |
463 | 2.0 * |
464 | (points_half_squared_norm(i) + |
465 | (centers_half_squared_norm.transpose() - inner_product.row(i)) |
466 | .minCoeff(&index)); |
467 | nearest_center_indices(i, 0) = index; |
468 | } |
469 | } else { |
470 | // Select k nearest centers for each point. |
471 | using Center = std::pair<float, int64_t>; |
472 | const int64_t num_centers = centers.rows(); |
473 | gtl::TopN<Center, std::less<Center>> selector(k); |
474 | std::unique_ptr<std::vector<Center>> nearest_centers; |
475 | for (int i = 0; i < num_points; ++i) { |
476 | selector.reserve(num_centers); |
477 | for (int j = 0; j < num_centers; ++j) { |
478 | const float partial_distance = |
479 | centers_half_squared_norm(j) - inner_product(i, j); |
480 | selector.push(Center(partial_distance, j)); |
481 | } |
482 | nearest_centers.reset(selector.Extract()); |
483 | selector.Reset(); |
484 | const float point_half_squared_norm = points_half_squared_norm(i); |
485 | for (int j = 0; j < k; ++j) { |
486 | const Center& center = (*nearest_centers)[j]; |
487 | nearest_center_distances(i, j) = |
488 | 2.0 * (point_half_squared_norm + center.first); |
489 | nearest_center_indices(i, j) = center.second; |
490 | } |
491 | } |
492 | } |
493 | } |
494 | |
495 | static void FindKNearestCentersBlockwise( |
496 | int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points, |
497 | const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, |
498 | const Eigen::Ref<const MatrixXfRowMajor>& centers, |
499 | const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, |
500 | Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, |
501 | Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { |
502 | const int64_t num_points = points.rows(); |
503 | const int64_t num_centers = centers.rows(); |
504 | DCHECK_LE(k, num_centers); |
505 | DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize); |
506 | // Store nearest neighbors with first block of centers directly into the |
507 | // output matrices. |
508 | int64_t out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize); |
509 | FindKNearestCentersOneBlock( |
510 | out_k, points, points_half_squared_norm, |
511 | centers.topRows(kNearestNeighborsCentersMaxBlockSize), |
512 | centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize), |
513 | nearest_center_indices, nearest_center_distances); |
514 | // Iteratively compute nearest neighbors with other blocks of centers, and |
515 | // update the output matrices. |
516 | MatrixXi64RowMajor block_nearest_center_indices(num_points, k); |
517 | MatrixXfRowMajor block_nearest_center_distances(num_points, k); |
518 | Eigen::Matrix<int64_t, 1, Eigen::Dynamic> merged_indices(k); |
519 | Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k); |
520 | for (int64_t centers_start = kNearestNeighborsCentersMaxBlockSize; |
521 | centers_start < num_centers; |
522 | centers_start += kNearestNeighborsCentersMaxBlockSize) { |
523 | const int64_t centers_block_size = std::min( |
524 | kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start); |
525 | const int64_t block_k = std::min(k, centers_block_size); |
526 | FindKNearestCentersOneBlock( |
527 | block_k, points, points_half_squared_norm, |
528 | centers.middleRows(centers_start, centers_block_size), |
529 | centers_half_squared_norm.segment(centers_start, centers_block_size), |
530 | block_nearest_center_indices, block_nearest_center_distances); |
531 | if (k == 1) { |
532 | for (int i = 0; i < num_points; ++i) { |
533 | if (block_nearest_center_distances(i, 0) < |
534 | nearest_center_distances(i, 0)) { |
535 | nearest_center_indices(i, 0) = |
536 | block_nearest_center_indices(i, 0) + centers_start; |
537 | nearest_center_distances(i, 0) = |
538 | block_nearest_center_distances(i, 0); |
539 | } |
540 | } |
541 | } else { |
542 | for (int i = 0; i < num_points; ++i) { |
543 | // Merge and accumulate top-k list from block_nearest_center_indices |
544 | // into nearest_center_indices. |
545 | for (int64_t j_out = 0, j_block = 0, j_merged = 0; |
546 | (j_out < out_k || j_block < block_k) && j_merged < k; |
547 | ++j_merged) { |
548 | const float distance_out = |
549 | j_out < out_k ? nearest_center_distances(i, j_out) |
550 | : std::numeric_limits<float>::infinity(); |
551 | const float distance_block = |
552 | j_block < block_k ? block_nearest_center_distances(i, j_block) |
553 | : std::numeric_limits<float>::infinity(); |
554 | if (distance_out <= distance_block) { |
555 | merged_indices(j_merged) = nearest_center_indices(i, j_out); |
556 | merged_distances(j_merged) = distance_out; |
557 | ++j_out; |
558 | } else { |
559 | merged_indices(j_merged) = |
560 | block_nearest_center_indices(i, j_block) + centers_start; |
561 | merged_distances(j_merged) = distance_block; |
562 | ++j_block; |
563 | } |
564 | } |
565 | nearest_center_indices.row(i) = merged_indices; |
566 | nearest_center_distances.row(i) = merged_distances; |
567 | out_k = std::min(k, out_k + block_k); |
568 | } |
569 | } |
570 | } |
571 | } |
572 | }; |
573 | |
574 | REGISTER_KERNEL_BUILDER(Name("NearestNeighbors" ).Device(DEVICE_CPU), |
575 | NearestNeighborsOp); |
576 | |
577 | } // namespace tensorflow |
578 | |