Skip to content

Commit e6b13ca

Browse files
committed
Text-to-SQL: Add Anthropic provider
1 parent d0d0a7b commit e6b13ca

7 files changed

Lines changed: 91 additions & 15 deletions

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
OS: ${{ matrix.os }}
3737
PYTHON: ${{ matrix.python-version }}
3838
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
39+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
3940
# Do not tear down Testcontainers
4041
TC_KEEPALIVE: true
4142
# https://docs.github.com/en/actions/using-containerized-services/about-service-containers

cratedb_toolkit/query/nlsql/api.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,11 @@ def __post_init__(self):
5151

5252
def setup(self):
5353
"""Configure database connection and query engine."""
54-
logger.info("Connecting to CrateDB")
54+
from cratedb_toolkit.query.nlsql.util import configure_llm
5555

5656
# Configure model.
5757
logger.info("Configuring LLM model")
58-
llm: LLM
59-
from cratedb_toolkit.query.nlsql.util import configure_llm
60-
61-
llm = configure_llm(self.model)
58+
llm: LLM = configure_llm(self.model)
6259

6360
# Configure query engine.
6461
logger.info("Creating query engine")

cratedb_toolkit/query/nlsql/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class ModelProvider(Enum):
1010
"""Model provider choices."""
1111

1212
OPENAI = "openai"
13+
ANTHROPIC = "anthropic"
1314
AZURE = "azure"
1415
OLLAMA = "ollama"
1516

