diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6a427f0..aed995e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,13 +106,9 @@ jobs: # Qdrant service container qdrant: image: qdrant/qdrant:v1.6.1 - options: >- - --health-cmd "curl -f http://localhost:6333/health || exit 1" - --health-interval 30s - --health-timeout 10s - --health-retries 3 ports: - 6333:6333 + - 6334:6334 steps: - name: Checkout code @@ -143,15 +139,23 @@ jobs: - name: Wait for Qdrant run: | - for i in {1..10}; do - if curl -f http://localhost:6333/health; then - echo "Qdrant is ready" + echo "Waiting for Qdrant to be ready..." + for i in {1..20}; do + if curl -s -f http://localhost:6333/health > /dev/null 2>&1; then + echo "Qdrant is ready!" + curl http://localhost:6333/health break fi - echo "Waiting for Qdrant... attempt $i" + echo "Waiting for Qdrant... attempt $i/20" sleep 3 done + # Final check + if ! curl -s -f http://localhost:6333/health > /dev/null 2>&1; then + echo "ERROR: Qdrant failed to become ready" + exit 1 + fi + - name: Run integration tests env: REDIS_HOST: localhost diff --git a/app/benchmarks/__init__.py b/app/benchmarks/__init__.py new file mode 100644 index 0000000..c1bcae8 --- /dev/null +++ b/app/benchmarks/__init__.py @@ -0,0 +1,13 @@ +""" +Benchmarking utilities for performance testing. + +This module provides tools for measuring and reporting performance metrics. +""" + +from app.benchmarks.qdrant_benchmark import ( + BenchmarkMetrics, + BenchmarkResult, + QdrantBenchmark, +) + +__all__ = ["QdrantBenchmark", "BenchmarkResult", "BenchmarkMetrics"] diff --git a/app/benchmarks/qdrant_benchmark.py b/app/benchmarks/qdrant_benchmark.py new file mode 100644 index 0000000..8ce6aa2 --- /dev/null +++ b/app/benchmarks/qdrant_benchmark.py @@ -0,0 +1,484 @@ +""" +Qdrant performance benchmarking utilities. + +Sandi Metz Principles: +- Single Responsibility: Performance measurement +- Small methods: Each benchmark isolated +- Clear naming: Descriptive method names +""" + +import asyncio +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List + +from app.models.qdrant_point import QdrantPoint +from app.repositories.qdrant_repository import QdrantRepository +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class BenchmarkMetrics: + """Metrics collected during benchmark.""" + + operation: str + total_operations: int + total_time: float + operations_per_second: float + avg_latency_ms: float + min_latency_ms: float + max_latency_ms: float + p50_latency_ms: float + p95_latency_ms: float + p99_latency_ms: float + success_count: int + error_count: int + metadata: Dict[str, Any] = field(default_factory=dict) + + def __str__(self) -> str: + """Format metrics as readable string.""" + return ( + f"\n{self.operation} Benchmark Results:\n" + f" Total Operations: {self.total_operations}\n" + f" Total Time: {self.total_time:.2f}s\n" + f" Throughput: {self.operations_per_second:.2f} ops/sec\n" + f" Avg Latency: {self.avg_latency_ms:.2f}ms\n" + f" Min Latency: {self.min_latency_ms:.2f}ms\n" + f" Max Latency: {self.max_latency_ms:.2f}ms\n" + f" P50 Latency: {self.p50_latency_ms:.2f}ms\n" + f" P95 Latency: {self.p95_latency_ms:.2f}ms\n" + f" P99 Latency: {self.p99_latency_ms:.2f}ms\n" + f" Success: {self.success_count} | Errors: {self.error_count}" + ) + + +@dataclass +class BenchmarkResult: + """Complete benchmark results.""" + + benchmark_name: str + metrics: List[BenchmarkMetrics] + duration: float + timestamp: float + metadata: Dict[str, Any] = field(default_factory=dict) + + def summary(self) -> str: + """Generate summary report.""" + lines = [ + f"\n{'=' * 60}", + f"Benchmark: {self.benchmark_name}", + f"Duration: {self.duration:.2f}s", + f"Timestamp: {self.timestamp}", + f"{'=' * 60}", + ] + + for metric in self.metrics: + lines.append(str(metric)) + lines.append("-" * 60) + + return "\n".join(lines) + + +class QdrantBenchmark: + """ + Performance benchmarking for Qdrant operations. + + Measures throughput, latency, and resource usage. + """ + + def __init__(self, repository: QdrantRepository): + """ + Initialize benchmark. + + Args: + repository: Qdrant repository instance + """ + self._repository = repository + self._logger = logger + + async def benchmark_operation( + self, + operation_name: str, + operation_func: Callable, + iterations: int = 100, + **kwargs: Any, + ) -> BenchmarkMetrics: + """ + Benchmark a single operation. + + Args: + operation_name: Name of operation + operation_func: Async function to benchmark + iterations: Number of iterations + **kwargs: Additional operation arguments + + Returns: + BenchmarkMetrics with results + """ + latencies: List[float] = [] + success_count = 0 + error_count = 0 + + self._logger.info( + "Starting benchmark", + operation=operation_name, + iterations=iterations, + ) + + start_time = time.time() + + for i in range(iterations): + op_start = time.time() + try: + await operation_func(**kwargs) + success_count += 1 + except Exception as e: + error_count += 1 + self._logger.warning( + "Operation failed", + operation=operation_name, + iteration=i, + error=str(e), + ) + + op_end = time.time() + latencies.append((op_end - op_start) * 1000) # Convert to ms + + end_time = time.time() + total_time = end_time - start_time + + # Calculate percentiles + sorted_latencies = sorted(latencies) + p50_idx = int(len(sorted_latencies) * 0.50) + p95_idx = int(len(sorted_latencies) * 0.95) + p99_idx = int(len(sorted_latencies) * 0.99) + + metrics = BenchmarkMetrics( + operation=operation_name, + total_operations=iterations, + total_time=total_time, + operations_per_second=iterations / total_time if total_time > 0 else 0, + avg_latency_ms=sum(latencies) / len(latencies) if latencies else 0, + min_latency_ms=min(latencies) if latencies else 0, + max_latency_ms=max(latencies) if latencies else 0, + p50_latency_ms=sorted_latencies[p50_idx] if sorted_latencies else 0, + p95_latency_ms=sorted_latencies[p95_idx] if sorted_latencies else 0, + p99_latency_ms=sorted_latencies[p99_idx] if sorted_latencies else 0, + success_count=success_count, + error_count=error_count, + ) + + self._logger.info("Benchmark completed", operation=operation_name) + return metrics + + async def benchmark_insert( + self, num_points: int = 1000, vector_dim: int = 384 + ) -> BenchmarkMetrics: + """ + Benchmark point insertion. + + Args: + num_points: Number of points to insert + vector_dim: Vector dimensions + + Returns: + BenchmarkMetrics for insertions + """ + + async def insert_point(point_id: str, vector: List[float]) -> None: + point = QdrantPoint( + id=point_id, + vector=vector, + payload={"benchmark": True, "index": point_id}, + ) + await self._repository.store_point(point) + + # Generate test data + test_vectors = [[0.1 * (i % 100)] * vector_dim for i in range(num_points)] + + latencies: List[float] = [] + start_time = time.time() + + for i, vector in enumerate(test_vectors): + op_start = time.time() + await insert_point(f"bench_insert_{i}", vector) + op_end = time.time() + latencies.append((op_end - op_start) * 1000) + + end_time = time.time() + total_time = end_time - start_time + + sorted_latencies = sorted(latencies) + p50_idx = int(len(sorted_latencies) * 0.50) + p95_idx = int(len(sorted_latencies) * 0.95) + p99_idx = int(len(sorted_latencies) * 0.99) + + return BenchmarkMetrics( + operation="insert", + total_operations=num_points, + total_time=total_time, + operations_per_second=num_points / total_time, + avg_latency_ms=sum(latencies) / len(latencies), + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + p50_latency_ms=sorted_latencies[p50_idx], + p95_latency_ms=sorted_latencies[p95_idx], + p99_latency_ms=sorted_latencies[p99_idx], + success_count=num_points, + error_count=0, + metadata={"vector_dim": vector_dim}, + ) + + async def benchmark_batch_insert( + self, + num_points: int = 1000, + batch_size: int = 100, + vector_dim: int = 384, + ) -> BenchmarkMetrics: + """ + Benchmark batch insertion. + + Args: + num_points: Total points to insert + batch_size: Points per batch + vector_dim: Vector dimensions + + Returns: + BenchmarkMetrics for batch insertions + """ + num_batches = (num_points + batch_size - 1) // batch_size + latencies: List[float] = [] + start_time = time.time() + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, num_points) + batch_points = [ + QdrantPoint( + id=f"bench_batch_{i}", + vector=[0.1 * (i % 100)] * vector_dim, + payload={"benchmark": True, "batch": batch_idx}, + ) + for i in range(start_idx, end_idx) + ] + + op_start = time.time() + await self._repository.store_points(batch_points) + op_end = time.time() + latencies.append((op_end - op_start) * 1000) + + end_time = time.time() + total_time = end_time - start_time + + sorted_latencies = sorted(latencies) + p50_idx = int(len(sorted_latencies) * 0.50) + p95_idx = int(len(sorted_latencies) * 0.95) + p99_idx = int(len(sorted_latencies) * 0.99) + + return BenchmarkMetrics( + operation="batch_insert", + total_operations=num_batches, + total_time=total_time, + operations_per_second=num_batches / total_time, + avg_latency_ms=sum(latencies) / len(latencies), + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + p50_latency_ms=sorted_latencies[p50_idx], + p95_latency_ms=sorted_latencies[p95_idx], + p99_latency_ms=sorted_latencies[p99_idx], + success_count=num_batches, + error_count=0, + metadata={ + "total_points": num_points, + "batch_size": batch_size, + "vector_dim": vector_dim, + }, + ) + + async def benchmark_search( + self, + num_searches: int = 100, + vector_dim: int = 384, + limit: int = 10, + ) -> BenchmarkMetrics: + """ + Benchmark similarity search. + + Args: + num_searches: Number of searches to perform + vector_dim: Vector dimensions + limit: Results per search + + Returns: + BenchmarkMetrics for searches + """ + query_vector = [0.1] * vector_dim + latencies: List[float] = [] + start_time = time.time() + + for _ in range(num_searches): + op_start = time.time() + await self._repository.search_similar(query_vector, limit=limit) + op_end = time.time() + latencies.append((op_end - op_start) * 1000) + + end_time = time.time() + total_time = end_time - start_time + + sorted_latencies = sorted(latencies) + p50_idx = int(len(sorted_latencies) * 0.50) + p95_idx = int(len(sorted_latencies) * 0.95) + p99_idx = int(len(sorted_latencies) * 0.99) + + return BenchmarkMetrics( + operation="search", + total_operations=num_searches, + total_time=total_time, + operations_per_second=num_searches / total_time, + avg_latency_ms=sum(latencies) / len(latencies), + min_latency_ms=min(latencies), + max_latency_ms=max(latencies), + p50_latency_ms=sorted_latencies[p50_idx], + p95_latency_ms=sorted_latencies[p95_idx], + p99_latency_ms=sorted_latencies[p99_idx], + success_count=num_searches, + error_count=0, + metadata={"vector_dim": vector_dim, "result_limit": limit}, + ) + + async def benchmark_concurrent_operations( + self, + num_operations: int = 100, + concurrency: int = 10, + vector_dim: int = 384, + ) -> BenchmarkMetrics: + """ + Benchmark concurrent operations. + + Args: + num_operations: Total operations + concurrency: Concurrent tasks + vector_dim: Vector dimensions + + Returns: + BenchmarkMetrics for concurrent ops + """ + + async def concurrent_insert(idx: int) -> None: + point = QdrantPoint( + id=f"bench_concurrent_{idx}", + vector=[0.1 * idx] * vector_dim, + payload={"benchmark": True, "index": idx}, + ) + await self._repository.store_point(point) + + start_time = time.time() + tasks = [concurrent_insert(i) for i in range(num_operations)] + + # Execute in batches of 'concurrency' + for i in range(0, len(tasks), concurrency): + batch = tasks[i : i + concurrency] + await asyncio.gather(*batch, return_exceptions=True) + + end_time = time.time() + total_time = end_time - start_time + + return BenchmarkMetrics( + operation="concurrent_insert", + total_operations=num_operations, + total_time=total_time, + operations_per_second=num_operations / total_time, + avg_latency_ms=(total_time / num_operations) * 1000, + min_latency_ms=0, + max_latency_ms=0, + p50_latency_ms=0, + p95_latency_ms=0, + p99_latency_ms=0, + success_count=num_operations, + error_count=0, + metadata={"concurrency": concurrency, "vector_dim": vector_dim}, + ) + + async def run_full_benchmark( + self, + small_dataset: bool = True, + medium_dataset: bool = False, + large_dataset: bool = False, + ) -> BenchmarkResult: + """ + Run comprehensive benchmark suite. + + Args: + small_dataset: Run small dataset tests (1K points) + medium_dataset: Run medium dataset tests (10K points) + large_dataset: Run large dataset tests (100K points) + + Returns: + BenchmarkResult with all metrics + """ + start_time = time.time() + all_metrics: List[BenchmarkMetrics] = [] + + self._logger.info("Starting full benchmark suite") + + if small_dataset: + self._logger.info("Running small dataset benchmarks (1K points)") + all_metrics.append(await self.benchmark_insert(1000, 384)) + all_metrics.append(await self.benchmark_batch_insert(1000, 100, 384)) + all_metrics.append(await self.benchmark_search(100, 384, 10)) + all_metrics.append(await self.benchmark_concurrent_operations(100, 10, 384)) + + if medium_dataset: + self._logger.info("Running medium dataset benchmarks (10K points)") + all_metrics.append(await self.benchmark_insert(10000, 384)) + all_metrics.append(await self.benchmark_batch_insert(10000, 500, 384)) + all_metrics.append(await self.benchmark_search(200, 384, 10)) + + if large_dataset: + self._logger.info("Running large dataset benchmarks (100K points)") + all_metrics.append(await self.benchmark_batch_insert(100000, 1000, 384)) + all_metrics.append(await self.benchmark_search(500, 384, 10)) + + end_time = time.time() + duration = end_time - start_time + + result = BenchmarkResult( + benchmark_name="Qdrant Full Suite", + metrics=all_metrics, + duration=duration, + timestamp=start_time, + metadata={ + "small_dataset": small_dataset, + "medium_dataset": medium_dataset, + "large_dataset": large_dataset, + }, + ) + + self._logger.info( + "Benchmark suite completed", + duration=duration, + num_metrics=len(all_metrics), + ) + + return result + + +# Convenience function +async def run_quick_benchmark(repository: QdrantRepository) -> BenchmarkResult: + """ + Run quick performance benchmark. + + Args: + repository: Qdrant repository + + Returns: + BenchmarkResult + """ + benchmark = QdrantBenchmark(repository) + return await benchmark.run_full_benchmark( + small_dataset=True, + medium_dataset=False, + large_dataset=False, + ) diff --git a/app/cache/qdrant_backup.py b/app/cache/qdrant_backup.py new file mode 100644 index 0000000..3eb3b5d --- /dev/null +++ b/app/cache/qdrant_backup.py @@ -0,0 +1,428 @@ +""" +Qdrant collection backup and restore functionality. + +Sandi Metz Principles: +- Single Responsibility: Backup/restore operations +- Small methods: Each operation isolated +- Clear naming: Descriptive method names +""" + +import json +from pathlib import Path +from typing import Dict, List, Optional + +from qdrant_client import AsyncQdrantClient + +from app.models.qdrant_point import QdrantPoint +from app.repositories.qdrant_repository import QdrantRepository +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class BackupFormat: + """Backup file format constants.""" + + JSON = "json" + JSONL = "jsonl" # JSON Lines (one JSON object per line) + + +class QdrantBackup: + """ + Backup and restore Qdrant collections. + + Handles export/import of collection data. + """ + + def __init__(self, repository: QdrantRepository): + """ + Initialize backup manager. + + Args: + repository: Qdrant repository instance + """ + self._repository = repository + + async def backup_to_file( + self, + file_path: str, + format: str = BackupFormat.JSONL, + batch_size: int = 100, + ) -> bool: + """ + Backup collection to file. + + Args: + file_path: Path to backup file + format: Backup format (json or jsonl) + batch_size: Batch size for scrolling + + Returns: + True if successful + """ + try: + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + + logger.info( + "Starting collection backup", + file=file_path, + format=format, + ) + + # Get all points using pagination + all_points: List[Dict] = [] + offset: Optional[str] = None + + while True: + points, next_offset = await self._repository.scroll_points( + limit=batch_size, + offset=offset, + with_vectors=True, + ) + + if not points: + break + + # Convert points to dict format + for point in points: + all_points.append( + { + "id": point.id, + "vector": point.vector, + "payload": point.payload, + } + ) + + logger.debug(f"Backed up {len(points)} points") + + if next_offset is None: + break + + offset = str(next_offset) + + # Write to file + if format == BackupFormat.JSON: + await self._write_json(path, all_points) + elif format == BackupFormat.JSONL: + await self._write_jsonl(path, all_points) + else: + raise ValueError(f"Unsupported format: {format}") + + logger.info( + "Collection backup completed", + file=file_path, + points_count=len(all_points), + ) + return True + + except Exception as e: + logger.error("Backup failed", error=str(e)) + return False + + async def restore_from_file( + self, + file_path: str, + format: str = BackupFormat.JSONL, + batch_size: int = 100, + clear_existing: bool = False, + ) -> bool: + """ + Restore collection from file. + + Args: + file_path: Path to backup file + format: Backup format (json or jsonl) + batch_size: Batch size for uploading + clear_existing: Whether to clear existing data + + Returns: + True if successful + """ + try: + path = Path(file_path) + if not path.exists(): + logger.error("Backup file not found", file=file_path) + return False + + logger.info( + "Starting collection restore", + file=file_path, + format=format, + clear_existing=clear_existing, + ) + + # Clear existing data if requested + if clear_existing: + await self._repository.delete_collection() + await self._repository.create_collection() + + # Read from file + if format == BackupFormat.JSON: + points_data = await self._read_json(path) + elif format == BackupFormat.JSONL: + points_data = await self._read_jsonl(path) + else: + raise ValueError(f"Unsupported format: {format}") + + # Convert to QdrantPoint objects and upload in batches + total_restored = 0 + for i in range(0, len(points_data), batch_size): + batch = points_data[i : i + batch_size] + points = [ + QdrantPoint( + id=p["id"], + vector=p["vector"], + payload=p["payload"], + ) + for p in batch + ] + + count = await self._repository.store_points(points) + total_restored += count + logger.debug(f"Restored {count} points") + + logger.info( + "Collection restore completed", + file=file_path, + points_count=total_restored, + ) + return True + + except Exception as e: + logger.error("Restore failed", error=str(e)) + return False + + async def _write_json(self, path: Path, data: List[Dict]) -> None: + """ + Write data as JSON array. + + Args: + path: File path + data: Data to write + """ + with open(path, "w") as f: + json.dump(data, f, indent=2) + + async def _write_jsonl(self, path: Path, data: List[Dict]) -> None: + """ + Write data as JSON Lines. + + Args: + path: File path + data: Data to write + """ + with open(path, "w") as f: + for item in data: + f.write(json.dumps(item) + "\n") + + async def _read_json(self, path: Path) -> List[Dict]: + """ + Read data from JSON array. + + Args: + path: File path + + Returns: + List of data items + """ + with open(path, "r") as f: + return json.load(f) + + async def _read_jsonl(self, path: Path) -> List[Dict]: + """ + Read data from JSON Lines. + + Args: + path: File path + + Returns: + List of data items + """ + data = [] + with open(path, "r") as f: + for line in f: + if line.strip(): + data.append(json.loads(line)) + return data + + async def get_backup_info(self, file_path: str) -> Optional[Dict]: + """ + Get information about backup file. + + Args: + file_path: Path to backup file + + Returns: + Backup information dictionary + """ + try: + path = Path(file_path) + if not path.exists(): + return None + + # Determine format + format = ( + BackupFormat.JSONL if path.suffix == ".jsonl" else BackupFormat.JSON + ) + + # Read data + if format == BackupFormat.JSON: + data = await self._read_json(path) + else: + data = await self._read_jsonl(path) + + return { + "file_path": str(path), + "file_size": path.stat().st_size, + "format": format, + "points_count": len(data), + "modified_time": path.stat().st_mtime, + } + + except Exception as e: + logger.error("Get backup info failed", error=str(e)) + return None + + +class SnapshotManager: + """ + Manager for Qdrant collection snapshots. + + Uses Qdrant's native snapshot functionality. + """ + + def __init__(self, client: AsyncQdrantClient, collection_name: str): + """ + Initialize snapshot manager. + + Args: + client: Qdrant client + collection_name: Collection name + """ + self._client = client + self._collection_name = collection_name + + async def create_snapshot(self) -> Optional[str]: + """ + Create collection snapshot. + + Returns: + Snapshot name if successful + """ + try: + result = await self._client.create_snapshot( + collection_name=self._collection_name + ) + + if result: + logger.info( + "Snapshot created", + collection=self._collection_name, + snapshot=result.name, + ) + return result.name + + return None + + except Exception as e: + logger.error("Snapshot creation failed", error=str(e)) + return None + + async def list_snapshots(self) -> List[Dict]: + """ + List collection snapshots. + + Returns: + List of snapshot information + """ + try: + snapshots = await self._client.list_snapshots( + collection_name=self._collection_name + ) + + return [ + { + "name": snap.name, + "creation_time": snap.creation_time, + "size": snap.size, + } + for snap in snapshots + ] + + except Exception as e: + logger.error("List snapshots failed", error=str(e)) + return [] + + async def delete_snapshot(self, snapshot_name: str) -> bool: + """ + Delete collection snapshot. + + Args: + snapshot_name: Snapshot name + + Returns: + True if successful + """ + try: + await self._client.delete_snapshot( + collection_name=self._collection_name, + snapshot_name=snapshot_name, + ) + + logger.info( + "Snapshot deleted", + collection=self._collection_name, + snapshot=snapshot_name, + ) + return True + + except Exception as e: + logger.error("Snapshot deletion failed", error=str(e)) + return False + + +async def backup_collection( + repository: QdrantRepository, + backup_path: str, + format: str = BackupFormat.JSONL, +) -> bool: + """ + Convenience function to backup collection. + + Args: + repository: Qdrant repository + backup_path: Path to backup file + format: Backup format + + Returns: + True if successful + """ + backup = QdrantBackup(repository) + return await backup.backup_to_file(backup_path, format=format) + + +async def restore_collection( + repository: QdrantRepository, + backup_path: str, + format: str = BackupFormat.JSONL, + clear_existing: bool = False, +) -> bool: + """ + Convenience function to restore collection. + + Args: + repository: Qdrant repository + backup_path: Path to backup file + format: Backup format + clear_existing: Whether to clear existing data + + Returns: + True if successful + """ + backup = QdrantBackup(repository) + return await backup.restore_from_file( + backup_path, + format=format, + clear_existing=clear_existing, + ) diff --git a/app/cache/qdrant_client.py b/app/cache/qdrant_client.py new file mode 100644 index 0000000..cef71df --- /dev/null +++ b/app/cache/qdrant_client.py @@ -0,0 +1,140 @@ +""" +Qdrant client connection manager. + +Sandi Metz Principles: +- Single Responsibility: Qdrant connection management +- Small methods: Each operation isolated +- Dependency Injection: Configuration injected +""" + +from contextlib import asynccontextmanager +from typing import AsyncIterator, Optional + +from qdrant_client import AsyncQdrantClient + +from app.config import config +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +async def create_qdrant_client() -> AsyncQdrantClient: + """ + Create Qdrant async client connection. + + Returns: + Qdrant async client + + Raises: + ConnectionError: If connection fails + """ + try: + client = AsyncQdrantClient( + host=config.qdrant_host, + port=config.qdrant_port, + timeout=30, + ) + + # Test connection + await client.get_collections() + + logger.info( + "Qdrant client connected", + host=config.qdrant_host, + port=config.qdrant_port, + ) + + return client + + except Exception as e: + logger.error("Qdrant connection failed", error=str(e)) + raise ConnectionError(f"Failed to connect to Qdrant: {e}") + + +class QdrantConnectionManager: + """ + Manages Qdrant client connection lifecycle. + + Handles connection pooling and health checks. + """ + + def __init__(self): + """Initialize connection manager.""" + self._client: Optional[AsyncQdrantClient] = None + + async def get_client(self) -> AsyncQdrantClient: + """ + Get or create Qdrant client. + + Returns: + Qdrant async client + + Raises: + ConnectionError: If connection fails + """ + if self._client is None: + self._client = await create_qdrant_client() + return self._client + + async def close(self) -> None: + """Close Qdrant client connection.""" + if self._client is not None: + try: + await self._client.close() + logger.info("Qdrant client closed") + except Exception as e: + logger.error("Failed to close Qdrant client", error=str(e)) + finally: + self._client = None + + async def health_check(self) -> bool: + """ + Check Qdrant server health. + + Returns: + True if healthy, False otherwise + """ + try: + client = await self.get_client() + await client.get_collections() + return True + except Exception as e: + logger.error("Qdrant health check failed", error=str(e)) + return False + + async def reconnect(self) -> bool: + """ + Reconnect to Qdrant server. + + Returns: + True if reconnected successfully + """ + try: + await self.close() + self._client = await create_qdrant_client() + return True + except Exception as e: + logger.error("Qdrant reconnection failed", error=str(e)) + return False + + +@asynccontextmanager +async def get_pooled_client() -> AsyncIterator[AsyncQdrantClient]: + """ + Context manager for acquiring pooled connection. + + Yields: + Qdrant client from pool + + Example: + async with get_pooled_client() as client: + await client.upsert(...) + """ + from app.cache.qdrant_pool import get_pool + + pool = await get_pool() + client = await pool.acquire() + try: + yield client + finally: + await pool.release(client) diff --git a/app/cache/qdrant_collection.py b/app/cache/qdrant_collection.py new file mode 100644 index 0000000..1a2e101 --- /dev/null +++ b/app/cache/qdrant_collection.py @@ -0,0 +1,167 @@ +""" +Qdrant collection initialization and management. + +Sandi Metz Principles: +- Single Responsibility: Collection setup and validation +- Small methods: Each operation isolated +- Dependency Injection: Repository injected +""" + +from typing import Optional + +from qdrant_client.models import Distance + +from app.repositories.qdrant_repository import QdrantRepository +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class QdrantCollectionManager: + """ + Manages Qdrant collection initialization. + + Ensures collection exists and is properly configured. + """ + + def __init__(self, repository: QdrantRepository): + """ + Initialize collection manager. + + Args: + repository: Qdrant repository + """ + self._repository = repository + + async def initialize( + self, distance: Distance = Distance.COSINE, recreate: bool = False + ) -> bool: + """ + Initialize collection for vector storage. + + Args: + distance: Distance metric for similarity + recreate: Whether to recreate existing collection + + Returns: + True if initialized successfully + """ + try: + if recreate: + await self._recreate_collection(distance) + return True + + return await self._ensure_collection_exists(distance) + + except Exception as e: + logger.error("Collection initialization failed", error=str(e)) + return False + + async def _ensure_collection_exists(self, distance: Distance) -> bool: + """ + Ensure collection exists. + + Args: + distance: Distance metric + + Returns: + True if exists or created + """ + exists = await self._repository.collection_exists() + + if exists: + logger.info("Collection verified") + return True + + return await self._repository.create_collection(distance) + + async def _recreate_collection(self, distance: Distance) -> bool: + """ + Recreate collection (delete and create). + + Args: + distance: Distance metric + + Returns: + True if recreated successfully + """ + logger.warning("Recreating collection - all data will be lost") + + # Delete if exists + exists = await self._repository.collection_exists() + if exists: + await self._repository.delete_collection() + + # Create new collection + return await self._repository.create_collection(distance) + + async def validate_collection(self) -> dict[str, bool]: + """ + Validate collection configuration. + + Returns: + Validation results dict + """ + results = { + "exists": False, + "accessible": False, + "configured": False, + } + + try: + # Check existence + results["exists"] = await self._repository.collection_exists() + if not results["exists"]: + return results + + # Check accessibility + results["accessible"] = await self._repository.ping() + if not results["accessible"]: + return results + + # Check configuration + info = await self._repository.get_collection_info() + results["configured"] = info is not None + + return results + + except Exception as e: + logger.error("Collection validation failed", error=str(e)) + return results + + async def get_status(self) -> Optional[dict]: + """ + Get collection status and statistics. + + Returns: + Status dict if successful + """ + try: + validation = await self.validate_collection() + if not validation["exists"]: + return { + "status": "not_initialized", + "message": "Collection does not exist", + } + + info = await self._repository.get_collection_info() + if not info: + return { + "status": "error", + "message": "Failed to get collection info", + } + + return { + "status": "ready", + "vectors_count": info["vectors_count"], + "points_count": info["points_count"], + "collection_status": info["status"], + "config": info["config"], + } + + except Exception as e: + logger.error("Get status failed", error=str(e)) + return { + "status": "error", + "message": str(e), + } diff --git a/app/cache/qdrant_errors.py b/app/cache/qdrant_errors.py new file mode 100644 index 0000000..b857c45 --- /dev/null +++ b/app/cache/qdrant_errors.py @@ -0,0 +1,264 @@ +""" +Qdrant error handling and custom exceptions. + +Sandi Metz Principles: +- Single Responsibility: Error handling only +- Small classes: Each exception focused +- Clear naming: Descriptive exception names +""" + +from typing import Any, Optional + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class QdrantError(Exception): + """Base exception for all Qdrant errors.""" + + def __init__(self, message: str, cause: Optional[Exception] = None): + """ + Initialize Qdrant error. + + Args: + message: Error message + cause: Original exception that caused this error + """ + self.message = message + self.cause = cause + super().__init__(message) + + def __str__(self) -> str: + """Get string representation.""" + if self.cause: + return f"{self.message} (caused by: {str(self.cause)})" + return self.message + + +class QdrantConnectionError(QdrantError): + """Raised when connection to Qdrant fails.""" + + pass + + +class QdrantCollectionError(QdrantError): + """Raised when collection operations fail.""" + + pass + + +class QdrantCollectionNotFoundError(QdrantCollectionError): + """Raised when collection does not exist.""" + + pass + + +class QdrantCollectionExistsError(QdrantCollectionError): + """Raised when collection already exists.""" + + pass + + +class QdrantPointError(QdrantError): + """Raised when point operations fail.""" + + pass + + +class QdrantPointNotFoundError(QdrantPointError): + """Raised when point does not exist.""" + + pass + + +class QdrantSearchError(QdrantError): + """Raised when search operations fail.""" + + pass + + +class QdrantValidationError(QdrantError): + """Raised when validation fails.""" + + pass + + +class QdrantTimeoutError(QdrantError): + """Raised when operation times out.""" + + pass + + +class QdrantCapacityError(QdrantError): + """Raised when storage capacity is exceeded.""" + + pass + + +class QdrantIndexError(QdrantError): + """Raised when index operations fail.""" + + pass + + +def handle_qdrant_error(error: Exception, operation: str) -> QdrantError: + """ + Map Qdrant exceptions to custom exceptions. + + Args: + error: Original exception + operation: Operation that failed + + Returns: + Custom Qdrant exception + """ + error_msg = str(error) + error_type = type(error).__name__ + + # Connection errors + if "connect" in error_msg.lower() or "connection" in error_msg.lower(): + logger.error(f"Connection error during {operation}", error=error_msg) + return QdrantConnectionError( + f"Failed to connect to Qdrant during {operation}", cause=error + ) + + # Timeout errors + if "timeout" in error_msg.lower(): + logger.error(f"Timeout during {operation}", error=error_msg) + return QdrantTimeoutError( + f"Operation {operation} timeout exceeded", cause=error + ) + + # Collection errors + if "collection" in error_msg.lower(): + if "not found" in error_msg.lower() or "does not exist" in error_msg.lower(): + logger.error(f"Collection not found during {operation}", error=error_msg) + return QdrantCollectionNotFoundError( + f"Collection not found during {operation}", cause=error + ) + if "already exists" in error_msg.lower(): + logger.error(f"Collection exists during {operation}", error=error_msg) + return QdrantCollectionExistsError( + f"Collection already exists during {operation}", cause=error + ) + logger.error(f"Collection error during {operation}", error=error_msg) + return QdrantCollectionError( + f"Collection operation failed during {operation}", cause=error + ) + + # Point errors + if "point" in error_msg.lower(): + if "not found" in error_msg.lower(): + logger.error(f"Point not found during {operation}", error=error_msg) + return QdrantPointNotFoundError( + f"Point not found during {operation}", cause=error + ) + logger.error(f"Point error during {operation}", error=error_msg) + return QdrantPointError( + f"Point operation failed during {operation}", cause=error + ) + + # Search errors + if "search" in error_msg.lower() or "query" in error_msg.lower(): + logger.error(f"Search error during {operation}", error=error_msg) + return QdrantSearchError(f"Search failed during {operation}", cause=error) + + # Validation errors + if "invalid" in error_msg.lower() or "validation" in error_msg.lower(): + logger.error(f"Validation error during {operation}", error=error_msg) + return QdrantValidationError( + f"Validation failed during {operation}", cause=error + ) + + # Capacity errors + if "capacity" in error_msg.lower() or "full" in error_msg.lower(): + logger.error(f"Capacity error during {operation}", error=error_msg) + return QdrantCapacityError( + f"Storage capacity exceeded during {operation}", cause=error + ) + + # Index errors + if "index" in error_msg.lower(): + logger.error(f"Index error during {operation}", error=error_msg) + return QdrantIndexError( + f"Index operation failed during {operation}", cause=error + ) + + # Generic error + logger.error( + f"Unknown error during {operation}", + error=error_msg, + error_type=error_type, + ) + return QdrantError(f"Operation {operation} failed: {error_msg}", cause=error) + + +def is_retryable_error(error: Exception) -> bool: + """ + Check if error is retryable. + + Args: + error: Exception to check + + Returns: + True if error is transient and retryable + """ + retryable_types = ( + QdrantConnectionError, + QdrantTimeoutError, + ) + + if isinstance(error, retryable_types): + return True + + error_msg = str(error).lower() + retryable_keywords = [ + "timeout", + "connection", + "network", + "unavailable", + "temporary", + ] + + return any(keyword in error_msg for keyword in retryable_keywords) + + +class ErrorContext: + """ + Context manager for Qdrant error handling. + + Automatically maps exceptions to custom types. + """ + + def __init__(self, operation: str): + """ + Initialize error context. + + Args: + operation: Operation name for error messages + """ + self.operation = operation + + def __enter__(self) -> "ErrorContext": + """Enter context.""" + return self + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[Exception], + exc_tb: Any, + ) -> None: + """ + Exit context and handle exceptions. + + Args: + exc_type: Exception type + exc_val: Exception value + exc_tb: Exception traceback + """ + if exc_val is not None: + # Map to custom exception + custom_error = handle_qdrant_error(exc_val, self.operation) + raise custom_error from exc_val diff --git a/app/cache/qdrant_filter.py b/app/cache/qdrant_filter.py new file mode 100644 index 0000000..0f175f6 --- /dev/null +++ b/app/cache/qdrant_filter.py @@ -0,0 +1,261 @@ +""" +Qdrant filter builder for advanced queries. + +Sandi Metz Principles: +- Single Responsibility: Filter construction +- Small methods: Each filter type isolated +- Clear naming: Descriptive method names +""" + +from typing import Any, List, Optional + +from qdrant_client.models import ( + Condition, + FieldCondition, + Filter, + IsEmptyCondition, + MatchAny, + MatchValue, + PayloadField, + Range, +) + +from app.models.qdrant_schema import QdrantSchema +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class QdrantFilterBuilder: + """ + Builder for Qdrant filter conditions. + + Provides fluent API for constructing complex filters. + """ + + def __init__(self): + """Initialize filter builder.""" + self._must: List[Condition] = [] + self._should: List[Condition] = [] + self._must_not: List[Condition] = [] + + def match_field(self, field: str, value: Any) -> "QdrantFilterBuilder": + """ + Add exact match condition. + + Args: + field: Field name to match + value: Value to match + + Returns: + Self for chaining + """ + condition = FieldCondition(key=field, match=MatchValue(value=value)) + self._must.append(condition) + return self + + def match_any(self, field: str, values: List[Any]) -> "QdrantFilterBuilder": + """ + Add match any condition. + + Args: + field: Field name to match + values: List of acceptable values + + Returns: + Self for chaining + """ + condition = FieldCondition(key=field, match=MatchAny(any=values)) + self._must.append(condition) + return self + + def range_field( + self, + field: str, + gte: Optional[float] = None, + gt: Optional[float] = None, + lte: Optional[float] = None, + lt: Optional[float] = None, + ) -> "QdrantFilterBuilder": + """ + Add range condition. + + Args: + field: Field name for range + gte: Greater than or equal to + gt: Greater than + lte: Less than or equal to + lt: Less than + + Returns: + Self for chaining + """ + condition = FieldCondition( + key=field, range=Range(gte=gte, gt=gt, lte=lte, lt=lt) + ) + self._must.append(condition) + return self + + def is_empty(self, field: str) -> "QdrantFilterBuilder": + """ + Add is empty condition. + + Args: + field: Field name to check + + Returns: + Self for chaining + """ + condition = IsEmptyCondition(is_empty=PayloadField(key=field)) + self._must.append(condition) + return self + + def is_not_empty(self, field: str) -> "QdrantFilterBuilder": + """ + Add is not empty condition. + + Args: + field: Field name to check + + Returns: + Self for chaining + """ + condition = IsEmptyCondition(is_empty=PayloadField(key=field)) + self._must_not.append(condition) + return self + + def with_provider(self, provider: str) -> "QdrantFilterBuilder": + """ + Filter by LLM provider. + + Args: + provider: Provider name + + Returns: + Self for chaining + """ + return self.match_field(QdrantSchema.FIELD_PROVIDER, provider) + + def with_model(self, model: str) -> "QdrantFilterBuilder": + """ + Filter by LLM model. + + Args: + model: Model name + + Returns: + Self for chaining + """ + return self.match_field(QdrantSchema.FIELD_MODEL, model) + + def with_query_hash(self, query_hash: str) -> "QdrantFilterBuilder": + """ + Filter by query hash. + + Args: + query_hash: Query hash value + + Returns: + Self for chaining + """ + return self.match_field(QdrantSchema.FIELD_QUERY_HASH, query_hash) + + def created_after(self, timestamp: float) -> "QdrantFilterBuilder": + """ + Filter by creation time (after timestamp). + + Args: + timestamp: Unix timestamp + + Returns: + Self for chaining + """ + return self.range_field(QdrantSchema.FIELD_CREATED_AT, gte=timestamp) + + def created_before(self, timestamp: float) -> "QdrantFilterBuilder": + """ + Filter by creation time (before timestamp). + + Args: + timestamp: Unix timestamp + + Returns: + Self for chaining + """ + return self.range_field(QdrantSchema.FIELD_CREATED_AT, lte=timestamp) + + def created_between( + self, start_time: float, end_time: float + ) -> "QdrantFilterBuilder": + """ + Filter by creation time range. + + Args: + start_time: Start timestamp + end_time: End timestamp + + Returns: + Self for chaining + """ + return self.range_field( + QdrantSchema.FIELD_CREATED_AT, gte=start_time, lte=end_time + ) + + def with_tags(self, tags: List[str]) -> "QdrantFilterBuilder": + """ + Filter by tags. + + Args: + tags: List of tags to match + + Returns: + Self for chaining + """ + return self.match_any(QdrantSchema.FIELD_TAGS, tags) + + def build(self) -> Optional[Filter]: + """ + Build the filter. + + Returns: + Filter object if conditions exist, None otherwise + """ + if not (self._must or self._should or self._must_not): + return None + + filter_obj = Filter( + must=self._must if self._must else None, + should=self._should if self._should else None, + must_not=self._must_not if self._must_not else None, + ) + + logger.debug( + "Filter built", + must_count=len(self._must), + should_count=len(self._should), + must_not_count=len(self._must_not), + ) + + return filter_obj + + def reset(self) -> "QdrantFilterBuilder": + """ + Reset builder to empty state. + + Returns: + Self for chaining + """ + self._must = [] + self._should = [] + self._must_not = [] + return self + + +def create_filter() -> QdrantFilterBuilder: + """ + Create a new filter builder. + + Returns: + QdrantFilterBuilder instance + """ + return QdrantFilterBuilder() diff --git a/app/cache/qdrant_health.py b/app/cache/qdrant_health.py new file mode 100644 index 0000000..f90d70c --- /dev/null +++ b/app/cache/qdrant_health.py @@ -0,0 +1,170 @@ +""" +Qdrant health check service. + +Sandi Metz Principles: +- Single Responsibility: Health monitoring +- Small methods: Each check isolated +- Clear naming: Descriptive method names +""" + +from enum import Enum +from typing import Any, Dict, Optional + +from app.cache.qdrant_collection import QdrantCollectionManager +from app.repositories.qdrant_repository import QdrantRepository +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class HealthStatus(str, Enum): + """Health check status levels.""" + + HEALTHY = "healthy" + DEGRADED = "degraded" + UNHEALTHY = "unhealthy" + + +class QdrantHealthCheck: + """ + Qdrant health check service. + + Monitors Qdrant service health and collection status. + """ + + def __init__( + self, repository: QdrantRepository, collection_manager: QdrantCollectionManager + ): + """ + Initialize health check service. + + Args: + repository: Qdrant repository + collection_manager: Collection manager + """ + self._repository = repository + self._collection_manager = collection_manager + + async def check_health(self) -> Dict[str, Any]: + """ + Perform comprehensive health check. + + Returns: + Health check results dictionary + """ + results: Dict[str, Any] = { + "status": HealthStatus.HEALTHY.value, + "checks": {}, + "details": {}, + } + + # Check connection + connection_ok = await self._check_connection() + results["checks"]["connection"] = connection_ok # type: ignore[index] + + if not connection_ok: + results["status"] = HealthStatus.UNHEALTHY.value + # type: ignore[index] + results["details"]["error"] = "Cannot connect to Qdrant" + return results + + # Check collection + collection_ok = await self._check_collection() + results["checks"]["collection"] = collection_ok # type: ignore[index] + + if not collection_ok: + results["status"] = HealthStatus.DEGRADED.value + # type: ignore[index] + results["details"]["warning"] = "Collection not properly configured" + + # Get collection stats + stats = await self._get_collection_stats() + results["details"]["statistics"] = stats # type: ignore[index] + + logger.info("Health check completed", status=results["status"]) + return results + + async def _check_connection(self) -> bool: + """ + Check Qdrant server connection. + + Returns: + True if connected + """ + try: + return await self._repository.ping() + except Exception as e: + logger.error("Connection check failed", error=str(e)) + return False + + async def _check_collection(self) -> bool: + """ + Check collection existence and configuration. + + Returns: + True if collection is properly configured + """ + try: + validation = await self._collection_manager.validate_collection() + return all(validation.values()) + except Exception as e: + logger.error("Collection check failed", error=str(e)) + return False + + async def _get_collection_stats(self) -> Optional[Dict]: + """ + Get collection statistics. + + Returns: + Statistics dictionary + """ + try: + info = await self._repository.get_collection_info() + if info: + return { + "vectors_count": info.get("vectors_count", 0), + "points_count": info.get("points_count", 0), + "status": info.get("status", "unknown"), + } + return None + except Exception as e: + logger.error("Stats retrieval failed", error=str(e)) + return None + + async def is_healthy(self) -> bool: + """ + Quick health check. + + Returns: + True if healthy + """ + results = await self.check_health() + return results["status"] == HealthStatus.HEALTHY.value + + async def is_ready(self) -> bool: + """ + Check if service is ready to handle requests. + + Returns: + True if ready + """ + connection_ok = await self._check_connection() + collection_ok = await self._check_collection() + return connection_ok and collection_ok + + async def get_status_summary(self) -> str: + """ + Get human-readable status summary. + + Returns: + Status summary string + """ + results = await self.check_health() + status = results["status"] + + if status == HealthStatus.HEALTHY.value: + return "All systems operational" + elif status == HealthStatus.DEGRADED.value: + return "Service degraded - some features may be limited" + else: + return "Service unavailable - critical issues detected" diff --git a/app/cache/qdrant_index.py b/app/cache/qdrant_index.py new file mode 100644 index 0000000..b14accb --- /dev/null +++ b/app/cache/qdrant_index.py @@ -0,0 +1,427 @@ +""" +Qdrant index optimization and management. + +Sandi Metz Principles: +- Single Responsibility: Index optimization +- Small methods: Each optimization focused +- Clear naming: Descriptive method names +""" + +from typing import Dict, Optional + +from qdrant_client import AsyncQdrantClient +from qdrant_client.models import ( + HnswConfigDiff, + OptimizersConfigDiff, + ScalarQuantization, + ScalarQuantizationConfig, + ScalarType, +) + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class IndexOptimizationConfig: + """ + Configuration for index optimization. + + Defines HNSW and optimization parameters. + """ + + def __init__( + self, + m: Optional[int] = None, + ef_construct: Optional[int] = None, + full_scan_threshold: Optional[int] = None, + max_indexing_threads: Optional[int] = None, + on_disk: Optional[bool] = None, + ): + """ + Initialize optimization configuration. + + Args: + m: Number of edges per node in graph (4-64, default: 16) + ef_construct: Size of dynamic candidate list (default: 100) + full_scan_threshold: Threshold for full scan vs HNSW (default: 10000) + max_indexing_threads: Max threads for indexing (default: 0 = auto) + on_disk: Store index on disk vs memory (default: False) + """ + self.m = m + self.ef_construct = ef_construct + self.full_scan_threshold = full_scan_threshold + self.max_indexing_threads = max_indexing_threads + self.on_disk = on_disk + + +class OptimizationProfile: + """ + Predefined optimization profiles for different use cases. + + Provides balanced configurations for common scenarios. + """ + + # High accuracy, slower indexing/search + HIGH_ACCURACY = IndexOptimizationConfig( + m=64, + ef_construct=200, + full_scan_threshold=20000, + max_indexing_threads=0, + on_disk=False, + ) + + # Balanced accuracy and speed + BALANCED = IndexOptimizationConfig( + m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) + + # Fast search, lower accuracy + FAST_SEARCH = IndexOptimizationConfig( + m=8, + ef_construct=64, + full_scan_threshold=5000, + max_indexing_threads=0, + on_disk=False, + ) + + # Memory optimized (disk storage) + MEMORY_OPTIMIZED = IndexOptimizationConfig( + m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=True, + ) + + # Large dataset optimized + LARGE_DATASET = IndexOptimizationConfig( + m=32, + ef_construct=128, + full_scan_threshold=50000, + max_indexing_threads=0, + on_disk=True, + ) + + +class QdrantIndexOptimizer: + """ + Optimizer for Qdrant collection indexes. + + Manages HNSW configuration and optimizations. + """ + + def __init__(self, client: AsyncQdrantClient, collection_name: str): + """ + Initialize index optimizer. + + Args: + client: Qdrant client + collection_name: Collection to optimize + """ + self._client = client + self._collection_name = collection_name + + async def optimize_hnsw(self, config: IndexOptimizationConfig) -> bool: + """ + Optimize HNSW index configuration. + + Args: + config: Optimization configuration + + Returns: + True if successful + """ + try: + hnsw_config = HnswConfigDiff( + m=config.m, + ef_construct=config.ef_construct, + full_scan_threshold=config.full_scan_threshold, + on_disk=config.on_disk, + ) + + await self._client.update_collection( + collection_name=self._collection_name, + hnsw_config=hnsw_config, + ) + + logger.info( + "HNSW index optimized", + collection=self._collection_name, + m=config.m, + ef_construct=config.ef_construct, + ) + return True + + except Exception as e: + logger.error("HNSW optimization failed", error=str(e)) + return False + + async def optimize_indexing( + self, + memmap_threshold: Optional[int] = None, + max_segment_size: Optional[int] = None, + ) -> bool: + """ + Optimize indexing parameters. + + Args: + memmap_threshold: Memory map threshold in KB + max_segment_size: Maximum segment size + + Returns: + True if successful + """ + try: + optimizer_config = OptimizersConfigDiff( + memmap_threshold=memmap_threshold, + max_segment_size=max_segment_size, + ) + + await self._client.update_collection( + collection_name=self._collection_name, + optimizers_config=optimizer_config, + ) + + logger.info( + "Indexing optimized", + collection=self._collection_name, + memmap_threshold=memmap_threshold, + max_segment_size=max_segment_size, + ) + return True + + except Exception as e: + logger.error("Indexing optimization failed", error=str(e)) + return False + + async def enable_quantization( + self, + quantization_type: str = "scalar", + always_ram: bool = True, + ) -> bool: + """ + Enable vector quantization for memory optimization. + + Args: + quantization_type: Type of quantization (scalar, product) + always_ram: Keep quantized vectors in RAM + + Returns: + True if successful + """ + try: + if quantization_type == "scalar": + quantization = ScalarQuantization( + scalar=ScalarQuantizationConfig( + type=ScalarType.INT8, + always_ram=always_ram, + ) + ) + else: + logger.warning(f"Unsupported quantization type: {quantization_type}") + return False + + await self._client.update_collection( + collection_name=self._collection_name, + quantization_config=quantization, + ) + + logger.info( + "Quantization enabled", + collection=self._collection_name, + type=quantization_type, + ) + return True + + except Exception as e: + logger.error("Quantization enable failed", error=str(e)) + return False + + async def apply_profile(self, profile: IndexOptimizationConfig) -> bool: + """ + Apply optimization profile. + + Args: + profile: Optimization profile to apply + + Returns: + True if successful + """ + return await self.optimize_hnsw(profile) + + async def get_index_stats(self) -> Optional[Dict]: + """ + Get current index statistics. + + Returns: + Index statistics dictionary + """ + try: + info = await self._client.get_collection( + collection_name=self._collection_name + ) + + return { + "vectors_count": info.vectors_count, + "points_count": info.points_count, + "segments_count": info.segments_count, + "status": info.status, + "optimizer_status": info.optimizer_status, + } + + except Exception as e: + logger.error("Get index stats failed", error=str(e)) + return None + + async def reindex(self) -> bool: + """ + Trigger reindexing of collection. + + Returns: + True if successful + """ + try: + # Trigger optimization which includes reindexing + await self._client.update_collection( + collection_name=self._collection_name, + optimizers_config=OptimizersConfigDiff(), + ) + + logger.info("Reindexing triggered", collection=self._collection_name) + return True + + except Exception as e: + logger.error("Reindex failed", error=str(e)) + return False + + +class IndexTuner: + """ + Automatic index tuning based on collection size. + + Provides recommendations for optimal settings. + """ + + @staticmethod + def recommend_config( + collection_size: int, + memory_available_gb: float = 8.0, + ) -> IndexOptimizationConfig: + """ + Recommend index configuration based on collection size. + + Args: + collection_size: Number of vectors in collection + memory_available_gb: Available memory in GB + + Returns: + Recommended configuration + """ + # Small collections (< 10K vectors) + if collection_size < 10_000: + return IndexOptimizationConfig( + m=8, + ef_construct=64, + full_scan_threshold=5000, + on_disk=False, + ) + + # Medium collections (10K - 100K vectors) + elif collection_size < 100_000: + return IndexOptimizationConfig( + m=16, + ef_construct=100, + full_scan_threshold=10000, + on_disk=memory_available_gb < 4.0, + ) + + # Large collections (100K - 1M vectors) + elif collection_size < 1_000_000: + return IndexOptimizationConfig( + m=32, + ef_construct=128, + full_scan_threshold=20000, + on_disk=memory_available_gb < 8.0, + ) + + # Very large collections (> 1M vectors) + else: + return IndexOptimizationConfig( + m=48, + ef_construct=150, + full_scan_threshold=50000, + on_disk=True, + ) + + @staticmethod + def estimate_memory_usage( + vector_count: int, + vector_size: int, + m: int = 16, + quantized: bool = False, + ) -> float: + """ + Estimate memory usage for collection. + + Args: + vector_count: Number of vectors + vector_size: Dimension of vectors + m: HNSW m parameter + quantized: Whether quantization is enabled + + Returns: + Estimated memory usage in GB + """ + # Vector storage (4 bytes per float, or 1 byte if quantized) + bytes_per_element = 1 if quantized else 4 + vector_memory = vector_count * vector_size * bytes_per_element + + # HNSW graph overhead (approximately m * 2 * 8 bytes per vector) + graph_memory = vector_count * m * 2 * 8 + + # Payload overhead (estimated 1KB per vector) + payload_memory = vector_count * 1024 + + total_bytes = vector_memory + graph_memory + payload_memory + return total_bytes / (1024**3) # Convert to GB + + +async def optimize_collection( + client: AsyncQdrantClient, + collection_name: str, + profile: Optional[IndexOptimizationConfig] = None, +) -> bool: + """ + Optimize collection with recommended or custom profile. + + Args: + client: Qdrant client + collection_name: Collection to optimize + profile: Custom profile (uses auto-tuned if None) + + Returns: + True if successful + """ + optimizer = QdrantIndexOptimizer(client, collection_name) + + if profile is None: + # Auto-tune based on collection size + stats = await optimizer.get_index_stats() + if stats: + size = stats.get("vectors_count", 0) + profile = IndexTuner.recommend_config(size) + logger.info( + "Auto-tuned index configuration", + collection=collection_name, + size=size, + ) + + if profile: + return await optimizer.apply_profile(profile) + + return False diff --git a/app/cache/qdrant_metadata.py b/app/cache/qdrant_metadata.py new file mode 100644 index 0000000..9d0786c --- /dev/null +++ b/app/cache/qdrant_metadata.py @@ -0,0 +1,238 @@ +""" +Qdrant metadata handling utilities. + +Sandi Metz Principles: +- Single Responsibility: Metadata management +- Small methods: Each operation isolated +- Clear naming: Descriptive method names +""" + +import time +from typing import Any, Dict, List, Optional + +from app.models.cache_entry import CacheEntry +from app.models.qdrant_schema import QdrantSchema +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class MetadataHandler: + """ + Handler for point metadata operations. + + Manages metadata creation, validation, and extraction. + """ + + @staticmethod + def create_from_cache_entry(entry: CacheEntry) -> Dict[str, Any]: + """ + Create metadata payload from cache entry. + + Args: + entry: Cache entry + + Returns: + Metadata dictionary + """ + metadata = { + QdrantSchema.FIELD_QUERY_HASH: entry.query_hash, + QdrantSchema.FIELD_ORIGINAL_QUERY: entry.original_query, + QdrantSchema.FIELD_RESPONSE: entry.response, + QdrantSchema.FIELD_PROVIDER: entry.provider, + QdrantSchema.FIELD_MODEL: entry.model, + QdrantSchema.FIELD_PROMPT_TOKENS: entry.prompt_tokens, + QdrantSchema.FIELD_COMPLETION_TOKENS: entry.completion_tokens, + QdrantSchema.FIELD_CREATED_AT: time.time(), + QdrantSchema.FIELD_CACHED_AT: time.time(), + } + + return metadata + + @staticmethod + def validate_payload(payload: Dict[str, Any]) -> bool: + """ + Validate payload has required fields. + + Args: + payload: Payload dictionary + + Returns: + True if valid, False otherwise + """ + required_fields = QdrantSchema.get_required_fields() + + for field in required_fields: + if field not in payload: + logger.error("Missing required field", field=field) + return False + + return True + + @staticmethod + def extract_cache_entry(payload: Dict[str, Any]) -> Optional[CacheEntry]: + """ + Extract cache entry from payload. + + Args: + payload: Payload dictionary + + Returns: + CacheEntry if valid, None otherwise + """ + try: + return CacheEntry( + query_hash=payload[QdrantSchema.FIELD_QUERY_HASH], + original_query=payload[QdrantSchema.FIELD_ORIGINAL_QUERY], + response=payload[QdrantSchema.FIELD_RESPONSE], + provider=payload[QdrantSchema.FIELD_PROVIDER], + model=payload[QdrantSchema.FIELD_MODEL], + prompt_tokens=payload.get(QdrantSchema.FIELD_PROMPT_TOKENS, 0), + completion_tokens=payload.get(QdrantSchema.FIELD_COMPLETION_TOKENS, 0), + embedding=None, + ) + except KeyError as e: + logger.error("Missing required field in payload", field=str(e)) + return None + except Exception as e: + logger.error("Cache entry extraction failed", error=str(e)) + return None + + @staticmethod + def add_tags(payload: Dict[str, Any], tags: List[str]) -> Dict[str, Any]: + """ + Add tags to payload. + + Args: + payload: Existing payload + tags: Tags to add + + Returns: + Updated payload + """ + existing_tags = payload.get(QdrantSchema.FIELD_TAGS, []) + combined_tags = list(set(existing_tags + tags)) + payload[QdrantSchema.FIELD_TAGS] = combined_tags + return payload + + @staticmethod + def add_metadata( + payload: Dict[str, Any], metadata: Dict[str, str] + ) -> Dict[str, Any]: + """ + Add custom metadata to payload. + + Args: + payload: Existing payload + metadata: Custom metadata + + Returns: + Updated payload + """ + existing_metadata = payload.get(QdrantSchema.FIELD_METADATA, {}) + existing_metadata.update(metadata) + payload[QdrantSchema.FIELD_METADATA] = existing_metadata + return payload + + @staticmethod + def get_field(payload: Dict[str, Any], field: str) -> Optional[Any]: + """ + Safely get field from payload. + + Args: + payload: Payload dictionary + field: Field name + + Returns: + Field value if exists, None otherwise + """ + return payload.get(field) + + @staticmethod + def has_field(payload: Dict[str, Any], field: str) -> bool: + """ + Check if payload has field. + + Args: + payload: Payload dictionary + field: Field name + + Returns: + True if field exists + """ + return field in payload + + @staticmethod + def filter_sensitive_fields(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Remove sensitive fields from payload for logging. + + Args: + payload: Payload dictionary + + Returns: + Filtered payload + """ + filtered = payload.copy() + + # Remove potentially large or sensitive fields + sensitive_fields = [QdrantSchema.FIELD_RESPONSE] + + for field in sensitive_fields: + if field in filtered: + filtered[field] = "[REDACTED]" + + return filtered + + @staticmethod + def get_metadata_summary(payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Get summary of metadata for logging. + + Args: + payload: Payload dictionary + + Returns: + Summary dictionary + """ + return { + "query_hash": payload.get(QdrantSchema.FIELD_QUERY_HASH), + "provider": payload.get(QdrantSchema.FIELD_PROVIDER), + "model": payload.get(QdrantSchema.FIELD_MODEL), + "prompt_tokens": payload.get(QdrantSchema.FIELD_PROMPT_TOKENS), + "completion_tokens": payload.get(QdrantSchema.FIELD_COMPLETION_TOKENS), + "has_tags": QdrantSchema.FIELD_TAGS in payload, + "has_metadata": QdrantSchema.FIELD_METADATA in payload, + } + + @staticmethod + def merge_payloads(base: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]: + """ + Merge two payloads with conflict resolution. + + Args: + base: Base payload + updates: Update payload + + Returns: + Merged payload + """ + merged = base.copy() + merged.update(updates) + + # Special handling for tags (combine) + if QdrantSchema.FIELD_TAGS in base and QdrantSchema.FIELD_TAGS in updates: + merged[QdrantSchema.FIELD_TAGS] = list( + set(base[QdrantSchema.FIELD_TAGS] + updates[QdrantSchema.FIELD_TAGS]) + ) + + # Special handling for metadata (merge dicts) + if ( + QdrantSchema.FIELD_METADATA in base + and QdrantSchema.FIELD_METADATA in updates + ): + merged_metadata = base[QdrantSchema.FIELD_METADATA].copy() + merged_metadata.update(updates[QdrantSchema.FIELD_METADATA]) + merged[QdrantSchema.FIELD_METADATA] = merged_metadata + + return merged diff --git a/app/cache/qdrant_pool.py b/app/cache/qdrant_pool.py new file mode 100644 index 0000000..ebd32d5 --- /dev/null +++ b/app/cache/qdrant_pool.py @@ -0,0 +1,363 @@ +""" +Qdrant connection pool manager. + +Sandi Metz Principles: +- Single Responsibility: Connection pooling +- Small methods: Each operation focused +- Clear naming: Descriptive method names +""" + +import asyncio +from typing import Dict, List, Optional + +from qdrant_client import AsyncQdrantClient + +from app.cache.qdrant_client import create_qdrant_client +from app.cache.qdrant_errors import QdrantConnectionError +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class PoolConfig: + """ + Configuration for connection pool. + + Defines pool behavior and limits. + """ + + def __init__( + self, + min_size: int = 1, + max_size: int = 10, + idle_timeout: float = 300.0, + max_lifetime: float = 3600.0, + acquire_timeout: float = 30.0, + ): + """ + Initialize pool configuration. + + Args: + min_size: Minimum pool connections + max_size: Maximum pool connections + idle_timeout: Max idle time before closing (seconds) + max_lifetime: Max connection lifetime (seconds) + acquire_timeout: Timeout for acquiring connection (seconds) + """ + self.min_size = max(1, min_size) + self.max_size = max(self.min_size, max_size) + self.idle_timeout = idle_timeout + self.max_lifetime = max_lifetime + self.acquire_timeout = acquire_timeout + + +class PooledConnection: + """ + Wrapper for pooled connection. + + Tracks connection metadata for pool management. + """ + + def __init__(self, client: AsyncQdrantClient): + """ + Initialize pooled connection. + + Args: + client: Qdrant client instance + """ + self.client = client + self.created_at = asyncio.get_event_loop().time() + self.last_used = self.created_at + self.in_use = False + self.use_count = 0 + + def mark_used(self) -> None: + """Mark connection as in use.""" + self.in_use = True + self.use_count += 1 + self.last_used = asyncio.get_event_loop().time() + + def mark_released(self) -> None: + """Mark connection as released.""" + self.in_use = False + self.last_used = asyncio.get_event_loop().time() + + def is_expired(self, max_lifetime: float) -> bool: + """ + Check if connection has exceeded max lifetime. + + Args: + max_lifetime: Maximum lifetime in seconds + + Returns: + True if expired + """ + age = asyncio.get_event_loop().time() - self.created_at + return age > max_lifetime + + def is_idle_expired(self, idle_timeout: float) -> bool: + """ + Check if connection has been idle too long. + + Args: + idle_timeout: Idle timeout in seconds + + Returns: + True if idle too long + """ + idle_time = asyncio.get_event_loop().time() - self.last_used + return not self.in_use and idle_time > idle_timeout + + +class QdrantConnectionPool: + """ + Connection pool for Qdrant clients. + + Manages a pool of reusable connections with lifecycle management. + """ + + def __init__(self, config: Optional[PoolConfig] = None): + """ + Initialize connection pool. + + Args: + config: Pool configuration + """ + self._config = config or PoolConfig() + self._pool: List[PooledConnection] = [] + self._lock = asyncio.Lock() + self._closed = False + self._cleanup_task: Optional[asyncio.Task] = None + + async def initialize(self) -> None: + """Initialize pool with minimum connections.""" + async with self._lock: + if self._closed: + raise QdrantConnectionError("Pool is closed") + + # Create minimum connections + for _ in range(self._config.min_size): + await self._create_connection() + + # Start cleanup task + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + logger.info( + "Connection pool initialized", + min_size=self._config.min_size, + max_size=self._config.max_size, + ) + + async def acquire(self) -> AsyncQdrantClient: + """ + Acquire a connection from the pool. + + Returns: + Qdrant client + + Raises: + QdrantConnectionError: If unable to acquire connection + """ + try: + return await asyncio.wait_for( + self._acquire_internal(), + timeout=self._config.acquire_timeout, + ) + except asyncio.TimeoutError: + raise QdrantConnectionError( + f"Timeout acquiring connection after {self._config.acquire_timeout}s" + ) + + async def _acquire_internal(self) -> AsyncQdrantClient: + """ + Internal acquire logic. + + Returns: + Qdrant client + """ + while True: + async with self._lock: + if self._closed: + raise QdrantConnectionError("Pool is closed") + + # Find available connection + for conn in self._pool: + if not conn.in_use: + # Check if expired + if conn.is_expired(self._config.max_lifetime): + await self._remove_connection(conn) + continue + + conn.mark_used() + logger.debug( + "Connection acquired from pool", + pool_size=len(self._pool), + use_count=conn.use_count, + ) + return conn.client + + # Create new connection if below max + if len(self._pool) < self._config.max_size: + conn = await self._create_connection() + conn.mark_used() + logger.debug( + "New connection created and acquired", + pool_size=len(self._pool), + ) + return conn.client + + # Wait briefly before retrying + await asyncio.sleep(0.1) + + async def release(self, client: AsyncQdrantClient) -> None: + """ + Release a connection back to the pool. + + Args: + client: Qdrant client to release + """ + async with self._lock: + for conn in self._pool: + if conn.client is client: + conn.mark_released() + logger.debug( + "Connection released to pool", + pool_size=len(self._pool), + in_use_count=sum(1 for c in self._pool if c.in_use), + ) + return + + async def close(self) -> None: + """Close all connections in the pool.""" + async with self._lock: + if self._closed: + return + + self._closed = True + + # Cancel cleanup task + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + # Close all connections + for conn in self._pool[:]: + await self._remove_connection(conn) + + logger.info("Connection pool closed") + + async def _create_connection(self) -> PooledConnection: + """ + Create a new pooled connection. + + Returns: + Pooled connection + """ + client = await create_qdrant_client() + conn = PooledConnection(client) + self._pool.append(conn) + return conn + + async def _remove_connection(self, conn: PooledConnection) -> None: + """ + Remove and close a connection. + + Args: + conn: Connection to remove + """ + try: + await conn.client.close() + except Exception as e: + logger.error("Error closing connection", error=str(e)) + finally: + if conn in self._pool: + self._pool.remove(conn) + + async def _cleanup_loop(self) -> None: + """Background task to cleanup expired connections.""" + while not self._closed: + try: + await asyncio.sleep(60) # Cleanup every minute + await self._cleanup_expired() + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Cleanup loop error", error=str(e)) + + async def _cleanup_expired(self) -> None: + """Remove expired and idle connections.""" + async with self._lock: + expired = [] + + for conn in self._pool: + # Skip connections in use + if conn.in_use: + continue + + # Check lifetime + if conn.is_expired(self._config.max_lifetime): + expired.append(conn) + continue + + # Check idle timeout (keep minimum connections) + if len(self._pool) > self._config.min_size: + if conn.is_idle_expired(self._config.idle_timeout): + expired.append(conn) + + # Remove expired connections + for conn in expired: + await self._remove_connection(conn) + + if expired: + logger.info( + "Cleaned up expired connections", + removed=len(expired), + remaining=len(self._pool), + ) + + def get_stats(self) -> Dict[str, int]: + """ + Get pool statistics. + + Returns: + Statistics dictionary + """ + return { + "total": len(self._pool), + "in_use": sum(1 for conn in self._pool if conn.in_use), + "available": sum(1 for conn in self._pool if not conn.in_use), + "min_size": self._config.min_size, + "max_size": self._config.max_size, + } + + +# Global pool instance +_global_pool: Optional[QdrantConnectionPool] = None + + +async def get_pool() -> QdrantConnectionPool: + """ + Get or create global connection pool. + + Returns: + Connection pool instance + """ + global _global_pool + + if _global_pool is None: + _global_pool = QdrantConnectionPool() + await _global_pool.initialize() + + return _global_pool + + +async def close_pool() -> None: + """Close global connection pool.""" + global _global_pool + + if _global_pool is not None: + await _global_pool.close() + _global_pool = None diff --git a/app/cache/qdrant_retry.py b/app/cache/qdrant_retry.py new file mode 100644 index 0000000..426e740 --- /dev/null +++ b/app/cache/qdrant_retry.py @@ -0,0 +1,239 @@ +""" +Retry mechanism for Qdrant operations. + +Sandi Metz Principles: +- Single Responsibility: Retry logic +- Small methods: Each retry strategy isolated +- Clear naming: Descriptive function names +""" + +import asyncio +from functools import wraps +from typing import Any, Callable, Optional, TypeVar + +from app.cache.qdrant_errors import is_retryable_error +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class RetryConfig: + """ + Configuration for retry behavior. + + Defines retry parameters and backoff strategy. + """ + + def __init__( + self, + max_attempts: int = 3, + initial_delay: float = 1.0, + max_delay: float = 60.0, + exponential_base: float = 2.0, + jitter: bool = True, + ): + """ + Initialize retry configuration. + + Args: + max_attempts: Maximum number of retry attempts + initial_delay: Initial delay in seconds + max_delay: Maximum delay in seconds + exponential_base: Base for exponential backoff + jitter: Whether to add random jitter to delays + """ + self.max_attempts = max_attempts + self.initial_delay = initial_delay + self.max_delay = max_delay + self.exponential_base = exponential_base + self.jitter = jitter + + def get_delay(self, attempt: int) -> float: + """ + Calculate delay for given attempt. + + Args: + attempt: Attempt number (0-indexed) + + Returns: + Delay in seconds + """ + import random + + # Exponential backoff + delay = self.initial_delay * (self.exponential_base**attempt) + + # Cap at max delay + delay = min(delay, self.max_delay) + + # Add jitter if enabled + if self.jitter: + delay = delay * (0.5 + random.random()) + + return delay + + +def retry_on_error( + config: Optional[RetryConfig] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorator to retry async functions on transient errors. + + Args: + config: Retry configuration + + Returns: + Decorator function + """ + if config is None: + config = RetryConfig() + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + last_error: Optional[Exception] = None + + for attempt in range(config.max_attempts): + try: + return await func(*args, **kwargs) + except Exception as e: + last_error = e + + # Check if error is retryable + if not is_retryable_error(e): + logger.warning( + f"Non-retryable error in {func.__name__}", + error=str(e), + ) + raise + + # Check if this was the last attempt + if attempt == config.max_attempts - 1: + logger.error( + f"Max retries exceeded for {func.__name__}", + attempts=config.max_attempts, + error=str(e), + ) + raise + + # Calculate delay and wait + delay = config.get_delay(attempt) + logger.warning( + f"Retrying {func.__name__} after error", + attempt=attempt + 1, + max_attempts=config.max_attempts, + delay=delay, + error=str(e), + ) + await asyncio.sleep(delay) + + # Should never reach here, but just in case + if last_error: + raise last_error + raise RuntimeError(f"Retry logic failed for {func.__name__}") + + return wrapper + + return decorator + + +async def retry_async( + func: Callable[..., Any], + *args: Any, + config: Optional[RetryConfig] = None, + **kwargs: Any, +) -> Any: + """ + Retry an async function with exponential backoff. + + Args: + func: Async function to retry + *args: Positional arguments for function + config: Retry configuration + **kwargs: Keyword arguments for function + + Returns: + Function result + + Raises: + Last exception if all retries fail + """ + if config is None: + config = RetryConfig() + + last_error: Optional[Exception] = None + + for attempt in range(config.max_attempts): + try: + return await func(*args, **kwargs) + except Exception as e: + last_error = e + + if not is_retryable_error(e): + raise + + if attempt == config.max_attempts - 1: + logger.error( + "Max retries exceeded", + function=func.__name__, + attempts=config.max_attempts, + error=str(e), + ) + raise + + delay = config.get_delay(attempt) + logger.warning( + "Retrying after error", + function=func.__name__, + attempt=attempt + 1, + max_attempts=config.max_attempts, + delay=delay, + error=str(e), + ) + await asyncio.sleep(delay) + + if last_error: + raise last_error + raise RuntimeError("Retry logic failed") + + +class RetryPolicy: + """ + Retry policy for different operation types. + + Provides predefined retry configurations. + """ + + # Quick operations (search, get) + QUICK = RetryConfig( + max_attempts=2, + initial_delay=0.5, + max_delay=2.0, + exponential_base=2.0, + ) + + # Standard operations (upsert, delete) + STANDARD = RetryConfig( + max_attempts=3, + initial_delay=1.0, + max_delay=10.0, + exponential_base=2.0, + ) + + # Long operations (batch, collection create) + LONG = RetryConfig( + max_attempts=5, + initial_delay=2.0, + max_delay=60.0, + exponential_base=2.0, + ) + + # Critical operations (health check, ping) + CRITICAL = RetryConfig( + max_attempts=3, + initial_delay=0.1, + max_delay=1.0, + exponential_base=1.5, + ) diff --git a/app/models/qdrant_metrics.py b/app/models/qdrant_metrics.py new file mode 100644 index 0000000..1854b73 --- /dev/null +++ b/app/models/qdrant_metrics.py @@ -0,0 +1,159 @@ +""" +Qdrant metrics models. + +Sandi Metz Principles: +- Single Responsibility: Metrics modeling +- Small class: Focused data structures +- Clear naming: Descriptive field names +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, Field + + +class QdrantMetrics(BaseModel): + """ + Qdrant collection metrics. + + Tracks operational statistics. + """ + + # Collection stats + total_points: int = Field(default=0, ge=0, description="Total points") + total_vectors: int = Field(default=0, ge=0, description="Total vectors") + + # Operation counters + searches_performed: int = Field(default=0, ge=0, description="Search count") + points_added: int = Field(default=0, ge=0, description="Points added") + points_updated: int = Field(default=0, ge=0, description="Points updated") + points_deleted: int = Field(default=0, ge=0, description="Points deleted") + + # Performance metrics + avg_search_time_ms: float = Field( + default=0.0, ge=0.0, description="Avg search time" + ) + avg_upload_time_ms: float = Field( + default=0.0, ge=0.0, description="Avg upload time" + ) + + # Cache metrics + semantic_hits: int = Field(default=0, ge=0, description="Semantic cache hits") + semantic_misses: int = Field(default=0, ge=0, description="Semantic cache misses") + + # Error tracking + errors_count: int = Field(default=0, ge=0, description="Error count") + last_error: Optional[str] = Field(None, description="Last error message") + + @property + def cache_hit_rate(self) -> float: + """Calculate cache hit rate.""" + total = self.semantic_hits + self.semantic_misses + if total == 0: + return 0.0 + return self.semantic_hits / total + + @property + def total_operations(self) -> int: + """Calculate total operations.""" + return ( + self.searches_performed + + self.points_added + + self.points_updated + + self.points_deleted + ) + + +class OperationMetrics(BaseModel): + """ + Metrics for a specific operation. + + Tracks timing and success rate. + """ + + operation_name: str = Field(..., description="Operation name") + total_count: int = Field(default=0, ge=0, description="Total executions") + success_count: int = Field(default=0, ge=0, description="Successful executions") + failure_count: int = Field(default=0, ge=0, description="Failed executions") + total_time_ms: float = Field( + default=0.0, ge=0.0, description="Total execution time" + ) + min_time_ms: float = Field(default=0.0, ge=0.0, description="Minimum time") + max_time_ms: float = Field(default=0.0, ge=0.0, description="Maximum time") + + @property + def success_rate(self) -> float: + """Calculate success rate.""" + if self.total_count == 0: + return 0.0 + return self.success_count / self.total_count + + @property + def avg_time_ms(self) -> float: + """Calculate average execution time.""" + if self.total_count == 0: + return 0.0 + return self.total_time_ms / self.total_count + + +class SearchMetrics(BaseModel): + """ + Semantic search specific metrics. + + Tracks search performance and quality. + """ + + total_searches: int = Field(default=0, ge=0, description="Total searches") + avg_results_per_search: float = Field( + default=0.0, ge=0.0, description="Avg results" + ) + avg_search_time_ms: float = Field( + default=0.0, ge=0.0, description="Avg search time" + ) + avg_similarity_score: float = Field( + default=0.0, ge=0.0, le=1.0, description="Avg similarity" + ) + + # Score distribution + high_quality_matches: int = Field(default=0, ge=0, description="Score >= 0.9") + medium_quality_matches: int = Field( + default=0, ge=0, description="0.7 <= Score < 0.9" + ) + low_quality_matches: int = Field(default=0, ge=0, description="Score < 0.7") + + @property + def high_quality_rate(self) -> float: + """Calculate high quality match rate.""" + total = ( + self.high_quality_matches + + self.medium_quality_matches + + self.low_quality_matches + ) + if total == 0: + return 0.0 + return self.high_quality_matches / total + + +class MetricsSummary(BaseModel): + """ + Complete metrics summary. + + Aggregates all metric types. + """ + + collection_metrics: QdrantMetrics + search_metrics: SearchMetrics + operation_metrics: Dict[str, OperationMetrics] = Field(default_factory=dict) + uptime_seconds: float = Field(default=0.0, ge=0.0, description="Service uptime") + + def to_dict(self) -> Dict: + """Convert to dictionary for export.""" + return { + "collection": self.collection_metrics.model_dump(), + "search": self.search_metrics.model_dump(), + "operations": { + name: metrics.model_dump() + for name, metrics in self.operation_metrics.items() + }, + "uptime_seconds": self.uptime_seconds, + } diff --git a/app/models/qdrant_point.py b/app/models/qdrant_point.py new file mode 100644 index 0000000..2d8ea56 --- /dev/null +++ b/app/models/qdrant_point.py @@ -0,0 +1,197 @@ +""" +Qdrant point models for vector storage. + +Sandi Metz Principles: +- Single Responsibility: Point data modeling +- Small class: Focused on point representation +- Clear naming: Descriptive field names +""" + +import time +from datetime import datetime +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from pydantic import BaseModel, Field, ValidationError +from qdrant_client.models import PointStruct + +from app.models.cache_entry import CacheEntry + + +class QdrantPoint(BaseModel): + """ + Represents a point (vector + payload) in Qdrant. + + Combines embedding vector with metadata. + """ + + id: str = Field(default_factory=lambda: str(uuid4()), description="Point ID") + vector: List[float] = Field(..., description="Embedding vector") + payload: Dict[str, Any] = Field(default_factory=dict, description="Metadata") + + @classmethod + def from_cache_entry( + cls, entry: CacheEntry, embedding: List[float] + ) -> "QdrantPoint": + """ + Create point from cache entry. + + Args: + entry: Cache entry with query and response + embedding: Vector embedding of query + + Returns: + QdrantPoint instance + """ + payload = { + "query_hash": entry.query_hash, + "original_query": entry.original_query, + "response": entry.response, + "provider": entry.provider, + "model": entry.model, + "prompt_tokens": entry.prompt_tokens, + "completion_tokens": entry.completion_tokens, + "created_at": entry.created_at.timestamp(), + "cached_at": time.time(), + } + + # Add optional fields + if entry.embedding is not None: + payload["has_embedding"] = True + + return cls(vector=embedding, payload=payload) + + def to_qdrant_point(self) -> PointStruct: + """ + Convert to Qdrant PointStruct. + + Returns: + PointStruct for Qdrant API + """ + return PointStruct(id=self.id, vector=self.vector, payload=self.payload) + + @classmethod + def from_qdrant_point( + cls, point_id: str, vector: List[float], payload: Dict + ) -> "QdrantPoint": + """ + Create from Qdrant point data. + + Args: + point_id: Point ID + vector: Embedding vector + payload: Metadata dict + + Returns: + QdrantPoint instance + """ + return cls(id=point_id, vector=vector, payload=payload) + + +class SearchResult(BaseModel): + """ + Result from vector similarity search. + + Contains matched point and similarity score. + """ + + point_id: str = Field(..., description="Matched point ID") + score: float = Field(..., ge=0.0, le=1.0, description="Similarity score") + vector: Optional[List[float]] = Field(None, description="Embedding vector") + payload: Dict[str, Any] = Field(default_factory=dict, description="Metadata") + + @property + def query_hash(self) -> Optional[str]: + """Get query hash from payload.""" + return self.payload.get("query_hash") + + @property + def original_query(self) -> Optional[str]: + """Get original query from payload.""" + return self.payload.get("original_query") + + @property + def response(self) -> Optional[str]: + """Get response from payload.""" + return self.payload.get("response") + + @property + def provider(self) -> Optional[str]: + """Get provider from payload.""" + return self.payload.get("provider") + + @property + def model(self) -> Optional[str]: + """Get model from payload.""" + return self.payload.get("model") + + def to_cache_entry(self) -> Optional[CacheEntry]: + """ + Convert to cache entry. + + Returns: + CacheEntry if payload is valid, None otherwise + """ + try: + # Build kwargs for CacheEntry + kwargs: Dict[str, Any] = { + "query_hash": self.payload["query_hash"], + "original_query": self.payload["original_query"], + "response": self.payload["response"], + "provider": self.payload["provider"], + "model": self.payload["model"], + "prompt_tokens": self.payload.get("prompt_tokens", 0), + "completion_tokens": self.payload.get("completion_tokens", 0), + "embedding": self.vector, + } + + # Convert timestamp back to datetime if present + if "created_at" in self.payload: + kwargs["created_at"] = datetime.fromtimestamp( + self.payload["created_at"] + ) + + return CacheEntry(**kwargs) + except (KeyError, ValidationError, ValueError): + # KeyError: missing required field + # ValidationError: pydantic validation failed + # ValueError: invalid timestamp + return None + + +class BatchUploadResult(BaseModel): + """ + Result from batch upload operation. + + Tracks success and failure counts. + """ + + total: int = Field(..., ge=0, description="Total points") + successful: int = Field(..., ge=0, description="Successfully uploaded") + failed: int = Field(..., ge=0, description="Failed uploads") + point_ids: List[str] = Field(default_factory=list, description="Uploaded IDs") + errors: List[str] = Field(default_factory=list, description="Error messages") + + @property + def success_rate(self) -> float: + """Calculate success rate.""" + if self.total == 0: + return 0.0 + return self.successful / self.total + + @property + def has_failures(self) -> bool: + """Check if there were any failures.""" + return self.failed > 0 + + +class DeleteResult(BaseModel): + """ + Result from delete operation. + + Tracks deletion status. + """ + + deleted_count: int = Field(..., ge=0, description="Number deleted") + success: bool = Field(..., description="Operation success") + message: Optional[str] = Field(None, description="Status message") diff --git a/app/models/qdrant_schema.py b/app/models/qdrant_schema.py new file mode 100644 index 0000000..b9a8f2b --- /dev/null +++ b/app/models/qdrant_schema.py @@ -0,0 +1,195 @@ +""" +Qdrant collection schema definitions. + +Sandi Metz Principles: +- Single Responsibility: Schema configuration +- Small class: Clear schema structure +- Dependency Injection: Configuration injected +""" + +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field +from qdrant_client.models import Distance + + +class QdrantDistanceMetric(str, Enum): + """ + Distance metrics for vector similarity. + + COSINE: Cosine similarity (default for most text embeddings) + EUCLID: Euclidean distance + DOT: Dot product similarity + """ + + COSINE = "Cosine" + EUCLID = "Euclid" + DOT = "Dot" + + def to_qdrant_distance(self) -> Distance: + """Convert to Qdrant Distance enum.""" + mapping = { + self.COSINE: Distance.COSINE, + self.EUCLID: Distance.EUCLID, + self.DOT: Distance.DOT, + } + return mapping[self] + + +class VectorPayloadSchema(BaseModel): + """ + Schema for vector point payload. + + Defines metadata stored with each vector. + """ + + # Required fields + query_hash: str = Field(..., description="Hash of the original query") + original_query: str = Field(..., description="Original query text") + response: str = Field(..., description="Cached response") + + # Provider info + provider: str = Field(..., description="LLM provider used") + model: str = Field(..., description="LLM model used") + + # Token usage + prompt_tokens: int = Field(..., ge=0, description="Tokens in prompt") + completion_tokens: int = Field(..., ge=0, description="Tokens in completion") + + # Timestamps + created_at: Optional[float] = Field(None, description="Creation timestamp") + cached_at: Optional[float] = Field(None, description="Cache timestamp") + + # Additional metadata + tags: Optional[List[str]] = Field(None, description="Optional tags") + metadata: Optional[Dict[str, str]] = Field(None, description="Extra metadata") + + +class CollectionConfig(BaseModel): + """ + Configuration for Qdrant collection. + + Defines collection parameters. + """ + + name: str = Field(..., description="Collection name") + vector_size: int = Field(..., ge=1, description="Vector dimension size") + distance: QdrantDistanceMetric = Field( + default=QdrantDistanceMetric.COSINE, description="Distance metric" + ) + on_disk_payload: bool = Field(default=False, description="Store payload on disk") + hnsw_config: Optional[Dict[str, int]] = Field( + None, description="HNSW index configuration" + ) + + @property + def qdrant_distance(self) -> Distance: + """Get Qdrant Distance enum.""" + return self.distance.to_qdrant_distance() + + +class SearchConfig(BaseModel): + """ + Configuration for vector search. + + Defines search parameters. + """ + + limit: int = Field(default=5, ge=1, le=100, description="Max results") + score_threshold: float = Field( + default=0.85, ge=0.0, le=1.0, description="Minimum similarity score" + ) + exact: bool = Field(default=False, description="Exact search (no HNSW)") + with_payload: bool = Field(default=True, description="Include payload") + with_vectors: bool = Field(default=False, description="Include vectors") + + +class PointMetadata(BaseModel): + """ + Metadata for a vector point. + + Minimal info for identification. + """ + + point_id: str = Field(..., description="Unique point ID") + query_hash: str = Field(..., description="Query hash") + score: Optional[float] = Field(None, description="Similarity score") + created_at: Optional[float] = Field(None, description="Creation timestamp") + + +class IndexConfig(BaseModel): + """ + HNSW index configuration. + + Controls index performance and accuracy. + """ + + m: int = Field(default=16, ge=4, le=64, description="Links per node") + ef_construct: int = Field( + default=100, ge=4, description="Construction time accuracy" + ) + full_scan_threshold: int = Field( + default=10000, ge=0, description="Threshold for full scan" + ) + on_disk: bool = Field(default=False, description="Store index on disk") + + +class QdrantSchema: + """ + Central schema configuration. + + Provides schema constants and defaults. + """ + + # Payload field names + FIELD_QUERY_HASH = "query_hash" + FIELD_ORIGINAL_QUERY = "original_query" + FIELD_RESPONSE = "response" + FIELD_PROVIDER = "provider" + FIELD_MODEL = "model" + FIELD_PROMPT_TOKENS = "prompt_tokens" + FIELD_COMPLETION_TOKENS = "completion_tokens" + FIELD_CREATED_AT = "created_at" + FIELD_CACHED_AT = "cached_at" + FIELD_TAGS = "tags" + FIELD_METADATA = "metadata" + + # Default configurations + DEFAULT_DISTANCE = QdrantDistanceMetric.COSINE + DEFAULT_VECTOR_SIZE = 384 # sentence-transformers/all-MiniLM-L6-v2 + DEFAULT_SEARCH_LIMIT = 5 + DEFAULT_SCORE_THRESHOLD = 0.85 + + @staticmethod + def get_indexed_fields() -> List[str]: + """ + Get fields that should be indexed for filtering. + + Returns: + List of field names to index + """ + return [ + QdrantSchema.FIELD_QUERY_HASH, + QdrantSchema.FIELD_PROVIDER, + QdrantSchema.FIELD_MODEL, + QdrantSchema.FIELD_CREATED_AT, + ] + + @staticmethod + def get_required_fields() -> List[str]: + """ + Get required payload fields. + + Returns: + List of required field names + """ + return [ + QdrantSchema.FIELD_QUERY_HASH, + QdrantSchema.FIELD_ORIGINAL_QUERY, + QdrantSchema.FIELD_RESPONSE, + QdrantSchema.FIELD_PROVIDER, + QdrantSchema.FIELD_MODEL, + QdrantSchema.FIELD_PROMPT_TOKENS, + QdrantSchema.FIELD_COMPLETION_TOKENS, + ] diff --git a/app/repositories/qdrant_repository.py b/app/repositories/qdrant_repository.py new file mode 100644 index 0000000..a3be3a2 --- /dev/null +++ b/app/repositories/qdrant_repository.py @@ -0,0 +1,869 @@ +""" +Qdrant repository for vector storage and search. + +Sandi Metz Principles: +- Single Responsibility: Qdrant data access +- Small methods: Each operation isolated +- Dependency Injection: Client injected +""" + +from typing import Any, Dict, List, Optional, Union + +from qdrant_client import AsyncQdrantClient +from qdrant_client.models import Distance, Filter, PointStruct, VectorParams + +from app.cache.qdrant_errors import ErrorContext, handle_qdrant_error +from app.cache.qdrant_retry import RetryPolicy, retry_on_error +from app.config import config +from app.models.qdrant_point import ( + BatchUploadResult, + DeleteResult, + QdrantPoint, + SearchResult, +) +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class QdrantRepository: + """ + Repository for Qdrant vector operations. + + Handles low-level Qdrant interactions. + """ + + def __init__(self, client: AsyncQdrantClient): + """ + Initialize repository. + + Args: + client: Qdrant async client + """ + self._client = client + self._collection_name = config.qdrant_collection_name + self._vector_size = config.qdrant_vector_size + + async def collection_exists(self) -> bool: + """ + Check if collection exists. + + Returns: + True if exists, False otherwise + """ + try: + collections = await self._client.get_collections() + return any( + col.name == self._collection_name for col in collections.collections + ) + except Exception as e: + logger.error("Collection check failed", error=str(e)) + return False + + async def create_collection(self, distance: Distance = Distance.COSINE) -> bool: + """ + Create collection if not exists. + + Args: + distance: Distance metric (COSINE, EUCLID, DOT) + + Returns: + True if created or exists, False on error + """ + try: + exists = await self.collection_exists() + if exists: + logger.info("Collection already exists", name=self._collection_name) + return True + + await self._client.create_collection( + collection_name=self._collection_name, + vectors_config=VectorParams(size=self._vector_size, distance=distance), + ) + + logger.info( + "Collection created", + name=self._collection_name, + vector_size=self._vector_size, + distance=distance, + ) + return True + + except Exception as e: + logger.error("Collection creation failed", error=str(e)) + return False + + async def delete_collection(self) -> bool: + """ + Delete collection. + + Returns: + True if deleted successfully + """ + try: + await self._client.delete_collection(collection_name=self._collection_name) + logger.info("Collection deleted", name=self._collection_name) + return True + except Exception as e: + logger.error("Collection deletion failed", error=str(e)) + return False + + @retry_on_error(config=RetryPolicy.CRITICAL) + async def ping(self) -> bool: + """ + Ping Qdrant server. + + Returns: + True if connected, False otherwise + """ + try: + with ErrorContext("ping"): + await self._client.get_collections() + return True + except Exception as e: + logger.error("Qdrant ping failed", error=str(e)) + return False + + async def get_collection_info(self) -> Optional[dict]: + """ + Get collection information. + + Returns: + Collection info dict if successful + """ + try: + info = await self._client.get_collection( + collection_name=self._collection_name + ) + return { + "vectors_count": info.vectors_count, + "points_count": info.points_count, + "status": info.status, + "config": { + "vector_size": self._vector_size, + "distance": ( + info.config.params.vectors.distance + if isinstance(info.config.params.vectors, VectorParams) + else None + ), + }, + } + except Exception as e: + logger.error("Get collection info failed", error=str(e)) + return None + + @retry_on_error(config=RetryPolicy.STANDARD) + async def store_point(self, point: QdrantPoint) -> bool: + """ + Store a single vector point. + + Args: + point: QdrantPoint to store + + Returns: + True if stored successfully + """ + try: + with ErrorContext("store_point"): + await self._client.upsert( + collection_name=self._collection_name, + points=[point.to_qdrant_point()], + ) + + logger.info( + "Point stored", + point_id=point.id, + query_hash=point.payload.get("query_hash"), + ) + return True + + except Exception as e: + mapped_error = handle_qdrant_error(e, "store_point") + logger.error( + "Point store failed", + point_id=point.id, + error=str(mapped_error), + ) + return False + + async def store_points(self, points: List[QdrantPoint]) -> int: + """ + Store multiple vector points. + + Args: + points: List of QdrantPoints to store + + Returns: + Number of points stored successfully + """ + if not points: + return 0 + + try: + qdrant_points = [p.to_qdrant_point() for p in points] + await self._client.upsert( + collection_name=self._collection_name, points=qdrant_points + ) + + logger.info("Points stored", count=len(points)) + return len(points) + + except Exception as e: + logger.error("Batch store failed", count=len(points), error=str(e)) + return 0 + + async def point_exists(self, point_id: str) -> bool: + """ + Check if point exists by ID. + + Args: + point_id: Point ID to check + + Returns: + True if exists, False otherwise + """ + try: + points = await self._client.retrieve( + collection_name=self._collection_name, ids=[point_id] + ) + return len(points) > 0 + + except Exception as e: + logger.error("Point exists check failed", point_id=point_id, error=str(e)) + return False + + async def get_point(self, point_id: str) -> Optional[QdrantPoint]: + """ + Retrieve point by ID. + + Args: + point_id: Point ID to retrieve + + Returns: + QdrantPoint if found, None otherwise + """ + try: + points = await self._client.retrieve( + collection_name=self._collection_name, + ids=[point_id], + with_vectors=True, + with_payload=True, + ) + + if not points: + return None + + point = points[0] + # Extract vector safely + vector = point.vector if isinstance(point.vector, list) else [] + return QdrantPoint.from_qdrant_point( + point_id=str(point.id), + vector=vector, # type: ignore[arg-type] + payload=point.payload or {}, + ) + + except Exception as e: + logger.error("Get point failed", point_id=point_id, error=str(e)) + return None + + async def search_similar( + self, + query_vector: List[float], + limit: int = 5, + score_threshold: Optional[float] = None, + filter_condition: Optional[Filter] = None, + ) -> List[SearchResult]: + """ + Search for similar vectors. + + Args: + query_vector: Query embedding vector + limit: Maximum number of results + score_threshold: Minimum similarity score + filter_condition: Optional filter for search + + Returns: + List of SearchResult objects + """ + try: + results = await self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + limit=limit, + score_threshold=score_threshold, + query_filter=filter_condition, + with_payload=True, + with_vectors=False, + ) + + search_results = [ + SearchResult( + point_id=str(result.id), + score=result.score, + vector=( + result.vector # type: ignore[arg-type] + if result.vector + and isinstance(result.vector, list) + and all(isinstance(x, (int, float)) for x in result.vector) + else None + ), + payload=result.payload if result.payload else {}, + ) + for result in results + ] + + logger.info( + "Similarity search completed", + results_count=len(search_results), + threshold=score_threshold, + ) + + return search_results + + except Exception as e: + logger.error("Similarity search failed", error=str(e)) + return [] + + async def search_similar_with_vectors( + self, + query_vector: List[float], + limit: int = 5, + score_threshold: Optional[float] = None, + ) -> List[SearchResult]: + """ + Search for similar vectors including vector data. + + Args: + query_vector: Query embedding vector + limit: Maximum number of results + score_threshold: Minimum similarity score + + Returns: + List of SearchResult objects with vectors + """ + try: + results = await self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + limit=limit, + score_threshold=score_threshold, + with_payload=True, + with_vectors=True, + ) + + search_results = [ + SearchResult( + point_id=str(result.id), + score=result.score, + vector=( + result.vector # type: ignore[arg-type] + if result.vector + and isinstance(result.vector, list) + and all(isinstance(x, (int, float)) for x in result.vector) + else None + ), + payload=result.payload if result.payload else {}, + ) + for result in results + ] + + logger.info( + "Similarity search with vectors completed", + results_count=len(search_results), + ) + + return search_results + + except Exception as e: + logger.error("Similarity search with vectors failed", error=str(e)) + return [] + + async def batch_upload( + self, points: List[QdrantPoint], batch_size: int = 100 + ) -> BatchUploadResult: + """ + Upload points in batches with progress tracking. + + Args: + points: List of QdrantPoints to upload + batch_size: Number of points per batch + + Returns: + BatchUploadResult with statistics + """ + if not points: + return BatchUploadResult( + total=0, successful=0, failed=0, point_ids=[], errors=[] + ) + + total = len(points) + successful = 0 + failed = 0 + uploaded_ids = [] + errors = [] + + try: + # Process in batches + for i in range(0, total, batch_size): + batch = points[i : i + batch_size] + + try: + qdrant_points = [p.to_qdrant_point() for p in batch] + await self._client.upsert( + collection_name=self._collection_name, points=qdrant_points + ) + + # Track success + successful += len(batch) + uploaded_ids.extend([p.id for p in batch]) + + logger.info( + "Batch uploaded", + batch_num=i // batch_size + 1, + batch_size=len(batch), + progress=f"{successful}/{total}", + ) + + except Exception as batch_error: + # Track failure + failed += len(batch) + error_msg = ( + f"Batch {i // batch_size + 1} failed: {str(batch_error)}" + ) + errors.append(error_msg) + logger.error("Batch upload failed", error=error_msg) + + result = BatchUploadResult( + total=total, + successful=successful, + failed=failed, + point_ids=uploaded_ids, + errors=errors, + ) + + logger.info( + "Batch upload completed", + total=total, + successful=successful, + failed=failed, + success_rate=result.success_rate, + ) + + return result + + except Exception as e: + logger.error("Batch upload fatal error", error=str(e)) + return BatchUploadResult( + total=total, + successful=successful, + failed=total - successful, + point_ids=uploaded_ids, + errors=errors + [f"Fatal error: {str(e)}"], + ) + + async def batch_upload_with_retry( + self, + points: List[QdrantPoint], + batch_size: int = 100, + max_retries: int = 3, + ) -> BatchUploadResult: + """ + Upload points with automatic retry on failure. + + Args: + points: List of QdrantPoints to upload + batch_size: Number of points per batch + max_retries: Maximum retry attempts + + Returns: + BatchUploadResult with statistics + """ + retry_count = 0 + last_result = None + + while retry_count <= max_retries: + result = await self.batch_upload(points, batch_size) + + if not result.has_failures: + return result + + # Retry failed batches + if retry_count < max_retries: + logger.warning( + "Retrying failed uploads", + retry=retry_count + 1, + max_retries=max_retries, + failed=result.failed, + ) + retry_count += 1 + last_result = result + else: + logger.error("Max retries exceeded", failed=result.failed) + return result + + return last_result or BatchUploadResult( + total=len(points), successful=0, failed=len(points), errors=[] + ) + + async def delete_point(self, point_id: str) -> DeleteResult: + """ + Delete a single point by ID. + + Args: + point_id: Point ID to delete + + Returns: + DeleteResult with operation status + """ + try: + await self._client.delete( + collection_name=self._collection_name, points_selector=[point_id] + ) + + logger.info("Point deleted", point_id=point_id) + return DeleteResult( + deleted_count=1, success=True, message=f"Point {point_id} deleted" + ) + + except Exception as e: + logger.error("Point deletion failed", point_id=point_id, error=str(e)) + return DeleteResult( + deleted_count=0, success=False, message=f"Deletion failed: {str(e)}" + ) + + async def delete_points(self, point_ids: List[str]) -> DeleteResult: + """ + Delete multiple points by IDs. + + Args: + point_ids: List of point IDs to delete + + Returns: + DeleteResult with operation status + """ + if not point_ids: + return DeleteResult( + deleted_count=0, success=True, message="No points to delete" + ) + + try: + await self._client.delete( + collection_name=self._collection_name, points_selector=point_ids + ) + + logger.info("Points deleted", count=len(point_ids)) + return DeleteResult( + deleted_count=len(point_ids), + success=True, + message=f"Deleted {len(point_ids)} points", + ) + + except Exception as e: + logger.error("Batch deletion failed", count=len(point_ids), error=str(e)) + return DeleteResult( + deleted_count=0, success=False, message=f"Deletion failed: {str(e)}" + ) + + async def delete_by_filter(self, filter_condition: Filter) -> DeleteResult: + """ + Delete points matching filter condition. + + Args: + filter_condition: Filter to match points + + Returns: + DeleteResult with operation status + """ + try: + # Note: Qdrant doesn't return count for filter-based deletion + await self._client.delete( + collection_name=self._collection_name, + points_selector=filter_condition, + ) + + logger.info("Points deleted by filter") + return DeleteResult( + deleted_count=-1, # Unknown count + success=True, + message="Points deleted by filter", + ) + + except Exception as e: + logger.error("Filter deletion failed", error=str(e)) + return DeleteResult( + deleted_count=0, success=False, message=f"Deletion failed: {str(e)}" + ) + + async def delete_by_query_hash(self, query_hash: str) -> DeleteResult: + """ + Delete points by query hash. + + Args: + query_hash: Query hash to match + + Returns: + DeleteResult with operation status + """ + from app.cache.qdrant_filter import create_filter + + filter_obj = create_filter().with_query_hash(query_hash).build() + + if not filter_obj: + return DeleteResult( + deleted_count=0, success=False, message="Failed to create filter" + ) + + return await self.delete_by_filter(filter_obj) + + async def update_point_payload( + self, point_id: str, payload: Dict[str, Any] + ) -> bool: + """ + Update point payload metadata. + + Args: + point_id: Point ID to update + payload: New payload data + + Returns: + True if updated successfully + """ + try: + await self._client.set_payload( + collection_name=self._collection_name, + payload=payload, + points=[point_id], + ) + + logger.info("Point payload updated", point_id=point_id) + return True + + except Exception as e: + logger.error("Payload update failed", point_id=point_id, error=str(e)) + return False + + async def update_point_vector(self, point_id: str, vector: List[float]) -> bool: + """ + Update point vector. + + Args: + point_id: Point ID to update + vector: New vector data + + Returns: + True if updated successfully + """ + try: + # Note: update_vectors API usage - using upsert instead + await self._client.upsert( + collection_name=self._collection_name, + points=[PointStruct(id=point_id, vector=vector, payload={})], + ) + + logger.info("Point vector updated", point_id=point_id) + return True + + except Exception as e: + logger.error("Vector update failed", point_id=point_id, error=str(e)) + return False + + async def update_point(self, point: QdrantPoint) -> bool: + """ + Update complete point (vector + payload). + + Args: + point: QdrantPoint with updated data + + Returns: + True if updated successfully + """ + try: + # Upsert replaces the point completely + await self._client.upsert( + collection_name=self._collection_name, + points=[point.to_qdrant_point()], + ) + + logger.info("Point updated", point_id=point.id) + return True + + except Exception as e: + logger.error("Point update failed", point_id=point.id, error=str(e)) + return False + + async def partial_update_payload( + self, point_id: str, updates: Dict[str, Any] + ) -> bool: + """ + Partially update payload fields. + + Args: + point_id: Point ID to update + updates: Fields to update + + Returns: + True if updated successfully + """ + try: + await self._client.set_payload( + collection_name=self._collection_name, + payload=updates, + points=[point_id], + ) + + logger.info( + "Partial payload update", point_id=point_id, fields=list(updates.keys()) + ) + return True + + except Exception as e: + logger.error("Partial update failed", point_id=point_id, error=str(e)) + return False + + async def delete_payload_fields( + self, point_id: str, field_names: List[str] + ) -> bool: + """ + Delete specific payload fields. + + Args: + point_id: Point ID + field_names: Fields to delete + + Returns: + True if deleted successfully + """ + try: + await self._client.delete_payload( + collection_name=self._collection_name, + keys=field_names, + points=[point_id], + ) + + logger.info("Payload fields deleted", point_id=point_id, fields=field_names) + return True + + except Exception as e: + logger.error("Field deletion failed", point_id=point_id, error=str(e)) + return False + + async def scroll_points( + self, + limit: int = 100, + offset: Optional[str] = None, + filter_condition: Optional[Filter] = None, + with_vectors: bool = False, + ) -> tuple[List[QdrantPoint], Optional[Union[int, str]]]: + """ + Scroll through points with pagination. + + Args: + limit: Number of points per page + offset: Offset ID for pagination + filter_condition: Optional filter + with_vectors: Include vectors in results + + Returns: + Tuple of (points, next_offset) + """ + try: + result = await self._client.scroll( + collection_name=self._collection_name, + limit=limit, + offset=offset, + scroll_filter=filter_condition, + with_payload=True, + with_vectors=with_vectors, + ) + + points = [ + QdrantPoint.from_qdrant_point( + point_id=str(point.id), + vector=( + point.vector # type: ignore[arg-type] + if isinstance(point.vector, list) + and all(isinstance(x, (int, float)) for x in point.vector) + else [] + ), + payload=point.payload if point.payload else {}, + ) + for point in result[0] + ] + + next_offset = result[1] # Next offset for pagination + + logger.info( + "Scroll completed", + returned=len(points), + has_next=next_offset is not None, + ) + + return points, next_offset + + except Exception as e: + logger.error("Scroll failed", error=str(e)) + return [], None + + async def count_points(self, filter_condition: Optional[Filter] = None) -> int: + """ + Count points in collection. + + Args: + filter_condition: Optional filter + + Returns: + Number of points + """ + try: + result = await self._client.count( + collection_name=self._collection_name, + count_filter=filter_condition, + exact=True, + ) + + count = result.count + logger.info("Point count", count=count) + return count + + except Exception as e: + logger.error("Count failed", error=str(e)) + return 0 + + async def get_all_points( + self, batch_size: int = 100, filter_condition: Optional[Filter] = None + ) -> List[QdrantPoint]: + """ + Get all points using scroll pagination. + + Args: + batch_size: Points per batch + filter_condition: Optional filter + + Returns: + List of all points + """ + all_points = [] + offset = None + + try: + while True: + points, next_offset = await self.scroll_points( + limit=batch_size, + offset=offset, + filter_condition=filter_condition, + with_vectors=False, + ) + + all_points.extend(points) + + if next_offset is None: + break + + offset = next_offset + + logger.info("Retrieved all points", total=len(all_points)) + return all_points + + except Exception as e: + logger.error("Get all points failed", error=str(e)) + return all_points # Return what we got so far diff --git a/app/similarity/__init__.py b/app/similarity/__init__.py index e69de29..1531e74 100644 --- a/app/similarity/__init__.py +++ b/app/similarity/__init__.py @@ -0,0 +1,65 @@ +""" +Similarity calculation and threshold tuning utilities. + +This module provides tools for vector similarity computation, +score interpretation, and threshold optimization. +""" + +from app.similarity.score_calculator import ( + ScoreCalculator, + ScoreInterpretation, + SimilarityLevel, + SimilarityScoreCalculator, + cosine_similarity, + euclidean_distance, + interpret_cosine_score, +) +from app.similarity.threshold_tuner import ( + ThresholdMetrics, + ThresholdOptimizationGoal, + ThresholdRecommendation, + ThresholdTuner, + UseCase, + evaluate_threshold_quality, + get_cache_threshold, + get_exact_match_threshold, + tune_threshold, +) +from app.similarity.vector_normalizer import ( + NormalizationType, + VectorNormalizer, + clip_vector, + l1_normalize, + l2_normalize, + max_normalize, + standardize_vector, +) + +__all__ = [ + # Score calculator + "SimilarityScoreCalculator", + "ScoreCalculator", + "SimilarityLevel", + "ScoreInterpretation", + "cosine_similarity", + "euclidean_distance", + "interpret_cosine_score", + # Threshold tuner + "ThresholdTuner", + "ThresholdMetrics", + "ThresholdRecommendation", + "ThresholdOptimizationGoal", + "UseCase", + "tune_threshold", + "get_cache_threshold", + "get_exact_match_threshold", + "evaluate_threshold_quality", + # Vector normalizer + "VectorNormalizer", + "NormalizationType", + "l1_normalize", + "l2_normalize", + "max_normalize", + "standardize_vector", + "clip_vector", +] diff --git a/app/similarity/score_calculator.py b/app/similarity/score_calculator.py new file mode 100644 index 0000000..4c4dd6a --- /dev/null +++ b/app/similarity/score_calculator.py @@ -0,0 +1,329 @@ +""" +Similarity score calculation and interpretation. + +Sandi Metz Principles: +- Single Responsibility: Score calculation +- Small methods: Each calculation isolated +- Clear naming: Descriptive method names +""" + +import math +from enum import Enum +from typing import List + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class SimilarityLevel(str, Enum): + """ + Semantic similarity quality levels. + + Helps interpret similarity scores. + """ + + EXACT = "exact" # > 0.95 + VERY_HIGH = "very_high" # 0.90 - 0.95 + HIGH = "high" # 0.80 - 0.90 + MODERATE = "moderate" # 0.60 - 0.80 + LOW = "low" # 0.40 - 0.60 + VERY_LOW = "very_low" # < 0.40 + + +class SimilarityScoreCalculator: + """ + Calculator for similarity scores. + + Provides score calculation and interpretation. + """ + + # Threshold definitions + EXACT_THRESHOLD = 0.95 + VERY_HIGH_THRESHOLD = 0.90 + HIGH_THRESHOLD = 0.80 + MODERATE_THRESHOLD = 0.60 + LOW_THRESHOLD = 0.40 + + @staticmethod + def cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """ + Calculate cosine similarity between vectors. + + Args: + vec1: First vector + vec2: Second vector + + Returns: + Cosine similarity score (-1.0 to 1.0) + + Raises: + ValueError: If vectors have different dimensions + """ + if len(vec1) != len(vec2): + raise ValueError( + f"Vectors must have same dimensions: {len(vec1)} != {len(vec2)}" + ) + + if len(vec1) == 0: + raise ValueError("Vectors must have same dimensions: 0 != 0") + + try: + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + magnitude1 = math.sqrt(sum(a * a for a in vec1)) + magnitude2 = math.sqrt(sum(b * b for b in vec2)) + + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + + similarity = dot_product / (magnitude1 * magnitude2) + return similarity + + except Exception as e: + logger.error("Cosine similarity calculation failed", error=str(e)) + return 0.0 + + @staticmethod + def euclidean_distance(vec1: List[float], vec2: List[float]) -> float: + """ + Calculate Euclidean distance between vectors. + + Args: + vec1: First vector + vec2: Second vector + + Returns: + Euclidean distance + + Raises: + ValueError: If vectors have different dimensions + """ + if len(vec1) != len(vec2): + raise ValueError( + f"Vectors must have same dimensions: {len(vec1)} != {len(vec2)}" + ) + + try: + return math.sqrt(sum((a - b) ** 2 for a, b in zip(vec1, vec2))) + except Exception as e: + logger.error("Euclidean distance calculation failed", error=str(e)) + return float("inf") + + @staticmethod + def euclidean_to_similarity(distance: float, max_distance: float = 2.0) -> float: + """ + Convert Euclidean distance to similarity score. + + Args: + distance: Euclidean distance + max_distance: Maximum expected distance + + Returns: + Similarity score (0.0 to 1.0) + """ + if distance >= max_distance: + return 0.0 + + return 1.0 - (distance / max_distance) + + @staticmethod + def dot_product(vec1: List[float], vec2: List[float]) -> float: + """ + Calculate dot product of vectors. + + Args: + vec1: First vector + vec2: Second vector + + Returns: + Dot product value + """ + if len(vec1) != len(vec2): + logger.error("Vector size mismatch", v1=len(vec1), v2=len(vec2)) + return 0.0 + + try: + return sum(a * b for a, b in zip(vec1, vec2)) + except Exception as e: + logger.error("Dot product calculation failed", error=str(e)) + return 0.0 + + @classmethod + def interpret_score(cls, score: float) -> SimilarityLevel: + """ + Interpret similarity score quality. + + Args: + score: Similarity score (0.0 to 1.0) + + Returns: + SimilarityLevel enum + """ + if score > cls.EXACT_THRESHOLD: + return SimilarityLevel.EXACT + elif score > cls.VERY_HIGH_THRESHOLD: + return SimilarityLevel.VERY_HIGH + elif score > cls.HIGH_THRESHOLD: + return SimilarityLevel.HIGH + elif score >= cls.MODERATE_THRESHOLD: + return SimilarityLevel.MODERATE + elif score >= cls.LOW_THRESHOLD: + return SimilarityLevel.LOW + else: + return SimilarityLevel.VERY_LOW + + @classmethod + def should_cache_hit(cls, score: float, threshold: float = 0.85) -> bool: + """ + Determine if score qualifies as cache hit. + + Args: + score: Similarity score + threshold: Minimum acceptable score + + Returns: + True if score meets threshold + """ + return score >= threshold + + @classmethod + def get_confidence_level(cls, score: float) -> str: + """ + Get human-readable confidence level. + + Args: + score: Similarity score + + Returns: + Confidence description + """ + level = cls.interpret_score(score) + + descriptions = { + SimilarityLevel.EXACT: "Exact match - virtually identical", + SimilarityLevel.VERY_HIGH: "Very high confidence - strong match", + SimilarityLevel.HIGH: "High confidence - good match", + SimilarityLevel.MODERATE: "Moderate confidence - acceptable match", + SimilarityLevel.LOW: "Low confidence - weak match", + SimilarityLevel.VERY_LOW: "Very low confidence - poor match", + } + + return descriptions.get(level, "Unknown confidence") + + @classmethod + def calculate( + cls, vec1: List[float], vec2: List[float], metric: str = "cosine" + ) -> float: + """ + Calculate similarity using specified metric. + + Args: + vec1: First vector + vec2: Second vector + metric: Similarity metric (cosine, euclidean) + + Returns: + Similarity score + + Raises: + ValueError: If unknown metric specified + """ + if metric == "cosine": + return cls.cosine_similarity(vec1, vec2) + elif metric == "euclidean": + return cls.euclidean_distance(vec1, vec2) + else: + raise ValueError(f"Unknown metric: {metric}") + + @staticmethod + def is_match(score: float, threshold: float = 0.85) -> bool: + """ + Check if score meets threshold for cache hit. + + Args: + score: Similarity score + threshold: Minimum threshold + + Returns: + True if score meets or exceeds threshold + """ + return score >= threshold + + @staticmethod + def normalize_score(score: float, metric: str = "cosine") -> float: + """ + Normalize score to 0-1 range. + + Args: + score: Raw score + metric: Metric type (cosine, euclidean) + + Returns: + Normalized score (0-1) + """ + if metric == "cosine": + # Cosine already in [0, 1] for our use case (non-negative similarity) + return score + elif metric == "euclidean": + # Convert distance to similarity (0 distance = 1.0 similarity) + # Using inverse relationship + max_distance = 10.0 # Configurable max + if score == 0.0: + return 1.0 + if score >= max_distance: + return 0.0 + return 1.0 - (score / max_distance) + return score + + @classmethod + def get_interpretation(cls, score: float) -> SimilarityLevel: + """ + Get score interpretation. + + Args: + score: Similarity score + + Returns: + SimilarityLevel enum + """ + return cls.interpret_score(score) + + @staticmethod + def calculate_match_quality(score: float) -> dict: + """ + Calculate detailed match quality metrics. + + Args: + score: Similarity score + + Returns: + Dict with quality metrics + """ + return { + "score": round(score, 4), + "percentage": round(score * 100, 2), + "level": SimilarityScoreCalculator.interpret_score(score).value, + "confidence": SimilarityScoreCalculator.get_confidence_level(score), + "is_cache_hit": SimilarityScoreCalculator.should_cache_hit(score), + } + + +# Convenience aliases for easier imports +ScoreCalculator = SimilarityScoreCalculator +ScoreInterpretation = SimilarityLevel + + +# Standalone functions for convenience +def cosine_similarity(vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between vectors.""" + return SimilarityScoreCalculator.cosine_similarity(vec1, vec2) + + +def euclidean_distance(vec1: List[float], vec2: List[float]) -> float: + """Calculate Euclidean distance between vectors.""" + return SimilarityScoreCalculator.euclidean_distance(vec1, vec2) + + +def interpret_cosine_score(score: float) -> SimilarityLevel: + """Interpret cosine similarity score.""" + return SimilarityScoreCalculator.interpret_score(score) diff --git a/app/similarity/threshold_tuner.py b/app/similarity/threshold_tuner.py new file mode 100644 index 0000000..b3d2217 --- /dev/null +++ b/app/similarity/threshold_tuner.py @@ -0,0 +1,469 @@ +""" +Semantic similarity threshold tuning utilities. + +Sandi Metz Principles: +- Single Responsibility: Threshold optimization +- Small methods: Each analysis isolated +- Clear naming: Descriptive method names +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class ThresholdOptimizationGoal(str, Enum): + """Optimization goals for threshold tuning.""" + + PRECISION = "precision" # Minimize false positives + RECALL = "recall" # Minimize false negatives + F1_SCORE = "f1_score" # Balance precision and recall + BALANCED = "balanced" # Equal weight to precision and recall + + +class UseCase(str, Enum): + """Common use cases with different threshold requirements.""" + + EXACT_MATCH = "exact_match" # Strict matching (high threshold) + CACHE_HIT = "cache_hit" # Balance between hits and accuracy + SIMILAR_CONTENT = "similar_content" # Broad matching (lower threshold) + DEDUPLICATION = "deduplication" # Avoid duplicates (moderate threshold) + + +@dataclass +class ThresholdMetrics: + """Metrics for threshold evaluation.""" + + threshold: float + true_positives: int + false_positives: int + true_negatives: int + false_negatives: int + + @property + def precision(self) -> float: + """Calculate precision (TP / (TP + FP)).""" + denominator = self.true_positives + self.false_positives + if denominator == 0: + return 0.0 + return self.true_positives / denominator + + @property + def recall(self) -> float: + """Calculate recall (TP / (TP + FN)).""" + denominator = self.true_positives + self.false_negatives + if denominator == 0: + return 0.0 + return self.true_positives / denominator + + @property + def f1_score(self) -> float: + """Calculate F1 score (harmonic mean of precision and recall).""" + if self.precision + self.recall == 0: + return 0.0 + return 2 * (self.precision * self.recall) / (self.precision + self.recall) + + @property + def accuracy(self) -> float: + """Calculate accuracy ((TP + TN) / Total).""" + total = ( + self.true_positives + + self.false_positives + + self.true_negatives + + self.false_negatives + ) + if total == 0: + return 0.0 + return (self.true_positives + self.true_negatives) / total + + def __str__(self) -> str: + """Format metrics as readable string.""" + return ( + f"Threshold: {self.threshold:.3f}\n" + f" Precision: {self.precision:.3f}\n" + f" Recall: {self.recall:.3f}\n" + f" F1 Score: {self.f1_score:.3f}\n" + f" Accuracy: {self.accuracy:.3f}\n" + f" TP: {self.true_positives} | FP: {self.false_positives}\n" + f" TN: {self.true_negatives} | FN: {self.false_negatives}" + ) + + +@dataclass +class ThresholdRecommendation: + """Recommended threshold with justification.""" + + threshold: float + use_case: UseCase + metrics: ThresholdMetrics + reasoning: str + confidence: float + alternative_thresholds: Dict[str, float] = field(default_factory=dict) + + def summary(self) -> str: + """Generate recommendation summary.""" + lines = [ + f"\nThreshold Recommendation for {self.use_case.value}:", + f" Recommended Threshold: {self.threshold:.3f}", + f" Confidence: {self.confidence:.1%}", + "\nReasoning:", + f" {self.reasoning}", + "\nExpected Performance:", + f"{self.metrics}", + ] + + if self.alternative_thresholds: + lines.append("\nAlternative Thresholds:") + for name, value in self.alternative_thresholds.items(): + lines.append(f" {name}: {value:.3f}") + + return "\n".join(lines) + + +class ThresholdTuner: + """ + Utilities for tuning similarity thresholds. + + Analyzes test data to recommend optimal thresholds. + """ + + # Default recommended thresholds by use case + DEFAULT_THRESHOLDS = { + UseCase.EXACT_MATCH: 0.95, + UseCase.CACHE_HIT: 0.85, + UseCase.SIMILAR_CONTENT: 0.70, + UseCase.DEDUPLICATION: 0.80, + } + + @staticmethod + def evaluate_threshold( + scores: List[float], + ground_truth: List[bool], + threshold: float, + ) -> ThresholdMetrics: + """ + Evaluate threshold performance. + + Args: + scores: Similarity scores + ground_truth: True labels (True = should match) + threshold: Threshold to evaluate + + Returns: + ThresholdMetrics with performance data + + Raises: + ValueError: If scores and ground_truth have different lengths + """ + if len(scores) != len(ground_truth): + raise ValueError( + f"Scores and ground_truth must have same length: " + f"{len(scores)} != {len(ground_truth)}" + ) + + tp = fp = tn = fn = 0 + + for score, is_match in zip(scores, ground_truth): + predicted_match = score >= threshold + + if predicted_match and is_match: + tp += 1 + elif predicted_match and not is_match: + fp += 1 + elif not predicted_match and not is_match: + tn += 1 + else: + fn += 1 + + return ThresholdMetrics( + threshold=threshold, + true_positives=tp, + false_positives=fp, + true_negatives=tn, + false_negatives=fn, + ) + + @classmethod + def find_optimal_threshold( + cls, + scores: List[float], + ground_truth: List[bool], + goal: ThresholdOptimizationGoal = ThresholdOptimizationGoal.F1_SCORE, + min_threshold: float = 0.5, + max_threshold: float = 0.99, + step: float = 0.01, + ) -> Tuple[float, ThresholdMetrics]: + """ + Find optimal threshold using grid search. + + Args: + scores: Similarity scores + ground_truth: True labels + goal: Optimization goal + min_threshold: Minimum threshold to test + max_threshold: Maximum threshold to test + step: Step size for grid search + + Returns: + Tuple of (optimal_threshold, metrics) + """ + # Initialize with first threshold + best_threshold = min_threshold + best_metrics = cls.evaluate_threshold(scores, ground_truth, min_threshold) + + if goal == ThresholdOptimizationGoal.PRECISION: + best_score = best_metrics.precision + elif goal == ThresholdOptimizationGoal.RECALL: + best_score = best_metrics.recall + elif goal == ThresholdOptimizationGoal.F1_SCORE: + best_score = best_metrics.f1_score + else: # BALANCED + best_score = (best_metrics.precision + best_metrics.recall) / 2 + + current = min_threshold + step + while current <= max_threshold: + metrics = cls.evaluate_threshold(scores, ground_truth, current) + + if goal == ThresholdOptimizationGoal.PRECISION: + score = metrics.precision + elif goal == ThresholdOptimizationGoal.RECALL: + score = metrics.recall + elif goal == ThresholdOptimizationGoal.F1_SCORE: + score = metrics.f1_score + else: # BALANCED + score = (metrics.precision + metrics.recall) / 2 + + if score > best_score: + best_score = score + best_threshold = current + best_metrics = metrics + + current += step + + logger.info( + "Optimal threshold found", + threshold=best_threshold, + goal=goal.value, + score=best_score, + ) + + return best_threshold, best_metrics + + @classmethod + def recommend_threshold( + cls, + use_case: UseCase, + scores: Optional[List[float]] = None, + ground_truth: Optional[List[bool]] = None, + ) -> ThresholdRecommendation: + """ + Recommend threshold for specific use case. + + Args: + use_case: Target use case + scores: Optional test scores for tuning + ground_truth: Optional ground truth labels + + Returns: + ThresholdRecommendation with detailed analysis + """ + if scores is not None and ground_truth is not None: + # Tune based on provided data + goal_map = { + UseCase.EXACT_MATCH: ThresholdOptimizationGoal.PRECISION, + UseCase.CACHE_HIT: ThresholdOptimizationGoal.F1_SCORE, + UseCase.SIMILAR_CONTENT: ThresholdOptimizationGoal.RECALL, + UseCase.DEDUPLICATION: ThresholdOptimizationGoal.F1_SCORE, + } + + goal = goal_map.get(use_case, ThresholdOptimizationGoal.F1_SCORE) + threshold, metrics = cls.find_optimal_threshold(scores, ground_truth, goal) + + # Calculate alternatives + alt_thresholds = {} + for alt_goal in ThresholdOptimizationGoal: + if alt_goal != goal: + alt_thresh, _ = cls.find_optimal_threshold( + scores, ground_truth, alt_goal + ) + alt_thresholds[alt_goal.value] = alt_thresh + + confidence = metrics.f1_score + else: + # Use default threshold + threshold = cls.DEFAULT_THRESHOLDS[use_case] + metrics = ThresholdMetrics( + threshold=threshold, + true_positives=0, + false_positives=0, + true_negatives=0, + false_negatives=0, + ) + alt_thresholds = { + name.value: thresh + for name, thresh in cls.DEFAULT_THRESHOLDS.items() + if name != use_case + } + confidence = 0.8 + + reasoning = cls._generate_reasoning(use_case, threshold) + + return ThresholdRecommendation( + threshold=threshold, + use_case=use_case, + metrics=metrics, + reasoning=reasoning, + confidence=confidence, + alternative_thresholds=alt_thresholds, + ) + + @staticmethod + def _generate_reasoning(use_case: UseCase, threshold: float) -> str: + """Generate reasoning for threshold recommendation.""" + if use_case == UseCase.EXACT_MATCH: + return ( + f"For exact matching, a high threshold ({threshold:.3f}) minimizes " + f"false positives, ensuring only nearly identical queries match." + ) + elif use_case == UseCase.CACHE_HIT: + return ( + f"For cache hits, threshold {threshold:.3f} balances between " + f"cache effectiveness and accuracy, optimizing for F1 score." + ) + elif use_case == UseCase.SIMILAR_CONTENT: + return ( + f"For similar content detection, a moderate threshold " + f"({threshold:.3f}) allows broader matching while " + f"maintaining relevance." + ) + elif use_case == UseCase.DEDUPLICATION: + return ( + f"For deduplication, threshold {threshold:.3f} prevents duplicates " + f"while avoiding false merges of distinct content." + ) + return f"Recommended threshold: {threshold:.3f}" + + @classmethod + def analyze_threshold_range( + cls, + scores: List[float], + ground_truth: List[bool], + start: float = 0.5, + end: float = 0.99, + step: float = 0.05, + ) -> List[ThresholdMetrics]: + """ + Analyze performance across threshold range. + + Args: + scores: Similarity scores + ground_truth: True labels + start: Start threshold + end: End threshold + step: Step size + + Returns: + List of ThresholdMetrics for each threshold + """ + results = [] + current = start + + while current <= end: + metrics = cls.evaluate_threshold(scores, ground_truth, current) + results.append(metrics) + current += step + + return results + + @staticmethod + def format_analysis_report(metrics_list: List[ThresholdMetrics]) -> str: + """ + Format analysis report for threshold range. + + Args: + metrics_list: List of metrics to report + + Returns: + Formatted report string + """ + lines = [ + "\nThreshold Analysis Report", + "=" * 80, + f"{'Threshold':<12} {'Precision':<12} {'Recall':<12} " + f"{'F1 Score':<12} {'Accuracy':<12}", + "-" * 80, + ] + + for metrics in metrics_list: + lines.append( + f"{metrics.threshold:<12.3f} {metrics.precision:<12.3f} " + f"{metrics.recall:<12.3f} {metrics.f1_score:<12.3f} " + f"{metrics.accuracy:<12.3f}" + ) + + return "\n".join(lines) + + @classmethod + def get_threshold_for_use_case(cls, use_case: UseCase) -> float: + """ + Get recommended threshold for use case. + + Args: + use_case: Target use case + + Returns: + Recommended threshold value + """ + return cls.DEFAULT_THRESHOLDS.get(use_case, 0.85) + + +# Convenience functions +def tune_threshold( + scores: List[float], + ground_truth: List[bool], + goal: ThresholdOptimizationGoal = ThresholdOptimizationGoal.F1_SCORE, +) -> float: + """ + Quick threshold tuning. + + Args: + scores: Similarity scores + ground_truth: True labels + goal: Optimization goal + + Returns: + Optimal threshold + """ + threshold, _ = ThresholdTuner.find_optimal_threshold(scores, ground_truth, goal) + return threshold + + +def get_cache_threshold() -> float: + """Get recommended threshold for cache hits.""" + return ThresholdTuner.get_threshold_for_use_case(UseCase.CACHE_HIT) + + +def get_exact_match_threshold() -> float: + """Get recommended threshold for exact matching.""" + return ThresholdTuner.get_threshold_for_use_case(UseCase.EXACT_MATCH) + + +def evaluate_threshold_quality( + scores: List[float], ground_truth: List[bool], threshold: float +) -> ThresholdMetrics: + """ + Evaluate threshold quality. + + Args: + scores: Similarity scores + ground_truth: True labels + threshold: Threshold to evaluate + + Returns: + ThresholdMetrics with performance data + """ + return ThresholdTuner.evaluate_threshold(scores, ground_truth, threshold) diff --git a/app/similarity/vector_normalizer.py b/app/similarity/vector_normalizer.py new file mode 100644 index 0000000..52e72fe --- /dev/null +++ b/app/similarity/vector_normalizer.py @@ -0,0 +1,336 @@ +""" +Vector normalization utilities. + +Sandi Metz Principles: +- Single Responsibility: Vector normalization +- Small methods: Each operation isolated +- Clear naming: Descriptive method names +""" + +import math +from enum import Enum +from typing import List + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class NormalizationType(str, Enum): + """Normalization type options.""" + + L1 = "l1" + L2 = "l2" + MAX = "max" + + +class VectorNormalizer: + """ + Utilities for vector normalization. + + Ensures vectors are properly normalized for distance calculations. + """ + + @staticmethod + def l2_normalize(vector: List[float]) -> List[float]: + """ + Normalize vector using L2 (Euclidean) norm. + + Args: + vector: Input vector + + Returns: + Normalized vector with unit length + + Raises: + ValueError: If vector is empty or normalization fails + """ + if not vector: + raise ValueError("Cannot normalize empty vector") + + try: + magnitude = math.sqrt(sum(x * x for x in vector)) + + if magnitude == 0: + logger.warning("Cannot normalize zero vector, returning as-is") + return vector + + return [x / magnitude for x in vector] + + except Exception as e: + logger.error("L2 normalization failed", error=str(e)) + raise ValueError(f"L2 normalization failed: {e}") from e + + @staticmethod + def l1_normalize(vector: List[float]) -> List[float]: + """ + Normalize vector using L1 (Manhattan) norm. + + Args: + vector: Input vector + + Returns: + L1 normalized vector + + Raises: + ValueError: If vector is empty or normalization fails + """ + if not vector: + raise ValueError("Cannot normalize empty vector") + + try: + total = sum(abs(x) for x in vector) + + if total == 0: + logger.warning("Cannot normalize zero vector, returning as-is") + return vector + + return [x / total for x in vector] + + except Exception as e: + logger.error("L1 normalization failed", error=str(e)) + raise ValueError(f"L1 normalization failed: {e}") from e + + @staticmethod + def max_normalize(vector: List[float]) -> List[float]: + """ + Normalize vector by dividing by maximum absolute value. + + Args: + vector: Input vector + + Returns: + Max normalized vector + + Raises: + ValueError: If vector is empty or normalization fails + """ + if not vector: + raise ValueError("Cannot normalize empty vector") + + try: + max_val = max(abs(x) for x in vector) + + if max_val == 0: + logger.warning("Cannot normalize zero vector, returning as-is") + return vector + + return [x / max_val for x in vector] + + except Exception as e: + logger.error("Max normalization failed", error=str(e)) + raise ValueError(f"Max normalization failed: {e}") from e + + @staticmethod + def magnitude(vector: List[float]) -> float: + """ + Calculate vector magnitude (L2 norm). + + Args: + vector: Input vector + + Returns: + Vector magnitude + """ + try: + return math.sqrt(sum(x * x for x in vector)) + except Exception as e: + logger.error("Magnitude calculation failed", error=str(e)) + return 0.0 + + @staticmethod + def zero_center(vector: List[float]) -> List[float]: + """ + Center vector around zero by subtracting mean. + + Args: + vector: Input vector + + Returns: + Zero-centered vector + + Raises: + ValueError: If vector is empty or operation fails + """ + if not vector: + raise ValueError("Cannot center empty vector") + + try: + mean = sum(vector) / len(vector) + return [x - mean for x in vector] + except Exception as e: + logger.error("Zero centering failed", error=str(e)) + raise ValueError(f"Zero centering failed: {e}") from e + + @staticmethod + def standardize(vector: List[float]) -> List[float]: + """ + Standardize vector (zero mean, unit variance). + + Args: + vector: Input vector + + Returns: + Standardized vector + + Raises: + ValueError: If vector is empty or operation fails + """ + if not vector: + raise ValueError("Cannot standardize empty vector") + + try: + mean = sum(vector) / len(vector) + variance = sum((x - mean) ** 2 for x in vector) / len(vector) + + if variance == 0: + logger.warning( + "Cannot standardize constant vector, returning zero vector" + ) + return [0.0 for _ in vector] + + std_dev = math.sqrt(variance) + return [(x - mean) / std_dev for x in vector] + + except Exception as e: + logger.error("Standardization failed", error=str(e)) + raise ValueError(f"Standardization failed: {e}") from e + + @staticmethod + def clip( + vector: List[float], min_val: float = -1.0, max_val: float = 1.0 + ) -> List[float]: + """ + Clip vector values to range. + + Args: + vector: Input vector + min_val: Minimum value + max_val: Maximum value + + Returns: + Clipped vector + """ + try: + return [max(min_val, min(max_val, x)) for x in vector] + except Exception as e: + logger.error("Vector clipping failed", error=str(e)) + return vector + + @classmethod + def normalize( + cls, vector: List[float], norm_type: NormalizationType = NormalizationType.L2 + ) -> List[float]: + """ + Normalize vector using specified normalization type. + + Args: + vector: Input vector + norm_type: Type of normalization to apply + + Returns: + Normalized vector + """ + if norm_type == NormalizationType.L1: + return cls.l1_normalize(vector) + elif norm_type == NormalizationType.L2: + return cls.l2_normalize(vector) + elif norm_type == NormalizationType.MAX: + return cls.max_normalize(vector) + else: + raise ValueError(f"Unknown normalization type: {norm_type}") + + @classmethod + def batch_normalize( + cls, + vectors: List[List[float]], + norm_type: NormalizationType = NormalizationType.L2, + ) -> List[List[float]]: + """ + Normalize multiple vectors. + + Args: + vectors: List of input vectors + norm_type: Type of normalization to apply + + Returns: + List of normalized vectors + """ + return [cls.normalize(vec, norm_type) for vec in vectors] + + @classmethod + def is_normalized( + cls, + vector: List[float], + norm_type: NormalizationType = NormalizationType.L2, + tolerance: float = 1e-6, + ) -> bool: + """ + Check if vector is normalized according to the given norm type. + + Args: + vector: Vector to check + norm_type: Type of normalization to check + tolerance: Tolerance for check + + Returns: + True if vector is normalized + """ + if norm_type == NormalizationType.L2: + magnitude = cls.magnitude(vector) + return abs(magnitude - 1.0) < tolerance + elif norm_type == NormalizationType.L1: + l1_norm = sum(abs(x) for x in vector) + return abs(l1_norm - 1.0) < tolerance + elif norm_type == NormalizationType.MAX: + max_val = max(abs(x) for x in vector) if vector else 0.0 + return abs(max_val - 1.0) < tolerance + return False + + @classmethod + def safe_normalize( + cls, vector: List[float], norm_type: NormalizationType = NormalizationType.L2 + ) -> List[float]: + """ + Safely normalize vector, returning original if normalization fails. + + Args: + vector: Input vector + norm_type: Type of normalization to apply + + Returns: + Normalized vector or original if normalization fails + """ + try: + return cls.normalize(vector, norm_type) + except Exception as e: + logger.error("Safe normalization failed", error=str(e)) + return vector + + +# Convenience standalone functions +def l1_normalize(vector: List[float]) -> List[float]: + """Normalize vector using L1 norm.""" + return VectorNormalizer.l1_normalize(vector) + + +def l2_normalize(vector: List[float]) -> List[float]: + """Normalize vector using L2 norm.""" + return VectorNormalizer.l2_normalize(vector) + + +def max_normalize(vector: List[float]) -> List[float]: + """Normalize vector by max value.""" + return VectorNormalizer.max_normalize(vector) + + +def standardize_vector(vector: List[float]) -> List[float]: + """Standardize vector (zero mean, unit variance).""" + return VectorNormalizer.standardize(vector) + + +def clip_vector( + vector: List[float], min_val: float = -1.0, max_val: float = 1.0 +) -> List[float]: + """Clip vector values to range.""" + return VectorNormalizer.clip(vector, min_val, max_val) diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 0000000..3edf4dc --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1,6 @@ +""" +Performance benchmark tests. + +These tests measure performance characteristics and should be +run separately from regular unit/integration tests. +""" diff --git a/tests/benchmarks/test_qdrant_performance.py b/tests/benchmarks/test_qdrant_performance.py new file mode 100644 index 0000000..e78180a --- /dev/null +++ b/tests/benchmarks/test_qdrant_performance.py @@ -0,0 +1,361 @@ +""" +Performance benchmarks for Qdrant operations. + +Run with: pytest tests/benchmarks/ -v -m benchmark + +These tests measure performance and are not part of regular CI. +""" + +import pytest +import pytest_asyncio +from qdrant_client.models import Distance + +from app.benchmarks.qdrant_benchmark import QdrantBenchmark +from app.cache.qdrant_client import create_qdrant_client +from app.models.qdrant_point import QdrantPoint +from app.repositories.qdrant_repository import QdrantRepository + + +@pytest_asyncio.fixture +async def qdrant_client(): + """Create Qdrant client for benchmarking.""" + client = await create_qdrant_client() + yield client + await client.close() + + +@pytest_asyncio.fixture +async def qdrant_repository(qdrant_client): + """Create Qdrant repository for benchmarking.""" + repository = QdrantRepository(qdrant_client) + + # Ensure collection exists + await repository.create_collection(distance=Distance.COSINE) + + yield repository + + # Clean up + try: + await repository.delete_collection() + except Exception: + pass + + +@pytest_asyncio.fixture +async def benchmark(qdrant_repository): + """Create benchmark instance.""" + return QdrantBenchmark(qdrant_repository) + + +@pytest.mark.benchmark +class TestQdrantPerformance: + """Performance benchmark tests for Qdrant operations.""" + + @pytest.mark.asyncio + async def test_benchmark_single_insert(self, benchmark): + """Benchmark single point insertion.""" + metrics = await benchmark.benchmark_insert(num_points=100, vector_dim=384) + + assert metrics.total_operations == 100 + assert metrics.success_count == 100 + assert metrics.error_count == 0 + assert metrics.operations_per_second > 0 + + print(metrics) + + @pytest.mark.asyncio + async def test_benchmark_batch_insert(self, benchmark): + """Benchmark batch insertion.""" + metrics = await benchmark.benchmark_batch_insert( + num_points=1000, batch_size=100, vector_dim=384 + ) + + assert metrics.total_operations == 10 # 10 batches + assert metrics.success_count == 10 + assert metrics.error_count == 0 + assert metrics.operations_per_second > 0 + + print(metrics) + + @pytest.mark.asyncio + async def test_benchmark_search(self, benchmark, qdrant_repository): + """Benchmark similarity search.""" + # First, insert some test data + points = [ + QdrantPoint( + id=f"search_bench_{i}", + vector=[0.1 * i] * 384, + payload={"index": i}, + ) + for i in range(100) + ] + await qdrant_repository.store_points(points) + + metrics = await benchmark.benchmark_search( + num_searches=50, vector_dim=384, limit=10 + ) + + assert metrics.total_operations == 50 + assert metrics.success_count == 50 + assert metrics.error_count == 0 + assert metrics.operations_per_second > 0 + assert metrics.avg_latency_ms > 0 + + print(metrics) + + @pytest.mark.asyncio + async def test_benchmark_concurrent_operations(self, benchmark): + """Benchmark concurrent insertions.""" + metrics = await benchmark.benchmark_concurrent_operations( + num_operations=50, concurrency=10, vector_dim=384 + ) + + assert metrics.total_operations == 50 + assert metrics.success_count == 50 + assert metrics.error_count == 0 + assert metrics.operations_per_second > 0 + + print(metrics) + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_benchmark_small_dataset(self, benchmark): + """Run small dataset benchmark suite.""" + result = await benchmark.run_full_benchmark( + small_dataset=True, medium_dataset=False, large_dataset=False + ) + + assert len(result.metrics) >= 4 + assert result.duration > 0 + assert all(m.error_count == 0 for m in result.metrics) + + print(result.summary()) + + @pytest.mark.asyncio + @pytest.mark.slow + async def test_benchmark_medium_dataset(self, benchmark): + """Run medium dataset benchmark suite.""" + result = await benchmark.run_full_benchmark( + small_dataset=False, medium_dataset=True, large_dataset=False + ) + + assert len(result.metrics) >= 3 + assert result.duration > 0 + assert all(m.error_count == 0 for m in result.metrics) + + print(result.summary()) + + @pytest.mark.asyncio + async def test_benchmark_insert_different_vector_dims(self, benchmark): + """Benchmark insertions with different vector dimensions.""" + dims = [128, 384, 768, 1536] + results = {} + + for dim in dims: + metrics = await benchmark.benchmark_insert(num_points=50, vector_dim=dim) + results[dim] = metrics + assert metrics.success_count == 50 + + # Print comparison + print("\n\nVector Dimension Performance Comparison:") + print(f"{'Dimension':<12} {'Ops/sec':<12} {'Avg Latency (ms)':<20}") + print("-" * 44) + for dim, metrics in results.items(): + print( + f"{dim:<12} {metrics.operations_per_second:<12.2f} " + f"{metrics.avg_latency_ms:<20.2f}" + ) + + @pytest.mark.asyncio + async def test_benchmark_batch_sizes(self, benchmark): + """Benchmark different batch sizes.""" + batch_sizes = [10, 50, 100, 500] + results = {} + + for batch_size in batch_sizes: + metrics = await benchmark.benchmark_batch_insert( + num_points=1000, batch_size=batch_size, vector_dim=384 + ) + results[batch_size] = metrics + + # Print comparison + print("\n\nBatch Size Performance Comparison:") + print( + f"{'Batch Size':<12} {'Batches':<12} {'Ops/sec':<12} " + f"{'Avg Latency (ms)':<20}" + ) + print("-" * 56) + for batch_size, metrics in results.items(): + print( + f"{batch_size:<12} {metrics.total_operations:<12} " + f"{metrics.operations_per_second:<12.2f} " + f"{metrics.avg_latency_ms:<20.2f}" + ) + + @pytest.mark.asyncio + async def test_benchmark_search_result_limits(self, benchmark, qdrant_repository): + """Benchmark search with different result limits.""" + # Insert test data + points = [ + QdrantPoint( + id=f"limit_bench_{i}", + vector=[0.1 * i] * 384, + payload={"index": i}, + ) + for i in range(200) + ] + await qdrant_repository.store_points(points) + + limits = [1, 5, 10, 50, 100] + results = {} + + for limit in limits: + metrics = await benchmark.benchmark_search( + num_searches=30, vector_dim=384, limit=limit + ) + results[limit] = metrics + + # Print comparison + print("\n\nSearch Result Limit Performance Comparison:") + print(f"{'Limit':<12} {'Ops/sec':<12} {'Avg Latency (ms)':<20}") + print("-" * 44) + for limit, metrics in results.items(): + print( + f"{limit:<12} {metrics.operations_per_second:<12.2f} " + f"{metrics.avg_latency_ms:<20.2f}" + ) + + @pytest.mark.asyncio + async def test_benchmark_concurrent_levels(self, benchmark): + """Benchmark different concurrency levels.""" + concurrency_levels = [1, 5, 10, 20, 50] + results = {} + + for concurrency in concurrency_levels: + metrics = await benchmark.benchmark_concurrent_operations( + num_operations=100, concurrency=concurrency, vector_dim=384 + ) + results[concurrency] = metrics + + # Print comparison + print("\n\nConcurrency Level Performance Comparison:") + print(f"{'Concurrency':<14} {'Ops/sec':<12} {'Avg Latency (ms)':<20}") + print("-" * 46) + for concurrency, metrics in results.items(): + print( + f"{concurrency:<14} {metrics.operations_per_second:<12.2f} " + f"{metrics.avg_latency_ms:<20.2f}" + ) + + @pytest.mark.asyncio + async def test_benchmark_latency_percentiles(self, benchmark, qdrant_repository): + """Test latency percentiles for search operations.""" + # Insert test data + points = [ + QdrantPoint( + id=f"percentile_bench_{i}", + vector=[0.1 * i] * 384, + payload={"index": i}, + ) + for i in range(100) + ] + await qdrant_repository.store_points(points) + + metrics = await benchmark.benchmark_search( + num_searches=100, vector_dim=384, limit=10 + ) + + # Verify percentile ordering + assert metrics.min_latency_ms <= metrics.p50_latency_ms + assert metrics.p50_latency_ms <= metrics.p95_latency_ms + assert metrics.p95_latency_ms <= metrics.p99_latency_ms + assert metrics.p99_latency_ms <= metrics.max_latency_ms + + print("\n\nLatency Percentile Distribution:") + print(f" Min: {metrics.min_latency_ms:.2f}ms") + print(f" P50: {metrics.p50_latency_ms:.2f}ms") + print(f" P95: {metrics.p95_latency_ms:.2f}ms") + print(f" P99: {metrics.p99_latency_ms:.2f}ms") + print(f" Max: {metrics.max_latency_ms:.2f}ms") + + +@pytest.mark.benchmark +@pytest.mark.slow +class TestQdrantScalability: + """Scalability tests for large datasets.""" + + @pytest.mark.asyncio + async def test_scalability_insert(self, benchmark): + """Test insert scalability across different dataset sizes.""" + sizes = [100, 500, 1000] + results = {} + + for size in sizes: + metrics = await benchmark.benchmark_insert(num_points=size, vector_dim=384) + results[size] = metrics + + print("\n\nInsert Scalability Test:") + print(f"{'Dataset Size':<14} {'Ops/sec':<12} {'Total Time (s)':<16}") + print("-" * 42) + for size, metrics in results.items(): + print( + f"{size:<14} {metrics.operations_per_second:<12.2f} " + f"{metrics.total_time:<16.2f}" + ) + + @pytest.mark.asyncio + async def test_scalability_batch_insert(self, benchmark): + """Test batch insert scalability.""" + sizes = [1000, 5000, 10000] + batch_size = 100 + results = {} + + for size in sizes: + metrics = await benchmark.benchmark_batch_insert( + num_points=size, batch_size=batch_size, vector_dim=384 + ) + results[size] = metrics + + print("\n\nBatch Insert Scalability Test:") + print(f"{'Dataset Size':<14} {'Ops/sec':<12} {'Total Time (s)':<16}") + print("-" * 42) + for size, metrics in results.items(): + print( + f"{size:<14} {metrics.operations_per_second:<12.2f} " + f"{metrics.total_time:<16.2f}" + ) + + @pytest.mark.asyncio + async def test_search_performance_with_dataset_growth( + self, benchmark, qdrant_repository + ): + """Test search performance as dataset grows.""" + results = {} + dataset_sizes = [100, 500, 1000] + + for size in dataset_sizes: + # Insert data + points = [ + QdrantPoint( + id=f"growth_bench_{size}_{i}", + vector=[0.1 * i] * 384, + payload={"size": size, "index": i}, + ) + for i in range(size) + ] + await qdrant_repository.store_points(points) + + # Benchmark search + metrics = await benchmark.benchmark_search( + num_searches=30, vector_dim=384, limit=10 + ) + results[size] = metrics + + print("\n\nSearch Performance vs Dataset Size:") + print(f"{'Dataset Size':<14} {'Ops/sec':<12} {'Avg Latency (ms)':<20}") + print("-" * 46) + for size, metrics in results.items(): + print( + f"{size:<14} {metrics.operations_per_second:<12.2f} " + f"{metrics.avg_latency_ms:<20.2f}" + ) diff --git a/tests/integration/test_qdrant_cache.py b/tests/integration/test_qdrant_cache.py new file mode 100644 index 0000000..04cb888 --- /dev/null +++ b/tests/integration/test_qdrant_cache.py @@ -0,0 +1,253 @@ +""" +Integration tests for Qdrant cache. + +These tests require a running Qdrant instance. +""" + +import pytest +import pytest_asyncio +from qdrant_client.models import Distance + +from app.cache.qdrant_client import create_qdrant_client +from app.models.qdrant_point import QdrantPoint +from app.repositories.qdrant_repository import QdrantRepository + + +@pytest_asyncio.fixture +async def qdrant_client(): + """Create Qdrant client for testing.""" + client = await create_qdrant_client() + yield client + await client.close() + + +@pytest_asyncio.fixture +async def qdrant_repository(qdrant_client): + """Create Qdrant repository for testing.""" + repository = QdrantRepository(qdrant_client) + + # Ensure collection exists + await repository.create_collection(distance=Distance.COSINE) + + yield repository + + # Clean up: delete all points after tests + try: + await repository.delete_collection() + await repository.create_collection(distance=Distance.COSINE) + except Exception: + pass + + +@pytest.fixture +def sample_point(): + """Create sample Qdrant point.""" + return QdrantPoint( + id="test_integration_001", + vector=[0.1, 0.2, 0.3, 0.4, 0.5] * 77, # 385 dims (close to 384) + payload={ + "query_hash": "test_integration_hash", + "query": "What is integration testing?", + "response": "Integration testing tests the complete flow", + "provider": "openai", + "model": "gpt-3.5-turbo", + }, + ) + + +@pytest.mark.integration +class TestQdrantIntegration: + """Integration tests for Qdrant operations.""" + + @pytest.mark.asyncio + async def test_collection_creation(self, qdrant_repository): + """Test collection creation.""" + exists = await qdrant_repository.collection_exists() + assert exists is True + + @pytest.mark.asyncio + async def test_store_and_retrieve_point(self, qdrant_repository, sample_point): + """Test storing and retrieving a point.""" + # Store point + result = await qdrant_repository.store_point(sample_point) + assert result is True + + # Retrieve point + retrieved = await qdrant_repository.get_point(sample_point.id) + assert retrieved is not None + assert retrieved.id == sample_point.id + assert retrieved.payload["query_hash"] == "test_integration_hash" + + @pytest.mark.asyncio + async def test_batch_store_points(self, qdrant_repository): + """Test batch storing multiple points.""" + points = [ + QdrantPoint( + id=f"test_batch_{i}", + vector=[0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i, 0.5 * i] * 77, + payload={ + "query_hash": f"hash_{i}", + "index": i, + }, + ) + for i in range(5) + ] + + count = await qdrant_repository.store_points(points) + assert count == 5 + + # Verify all points were stored + for point in points: + retrieved = await qdrant_repository.get_point(point.id) + assert retrieved is not None + + @pytest.mark.asyncio + async def test_similarity_search(self, qdrant_repository): + """Test similarity search.""" + # Store multiple points with similar vectors + points = [ + QdrantPoint( + id=f"test_search_{i}", + vector=[0.1 + i * 0.01] * 385, + payload={"index": i}, + ) + for i in range(3) + ] + await qdrant_repository.store_points(points) + + # Search for similar points + query_vector = [0.1] * 385 + results = await qdrant_repository.search_similar(query_vector, limit=2) + + assert len(results) > 0 + assert results[0].score > 0.8 # High similarity expected + + @pytest.mark.asyncio + async def test_delete_point(self, qdrant_repository, sample_point): + """Test point deletion.""" + # Store point + await qdrant_repository.store_point(sample_point) + + # Delete point + result = await qdrant_repository.delete_point(sample_point.id) + assert result.success is True + assert result.deleted_count == 1 + + # Verify point is deleted + retrieved = await qdrant_repository.get_point(sample_point.id) + assert retrieved is None + + @pytest.mark.asyncio + async def test_count_points(self, qdrant_repository): + """Test counting points in collection.""" + # Store some points + points = [ + QdrantPoint( + id=f"test_count_{i}", + vector=[0.1] * 385, + payload={"index": i}, + ) + for i in range(3) + ] + await qdrant_repository.store_points(points) + + # Count points + count = await qdrant_repository.count_points() + assert count >= 3 + + @pytest.mark.asyncio + async def test_pagination(self, qdrant_repository): + """Test pagination of points.""" + # Store multiple points + points = [ + QdrantPoint( + id=f"test_page_{i}", + vector=[0.1] * 385, + payload={"index": i}, + ) + for i in range(10) + ] + await qdrant_repository.store_points(points) + + # Get first page + page1, offset1 = await qdrant_repository.scroll_points( + limit=5, with_vectors=False + ) + assert len(page1) == 5 + + # Get second page + if offset1: + page2, offset2 = await qdrant_repository.scroll_points( + limit=5, offset=str(offset1), with_vectors=False + ) + assert len(page2) <= 5 + + @pytest.mark.asyncio + async def test_update_point(self, qdrant_repository, sample_point): + """Test updating point payload.""" + # Store original point + await qdrant_repository.store_point(sample_point) + + # Update payload + new_payload = { + **sample_point.payload, + "updated": True, + "new_field": "new_value", + } + result = await qdrant_repository.update_point(sample_point.id, new_payload) + assert result is True + + # Verify update + retrieved = await qdrant_repository.get_point(sample_point.id) + assert retrieved is not None + assert retrieved.payload.get("updated") is True + assert retrieved.payload.get("new_field") == "new_value" + + @pytest.mark.asyncio + async def test_filter_search(self, qdrant_repository): + """Test search with filters.""" + from app.cache.qdrant_filter import create_filter + + # Store points with different providers + points = [ + QdrantPoint( + id=f"test_filter_{i}", + vector=[0.1] * 385, + payload={ + "provider": "openai" if i % 2 == 0 else "anthropic", + "index": i, + }, + ) + for i in range(4) + ] + await qdrant_repository.store_points(points) + + # Search with filter + query_vector = [0.1] * 385 + filter_builder = create_filter().with_provider("openai") + + results = await qdrant_repository.search_similar( + query_vector, + limit=10, + filter_condition=filter_builder.build(), + ) + + # All results should be from openai + assert len(results) > 0 + for result in results: + assert result.payload.get("provider") == "openai" + + @pytest.mark.asyncio + async def test_connection_health(self, qdrant_repository): + """Test connection health check.""" + result = await qdrant_repository.ping() + assert result is True + + @pytest.mark.asyncio + async def test_collection_info(self, qdrant_repository): + """Test getting collection information.""" + info = await qdrant_repository.get_collection_info() + assert info is not None + assert "vectors_count" in info + assert "points_count" in info + assert "status" in info diff --git a/tests/unit/cache/test_qdrant_errors.py b/tests/unit/cache/test_qdrant_errors.py new file mode 100644 index 0000000..f10aab2 --- /dev/null +++ b/tests/unit/cache/test_qdrant_errors.py @@ -0,0 +1,197 @@ +"""Unit tests for Qdrant error handling.""" + +import pytest + +from app.cache.qdrant_errors import ( + ErrorContext, + QdrantCollectionError, + QdrantCollectionExistsError, + QdrantCollectionNotFoundError, + QdrantConnectionError, + QdrantError, + QdrantPointError, + QdrantPointNotFoundError, + QdrantSearchError, + QdrantTimeoutError, + QdrantValidationError, + handle_qdrant_error, + is_retryable_error, +) + + +class TestQdrantError: + """Tests for QdrantError base class.""" + + def test_error_creation(self): + """Test QdrantError creation.""" + error = QdrantError("Test error") + + assert str(error) == "Test error" + assert error.message == "Test error" + assert error.cause is None + + def test_error_with_cause(self): + """Test QdrantError with cause.""" + cause = ValueError("Original error") + error = QdrantError("Test error", cause=cause) + + assert error.cause is cause + assert "caused by" in str(error) + assert "Original error" in str(error) + + +class TestErrorMapping: + """Tests for handle_qdrant_error function.""" + + def test_connection_error_mapping(self): + """Test mapping connection errors.""" + error = Exception("Failed to connect to server") + result = handle_qdrant_error(error, "test_operation") + + assert isinstance(result, QdrantConnectionError) + assert result.cause is error + assert "test_operation" in result.message + + def test_timeout_error_mapping(self): + """Test mapping timeout errors.""" + error = Exception("Operation timeout exceeded") + result = handle_qdrant_error(error, "search") + + assert isinstance(result, QdrantTimeoutError) + assert "timeout" in result.message.lower() + + def test_collection_not_found_mapping(self): + """Test mapping collection not found errors.""" + error = Exception("Collection not found") + result = handle_qdrant_error(error, "get_collection") + + assert isinstance(result, QdrantCollectionNotFoundError) + + def test_collection_exists_mapping(self): + """Test mapping collection exists errors.""" + error = Exception("Collection already exists") + result = handle_qdrant_error(error, "create_collection") + + assert isinstance(result, QdrantCollectionExistsError) + + def test_collection_error_mapping(self): + """Test mapping generic collection errors.""" + error = Exception("Collection operation failed") + result = handle_qdrant_error(error, "update_collection") + + assert isinstance(result, QdrantCollectionError) + + def test_point_not_found_mapping(self): + """Test mapping point not found errors.""" + error = Exception("Point not found") + result = handle_qdrant_error(error, "get_point") + + assert isinstance(result, QdrantPointNotFoundError) + + def test_point_error_mapping(self): + """Test mapping generic point errors.""" + error = Exception("Point operation failed") + result = handle_qdrant_error(error, "upsert_point") + + assert isinstance(result, QdrantPointError) + + def test_search_error_mapping(self): + """Test mapping search errors.""" + error = Exception("Search query failed") + result = handle_qdrant_error(error, "search") + + assert isinstance(result, QdrantSearchError) + + def test_validation_error_mapping(self): + """Test mapping validation errors.""" + error = Exception("Invalid vector dimension") + result = handle_qdrant_error(error, "validate") + + assert isinstance(result, QdrantValidationError) + + def test_generic_error_mapping(self): + """Test mapping generic errors.""" + error = Exception("Unknown error") + result = handle_qdrant_error(error, "unknown_op") + + assert isinstance(result, QdrantError) + assert result.cause is error + + +class TestRetryableErrors: + """Tests for is_retryable_error function.""" + + def test_connection_error_retryable(self): + """Test connection errors are retryable.""" + error = QdrantConnectionError("Connection failed") + + assert is_retryable_error(error) is True + + def test_timeout_error_retryable(self): + """Test timeout errors are retryable.""" + error = QdrantTimeoutError("Operation timed out") + + assert is_retryable_error(error) is True + + def test_validation_error_not_retryable(self): + """Test validation errors are not retryable.""" + error = QdrantValidationError("Invalid input") + + assert is_retryable_error(error) is False + + def test_generic_timeout_retryable(self): + """Test generic timeout errors are retryable.""" + error = Exception("Request timeout") + + assert is_retryable_error(error) is True + + def test_generic_connection_retryable(self): + """Test generic connection errors are retryable.""" + error = Exception("Network connection lost") + + assert is_retryable_error(error) is True + + def test_generic_unavailable_retryable(self): + """Test unavailable errors are retryable.""" + error = Exception("Service unavailable") + + assert is_retryable_error(error) is True + + def test_non_retryable_error(self): + """Test non-retryable errors.""" + error = Exception("Invalid operation") + + assert is_retryable_error(error) is False + + +class TestErrorContext: + """Tests for ErrorContext context manager.""" + + def test_error_context_no_error(self): + """Test ErrorContext with no errors.""" + with ErrorContext("test_operation"): + pass # No error should occur + + def test_error_context_maps_error(self): + """Test ErrorContext maps exceptions.""" + with pytest.raises(QdrantConnectionError): + with ErrorContext("test_operation"): + raise Exception("Connection failed") + + def test_error_context_preserves_operation(self): + """Test ErrorContext preserves operation name.""" + try: + with ErrorContext("my_operation"): + raise Exception("Test error") + except QdrantError as e: + assert "my_operation" in e.message + + def test_error_context_chains_exceptions(self): + """Test ErrorContext chains exceptions properly.""" + original = ValueError("Original error") + + try: + with ErrorContext("test_op"): + raise original + except QdrantError as e: + assert e.cause is original diff --git a/tests/unit/cache/test_qdrant_filter.py b/tests/unit/cache/test_qdrant_filter.py new file mode 100644 index 0000000..1dee115 --- /dev/null +++ b/tests/unit/cache/test_qdrant_filter.py @@ -0,0 +1,211 @@ +"""Unit tests for Qdrant filter builder.""" + +from qdrant_client.models import ( + FieldCondition, + IsEmptyCondition, + MatchAny, + MatchValue, + PayloadField, + Range, +) + +from app.cache.qdrant_filter import QdrantFilterBuilder, create_filter +from app.models.qdrant_schema import QdrantSchema + + +class TestQdrantFilterBuilder: + """Tests for QdrantFilterBuilder class.""" + + def test_match_field(self): + """Test match_field adds exact match condition.""" + builder = QdrantFilterBuilder() + result = builder.match_field("provider", "openai") + + assert result is builder # Test fluent API + assert len(builder._must) == 1 + condition = builder._must[0] + assert isinstance(condition, FieldCondition) + assert condition.key == "provider" + assert isinstance(condition.match, MatchValue) + assert condition.match.value == "openai" + + def test_match_any(self): + """Test match_any adds match any condition.""" + builder = QdrantFilterBuilder() + values = ["openai", "anthropic", "cohere"] + result = builder.match_any("provider", values) + + assert result is builder + assert len(builder._must) == 1 + condition = builder._must[0] + assert isinstance(condition, FieldCondition) + assert isinstance(condition.match, MatchAny) + assert condition.match.any == values + + def test_range_field_gte(self): + """Test range_field with gte parameter.""" + builder = QdrantFilterBuilder() + result = builder.range_field("created_at", gte=1000.0) + + assert result is builder + assert len(builder._must) == 1 + condition = builder._must[0] + assert isinstance(condition, FieldCondition) + assert isinstance(condition.range, Range) + assert condition.range.gte == 1000.0 + + def test_range_field_between(self): + """Test range_field with gte and lte parameters.""" + builder = QdrantFilterBuilder() + result = builder.range_field("created_at", gte=1000.0, lte=2000.0) + + assert result is builder + condition = builder._must[0] + assert condition.range.gte == 1000.0 + assert condition.range.lte == 2000.0 + + def test_is_empty(self): + """Test is_empty adds is empty condition.""" + builder = QdrantFilterBuilder() + result = builder.is_empty("tags") + + assert result is builder + assert len(builder._must) == 1 + condition = builder._must[0] + assert isinstance(condition, IsEmptyCondition) + assert isinstance(condition.is_empty, PayloadField) + assert condition.is_empty.key == "tags" + + def test_is_not_empty(self): + """Test is_not_empty adds is not empty condition.""" + builder = QdrantFilterBuilder() + result = builder.is_not_empty("tags") + + assert result is builder + assert len(builder._must_not) == 1 + condition = builder._must_not[0] + assert isinstance(condition, IsEmptyCondition) + + def test_with_provider(self): + """Test with_provider convenience method.""" + builder = QdrantFilterBuilder() + result = builder.with_provider("openai") + + assert result is builder + assert len(builder._must) == 1 + condition = builder._must[0] + assert condition.key == QdrantSchema.FIELD_PROVIDER + assert condition.match.value == "openai" + + def test_with_model(self): + """Test with_model convenience method.""" + builder = QdrantFilterBuilder() + result = builder.with_model("gpt-4") + + assert result is builder + condition = builder._must[0] + assert condition.key == QdrantSchema.FIELD_MODEL + assert condition.match.value == "gpt-4" + + def test_with_query_hash(self): + """Test with_query_hash convenience method.""" + builder = QdrantFilterBuilder() + result = builder.with_query_hash("abc123") + + assert result is builder + condition = builder._must[0] + assert condition.key == QdrantSchema.FIELD_QUERY_HASH + assert condition.match.value == "abc123" + + def test_created_after(self): + """Test created_after convenience method.""" + builder = QdrantFilterBuilder() + timestamp = 1234567890.0 + result = builder.created_after(timestamp) + + assert result is builder + condition = builder._must[0] + assert condition.key == QdrantSchema.FIELD_CREATED_AT + assert condition.range.gte == timestamp + + def test_created_before(self): + """Test created_before convenience method.""" + builder = QdrantFilterBuilder() + timestamp = 1234567890.0 + result = builder.created_before(timestamp) + + assert result is builder + condition = builder._must[0] + assert condition.range.lte == timestamp + + def test_created_between(self): + """Test created_between convenience method.""" + builder = QdrantFilterBuilder() + start = 1000.0 + end = 2000.0 + result = builder.created_between(start, end) + + assert result is builder + condition = builder._must[0] + assert condition.range.gte == start + assert condition.range.lte == end + + def test_with_tags(self): + """Test with_tags convenience method.""" + builder = QdrantFilterBuilder() + tags = ["production", "cache"] + result = builder.with_tags(tags) + + assert result is builder + condition = builder._must[0] + assert condition.key == QdrantSchema.FIELD_TAGS + assert condition.match.any == tags + + def test_build_with_conditions(self): + """Test build creates Filter with conditions.""" + builder = QdrantFilterBuilder() + builder.match_field("provider", "openai") + builder.match_field("model", "gpt-4") + + filter_obj = builder.build() + + assert filter_obj is not None + assert filter_obj.must is not None + assert len(filter_obj.must) == 2 + + def test_build_empty(self): + """Test build returns None when no conditions.""" + builder = QdrantFilterBuilder() + filter_obj = builder.build() + + assert filter_obj is None + + def test_reset(self): + """Test reset clears all conditions.""" + builder = QdrantFilterBuilder() + builder.match_field("provider", "openai") + builder.match_field("model", "gpt-4") + + result = builder.reset() + + assert result is builder + assert len(builder._must) == 0 + assert len(builder._should) == 0 + assert len(builder._must_not) == 0 + + def test_chaining(self): + """Test method chaining works correctly.""" + builder = QdrantFilterBuilder() + result = ( + builder.with_provider("openai").with_model("gpt-4").created_after(1000.0) + ) + + assert result is builder + assert len(builder._must) == 3 + + def test_create_filter_function(self): + """Test create_filter factory function.""" + builder = create_filter() + + assert isinstance(builder, QdrantFilterBuilder) + assert len(builder._must) == 0 diff --git a/tests/unit/repositories/test_qdrant_repository.py b/tests/unit/repositories/test_qdrant_repository.py new file mode 100644 index 0000000..bd7b302 --- /dev/null +++ b/tests/unit/repositories/test_qdrant_repository.py @@ -0,0 +1,248 @@ +"""Unit tests for Qdrant repository.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from qdrant_client.models import Distance, PointStruct, ScoredPoint + +from app.models.qdrant_point import QdrantPoint, SearchResult +from app.repositories.qdrant_repository import QdrantRepository + + +@pytest.fixture +def mock_client(): + """Create mock Qdrant client.""" + client = AsyncMock() + client.get_collections.return_value = MagicMock(collections=[]) + return client + + +@pytest.fixture +def repository(mock_client): + """Create repository instance.""" + return QdrantRepository(mock_client) + + +class TestQdrantRepository: + """Tests for QdrantRepository class.""" + + @pytest.mark.asyncio + async def test_collection_exists_true(self, mock_client): + """Test collection_exists returns True when collection exists.""" + mock_collection = MagicMock() + mock_collection.name = "test_cache" + mock_client.get_collections.return_value = MagicMock( + collections=[mock_collection] + ) + + with patch("app.repositories.qdrant_repository.config") as mock_config: + mock_config.qdrant_collection_name = "test_cache" + mock_config.qdrant_vector_size = 384 + repository = QdrantRepository(mock_client) + result = await repository.collection_exists() + + assert result is True + mock_client.get_collections.assert_called_once() + + @pytest.mark.asyncio + async def test_collection_exists_false(self, repository, mock_client): + """Test collection_exists returns False when collection doesn't exist.""" + mock_client.get_collections.return_value = MagicMock(collections=[]) + + result = await repository.collection_exists() + + assert result is False + + @pytest.mark.asyncio + async def test_collection_exists_error(self, repository, mock_client): + """Test collection_exists handles errors gracefully.""" + mock_client.get_collections.side_effect = Exception("Connection failed") + + result = await repository.collection_exists() + + assert result is False + + @pytest.mark.asyncio + async def test_create_collection_success(self, repository, mock_client): + """Test successful collection creation.""" + mock_client.get_collections.return_value = MagicMock(collections=[]) + + result = await repository.create_collection(distance=Distance.COSINE) + + assert result is True + mock_client.create_collection.assert_called_once() + + @pytest.mark.asyncio + async def test_create_collection_already_exists(self, mock_client): + """Test collection creation when already exists.""" + mock_collection = MagicMock() + mock_collection.name = "test_cache" + mock_client.get_collections.return_value = MagicMock( + collections=[mock_collection] + ) + + with patch("app.repositories.qdrant_repository.config") as mock_config: + mock_config.qdrant_collection_name = "test_cache" + mock_config.qdrant_vector_size = 384 + repository = QdrantRepository(mock_client) + result = await repository.create_collection() + + assert result is True + mock_client.create_collection.assert_not_called() + + @pytest.mark.asyncio + async def test_delete_collection_success(self, repository, mock_client): + """Test successful collection deletion.""" + result = await repository.delete_collection() + + assert result is True + mock_client.delete_collection.assert_called_once() + + @pytest.mark.asyncio + async def test_ping_success(self, repository, mock_client): + """Test successful ping.""" + with ( + patch("app.repositories.qdrant_repository.RetryPolicy"), + patch("app.repositories.qdrant_repository.ErrorContext"), + ): + result = await repository.ping() + + assert result is True + mock_client.get_collections.assert_called() + + @pytest.mark.asyncio + async def test_store_point_success(self, repository, mock_client): + """Test successful point storage.""" + point = QdrantPoint( + id="test-123", + vector=[0.1, 0.2, 0.3], + payload={"query_hash": "abc123", "response": "test response"}, + ) + + with ( + patch("app.repositories.qdrant_repository.RetryPolicy"), + patch("app.repositories.qdrant_repository.ErrorContext"), + ): + result = await repository.store_point(point) + + assert result is True + mock_client.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_store_points_success(self, repository, mock_client): + """Test successful multiple points storage.""" + points = [ + QdrantPoint( + id=f"test-{i}", + vector=[0.1 * i, 0.2 * i, 0.3 * i], + payload={"query_hash": f"hash{i}"}, + ) + for i in range(3) + ] + + result = await repository.store_points(points) + + assert result == 3 + mock_client.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_store_points_empty(self, repository, mock_client): + """Test storing empty points list.""" + result = await repository.store_points([]) + + assert result == 0 + mock_client.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_search_similar_success(self, repository, mock_client): + """Test successful similarity search.""" + query_vector = [0.1, 0.2, 0.3] + mock_scored = ScoredPoint( + id="test-123", + version=1, + score=0.95, + payload={"query_hash": "abc123", "response": "test"}, + vector=[0.1, 0.2, 0.3], + ) + mock_client.search.return_value = [mock_scored] + + results = await repository.search_similar(query_vector, limit=5) + + assert len(results) == 1 + assert isinstance(results[0], SearchResult) + assert results[0].score == 0.95 + mock_client.search.assert_called_once() + + @pytest.mark.asyncio + async def test_search_similar_no_results(self, repository, mock_client): + """Test similarity search with no results.""" + query_vector = [0.1, 0.2, 0.3] + mock_client.search.return_value = [] + + results = await repository.search_similar(query_vector) + + assert results == [] + + @pytest.mark.asyncio + async def test_get_point_success(self, repository, mock_client): + """Test successful point retrieval by ID.""" + mock_point = PointStruct( + id="test-123", + vector=[0.1, 0.2, 0.3], + payload={"query_hash": "abc123"}, + ) + mock_client.retrieve.return_value = [mock_point] + + point = await repository.get_point("test-123") + + assert point is not None + assert point.id == "test-123" + mock_client.retrieve.assert_called_once() + + @pytest.mark.asyncio + async def test_get_point_not_found(self, repository, mock_client): + """Test point retrieval when not found.""" + mock_client.retrieve.return_value = [] + + point = await repository.get_point("nonexistent") + + assert point is None + + @pytest.mark.asyncio + async def test_delete_point_success(self, repository, mock_client): + """Test successful point deletion.""" + result = await repository.delete_point("test-123") + + assert result.success is True + assert result.deleted_count == 1 + mock_client.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_points_success(self, repository, mock_client): + """Test successful multiple points deletion.""" + point_ids = ["test-1", "test-2", "test-3"] + + result = await repository.delete_points(point_ids) + + assert result.success is True + assert result.deleted_count == 3 + mock_client.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_count_points_success(self, repository, mock_client): + """Test successful points counting.""" + mock_client.count.return_value = MagicMock(count=42) + + count = await repository.count_points() + + assert count == 42 + mock_client.count.assert_called_once() + + @pytest.mark.asyncio + async def test_count_points_error(self, repository, mock_client): + """Test points counting handles errors.""" + mock_client.count.side_effect = Exception("Count failed") + + count = await repository.count_points() + + assert count == 0 diff --git a/tests/unit/similarity/test_score_calculator.py b/tests/unit/similarity/test_score_calculator.py new file mode 100644 index 0000000..dbd13c2 --- /dev/null +++ b/tests/unit/similarity/test_score_calculator.py @@ -0,0 +1,217 @@ +"""Unit tests for similarity score calculator.""" + +import math + +import pytest + +from app.similarity.score_calculator import ( + ScoreCalculator, + ScoreInterpretation, + cosine_similarity, + euclidean_distance, + interpret_cosine_score, +) + + +class TestCosineSimilarity: + """Tests for cosine_similarity function.""" + + def test_identical_vectors(self): + """Test cosine similarity of identical vectors is 1.0.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0, 3.0] + + score = cosine_similarity(vec1, vec2) + + assert abs(score - 1.0) < 1e-6 + + def test_opposite_vectors(self): + """Test cosine similarity of opposite vectors is -1.0.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [-1.0, -2.0, -3.0] + + score = cosine_similarity(vec1, vec2) + + assert abs(score - (-1.0)) < 1e-6 + + def test_orthogonal_vectors(self): + """Test cosine similarity of orthogonal vectors is 0.0.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [0.0, 1.0, 0.0] + + score = cosine_similarity(vec1, vec2) + + assert abs(score - 0.0) < 1e-6 + + def test_zero_vector(self): + """Test cosine similarity with zero vector is 0.0.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [0.0, 0.0, 0.0] + + score = cosine_similarity(vec1, vec2) + + assert score == 0.0 + + def test_mismatched_dimensions(self): + """Test cosine similarity with mismatched dimensions raises error.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0] + + with pytest.raises(ValueError, match="must have same dimensions"): + cosine_similarity(vec1, vec2) + + def test_empty_vectors(self): + """Test cosine similarity with empty vectors raises error.""" + vec1 = [] + vec2 = [] + + with pytest.raises(ValueError, match="must have same dimensions"): + cosine_similarity(vec1, vec2) + + +class TestEuclideanDistance: + """Tests for euclidean_distance function.""" + + def test_identical_vectors(self): + """Test euclidean distance of identical vectors is 0.0.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0, 3.0] + + dist = euclidean_distance(vec1, vec2) + + assert abs(dist - 0.0) < 1e-6 + + def test_known_distance(self): + """Test euclidean distance with known values.""" + vec1 = [0.0, 0.0, 0.0] + vec2 = [3.0, 4.0, 0.0] + + dist = euclidean_distance(vec1, vec2) + + assert abs(dist - 5.0) < 1e-6 # 3-4-5 triangle + + def test_unit_distance(self): + """Test euclidean distance with unit vectors.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [0.0, 1.0, 0.0] + + dist = euclidean_distance(vec1, vec2) + + assert abs(dist - math.sqrt(2)) < 1e-6 + + def test_mismatched_dimensions(self): + """Test euclidean distance with mismatched dimensions raises error.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0] + + with pytest.raises(ValueError, match="must have same dimensions"): + euclidean_distance(vec1, vec2) + + +class TestScoreInterpretation: + """Tests for interpret_cosine_score function.""" + + def test_exact_match(self): + """Test interpretation of exact match score.""" + result = interpret_cosine_score(1.0) + + assert result == ScoreInterpretation.EXACT + + def test_very_high_match(self): + """Test interpretation of very high match score.""" + result = interpret_cosine_score(0.95) + + assert result == ScoreInterpretation.VERY_HIGH + + def test_high_match(self): + """Test interpretation of high match score.""" + result = interpret_cosine_score(0.88) + + assert result == ScoreInterpretation.HIGH + + def test_moderate_match(self): + """Test interpretation of moderate match score.""" + result = interpret_cosine_score(0.75) + + assert result == ScoreInterpretation.MODERATE + + def test_low_match(self): + """Test interpretation of low match score.""" + result = interpret_cosine_score(0.55) + + assert result == ScoreInterpretation.LOW + + def test_very_low_match(self): + """Test interpretation of very low match score.""" + result = interpret_cosine_score(0.25) + + assert result == ScoreInterpretation.VERY_LOW + + +class TestScoreCalculator: + """Tests for ScoreCalculator class.""" + + def test_calculate_cosine_similarity(self): + """Test calculate method with cosine similarity.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + + score = ScoreCalculator.calculate(vec1, vec2, metric="cosine") + + assert abs(score - 1.0) < 1e-6 + + def test_calculate_euclidean_distance(self): + """Test calculate method with euclidean distance.""" + vec1 = [0.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + + score = ScoreCalculator.calculate(vec1, vec2, metric="euclidean") + + assert abs(score - 1.0) < 1e-6 + + def test_calculate_invalid_metric(self): + """Test calculate method with invalid metric raises error.""" + vec1 = [1.0, 2.0, 3.0] + vec2 = [1.0, 2.0, 3.0] + + with pytest.raises(ValueError, match="Unknown metric"): + ScoreCalculator.calculate(vec1, vec2, metric="invalid") + + def test_is_match_above_threshold(self): + """Test is_match returns True above threshold.""" + result = ScoreCalculator.is_match(0.90, threshold=0.85) + + assert result is True + + def test_is_match_below_threshold(self): + """Test is_match returns False below threshold.""" + result = ScoreCalculator.is_match(0.80, threshold=0.85) + + assert result is False + + def test_is_match_exact_threshold(self): + """Test is_match returns True at exact threshold.""" + result = ScoreCalculator.is_match(0.85, threshold=0.85) + + assert result is True + + def test_normalize_cosine_score(self): + """Test normalize_score with cosine similarity.""" + # Cosine already in 0-1 range + normalized = ScoreCalculator.normalize_score(0.85, metric="cosine") + + assert abs(normalized - 0.85) < 1e-6 + + def test_normalize_euclidean_score(self): + """Test normalize_score with euclidean distance.""" + # Euclidean converted to similarity + normalized = ScoreCalculator.normalize_score(0.0, metric="euclidean") + + assert abs(normalized - 1.0) < 1e-6 + + def test_get_interpretation(self): + """Test get_interpretation method.""" + interpretation = ScoreCalculator.get_interpretation(0.95) + + assert interpretation == ScoreInterpretation.VERY_HIGH + assert interpretation.value == "very_high" diff --git a/tests/unit/similarity/test_threshold_tuner.py b/tests/unit/similarity/test_threshold_tuner.py new file mode 100644 index 0000000..6c5dd0a --- /dev/null +++ b/tests/unit/similarity/test_threshold_tuner.py @@ -0,0 +1,412 @@ +"""Unit tests for threshold tuner.""" + +import pytest + +from app.similarity.threshold_tuner import ( + ThresholdMetrics, + ThresholdOptimizationGoal, + ThresholdRecommendation, + ThresholdTuner, + UseCase, + evaluate_threshold_quality, + get_cache_threshold, + get_exact_match_threshold, + tune_threshold, +) + + +class TestThresholdMetrics: + """Tests for ThresholdMetrics class.""" + + def test_metrics_precision_calculation(self): + """Test precision calculation.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + assert metrics.precision == pytest.approx(0.8, abs=0.01) # 80 / (80 + 20) + + def test_metrics_recall_calculation(self): + """Test recall calculation.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + assert metrics.recall == pytest.approx(0.727, abs=0.01) # 80 / (80 + 30) + + def test_metrics_f1_score_calculation(self): + """Test F1 score calculation.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + # F1 = 2 * (0.8 * 0.727) / (0.8 + 0.727) ≈ 0.762 + assert metrics.f1_score == pytest.approx(0.762, abs=0.01) + + def test_metrics_accuracy_calculation(self): + """Test accuracy calculation.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + assert metrics.accuracy == pytest.approx(0.75, abs=0.01) # (80 + 70) / 200 + + def test_metrics_zero_division_handling(self): + """Test handling of zero division.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=0, + false_positives=0, + true_negatives=0, + false_negatives=0, + ) + + assert metrics.precision == 0.0 + assert metrics.recall == 0.0 + assert metrics.f1_score == 0.0 + assert metrics.accuracy == 0.0 + + def test_metrics_string_representation(self): + """Test string representation.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + string_repr = str(metrics) + assert "Threshold: 0.850" in string_repr + assert "Precision:" in string_repr + assert "Recall:" in string_repr + assert "F1 Score:" in string_repr + + +class TestThresholdTuner: + """Tests for ThresholdTuner class.""" + + def test_evaluate_threshold_perfect_classification(self): + """Test threshold evaluation with perfect classification.""" + scores = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2] + ground_truth = [True, True, True, True, False, False, False, False] + threshold = 0.55 + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, threshold) + + assert metrics.true_positives == 4 + assert metrics.false_positives == 0 + assert metrics.true_negatives == 4 + assert metrics.false_negatives == 0 + assert metrics.precision == 1.0 + assert metrics.recall == 1.0 + + def test_evaluate_threshold_with_errors(self): + """Test threshold evaluation with classification errors.""" + scores = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4] + ground_truth = [True, False, True, False, True, False] + threshold = 0.65 + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, threshold) + + assert metrics.true_positives == 2 # 0.9 and 0.7 are True + assert metrics.false_positives == 1 # 0.8 is False but predicted True + assert metrics.true_negatives == 2 # 0.4 and 0.6 are False + assert metrics.false_negatives == 1 # 0.5 is True but predicted False + + def test_evaluate_threshold_length_mismatch(self): + """Test error handling for mismatched lengths.""" + scores = [0.9, 0.8, 0.7] + ground_truth = [True, False] + + with pytest.raises(ValueError, match="same length"): + ThresholdTuner.evaluate_threshold(scores, ground_truth, 0.5) + + def test_find_optimal_threshold_for_f1(self): + """Test finding optimal threshold for F1 score.""" + # Create test data where 0.7 is optimal + scores = [0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25] + ground_truth = [True, True, True, True, False, False, False, False] + + threshold, metrics = ThresholdTuner.find_optimal_threshold( + scores, ground_truth, ThresholdOptimizationGoal.F1_SCORE, 0.5, 0.9, 0.05 + ) + + # Optimal should be around 0.60-0.70 + assert 0.55 <= threshold <= 0.75 + assert metrics.f1_score > 0.8 + + def test_find_optimal_threshold_for_precision(self): + """Test finding optimal threshold for precision.""" + scores = [0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25] + ground_truth = [True, True, True, True, False, False, False, False] + + threshold, metrics = ThresholdTuner.find_optimal_threshold( + scores, ground_truth, ThresholdOptimizationGoal.PRECISION, 0.5, 0.9, 0.05 + ) + + # Higher threshold for precision optimization + assert threshold >= 0.60 + assert metrics.precision >= 0.8 + + def test_find_optimal_threshold_for_recall(self): + """Test finding optimal threshold for recall.""" + scores = [0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25] + ground_truth = [True, True, True, True, False, False, False, False] + + threshold, metrics = ThresholdTuner.find_optimal_threshold( + scores, ground_truth, ThresholdOptimizationGoal.RECALL, 0.5, 0.9, 0.05 + ) + + # Lower threshold for recall optimization + assert threshold <= 0.70 + assert metrics.recall >= 0.9 + + def test_recommend_threshold_with_data(self): + """Test threshold recommendation with test data.""" + scores = [0.95, 0.85, 0.75, 0.65, 0.55, 0.45] + ground_truth = [True, True, True, False, False, False] + + recommendation = ThresholdTuner.recommend_threshold( + UseCase.CACHE_HIT, scores, ground_truth + ) + + assert isinstance(recommendation, ThresholdRecommendation) + assert 0.5 <= recommendation.threshold <= 0.99 + assert recommendation.use_case == UseCase.CACHE_HIT + assert recommendation.confidence > 0 + assert len(recommendation.reasoning) > 0 + + def test_recommend_threshold_without_data(self): + """Test threshold recommendation using defaults.""" + recommendation = ThresholdTuner.recommend_threshold(UseCase.EXACT_MATCH) + + assert ( + recommendation.threshold + == ThresholdTuner.DEFAULT_THRESHOLDS[UseCase.EXACT_MATCH] + ) + assert recommendation.use_case == UseCase.EXACT_MATCH + assert len(recommendation.alternative_thresholds) > 0 + + def test_default_thresholds_ordering(self): + """Test that default thresholds follow expected ordering.""" + exact = ThresholdTuner.DEFAULT_THRESHOLDS[UseCase.EXACT_MATCH] + cache = ThresholdTuner.DEFAULT_THRESHOLDS[UseCase.CACHE_HIT] + dedup = ThresholdTuner.DEFAULT_THRESHOLDS[UseCase.DEDUPLICATION] + similar = ThresholdTuner.DEFAULT_THRESHOLDS[UseCase.SIMILAR_CONTENT] + + # Exact match should be highest + assert exact >= cache + assert exact >= dedup + assert exact >= similar + + # Similar content should be lowest + assert similar <= cache + assert similar <= dedup + + def test_analyze_threshold_range(self): + """Test threshold range analysis.""" + scores = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4] + ground_truth = [True, True, True, False, False, False] + + results = ThresholdTuner.analyze_threshold_range( + scores, ground_truth, start=0.5, end=0.8, step=0.1 + ) + + assert len(results) == 4 # 0.5, 0.6, 0.7, 0.8 + assert all(isinstance(m, ThresholdMetrics) for m in results) + assert results[0].threshold == pytest.approx(0.5, abs=0.01) + assert results[-1].threshold == pytest.approx(0.8, abs=0.01) + + def test_format_analysis_report(self): + """Test report formatting.""" + metrics_list = [ + ThresholdMetrics( + threshold=0.7, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ), + ThresholdMetrics( + threshold=0.8, + true_positives=70, + false_positives=10, + true_negatives=80, + false_negatives=40, + ), + ] + + report = ThresholdTuner.format_analysis_report(metrics_list) + + assert "Threshold Analysis Report" in report + assert "0.700" in report + assert "0.800" in report + assert "Precision" in report + assert "Recall" in report + + def test_get_threshold_for_use_case(self): + """Test getting threshold by use case.""" + cache_threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.CACHE_HIT) + exact_threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.EXACT_MATCH) + + assert cache_threshold == 0.85 + assert exact_threshold == 0.95 + assert exact_threshold > cache_threshold + + +class TestConvenienceFunctions: + """Tests for convenience functions.""" + + def test_tune_threshold_function(self): + """Test tune_threshold convenience function.""" + scores = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4] + ground_truth = [True, True, True, False, False, False] + + threshold = tune_threshold(scores, ground_truth) + + assert 0.5 <= threshold <= 0.99 + + def test_get_cache_threshold(self): + """Test get_cache_threshold function.""" + threshold = get_cache_threshold() + assert threshold == 0.85 + + def test_get_exact_match_threshold(self): + """Test get_exact_match_threshold function.""" + threshold = get_exact_match_threshold() + assert threshold == 0.95 + + def test_evaluate_threshold_quality(self): + """Test evaluate_threshold_quality function.""" + scores = [0.9, 0.8, 0.7, 0.6] + ground_truth = [True, True, False, False] + + metrics = evaluate_threshold_quality(scores, ground_truth, 0.75) + + assert isinstance(metrics, ThresholdMetrics) + assert metrics.threshold == 0.75 + + +class TestThresholdRecommendation: + """Tests for ThresholdRecommendation class.""" + + def test_recommendation_summary(self): + """Test recommendation summary formatting.""" + metrics = ThresholdMetrics( + threshold=0.85, + true_positives=80, + false_positives=20, + true_negatives=70, + false_negatives=30, + ) + + recommendation = ThresholdRecommendation( + threshold=0.85, + use_case=UseCase.CACHE_HIT, + metrics=metrics, + reasoning="Test reasoning", + confidence=0.9, + alternative_thresholds={"precision": 0.90, "recall": 0.80}, + ) + + summary = recommendation.summary() + + assert "Threshold Recommendation" in summary + assert "0.850" in summary + assert "90.0%" in summary + assert "Test reasoning" in summary + assert "Alternative Thresholds" in summary + + +class TestUseCaseThresholds: + """Tests for use case specific thresholds.""" + + def test_exact_match_use_case(self): + """Test exact match use case has high threshold.""" + threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.EXACT_MATCH) + assert threshold >= 0.90 + + def test_cache_hit_use_case(self): + """Test cache hit use case has balanced threshold.""" + threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.CACHE_HIT) + assert 0.80 <= threshold <= 0.90 + + def test_similar_content_use_case(self): + """Test similar content use case has moderate threshold.""" + threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.SIMILAR_CONTENT) + assert 0.60 <= threshold <= 0.80 + + def test_deduplication_use_case(self): + """Test deduplication use case has moderate-high threshold.""" + threshold = ThresholdTuner.get_threshold_for_use_case(UseCase.DEDUPLICATION) + assert 0.75 <= threshold <= 0.85 + + +class TestEdgeCases: + """Tests for edge cases.""" + + def test_all_true_positives(self): + """Test when all predictions are true positives.""" + scores = [0.9, 0.8, 0.7] + ground_truth = [True, True, True] + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, 0.6) + + assert metrics.true_positives == 3 + assert metrics.false_positives == 0 + assert metrics.true_negatives == 0 + assert metrics.false_negatives == 0 + assert metrics.precision == 1.0 + assert metrics.recall == 1.0 + + def test_all_true_negatives(self): + """Test when all predictions are true negatives.""" + scores = [0.3, 0.2, 0.1] + ground_truth = [False, False, False] + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, 0.5) + + assert metrics.true_positives == 0 + assert metrics.false_positives == 0 + assert metrics.true_negatives == 3 + assert metrics.false_negatives == 0 + + def test_empty_datasets(self): + """Test with empty datasets.""" + scores = [] + ground_truth = [] + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, 0.5) + + assert metrics.true_positives == 0 + assert metrics.false_positives == 0 + assert metrics.true_negatives == 0 + assert metrics.false_negatives == 0 + + def test_single_data_point(self): + """Test with single data point.""" + scores = [0.85] + ground_truth = [True] + + metrics = ThresholdTuner.evaluate_threshold(scores, ground_truth, 0.80) + + assert metrics.true_positives == 1 + assert metrics.precision == 1.0 + assert metrics.recall == 1.0 diff --git a/tests/unit/similarity/test_vector_normalizer.py b/tests/unit/similarity/test_vector_normalizer.py new file mode 100644 index 0000000..398ee68 --- /dev/null +++ b/tests/unit/similarity/test_vector_normalizer.py @@ -0,0 +1,250 @@ +"""Unit tests for vector normalizer.""" + +import math + +import pytest + +from app.similarity.vector_normalizer import ( + NormalizationType, + VectorNormalizer, + clip_vector, + l1_normalize, + l2_normalize, + max_normalize, + standardize_vector, +) + + +class TestL2Normalize: + """Tests for L2 normalization.""" + + def test_l2_normalize_unit_vector(self): + """Test L2 normalization of unit vector.""" + vector = [1.0, 0.0, 0.0] + normalized = l2_normalize(vector) + + assert abs(normalized[0] - 1.0) < 1e-6 + assert abs(normalized[1] - 0.0) < 1e-6 + assert abs(normalized[2] - 0.0) < 1e-6 + + def test_l2_normalize_regular_vector(self): + """Test L2 normalization produces unit magnitude.""" + vector = [3.0, 4.0, 0.0] + normalized = l2_normalize(vector) + + # Should be [0.6, 0.8, 0.0] (3-4-5 triangle) + assert abs(normalized[0] - 0.6) < 1e-6 + assert abs(normalized[1] - 0.8) < 1e-6 + + # Check magnitude is 1.0 + magnitude = math.sqrt(sum(x**2 for x in normalized)) + assert abs(magnitude - 1.0) < 1e-6 + + def test_l2_normalize_zero_vector(self): + """Test L2 normalization of zero vector returns zero vector.""" + vector = [0.0, 0.0, 0.0] + normalized = l2_normalize(vector) + + assert all(x == 0.0 for x in normalized) + + def test_l2_normalize_preserves_direction(self): + """Test L2 normalization preserves vector direction.""" + vector = [2.0, 2.0, 2.0] + normalized = l2_normalize(vector) + + # All components should be equal (same direction) + assert abs(normalized[0] - normalized[1]) < 1e-6 + assert abs(normalized[1] - normalized[2]) < 1e-6 + + +class TestL1Normalize: + """Tests for L1 normalization.""" + + def test_l1_normalize_unit_vector(self): + """Test L1 normalization of unit vector.""" + vector = [1.0, 0.0, 0.0] + normalized = l1_normalize(vector) + + assert abs(normalized[0] - 1.0) < 1e-6 + assert abs(sum(abs(x) for x in normalized) - 1.0) < 1e-6 + + def test_l1_normalize_regular_vector(self): + """Test L1 normalization produces unit L1 norm.""" + vector = [1.0, 2.0, 3.0] + normalized = l1_normalize(vector) + + # L1 norm should be 1.0 + l1_norm = sum(abs(x) for x in normalized) + assert abs(l1_norm - 1.0) < 1e-6 + + def test_l1_normalize_zero_vector(self): + """Test L1 normalization of zero vector returns zero vector.""" + vector = [0.0, 0.0, 0.0] + normalized = l1_normalize(vector) + + assert all(x == 0.0 for x in normalized) + + +class TestMaxNormalize: + """Tests for max normalization.""" + + def test_max_normalize_regular_vector(self): + """Test max normalization scales by max absolute value.""" + vector = [1.0, 2.0, 4.0] + normalized = max_normalize(vector) + + # Should be scaled by 1/4 + assert abs(normalized[0] - 0.25) < 1e-6 + assert abs(normalized[1] - 0.5) < 1e-6 + assert abs(normalized[2] - 1.0) < 1e-6 + + def test_max_normalize_negative_values(self): + """Test max normalization with negative values.""" + vector = [-4.0, 2.0, 1.0] + normalized = max_normalize(vector) + + # Should be scaled by 1/4 (abs max is 4) + assert abs(normalized[0] - (-1.0)) < 1e-6 + assert abs(normalized[1] - 0.5) < 1e-6 + + def test_max_normalize_zero_vector(self): + """Test max normalization of zero vector returns zero vector.""" + vector = [0.0, 0.0, 0.0] + normalized = max_normalize(vector) + + assert all(x == 0.0 for x in normalized) + + +class TestStandardize: + """Tests for vector standardization.""" + + def test_standardize_regular_vector(self): + """Test standardization produces zero mean and unit variance.""" + vector = [1.0, 2.0, 3.0, 4.0, 5.0] + standardized = standardize_vector(vector) + + # Mean should be close to 0 + mean = sum(standardized) / len(standardized) + assert abs(mean) < 1e-6 + + # Variance should be close to 1 + variance = sum((x - mean) ** 2 for x in standardized) / len(standardized) + assert abs(variance - 1.0) < 1e-6 + + def test_standardize_constant_vector(self): + """Test standardization of constant vector returns zero vector.""" + vector = [5.0, 5.0, 5.0] + standardized = standardize_vector(vector) + + assert all(x == 0.0 for x in standardized) + + +class TestClipVector: + """Tests for vector clipping.""" + + def test_clip_vector_within_range(self): + """Test clipping vector already within range.""" + vector = [0.5, 0.3, -0.2] + clipped = clip_vector(vector, min_val=-1.0, max_val=1.0) + + assert clipped == vector + + def test_clip_vector_exceeds_max(self): + """Test clipping vector exceeding max value.""" + vector = [0.5, 1.5, -0.2] + clipped = clip_vector(vector, min_val=-1.0, max_val=1.0) + + assert clipped[0] == 0.5 + assert clipped[1] == 1.0 # Clipped to max + assert clipped[2] == -0.2 + + def test_clip_vector_below_min(self): + """Test clipping vector below min value.""" + vector = [0.5, -1.5, -0.2] + clipped = clip_vector(vector, min_val=-1.0, max_val=1.0) + + assert clipped[0] == 0.5 + assert clipped[1] == -1.0 # Clipped to min + assert clipped[2] == -0.2 + + +class TestVectorNormalizer: + """Tests for VectorNormalizer class.""" + + def test_normalize_l2(self): + """Test normalize with L2 normalization.""" + vector = [3.0, 4.0, 0.0] + normalized = VectorNormalizer.normalize(vector, norm_type=NormalizationType.L2) + + magnitude = math.sqrt(sum(x**2 for x in normalized)) + assert abs(magnitude - 1.0) < 1e-6 + + def test_normalize_l1(self): + """Test normalize with L1 normalization.""" + vector = [1.0, 2.0, 3.0] + normalized = VectorNormalizer.normalize(vector, norm_type=NormalizationType.L1) + + l1_norm = sum(abs(x) for x in normalized) + assert abs(l1_norm - 1.0) < 1e-6 + + def test_normalize_max(self): + """Test normalize with max normalization.""" + vector = [1.0, 2.0, 4.0] + normalized = VectorNormalizer.normalize(vector, norm_type=NormalizationType.MAX) + + assert abs(max(abs(x) for x in normalized) - 1.0) < 1e-6 + + def test_normalize_invalid_type(self): + """Test normalize with invalid type raises error.""" + vector = [1.0, 2.0, 3.0] + + with pytest.raises(ValueError, match="Unknown normalization type"): + VectorNormalizer.normalize(vector, norm_type="invalid") # type: ignore + + def test_batch_normalize(self): + """Test batch normalization.""" + vectors = [ + [3.0, 4.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 5.0, 12.0], + ] + + normalized = VectorNormalizer.batch_normalize( + vectors, norm_type=NormalizationType.L2 + ) + + assert len(normalized) == 3 + for vec in normalized: + magnitude = math.sqrt(sum(x**2 for x in vec)) + assert abs(magnitude - 1.0) < 1e-6 or all(x == 0.0 for x in vec) + + def test_is_normalized_true(self): + """Test is_normalized returns True for normalized vector.""" + vector = [0.6, 0.8, 0.0] # Already L2 normalized + + assert VectorNormalizer.is_normalized(vector, norm_type=NormalizationType.L2) + + def test_is_normalized_false(self): + """Test is_normalized returns False for non-normalized vector.""" + vector = [3.0, 4.0, 0.0] # Not normalized + + assert not VectorNormalizer.is_normalized( + vector, norm_type=NormalizationType.L2 + ) + + def test_safe_normalize(self): + """Test safe_normalize handles zero vectors.""" + zero_vector = [0.0, 0.0, 0.0] + normalized = VectorNormalizer.safe_normalize( + zero_vector, norm_type=NormalizationType.L2 + ) + + assert all(x == 0.0 for x in normalized) + + regular_vector = [3.0, 4.0, 0.0] + normalized = VectorNormalizer.safe_normalize( + regular_vector, norm_type=NormalizationType.L2 + ) + + magnitude = math.sqrt(sum(x**2 for x in normalized)) + assert abs(magnitude - 1.0) < 1e-6