Skip to content

Commit 49c5019

Browse files
committed
Command for initial similarity-based clustering
1 parent fd8c967 commit 49c5019

12 files changed

Lines changed: 486 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ server = [
5151
"djangorestframework~=3.15.1",
5252
"google-cloud-bigquery",
5353
"pyyaml",
54+
"scikit-learn>=1.3.0",
55+
"sentence-transformers>=2.2.0",
5456
"whitenoise~=6.6.0",
5557
]
5658

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# This Source Code Form is subject to the terms of the Mozilla Public
2+
# License, v. 2.0. If a copy of the MPL was not distributed with this
3+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4+
import json
5+
from collections import defaultdict
6+
from datetime import timedelta
7+
from typing import Any
8+
from urllib.parse import urlsplit
9+
10+
from django.db import transaction
11+
from django.db.models import QuerySet
12+
from django.utils import timezone
13+
14+
from reportmanager.Clustering.SBERTClusterer import SBERTClusterer
15+
from reportmanager.models import Bucket, BucketHit, Cluster, ReportEntry
16+
from reportmanager.utils import preprocess_text
17+
18+
19+
class ClusteringConfig:
20+
HIGH_VOLUME_WINDOW_DAYS = 14
21+
HIGH_VOLUME_THRESHOLD = 20 # reports per week
22+
HIGH_VOLUME_DISTANCE_THRESHOLD = 0.30
23+
NORMAL_VOLUME_DISTANCE_THRESHOLD = 0.38
24+
BATCH_SIZE = 500
25+
CLUSTER_BUCKET_IDENTIFIER = "[Cluster"
26+
DEFAULT_BUCKET_PRIORITY = 0
27+
28+
29+
def batch_update_in_chunks(
30+
queryset: QuerySet,
31+
ids: list[int],
32+
batch_size: int = ClusteringConfig.BATCH_SIZE,
33+
**update_fields: Any,
34+
) -> int:
35+
total_updated = 0
36+
for i in range(0, len(ids), batch_size):
37+
batch_ids = ids[i : i + batch_size]
38+
count = queryset.filter(id__in=batch_ids).update(**update_fields)
39+
total_updated += count
40+
return total_updated
41+
42+
43+
def deduplicate_reports(reports: list[dict[str, Any]]) -> list[dict[str, Any]]:
44+
"""Remove exact word-for-word duplicates within each cluster."""
45+
46+
deduped = []
47+
48+
seen_texts = set()
49+
for report in reports:
50+
if report["text"] not in seen_texts:
51+
seen_texts.add(report["text"])
52+
deduped.append(report)
53+
54+
return deduped
55+
56+
57+
class ClusterBucketManager:
58+
def __init__(self, clusterer: SBERTClusterer | None = None) -> None:
59+
self.clusterer = clusterer or SBERTClusterer()
60+
61+
def fetch_reports(self) -> list[dict[str, Any]]:
62+
reports_qs = ReportEntry.objects.exclude(comments="").filter(
63+
ml_valid_probability__gt=0.03
64+
)
65+
66+
all_reports = list(
67+
reports_qs.values(
68+
"id",
69+
"comments",
70+
"comments_translated",
71+
"ml_valid_probability",
72+
"reported_at",
73+
"url",
74+
)
75+
)
76+
77+
return all_reports
78+
79+
def group_reports_by_domain(self, reports: list[dict]) -> dict[str, list[dict]]:
80+
reports_by_domain = defaultdict(list)
81+
82+
for report in reports:
83+
text = report["comments_translated"] or report["comments"]
84+
85+
if text and text.strip():
86+
try:
87+
parsed_url = urlsplit(report["url"])
88+
domain = parsed_url.hostname or "unknown"
89+
except Exception:
90+
domain = "unknown"
91+
92+
report["text"] = preprocess_text(text)
93+
report["domain"] = domain
94+
reports_by_domain[domain].append(report)
95+
96+
return reports_by_domain
97+
98+
def is_high_volume_domain(self, reports: list[dict]) -> bool:
99+
"""Determine if a domain is high-volume based on average weekly reports."""
100+
101+
report_count = len(reports)
102+
dates = [r["reported_at"] for r in reports]
103+
min_date = min(dates)
104+
max_date = max(dates)
105+
days_span = (max_date - min_date).days + 1
106+
avg_weekly_reports = (report_count / days_span) * 7
107+
return avg_weekly_reports > ClusteringConfig.HIGH_VOLUME_THRESHOLD
108+
109+
def filter_recent_reports(self, reports: list[dict], days: int) -> list[dict]:
110+
cutoff_date = timezone.now() - timedelta(days=days)
111+
return [r for r in reports if r["reported_at"] >= cutoff_date]
112+
113+
def group_reports_by_label(
114+
self, reports: list[dict], labels: list[int], embeddings: list
115+
) -> dict[int, dict[str, list]]:
116+
clusters_dict = defaultdict(lambda: {"reports": [], "embeddings": []})
117+
for label, report, embedding in zip(labels, reports, embeddings):
118+
clusters_dict[label]["reports"].append(report)
119+
clusters_dict[label]["embeddings"].append(embedding)
120+
return clusters_dict
121+
122+
def build_clusters(
123+
self,
124+
clusters_dict: dict[int, dict[str, list]],
125+
domain: str,
126+
) -> list[dict]:
127+
"""Create cluster objects with centroids and deduplicated reports."""
128+
129+
clusters = []
130+
for cluster_data in clusters_dict.values():
131+
centroid_id = self.clusterer.find_centroid_for_cluster(
132+
cluster_data["reports"], cluster_data["embeddings"]
133+
)
134+
clusters.append(
135+
{
136+
"centroid_id": centroid_id,
137+
"reports": deduplicate_reports(cluster_data["reports"]),
138+
"domain": domain,
139+
}
140+
)
141+
return clusters
142+
143+
def cluster_domain_reports(
144+
self,
145+
domain: str,
146+
reports: list[dict],
147+
) -> list[dict]:
148+
"""Cluster reports for a single domain."""
149+
150+
if len(reports) == 0:
151+
return []
152+
153+
# Calculate if this is a high-volume domain
154+
# and if so, only use reports in the last 14 days
155+
is_high_volume = self.is_high_volume_domain(reports)
156+
157+
if is_high_volume:
158+
reports = self.filter_recent_reports(
159+
reports, ClusteringConfig.HIGH_VOLUME_WINDOW_DAYS
160+
)
161+
162+
if len(reports) == 0:
163+
return []
164+
165+
# Use different thresholds for high vs normal volume
166+
threshold = (
167+
ClusteringConfig.HIGH_VOLUME_DISTANCE_THRESHOLD
168+
if is_high_volume
169+
else ClusteringConfig.NORMAL_VOLUME_DISTANCE_THRESHOLD
170+
)
171+
172+
labels, embeddings = self.clusterer.cluster(reports, threshold)
173+
174+
clusters_dict = self.group_reports_by_label(reports, labels, embeddings)
175+
clusters = self.build_clusters(clusters_dict, domain)
176+
177+
return clusters
178+
179+
def save_clusters(self, clusters: list[dict]) -> list[dict]:
180+
"""Save clusters to db and add cluster DB IDs to cluster dicts."""
181+
182+
with transaction.atomic():
183+
for cluster in clusters:
184+
cluster_obj = Cluster.objects.create(
185+
domain=cluster["domain"],
186+
centroid_id=cluster["centroid_id"],
187+
)
188+
189+
cluster["cluster_id"] = cluster_obj.id
190+
191+
report_ids_in_cluster = [r["id"] for r in cluster["reports"]]
192+
batch_update_in_chunks(
193+
ReportEntry.objects, report_ids_in_cluster, cluster=cluster_obj
194+
)
195+
196+
return clusters
197+
198+
def delete_existing_clusters(self) -> int:
199+
cluster_count = Cluster.objects.count()
200+
Cluster.objects.all().delete()
201+
return cluster_count
202+
203+
def delete_cluster_buckets(self) -> int:
204+
old_cluster_buckets = Bucket.objects.filter(
205+
description__contains=ClusteringConfig.CLUSTER_BUCKET_IDENTIFIER
206+
)
207+
208+
bucket_count = old_cluster_buckets.count()
209+
210+
# Unassign reports from these buckets (to avoid CASCADE delete)
211+
ReportEntry.objects.filter(bucket__in=old_cluster_buckets).update(bucket=None)
212+
213+
# Besides clusters this would delete related BucketHit and BucketWatch records
214+
old_cluster_buckets.delete()
215+
216+
return bucket_count
217+
218+
def create_cluster_bucket_signature(self, domain: str, cluster_id: int) -> str:
219+
"""Create a signature JSON for a cluster bucket."""
220+
221+
signature = {
222+
"symptoms": [
223+
{"type": "url", "part": "hostname", "value": domain},
224+
{"type": "cluster_id", "value": str(cluster_id)},
225+
]
226+
}
227+
return json.dumps(signature, sort_keys=True)
228+
229+
def update_bucket_hits(self, reports_to_move, new_bucket_id: int):
230+
"""Update BucketHit counts when moving reports to a new bucket."""
231+
232+
for report in reports_to_move.values("reported_at", "bucket_id"):
233+
if report["bucket_id"]:
234+
BucketHit.decrement_count(report["bucket_id"], report["reported_at"])
235+
BucketHit.increment_count(new_bucket_id, report["reported_at"])
236+
237+
def create_bucket_for_cluster(
238+
self, domain: str, cluster_id: int, report_ids: list[int]
239+
) -> None:
240+
"""Create a new bucket for a cluster and reassign reports."""
241+
242+
signature = self.create_cluster_bucket_signature(domain, cluster_id)
243+
244+
with transaction.atomic():
245+
new_bucket = Bucket.objects.create(
246+
description=f"{domain} {ClusteringConfig.CLUSTER_BUCKET_IDENTIFIER} {cluster_id}]", # noqa
247+
signature=signature,
248+
priority=ClusteringConfig.DEFAULT_BUCKET_PRIORITY,
249+
color=None,
250+
bug=None,
251+
domain=domain,
252+
)
253+
254+
# Reassign reports to new bucket
255+
reports_to_move = ReportEntry.objects.filter(id__in=report_ids)
256+
self.update_bucket_hits(reports_to_move, new_bucket.id)
257+
reports_to_move.update(bucket=new_bucket)
258+
259+
def create_buckets_from_clusters(self, all_clusters: list[dict]) -> int:
260+
buckets_created = 0
261+
for cluster_data in all_clusters:
262+
report_ids = [r["id"] for r in cluster_data["reports"]]
263+
264+
if not report_ids:
265+
continue
266+
267+
self.create_bucket_for_cluster(
268+
cluster_data["domain"], cluster_data["cluster_id"], report_ids
269+
)
270+
buckets_created += 1
271+
272+
return buckets_created
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# This Source Code Form is subject to the terms of the Mozilla Public
2+
# License, v. 2.0. If a copy of the MPL was not distributed with this
3+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4+
from collections.abc import Sequence
5+
from typing import Any
6+
7+
from sentence_transformers import SentenceTransformer
8+
from sklearn.cluster import AgglomerativeClustering
9+
from sklearn.metrics.pairwise import cosine_similarity
10+
11+
12+
class SBERTClusterer:
13+
"""Clusters similar reports using SBERT embeddings."""
14+
15+
def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5") -> None:
16+
self.model: SentenceTransformer = SentenceTransformer(model_name)
17+
18+
def cluster(
19+
self, reports: list[dict[str, Any]], distance_threshold: float = 0.38
20+
) -> tuple[Any, Any]:
21+
"""Cluster reports using SBERT embeddings."""
22+
23+
if len(reports) < 2:
24+
return [0], []
25+
26+
texts = [report["text"] for report in reports]
27+
embeddings = self.model.encode(
28+
texts, show_progress_bar=False, normalize_embeddings=True
29+
)
30+
31+
similarities = cosine_similarity(embeddings)
32+
distances = 1 - similarities
33+
34+
clustering = AgglomerativeClustering(
35+
n_clusters=None,
36+
distance_threshold=distance_threshold,
37+
metric="precomputed",
38+
linkage="average",
39+
)
40+
41+
labels = clustering.fit_predict(distances)
42+
43+
return labels, embeddings
44+
45+
def euclidean_distance(self, emb1: Sequence[float], emb2: Sequence[float]) -> float:
46+
"""Calculate Euclidean distance between two embeddings."""
47+
return sum((a - b) ** 2 for a, b in zip(emb1, emb2)) ** 0.5
48+
49+
def find_centroid_for_cluster(
50+
self, reports: list[dict[str, Any]], embeddings: Any
51+
) -> int:
52+
"""Find the report closest to the centroid."""
53+
54+
# Early return for single-report clusters
55+
if len(reports) == 1:
56+
return reports[0]["id"]
57+
58+
report_embeddings = list(embeddings)
59+
embedding_dim = len(report_embeddings[0])
60+
centroid = [
61+
sum(emb[j] for emb in report_embeddings) / len(report_embeddings)
62+
for j in range(embedding_dim)
63+
]
64+
65+
# Find report closest to centroid
66+
distances = [
67+
self.euclidean_distance(emb, centroid) for emb in report_embeddings
68+
]
69+
closest_idx = distances.index(min(distances))
70+
71+
return reports[closest_idx]["id"]

server/reportmanager/clustering/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)