-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkmeans.py
More file actions
172 lines (145 loc) · 5.8 KB
/
kmeans.py
File metadata and controls
172 lines (145 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# kmeans.py
import csv
import sys
import os
import random
import math
import matplotlib.pyplot as plt
from rich import print
from rich.console import Console
from rich.prompt import Prompt, IntPrompt, FloatPrompt
from rich.panel import Panel
console = Console()
def load_data(path, header=True):
with open(path) as f:
reader = csv.reader(f)
headers = next(reader) if header else None
numeric_indices = None
text_attr = None
X = []
for row in reader:
if not row:
continue
if numeric_indices is None:
numeric_indices = []
for i, v in enumerate(row):
try:
float(v)
numeric_indices.append(i)
except ValueError:
if text_attr is None and headers:
text_attr = headers[i].lower()
if not numeric_indices:
console.print(f"[bold red]Error:[/] No se detectaron columnas numéricas en {row}")
sys.exit(1)
num_col_names = [headers[i].lower() for i in numeric_indices] if headers else [f"feature_{j}" for j in numeric_indices]
try:
numeric_vals = [float(row[i]) for i in numeric_indices]
X.append(numeric_vals)
except Exception:
console.print(f"[yellow]Fila inválida (omitida):[/] {row}")
return X, num_col_names, text_attr or "punto"
def euclidean(a, b):
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def initialize_centroids(X, k):
return random.sample(X, k)
def assign_clusters(X, centroids):
clusters = [[] for _ in centroids]
for x in X:
distances = [euclidean(x, c) for c in centroids]
idx = distances.index(min(distances))
clusters[idx].append(x)
return clusters
def recompute_centroids(clusters):
new_centroids = []
for cluster in clusters:
if not cluster:
new_centroids.append(None)
continue
m = len(cluster[0])
mean = [sum(point[i] for point in cluster) / len(cluster) for i in range(m)]
new_centroids.append(mean)
return new_centroids
def kmeans(X, k, max_iter=100, tol=1e-4):
centroids = initialize_centroids(X, k)
for i in range(max_iter):
clusters = assign_clusters(X, centroids)
new_centroids = recompute_centroids(clusters)
for idx, c in enumerate(new_centroids):
if c is None:
new_centroids[idx] = random.choice(X)
shifts = [euclidean(c, nc) for c, nc in zip(centroids, new_centroids)]
centroids = new_centroids
if max(shifts) < tol:
console.print(f"[green]🎉 Convergió en iteración {i}[/green]")
break
return centroids, clusters
def plot_clusters_2d(X, clusters, centroids):
colors = plt.get_cmap('tab10')
plt.figure()
for idx, cluster in enumerate(clusters):
xs = [p[0] for p in cluster]
ys = [p[1] for p in cluster]
plt.scatter(xs, ys, s=30, color=colors(idx), label=f'Cluster {idx}')
for idx, c in enumerate(centroids):
plt.scatter(c[0], c[1], marker='X', s=200, color=colors(idx), edgecolor='black', linewidth=2)
plt.title('K-means Clustering (2D)')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend()
plt.grid(True)
plt.show()
def main():
# Título
console.rule("[bold cyan]📊 K-means Interactivo[/bold cyan]")
raw = Prompt.ask("📂 Ruta al archivo CSV (o nombre dentro de 'data')").strip()
if os.path.exists(raw):
path = raw
else:
default = os.path.join('data', raw)
if os.path.exists(default):
path = default
else:
console.print(f"[red]Archivo no encontrado ni en '{raw}' ni en '{default}'[/red]")
sys.exit(1)
X, col_names, text_attr = load_data(path)
if not X:
console.print("[bold red]No hay datos válidos en el archivo.[/bold red]")
sys.exit(1)
dim = len(col_names)
k = IntPrompt.ask("🔢 Número de clusters (k)")
max_iter = IntPrompt.ask("🔁 Máximo iteraciones", default=100)
tol = FloatPrompt.ask("🎯 Tolerancia convergencia", default=1e-4)
console.rule("[bold green]📈 Resultados[/bold green]")
centroids, clusters = kmeans(X, k, max_iter, tol)
console.print("[bold magenta]Centroides finales:[/bold magenta]")
for idx, c in enumerate(centroids):
console.print(f"[cyan] Cluster {idx}[/cyan]: {c}")
sizes = [len(cl) for cl in clusters]
console.print(f"[bold blue]Tamaño de cada cluster:[/bold blue] {sizes}")
if dim == 2:
plot_clusters_2d(X, clusters, centroids)
else:
console.print(f"[yellow]⚠️ Atención:[/] no se grafica, dimensión = {dim} ≠ 2.")
console.print("[green]✅ Proceso completado.[/green]")
# Clasificación de nuevos puntos
while True:
ans = Prompt.ask(f"¿Deseas clasificar nuevo {text_attr}?", choices=["s","n"], default="n")
if ans != 's':
break
vals = Prompt.ask(f"✏️ Ingresa [bold]{', '.join(col_names)}[/bold] separados por coma").split(',')
try:
punto = [float(v) for v in vals]
if len(punto) != dim:
raise ValueError
distancias = [euclidean(punto, c) for c in centroids]
idx_c = distancias.index(min(distancias))
console.print(Panel.fit(
f"➡️ El {text_attr} pertenece al [bold green]cluster {idx_c}[/bold green]\n"
f"[dim]Centroide: {centroids[idx_c]}[/dim]",
title="Clasificación", border_style="green"
))
except:
console.print("[red]Entrada inválida.[/red]")
if __name__ == '__main__':
main()