@@ -41,6 +42,8 @@ def from_options(
4142
llm_name = "gpt-4.1"
4243
elif provider in [ModelProvider.OLLAMA]:
4344
llm_name = "gemma3:1b"
45+
elif provider in [ModelProvider.ANTHROPIC]:
46+
llm_name = "claude-sonnet-4-0"
4447
else:
4548
raise ValueError("LLM completion model not defined")
4649
if not llm_api_key:
@@ -50,6 +53,12 @@ def from_options(
5053
raise ValueError(
5154
"LLM API key not defined. Use either API option or OPENAI_API_KEY environment variable."
5255
)
56+
elif provider in [ModelProvider.ANTHROPIC]:
57+
llm_api_key = os.getenv("ANTHROPIC_API_KEY")
58+
if not llm_api_key:
59+
raise ValueError(
60+
"LLM API key not defined. Use either API option or ANTHROPIC_API_KEY environment variable."
61+
)
5362
return cls(
5463
provider=provider,
5564
endpoint=llm_endpoint,

cratedb_toolkit/query/nlsql/util.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1-
import llama_index.core
1+
from typing import Optional
2+
3+
from llama_index.core import MockEmbedding, set_global_handler, settings
4+
from llama_index.core.base.embeddings.base import BaseEmbedding
5+
from llama_index.core.callbacks import CallbackManager
6+
from llama_index.core.embeddings import utils
7+
from llama_index.core.embeddings.utils import EmbedType
28
from llama_index.core.llms import LLM
9+
from llama_index.llms.anthropic import Anthropic
310
from llama_index.llms.azure_openai import AzureOpenAI
411
from llama_index.llms.ollama import Ollama
512
from llama_index.llms.openai import OpenAI
613

714
from cratedb_toolkit.query.nlsql.model import ModelInfo, ModelProvider
815

916

17+
def resolve_embed_model(
18+
embed_model: Optional[EmbedType] = None,
19+
callback_manager: Optional[CallbackManager] = None,
20+
) -> BaseEmbedding:
21+
"""Stub function for disabling embeddings without the `print` and other side effects."""
22+
return MockEmbedding(embed_dim=1)
23+
24+
1025
def configure_llm(info: ModelInfo, debug: bool = False) -> LLM:
1126
"""
1227
Configure LLM access and model types. Use either vanilla Open AI, Azure Open AI, or Ollama.
@@ -16,14 +31,18 @@ def configure_llm(info: ModelInfo, debug: bool = False) -> LLM:
1631

1732
completion_model = info.name
1833

34+
# Disable embeddings.
35+
utils.resolve_embed_model = resolve_embed_model # ty: ignore[invalid-assignment]
36+
settings.resolve_embed_model = resolve_embed_model # ty: ignore[invalid-assignment]
37+
1938
if not info.provider:
2039
raise ValueError("LLM model provider not defined")
2140
if not completion_model:
2241
raise ValueError("LLM model name not defined")
2342

2443
# https://docs.llamaindex.ai/en/stable/understanding/tracing_and_debugging/tracing_and_debugging/
2544
if debug:
26-
llama_index.core.set_global_handler("simple")
45+
set_global_handler("simple")
2746

2847
# Select completions model.
2948
if info.provider is ModelProvider.OPENAI:
@@ -53,7 +72,14 @@ def configure_llm(info: ModelInfo, debug: bool = False) -> LLM:
5372
request_timeout=120.0,
5473
keep_alive=-1,
5574
)
75+
elif info.provider is ModelProvider.ANTHROPIC:
76+
llm = Anthropic(
77+
model=completion_model,
78+
temperature=0.0,
79+
base_url=info.endpoint,
80+
api_key=info.api_key,
81+
)
5682
else:
57-
raise ValueError(f"LLM model provider not found: {info.provider}")
83+
raise ValueError(f"LLM model provider not implemented: {info.provider}")
5884

5985
return llm

doc/query/nlsql/index.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ export LLM_PROVIDER=openai
1818
export OPENAI_API_KEY=<YOUR_OPENAI_API_KEY>
1919
```
2020

21+
```shell
22+
export CRATEDB_CLUSTER_URL=crate://localhost/
23+
export LLM_PROVIDER=anthropic
24+
export ANTHROPIC_API_KEY=<YOUR_ANTHROPIC_API_KEY>
25+
```
26+
2127
```shell
2228
export CRATEDB_CLUSTER_URL=crate://localhost/
2329
export LLM_PROVIDER=ollama
@@ -38,16 +44,25 @@ engine = sa.create_engine("crate://")
3844
schema = "doc"
3945

4046
# Use Open AI GPT-4.
41-
dq = DataQuery(
47+
dataquery = DataQuery(
4248
db=DatabaseInfo(engine=engine, schema=schema),
4349
model=ModelInfo(provider=ModelProvider.OPENAI, name="gpt-4.1"),
4450
)
4551

46-
# Use Gemma3 via Ollama.
47-
dq = DataQuery(
52+
# Use Anthropic Claude Sonnet.
53+
dataquery = DataQuery(
54+
db=DatabaseInfo(engine=engine, schema=schema),
55+
model=ModelInfo(provider=ModelProvider.ANTHROPIC, name="claude-sonnet-4-0"),
56+
)
57+
58+
# Use Google Gemma3 via Ollama.
59+
dataquery = DataQuery(
4860
db=DatabaseInfo(engine=engine, schema=schema),
4961
model=ModelInfo(provider=ModelProvider.OLLAMA, name="gemma3:1b"),
5062
)
63+
64+
response = dataquery.ask("What is the average value for sensor 1?")
65+
print(response)
5166
```
5267

5368
## Example

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ optional-dependencies.mongodb = [
225225
"undatum<1.2",
226226
]
227227
optional-dependencies.nlsql = [
228+
"llama-index-llms-anthropic<0.12; python_version>='3.10'",
228229
"llama-index-llms-azure-openai<0.6; python_version>='3.10'",
229230
"llama-index-llms-ollama<0.11; python_version>='3.10'",
230231
"llama-index-llms-openai<0.8; python_version>='3.10'",

tests/query/test_nlsql.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
@pytest.fixture
1515
def provision_db(cratedb):
1616
sql_ddl = """
17-
CREATE TABLE IF NOT EXISTS time_series_data (
17+
CREATE TABLE IF NOT EXISTS testdrive.time_series_data (
1818
timestamp TIMESTAMP,
1919
value DOUBLE,
2020
location STRING,
2121
sensor_id INT
2222
);
2323
"""
2424
sql_dml = """
25-
INSERT INTO time_series_data (timestamp, value, location, sensor_id)
25+
INSERT INTO testdrive.time_series_data (timestamp, value, location, sensor_id)
2626
VALUES
2727
('2023-09-14T00:00:00', 10.5, 'Sensor A', 1),
2828
('2023-09-14T01:00:00', 15.2, 'Sensor A', 1),
@@ -42,14 +42,15 @@ def provision_db(cratedb):
4242

4343

4444
@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set")
45-
def test_query_llm(cratedb, provision_db):
45+
def test_query_nlsql_openai(cratedb, provision_db):
4646
"""
47-
Verify `ctk query nlsql ...`.
47+
Verify `ctk query nlsql ...` with Open AI.
4848
"""
4949

5050
runner = CliRunner(
5151
env={
5252
"CRATEDB_CLUSTER_URL": cratedb.get_connection_url(),
53+
"CRATEDB_SCHEMA": "testdrive",
5354
"LLM_PROVIDER": "openai",
5455
}
5556
)
@@ -64,3 +65,29 @@ def test_query_llm(cratedb, provision_db):
6465
output = json.loads(result.output)
6566
assert output["answer"] == "The average value for sensor 1 is approximately 17.03."
6667
assert output["sql_query"] == "SELECT AVG(value) FROM time_series_data WHERE sensor_id = 1"
68+
69+
70+
@pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set")
71+
def test_query_nlsql_anthropic(cratedb, provision_db):
72+
"""
73+
Verify `ctk query nlsql ...` with Anthropic.
74+
"""
75+
76+
runner = CliRunner(
77+
env={
78+
"CRATEDB_CLUSTER_URL": cratedb.get_connection_url(),
79+
"CRATEDB_SCHEMA": "testdrive",
80+
"LLM_PROVIDER": "anthropic",
81+
}
82+
)
83+
84+
result = runner.invoke(
85+
cli,
86+
input="What is the average value for sensor 1?",
87+
args="nlsql -",
88+
catch_exceptions=False,
89+
)
90+
assert result.exit_code == 0, result.output
91+
output = json.loads(result.output)
92+
assert "the average value for sensor 1 is approximately **17.03**" in output["answer"]
93+
assert output["sql_query"] == "SELECT AVG(value) as average_value FROM time_series_data WHERE sensor_id = 1;"

0 commit comments

Comments
 (0)