Skip to content

Commit 000ce40

Browse files
committed
Clustering tests
1 parent dc762ce commit 000ce40

2 files changed

Lines changed: 405 additions & 0 deletions

File tree

tests/test_clustering.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""Tests for clustering functionality."""
2+
3+
import json
4+
from datetime import timedelta
5+
from unittest.mock import Mock, patch
6+
7+
import numpy as np
8+
import pytest
9+
from django.utils import timezone
10+
11+
from reportmanager.clustering.ClusterBucketManager import (
12+
ClusterBucketManager,
13+
ClusterGroup,
14+
ClusterReport,
15+
)
16+
from reportmanager.clustering.SBERTClusterer import SBERTClusterer
17+
18+
19+
@pytest.fixture
20+
def mock_clusterer():
21+
"""Create a mock SBERTClusterer for testing."""
22+
clusterer = Mock(spec=SBERTClusterer)
23+
return clusterer
24+
25+
26+
@pytest.fixture
27+
def manager(mock_clusterer):
28+
"""Create a ClusterBucketManager with a mock clusterer."""
29+
with patch(
30+
"reportmanager.clustering.ClusterBucketManager.SBERTClusterer",
31+
return_value=mock_clusterer,
32+
):
33+
return ClusterBucketManager()
34+
35+
36+
@pytest.fixture
37+
def sample_reports():
38+
"""Create sample ClusterReport objects for testing."""
39+
now = timezone.now()
40+
return [
41+
ClusterReport(
42+
id=1,
43+
ml_valid_probability=0.8,
44+
reported_at=now,
45+
url="https://example.com/page1",
46+
bucket_id=None,
47+
text="Page loading issue",
48+
domain="example.com",
49+
ok_to_cluster=True,
50+
),
51+
ClusterReport(
52+
id=2,
53+
ml_valid_probability=0.7,
54+
reported_at=now - timedelta(days=1),
55+
url="https://example.com/page2",
56+
bucket_id=None,
57+
text="Page not loading correctly",
58+
domain="example.com",
59+
ok_to_cluster=True,
60+
),
61+
ClusterReport(
62+
id=3,
63+
ml_valid_probability=0.9,
64+
reported_at=now - timedelta(days=2),
65+
url="https://other.com/page",
66+
bucket_id=None,
67+
text="Different issue",
68+
domain="other.com",
69+
ok_to_cluster=True,
70+
),
71+
]
72+
73+
74+
class TestClusterBucketManager:
75+
"""Tests for ClusterBucketManager."""
76+
77+
def test_ok_to_cluster(self):
78+
"""Test ok_to_cluster for various scenarios."""
79+
assert ClusterBucketManager.ok_to_cluster("Some text", 0.05) is True
80+
assert ClusterBucketManager.ok_to_cluster("Some text", 0.5) is True
81+
assert ClusterBucketManager.ok_to_cluster("", 0.5) is False
82+
assert ClusterBucketManager.ok_to_cluster(" ", 0.5) is False
83+
assert ClusterBucketManager.ok_to_cluster("Some text", 0.03) is False
84+
85+
def test_group_reports_by_domain(self, manager, sample_reports):
86+
"""Test grouping reports by domain."""
87+
reports_by_domain = manager.group_reports_by_domain(sample_reports)
88+
89+
assert "example.com" in reports_by_domain
90+
assert "other.com" in reports_by_domain
91+
assert len(reports_by_domain["example.com"]) == 2
92+
assert len(reports_by_domain["other.com"]) == 1
93+
94+
def test_group_reports_by_domain_filters_domains(self, manager, sample_reports):
95+
"""Test grouping reports with domain filter."""
96+
reports_by_domain = manager.group_reports_by_domain(
97+
sample_reports, domains=["example.com"]
98+
)
99+
100+
assert "example.com" in reports_by_domain
101+
assert "other.com" not in reports_by_domain
102+
103+
def test_group_reports_by_domain_skips_not_ok_to_cluster(self, manager):
104+
"""Test grouping skips reports that are not ok to cluster."""
105+
reports = [
106+
ClusterReport(
107+
id=1,
108+
ml_valid_probability=0.8,
109+
reported_at=timezone.now(),
110+
url="https://example.com",
111+
bucket_id=None,
112+
text="Valid text",
113+
domain="example.com",
114+
ok_to_cluster=True,
115+
),
116+
ClusterReport(
117+
id=2,
118+
ml_valid_probability=0.02,
119+
reported_at=timezone.now(),
120+
url="https://example.com",
121+
bucket_id=None,
122+
text="",
123+
domain="example.com",
124+
ok_to_cluster=False,
125+
),
126+
]
127+
128+
reports_by_domain = manager.group_reports_by_domain(reports)
129+
assert len(reports_by_domain["example.com"]) == 1
130+
131+
def test_is_high_volume_domain_with_high_volume(self, manager):
132+
"""Test detecting high-volume domains."""
133+
now = timezone.now()
134+
# Create 30 reports over 7 days = 30 reports/week (> 20 threshold)
135+
reports = [
136+
ClusterReport(
137+
id=i,
138+
ml_valid_probability=0.8,
139+
reported_at=now - timedelta(days=i % 7),
140+
url="https://example.com",
141+
bucket_id=None,
142+
text=f"Report {i}",
143+
domain="example.com",
144+
ok_to_cluster=True,
145+
)
146+
for i in range(30)
147+
]
148+
149+
assert manager.is_high_volume_domain(reports) is True
150+
151+
def test_is_high_volume_domain_with_normal_volume(self, manager):
152+
"""Test detecting normal-volume domains."""
153+
now = timezone.now()
154+
# Create 10 reports over 7 days = 10 reports/week (< 20 threshold)
155+
reports = [
156+
ClusterReport(
157+
id=i,
158+
ml_valid_probability=0.8,
159+
reported_at=now - timedelta(days=i % 7),
160+
url="https://example.com",
161+
bucket_id=None,
162+
text=f"Report {i}",
163+
domain="example.com",
164+
ok_to_cluster=True,
165+
)
166+
for i in range(10)
167+
]
168+
169+
assert manager.is_high_volume_domain(reports) is False
170+
171+
def test_filter_recent_reports(self, manager):
172+
"""Test filtering reports by date."""
173+
now = timezone.now()
174+
reports = [
175+
ClusterReport(
176+
id=1,
177+
ml_valid_probability=0.8,
178+
reported_at=now - timedelta(days=5),
179+
url="https://example.com",
180+
bucket_id=None,
181+
text="Recent",
182+
domain="example.com",
183+
ok_to_cluster=True,
184+
),
185+
ClusterReport(
186+
id=2,
187+
ml_valid_probability=0.8,
188+
reported_at=now - timedelta(days=20),
189+
url="https://example.com",
190+
bucket_id=None,
191+
text="Old",
192+
domain="example.com",
193+
ok_to_cluster=True,
194+
),
195+
]
196+
197+
recent = manager.filter_recent_reports(reports, days=14)
198+
assert len(recent) == 1
199+
assert recent[0].id == 1
200+
201+
def test_group_reports_by_label(self, manager, sample_reports):
202+
"""Test grouping reports by cluster labels."""
203+
labels = np.array([0, 0, 1])
204+
embeddings = np.array([[0.1, 0.2], [0.15, 0.25], [0.9, 0.8]])
205+
206+
cluster_groups = manager.group_reports_by_label(
207+
sample_reports, labels, embeddings
208+
)
209+
210+
assert len(cluster_groups) == 2
211+
groups_by_size = sorted(
212+
cluster_groups, key=lambda g: len(g.reports), reverse=True
213+
)
214+
assert len(groups_by_size[0].reports) == 2
215+
assert len(groups_by_size[0].embeddings) == 2
216+
assert len(groups_by_size[1].reports) == 1
217+
218+
def test_build_clusters_skips_low_probability_single_reports(
219+
self, manager, mock_clusterer
220+
):
221+
"""Test that single-report clusters with low probability are skipped."""
222+
mock_clusterer.find_centroid_index.return_value = 0
223+
224+
cluster_groups = [
225+
ClusterGroup(
226+
reports=[
227+
ClusterReport(
228+
id=1,
229+
ml_valid_probability=0.3, # Below threshold of 0.60
230+
reported_at=timezone.now(),
231+
url="https://example.com",
232+
bucket_id=None,
233+
text="Low quality",
234+
domain="example.com",
235+
ok_to_cluster=True,
236+
)
237+
],
238+
embeddings=np.array([[0.1, 0.2]]),
239+
)
240+
]
241+
242+
clusters = manager.build_clusters(cluster_groups, "example.com")
243+
assert len(clusters) == 0
244+
245+
def test_build_clusters_keeps_high_probability_single_reports(
246+
self, manager, mock_clusterer
247+
):
248+
"""Test that single-report clusters with high probability are kept."""
249+
mock_clusterer.find_centroid_index.return_value = 0
250+
251+
cluster_groups = [
252+
ClusterGroup(
253+
reports=[
254+
ClusterReport(
255+
id=1,
256+
ml_valid_probability=0.9, # Above threshold of 0.60
257+
reported_at=timezone.now(),
258+
url="https://example.com",
259+
bucket_id=None,
260+
text="High quality",
261+
domain="example.com",
262+
ok_to_cluster=True,
263+
)
264+
],
265+
embeddings=np.array([[0.1, 0.2]]),
266+
)
267+
]
268+
269+
clusters = manager.build_clusters(cluster_groups, "example.com")
270+
assert len(clusters) == 1
271+
assert clusters[0].centroid_id == 1
272+
273+
def test_build_clusters_keeps_multi_report_clusters(self, manager, mock_clusterer):
274+
"""Test that multi-report clusters are kept regardless of probability."""
275+
mock_clusterer.find_centroid_index.return_value = 1
276+
277+
cluster_groups = [
278+
ClusterGroup(
279+
reports=[
280+
ClusterReport(
281+
id=1,
282+
ml_valid_probability=0.3,
283+
reported_at=timezone.now(),
284+
url="https://example.com",
285+
bucket_id=None,
286+
text="Report 1",
287+
domain="example.com",
288+
ok_to_cluster=True,
289+
),
290+
ClusterReport(
291+
id=2,
292+
ml_valid_probability=0.4,
293+
reported_at=timezone.now(),
294+
url="https://example.com",
295+
bucket_id=None,
296+
text="Report 2",
297+
domain="example.com",
298+
ok_to_cluster=True,
299+
),
300+
],
301+
embeddings=np.array([[0.1, 0.2], [0.15, 0.25]]),
302+
)
303+
]
304+
305+
clusters = manager.build_clusters(cluster_groups, "example.com")
306+
assert len(clusters) == 1
307+
assert clusters[0].centroid_id == 2
308+
309+
def test_build_cluster_bucket_signature(self, manager):
310+
"""Test building bucket signature for a cluster."""
311+
signature_str = manager.build_cluster_bucket_signature("example.com", 123)
312+
signature = json.loads(signature_str)
313+
314+
assert "symptoms" in signature
315+
assert len(signature["symptoms"]) == 2
316+
assert signature["symptoms"][0]["type"] == "url"
317+
assert signature["symptoms"][0]["part"] == "hostname"
318+
assert signature["symptoms"][0]["value"] == "example.com"
319+
assert signature["symptoms"][1]["type"] == "cluster_id"
320+
assert signature["symptoms"][1]["value"] == "123"

0 commit comments

Comments
 (0)