diff --git a/docs/integrations/storage/azure_blob.md b/docs/integrations/storage/azure_blob.md new file mode 100644 index 000000000..baf55a5fd --- /dev/null +++ b/docs/integrations/storage/azure_blob.md @@ -0,0 +1,125 @@ +# Azure Blob Storage + +`AzureBlobLoader` fetches blobs from an Azure Blob Storage container and returns them as +[`Chunk`](../../rag/vector_stores/vector_store_info.md) objects containing +UTF-8 decoded content plus source metadata (`source`, `account_url`, `container`, +`blob_name`). + +## Installation + +=== "pip" + + ```bash + pip install railtracks[azure-blob] + ``` + +=== "uv" + + ```bash + uv add railtracks[azure-blob] + ``` + +## Authentication + +Authentication defaults to **`DefaultAzureCredential`**, which automatically resolves +credentials from the following sources (in order): + +1. Environment variables (`AZURE_CLIENT_ID`, `AZURE_TENANT_ID`, `AZURE_CLIENT_SECRET`) +2. Workload identity (Kubernetes) +3. Managed identity (Azure-hosted compute) +4. Azure CLI (`az login`) +5. Azure PowerShell / Visual Studio / IntelliJ + +Pass an explicit `credential` to override. + +!!! tip "Prefer managed identity over connection strings" + Managed identity is the recommended authentication method for Azure-hosted + workloads — it requires no secrets and rotates automatically. Avoid + embedding storage account keys or SAS tokens in source code; store them + in Azure Key Vault or environment variables instead. + +## Basic usage + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_basic" +``` + +## Load by prefix + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_prefix" +``` + +## Load specific blobs + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_load_keys" +``` + +## Async usage + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_async" +``` + +!!! note "Async is thread-backed" + `aload()` and `aload_keys()` run the synchronous `azure-storage-blob` + client on a thread-pool thread via `asyncio.to_thread()`. This is correct + for most workloads; for very high concurrency consider the async Azure SDK + (`azure.storage.blob.aio`). + +## Override credentials + +**SAS token** + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_sas" +``` + +**System-assigned or user-assigned managed identity** + +```python +--8<-- "docs/scripts/storage_loaders.py:azure_managed_identity" +``` + +## Chunk metadata + +Each returned `Chunk` carries: + +| Key | Value | +|---|---| +| `source` | Full blob URL: `https://.blob.core.windows.net//` | +| `account_url` | Storage account URL | +| `container` | Container name | +| `blob_name` | Blob name (path within the container) | + +## Full RAG pipeline example + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_azure_to_rag" +``` + +--- + +## Writing to Azure Blob Storage + +`AzureBlobWriter` uploads text content to a blob container. Existing blobs at +the same name are overwritten. + +### Basic write + +```python +--8<-- "docs/scripts/storage_writers.py:azure_write_basic" +``` + +### SAS token credential + +```python +--8<-- "docs/scripts/storage_writers.py:azure_write_sas" +``` + +### Async write + +```python +--8<-- "docs/scripts/storage_writers.py:azure_write_async" +``` diff --git a/docs/integrations/storage/gcs.md b/docs/integrations/storage/gcs.md new file mode 100644 index 000000000..5da636c01 --- /dev/null +++ b/docs/integrations/storage/gcs.md @@ -0,0 +1,110 @@ +# Google Cloud Storage + +`GCSLoader` fetches objects from a GCS bucket and returns them as +[`Chunk`](../../rag/vector_stores/vector_store_info.md) objects containing +UTF-8 decoded content plus source metadata (`source`, `bucket`, `name`). + +## Installation + +=== "pip" + + ```bash + pip install railtracks[gcp] + ``` + +=== "uv" + + ```bash + uv add railtracks[gcp] + ``` + +## Authentication + +Authentication uses **Application Default Credentials (ADC)** by default: + +1. `GOOGLE_APPLICATION_CREDENTIALS` environment variable (path to a service-account JSON) +2. `gcloud auth application-default login` (developer workstation) +3. Workload Identity / attached service account (GCE, GKE, Cloud Run, Cloud Functions …) + +Pass explicit `credentials` to override ADC. + +!!! tip "Prefer Workload Identity over service-account key files" + Service-account JSON key files are long-lived credentials that require + manual rotation. On GCP-hosted compute, Workload Identity or attached + service accounts are more secure and require zero key management. + +## Basic usage + +```python +--8<-- "docs/scripts/storage_loaders.py:gcs_basic" +``` + +## Load by prefix + +```python +--8<-- "docs/scripts/storage_loaders.py:gcs_prefix" +``` + +## Load specific objects + +```python +--8<-- "docs/scripts/storage_loaders.py:gcs_load_keys" +``` + +## Async usage + +```python +--8<-- "docs/scripts/storage_loaders.py:gcs_async" +``` + +!!! note "Async is thread-backed" + `aload()` and `aload_keys()` run the synchronous `google-cloud-storage` + client on a thread-pool thread via `asyncio.to_thread()`. This is correct + for most workloads. + +## Override credentials (service account key file) + +```python +--8<-- "docs/scripts/storage_loaders.py:gcs_service_account" +``` + +## Chunk metadata + +Each returned `Chunk` carries: + +| Key | Value | +|---|---| +| `source` | `gs:///` | +| `bucket` | GCS bucket name | +| `name` | Object name (path within the bucket) | + +## Full RAG pipeline example + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_gcs_to_rag" +``` + +--- + +## Writing to GCS + +`GCSWriter` uploads text content to a GCS bucket. Existing objects at the +same name are overwritten. + +### Basic write + +```python +--8<-- "docs/scripts/storage_writers.py:gcs_write_basic" +``` + +### Service account credentials + +```python +--8<-- "docs/scripts/storage_writers.py:gcs_write_service_account" +``` + +### Async write + +```python +--8<-- "docs/scripts/storage_writers.py:gcs_write_async" +``` diff --git a/docs/integrations/storage/overview.md b/docs/integrations/storage/overview.md new file mode 100644 index 000000000..4cd476157 --- /dev/null +++ b/docs/integrations/storage/overview.md @@ -0,0 +1,136 @@ +# Cloud Storage & Database Loaders / Writers + +Railtracks ships first-class **loaders** and **writers** for popular cloud +storage providers and relational databases. + +- **Loaders** fetch documents and return them as + [`Chunk`](../../rag/vector_stores/vector_store_info.md) objects — pipe remote + data straight into a vector store or agent without any glue code. +- **Writers** persist `Chunk` objects (or raw text) back to the same providers — + close the loop by saving AI-generated content to storage. + +## Supported providers + +| Provider | Loader | Writer | Install extra | +|---|---|---|---| +| AWS S3 | `S3Loader` | `S3Writer` | `railtracks[aws]` | +| Azure Blob Storage | `AzureBlobLoader` | `AzureBlobWriter` | `railtracks[azure-blob]` | +| Google Cloud Storage | `GCSLoader` | `GCSWriter` | `railtracks[gcp]` | +| SQL (PostgreSQL, Supabase, MySQL, SQLite …) | `SQLLoader` | `SQLWriter` | `railtracks[sql]` | + +Install any combination: + +=== "pip" + + ```bash + pip install "railtracks[aws,gcp,azure-blob,sql]" + ``` + +=== "uv" + + ```bash + uv add "railtracks[aws,gcp,azure-blob,sql]" + ``` + +## Loading — quick examples + +=== "AWS S3" + + ```python + --8<-- "docs/scripts/storage_loaders.py:s3_basic" + ``` + +=== "Azure Blob" + + ```python + --8<-- "docs/scripts/storage_loaders.py:azure_basic" + ``` + +=== "Google Cloud Storage" + + ```python + --8<-- "docs/scripts/storage_loaders.py:gcs_basic" + ``` + +=== "SQL / Database" + + ```python + --8<-- "docs/scripts/storage_loaders.py:sql_basic_postgres" + ``` + +## Writing — quick examples + +=== "AWS S3" + + ```python + --8<-- "docs/scripts/storage_writers.py:s3_write_basic" + ``` + +=== "Azure Blob" + + ```python + --8<-- "docs/scripts/storage_writers.py:azure_write_basic" + ``` + +=== "Google Cloud Storage" + + ```python + --8<-- "docs/scripts/storage_writers.py:gcs_write_basic" + ``` + +=== "SQL / Database" + + ```python + --8<-- "docs/scripts/storage_writers.py:sql_write_basic" + ``` + +## Feeding chunks into a RAG pipeline + +All loaders return the same `Chunk` type that `ChromaVectorStore.upsert()` accepts, +making it trivial to build a full load → index → retrieve → answer pipeline: + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_s3_to_rag" +``` + +## Load → Generate → Write back + +Writers make it easy to persist AI-generated content alongside source data: + +```python +--8<-- "docs/scripts/storage_writers.py:pipeline_generate_and_write" +``` + +## Async support + +Every loader and writer exposes async variants (`aload`, `aload_keys`, `awrite`, +`awrite_key`) that are safe to use in `async` agent pipelines: + +```python +chunks = await loader.aload(prefix="reports/2024/") +uris = await writer.awrite(chunks, prefix="summaries/") +``` + +The async methods delegate to `asyncio.to_thread()`, so they are non-blocking +from the caller's perspective while the underlying SDK call runs on a thread-pool +thread. + +## Key derivation for writers + +When writing `Chunk` objects, the storage key (S3 key, GCS object name, blob +name, SQL id) is derived in this order: + +1. Return value of `key_fn(chunk)` — if `key_fn` is provided +2. `chunk.id` — if set +3. `chunk.document` — if set +4. A freshly generated UUID4 — as a last resort + +Pass `key_fn` to take full control of the naming scheme: + +```python +writer = S3Writer("my-bucket", key_fn=lambda c: f"docs/{c.id}.txt") +``` + +!!! tip "Next steps" + - [AWS S3](s3.md) · [Azure Blob Storage](azure_blob.md) · [Google Cloud Storage](gcs.md) · [SQL](sql.md) + - [Cloud Storage Loaders Tutorial](../../tutorials/walkthroughs/storage_loaders_tutorial.md) diff --git a/docs/integrations/storage/s3.md b/docs/integrations/storage/s3.md new file mode 100644 index 000000000..be532db92 --- /dev/null +++ b/docs/integrations/storage/s3.md @@ -0,0 +1,126 @@ +# AWS S3 + +`S3Loader` fetches objects from an S3 bucket and returns them as +[`Chunk`](../../rag/vector_stores/vector_store_info.md) objects containing +UTF-8 decoded content plus source metadata (`source`, `bucket`, `key`). + +## Installation + +=== "pip" + + ```bash + pip install railtracks[aws] + ``` + +=== "uv" + + ```bash + uv add railtracks[aws] + ``` + +## Authentication + +Credentials follow **boto3's standard resolution chain** — no explicit configuration +needed in most environments: + +1. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`) +2. Shared credentials file (`~/.aws/credentials`) +3. AWS config file (`~/.aws/config`) +4. IAM role attached to an EC2 instance / ECS task / Lambda function + +Pass explicit credentials to the constructor to override the chain. + +!!! tip "Prefer IAM roles and environment variables over hard-coded credentials" + Never embed AWS keys directly in source code. Use environment variables, + AWS Secrets Manager, or an IAM instance profile wherever possible. + +## Basic usage + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_basic" +``` + +## Load by prefix + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_prefix" +``` + +## Load specific keys + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_load_keys" +``` + +## Async usage + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_async" +``` + +!!! note "Async is thread-backed" + `aload()` and `aload_keys()` run the synchronous boto3 client on a + thread-pool thread via `asyncio.to_thread()`. This is correct for most + workloads; for very high concurrency consider `aioboto3`. + +## Override credentials + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_explicit_creds" +``` + +## S3-compatible services (MinIO, LocalStack …) + +```python +--8<-- "docs/scripts/storage_loaders.py:s3_minio" +``` + +## Chunk metadata + +Each returned `Chunk` carries: + +| Key | Value | +|---|---| +| `source` | `s3:///` | +| `bucket` | S3 bucket name | +| `key` | Object key (path) | + +## Full RAG pipeline example + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_s3_to_rag" +``` + +--- + +## Writing to S3 + +`S3Writer` uploads text content to an S3 bucket. Existing objects at the same +key are silently overwritten. + +### Basic write + +```python +--8<-- "docs/scripts/storage_writers.py:s3_write_basic" +``` + +### Custom key derivation + +By default the object key is taken from `chunk.id`. Pass a `key_fn` to +build any key structure you need: + +```python +--8<-- "docs/scripts/storage_writers.py:s3_write_key_fn" +``` + +### Async write + +```python +--8<-- "docs/scripts/storage_writers.py:s3_write_async" +``` + +### End-to-end: load → generate → write back + +```python +--8<-- "docs/scripts/storage_writers.py:pipeline_generate_and_write" +``` diff --git a/docs/integrations/storage/sql.md b/docs/integrations/storage/sql.md new file mode 100644 index 000000000..396285491 --- /dev/null +++ b/docs/integrations/storage/sql.md @@ -0,0 +1,293 @@ +# SQL / Relational Databases + +`SQLLoader` reads rows from any **SQLAlchemy-compatible relational database** and +returns them as [`Chunk`](../../rag/vector_stores/vector_store_info.md) objects. +Works with PostgreSQL, Supabase, MySQL, SQLite, and more. + +## Installation + +=== "pip" + + ```bash + pip install railtracks[sql] + ``` + +=== "uv" + + ```bash + uv add railtracks[sql] + ``` + +For PostgreSQL / Supabase you also need a driver: + +=== "pip" + + ```bash + pip install psycopg2-binary # PostgreSQL (most common) + ``` + +=== "uv" + + ```bash + uv add psycopg2-binary + ``` + +## Connecting + +Pass a [SQLAlchemy database URL](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls): + +| Database | URL format | +|---|---| +| PostgreSQL | `postgresql+psycopg2://user:pass@host/db` | +| Supabase | `postgresql+psycopg2://postgres:pass@db..supabase.co:5432/postgres` | +| MySQL | `mysql+pymysql://user:pass@host/db` | +| SQLite (file) | `sqlite:///path/to/file.db` | +| SQLite (memory) | `sqlite:///:memory:` | + +## Basic usage — PostgreSQL + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_basic_postgres" +``` + +## Supabase + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_supabase" +``` + +## Raw SQL query + +Pass any `SELECT` statement instead of a table name for filtering, joining, or +transforming data before it reaches the loader: + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_raw_query" +``` + +!!! warning "CTE (`WITH`) queries are not supported directly" + `table_or_query` is detected as a raw query only when the string starts with + `SELECT`. Queries beginning with `WITH` (Common Table Expressions) are + treated as table names and will cause a database error. + + **Workaround** — wrap your CTE in a subquery: + + ```python + loader = SQLLoader( + connection_string, + table_or_query=""" + SELECT * FROM ( + WITH ranked AS ( + SELECT id, body, ROW_NUMBER() OVER (ORDER BY created_at DESC) AS rn + FROM docs + ) + SELECT id, body FROM ranked WHERE rn <= 100 + ) AS t + """, + content_column="body", + id_column="id", + ) + ``` + +## Load specific rows by ID + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_load_keys" +``` + +!!! note + `load_keys()` requires `id_column` to be set when constructing the loader. + +## Reuse an existing engine + +When you already have a configured `sqlalchemy.Engine` (custom pool size, SSL +certificates, read replicas, etc.) pass it directly via the `engine` parameter: + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_existing_engine" +``` + +!!! tip "Engine ownership" + When you supply your own `engine`, the loader does **not** dispose it on + `close()`. You remain responsible for its lifecycle. When the loader + creates its own engine (the default), `close()` disposes it for you. + +## Engine lifecycle — close and context manager + +For long-lived applications or scripts that create many loaders, explicitly +releasing the connection pool avoids resource leaks: + +```python +# Explicit close +loader = SQLLoader(connection_string, "documents", "body") +try: + chunks = loader.load() +finally: + loader.close() + +# Context-manager (preferred) +with SQLLoader(connection_string, "documents", "body") as loader: + chunks = loader.load() +``` + +## Async usage + +```python +--8<-- "docs/scripts/storage_loaders.py:sql_async" +``` + +!!! note "Async is thread-backed" + `aload()` and `aload_keys()` run the synchronous SQLAlchemy driver on a + thread-pool thread via `asyncio.to_thread()`. This works correctly but + occupies a thread for the full duration of the query. For very + high-concurrency workloads consider wiring up a true async engine + (e.g. `asyncpg` with `sqlalchemy.ext.asyncio`) and passing it via the + `engine` parameter. + +## Chunk metadata + +Each returned `Chunk` carries: + +| Key | Value | +|---|---| +| `source` | The `table_or_query` string used to construct the loader | +| _any `metadata_columns`_ | One key per column listed in `metadata_columns` | + +When `metadata_columns` is `None`, all columns except `content_column` and +`id_column` are included automatically. + +## Full RAG pipeline example + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_sql_to_rag" +``` + +--- + +## Writing to SQL databases + +`SQLWriter` inserts or upserts rows into any SQLAlchemy-compatible database. +The table must already exist. + +### Basic write — PostgreSQL + +```python +--8<-- "docs/scripts/storage_writers.py:sql_write_basic" +``` + +### Supabase + +```python +--8<-- "docs/scripts/storage_writers.py:sql_write_supabase" +``` + +### Insert vs upsert modes + +```python +--8<-- "docs/scripts/storage_writers.py:sql_write_modes" +``` + +!!! note "Upsert mechanics" + In `"upsert"` mode the writer issues a `DELETE` followed by an `INSERT` + inside a single transaction, which works across all SQLAlchemy-compatible + databases including SQLite, PostgreSQL, and MySQL. This is not the same + as a native `ON CONFLICT DO UPDATE` — it acquires a delete lock for the + duration of the insert, which may affect concurrent write throughput under + heavy load. + +!!! warning "All-or-nothing batch writes" + All chunks passed to a single `write()` call are committed inside **one + transaction**. If any individual row fails (constraint violation, type + mismatch, etc.) the **entire batch is rolled back** and no rows are + written. The exception from the failing row is re-raised so you can + inspect it. + + To tolerate partial failures, call `write_key()` per chunk in a loop: + + ```python + writer = SQLWriter(connection_string, "documents", "body", id_column="id") + written, failed = [], [] + for chunk in chunks: + try: + uri = writer.write_key(chunk.id, chunk.content) + written.append(uri) + except Exception as exc: + failed.append((chunk, exc)) + ``` + +### Engine lifecycle — close and context manager + +```python +# Context-manager (preferred) +with SQLWriter(connection_string, "documents", "body", id_column="id") as writer: + writer.write(chunks) + +# Explicit close +writer = SQLWriter(connection_string, "documents", "body", id_column="id") +try: + writer.write(chunks) +finally: + writer.close() +``` + +### Reuse an existing engine + +```python +--8<-- "docs/scripts/storage_writers.py:sql_write_existing_engine" +``` + +### Async write + +```python +--8<-- "docs/scripts/storage_writers.py:sql_write_async" +``` + +!!! note "Async is thread-backed" + `awrite()` and `awrite_key()` delegate to `asyncio.to_thread()`, the same + as the loader. See the async note in the loader section above. + +### Chunk-to-row mapping + +| Chunk field | SQL column | +|---|---| +| `chunk.content` | `content_column` (required) | +| `chunk.id` | `id_column` (when set) | +| `chunk.document` | `document_column` (when set) | +| `chunk.metadata[col]` | Each column in `metadata_columns` | + +--- + +## Security considerations + +!!! danger "Never pass user-controlled strings as identifiers" + `table_or_query`, `content_column`, `id_column`, `document_column`, and + `metadata_columns` are interpolated directly into SQL as structural + identifiers (table and column names). SQLAlchemy cannot parameterise these + the way it can parameterise values. + + Both `SQLLoader` and `SQLWriter` validate every identifier against a strict + allowlist (`[A-Za-z_][A-Za-z0-9_$]*`) at construction time and raise + `ValueError` on any value that contains SQL metacharacters. This catches + misconfiguration early, but **the best protection is to use only + hard-coded, developer-controlled strings** — never values derived from + user input or LLM output. + + For dynamic row filtering, use a parameterised `SELECT` query: + + ```python + # Safe: user_id is a bound parameter, not an identifier + loader = SQLLoader( + connection_string, + table_or_query="SELECT id, body FROM documents WHERE user_id = :uid", + content_column="body", + ) + # Execute with bound parameter via your engine directly, then pass chunks as needed. + ``` + + For connection strings, prefer environment variables or a secrets manager + over hard-coded passwords: + + ```python + import os + loader = SQLLoader(os.environ["DATABASE_URL"], "documents", "body") + ``` diff --git a/docs/scripts/storage_loaders.py b/docs/scripts/storage_loaders.py new file mode 100644 index 000000000..50b5d56d5 --- /dev/null +++ b/docs/scripts/storage_loaders.py @@ -0,0 +1,514 @@ +""" +Cloud storage loader examples for use in documentation via --8<-- includes. + +These snippets assume the relevant extras are installed: + pip install railtracks[aws] # for S3Loader + pip install railtracks[azure-blob] # for AzureBlobLoader +""" + +# --------------------------------------------------------------------------- +# Shared setup used across sections +# --------------------------------------------------------------------------- +# --8<-- [start:shared_embedding] +from railtracks.rag.embedding_service import EmbeddingService + +embedding_function = EmbeddingService().embed +# --8<-- [end:shared_embedding] + + +# =========================================================================== +# AWS S3 +# =========================================================================== + +# --8<-- [start:s3_basic] +from railtracks.loaders import S3Loader + +loader = S3Loader("my-bucket", region_name="us-east-1") + +# Load every object in the bucket +chunks = loader.load() + +for chunk in chunks: + print(chunk.metadata["source"], "→", chunk.content[:80]) +# --8<-- [end:s3_basic] + + +# --8<-- [start:s3_prefix] +from railtracks.loaders import S3Loader + +loader = S3Loader("my-bucket", region_name="us-east-1") + +# Load only objects under the "knowledge-base/" prefix +chunks = loader.load(prefix="knowledge-base/") +# --8<-- [end:s3_prefix] + + +# --8<-- [start:s3_load_keys] +from railtracks.loaders import S3Loader + +loader = S3Loader("my-bucket") + +# Load a specific set of objects by key +chunks = loader.load_keys([ + "policy.txt", + "faq.txt", + "onboarding/welcome.txt", +]) +# --8<-- [end:s3_load_keys] + + +# --8<-- [start:s3_explicit_creds] +from railtracks.loaders import S3Loader + +loader = S3Loader( + "my-bucket", + aws_access_key_id="AKIA...", + aws_secret_access_key="...", + region_name="eu-west-1", +) +chunks = loader.load() +# --8<-- [end:s3_explicit_creds] + + +# --8<-- [start:s3_minio] +from railtracks.loaders import S3Loader + +# Works with any S3-compatible service (MinIO, LocalStack, Ceph …) +loader = S3Loader( + "my-bucket", + endpoint_url="http://localhost:9000", + aws_access_key_id="minioadmin", + aws_secret_access_key="minioadmin", +) +chunks = loader.load() +# --8<-- [end:s3_minio] + + +# --8<-- [start:s3_async] +import asyncio +from railtracks.loaders import S3Loader + +async def load_s3_documents(): + loader = S3Loader("my-bucket", region_name="us-east-1") + + # Both methods have async equivalents + all_chunks = await loader.aload(prefix="docs/") + specific_chunks = await loader.aload_keys(["readme.txt", "faq.txt"]) + return all_chunks + specific_chunks + +chunks = asyncio.run(load_s3_documents()) +# --8<-- [end:s3_async] + + +# =========================================================================== +# Azure Blob Storage +# =========================================================================== + +# --8<-- [start:azure_basic] +from railtracks.loaders import AzureBlobLoader + +# DefaultAzureCredential resolves credentials automatically +# (env vars, managed identity, Azure CLI, …) +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", +) + +chunks = loader.load() + +for chunk in chunks: + print(chunk.metadata["source"], "→", chunk.content[:80]) +# --8<-- [end:azure_basic] + + +# --8<-- [start:azure_prefix] +from railtracks.loaders import AzureBlobLoader + +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", +) + +# Load only blobs whose names begin with "reports/2025/" +chunks = loader.load(prefix="reports/2025/") +# --8<-- [end:azure_prefix] + + +# --8<-- [start:azure_load_keys] +from railtracks.loaders import AzureBlobLoader + +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", +) + +chunks = loader.load_keys([ + "policy.txt", + "faq.txt", + "onboarding/welcome.txt", +]) +# --8<-- [end:azure_load_keys] + + +# --8<-- [start:azure_sas] +from azure.core.credentials import AzureSasCredential +from railtracks.loaders import AzureBlobLoader + +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", + credential=AzureSasCredential(""), +) +chunks = loader.load() +# --8<-- [end:azure_sas] + + +# --8<-- [start:azure_managed_identity] +from azure.identity import ManagedIdentityCredential +from railtracks.loaders import AzureBlobLoader + +# Pin to a specific user-assigned managed identity via its client ID +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", + credential=ManagedIdentityCredential(client_id=""), +) +chunks = loader.load() +# --8<-- [end:azure_managed_identity] + + +# --8<-- [start:azure_async] +import asyncio +from railtracks.loaders import AzureBlobLoader + +async def load_azure_documents(): + loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", + ) + + all_chunks = await loader.aload(prefix="reports/") + named_chunks = await loader.aload_keys(["readme.txt", "faq.txt"]) + return all_chunks + named_chunks + +chunks = asyncio.run(load_azure_documents()) +# --8<-- [end:azure_async] + + +# =========================================================================== +# Feeding loaded chunks into a RAG pipeline +# =========================================================================== + +# --8<-- [start:pipeline_s3_to_rag] +import railtracks as rt +from railtracks.loaders import S3Loader +from railtracks.vector_stores import ChromaVectorStore +from railtracks.rag.embedding_service import EmbeddingService + +# 1. Load documents from S3 +loader = S3Loader("my-knowledge-bucket", region_name="us-east-1") +chunks = loader.load(prefix="docs/") + +# 2. Create a vector store and embed the chunks +embedding_fn = EmbeddingService().embed +store = ChromaVectorStore("knowledge-base", embedding_function=embedding_fn) +store.upsert(chunks) + +# 3. Expose retrieval as an agent tool +@rt.function_node +def search_knowledge_base(query: str) -> str: + """Search the internal knowledge base for relevant information.""" + results = store.search(query, top_k=5) + return "\n\n".join(r.content for r in results) + +# 4. Build the agent +agent = rt.agent_node( + name="KnowledgeAgent", + llm=rt.llm.OpenAILLM("gpt-4o"), + system_message="You are a helpful assistant. Use the knowledge base to answer questions.", + tool_nodes=[search_knowledge_base], +) + +flow = rt.Flow("knowledge-flow", entry_point=agent) +response = flow.invoke("What is our remote work policy?") +# --8<-- [end:pipeline_s3_to_rag] + + +# =========================================================================== +# Google Cloud Storage +# =========================================================================== + +# --8<-- [start:gcs_basic] +from railtracks.loaders import GCSLoader + +# Application Default Credentials resolve automatically +# (GOOGLE_APPLICATION_CREDENTIALS, gcloud auth, Workload Identity …) +loader = GCSLoader("my-bucket", project="my-gcp-project") + +chunks = loader.load() + +for chunk in chunks: + print(chunk.metadata["source"], "→", chunk.content[:80]) +# --8<-- [end:gcs_basic] + + +# --8<-- [start:gcs_prefix] +from railtracks.loaders import GCSLoader + +loader = GCSLoader("my-bucket") +chunks = loader.load(prefix="knowledge-base/") +# --8<-- [end:gcs_prefix] + + +# --8<-- [start:gcs_load_keys] +from railtracks.loaders import GCSLoader + +loader = GCSLoader("my-bucket") +chunks = loader.load_keys([ + "policy.txt", + "faq.txt", + "onboarding/welcome.txt", +]) +# --8<-- [end:gcs_load_keys] + + +# --8<-- [start:gcs_service_account] +from google.oauth2 import service_account +from railtracks.loaders import GCSLoader + +credentials = service_account.Credentials.from_service_account_file( + "/path/to/service-account.json", + scopes=["https://www.googleapis.com/auth/cloud-platform"], +) +loader = GCSLoader("my-bucket", credentials=credentials) +chunks = loader.load() +# --8<-- [end:gcs_service_account] + + +# --8<-- [start:gcs_async] +import asyncio +from railtracks.loaders import GCSLoader + +async def load_gcs_documents(): + loader = GCSLoader("my-bucket", project="my-gcp-project") + all_chunks = await loader.aload(prefix="docs/") + named_chunks = await loader.aload_keys(["readme.txt", "faq.txt"]) + return all_chunks + named_chunks + +chunks = asyncio.run(load_gcs_documents()) +# --8<-- [end:gcs_async] + + +# --8<-- [start:pipeline_gcs_to_rag] +import railtracks as rt +from railtracks.loaders import GCSLoader +from railtracks.vector_stores import ChromaVectorStore +from railtracks.rag.embedding_service import EmbeddingService + +loader = GCSLoader("my-knowledge-bucket", project="my-gcp-project") +chunks = loader.load(prefix="docs/") + +embedding_fn = EmbeddingService().embed +store = ChromaVectorStore("knowledge-base", embedding_function=embedding_fn) +store.upsert(chunks) + +@rt.function_node +def search_knowledge_base(query: str) -> str: + """Search the internal knowledge base for relevant information.""" + results = store.search(query, top_k=5) + return "\n\n".join(r.content for r in results) + +agent = rt.agent_node( + name="KnowledgeAgent", + llm=rt.llm.OpenAILLM("gpt-4o"), + system_message="Answer questions using the knowledge base.", + tool_nodes=[search_knowledge_base], +) + +flow = rt.Flow("knowledge-flow", entry_point=agent) +response = flow.invoke("What is our remote work policy?") +# --8<-- [end:pipeline_gcs_to_rag] + + +# =========================================================================== +# SQL / Relational Database +# =========================================================================== + +# --8<-- [start:sql_basic_postgres] +from railtracks.loaders import SQLLoader + +loader = SQLLoader( + "postgresql+psycopg2://user:pass@db.example.com:5432/mydb", + table_or_query="documents", + content_column="body", + metadata_columns=["title", "author", "created_at"], + id_column="id", +) +chunks = loader.load() + +for chunk in chunks: + print(chunk.metadata["title"], "→", chunk.content[:80]) +# --8<-- [end:sql_basic_postgres] + + +# --8<-- [start:sql_supabase] +import os +from railtracks.loaders import SQLLoader + +# Supabase exposes a standard PostgreSQL connection string +loader = SQLLoader( + os.environ["SUPABASE_DB_URL"], # postgresql+psycopg2://... + table_or_query="knowledge_base", + content_column="content", + metadata_columns=["title", "category", "updated_at"], + id_column="id", + document_column="title", +) +chunks = loader.load() +# --8<-- [end:sql_supabase] + + +# --8<-- [start:sql_raw_query] +from railtracks.loaders import SQLLoader + +loader = SQLLoader( + "postgresql+psycopg2://user:pass@host/db", + table_or_query=( + "SELECT id, title, body " + "FROM articles " + "WHERE published = true AND category = 'policy'" + ), + content_column="body", + id_column="id", + document_column="title", +) +chunks = loader.load() +# --8<-- [end:sql_raw_query] + + +# --8<-- [start:sql_load_keys] +from railtracks.loaders import SQLLoader + +loader = SQLLoader( + "postgresql+psycopg2://user:pass@host/db", + table_or_query="documents", + content_column="body", + id_column="id", +) +# Fetch only specific rows by their id column value +chunks = loader.load_keys(["doc-001", "doc-002", "doc-003"]) +# --8<-- [end:sql_load_keys] + + +# --8<-- [start:sql_existing_engine] +import sqlalchemy as sa +from railtracks.loaders import SQLLoader + +# Reuse an engine you already have configured (custom pool, SSL, etc.) +engine = sa.create_engine( + "postgresql+psycopg2://user:pass@host/db", + pool_size=5, + max_overflow=10, +) +loader = SQLLoader( + "", # ignored when engine= is provided + table_or_query="documents", + content_column="body", + engine=engine, +) +chunks = loader.load() +# --8<-- [end:sql_existing_engine] + + +# --8<-- [start:sql_async] +import asyncio +from railtracks.loaders import SQLLoader + +async def load_sql_documents(): + loader = SQLLoader( + "postgresql+psycopg2://user:pass@host/db", + table_or_query="documents", + content_column="body", + id_column="id", + ) + all_chunks = await loader.aload() + specific_chunks = await loader.aload_keys(["doc-001", "doc-002"]) + return all_chunks + +chunks = asyncio.run(load_sql_documents()) +# --8<-- [end:sql_async] + + +# --8<-- [start:pipeline_sql_to_rag] +import railtracks as rt +from railtracks.loaders import SQLLoader +from railtracks.vector_stores import ChromaVectorStore +from railtracks.rag.embedding_service import EmbeddingService + +loader = SQLLoader( + "postgresql+psycopg2://user:pass@db.example.com/mydb", + table_or_query="knowledge_base", + content_column="content", + metadata_columns=["title", "category"], + id_column="id", +) +chunks = loader.load() + +embedding_fn = EmbeddingService().embed +store = ChromaVectorStore("sql-knowledge", embedding_function=embedding_fn) +store.upsert(chunks) + +@rt.function_node +def search_database(query: str) -> str: + """Search the knowledge base for information relevant to the query.""" + results = store.search(query, top_k=5) + return "\n\n".join(r.content for r in results) + +agent = rt.agent_node( + name="DatabaseAgent", + llm=rt.llm.OpenAILLM("gpt-4o"), + system_message="Answer questions using only information from the database.", + tool_nodes=[search_database], +) + +flow = rt.Flow("db-knowledge-flow", entry_point=agent) +response = flow.invoke("What is our refund policy?") +# --8<-- [end:pipeline_sql_to_rag] + + +# --8<-- [start:pipeline_azure_to_rag] +import railtracks as rt +from railtracks.loaders import AzureBlobLoader +from railtracks.vector_stores import ChromaVectorStore +from railtracks.rag.embedding_service import EmbeddingService + +# 1. Load documents from Azure Blob Storage +loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "company-docs", +) +chunks = loader.load(prefix="hr/") + +# 2. Build a vector store +embedding_fn = EmbeddingService().embed +store = ChromaVectorStore("hr-docs", embedding_function=embedding_fn) +store.upsert(chunks) + +# 3. Expose retrieval as a tool +@rt.function_node +def search_hr_docs(query: str) -> str: + """Search HR documentation for policies and procedures.""" + results = store.search(query, top_k=5) + return "\n\n".join(r.content for r in results) + +# 4. Build the agent +agent = rt.agent_node( + name="HRAgent", + llm=rt.llm.OpenAILLM("gpt-4o"), + system_message="You are an HR assistant. Answer questions based on company policies.", + tool_nodes=[search_hr_docs], +) + +flow = rt.Flow("hr-flow", entry_point=agent) +response = flow.invoke("How many vacation days do I get?") +# --8<-- [end:pipeline_azure_to_rag] diff --git a/docs/scripts/storage_writers.py b/docs/scripts/storage_writers.py new file mode 100644 index 000000000..6a452a342 --- /dev/null +++ b/docs/scripts/storage_writers.py @@ -0,0 +1,307 @@ +""" +Cloud storage writer examples for use in documentation via --8<-- includes. + +These snippets assume the relevant extras are installed: + pip install railtracks[aws] # for S3Writer + pip install railtracks[azure-blob] # for AzureBlobWriter + pip install railtracks[gcp] # for GCSWriter + pip install railtracks[sql] # for SQLWriter +""" + +# =========================================================================== +# AWS S3 +# =========================================================================== + +# --8<-- [start:s3_write_basic] +from railtracks.writers import S3Writer + +writer = S3Writer("my-bucket", region_name="us-east-1") + +# Write raw text at an explicit key +uri = writer.write_key("reports/summary.txt", "Today's executive summary ...") +print(uri) # s3://my-bucket/reports/summary.txt + +# Write a list of Chunk objects -- key is derived from chunk.id +uris = writer.write(chunks, prefix="generated/") +# --8<-- [end:s3_write_basic] + + +# --8<-- [start:s3_write_key_fn] +from railtracks.writers import S3Writer + +# Use a custom function to derive the storage key from each chunk +writer = S3Writer( + "my-bucket", + key_fn=lambda chunk: f"{chunk.metadata.get('category', 'misc')}/{chunk.id}.txt", +) +uris = writer.write(chunks) +# --8<-- [end:s3_write_key_fn] + + +# --8<-- [start:s3_write_explicit_creds] +from railtracks.writers import S3Writer + +writer = S3Writer( + "my-bucket", + aws_access_key_id="AKIA...", + aws_secret_access_key="...", + region_name="eu-west-1", + content_type="application/json", +) +uri = writer.write_key("data/result.json", '{"status": "ok"}') +# --8<-- [end:s3_write_explicit_creds] + + +# --8<-- [start:s3_write_async] +import asyncio +from railtracks.writers import S3Writer + +async def write_s3_documents(): + writer = S3Writer("my-bucket", region_name="us-east-1") + + # Write a batch of chunks asynchronously + uris = await writer.awrite(chunks, prefix="output/") + + # Write a single object asynchronously + uri = await writer.awrite_key("output/summary.txt", "Summary text ...") + return uris + +asyncio.run(write_s3_documents()) +# --8<-- [end:s3_write_async] + + +# =========================================================================== +# Azure Blob Storage +# =========================================================================== + +# --8<-- [start:azure_write_basic] +from railtracks.writers import AzureBlobWriter + +# DefaultAzureCredential resolves credentials automatically +writer = AzureBlobWriter( + "https://myaccount.blob.core.windows.net", + "my-container", +) + +# Write raw text at an explicit blob name +uri = writer.write_key("reports/summary.txt", "Today's executive summary ...") +print(uri) # https://myaccount.blob.core.windows.net/my-container/reports/summary.txt + +# Write a list of Chunk objects +uris = writer.write(chunks, prefix="generated/") +# --8<-- [end:azure_write_basic] + + +# --8<-- [start:azure_write_sas] +from azure.core.credentials import AzureSasCredential +from railtracks.writers import AzureBlobWriter + +writer = AzureBlobWriter( + "https://myaccount.blob.core.windows.net", + "my-container", + credential=AzureSasCredential(""), +) +uris = writer.write(chunks) +# --8<-- [end:azure_write_sas] + + +# --8<-- [start:azure_write_async] +import asyncio +from railtracks.writers import AzureBlobWriter + +async def write_azure_documents(): + writer = AzureBlobWriter( + "https://myaccount.blob.core.windows.net", + "my-container", + ) + uris = await writer.awrite(chunks, prefix="output/") + uri = await writer.awrite_key("output/summary.txt", "Summary ...") + return uris + +asyncio.run(write_azure_documents()) +# --8<-- [end:azure_write_async] + + +# =========================================================================== +# Google Cloud Storage +# =========================================================================== + +# --8<-- [start:gcs_write_basic] +from railtracks.writers import GCSWriter + +# Application Default Credentials resolve automatically +writer = GCSWriter("my-bucket", project="my-gcp-project") + +# Write raw text at an explicit object name +uri = writer.write_key("reports/summary.txt", "Today's executive summary ...") +print(uri) # gs://my-bucket/reports/summary.txt + +# Write a list of Chunk objects +uris = writer.write(chunks, prefix="generated/") +# --8<-- [end:gcs_write_basic] + + +# --8<-- [start:gcs_write_service_account] +from google.oauth2 import service_account +from railtracks.writers import GCSWriter + +credentials = service_account.Credentials.from_service_account_file( + "/path/to/service-account.json", + scopes=["https://www.googleapis.com/auth/cloud-platform"], +) +writer = GCSWriter("my-bucket", credentials=credentials) +uris = writer.write(chunks) +# --8<-- [end:gcs_write_service_account] + + +# --8<-- [start:gcs_write_async] +import asyncio +from railtracks.writers import GCSWriter + +async def write_gcs_documents(): + writer = GCSWriter("my-bucket", project="my-gcp-project") + uris = await writer.awrite(chunks, prefix="output/") + uri = await writer.awrite_key("output/summary.txt", "Summary ...") + return uris + +asyncio.run(write_gcs_documents()) +# --8<-- [end:gcs_write_async] + + +# =========================================================================== +# SQL / Relational Database +# =========================================================================== + +# --8<-- [start:sql_write_basic] +from railtracks.writers import SQLWriter + +writer = SQLWriter( + "postgresql+psycopg2://user:pass@db.example.com:5432/mydb", + table="documents", + content_column="body", + id_column="id", + metadata_columns=["title", "category"], +) + +# Write (upsert) a list of Chunk objects +uris = writer.write(chunks) +# Returns: ["sql://documents/", ...] + +# Write raw content at an explicit id +uri = writer.write_key("doc-42", "Revised policy text ...") +# Returns: "sql://documents/doc-42" +# --8<-- [end:sql_write_basic] + + +# --8<-- [start:sql_write_supabase] +import os +from railtracks.writers import SQLWriter + +writer = SQLWriter( + os.environ["SUPABASE_DB_URL"], # postgresql+psycopg2://... + table="knowledge_base", + content_column="content", + id_column="id", + document_column="title", + metadata_columns=["category", "updated_at"], +) +uris = writer.write(chunks) +# --8<-- [end:sql_write_supabase] + + +# --8<-- [start:sql_write_modes] +from railtracks.writers import SQLWriter + +# Default: "upsert" -- existing rows with the same id are replaced +writer = SQLWriter( + "postgresql+psycopg2://user:pass@host/db", + table="documents", + content_column="body", + id_column="id", + mode="upsert", # safe to call repeatedly +) +writer.write(chunks) + +# "insert" mode -- rows are appended without conflict handling +append_writer = SQLWriter( + "postgresql+psycopg2://user:pass@host/db", + table="audit_log", + content_column="message", + mode="insert", +) +append_writer.write(chunks) +# --8<-- [end:sql_write_modes] + + +# --8<-- [start:sql_write_existing_engine] +import sqlalchemy as sa +from railtracks.writers import SQLWriter + +engine = sa.create_engine( + "postgresql+psycopg2://user:pass@host/db", + pool_size=5, + max_overflow=10, +) +writer = SQLWriter( + "", # ignored when engine= is provided + table="documents", + content_column="body", + id_column="id", + engine=engine, +) +uris = writer.write(chunks) +# --8<-- [end:sql_write_existing_engine] + + +# --8<-- [start:sql_write_async] +import asyncio +from railtracks.writers import SQLWriter + +async def write_sql_documents(): + writer = SQLWriter( + "postgresql+psycopg2://user:pass@host/db", + table="documents", + content_column="body", + id_column="id", + ) + uris = await writer.awrite(chunks) + uri = await writer.awrite_key("doc-99", "New content ...") + return uris + +asyncio.run(write_sql_documents()) +# --8<-- [end:sql_write_async] + + +# =========================================================================== +# Load -> Generate -> Write back (end-to-end) +# =========================================================================== + +# --8<-- [start:pipeline_generate_and_write] +import railtracks as rt +from railtracks.loaders import S3Loader +from railtracks.writers import S3Writer +from railtracks.vector_stores.chunking.base_chunker import Chunk + +# 1. Load source documents +loader = S3Loader("source-bucket", region_name="us-east-1") +source_chunks = loader.load(prefix="raw/") + +# 2. Run an agent to generate a summary for each document +summariser = rt.agent_node( + name="Summariser", + llm=rt.llm.OpenAILLM("gpt-4o-mini"), + system_message="Summarise the provided document in 2-3 sentences.", +) + +writer = S3Writer("output-bucket", region_name="us-east-1") + +for chunk in source_chunks: + result = summariser.invoke(chunk.content) + summary_chunk = Chunk( + content=result.content, + id=f"summary-{chunk.id}", + metadata={"source": chunk.metadata["source"], "type": "summary"}, + ) + uri = writer.write_key(f"summaries/{chunk.id}.txt", summary_chunk.content) + print(f"Saved summary -> {uri}") +# --8<-- [end:pipeline_generate_and_write] diff --git a/docs/tutorials/notebooks/file_embedding.md b/docs/tutorials/notebooks/file_embedding.md new file mode 100644 index 000000000..9f91a834b --- /dev/null +++ b/docs/tutorials/notebooks/file_embedding.md @@ -0,0 +1,55 @@ +## Embedding Local Files + +Learn how to load documents from your local filesystem, chunk and embed them, and build a RAG-powered agent that can answer questions from your own files. + +
+ +
+
+ Local File Embedding +
+ +
+ Run this tutorial interactively in Google Colab. +
+
+ + +
+ +## Embedding Remote Files + +Learn how to fetch documents from remote URLs, embed them into a vector store, and give your agent grounded, retrieval-augmented answers from any web-accessible source. + +
+ +
+
+ Remote File Embedding +
+ +
+ Run this tutorial interactively in Google Colab. +
+
+ + +
diff --git a/docs/tutorials/walkthroughs/rag_tutorial.md b/docs/tutorials/walkthroughs/rag_tutorial.md index 61a2cd9c2..d9ce35434 100644 --- a/docs/tutorials/walkthroughs/rag_tutorial.md +++ b/docs/tutorials/walkthroughs/rag_tutorial.md @@ -28,4 +28,5 @@ You can also use our pre-configured RAG node that automatically collects context !!! success "Next Steps" - Check out the [RAG Reference Documentation](../../rag/RAG.md) to learn how to build RAG applications in Railtracks. - - Explore the [Agent Design Documentation](../../documentation/agent_design/overview.md) for integrating any type of tool. \ No newline at end of file + - Explore the [Agent Design Documentation](../../documentation/agent_design/overview.md) for integrating any type of tool. + - Try the hands-on Colab notebooks: [Local File Embedding](../notebooks/file_embedding.md) and [Remote File Embedding](../notebooks/file_embedding.md#embedding-remote-files). \ No newline at end of file diff --git a/docs/tutorials/walkthroughs/storage_loaders_tutorial.md b/docs/tutorials/walkthroughs/storage_loaders_tutorial.md new file mode 100644 index 000000000..1f4652aeb --- /dev/null +++ b/docs/tutorials/walkthroughs/storage_loaders_tutorial.md @@ -0,0 +1,116 @@ +# Loading Documents from Cloud Storage + +This tutorial walks you through loading documents from AWS S3 or Azure Blob Storage +and connecting them to a RAG-powered agent. + +!!! tip "Prerequisites" + - You should be comfortable with the [RAG concepts](../../rag/RAG.md) and have read + the [vector store guide](../../rag/vector_stores/vector_store_info.md). + - Install the required extras before starting: + + ```bash + pip install railtracks[aws] # for S3 + pip install railtracks[azure-blob] # for Azure Blob Storage + pip install railtracks[chroma] # for the vector store + ``` + +--- + +## Step 1 — Load your documents + +Pick the provider that matches your storage: + +=== "AWS S3" + + ```python + --8<-- "docs/scripts/storage_loaders.py:s3_prefix" + ``` + +=== "Azure Blob Storage" + + ```python + --8<-- "docs/scripts/storage_loaders.py:azure_prefix" + ``` + +Each loader returns a list of +[`Chunk`](../../rag/vector_stores/vector_store_info.md) objects. Every chunk carries: + +- **`content`** — the UTF-8 text of the file +- **`document`** — the key or blob name used as an identifier +- **`metadata`** — provider-specific fields including a `source` URL for citation + +--- + +## Step 2 — Index the chunks in a vector store + +Pass the chunks straight to `ChromaVectorStore.upsert()` — no conversion needed: + +```python +--8<-- "docs/scripts/storage_loaders.py:shared_embedding" +``` + +```python +from railtracks.vector_stores import ChromaVectorStore + +store = ChromaVectorStore("my-knowledge-base", embedding_function=embedding_function) +store.upsert(chunks) +``` + +--- + +## Step 3 — Build a retrieval tool + +Wrap the store's `search` method in a `function_node` so the agent can call it: + +```python +import railtracks as rt + +@rt.function_node +def search_docs(query: str) -> str: + """Search the knowledge base and return relevant excerpts.""" + results = store.search(query, top_k=5) + return "\n\n".join(r.content for r in results) +``` + +--- + +## Step 4 — Connect to an agent + +```python +agent = rt.agent_node( + name="KnowledgeAgent", + llm=rt.llm.OpenAILLM("gpt-4o"), + system_message="Answer questions using only the provided knowledge base.", + tool_nodes=[search_docs], +) + +flow = rt.Flow("knowledge-flow", entry_point=agent) +response = flow.invoke("What is our remote work policy?") +print(response) +``` + +--- + +## Putting it all together + +### S3 + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_s3_to_rag" +``` + +### Azure Blob Storage + +```python +--8<-- "docs/scripts/storage_loaders.py:pipeline_azure_to_rag" +``` + +--- + +!!! success "Next steps" + - Explore [filtering](../../rag/vector_stores/filtering.md) to scope retrieval by + metadata (e.g., `container`, `bucket`, `key`). + - Add [guardrails](../../documentation/advanced/guardrails/overview.md) to validate + agent responses before returning them to users. + - Use [async loading](../../integrations/storage/s3.md#async-usage) (`aload` / + `aload_keys`) when integrating with async frameworks such as FastAPI. diff --git a/mkdocs.yml b/mkdocs.yml index 5c1c1c034..2cdd5f511 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -160,6 +160,12 @@ nav: - Github: integrations/mcps/github.md - Slack: integrations/mcps/slack.md - Email: integrations/mcps/notion.md + - Storage & Databases: + - Overview: integrations/storage/overview.md + - AWS S3: integrations/storage/s3.md + - Azure Blob Storage: integrations/storage/azure_blob.md + - Google Cloud Storage: integrations/storage/gcs.md + - SQL Databases: integrations/storage/sql.md - Other: - Python Sandbox: integrations/other/python_sandbox.md - Shell: integrations/other/shell_bash.md @@ -175,6 +181,7 @@ nav: - Flows: tutorials/walkthroughs/flows.md - RAG system: tutorials/walkthroughs/rag_tutorial.md - Vector Stores: tutorials/walkthroughs/vector_store_tutorial.md + - Cloud Storage Loaders: tutorials/walkthroughs/storage_loaders_tutorial.md - FastAPI Integration: tutorials/walkthroughs/fastapi.md - Videos: - Multiagent Systems: tutorials/videos/multiagent.md @@ -184,6 +191,7 @@ nav: - Notebooks: - Agent Architectures: tutorials/notebooks/architectures.md - FastAPI Integration: tutorials/notebooks/fastapi.md + - File Embedding (RAG): tutorials/notebooks/file_embedding.md - Concepts: - Core: - Agents: tutorials/concepts/agents.md diff --git a/packages/railtracks/pyproject.toml b/packages/railtracks/pyproject.toml index aadc16f60..73f578bf3 100644 --- a/packages/railtracks/pyproject.toml +++ b/packages/railtracks/pyproject.toml @@ -72,11 +72,24 @@ chroma = [ ] # the integrations submodule will be list a of the above deps portkey = ["portkey_ai >= 2.0.2"] -integrations = ["railtracks[portkey]", "railtracks[chroma]"] +aws = ["boto3 >= 1.26.0"] +azure-blob = [ + "azure-storage-blob >= 12.0.0", + "azure-identity >= 1.15.0", +] +gcp = ["google-cloud-storage >= 2.0.0"] +sql = ["sqlalchemy >= 2.0.0"] stores-vector = ["asyncpg>=0.29", "pgvector>=0.3"] stores-chroma = ["chromadb>1.3.0"] stores-all = ["railtracks[stores-vector]", "railtracks[stores-chroma]"] - +integrations = [ + "railtracks[portkey]", + "railtracks[chroma]", + "railtracks[aws]", + "railtracks[azure-blob]", + "railtracks[gcp]", + "railtracks[sql]", +] all = ["railtracks[visual]", "railtracks[integrations]"] [project.scripts] diff --git a/packages/railtracks/src/railtracks/__init__.py b/packages/railtracks/src/railtracks/__init__.py index 8f616645a..8c2951cbf 100644 --- a/packages/railtracks/src/railtracks/__init__.py +++ b/packages/railtracks/src/railtracks/__init__.py @@ -42,6 +42,8 @@ "evaluations", "vector_stores", "rag", + "loaders", + "writers", "RagConfig", "Flow", "enable_logging", @@ -59,9 +61,11 @@ guardrails, integrations, llm, + loaders, prebuilt, rag, vector_stores, + writers, ) from ._session import ExecutionInfo, Session, session from .context.central import session_id, set_config diff --git a/packages/railtracks/src/railtracks/loaders/__init__.py b/packages/railtracks/src/railtracks/loaders/__init__.py new file mode 100644 index 000000000..03c7e29fe --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/__init__.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------- +# Copyright (c) Railtown AI. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. +# ------------------------------------------------------------- +"""Remote cloud storage document loaders for Railtracks.""" + +from .azure_blob import AzureBlobLoader +from .base import BaseStorageLoader +from .gcs import GCSLoader +from .s3 import S3Loader +from .sql import SQLLoader + +__all__ = [ + "BaseStorageLoader", + "S3Loader", + "AzureBlobLoader", + "GCSLoader", + "SQLLoader", +] diff --git a/packages/railtracks/src/railtracks/loaders/azure_blob.py b/packages/railtracks/src/railtracks/loaders/azure_blob.py new file mode 100644 index 000000000..29e1005ce --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/azure_blob.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import Any, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageLoader + + +class AzureBlobLoader(BaseStorageLoader): + """Document loader for Azure Blob Storage. + + Fetches blobs from a container and returns them as + :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk` objects + with UTF-8 decoded content and source metadata. + + Authentication defaults to ``DefaultAzureCredential``, which resolves + credentials from environment variables, managed identity, the Azure CLI, + and other standard sources automatically. Pass an explicit ``credential`` + to override. + + Requires the ``azure-blob`` extra: ``pip install railtracks[azure-blob]``. + + Args: + account_url: Azure Storage account URL, e.g. + ``"https://.blob.core.windows.net"``. + container_name: Name of the blob container. + credential: Azure credential used for authentication. Accepts any + ``azure.core.credentials.TokenCredential``, + ``azure.core.credentials.AzureSasCredential``, or an account key + string. Defaults to ``DefaultAzureCredential()`` when ``None``. + encoding: Text encoding used to decode blob bytes. Defaults to ``"utf-8"``. + + Raises: + ImportError: If ``azure-storage-blob`` or ``azure-identity`` are not installed. + + Example:: + + # Default credential (env vars, managed identity, Azure CLI …) + loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", + ) + chunks = loader.load(prefix="reports/") + + # Explicit SAS token + from azure.core.credentials import AzureSasCredential + loader = AzureBlobLoader( + "https://myaccount.blob.core.windows.net", + "my-container", + credential=AzureSasCredential(""), + ) + + # Load specific blobs + chunks = loader.load_keys(["readme.txt", "data/report.txt"]) + + # Async usage + chunks = await loader.aload(prefix="reports/") + """ + + def __init__( + self, + account_url: str, + container_name: str, + *, + credential: Optional[Any] = None, + encoding: str = "utf-8", + ) -> None: + try: + from azure.storage.blob import ContainerClient + except ImportError: + raise ImportError( + "azure-storage-blob is required for Azure Blob loading. " + "Install it via `pip install railtracks[azure-blob]` or `uv add railtracks[azure-blob]`." + ) + + if credential is None: + try: + from azure.identity import DefaultAzureCredential + + credential = DefaultAzureCredential() + except ImportError: + raise ImportError( + "azure-identity is required for default credential resolution. " + "Install it via `pip install railtracks[azure-blob]` or `uv add railtracks[azure-blob]`." + ) + + self._account_url = account_url.rstrip("/") + self._container_name = container_name + self._encoding = encoding + self._container_client = ContainerClient( + account_url=self._account_url, + container_name=container_name, + credential=credential, + ) + + def __repr__(self) -> str: + return ( + f"AzureBlobLoader(" + f"account_url={self._account_url!r}, " + f"container_name={self._container_name!r})" + ) + + def load(self, prefix: Optional[str] = None) -> list[Chunk]: + """Load all blobs from the container, optionally filtered by prefix. + + Args: + prefix: Optional blob name prefix. Only blobs whose names start + with this string are loaded. + + Returns: + list[Chunk]: All matching blobs as Chunk objects. + """ + list_kwargs: dict[str, Any] = {} + if prefix is not None: + list_kwargs["name_starts_with"] = prefix + + blob_names = [ + blob.name + for blob in self._container_client.list_blobs(**list_kwargs) + ] + return self.load_keys(blob_names) + + def load_keys(self, keys: list[str]) -> list[Chunk]: + """Load specific blobs by name. + + Args: + keys: List of blob names to load. + + Returns: + list[Chunk]: Specified blobs as Chunk objects. + """ + chunks: list[Chunk] = [] + for blob_name in keys: + blob_client = self._container_client.get_blob_client(blob_name) + data = blob_client.download_blob().readall() + content = data.decode(self._encoding) + chunks.append( + Chunk( + content=content, + document=blob_name, + metadata={ + "source": f"{self._account_url}/{self._container_name}/{blob_name}", + "account_url": self._account_url, + "container": self._container_name, + "blob_name": blob_name, + }, + ) + ) + return chunks diff --git a/packages/railtracks/src/railtracks/loaders/base.py b/packages/railtracks/src/railtracks/loaders/base.py new file mode 100644 index 000000000..9e5d3a2ce --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/base.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from typing import Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + + +class BaseStorageLoader(ABC): + """Abstract base class for remote cloud storage document loaders. + + Subclasses fetch objects from a specific cloud provider and return them + as :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk` objects + containing UTF-8 decoded content and provider-specific metadata. + """ + + @abstractmethod + def load(self, prefix: Optional[str] = None) -> list[Chunk]: + """Load all documents from the storage container. + + Args: + prefix: Optional key/path prefix to filter objects. + + Returns: + list[Chunk]: Documents as Chunk objects with content and metadata. + """ + + @abstractmethod + def load_keys(self, keys: list[str]) -> list[Chunk]: + """Load specific documents by key or blob name. + + Args: + keys: List of object keys or blob names to load. + + Returns: + list[Chunk]: Documents as Chunk objects with content and metadata. + """ + + async def aload(self, prefix: Optional[str] = None) -> list[Chunk]: + """Async wrapper around :meth:`load`. + + Args: + prefix: Optional key/path prefix to filter objects. + + Returns: + list[Chunk]: Documents as Chunk objects with content and metadata. + """ + return await asyncio.to_thread(self.load, prefix) + + async def aload_keys(self, keys: list[str]) -> list[Chunk]: + """Async wrapper around :meth:`load_keys`. + + Args: + keys: List of object keys or blob names to load. + + Returns: + list[Chunk]: Documents as Chunk objects with content and metadata. + """ + return await asyncio.to_thread(self.load_keys, keys) diff --git a/packages/railtracks/src/railtracks/loaders/gcs.py b/packages/railtracks/src/railtracks/loaders/gcs.py new file mode 100644 index 000000000..abc81fbcf --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/gcs.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Any, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageLoader + + +class GCSLoader(BaseStorageLoader): + """Document loader for Google Cloud Storage (GCS). + + Fetches objects from a GCS bucket and returns them as + :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk` objects + with UTF-8 decoded content and source metadata. + + Authentication uses **Application Default Credentials (ADC)** by default, + which resolves credentials from the following sources (in order): + + 1. ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable (path to a + service-account JSON key file) + 2. ``gcloud auth application-default login`` (developer workstation) + 3. Workload Identity / attached service account (GCE, GKE, Cloud Run …) + + Pass explicit ``credentials`` to override ADC. + + Requires the ``gcp`` extra: ``pip install railtracks[gcp]``. + + Args: + bucket_name: GCS bucket name. + project: Google Cloud project ID. Inferred from ADC when ``None``. + credentials: Explicit Google credential object (e.g. + ``google.oauth2.service_account.Credentials``). Defaults to ADC. + encoding: Text encoding used to decode object bytes. Defaults to ``"utf-8"``. + + Raises: + ImportError: If ``google-cloud-storage`` is not installed. + + Example:: + + loader = GCSLoader("my-bucket", project="my-gcp-project") + + # Load all objects under a prefix + chunks = loader.load(prefix="documents/") + + # Load specific object names + chunks = loader.load_keys(["readme.txt", "data/report.txt"]) + + # Async usage + chunks = await loader.aload(prefix="documents/") + """ + + def __init__( + self, + bucket_name: str, + *, + project: Optional[str] = None, + credentials: Optional[Any] = None, + encoding: str = "utf-8", + ) -> None: + try: + from google.cloud import storage # type: ignore[import] + except ImportError: + raise ImportError( + "google-cloud-storage is required for GCS loading. " + "Install it via `pip install railtracks[gcp]` or `uv add railtracks[gcp]`." + ) + + self._bucket_name = bucket_name + self._encoding = encoding + self._client = storage.Client(project=project, credentials=credentials) + + def __repr__(self) -> str: + return f"GCSLoader(bucket_name={self._bucket_name!r})" + + def load(self, prefix: Optional[str] = None) -> list[Chunk]: + """Load all objects from the bucket, optionally filtered by prefix. + + Args: + prefix: Optional object name prefix. Only objects whose names start + with this string are loaded. + + Returns: + list[Chunk]: All matching objects as Chunk objects. + """ + kwargs: dict[str, Any] = {} + if prefix is not None: + kwargs["prefix"] = prefix + + blob_names = [ + blob.name + for blob in self._client.list_blobs(self._bucket_name, **kwargs) + ] + return self.load_keys(blob_names) + + def load_keys(self, keys: list[str]) -> list[Chunk]: + """Load specific objects from the bucket by name. + + Args: + keys: List of GCS object names to load. + + Returns: + list[Chunk]: Specified objects as Chunk objects. + """ + bucket = self._client.bucket(self._bucket_name) + chunks: list[Chunk] = [] + for name in keys: + blob = bucket.blob(name) + content = blob.download_as_bytes().decode(self._encoding) + chunks.append( + Chunk( + content=content, + document=name, + metadata={ + "source": f"gs://{self._bucket_name}/{name}", + "bucket": self._bucket_name, + "name": name, + }, + ) + ) + return chunks diff --git a/packages/railtracks/src/railtracks/loaders/s3.py b/packages/railtracks/src/railtracks/loaders/s3.py new file mode 100644 index 000000000..ca8e5b164 --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/s3.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Any, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageLoader + + +class S3Loader(BaseStorageLoader): + """Document loader for AWS S3. + + Fetches objects from an S3 bucket and returns them as + :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk` objects + with UTF-8 decoded content and source metadata. + + Credentials follow boto3's standard resolution chain: environment variables + (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY``), ``~/.aws/credentials``, + IAM instance profiles, etc. Pass explicit credentials to override the chain. + + Requires the ``aws`` extra: ``pip install railtracks[aws]``. + + Args: + bucket: S3 bucket name. + region_name: AWS region (optional). + aws_access_key_id: Explicit access key ID (optional). + aws_secret_access_key: Explicit secret access key (optional). + aws_session_token: Explicit session token for temporary credentials (optional). + endpoint_url: Custom endpoint URL for S3-compatible services such as MinIO (optional). + encoding: Text encoding used to decode object bytes. Defaults to ``"utf-8"``. + + Raises: + ImportError: If ``boto3`` is not installed. + + Example:: + + loader = S3Loader("my-bucket", region_name="us-west-2") + + # Load all objects under a prefix + chunks = loader.load(prefix="documents/") + + # Load specific keys + chunks = loader.load_keys(["readme.txt", "data/report.txt"]) + + # Async usage + chunks = await loader.aload(prefix="documents/") + """ + + def __init__( + self, + bucket: str, + *, + region_name: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + endpoint_url: Optional[str] = None, + encoding: str = "utf-8", + ) -> None: + try: + import boto3 + except ImportError: + raise ImportError( + "boto3 is required for S3 loading. " + "Install it via `pip install railtracks[aws]` or `uv add railtracks[aws]`." + ) + + self._bucket = bucket + self._encoding = encoding + self._client = boto3.client( + "s3", + region_name=region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + endpoint_url=endpoint_url, + ) + + def __repr__(self) -> str: + return f"S3Loader(bucket={self._bucket!r})" + + def load(self, prefix: Optional[str] = None) -> list[Chunk]: + """Load all objects from the bucket, optionally filtered by prefix. + + Uses the S3 list-objects paginator so buckets with more than 1 000 + objects are handled correctly. + + Args: + prefix: Optional S3 key prefix. Only objects whose keys start with + this string are loaded. + + Returns: + list[Chunk]: All matching objects as Chunk objects. + """ + kwargs: dict[str, Any] = {"Bucket": self._bucket} + if prefix is not None: + kwargs["Prefix"] = prefix + + keys: list[str] = [] + paginator = self._client.get_paginator("list_objects_v2") + for page in paginator.paginate(**kwargs): + for obj in page.get("Contents", []): + keys.append(obj["Key"]) + + return self.load_keys(keys) + + def load_keys(self, keys: list[str]) -> list[Chunk]: + """Load specific objects from the bucket by key. + + Args: + keys: List of S3 object keys to load. + + Returns: + list[Chunk]: Specified objects as Chunk objects. + """ + chunks: list[Chunk] = [] + for key in keys: + response = self._client.get_object(Bucket=self._bucket, Key=key) + content = response["Body"].read().decode(self._encoding) + chunks.append( + Chunk( + content=content, + document=key, + metadata={ + "source": f"s3://{self._bucket}/{key}", + "bucket": self._bucket, + "key": key, + }, + ) + ) + return chunks diff --git a/packages/railtracks/src/railtracks/loaders/sql.py b/packages/railtracks/src/railtracks/loaders/sql.py new file mode 100644 index 000000000..653278c16 --- /dev/null +++ b/packages/railtracks/src/railtracks/loaders/sql.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import re +import warnings +from typing import TYPE_CHECKING, Any, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageLoader + +if TYPE_CHECKING: + from sqlalchemy import Engine + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*(\.[A-Za-z_][A-Za-z0-9_$]*)?$") + + +def _looks_like_query(s: str) -> bool: + """Return True if *s* appears to be a raw SQL statement rather than a table name.""" + return s.strip().upper().startswith("SELECT") + + +def _validate_identifier(value: str, label: str) -> None: + """Raise ValueError if *value* is not a safe SQL identifier. + + Allows simple names (``my_table``) and schema-qualified names + (``public.documents``). Rejects anything containing SQL metacharacters + or whitespace that could enable injection. + """ + if not _IDENT_RE.match(value): + raise ValueError( + f"Invalid SQL identifier for {label!r}: {value!r}. " + "Identifiers must start with a letter or underscore and may only " + "contain letters, digits, underscores, and dollar signs. " + "Use 'schema.table' notation for schema-qualified names. " + "Never pass user-controlled strings as identifiers." + ) + + +class SQLLoader(BaseStorageLoader): + """Document loader for relational databases via SQLAlchemy. + + Reads rows from a table (or arbitrary ``SELECT`` query) and converts each + row into a :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk`. + + Works with any SQLAlchemy-compatible database including PostgreSQL, + Supabase (PostgreSQL), MySQL, SQLite, and more. For SQLite no extra + driver is needed; for PostgreSQL install ``psycopg2`` or ``asyncpg``. + + Requires the ``sql`` extra: ``pip install railtracks[sql]``. + + Args: + connection_string: SQLAlchemy database URL, e.g. + ``"postgresql+psycopg2://user:pass@host/db"`` or + ``"sqlite:///:memory:"``. + table_or_query: Either a **table name** (e.g. ``"documents"``) or a + full ``SELECT`` statement (e.g. ``"SELECT id, body FROM docs"``). + Table names produce ``SELECT * FROM ``. + + .. note:: + Table names are validated against a strict allowlist + (``[A-Za-z_][A-Za-z0-9_$]*``) at construction time. Never + pass user-controlled strings here. Use raw ``SELECT`` queries + with bound parameters for dynamic filtering. + + ``WITH`` (CTE) queries are not supported as ``table_or_query``. + Wrap them in a subquery or use the CTE inline in a ``SELECT``: + ``"SELECT * FROM (WITH … SELECT …) AS t"`` — or pass the CTE + body as a regular SQL query where applicable. + + content_column: Name of the column whose value becomes + :attr:`Chunk.content`. + metadata_columns: Column names to include in :attr:`Chunk.metadata`. + When ``None`` all columns except ``content_column`` and + ``id_column`` are included. + id_column: Column to use as :attr:`Chunk.id`. When ``None`` a UUID + is auto-generated for every chunk. + document_column: Column to use as :attr:`Chunk.document`. Falls back + to ``id_column`` when ``None``. + engine: An existing ``sqlalchemy.Engine`` instance. When provided, + ``connection_string`` and ``engine_kwargs`` are ignored. Useful + when you already have a configured engine (custom pool, SSL, etc.) + or for testing with an in-memory database. The caller is + responsible for disposing this engine; :meth:`close` will not + touch it. + engine_kwargs: Extra keyword arguments forwarded to + ``sqlalchemy.create_engine()``. + + Raises: + ImportError: If ``sqlalchemy`` is not installed. + ValueError: If an identifier argument contains unsafe characters, or + if ``load_keys()`` is called without an ``id_column``. + + Example:: + + # PostgreSQL / Supabase + loader = SQLLoader( + "postgresql+psycopg2://user:pass@db.supabase.co:5432/postgres", + table_or_query="documents", + content_column="body", + metadata_columns=["title", "author", "created_at"], + id_column="id", + ) + chunks = loader.load() + + # SQLite (great for testing / local dev) + loader = SQLLoader( + "sqlite:///my_db.sqlite", + table_or_query="knowledge", + content_column="text", + ) + chunks = loader.load() + + # Raw query + loader = SQLLoader( + connection_string, + table_or_query="SELECT id, body FROM docs WHERE published = 1", + content_column="body", + id_column="id", + ) + chunks = loader.load() + + # Load specific rows by id_column value + chunks = loader.load_keys(["doc-001", "doc-002"]) + + # Context-manager — engine is disposed automatically + with SQLLoader(connection_string, "documents", "body") as loader: + chunks = loader.load() + + # Async usage + chunks = await loader.aload() + """ + + def __init__( + self, + connection_string: str, + table_or_query: str, + content_column: str, + *, + metadata_columns: Optional[list[str]] = None, + id_column: Optional[str] = None, + document_column: Optional[str] = None, + engine: Optional[Engine] = None, + engine_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + try: + import sqlalchemy # noqa: F401 + except ImportError: + raise ImportError( + "sqlalchemy is required for SQL loading. " + "Install it via `pip install railtracks[sql]` or `uv add railtracks[sql]`." + ) + + import sqlalchemy as sa + + self._content_column = content_column + self._metadata_columns = metadata_columns + self._id_column = id_column + self._document_column = document_column + self._table_or_query = table_or_query + self._is_raw_query = _looks_like_query(table_or_query) + + # Validate all structural identifiers at construction time. + # Values are developer-supplied config, not end-user input — but we + # enforce the allowlist to catch misconfiguration and make injection + # harder if a value ever originates from an untrusted source. + if not self._is_raw_query: + _validate_identifier(table_or_query, "table_or_query") + _validate_identifier(content_column, "content_column") + if id_column is not None: + _validate_identifier(id_column, "id_column") + if document_column is not None: + _validate_identifier(document_column, "document_column") + if metadata_columns is not None: + for col in metadata_columns: + _validate_identifier(col, f"metadata_columns[{col!r}]") + + self._owns_engine = engine is None + if engine is not None: + self._engine = engine + else: + self._engine = sa.create_engine( + connection_string, + **(engine_kwargs or {}), + ) + + def __repr__(self) -> str: + url = self._engine.url.render_as_string(hide_password=True) + return f"SQLLoader(url={url!r}, table_or_query={self._table_or_query!r})" + + # ------------------------------------------------------------------ + # Resource lifecycle + # ------------------------------------------------------------------ + + def close(self) -> None: + """Dispose the underlying SQLAlchemy engine, releasing pooled connections. + + Only disposes engines created internally by this loader. Engines + supplied via the ``engine`` constructor parameter are left for the + caller to manage. + """ + if self._owns_engine: + self._engine.dispose() + + def __enter__(self) -> SQLLoader: + return self + + def __exit__(self, *_: object) -> None: + self.close() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _rows_to_chunks(self, rows: list[dict[str, Any]]) -> list[Chunk]: + chunks: list[Chunk] = [] + for row in rows: + if self._content_column not in row: + available = sorted(row.keys()) + raise ValueError( + f"content_column {self._content_column!r} was not found in the " + f"query results. Available columns: {available}. " + "Check that 'content_column' matches a column returned by the query." + ) + content = str(row[self._content_column]) + + # Determine which columns go into metadata + if self._metadata_columns is not None: + meta = {col: row[col] for col in self._metadata_columns if col in row} + else: + excluded = {self._content_column} + if self._id_column: + excluded.add(self._id_column) + meta = {k: v for k, v in row.items() if k not in excluded} + + # Add a source hint so metadata always has a 'source' key + meta.setdefault("source", self._table_or_query) + + chunk_id = str(row[self._id_column]) if self._id_column else None + doc_col = self._document_column or self._id_column + document = str(row[doc_col]) if doc_col and doc_col in row else None + + chunks.append( + Chunk( + content=content, + id=chunk_id, + document=document, + metadata=meta, + ) + ) + return chunks + + def _fetch_all(self) -> list[dict[str, Any]]: + import sqlalchemy as sa + + if self._is_raw_query: + stmt = sa.text(self._table_or_query) + else: + stmt = sa.text(f"SELECT * FROM {self._table_or_query}") # noqa: S608 + + with self._engine.connect() as conn: + result = conn.execute(stmt) + return [dict(row._mapping) for row in result] + + def _fetch_by_ids(self, keys: list[str]) -> list[dict[str, Any]]: + import sqlalchemy as sa + + if self._is_raw_query: + base = self._table_or_query.rstrip(";") + stmt = sa.text( + f"SELECT * FROM ({base}) AS _sub " # noqa: S608 + f"WHERE {self._id_column} IN :keys" + ).bindparams(sa.bindparam("keys", expanding=True)) + else: + stmt = sa.text( + f"SELECT * FROM {self._table_or_query} " # noqa: S608 + f"WHERE {self._id_column} IN :keys" + ).bindparams(sa.bindparam("keys", expanding=True)) + + with self._engine.connect() as conn: + result = conn.execute(stmt, {"keys": keys}) + return [dict(row._mapping) for row in result] + + # ------------------------------------------------------------------ + # BaseStorageLoader interface + # ------------------------------------------------------------------ + + def load(self, prefix: Optional[str] = None) -> list[Chunk]: + """Load all rows from the configured table or query. + + The ``prefix`` parameter is not meaningful for SQL sources and is + accepted only to satisfy the :class:`BaseStorageLoader` interface. + Passing a non-``None`` value emits a :class:`UserWarning`. + + Returns: + list[Chunk]: All matching rows as Chunk objects. + """ + if prefix is not None: + warnings.warn( + "SQLLoader does not support prefix filtering; " + "the 'prefix' argument is ignored. " + "Use a WHERE clause in your SQL query for row-level filtering.", + UserWarning, + stacklevel=2, + ) + rows = self._fetch_all() + return self._rows_to_chunks(rows) + + def load_keys(self, keys: list[str]) -> list[Chunk]: + """Load specific rows by their ``id_column`` value. + + Args: + keys: List of ``id_column`` values to fetch. + + Returns: + list[Chunk]: Matching rows as Chunk objects. + + Raises: + ValueError: If the loader was created without an ``id_column``. + """ + if not self._id_column: + raise ValueError( + "load_keys() requires an 'id_column' to be set on the loader." + ) + if not keys: + return [] + rows = self._fetch_by_ids(keys) + return self._rows_to_chunks(rows) diff --git a/packages/railtracks/src/railtracks/writers/__init__.py b/packages/railtracks/src/railtracks/writers/__init__.py new file mode 100644 index 000000000..814308bdb --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/__init__.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------- +# Copyright (c) Railtown AI. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. +# ------------------------------------------------------------- +"""Remote cloud storage document writers for Railtracks.""" + +from .azure_blob import AzureBlobWriter +from .base import BaseStorageWriter +from .gcs import GCSWriter +from .s3 import S3Writer +from .sql import SQLWriter + +__all__ = [ + "BaseStorageWriter", + "S3Writer", + "AzureBlobWriter", + "GCSWriter", + "SQLWriter", +] diff --git a/packages/railtracks/src/railtracks/writers/azure_blob.py b/packages/railtracks/src/railtracks/writers/azure_blob.py new file mode 100644 index 000000000..6b3a337b4 --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/azure_blob.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageWriter + + +class AzureBlobWriter(BaseStorageWriter): + """Document writer for Azure Blob Storage. + + Uploads text content to a blob container, encoding each blob with the + configured text encoding and returning the full ``https://`` URI of every + blob written. Existing blobs at the same name are overwritten. + + Authentication defaults to ``DefaultAzureCredential``, which resolves + credentials from environment variables, managed identity, the Azure CLI, + and other standard sources automatically. Pass an explicit ``credential`` + to override. + + Requires the ``azure-blob`` extra: ``pip install railtracks[azure-blob]``. + + Args: + account_url: Azure Storage account URL, e.g. + ``"https://.blob.core.windows.net"``. + container_name: Name of the blob container. + credential: Azure credential used for authentication. Accepts any + ``azure.core.credentials.TokenCredential``, + ``azure.core.credentials.AzureSasCredential``, or an account key + string. Defaults to ``DefaultAzureCredential()`` when ``None``. + encoding: Text encoding used when converting content to bytes. Defaults to ``"utf-8"``. + content_type: MIME type set on the uploaded blob. Defaults to + ``"text/plain; charset=utf-8"``. + key_fn: Optional callable ``(chunk) -> str`` that derives a blob name + from a :class:`Chunk`. When ``None`` the name falls back to + ``chunk.id``, then ``chunk.document``, then a random UUID. + + Raises: + ImportError: If ``azure-storage-blob`` or ``azure-identity`` are not installed. + + Example:: + + writer = AzureBlobWriter( + "https://myaccount.blob.core.windows.net", + "my-container", + ) + + # Write a list of Chunk objects + uris = writer.write(chunks, prefix="generated/") + + # Write raw text at an explicit blob name + uri = writer.write_key("reports/summary.txt", "Today's summary ...") + + # Async usage + uris = await writer.awrite(chunks, prefix="generated/") + """ + + def __init__( + self, + account_url: str, + container_name: str, + *, + credential: Optional[Any] = None, + encoding: str = "utf-8", + content_type: str = "text/plain; charset=utf-8", + key_fn: Optional[Callable[[Chunk], str]] = None, + ) -> None: + try: + from azure.storage.blob import ContainerClient + except ImportError: + raise ImportError( + "azure-storage-blob is required for Azure Blob writing. " + "Install it via `pip install railtracks[azure-blob]` or `uv add railtracks[azure-blob]`." + ) + + if credential is None: + try: + from azure.identity import DefaultAzureCredential + + credential = DefaultAzureCredential() + except ImportError: + raise ImportError( + "azure-identity is required for default credential resolution. " + "Install it via `pip install railtracks[azure-blob]` or `uv add railtracks[azure-blob]`." + ) + + self._account_url = account_url.rstrip("/") + self._container_name = container_name + self._encoding = encoding + self._content_type = content_type + self._key_fn = key_fn + + self._container_client = ContainerClient( + account_url=self._account_url, + container_name=container_name, + credential=credential, + ) + + def __repr__(self) -> str: + return ( + f"AzureBlobWriter(" + f"account_url={self._account_url!r}, " + f"container_name={self._container_name!r})" + ) + + def write(self, chunks: list[Chunk], prefix: Optional[str] = None) -> list[str]: + """Write chunks to Azure Blob Storage, one blob per chunk. + + Args: + chunks: Chunk objects to persist. + prefix: Optional blob name prefix prepended to each derived name. + + Returns: + list[str]: Full ``https://`` URIs of every blob written. + """ + uris: list[str] = [] + for chunk in chunks: + name = self._derive_key(chunk, prefix, self._key_fn) + uri = self.write_key(name, chunk.content) + uris.append(uri) + return uris + + def write_key(self, key: str, content: str) -> str: + """Write raw text to Azure Blob Storage at an explicit blob name. + + Args: + key: Blob name (path within the container). + content: Text content to upload. + + Returns: + str: Full ``https://`` URI of the written blob. + """ + blob_client = self._container_client.get_blob_client(key) + blob_client.upload_blob( + content.encode(self._encoding), + overwrite=True, + content_settings=self._make_content_settings(), + ) + return f"{self._account_url}/{self._container_name}/{key}" + + def _make_content_settings(self): + try: + from azure.storage.blob import ContentSettings + + return ContentSettings(content_type=self._content_type) + except ImportError: + return None diff --git a/packages/railtracks/src/railtracks/writers/base.py b/packages/railtracks/src/railtracks/writers/base.py new file mode 100644 index 000000000..734d05714 --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/base.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import asyncio +import uuid +from abc import ABC, abstractmethod +from typing import Callable, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + + +class BaseStorageWriter(ABC): + """Abstract base class for remote cloud storage writers. + + Subclasses persist content (raw strings or :class:`Chunk` objects) to a + specific cloud provider and return the URIs or keys of the objects written. + """ + + @abstractmethod + def write(self, chunks: list[Chunk], prefix: Optional[str] = None) -> list[str]: + """Write chunks to storage. + + Args: + chunks: Chunk objects to persist. + prefix: Optional prefix prepended to auto-derived object keys. + Ignored by providers where it has no meaning (e.g. SQL). + + Returns: + list[str]: URIs or keys of the objects written, one per chunk. + """ + + @abstractmethod + def write_key(self, key: str, content: str) -> str: + """Write raw text content at an explicit key. + + Args: + key: The storage key, path, or ID (relative to bucket/container/table). + content: Text content to write. + + Returns: + str: The full URI or key of the written object. + """ + + async def awrite( + self, chunks: list[Chunk], prefix: Optional[str] = None + ) -> list[str]: + """Async wrapper around :meth:`write`.""" + return await asyncio.to_thread(self.write, chunks, prefix) + + async def awrite_key(self, key: str, content: str) -> str: + """Async wrapper around :meth:`write_key`.""" + return await asyncio.to_thread(self.write_key, key, content) + + # ------------------------------------------------------------------ + # Protected helpers + # ------------------------------------------------------------------ + + @staticmethod + def _derive_key( + chunk: Chunk, + prefix: Optional[str], + key_fn: Optional[Callable[[Chunk], str]], + ) -> str: + """Determine the storage key for a chunk. + + Priority order: + 1. ``key_fn(chunk)`` when provided. + 2. ``chunk.id`` when set. + 3. ``chunk.document`` when set. + 4. A freshly generated UUID4 string as a last resort. + + The ``prefix`` is prepended to whatever key is derived. + """ + if key_fn is not None: + raw = key_fn(chunk) + elif chunk.id: + raw = chunk.id + elif chunk.document: + raw = chunk.document + else: + raw = str(uuid.uuid4()) + + return f"{prefix}{raw}" if prefix else raw diff --git a/packages/railtracks/src/railtracks/writers/gcs.py b/packages/railtracks/src/railtracks/writers/gcs.py new file mode 100644 index 000000000..2536a3088 --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/gcs.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageWriter + + +class GCSWriter(BaseStorageWriter): + """Document writer for Google Cloud Storage (GCS). + + Uploads text content to a GCS bucket, encoding each object with the + configured text encoding and returning the full ``gs://`` URI of every + object written. Existing objects at the same name are overwritten. + + Authentication uses **Application Default Credentials (ADC)** by default, + which resolves credentials from the following sources (in order): + + 1. ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable + 2. ``gcloud auth application-default login`` (developer workstation) + 3. Workload Identity / attached service account (GCE, GKE, Cloud Run ...) + + Pass explicit ``credentials`` to override ADC. + + Requires the ``gcp`` extra: ``pip install railtracks[gcp]``. + + Args: + bucket_name: GCS bucket name. + project: Google Cloud project ID. Inferred from ADC when ``None``. + credentials: Explicit Google credential object. Defaults to ADC. + encoding: Text encoding used when converting content to bytes. Defaults to ``"utf-8"``. + content_type: MIME type set on the uploaded object. Defaults to + ``"text/plain; charset=utf-8"``. + key_fn: Optional callable ``(chunk) -> str`` that derives a storage key + from a :class:`Chunk`. When ``None`` the name falls back to + ``chunk.id``, then ``chunk.document``, then a random UUID. + + Raises: + ImportError: If ``google-cloud-storage`` is not installed. + + Example:: + + writer = GCSWriter("my-bucket", project="my-gcp-project") + + # Write a list of Chunk objects + uris = writer.write(chunks, prefix="generated/") + + # Write raw text at an explicit object name + uri = writer.write_key("reports/summary.txt", "Today's summary ...") + + # Async usage + uris = await writer.awrite(chunks, prefix="generated/") + """ + + def __init__( + self, + bucket_name: str, + *, + project: Optional[str] = None, + credentials: Optional[Any] = None, + encoding: str = "utf-8", + content_type: str = "text/plain; charset=utf-8", + key_fn: Optional[Callable[[Chunk], str]] = None, + ) -> None: + try: + from google.cloud import storage # type: ignore[import] + except ImportError: + raise ImportError( + "google-cloud-storage is required for GCS writing. " + "Install it via `pip install railtracks[gcp]` or `uv add railtracks[gcp]`." + ) + + self._bucket_name = bucket_name + self._encoding = encoding + self._content_type = content_type + self._key_fn = key_fn + self._client = storage.Client(project=project, credentials=credentials) + + def __repr__(self) -> str: + return f"GCSWriter(bucket_name={self._bucket_name!r})" + + def write(self, chunks: list[Chunk], prefix: Optional[str] = None) -> list[str]: + """Write chunks to GCS, one object per chunk. + + Args: + chunks: Chunk objects to persist. + prefix: Optional object name prefix prepended to each derived name. + + Returns: + list[str]: ``gs://bucket/name`` URIs of every object written. + """ + uris: list[str] = [] + for chunk in chunks: + name = self._derive_key(chunk, prefix, self._key_fn) + uri = self.write_key(name, chunk.content) + uris.append(uri) + return uris + + def write_key(self, key: str, content: str) -> str: + """Write raw text to GCS at an explicit object name. + + Args: + key: GCS object name (path within the bucket). + content: Text content to upload. + + Returns: + str: ``gs://bucket/name`` URI of the written object. + """ + bucket = self._client.bucket(self._bucket_name) + blob = bucket.blob(key) + blob.upload_from_string( + content.encode(self._encoding), + content_type=self._content_type, + ) + return f"gs://{self._bucket_name}/{key}" diff --git a/packages/railtracks/src/railtracks/writers/s3.py b/packages/railtracks/src/railtracks/writers/s3.py new file mode 100644 index 000000000..4b7a6987e --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/s3.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from typing import Callable, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageWriter + + +class S3Writer(BaseStorageWriter): + """Document writer for AWS S3. + + Writes text content to an S3 bucket, encoding each object with the + configured text encoding and returning the full ``s3://`` URI of every + object written. + + Credentials follow boto3's standard resolution chain: environment variables + (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY``), ``~/.aws/credentials``, + IAM instance profiles, etc. Pass explicit credentials to override. + + Requires the ``aws`` extra: ``pip install railtracks[aws]``. + + Args: + bucket: S3 bucket name. + region_name: AWS region (optional). + aws_access_key_id: Explicit access key ID (optional). + aws_secret_access_key: Explicit secret access key (optional). + aws_session_token: Explicit session token for temporary credentials (optional). + endpoint_url: Custom endpoint URL for S3-compatible services such as MinIO (optional). + encoding: Text encoding used when converting content to bytes. Defaults to ``"utf-8"``. + content_type: MIME type set on the uploaded object. Defaults to + ``"text/plain; charset=utf-8"``. + key_fn: Optional callable ``(chunk) -> str`` that derives a storage key + from a :class:`Chunk`. When ``None`` the key falls back to + ``chunk.id``, then ``chunk.document``, then a random UUID. + + Raises: + ImportError: If ``boto3`` is not installed. + + Example:: + + writer = S3Writer("my-bucket", region_name="us-east-1") + + # Write a list of Chunk objects + uris = writer.write(chunks, prefix="generated/") + + # Write raw text at an explicit key + uri = writer.write_key("reports/summary.txt", "Today's summary ...") + + # Async usage + uris = await writer.awrite(chunks, prefix="generated/") + """ + + def __init__( + self, + bucket: str, + *, + region_name: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + endpoint_url: Optional[str] = None, + encoding: str = "utf-8", + content_type: str = "text/plain; charset=utf-8", + key_fn: Optional[Callable[[Chunk], str]] = None, + ) -> None: + try: + import boto3 + except ImportError: + raise ImportError( + "boto3 is required for S3 writing. " + "Install it via `pip install railtracks[aws]` or `uv add railtracks[aws]`." + ) + + self._bucket = bucket + self._encoding = encoding + self._content_type = content_type + self._key_fn = key_fn + self._client = boto3.client( + "s3", + region_name=region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + endpoint_url=endpoint_url, + ) + + def __repr__(self) -> str: + return f"S3Writer(bucket={self._bucket!r})" + + def write(self, chunks: list[Chunk], prefix: Optional[str] = None) -> list[str]: + """Write chunks to S3, one object per chunk. + + Args: + chunks: Chunk objects to persist. + prefix: Optional S3 key prefix prepended to each derived key. + + Returns: + list[str]: ``s3://bucket/key`` URIs of every object written. + """ + uris: list[str] = [] + for chunk in chunks: + key = self._derive_key(chunk, prefix, self._key_fn) + uri = self.write_key(key, chunk.content) + uris.append(uri) + return uris + + def write_key(self, key: str, content: str) -> str: + """Write raw text to S3 at an explicit key. + + Args: + key: S3 object key (path within the bucket). + content: Text content to upload. + + Returns: + str: ``s3://bucket/key`` URI of the written object. + """ + self._client.put_object( + Bucket=self._bucket, + Key=key, + Body=content.encode(self._encoding), + ContentType=self._content_type, + ) + return f"s3://{self._bucket}/{key}" diff --git a/packages/railtracks/src/railtracks/writers/sql.py b/packages/railtracks/src/railtracks/writers/sql.py new file mode 100644 index 000000000..1838cd251 --- /dev/null +++ b/packages/railtracks/src/railtracks/writers/sql.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +import re +import warnings +from typing import TYPE_CHECKING, Any, Literal, Optional + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .base import BaseStorageWriter + +if TYPE_CHECKING: + from sqlalchemy import Engine + +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*(\.[A-Za-z_][A-Za-z0-9_$]*)?$") + + +def _validate_identifier(value: str, label: str) -> None: + """Raise ValueError if *value* is not a safe SQL identifier.""" + if not _IDENT_RE.match(value): + raise ValueError( + f"Invalid SQL identifier for {label!r}: {value!r}. " + "Identifiers must start with a letter or underscore and may only " + "contain letters, digits, underscores, and dollar signs. " + "Use 'schema.table' notation for schema-qualified names. " + "Never pass user-controlled strings as identifiers." + ) + + +class SQLWriter(BaseStorageWriter): + """Document writer for relational databases via SQLAlchemy. + + Converts :class:`~railtracks.vector_stores.chunking.base_chunker.Chunk` + objects into table rows and writes them to any SQLAlchemy-compatible + database (PostgreSQL, Supabase, MySQL, SQLite, ...). + + In ``"upsert"`` mode (default) an existing row with the same id is replaced + before the new row is inserted, giving safe idempotent writes. In + ``"insert"`` mode rows are appended without any conflict handling. + + .. note:: **All-or-nothing writes** + + All chunks in a single :meth:`write` call are persisted inside one + database transaction. If any individual row fails (constraint + violation, type mismatch, etc.) the **entire batch is rolled back** and + no rows are written. The exception from the failing row is re-raised + so you can inspect which chunk caused the error. To write chunks + individually and tolerate partial failures, call :meth:`write_key` in + a loop with your own error handling. + + Requires the ``sql`` extra: ``pip install railtracks[sql]``. + + Args: + connection_string: SQLAlchemy database URL, e.g. + ``"postgresql+psycopg2://user:pass@host/db"`` or + ``"sqlite:///my.db"``. + table: Target table name. The table must already exist. + + .. note:: + The table name is validated against a strict allowlist at + construction time. Never pass user-controlled strings here. + + content_column: Column that receives :attr:`Chunk.content`. + id_column: Column that receives :attr:`Chunk.id`. Also used as the + conflict key for upserts. When ``None`` no id is written and + :meth:`write_key` raises :exc:`ValueError`. + document_column: Column that receives :attr:`Chunk.document`. + Ignored when ``None``. + metadata_columns: Metadata keys to persist as individual columns. + Each key in this list is read from ``chunk.metadata`` and written + to the column of the same name. When ``None`` no metadata columns + are written. + mode: ``"upsert"`` (default) deletes any existing row matching + ``id_column`` before inserting; ``"insert"`` performs a plain + ``INSERT``. + engine: An existing ``sqlalchemy.Engine`` instance. When provided, + ``connection_string`` and ``engine_kwargs`` are ignored. The + caller is responsible for disposing this engine; :meth:`close` + will not touch it. + engine_kwargs: Extra keyword arguments forwarded to + ``sqlalchemy.create_engine()``. + + Raises: + ImportError: If ``sqlalchemy`` is not installed. + ValueError: If an identifier argument contains unsafe characters, or + if :meth:`write_key` is called without ``id_column``. + + Example:: + + writer = SQLWriter( + "postgresql+psycopg2://user:pass@host/db", + table="documents", + content_column="body", + id_column="id", + metadata_columns=["title", "category"], + ) + + # Write chunks (upsert by default) + ids = writer.write(chunks) + + # Write raw content at an explicit id + id_ = writer.write_key("doc-42", "Revised content ...") + + # Context-manager — engine is disposed automatically + with SQLWriter(connection_string, "documents", "body", id_column="id") as w: + w.write(chunks) + + # Async usage + ids = await writer.awrite(chunks) + """ + + def __init__( + self, + connection_string: str, + table: str, + content_column: str, + *, + id_column: Optional[str] = None, + document_column: Optional[str] = None, + metadata_columns: Optional[list[str]] = None, + mode: Literal["insert", "upsert"] = "upsert", + engine: Optional[Engine] = None, + engine_kwargs: Optional[dict[str, Any]] = None, + ) -> None: + try: + import sqlalchemy # noqa: F401 + except ImportError: + raise ImportError( + "sqlalchemy is required for SQL writing. " + "Install it via `pip install railtracks[sql]` or `uv add railtracks[sql]`." + ) + + import sqlalchemy as sa + + # Validate all structural identifiers at construction time. + _validate_identifier(table, "table") + _validate_identifier(content_column, "content_column") + if id_column is not None: + _validate_identifier(id_column, "id_column") + if document_column is not None: + _validate_identifier(document_column, "document_column") + if metadata_columns is not None: + for col in metadata_columns: + _validate_identifier(col, f"metadata_columns[{col!r}]") + + self._table = table + self._content_column = content_column + self._id_column = id_column + self._document_column = document_column + self._metadata_columns = metadata_columns + self._mode = mode + + self._owns_engine = engine is None + if engine is not None: + self._engine = engine + else: + self._engine = sa.create_engine( + connection_string, + **(engine_kwargs or {}), + ) + + def __repr__(self) -> str: + url = self._engine.url.render_as_string(hide_password=True) + return f"SQLWriter(url={url!r}, table={self._table!r}, mode={self._mode!r})" + + # ------------------------------------------------------------------ + # Resource lifecycle + # ------------------------------------------------------------------ + + def close(self) -> None: + """Dispose the underlying SQLAlchemy engine, releasing pooled connections. + + Only disposes engines created internally by this writer. Engines + supplied via the ``engine`` constructor parameter are left for the + caller to manage. + """ + if self._owns_engine: + self._engine.dispose() + + def __enter__(self) -> SQLWriter: + return self + + def __exit__(self, *_: object) -> None: + self.close() + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _chunk_to_row(self, chunk: Chunk) -> dict[str, Any]: + row: dict[str, Any] = {self._content_column: chunk.content} + if self._id_column and chunk.id is not None: + row[self._id_column] = chunk.id + if self._document_column and chunk.document is not None: + row[self._document_column] = chunk.document + if self._metadata_columns: + for col in self._metadata_columns: + if col in chunk.metadata: + row[col] = chunk.metadata[col] + return row + + def _write_row(self, conn: Any, row: dict[str, Any]) -> None: + import sqlalchemy as sa + + if self._mode == "upsert" and self._id_column and self._id_column in row: + conn.execute( + sa.text( + f"DELETE FROM {self._table} WHERE {self._id_column} = :_id" # noqa: S608 + ), + {"_id": row[self._id_column]}, + ) + + cols = list(row.keys()) + col_str = ", ".join(cols) + param_str = ", ".join(f":{c}" for c in cols) + conn.execute( + sa.text( + f"INSERT INTO {self._table} ({col_str}) VALUES ({param_str})" # noqa: S608 + ), + row, + ) + + def _row_uri(self, row: dict[str, Any]) -> str: + if self._id_column and self._id_column in row: + return f"sql://{self._table}/{row[self._id_column]}" + return f"sql://{self._table}" + + # ------------------------------------------------------------------ + # BaseStorageWriter interface + # ------------------------------------------------------------------ + + def write(self, chunks: list[Chunk], prefix: Optional[str] = None) -> list[str]: + """Write chunks to the database table. + + All chunks are written inside a **single transaction**. If any row + fails the entire batch is rolled back and the exception is re-raised. + See the class docstring for details on handling partial failures. + + The ``prefix`` parameter has no effect on SQL writes and is accepted + only to satisfy the :class:`BaseStorageWriter` interface. Passing a + non-``None`` value emits a :class:`UserWarning`. + + Args: + chunks: Chunk objects to persist. + prefix: Unused for SQL; accepted for interface compatibility. + + Returns: + list[str]: ``sql://table/id`` URIs (or ``sql://table`` when no + ``id_column`` is configured) for every chunk written. + """ + if prefix is not None: + warnings.warn( + "SQLWriter does not support prefix; " + "the 'prefix' argument is ignored.", + UserWarning, + stacklevel=2, + ) + uris: list[str] = [] + with self._engine.begin() as conn: + for chunk in chunks: + row = self._chunk_to_row(chunk) + self._write_row(conn, row) + uris.append(self._row_uri(row)) + return uris + + def write_key(self, key: str, content: str) -> str: + """Write raw text content as a single row identified by ``key``. + + The ``key`` is stored in ``id_column``. This method requires + ``id_column`` to be set on the writer. + + Args: + key: Value to store in ``id_column``. + content: Text content stored in ``content_column``. + + Returns: + str: ``sql://table/key`` URI of the written row. + + Raises: + ValueError: If the writer was created without an ``id_column``. + """ + if not self._id_column: + raise ValueError( + "write_key() requires an 'id_column' to be set on the writer." + ) + row: dict[str, Any] = { + self._id_column: key, + self._content_column: content, + } + with self._engine.begin() as conn: + self._write_row(conn, row) + return f"sql://{self._table}/{key}" diff --git a/packages/railtracks/tests/unit_tests/loaders/__init__.py b/packages/railtracks/tests/unit_tests/loaders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/railtracks/tests/unit_tests/loaders/conftest.py b/packages/railtracks/tests/unit_tests/loaders/conftest.py new file mode 100644 index 000000000..7e0d65690 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/loaders/conftest.py @@ -0,0 +1,210 @@ +"""Shared fixtures for storage loader unit tests.""" + +import sys +from contextlib import contextmanager +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from railtracks.vector_stores.chunking.base_chunker import Chunk + + +# --------------------------------------------------------------------------- +# S3 helpers +# --------------------------------------------------------------------------- + + +def make_s3_client(objects: dict[str, str], encoding: str = "utf-8") -> MagicMock: + """Return a mock boto3 S3 client pre-loaded with *objects* (key → text).""" + client = MagicMock() + + paginator = MagicMock() + page_contents = [{"Key": k} for k in objects] + paginator.paginate.return_value = [{"Contents": page_contents}] + client.get_paginator.return_value = paginator + + def _get_object(Bucket: str, Key: str, **_: Any) -> dict[str, Any]: + text = objects[Key] + return {"Body": MagicMock(read=MagicMock(return_value=text.encode(encoding)))} + + client.get_object.side_effect = _get_object + return client + + +@pytest.fixture +def s3_client_factory(): + """Factory fixture: call with a dict[key, text] to get a mock S3 client.""" + return make_s3_client + + +# --------------------------------------------------------------------------- +# Azure Blob helpers +# --------------------------------------------------------------------------- + + +def make_container_client(blobs: dict[str, str], encoding: str = "utf-8") -> MagicMock: + """Return a mock ContainerClient pre-loaded with *blobs* (name → text).""" + container_client = MagicMock() + + blob_items = [] + for name in blobs: + item = MagicMock() + item.name = name + blob_items.append(item) + container_client.list_blobs.return_value = blob_items + + def _get_blob_client(blob_name: str) -> MagicMock: + blob_client = MagicMock() + text = blobs.get(blob_name, "") + blob_client.download_blob.return_value.readall.return_value = text.encode(encoding) + return blob_client + + container_client.get_blob_client.side_effect = _get_blob_client + return container_client + + +@pytest.fixture +def container_client_factory(): + """Factory fixture: call with a dict[name, text] to get a mock ContainerClient.""" + return make_container_client + + +@contextmanager +def patch_azure(container_client: MagicMock): + """Context manager that patches both azure.storage.blob and azure.identity.""" + from unittest.mock import patch + + mock_storage = MagicMock() + mock_storage.ContainerClient.return_value = container_client + + mock_identity = MagicMock() + + with patch.dict( + sys.modules, + { + "azure.storage.blob": mock_storage, + "azure.storage": MagicMock(blob=mock_storage), + "azure": MagicMock( + storage=MagicMock(blob=mock_storage), + identity=mock_identity, + ), + "azure.identity": mock_identity, + }, + ): + yield mock_storage, mock_identity + + +@pytest.fixture +def azure_patch(): + """Fixture exposing the patch_azure context manager.""" + return patch_azure + + +# --------------------------------------------------------------------------- +# GCS helpers +# --------------------------------------------------------------------------- + + +def make_gcs_client(objects: dict[str, str], encoding: str = "utf-8") -> MagicMock: + """Return a mock google.cloud.storage.Client pre-loaded with *objects* (name → text).""" + client = MagicMock() + + # list_blobs returns an iterable of blob objects with .name + blob_items = [] + for name in objects: + item = MagicMock() + item.name = name + blob_items.append(item) + client.list_blobs.return_value = blob_items + + # client.bucket(...).blob(name).download_as_bytes() + def _bucket(bucket_name: str) -> MagicMock: + bucket = MagicMock() + + def _blob(name: str) -> MagicMock: + b = MagicMock() + b.download_as_bytes.return_value = objects.get(name, "").encode(encoding) + return b + + bucket.blob.side_effect = _blob + return bucket + + client.bucket.side_effect = _bucket + return client + + +@contextmanager +def patch_gcs(gcs_client: MagicMock): + """Context manager that patches google.cloud.storage.""" + from unittest.mock import patch + + mock_storage = MagicMock() + mock_storage.Client.return_value = gcs_client + + with patch.dict( + sys.modules, + { + "google": MagicMock(cloud=MagicMock(storage=mock_storage)), + "google.cloud": MagicMock(storage=mock_storage), + "google.cloud.storage": mock_storage, + }, + ): + yield mock_storage + + +@pytest.fixture +def gcs_client_factory(): + """Factory fixture: call with a dict[name, text] to get a mock GCS client.""" + return make_gcs_client + + +@pytest.fixture +def gcs_patch(): + """Fixture exposing the patch_gcs context manager.""" + return patch_gcs + + +# --------------------------------------------------------------------------- +# SQL helpers (real in-memory SQLite — no mocking needed) +# --------------------------------------------------------------------------- + + +def make_sqlite_engine(rows: list[dict[str, Any]], table: str = "documents") -> Any: + """Create an in-memory SQLite engine pre-populated with *rows*. + + Uses ``StaticPool`` so the same underlying connection is reused across all + threads — necessary for ``asyncio.to_thread`` calls in async loader tests. + The table is created dynamically from the keys of the first row. + All column values are stored as TEXT for simplicity. + + Returns the SQLAlchemy engine; pass it directly to SQLLoader via ``engine=``. + """ + import sqlalchemy as sa + from sqlalchemy.pool import StaticPool + + engine = sa.create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + if not rows: + return engine + + columns = list(rows[0].keys()) + col_defs = ", ".join(f"{c} TEXT" for c in columns) + with engine.begin() as conn: + conn.execute(sa.text(f"CREATE TABLE IF NOT EXISTS {table} ({col_defs})")) # noqa: S608 + for row in rows: + placeholders = ", ".join(f":{c}" for c in columns) + conn.execute( + sa.text(f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"), # noqa: S608 + row, + ) + return engine + + +@pytest.fixture +def sqlite_engine_factory(): + """Factory fixture: call with rows and optional table name to get a real SQLite engine.""" + return make_sqlite_engine diff --git a/packages/railtracks/tests/unit_tests/loaders/test_azure_blob_loader.py b/packages/railtracks/tests/unit_tests/loaders/test_azure_blob_loader.py new file mode 100644 index 000000000..354edafd2 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/loaders/test_azure_blob_loader.py @@ -0,0 +1,170 @@ +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from railtracks.loaders.azure_blob import AzureBlobLoader +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_container_client, patch_azure + +ACCOUNT_URL = "https://myaccount.blob.core.windows.net" +CONTAINER = "my-container" + + +class TestAzureBlobLoaderInit: + def test_raises_import_error_when_azure_storage_blob_missing(self) -> None: + with patch.dict(sys.modules, {"azure.storage.blob": None}): + with pytest.raises(ImportError, match="azure-storage-blob"): + AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + def test_import_error_message_mentions_extra(self) -> None: + with patch.dict(sys.modules, {"azure.storage.blob": None}): + with pytest.raises(ImportError, match="railtracks\\[azure-blob\\]"): + AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + def test_trailing_slash_stripped_from_account_url(self) -> None: + container_client = make_container_client({}) + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL + "/", CONTAINER) + assert loader._account_url == ACCOUNT_URL + + def test_explicit_credential_bypasses_default_azure_credential(self) -> None: + container_client = make_container_client({}) + explicit_cred = MagicMock() + + with patch_azure(container_client): + import azure.identity as ai # type: ignore[import] + + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER, credential=explicit_cred) + ai.DefaultAzureCredential.assert_not_called() + + +class TestAzureBlobLoaderLoadKeys: + def test_returns_chunk_for_each_blob(self) -> None: + blobs = {"doc1.txt": "hello", "doc2.txt": "world"} + container_client = make_container_client(blobs) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + chunks = loader.load_keys(list(blobs.keys())) + + assert len(chunks) == 2 + assert all(isinstance(c, Chunk) for c in chunks) + + def test_chunk_content_matches_blob_body(self) -> None: + container_client = make_container_client({"readme.txt": "important content"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + (chunk,) = loader.load_keys(["readme.txt"]) + assert chunk.content == "important content" + + def test_chunk_document_is_blob_name(self) -> None: + container_client = make_container_client({"path/to/file.txt": "data"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + (chunk,) = loader.load_keys(["path/to/file.txt"]) + assert chunk.document == "path/to/file.txt" + + def test_chunk_metadata_source_url(self) -> None: + container_client = make_container_client({"file.txt": "data"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.metadata["source"] == f"{ACCOUNT_URL}/{CONTAINER}/file.txt" + assert chunk.metadata["container"] == CONTAINER + assert chunk.metadata["blob_name"] == "file.txt" + assert chunk.metadata["account_url"] == ACCOUNT_URL + + def test_custom_encoding_is_used(self) -> None: + text = "café" + container_client = make_container_client({"file.txt": text}, encoding="latin-1") + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER, encoding="latin-1") + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.content == text + + def test_empty_keys_returns_empty_list(self) -> None: + with patch_azure(make_container_client({})): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + assert loader.load_keys([]) == [] + + +class TestAzureBlobLoaderLoad: + def test_load_returns_all_blobs(self) -> None: + blobs = {"a.txt": "aaa", "b.txt": "bbb"} + container_client = make_container_client(blobs) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + chunks = loader.load() + + assert len(chunks) == 2 + assert {c.content for c in chunks} == {"aaa", "bbb"} + + def test_load_passes_prefix_to_list_blobs(self) -> None: + container_client = make_container_client({"docs/a.txt": "aaa"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + loader.load(prefix="docs/") + + container_client.list_blobs.assert_called_once_with(name_starts_with="docs/") + + def test_load_without_prefix_omits_name_starts_with(self) -> None: + container_client = make_container_client({"a.txt": "aaa"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + loader.load() + + container_client.list_blobs.assert_called_once_with() + + def test_empty_container_returns_empty_list(self) -> None: + container_client = make_container_client({}) + container_client.list_blobs.return_value = [] + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + assert loader.load() == [] + + +@pytest.mark.asyncio +class TestAzureBlobLoaderAsync: + async def test_aload_returns_same_chunks_as_load(self) -> None: + container_client = make_container_client({"file.txt": "hello"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + sync_chunks = loader.load() + async_chunks = await loader.aload() + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content + + async def test_aload_keys_returns_same_chunks_as_load_keys(self) -> None: + container_client = make_container_client({"file.txt": "hello"}) + + with patch_azure(container_client): + loader = AzureBlobLoader(ACCOUNT_URL, CONTAINER) + + sync_chunks = loader.load_keys(["file.txt"]) + async_chunks = await loader.aload_keys(["file.txt"]) + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content diff --git a/packages/railtracks/tests/unit_tests/loaders/test_gcs_loader.py b/packages/railtracks/tests/unit_tests/loaders/test_gcs_loader.py new file mode 100644 index 000000000..c5b4312a4 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/loaders/test_gcs_loader.py @@ -0,0 +1,173 @@ +import sys +from unittest.mock import patch + +import pytest + +from railtracks.loaders.gcs import GCSLoader +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_gcs_client, patch_gcs + +BUCKET = "my-gcs-bucket" + + +class TestGCSLoaderInit: + def test_raises_import_error_when_package_missing(self) -> None: + with patch.dict(sys.modules, {"google.cloud.storage": None, "google.cloud": None}): + with pytest.raises(ImportError, match="google-cloud-storage"): + GCSLoader(BUCKET) + + def test_import_error_message_mentions_extra(self) -> None: + with patch.dict(sys.modules, {"google.cloud.storage": None, "google.cloud": None}): + with pytest.raises(ImportError, match="railtracks\\[gcp\\]"): + GCSLoader(BUCKET) + + def test_explicit_credentials_forwarded_to_client(self) -> None: + from unittest.mock import MagicMock + + fake_creds = MagicMock() + gcs_client = make_gcs_client({}) + + with patch_gcs(gcs_client) as mock_storage: + GCSLoader(BUCKET, credentials=fake_creds) + mock_storage.Client.assert_called_once_with( + project=None, credentials=fake_creds + ) + + def test_project_forwarded_to_client(self) -> None: + gcs_client = make_gcs_client({}) + + with patch_gcs(gcs_client) as mock_storage: + GCSLoader(BUCKET, project="my-project") + mock_storage.Client.assert_called_once_with( + project="my-project", credentials=None + ) + + +class TestGCSLoaderLoadKeys: + def test_returns_chunk_for_each_key(self) -> None: + objects = {"doc1.txt": "hello", "doc2.txt": "world"} + gcs_client = make_gcs_client(objects) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + chunks = loader.load_keys(list(objects.keys())) + + assert len(chunks) == 2 + assert all(isinstance(c, Chunk) for c in chunks) + + def test_chunk_content_matches_object_body(self) -> None: + gcs_client = make_gcs_client({"notes.txt": "some text content"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + (chunk,) = loader.load_keys(["notes.txt"]) + assert chunk.content == "some text content" + + def test_chunk_document_is_object_name(self) -> None: + gcs_client = make_gcs_client({"path/to/file.txt": "data"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + (chunk,) = loader.load_keys(["path/to/file.txt"]) + assert chunk.document == "path/to/file.txt" + + def test_chunk_metadata_source_url(self) -> None: + gcs_client = make_gcs_client({"file.txt": "data"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.metadata["source"] == f"gs://{BUCKET}/file.txt" + assert chunk.metadata["bucket"] == BUCKET + assert chunk.metadata["name"] == "file.txt" + + def test_custom_encoding_is_used(self) -> None: + text = "café" + gcs_client = make_gcs_client({"file.txt": text}, encoding="latin-1") + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET, encoding="latin-1") + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.content == text + + def test_empty_keys_returns_empty_list(self) -> None: + with patch_gcs(make_gcs_client({})): + loader = GCSLoader(BUCKET) + + assert loader.load_keys([]) == [] + + +class TestGCSLoaderLoad: + def test_load_returns_all_objects(self) -> None: + objects = {"a.txt": "aaa", "b.txt": "bbb", "c.txt": "ccc"} + gcs_client = make_gcs_client(objects) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + chunks = loader.load() + + assert len(chunks) == 3 + assert {c.content for c in chunks} == {"aaa", "bbb", "ccc"} + + def test_load_passes_prefix_to_list_blobs(self) -> None: + gcs_client = make_gcs_client({"docs/a.txt": "aaa"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + loader.load(prefix="docs/") + + gcs_client.list_blobs.assert_called_once_with(BUCKET, prefix="docs/") + + def test_load_without_prefix_omits_prefix_kwarg(self) -> None: + gcs_client = make_gcs_client({"a.txt": "aaa"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + loader.load() + + gcs_client.list_blobs.assert_called_once_with(BUCKET) + + def test_empty_bucket_returns_empty_list(self) -> None: + gcs_client = make_gcs_client({}) + gcs_client.list_blobs.return_value = [] + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + assert loader.load() == [] + + +@pytest.mark.asyncio +class TestGCSLoaderAsync: + async def test_aload_returns_same_chunks_as_load(self) -> None: + gcs_client = make_gcs_client({"file.txt": "hello"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + sync_chunks = loader.load() + async_chunks = await loader.aload() + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content + + async def test_aload_keys_returns_same_chunks_as_load_keys(self) -> None: + gcs_client = make_gcs_client({"file.txt": "hello"}) + + with patch_gcs(gcs_client): + loader = GCSLoader(BUCKET) + + sync_chunks = loader.load_keys(["file.txt"]) + async_chunks = await loader.aload_keys(["file.txt"]) + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content diff --git a/packages/railtracks/tests/unit_tests/loaders/test_s3_loader.py b/packages/railtracks/tests/unit_tests/loaders/test_s3_loader.py new file mode 100644 index 000000000..9e3a86907 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/loaders/test_s3_loader.py @@ -0,0 +1,158 @@ +import sys +from unittest.mock import patch + +import pytest + +from railtracks.loaders.s3 import S3Loader +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_s3_client + + +class TestS3LoaderInit: + def test_raises_import_error_when_boto3_missing(self) -> None: + with patch.dict(sys.modules, {"boto3": None}): + with pytest.raises(ImportError, match="boto3"): + S3Loader("bucket") + + def test_import_error_message_mentions_extra(self) -> None: + with patch.dict(sys.modules, {"boto3": None}): + with pytest.raises(ImportError, match="railtracks\\[aws\\]"): + S3Loader("bucket") + + +class TestS3LoaderLoadKeys: + def test_returns_chunk_for_each_key(self) -> None: + objects = {"doc1.txt": "hello world", "doc2.txt": "foo bar"} + mock_client = make_s3_client(objects) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + chunks = loader.load_keys(list(objects.keys())) + + assert len(chunks) == 2 + assert all(isinstance(c, Chunk) for c in chunks) + + def test_chunk_content_matches_object_body(self) -> None: + mock_client = make_s3_client({"notes.txt": "some text content"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + (chunk,) = loader.load_keys(["notes.txt"]) + assert chunk.content == "some text content" + + def test_chunk_document_is_object_key(self) -> None: + mock_client = make_s3_client({"path/to/file.txt": "data"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + (chunk,) = loader.load_keys(["path/to/file.txt"]) + assert chunk.document == "path/to/file.txt" + + def test_chunk_metadata_contains_source_url(self) -> None: + mock_client = make_s3_client({"file.txt": "data"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.metadata["source"] == "s3://my-bucket/file.txt" + assert chunk.metadata["bucket"] == "my-bucket" + assert chunk.metadata["key"] == "file.txt" + + def test_custom_encoding_is_used(self) -> None: + text = "café" + mock_client = make_s3_client({"file.txt": text}, encoding="latin-1") + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket", encoding="latin-1") + + (chunk,) = loader.load_keys(["file.txt"]) + assert chunk.content == text + + def test_empty_keys_returns_empty_list(self) -> None: + mock_client = make_s3_client({}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + assert loader.load_keys([]) == [] + + +class TestS3LoaderLoad: + def test_load_without_prefix_returns_all_objects(self) -> None: + objects = {"a.txt": "aaa", "b.txt": "bbb", "c.txt": "ccc"} + mock_client = make_s3_client(objects) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + chunks = loader.load() + + assert len(chunks) == 3 + assert {c.content for c in chunks} == {"aaa", "bbb", "ccc"} + + def test_load_passes_prefix_to_paginator(self) -> None: + mock_client = make_s3_client({"docs/a.txt": "aaa", "docs/b.txt": "bbb"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + loader.load(prefix="docs/") + + paginator = mock_client.get_paginator.return_value + paginator.paginate.assert_called_once_with(Bucket="my-bucket", Prefix="docs/") + + def test_load_without_prefix_omits_prefix_kwarg(self) -> None: + mock_client = make_s3_client({}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + loader.load() + + call_kwargs = mock_client.get_paginator.return_value.paginate.call_args.kwargs + assert "Prefix" not in call_kwargs + + def test_empty_bucket_returns_empty_list(self) -> None: + from unittest.mock import MagicMock + + client = MagicMock() + paginator = MagicMock() + paginator.paginate.return_value = [{}] # no "Contents" key + client.get_paginator.return_value = paginator + + with patch("boto3.client", return_value=client): + loader = S3Loader("empty-bucket") + + assert loader.load() == [] + + +@pytest.mark.asyncio +class TestS3LoaderAsync: + async def test_aload_returns_same_chunks_as_load(self) -> None: + mock_client = make_s3_client({"file.txt": "hello"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + sync_chunks = loader.load() + async_chunks = await loader.aload() + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content + + async def test_aload_keys_returns_same_chunks_as_load_keys(self) -> None: + mock_client = make_s3_client({"file.txt": "hello"}) + + with patch("boto3.client", return_value=mock_client): + loader = S3Loader("my-bucket") + + sync_chunks = loader.load_keys(["file.txt"]) + async_chunks = await loader.aload_keys(["file.txt"]) + + assert len(async_chunks) == len(sync_chunks) + assert async_chunks[0].content == sync_chunks[0].content diff --git a/packages/railtracks/tests/unit_tests/loaders/test_sql_loader.py b/packages/railtracks/tests/unit_tests/loaders/test_sql_loader.py new file mode 100644 index 000000000..01d6d0f0c --- /dev/null +++ b/packages/railtracks/tests/unit_tests/loaders/test_sql_loader.py @@ -0,0 +1,219 @@ +""" +SQLLoader tests use a real in-memory SQLite database (no mocking). +SQLite is part of Python's stdlib so these tests run with zero extra +cloud credentials or installed services. +""" + +import sys + +import pytest + +from railtracks.loaders.sql import SQLLoader +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_sqlite_engine + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +ROWS = [ + {"id": "doc-1", "title": "First Doc", "body": "Content of first document."}, + {"id": "doc-2", "title": "Second Doc", "body": "Content of second document."}, + {"id": "doc-3", "title": "Third Doc", "body": "Content of third document."}, +] + + +@pytest.fixture() +def sqlite_engine(): + """Spin up an in-memory SQLite engine pre-populated with three rows. + + We pass the engine object directly into SQLLoader (via the `engine` + kwarg) so that both share the same in-memory connection pool — SQLite + in-memory databases are connection-scoped and would be empty if SQLLoader + opened a new connection via a URL. + """ + return make_sqlite_engine(ROWS, table="documents") + + +@pytest.fixture() +def loader(sqlite_engine) -> SQLLoader: + return SQLLoader( + "", # unused when engine= is provided + table_or_query="documents", + content_column="body", + id_column="id", + document_column="title", + metadata_columns=["title"], + engine=sqlite_engine, + ) + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestSQLLoaderInit: + def test_raises_import_error_when_sqlalchemy_missing(self) -> None: + with pytest.raises(ImportError, match="sqlalchemy"): + # Patch sqlalchemy out of the import system then force a re-init + original = sys.modules.pop("sqlalchemy", None) + try: + from unittest.mock import patch + with patch.dict(sys.modules, {"sqlalchemy": None}): + SQLLoader("sqlite:///:memory:", "docs", "body") + finally: + if original is not None: + sys.modules["sqlalchemy"] = original + + def test_import_error_message_mentions_extra(self) -> None: + from unittest.mock import patch + with patch.dict(sys.modules, {"sqlalchemy": None}): + with pytest.raises(ImportError, match="railtracks\\[sql\\]"): + SQLLoader("sqlite:///:memory:", "docs", "body") + + +# --------------------------------------------------------------------------- +# load() — table name +# --------------------------------------------------------------------------- + + +class TestSQLLoaderLoad: + def test_returns_all_rows(self, loader: SQLLoader) -> None: + chunks = loader.load() + assert len(chunks) == 3 + + def test_all_chunks_are_chunk_instances(self, loader: SQLLoader) -> None: + assert all(isinstance(c, Chunk) for c in loader.load()) + + def test_content_column_maps_to_chunk_content(self, loader: SQLLoader) -> None: + contents = {c.content for c in loader.load()} + assert contents == {r["body"] for r in ROWS} + + def test_id_column_maps_to_chunk_id(self, loader: SQLLoader) -> None: + ids = {c.id for c in loader.load()} + assert ids == {"doc-1", "doc-2", "doc-3"} + + def test_document_column_maps_to_chunk_document(self, loader: SQLLoader) -> None: + docs = {c.document for c in loader.load()} + assert docs == {"First Doc", "Second Doc", "Third Doc"} + + def test_metadata_contains_source(self, loader: SQLLoader) -> None: + for chunk in loader.load(): + assert "source" in chunk.metadata + + def test_metadata_columns_included(self, loader: SQLLoader) -> None: + for chunk in loader.load(): + assert "title" in chunk.metadata + + def test_content_column_excluded_from_metadata(self, loader: SQLLoader) -> None: + for chunk in loader.load(): + assert "body" not in chunk.metadata + + def test_empty_table_returns_empty_list(self, sqlite_engine) -> None: + import sqlalchemy as sa + with sqlite_engine.begin() as conn: + conn.execute(sa.text("DELETE FROM documents")) + loader = SQLLoader("", "documents", "body", engine=sqlite_engine) + assert loader.load() == [] + + def test_prefix_argument_is_ignored(self, loader: SQLLoader) -> None: + all_chunks = loader.load() + prefixed_chunks = loader.load(prefix="anything") + assert len(all_chunks) == len(prefixed_chunks) + + +# --------------------------------------------------------------------------- +# load() — raw SQL query +# --------------------------------------------------------------------------- + + +class TestSQLLoaderRawQuery: + def test_raw_select_returns_correct_rows(self, sqlite_engine) -> None: + loader = SQLLoader( + "", + table_or_query="SELECT id, body FROM documents WHERE id = 'doc-1'", + content_column="body", + id_column="id", + engine=sqlite_engine, + ) + chunks = loader.load() + assert len(chunks) == 1 + assert chunks[0].content == "Content of first document." + + def test_raw_query_detected_case_insensitive(self, sqlite_engine) -> None: + loader = SQLLoader( + "", + table_or_query="select id, body from documents", + content_column="body", + engine=sqlite_engine, + ) + chunks = loader.load() + assert len(chunks) == 3 + + +# --------------------------------------------------------------------------- +# load_keys() +# --------------------------------------------------------------------------- + + +class TestSQLLoaderLoadKeys: + def test_returns_only_requested_rows(self, loader: SQLLoader) -> None: + chunks = loader.load_keys(["doc-1", "doc-3"]) + assert len(chunks) == 2 + ids = {c.id for c in chunks} + assert ids == {"doc-1", "doc-3"} + + def test_empty_keys_returns_empty_list(self, loader: SQLLoader) -> None: + assert loader.load_keys([]) == [] + + def test_raises_without_id_column(self, sqlite_engine) -> None: + loader_no_id = SQLLoader("", "documents", "body", engine=sqlite_engine) + with pytest.raises(ValueError, match="id_column"): + loader_no_id.load_keys(["doc-1"]) + + def test_unknown_key_returns_empty_list(self, loader: SQLLoader) -> None: + chunks = loader.load_keys(["does-not-exist"]) + assert chunks == [] + + +# --------------------------------------------------------------------------- +# Auto metadata — no metadata_columns specified +# --------------------------------------------------------------------------- + + +class TestSQLLoaderAutoMetadata: + def test_all_non_content_columns_in_metadata(self, sqlite_engine) -> None: + loader = SQLLoader( + "", + "documents", + content_column="body", + id_column="id", + engine=sqlite_engine, + ) + for chunk in loader.load(): + # id excluded because it's id_column; body excluded because content_column + assert "body" not in chunk.metadata + assert "id" not in chunk.metadata + assert "title" in chunk.metadata + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestSQLLoaderAsync: + async def test_aload_matches_sync_load(self, loader: SQLLoader) -> None: + sync = loader.load() + async_ = await loader.aload() + assert len(async_) == len(sync) + assert {c.content for c in async_} == {c.content for c in sync} + + async def test_aload_keys_matches_sync_load_keys(self, loader: SQLLoader) -> None: + sync = loader.load_keys(["doc-2"]) + async_ = await loader.aload_keys(["doc-2"]) + assert len(async_) == len(sync) + assert async_[0].content == sync[0].content diff --git a/packages/railtracks/tests/unit_tests/writers/__init__.py b/packages/railtracks/tests/unit_tests/writers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/railtracks/tests/unit_tests/writers/conftest.py b/packages/railtracks/tests/unit_tests/writers/conftest.py new file mode 100644 index 000000000..f9cd89e57 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/writers/conftest.py @@ -0,0 +1,198 @@ +"""Shared fixtures and helpers for storage writer unit tests.""" + +import sys +from contextlib import contextmanager +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# S3 helpers +# --------------------------------------------------------------------------- + + +def make_s3_client() -> MagicMock: + """Return a mock boto3 S3 client that records put_object calls.""" + client = MagicMock() + client.put_object.return_value = {} + return client + + +@contextmanager +def patch_s3(s3_client: MagicMock): + """Patch boto3.client so it returns *s3_client*.""" + import boto3 + from unittest.mock import patch + + with patch.object(boto3, "client", return_value=s3_client): + yield s3_client + + +# --------------------------------------------------------------------------- +# Azure Blob helpers +# --------------------------------------------------------------------------- + + +def make_container_client() -> MagicMock: + """Return a mock Azure ContainerClient that records upload_blob calls. + + Exposes ``container_client._written`` (name → bytes) and + ``container_client._content_types`` (name → content_type str) so tests + can verify what was written without chasing side_effect return values. + """ + container_client = MagicMock() + written: dict[str, bytes] = {} + content_types: dict[str, str] = {} + + def _get_blob_client(name: str) -> MagicMock: + blob_client = MagicMock() + + def _upload(data: bytes, *, overwrite: bool = False, content_settings=None) -> None: + written[name] = data + if content_settings is not None: + # ContentSettings is a mock; grab content_type if set + ct = getattr(content_settings, "content_type", None) + if ct: + content_types[name] = ct + + blob_client.upload_blob.side_effect = _upload + return blob_client + + container_client.get_blob_client.side_effect = _get_blob_client + container_client._written = written + container_client._content_types = content_types + return container_client + + +@contextmanager +def patch_azure(container_client: MagicMock): + """Patch azure.storage.blob and azure.identity.""" + from unittest.mock import patch + + mock_storage = MagicMock() + mock_storage.ContainerClient.return_value = container_client + mock_storage.ContentSettings.return_value = MagicMock() + + mock_identity = MagicMock() + + with patch.dict( + sys.modules, + { + "azure.storage.blob": mock_storage, + "azure.storage": MagicMock(blob=mock_storage), + "azure": MagicMock( + storage=MagicMock(blob=mock_storage), + identity=mock_identity, + ), + "azure.identity": mock_identity, + }, + ): + yield mock_storage, mock_identity + + +# --------------------------------------------------------------------------- +# GCS helpers +# --------------------------------------------------------------------------- + + +def make_gcs_client() -> MagicMock: + """Return a mock GCS client that records upload_from_string calls. + + Exposes ``client._written`` (name → bytes) and + ``client._content_types`` (name → content_type str). + """ + client = MagicMock() + + written: dict[str, bytes] = {} + content_types: dict[str, str] = {} + + def _bucket(bucket_name: str) -> MagicMock: + bucket = MagicMock() + + def _blob(name: str) -> MagicMock: + b = MagicMock() + b._name = name + + def _upload(data: bytes, content_type: str = "text/plain") -> None: + written[name] = data + content_types[name] = content_type + + b.upload_from_string.side_effect = _upload + return b + + bucket.blob.side_effect = _blob + return bucket + + client.bucket.side_effect = _bucket + client._written = written + client._content_types = content_types + return client + + +@contextmanager +def patch_gcs(gcs_client: MagicMock): + """Patch google.cloud.storage.""" + from unittest.mock import patch + + mock_storage = MagicMock() + mock_storage.Client.return_value = gcs_client + + with patch.dict( + sys.modules, + { + "google": MagicMock(cloud=MagicMock(storage=mock_storage)), + "google.cloud": MagicMock(storage=mock_storage), + "google.cloud.storage": mock_storage, + }, + ): + yield mock_storage + + +# --------------------------------------------------------------------------- +# SQL helpers (real in-memory SQLite) +# --------------------------------------------------------------------------- + + +def make_sqlite_engine(table: str = "documents", columns: list[str] | None = None) -> Any: + """Create an empty in-memory SQLite engine with *columns*. + + Uses ``StaticPool`` so all threads share the same connection — required for + ``asyncio.to_thread`` tests. *columns* defaults to + ``["id", "title", "body"]``. + """ + import sqlalchemy as sa + from sqlalchemy.pool import StaticPool + + if columns is None: + columns = ["id", "title", "body"] + + engine = sa.create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + col_defs = ", ".join(f"{c} TEXT" for c in columns) + with engine.begin() as conn: + conn.execute(sa.text(f"CREATE TABLE {table} ({col_defs})")) # noqa: S608 + return engine + + +def read_all_rows(engine: Any, table: str = "documents") -> list[dict[str, Any]]: + """Read every row from *table* and return as a list of dicts.""" + import sqlalchemy as sa + + with engine.connect() as conn: + result = conn.execute(sa.text(f"SELECT * FROM {table}")) # noqa: S608 + return [dict(row._mapping) for row in result] + + +@pytest.fixture +def sqlite_engine(): + return make_sqlite_engine() + + +@pytest.fixture +def sqlite_engine_factory(): + return make_sqlite_engine diff --git a/packages/railtracks/tests/unit_tests/writers/test_azure_blob_writer.py b/packages/railtracks/tests/unit_tests/writers/test_azure_blob_writer.py new file mode 100644 index 000000000..e03be01d8 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/writers/test_azure_blob_writer.py @@ -0,0 +1,208 @@ +"""Unit tests for AzureBlobWriter.""" + +import pytest + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_container_client, patch_azure + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_chunk(id_=None, document=None, content="hello world"): + return Chunk(content=content, id=id_, document=document) + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestInit: + def test_missing_azure_raises(self, monkeypatch): + import sys + + monkeypatch.setitem(sys.modules, "azure.storage.blob", None) + with pytest.raises(ImportError, match="azure-storage-blob"): + from railtracks.writers.azure_blob import AzureBlobWriter + + AzureBlobWriter("https://acc.blob.core.windows.net", "container") + + def test_default_credential_used(self): + cc = make_container_client() + with patch_azure(cc) as (mock_storage, mock_identity): + from railtracks.writers.azure_blob import AzureBlobWriter + + AzureBlobWriter("https://acc.blob.core.windows.net", "container") + + mock_identity.DefaultAzureCredential.assert_called_once() + + def test_explicit_credential_skips_default(self): + cc = make_container_client() + with patch_azure(cc) as (mock_storage, mock_identity): + from railtracks.writers.azure_blob import AzureBlobWriter + + AzureBlobWriter( + "https://acc.blob.core.windows.net", + "container", + credential="my-key", + ) + mock_identity.DefaultAzureCredential.assert_not_called() + + def test_trailing_slash_stripped(self): + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter( + "https://acc.blob.core.windows.net/", + "container", + ) + assert writer._account_url == "https://acc.blob.core.windows.net" + + +# --------------------------------------------------------------------------- +# write_key +# --------------------------------------------------------------------------- + + +class TestWriteKey: + def test_uploads_blob(self): + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uri = writer.write_key("reports/hello.txt", "hello world") + + assert uri == "https://acc.blob.core.windows.net/ctr/reports/hello.txt" + cc.get_blob_client.assert_called_once_with("reports/hello.txt") + assert cc._written["reports/hello.txt"] == b"hello world" + + def test_custom_encoding(self): + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter( + "https://acc.blob.core.windows.net", "ctr", encoding="latin-1" + ) + writer.write_key("f.txt", "caf\xe9") + + assert cc._written["f.txt"] == "caf\xe9".encode("latin-1") + + +# --------------------------------------------------------------------------- +# write +# --------------------------------------------------------------------------- + + +class TestWrite: + def test_writes_all_chunks(self): + chunks = [make_chunk(id_=f"blob-{i}") for i in range(3)] + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = writer.write(chunks) + + assert len(uris) == 3 + assert cc.get_blob_client.call_count == 3 + + def test_key_from_chunk_id(self): + chunk = make_chunk(id_="my-id", document="my-doc") + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = writer.write([chunk]) + + cc.get_blob_client.assert_called_once_with("my-id") + assert uris[0] == "https://acc.blob.core.windows.net/ctr/my-id" + + def test_key_uses_auto_uuid_when_no_explicit_id(self): + # Chunk always auto-generates a UUID id in __post_init__ + chunk = make_chunk(id_=None, document="ignored") + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = writer.write([chunk]) + + # chunk.id was auto-assigned a UUID by Chunk.__post_init__ + name = cc.get_blob_client.call_args.args[0] + assert len(name) == 36 # UUID4 format + + def test_prefix_prepended(self): + chunk = make_chunk(id_="report.txt") + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = writer.write([chunk], prefix="output/") + + cc.get_blob_client.assert_called_once_with("output/report.txt") + assert uris[0] == "https://acc.blob.core.windows.net/ctr/output/report.txt" + + def test_custom_key_fn(self): + chunk = make_chunk(id_="ignore") + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter( + "https://acc.blob.core.windows.net", + "ctr", + key_fn=lambda c: "overridden.txt", + ) + writer.write([chunk]) + + cc.get_blob_client.assert_called_once_with("overridden.txt") + + def test_empty_chunks(self): + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = writer.write([]) + + assert uris == [] + cc.get_blob_client.assert_not_called() + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +class TestAsync: + @pytest.mark.asyncio + async def test_awrite(self): + chunks = [make_chunk(id_="z")] + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uris = await writer.awrite(chunks) + + assert uris == ["https://acc.blob.core.windows.net/ctr/z"] + + @pytest.mark.asyncio + async def test_awrite_key(self): + cc = make_container_client() + with patch_azure(cc): + from railtracks.writers.azure_blob import AzureBlobWriter + + writer = AzureBlobWriter("https://acc.blob.core.windows.net", "ctr") + uri = await writer.awrite_key("k.txt", "data") + + assert uri == "https://acc.blob.core.windows.net/ctr/k.txt" diff --git a/packages/railtracks/tests/unit_tests/writers/test_gcs_writer.py b/packages/railtracks/tests/unit_tests/writers/test_gcs_writer.py new file mode 100644 index 000000000..6507feba7 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/writers/test_gcs_writer.py @@ -0,0 +1,204 @@ +"""Unit tests for GCSWriter.""" + +import pytest + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_gcs_client, patch_gcs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_chunk(id_=None, document=None, content="hello world"): + return Chunk(content=content, id=id_, document=document) + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestInit: + def test_missing_gcs_raises(self, monkeypatch): + import sys + + monkeypatch.setitem(sys.modules, "google.cloud.storage", None) + monkeypatch.setitem(sys.modules, "google.cloud", None) + with pytest.raises(ImportError, match="google-cloud-storage"): + from railtracks.writers.gcs import GCSWriter + + GCSWriter("my-bucket") + + def test_credentials_forwarded(self): + client = make_gcs_client() + with patch_gcs(client) as mock_storage: + from railtracks.writers.gcs import GCSWriter + + creds = object() + GCSWriter("my-bucket", credentials=creds, project="proj") + + mock_storage.Client.assert_called_once_with( + project="proj", credentials=creds + ) + + +# --------------------------------------------------------------------------- +# write_key +# --------------------------------------------------------------------------- + + +class TestWriteKey: + def test_uploads_object(self): + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uri = writer.write_key("docs/hello.txt", "hello world") + + assert uri == "gs://my-bucket/docs/hello.txt" + assert client._written["docs/hello.txt"] == b"hello world" + + def test_custom_encoding(self): + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket", encoding="latin-1") + writer.write_key("f.txt", "caf\xe9") + + assert client._written["f.txt"] == "caf\xe9".encode("latin-1") + + def test_content_type_passed(self): + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket", content_type="application/json") + writer.write_key("data.json", "{}") + + assert client._content_types["data.json"] == "application/json" + + +# --------------------------------------------------------------------------- +# write +# --------------------------------------------------------------------------- + + +class TestWrite: + def test_writes_all_chunks(self): + chunks = [make_chunk(id_=f"doc-{i}") for i in range(3)] + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uris = writer.write(chunks) + + assert len(uris) == 3 + assert len(client._written) == 3 + + def test_key_from_chunk_id(self): + chunk = make_chunk(id_="my-id", document="my-doc") + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uris = writer.write([chunk]) + + assert "my-id" in client._written + assert uris[0] == "gs://my-bucket/my-id" + + def test_key_uses_auto_uuid_when_no_explicit_id(self): + # Chunk always auto-generates a UUID id in __post_init__ + chunk = make_chunk(id_=None, document="ignored") + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + writer.write([chunk]) + + key = list(client._written.keys())[0] + assert len(key) == 36 # UUID4 format + + def test_prefix_prepended(self): + chunk = make_chunk(id_="report.txt") + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uris = writer.write([chunk], prefix="output/") + + assert "output/report.txt" in client._written + assert uris[0] == "gs://my-bucket/output/report.txt" + + def test_custom_key_fn(self): + chunk = make_chunk(id_="ignore") + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket", key_fn=lambda c: "custom.txt") + writer.write([chunk]) + + assert "custom.txt" in client._written + + def test_empty_chunks(self): + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uris = writer.write([]) + + assert uris == [] + assert client._written == {} + + def test_content_written_correctly(self): + chunk = make_chunk(id_="doc", content="the quick brown fox") + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + writer.write([chunk]) + + assert client._written["doc"] == b"the quick brown fox" + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +class TestAsync: + @pytest.mark.asyncio + async def test_awrite(self): + chunks = [make_chunk(id_="z", content="data")] + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uris = await writer.awrite(chunks) + + assert uris == ["gs://my-bucket/z"] + assert client._written["z"] == b"data" + + @pytest.mark.asyncio + async def test_awrite_key(self): + client = make_gcs_client() + with patch_gcs(client): + from railtracks.writers.gcs import GCSWriter + + writer = GCSWriter("my-bucket") + uri = await writer.awrite_key("k.txt", "content") + + assert uri == "gs://my-bucket/k.txt" + assert client._written["k.txt"] == b"content" diff --git a/packages/railtracks/tests/unit_tests/writers/test_s3_writer.py b/packages/railtracks/tests/unit_tests/writers/test_s3_writer.py new file mode 100644 index 000000000..a9dee2e8f --- /dev/null +++ b/packages/railtracks/tests/unit_tests/writers/test_s3_writer.py @@ -0,0 +1,196 @@ +"""Unit tests for S3Writer.""" + +import pytest + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_s3_client, patch_s3 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_chunk(id_=None, document=None, content="hello world"): + return Chunk(content=content, id=id_, document=document) + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestInit: + def test_missing_boto3_raises(self, monkeypatch): + import sys + + monkeypatch.setitem(sys.modules, "boto3", None) + with pytest.raises(ImportError, match="boto3"): + from railtracks.writers.s3 import S3Writer + + S3Writer("my-bucket") + + +# --------------------------------------------------------------------------- +# write_key +# --------------------------------------------------------------------------- + + +class TestWriteKey: + def test_puts_object_at_key(self): + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uri = writer.write_key("docs/hello.txt", "hello world") + + assert uri == "s3://my-bucket/docs/hello.txt" + client.put_object.assert_called_once_with( + Bucket="my-bucket", + Key="docs/hello.txt", + Body=b"hello world", + ContentType="text/plain; charset=utf-8", + ) + + def test_custom_encoding(self): + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket", encoding="latin-1") + writer.write_key("file.txt", "caf\xe9") + + _call = client.put_object.call_args + assert _call.kwargs["Body"] == "caf\xe9".encode("latin-1") + + def test_custom_content_type(self): + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket", content_type="application/json") + writer.write_key("data.json", "{}") + + assert client.put_object.call_args.kwargs["ContentType"] == "application/json" + + +# --------------------------------------------------------------------------- +# write +# --------------------------------------------------------------------------- + + +class TestWrite: + def test_writes_all_chunks(self): + chunks = [make_chunk(id_=f"doc-{i}", content=f"content {i}") for i in range(3)] + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uris = writer.write(chunks) + + assert len(uris) == 3 + assert client.put_object.call_count == 3 + + def test_key_from_chunk_id(self): + chunk = make_chunk(id_="my-id", document="my-doc") + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uris = writer.write([chunk]) + + assert uris[0] == "s3://my-bucket/my-id" + client.put_object.assert_called_once() + assert client.put_object.call_args.kwargs["Key"] == "my-id" + + def test_key_uses_auto_uuid_when_no_explicit_id(self): + # Chunk always auto-generates a UUID id in __post_init__ + chunk = make_chunk(id_=None, document="ignored") + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + writer.write([chunk]) + + key = client.put_object.call_args.kwargs["Key"] + assert len(key) == 36 # UUID4 format + + def test_prefix_prepended_to_key(self): + chunk = make_chunk(id_="doc.txt") + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uris = writer.write([chunk], prefix="generated/") + + assert client.put_object.call_args.kwargs["Key"] == "generated/doc.txt" + assert uris[0] == "s3://my-bucket/generated/doc.txt" + + def test_custom_key_fn(self): + chunk = make_chunk(id_="ignore-this") + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket", key_fn=lambda c: "custom-key.txt") + writer.write([chunk]) + + assert client.put_object.call_args.kwargs["Key"] == "custom-key.txt" + + def test_returns_correct_uris(self): + chunks = [make_chunk(id_="a"), make_chunk(id_="b")] + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("bucket-x") + uris = writer.write(chunks) + + assert uris == ["s3://bucket-x/a", "s3://bucket-x/b"] + + def test_empty_chunks(self): + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uris = writer.write([]) + + assert uris == [] + client.put_object.assert_not_called() + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +class TestAsync: + @pytest.mark.asyncio + async def test_awrite_delegates_to_write(self): + chunks = [make_chunk(id_="x")] + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uris = await writer.awrite(chunks) + + assert uris == ["s3://my-bucket/x"] + + @pytest.mark.asyncio + async def test_awrite_key_delegates(self): + client = make_s3_client() + with patch_s3(client): + from railtracks.writers.s3 import S3Writer + + writer = S3Writer("my-bucket") + uri = await writer.awrite_key("k.txt", "data") + + assert uri == "s3://my-bucket/k.txt" diff --git a/packages/railtracks/tests/unit_tests/writers/test_sql_writer.py b/packages/railtracks/tests/unit_tests/writers/test_sql_writer.py new file mode 100644 index 000000000..d17989d13 --- /dev/null +++ b/packages/railtracks/tests/unit_tests/writers/test_sql_writer.py @@ -0,0 +1,293 @@ +"""Unit tests for SQLWriter (real in-memory SQLite — no mocking needed).""" + +import pytest + +from railtracks.vector_stores.chunking.base_chunker import Chunk + +from .conftest import make_sqlite_engine, read_all_rows + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_chunk(id_=None, document=None, content="hello", metadata=None): + return Chunk( + content=content, + id=id_, + document=document, + metadata=metadata or {}, + ) + + +# --------------------------------------------------------------------------- +# Init +# --------------------------------------------------------------------------- + + +class TestInit: + def test_missing_sqlalchemy_raises(self, monkeypatch): + import sys + + monkeypatch.setitem(sys.modules, "sqlalchemy", None) + with pytest.raises(ImportError, match="sqlalchemy"): + from railtracks.writers.sql import SQLWriter + + SQLWriter("", "docs", "body") + + +# --------------------------------------------------------------------------- +# write_key +# --------------------------------------------------------------------------- + + +class TestWriteKey: + def test_inserts_row(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uri = writer.write_key("doc-1", "some content") + + rows = read_all_rows(engine) + assert len(rows) == 1 + assert rows[0]["id"] == "doc-1" + assert rows[0]["body"] == "some content" + assert uri == "sql://documents/doc-1" + + def test_write_key_without_id_column_raises(self): + engine = make_sqlite_engine(columns=["body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", engine=engine) + with pytest.raises(ValueError, match="id_column"): + writer.write_key("k", "content") + + def test_upsert_replaces_existing_row(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + writer.write_key("doc-1", "first version") + writer.write_key("doc-1", "second version") + + rows = read_all_rows(engine) + assert len(rows) == 1 + assert rows[0]["body"] == "second version" + + def test_insert_mode_allows_duplicate_ids(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter( + "", "documents", "body", id_column="id", mode="insert", engine=engine + ) + writer.write_key("doc-1", "v1") + writer.write_key("doc-1", "v2") + + rows = read_all_rows(engine) + assert len(rows) == 2 + + +# --------------------------------------------------------------------------- +# write +# --------------------------------------------------------------------------- + + +class TestWrite: + def test_inserts_all_chunks(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + chunks = [make_chunk(id_=f"doc-{i}", content=f"content {i}") for i in range(3)] + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uris = writer.write(chunks) + + rows = read_all_rows(engine) + assert len(rows) == 3 + assert len(uris) == 3 + + def test_content_column_written(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(id_="x", content="the quick brown fox") + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["body"] == "the quick brown fox" + + def test_id_column_written(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(id_="my-id", content="x") + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["id"] == "my-id" + + def test_document_column_written(self): + engine = make_sqlite_engine(columns=["id", "title", "body"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(id_="1", document="My Title", content="body text") + writer = SQLWriter( + "", + "documents", + "body", + id_column="id", + document_column="title", + engine=engine, + ) + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["title"] == "My Title" + + def test_metadata_columns_written(self): + engine = make_sqlite_engine(columns=["id", "body", "category"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk( + id_="1", + content="body", + metadata={"category": "HR", "ignored": "value"}, + ) + writer = SQLWriter( + "", + "documents", + "body", + id_column="id", + metadata_columns=["category"], + engine=engine, + ) + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["category"] == "HR" + + def test_missing_metadata_key_skipped(self): + engine = make_sqlite_engine(columns=["id", "body", "category"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(id_="1", content="body", metadata={}) + writer = SQLWriter( + "", + "documents", + "body", + id_column="id", + metadata_columns=["category"], + engine=engine, + ) + # Should not raise even when the metadata key is absent + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["category"] is None + + def test_upsert_mode_replaces_row(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + writer.write([make_chunk(id_="doc-1", content="v1")]) + writer.write([make_chunk(id_="doc-1", content="v2")]) + + rows = read_all_rows(engine) + assert len(rows) == 1 + assert rows[0]["body"] == "v2" + + def test_insert_mode_appends(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter( + "", "documents", "body", id_column="id", mode="insert", engine=engine + ) + writer.write([make_chunk(id_="doc-1", content="v1")]) + writer.write([make_chunk(id_="doc-1", content="v2")]) + + rows = read_all_rows(engine) + assert len(rows) == 2 + + def test_returns_uri_with_id(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(id_="abc", content="x") + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uris = writer.write([chunk]) + + assert uris == ["sql://documents/abc"] + + def test_returns_uri_without_id(self): + engine = make_sqlite_engine(columns=["body"]) + from railtracks.writers.sql import SQLWriter + + chunk = make_chunk(content="x") + writer = SQLWriter("", "documents", "body", engine=engine) + uris = writer.write([chunk]) + + assert uris == ["sql://documents"] + + def test_empty_chunks(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uris = writer.write([]) + + assert uris == [] + assert read_all_rows(engine) == [] + + def test_chunk_auto_uuid_written_to_id_column(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + # Chunk always auto-assigns a UUID in __post_init__, so id_column + # always receives the auto-generated UUID when id_column is configured. + chunk = make_chunk(id_=None, content="x") + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + writer.write([chunk]) + + rows = read_all_rows(engine) + assert rows[0]["body"] == "x" + assert rows[0]["id"] is not None + assert len(rows[0]["id"]) == 36 # UUID4 + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +class TestAsync: + @pytest.mark.asyncio + async def test_awrite(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + chunks = [make_chunk(id_="x", content="async content")] + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uris = await writer.awrite(chunks) + + assert uris == ["sql://documents/x"] + rows = read_all_rows(engine) + assert rows[0]["body"] == "async content" + + @pytest.mark.asyncio + async def test_awrite_key(self): + engine = make_sqlite_engine(columns=["id", "body"]) + from railtracks.writers.sql import SQLWriter + + writer = SQLWriter("", "documents", "body", id_column="id", engine=engine) + uri = await writer.awrite_key("k", "data") + + assert uri == "sql://documents/k" + rows = read_all_rows(engine) + assert rows[0]["id"] == "k" + assert rows[0]["body"] == "data" diff --git a/pyproject.toml b/pyproject.toml index 1bb1524e3..0f51ce79c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,9 @@ version = "0.0.0" description = "Root workspace for dev tools and workspace packages" readme = "README.md" requires-python = ">=3.10" -dependencies = ["railtracks"] +dependencies = [ + "railtracks", +] [tool.uv.sources] @@ -13,7 +15,7 @@ railtracks = { workspace = true } [tool.uv.workspace] members = [ - "packages/*", + "packages/railtracks", ]