1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License. You may obtain a copy
5// of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12// License for the specific language governing permissions and limitations under
13// the License.
14// ==============================================================================
15
16#define EIGEN_USE_THREADS
17
18#include <algorithm>
19#include <memory>
20#include <numeric>
21#include <tuple>
22#include <unordered_set>
23#include <vector>
24
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/lib/core/threadpool.h"
31#include "tensorflow/core/lib/gtl/top_n.h"
32#include "tensorflow/core/lib/random/philox_random.h"
33#include "tensorflow/core/lib/random/simple_philox.h"
34#include "tensorflow/core/platform/blocking_counter.h"
35#include "tensorflow/core/platform/byte_order.h"
36#include "tensorflow/core/platform/cpu_info.h"
37#include "tensorflow/core/platform/logging.h"
38
39namespace tensorflow {
40namespace {
41using errors::InvalidArgument;
42
43template <typename Scalar>
44using RowMajorMatrix =
45 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
46
47using MatrixXfRowMajor = RowMajorMatrix<float>;
48using MatrixXi64RowMajor = RowMajorMatrix<int64_t>;
49
50// Ideally this should be computed by dividing L3 cache size by the number of
51// physical CPUs. Since there isn't a portable method to do this, we are using
52// a conservative estimate here.
53const int64_t kDefaultL3CachePerCpu = 1 << 20;
54
55// These values were determined by performing a parameter sweep on the
56// NearestNeighborsOp benchmark.
57const int64_t kNearestNeighborsCentersMaxBlockSize = 1024;
58const int64_t kNearestNeighborsPointsMinBlockSize = 16;
59
60// Returns the smallest multiple of a that is not smaller than b.
61int64_t NextMultiple(int64_t a, int64_t b) {
62 const int64_t remainder = b % a;
63 return remainder == 0 ? b : (b + a - remainder);
64}
65
66// Returns a / b rounded up to the next higher integer.
67int64_t CeilOfRatio(int64_t a, int64_t b) { return (a + b - 1) / b; }
68
69} // namespace
70
71// Implementation of K-means++ initialization. Samples points iteratively in
72// proportion to the squared distances from selected points.
73// TODO(ands): Add support for other distance metrics.
74class KmeansPlusPlusInitializationOp : public OpKernel {
75 public:
76 explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
77 : OpKernel(context) {
78 OP_REQUIRES_OK(context,
79 context->MatchSignature(
80 {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
81 }
82
83 void Compute(OpKernelContext* context) override {
84 const Tensor& points_tensor = context->input(0);
85 const Tensor& num_to_sample_tensor = context->input(1);
86 const Tensor& seed_tensor = context->input(2);
87 const Tensor& num_retries_per_sample_tensor = context->input(3);
88
89 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
90 InvalidArgument("Input points should be a matrix."));
91 OP_REQUIRES(context,
92 TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
93 InvalidArgument("Input num_to_sample should be a scalar."));
94 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
95 InvalidArgument("Input seed should be a scalar."));
96 OP_REQUIRES(
97 context,
98 TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
99 InvalidArgument("Input num_retries_per_sample should be a scalar."));
100
101 const int64_t num_points = points_tensor.dim_size(0);
102 const int64_t point_dimensions = points_tensor.dim_size(1);
103 const int64_t num_to_sample = num_to_sample_tensor.scalar<int64_t>()();
104 const int64_t seed = seed_tensor.scalar<int64_t>()();
105 const int64_t num_retries_per_sample = [&]() {
106 const int64_t value = num_retries_per_sample_tensor.scalar<int64_t>()();
107 return value >= 0 ? value
108 : 2 + static_cast<int64_t>(std::log(num_to_sample));
109 }();
110
111 OP_REQUIRES(context, num_points > 0,
112 InvalidArgument("Expected points.rows() > 0."));
113 OP_REQUIRES(context, num_to_sample > 0,
114 InvalidArgument("Expected num_to_sample > 0."));
115 OP_REQUIRES(context, num_to_sample <= num_points,
116 InvalidArgument("Expected num_to_sample <= points.rows(). ",
117 num_to_sample, " vs ", num_points, "."));
118
119 Tensor* output_sampled_points_tensor;
120 OP_REQUIRES_OK(context,
121 context->allocate_output(
122 0, TensorShape({num_to_sample, point_dimensions}),
123 &output_sampled_points_tensor));
124
125 const Eigen::Map<const MatrixXfRowMajor> points(
126 points_tensor.matrix<float>().data(), num_points, point_dimensions);
127 const Eigen::VectorXf points_half_squared_norm =
128 0.5 * points.rowwise().squaredNorm();
129
130 Eigen::Map<MatrixXfRowMajor> sampled_points(
131 output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
132 point_dimensions);
133 std::unordered_set<int64_t> sampled_indices;
134
135 random::PhiloxRandom random(seed);
136 random::SimplePhilox rng(&random);
137
138 auto add_one_point = [&](int64_t from, int64_t to) {
139 from = std::min(from, num_points - 1);
140 sampled_points.row(to) = points.row(from);
141 sampled_indices.insert(from);
142 };
143
144 // Distances from all points to nearest selected point. Initialize with
145 // distances to first selected point.
146 Eigen::VectorXf min_distances(num_points);
147 min_distances.fill(std::numeric_limits<float>::infinity());
148 Eigen::VectorXf min_distances_cumsum(num_points);
149
150 auto draw_one_sample = [&]() -> int64 {
151 if (sampled_indices.empty()) return rng.Uniform64(num_points);
152 int64_t index = 0;
153 do {
154 // If v is drawn from Uniform[0, distances.sum()), then
155 // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
156 // proportional to distances(i).
157 index = std::upper_bound(
158 min_distances_cumsum.data(),
159 min_distances_cumsum.data() + num_points,
160 rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
161 min_distances_cumsum.data();
162 } while (sampled_indices.find(index) != sampled_indices.end());
163 return index;
164 };
165
166 auto sample_one_point = [&]() {
167 const int64_t sampled_index = draw_one_sample();
168 min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
169 points, points_half_squared_norm, points.row(sampled_index),
170 points_half_squared_norm(sampled_index)));
171 return sampled_index;
172 };
173
174 auto sample_one_point_with_retries = [&]() {
175 Eigen::VectorXf best_new_min_distances(num_points);
176 float best_potential = std::numeric_limits<float>::infinity();
177 int64_t best_sampled_index = 0;
178 for (int i = 1 + num_retries_per_sample; i > 0; --i) {
179 const int64_t sampled_index = draw_one_sample();
180 Eigen::VectorXf new_min_distances =
181 min_distances.cwiseMin(GetHalfSquaredDistancesToY(
182 points, points_half_squared_norm, points.row(sampled_index),
183 points_half_squared_norm(sampled_index)));
184 const float potential = new_min_distances.sum();
185 if (potential < best_potential) {
186 best_potential = potential;
187 best_sampled_index = sampled_index;
188 best_new_min_distances.swap(new_min_distances);
189 }
190 }
191 min_distances.swap(best_new_min_distances);
192 return best_sampled_index;
193 };
194
195 for (int64_t i = 0; i < num_to_sample; ++i) {
196 if (i > 0) {
197 std::partial_sum(min_distances.data(),
198 min_distances.data() + num_points,
199 min_distances_cumsum.data());
200 }
201 int64_t next = num_retries_per_sample == 0
202 ? sample_one_point()
203 : sample_one_point_with_retries();
204 add_one_point(next, i);
205 }
206 }
207
208 private:
209 // Returns a column vector with the i-th element set to half the squared
210 // euclidean distance between the i-th row of xs, and y. Precomputed norms for
211 // each row of xs and y must be provided for efficiency.
212 // TODO(ands): Parallelize this for large xs.
213 static Eigen::VectorXf GetHalfSquaredDistancesToY(
214 const Eigen::Ref<const MatrixXfRowMajor>& xs,
215 const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
216 const Eigen::Ref<const Eigen::RowVectorXf>& y,
217 float y_half_squared_norm) {
218 // Squared distance between points xs_i and y is:
219 // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
220 return (xs_half_squared_norm - xs * y.transpose()).array() +
221 y_half_squared_norm;
222 }
223};
224
225REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
226 KmeansPlusPlusInitializationOp);
227
228// Implementation of one single Markov Chain for the k-MC^2 algorithm
229class KMC2ChainInitializationOp : public OpKernel {
230 public:
231 explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
232 : OpKernel(context) {
233 OP_REQUIRES_OK(context,
234 context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
235 }
236
237 void Compute(OpKernelContext* context) override {
238 const Tensor& distances_tensor = context->input(0);
239 const Tensor& seed_tensor = context->input(1);
240 OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
241 InvalidArgument("Input distances should be a vector."));
242 OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
243 InvalidArgument("Input seed should be a scalar."));
244 const int64_t num_points = distances_tensor.dim_size(0);
245 const int64_t seed = seed_tensor.scalar<int64_t>()();
246 OP_REQUIRES(context, num_points > 0,
247 InvalidArgument("Expected distances_tensor.size() > 0."));
248
249 random::PhiloxRandom random(seed);
250 random::SimplePhilox rng(&random);
251
252 auto distances = distances_tensor.flat<float>();
253 // Set the initial state of the Markov chain to be the first candidate.
254 int64_t selected_index = 0;
255 float selected_distance = distances(selected_index);
256 // Build a Markov chain of length num_points.
257 for (int64_t i = 1; i < num_points; ++i) {
258 const float candidate_distance = distances(i);
259 // Set the next state of the Markov chain to be the candidate with
260 // probability min(1, candidate_distance/selected_distance).
261 if (candidate_distance > rng.RandFloat() * selected_distance) {
262 selected_index = i;
263 selected_distance = candidate_distance;
264 }
265 }
266
267 Tensor* output_sampled_index_tensor;
268 OP_REQUIRES_OK(context,
269 context->allocate_output(0, TensorShape({}),
270 &output_sampled_index_tensor));
271 auto output = output_sampled_index_tensor->scalar<int64_t>();
272 // Return the last state of the Markov chain as the new center.
273 output() = selected_index;
274 }
275};
276
277REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
278 KMC2ChainInitializationOp);
279
280// Operator for computing the nearest neighbors for a set of points.
281class NearestNeighborsOp : public OpKernel {
282 public:
283 explicit NearestNeighborsOp(OpKernelConstruction* context)
284 : OpKernel(context) {
285 OP_REQUIRES_OK(context,
286 context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
287 {DT_INT64, DT_FLOAT}));
288 }
289
290 void Compute(OpKernelContext* context) override {
291 const Tensor& points_tensor = context->input(0);
292 const Tensor& centers_tensor = context->input(1);
293 const Tensor& k_tensor = context->input(2);
294
295 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
296 InvalidArgument("Input points should be a matrix."));
297 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
298 InvalidArgument("Input centers should be a matrix."));
299 OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
300 InvalidArgument("Input k should be a scalar."));
301
302 const int64_t num_points = points_tensor.dim_size(0);
303 const int64_t point_dimensions = points_tensor.dim_size(1);
304 const int64_t num_centers = centers_tensor.dim_size(0);
305 const int64_t center_dimensions = centers_tensor.dim_size(1);
306
307 OP_REQUIRES(context, num_points > 0,
308 InvalidArgument("Expected points.rows() > 0."));
309 OP_REQUIRES(
310 context, point_dimensions == center_dimensions,
311 InvalidArgument("Expected point_dimensions == center_dimensions: ",
312 point_dimensions, " vs ", center_dimensions, "."));
313
314 const Eigen::Map<const MatrixXfRowMajor> points(
315 points_tensor.matrix<float>().data(), num_points, point_dimensions);
316 const Eigen::Map<const MatrixXfRowMajor> centers(
317 centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
318 const int64_t k =
319 std::min<int64_t>(num_centers, k_tensor.scalar<int64_t>()());
320
321 Tensor* output_nearest_center_indices_tensor;
322 Tensor* output_nearest_center_distances_tensor;
323 OP_REQUIRES_OK(context, context->allocate_output(
324 0, TensorShape({num_points, k}),
325 &output_nearest_center_indices_tensor));
326 OP_REQUIRES_OK(context, context->allocate_output(
327 1, TensorShape({num_points, k}),
328 &output_nearest_center_distances_tensor));
329
330 if (k == 0) return;
331
332 Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
333 output_nearest_center_indices_tensor->matrix<int64_t>().data(),
334 num_points, k);
335 Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
336 output_nearest_center_distances_tensor->matrix<float>().data(),
337 num_points, k);
338
339 const Eigen::VectorXf centers_half_squared_norm =
340 0.5 * centers.rowwise().squaredNorm();
341
342 // The distance computation is sharded to take advantage of multiple cores
343 // and to allow intermediate values to reside in L3 cache. This is done by
344 // sharding the points and centers as follows:
345 //
346 // 1. Centers are sharded such that each block of centers has at most
347 // kNearestNeighborsCentersMaxBlockSize rows.
348 // 2. Points are sharded, and each block of points is multiplied with each
349 // block of centers. The block size of points is chosen such that the
350 // point coordinates (point_dimensions) and the matrix of distances to
351 // each center in one block -- the intermediate data -- fits in L3 cache.
352 // 3. After performing each block-block distance computation, the results
353 // are reduced to a set of k nearest centers as soon as possible. This
354 // decreases total memory I/O.
355 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
356 const int64_t num_threads = worker_threads.num_threads;
357 // This kernel might be configured to use fewer than the total number of
358 // available CPUs on the host machine. To avoid destructive interference
359 // with other jobs running on the host machine, we must only use a fraction
360 // of total available L3 cache. Unfortunately, we cannot query the host
361 // machine to get the number of physical CPUs. So, we use a fixed per-CPU
362 // budget and scale it by the number of CPUs available to this operation.
363 const int64_t total_memory_budget =
364 kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
365 // Compute the number of blocks into which rows of points must be split so
366 // that the distance matrix and the block of points can fit in cache. One
367 // row of points will yield a vector of distances to each center in a block.
368 const int64_t bytes_per_row =
369 (std::min(kNearestNeighborsCentersMaxBlockSize,
370 num_centers) /* centers in a block */
371 + point_dimensions /* coordinates of one point */) *
372 sizeof(float);
373 // The memory needed for storing the centers being processed. This is shared
374 // by all workers. Adding slack to the number of threads to avoid incorrect
375 // cache eviction when a new block of centers is loaded.
376 const int64_t bytes_for_centers =
377 std::min(num_centers,
378 (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
379 point_dimensions * sizeof(float);
380 // The memory budget available for workers to store their distance matrices.
381 const int64_t available_memory_budget =
382 total_memory_budget - bytes_for_centers;
383 // That memory budget is shared by all threads.
384 const int64_t rows_per_block = std::max<int64_t>(
385 kNearestNeighborsPointsMinBlockSize,
386 available_memory_budget / num_threads / bytes_per_row);
387 // Divide rows into almost uniformly-sized units of work that are small
388 // enough for the memory budget (rows_per_block). Round up to a multiple of
389 // the number of threads.
390 const int64_t num_units =
391 NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
392 auto work = [&](int64_t start, int64_t limit) {
393 for (; start < limit; ++start) {
394 const int64_t start_row = num_points * start / num_units;
395 const int64_t limit_row = num_points * (start + 1) / num_units;
396 DCHECK_LE(limit_row, num_points);
397 const int64_t num_rows = limit_row - start_row;
398 auto points_shard = points.middleRows(start_row, num_rows);
399 const Eigen::VectorXf points_half_squared_norm =
400 0.5 * points_shard.rowwise().squaredNorm();
401 auto nearest_center_indices_shard =
402 nearest_center_indices.middleRows(start_row, num_rows);
403 auto nearest_center_distances_shard =
404 nearest_center_distances.middleRows(start_row, num_rows);
405 FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
406 centers_half_squared_norm,
407 nearest_center_indices_shard,
408 nearest_center_distances_shard);
409 }
410 };
411
412 const int64_t units_per_thread = num_units / num_threads;
413 BlockingCounter counter(num_threads - 1);
414 for (int64_t i = 1; i < num_threads; ++i) {
415 const int64_t start = i * units_per_thread;
416 const int64_t limit = start + units_per_thread;
417 worker_threads.workers->Schedule([work, &counter, start, limit]() {
418 work(start, limit);
419 counter.DecrementCount();
420 });
421 }
422 work(0, units_per_thread);
423 counter.Wait();
424 }
425
426 private:
427 static void FindKNearestCenters(
428 int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
429 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
430 const Eigen::Ref<const MatrixXfRowMajor>& centers,
431 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
432 const Eigen::Ref<MatrixXi64RowMajor>& nearest_center_indices,
433 const Eigen::Ref<MatrixXfRowMajor>& nearest_center_distances) {
434 DCHECK_LE(k, centers.rows());
435 if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
436 FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
437 centers_half_squared_norm,
438 nearest_center_indices,
439 nearest_center_distances);
440 } else {
441 FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
442 centers_half_squared_norm,
443 nearest_center_indices,
444 nearest_center_distances);
445 }
446 }
447
448 static void FindKNearestCentersOneBlock(
449 int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
450 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
451 const Eigen::Ref<const MatrixXfRowMajor>& centers,
452 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
453 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
454 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
455 DCHECK_LE(k, centers.rows());
456 const int64_t num_points = points.rows();
457 const MatrixXfRowMajor inner_product = points * centers.transpose();
458 // Find nearest neighbors.
459 if (k == 1) {
460 for (int i = 0; i < num_points; ++i) {
461 int64_t index;
462 nearest_center_distances(i, 0) =
463 2.0 *
464 (points_half_squared_norm(i) +
465 (centers_half_squared_norm.transpose() - inner_product.row(i))
466 .minCoeff(&index));
467 nearest_center_indices(i, 0) = index;
468 }
469 } else {
470 // Select k nearest centers for each point.
471 using Center = std::pair<float, int64_t>;
472 const int64_t num_centers = centers.rows();
473 gtl::TopN<Center, std::less<Center>> selector(k);
474 std::unique_ptr<std::vector<Center>> nearest_centers;
475 for (int i = 0; i < num_points; ++i) {
476 selector.reserve(num_centers);
477 for (int j = 0; j < num_centers; ++j) {
478 const float partial_distance =
479 centers_half_squared_norm(j) - inner_product(i, j);
480 selector.push(Center(partial_distance, j));
481 }
482 nearest_centers.reset(selector.Extract());
483 selector.Reset();
484 const float point_half_squared_norm = points_half_squared_norm(i);
485 for (int j = 0; j < k; ++j) {
486 const Center& center = (*nearest_centers)[j];
487 nearest_center_distances(i, j) =
488 2.0 * (point_half_squared_norm + center.first);
489 nearest_center_indices(i, j) = center.second;
490 }
491 }
492 }
493 }
494
495 static void FindKNearestCentersBlockwise(
496 int64_t k, const Eigen::Ref<const MatrixXfRowMajor>& points,
497 const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
498 const Eigen::Ref<const MatrixXfRowMajor>& centers,
499 const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
500 Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
501 Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
502 const int64_t num_points = points.rows();
503 const int64_t num_centers = centers.rows();
504 DCHECK_LE(k, num_centers);
505 DCHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
506 // Store nearest neighbors with first block of centers directly into the
507 // output matrices.
508 int64_t out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
509 FindKNearestCentersOneBlock(
510 out_k, points, points_half_squared_norm,
511 centers.topRows(kNearestNeighborsCentersMaxBlockSize),
512 centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
513 nearest_center_indices, nearest_center_distances);
514 // Iteratively compute nearest neighbors with other blocks of centers, and
515 // update the output matrices.
516 MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
517 MatrixXfRowMajor block_nearest_center_distances(num_points, k);
518 Eigen::Matrix<int64_t, 1, Eigen::Dynamic> merged_indices(k);
519 Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
520 for (int64_t centers_start = kNearestNeighborsCentersMaxBlockSize;
521 centers_start < num_centers;
522 centers_start += kNearestNeighborsCentersMaxBlockSize) {
523 const int64_t centers_block_size = std::min(
524 kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
525 const int64_t block_k = std::min(k, centers_block_size);
526 FindKNearestCentersOneBlock(
527 block_k, points, points_half_squared_norm,
528 centers.middleRows(centers_start, centers_block_size),
529 centers_half_squared_norm.segment(centers_start, centers_block_size),
530 block_nearest_center_indices, block_nearest_center_distances);
531 if (k == 1) {
532 for (int i = 0; i < num_points; ++i) {
533 if (block_nearest_center_distances(i, 0) <
534 nearest_center_distances(i, 0)) {
535 nearest_center_indices(i, 0) =
536 block_nearest_center_indices(i, 0) + centers_start;
537 nearest_center_distances(i, 0) =
538 block_nearest_center_distances(i, 0);
539 }
540 }
541 } else {
542 for (int i = 0; i < num_points; ++i) {
543 // Merge and accumulate top-k list from block_nearest_center_indices
544 // into nearest_center_indices.
545 for (int64_t j_out = 0, j_block = 0, j_merged = 0;
546 (j_out < out_k || j_block < block_k) && j_merged < k;
547 ++j_merged) {
548 const float distance_out =
549 j_out < out_k ? nearest_center_distances(i, j_out)
550 : std::numeric_limits<float>::infinity();
551 const float distance_block =
552 j_block < block_k ? block_nearest_center_distances(i, j_block)
553 : std::numeric_limits<float>::infinity();
554 if (distance_out <= distance_block) {
555 merged_indices(j_merged) = nearest_center_indices(i, j_out);
556 merged_distances(j_merged) = distance_out;
557 ++j_out;
558 } else {
559 merged_indices(j_merged) =
560 block_nearest_center_indices(i, j_block) + centers_start;
561 merged_distances(j_merged) = distance_block;
562 ++j_block;
563 }
564 }
565 nearest_center_indices.row(i) = merged_indices;
566 nearest_center_distances.row(i) = merged_distances;
567 out_k = std::min(k, out_k + block_k);
568 }
569 }
570 }
571 }
572};
573
574REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
575 NearestNeighborsOp);
576
577} // namespace tensorflow
578