Skip to content

Commit 2200453

Browse files
committed
Text-to-SQL: Copy editing. Suggestions by CodeRabbit.
1 parent e6b13ca commit 2200453

4 files changed

Lines changed: 26 additions & 10 deletions

File tree

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ jobs:
3535
env:
3636
OS: ${{ matrix.os }}
3737
PYTHON: ${{ matrix.python-version }}
38-
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
39-
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
4038
# Do not tear down Testcontainers
4139
TC_KEEPALIVE: true
4240
# https://docs.github.com/en/actions/using-containerized-services/about-service-containers
@@ -76,6 +74,8 @@ jobs:
7674
env:
7775
KAGGLE_USERNAME: ${{ secrets.KAGGLE_USERNAME }}
7876
KAGGLE_KEY: ${{ secrets.KAGGLE_KEY }}
77+
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
78+
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
7979
run: |
8080
poe check
8181

cratedb_toolkit/query/nlsql/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14+
llama_index_import_error: Optional[ImportError] = None
1415

1516
try:
1617
from llama_index.core.base.response.schema import RESPONSE_TYPE
1718
from llama_index.core.llms import LLM
1819
from llama_index.core.query_engine import NLSQLTableQueryEngine
1920
from llama_index.core.utilities.sql_wrapper import SQLDatabase
20-
except ImportError:
21-
pass
21+
except ImportError as exc:
22+
llama_index_import_error = exc
2223

2324

2425
@dataclasses.dataclass
@@ -53,6 +54,11 @@ def setup(self):
5354
"""Configure database connection and query engine."""
5455
from cratedb_toolkit.query.nlsql.util import configure_llm
5556

57+
if llama_index_import_error:
58+
raise ImportError(
59+
"NLSQL support requires installing `cratedb-toolkit[nlsql]`"
60+
) from llama_index_import_error
61+
5662
# Configure model.
5763
logger.info("Configuring LLM model")
5864
llm: LLM = configure_llm(self.model)

cratedb_toolkit/query/nlsql/model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,27 @@ def from_options(
4747
else:
4848
raise ValueError("LLM completion model not defined")
4949
if not llm_api_key:
50-
if provider in [ModelProvider.OPENAI, ModelProvider.AZURE]:
50+
if provider is ModelProvider.OPENAI:
5151
llm_api_key = os.getenv("OPENAI_API_KEY")
5252
if not llm_api_key:
5353
raise ValueError(
54-
"LLM API key not defined. Use either API option or OPENAI_API_KEY environment variable."
54+
"LLM API key not defined. Use either CLI/API parameter or OPENAI_API_KEY environment variable."
5555
)
56-
elif provider in [ModelProvider.ANTHROPIC]:
56+
elif provider is ModelProvider.AZURE:
57+
llm_endpoint = llm_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
58+
llm_api_key = os.getenv("AZURE_OPENAI_API_KEY")
59+
llm_api_version = llm_api_version or os.getenv("OPENAI_API_VERSION")
60+
if not llm_api_key:
61+
raise ValueError(
62+
"LLM API key not defined. Use either CLI/API parameter or "
63+
"AZURE_OPENAI_API_KEY environment variable."
64+
)
65+
elif provider is ModelProvider.ANTHROPIC:
5766
llm_api_key = os.getenv("ANTHROPIC_API_KEY")
5867
if not llm_api_key:
5968
raise ValueError(
60-
"LLM API key not defined. Use either API option or ANTHROPIC_API_KEY environment variable."
69+
"LLM API key not defined. Use either CLI/API parameter or "
70+
"ANTHROPIC_API_KEY environment variable."
6171
)
6272
return cls(
6373
provider=provider,

tests/query/test_nlsql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from cratedb_toolkit.query.cli import cli
99

1010
if sys.version_info < (3, 10):
11-
pytest.skip("Only available for Python 3.10+", allow_module_level=True)
11+
pytest.skip("Only available for Python 3.10+", allow_module_level=True) # ty: ignore[invalid-argument-type,too-many-positional-arguments]
1212

1313

1414
@pytest.fixture
1515
def provision_db(cratedb):
1616
sql_ddl = """
17-
CREATE TABLE IF NOT EXISTS testdrive.time_series_data (
17+
CREATE TABLE testdrive.time_series_data (
1818
timestamp TIMESTAMP,
1919
value DOUBLE,
2020
location STRING,

0 commit comments

Comments
 (0)