Skip to content

Commit 3c637bd

Browse files
committed
feat(linfa-clustering): Add Hamerly's accelerated K-means algorithm
Implement K-means Hamerly's triangle-inequality optimization as an alternative to Lloyd's algorithm for K-means clustering. For each observation the algorithm maintains upper/lower distance bounds and skips centroid comparisons that cannot change the assignment, yielding the same results as Lloyd but with significantly fewer distance computations when clusters are well separated. Key changes: - The new Hamerly K-means algorithm - Add KMeansAlgorithm enum (Lloyd | Hamerly) and .algorithm() builder method - Reject Hamerly for incremental fit_with - Comprehensive tests
1 parent 1abc88f commit 3c637bd

5 files changed

Lines changed: 1006 additions & 39 deletions

File tree

algorithms/linfa-clustering/benches/k_means.rs

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use criterion::{
55
use linfa::benchmarks::config;
66
use linfa::prelude::*;
77
use linfa::DatasetBase;
8-
use linfa_clustering::{IncrKMeansError, KMeans, KMeansInit};
8+
use linfa_clustering::{IncrKMeansError, KMeans, KMeansAlgorithm, KMeansInit};
99
use linfa_datasets::generate;
1010
use ndarray::Array2;
1111
use ndarray_rand::RandomExt;
@@ -36,33 +36,38 @@ impl Drop for Stats {
3636
fn k_means_bench(c: &mut Criterion) {
3737
let mut rng = Xoshiro256Plus::seed_from_u64(40);
3838
let cluster_sizes = [(100, 4), (400, 10), (3000, 10)];
39+
let algorithms = [KMeansAlgorithm::Lloyd, KMeansAlgorithm::Hamerly];
3940
let n_features = 3;
4041

41-
let mut benchmark = c.benchmark_group("naive_k_means");
42+
let mut benchmark = c.benchmark_group("k_means");
4243
config::set_default_benchmark_configs(&mut benchmark);
4344
benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
4445

45-
for &(cluster_size, n_clusters) in &cluster_sizes {
46-
let rng = &mut rng;
47-
let centroids =
48-
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng);
49-
let dataset = DatasetBase::from(generate::blobs(cluster_size, &centroids, rng));
50-
let mut stats = Stats::default();
46+
for &algorithm in &algorithms {
47+
for &(cluster_size, n_clusters) in &cluster_sizes {
48+
let rng = &mut rng;
49+
let centroids =
50+
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng);
51+
let dataset = DatasetBase::from(generate::blobs(cluster_size, &centroids, rng));
52+
let mut stats = Stats::default();
5153

52-
benchmark.bench_function(
53-
BenchmarkId::new("naive_k_means", format!("{n_clusters}x{cluster_size}")),
54-
|bencher| {
55-
bencher.iter(|| {
56-
let m = KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone()))
57-
.init_method(KMeansInit::KMeansPlusPlus)
58-
.max_n_iterations(black_box(1000))
59-
.tolerance(black_box(1e-3))
60-
.fit(&dataset)
61-
.unwrap();
62-
stats.add(m.inertia());
63-
});
64-
},
65-
);
54+
benchmark.bench_function(
55+
BenchmarkId::new("k_means", format!("{algorithm:?}:{n_clusters}x{cluster_size}")),
56+
|bencher| {
57+
bencher.iter(|| {
58+
let m =
59+
KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone()))
60+
.init_method(KMeansInit::KMeansPlusPlus)
61+
.algorithm(algorithm)
62+
.max_n_iterations(black_box(1000))
63+
.tolerance(black_box(1e-3))
64+
.fit(&dataset)
65+
.unwrap();
66+
stats.add(m.inertia());
67+
});
68+
},
69+
);
70+
}
6671
}
6772

6873
benchmark.finish();

0 commit comments

Comments
 (0)