Skip to content

Commit 006c7a7

Browse files
committed
fix: format code
1 parent 3c637bd commit 006c7a7

2 files changed

Lines changed: 32 additions & 26 deletions

File tree

algorithms/linfa-clustering/benches/k_means.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ fn k_means_bench(c: &mut Criterion) {
5252
let mut stats = Stats::default();
5353

5454
benchmark.bench_function(
55-
BenchmarkId::new("k_means", format!("{algorithm:?}:{n_clusters}x{cluster_size}")),
55+
BenchmarkId::new(
56+
"k_means",
57+
format!("{algorithm:?}:{n_clusters}x{cluster_size}"),
58+
),
5659
|bencher| {
5760
bencher.iter(|| {
5861
let m =

algorithms/linfa-clustering/src/k_means/algorithm.rs

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::cmp::Ordering;
22
use std::fmt::Debug;
33

44
use crate::k_means::{KMeansParams, KMeansValidParams};
5-
use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError};
65
use crate::{k_means::errors::KMeansError, KMeansInit};
6+
use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError};
77
use linfa::{prelude::*, DatasetBase, Float};
88
use linfa_nn::distance::{Distance, L2Dist};
99
use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip};
@@ -256,11 +256,10 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
256256
let mut best_memberships = None;
257257

258258
for _ in 0..self.n_runs() {
259-
let centroids = self
260-
.init_method()
261-
.run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
262-
let mut hamerly =
263-
HamerlyAlgorithm::new(self.dist_fn(), observations, centroids);
259+
let centroids =
260+
self.init_method()
261+
.run(self.dist_fn(), self.n_clusters(), observations, &mut rng);
262+
let mut hamerly = HamerlyAlgorithm::new(self.dist_fn(), observations, centroids);
264263

265264
let mut n_iter = 0;
266265
let inertia = loop {
@@ -272,9 +271,7 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
272271

273272
let update = hamerly.recompute_centroids();
274273

275-
if update.convergence_dist < self.tolerance()
276-
|| n_iter == self.max_n_iterations()
277-
{
274+
if update.convergence_dist < self.tolerance() || n_iter == self.max_n_iterations() {
278275
break hamerly.inertia();
279276
}
280277

@@ -289,8 +286,7 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
289286
}
290287
}
291288

292-
let memberships =
293-
best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples()));
289+
let memberships = best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples()));
294290
self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships)
295291
}
296292

@@ -484,12 +480,10 @@ impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
484480
.par_for_each(|obs, membership, upper, lower, prev_slot| {
485481
let current = *membership;
486482
*prev_slot = current;
487-
let threshold =
488-
F::max(nearest_center_dists[current] / F::cast(2), *lower);
483+
let threshold = F::max(nearest_center_dists[current] / F::cast(2), *lower);
489484

490485
if *upper > threshold {
491-
*upper =
492-
dist_fn.distance(obs.view(), centroids.row(current).view());
486+
*upper = dist_fn.distance(obs.view(), centroids.row(current).view());
493487

494488
if *upper > threshold {
495489
let (idx, closest_dist, second_dist) =
@@ -548,8 +542,7 @@ impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
548542
}
549543

550544
fn update_bounds(&mut self, distances_moved: &Array1<F>) {
551-
let (farthest_moved_idx, second_farthest_moved_idx) =
552-
two_farthest_indices(distances_moved);
545+
let (farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices(distances_moved);
553546
Zip::from(&self.memberships)
554547
.and(&mut self.upper_bounds)
555548
.and(&mut self.lower_bounds)
@@ -1291,7 +1284,11 @@ mod tests {
12911284
.expect("Hamerly fitted");
12921285

12931286
assert_eq!(model_lloyd.centroids().nrows(), 6);
1294-
assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-4);
1287+
assert_abs_diff_eq!(
1288+
model_lloyd.inertia(),
1289+
model_hamerly.inertia(),
1290+
epsilon = 1e-4
1291+
);
12951292
assert_abs_diff_eq!(
12961293
sort_centroids(model_lloyd.centroids()),
12971294
sort_centroids(model_hamerly.centroids()),
@@ -1314,8 +1311,7 @@ mod tests {
13141311
// runs. Pre-compute centroids deterministically and pass them as Precomputed so
13151312
// both Lloyd and Hamerly start from the same initial centroids.
13161313
let mut rng = Xoshiro256Plus::seed_from_u64(99);
1317-
let xt =
1318-
Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
1314+
let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1));
13191315
let yt = function_test_1d(&xt);
13201316
let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
13211317
let dataset = DatasetBase::from(data);
@@ -1590,8 +1586,12 @@ mod tests {
15901586
fn test_hamerly_precomputed_centroids() {
15911587
let rng = Xoshiro256Plus::seed_from_u64(42);
15921588
let data = array![
1593-
[0.0, 0.0], [1.0, 0.0], [0.0, 1.0],
1594-
[10.0, 10.0], [11.0, 10.0], [10.0, 11.0]
1589+
[0.0, 0.0],
1590+
[1.0, 0.0],
1591+
[0.0, 1.0],
1592+
[10.0, 10.0],
1593+
[11.0, 10.0],
1594+
[10.0, 11.0]
15951595
];
15961596
let init_centroids = array![[0.0, 0.0], [10.0, 10.0]];
15971597
let dataset = DatasetBase::from(data);
@@ -1614,7 +1614,11 @@ mod tests {
16141614
model_hamerly.centroids(),
16151615
epsilon = 1e-1
16161616
);
1617-
assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-1);
1617+
assert_abs_diff_eq!(
1618+
model_lloyd.inertia(),
1619+
model_hamerly.inertia(),
1620+
epsilon = 1e-1
1621+
);
16181622
}
16191623

16201624
#[test]
@@ -1671,8 +1675,7 @@ mod tests {
16711675
#[test]
16721676
fn test_hamerly_high_dimensionality() {
16731677
let mut rng = Xoshiro256Plus::seed_from_u64(42);
1674-
let data: Array2<f64> =
1675-
Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng);
1678+
let data: Array2<f64> = Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng);
16761679
let dataset = DatasetBase::from(data);
16771680

16781681
let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist)

0 commit comments

Comments
 (0)