From e94b1bfb203d09465a3e0e7f0dcd530b90cdffdf Mon Sep 17 00:00:00 2001 From: nranudeep <10704725+nranudeep@users.noreply.github.com> Date: Wed, 9 Jul 2025 16:13:30 +0000 Subject: [PATCH] Azure Open AI support. api-version flag addition to client args on LLM Client initialization --- DOCS.md | 14 +++++ synthetic_data_kit/cli.py | 35 ++++++++++-- synthetic_data_kit/config.yaml | 1 + synthetic_data_kit/models/llm_client.py | 37 ++++++++++-- synthetic_data_kit/utils/config.py | 1 + tests/conftest.py | 23 +++++++- tests/functional/test_cli.py | 24 +++++++- tests/unit/test_error_handling.py | 6 +- tests/unit/test_llm_client.py | 76 ++++++++++++++++++++++++- 9 files changed, 200 insertions(+), 17 deletions(-) diff --git a/DOCS.md b/DOCS.md index 3d658bad..f8aa7b98 100644 --- a/DOCS.md +++ b/DOCS.md @@ -539,6 +539,11 @@ paths: cleaned: "data/cleaned" final: "data/final" +# LLM Provider configuration +llm: + # Provider selection: "vllm" or "api-endpoint" + provider: "api-endpoint" + # vllm: Configure VLLM server settings vllm: api_base: "http://localhost:8000/v1" @@ -547,6 +552,15 @@ vllm: max_retries: 3 retry_delay: 1.0 +# API endpoint configuration +api-endpoint: + api_base: "https://api.llama.com/v1" # Optional base URL for API endpoint (null for default API) + api_key: "llama-api-key" # API key for API endpoint or compatible service (can use env var instead) + model: "Llama-4-Maverick-17B-128E-Instruct-FP8" # Default model to use + azure_api_version: "2024-06-01" # API version needed for Azure OpenAI endpoints. Make it Null for other providers + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + # generation: Content generation parameters generation: temperature: 0.7 diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index 5915fbcd..df2665d6 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -87,6 +87,15 @@ def system_check( console.print(f"API key source: {'Environment variable' if api_endpoint_key else 'Config file'}") model = api_endpoint_config.get("model") + + # Check for azure api version in environment variables + azure_api_version_key = os.environ.get('AZURE_API_VERSION') + console.print(f"AZURE_API_VERSION_KEY environment variable: {'Found' if azure_api_version_key else 'Not found'}") + + # Set API VERSION key with priority: env var > config + azure_api_version = azure_api_version_key or api_endpoint_config.get("azure_api_version") + if azure_api_version: + console.print(f"API version source: {'Environment variable' if azure_api_version_key else 'Config file'}") # Check API endpoint access with console.status(f"Checking API endpoint access..."): @@ -94,6 +103,7 @@ def system_check( # Try to import OpenAI try: from openai import OpenAI + from openai import AzureOpenAI except ImportError: console.print("L API endpoint package not installed", style="red") console.print("Install with: pip install openai>=1.0.0", style="yellow") @@ -105,13 +115,26 @@ def system_check( client_kwargs['api_key'] = api_key if api_base: client_kwargs['base_url'] = api_base + if azure_api_version: + client_kwargs['api_version'] = azure_api_version # Check API access try: - client = OpenAI(**client_kwargs) - # Try a simple models list request to check connectivity - models = client.models.list() - console.print(f" API endpoint access confirmed", style="green") + if azure_api_version: + client = AzureOpenAI(**client_kwargs) + resp = client.chat.completions.create( + model=model, + messages=[{"role":"user", "content":"Hello World!"}], + max_tokens=1 + ) + console.print("API endpoint access confirmed.. API Responded with: ", + resp.choices[0].message.content, style="green") + else: + client = OpenAI(**client_kwargs) + # Try a simple models list request to check connectivity + models = client.models.list() + console.print(f" API endpoint access confirmed", style="green") + if api_base: console.print(f"Using custom API base: {api_base}", style="green") console.print(f"Default model: {model}", style="green") @@ -120,8 +143,10 @@ def system_check( console.print(f"L Error connecting to API endpoint: {str(e)}", style="red") if api_base: console.print(f"Using custom API base: {api_base}", style="yellow") - if not api_key and not api_base: + if not api_key and api_base: console.print("API key is required. Set in config.yaml or as API_ENDPOINT_KEY env var", style="yellow") + if not azure_api_version and api_base: + console.print("Azure API version is required. Set in config.yaml or as AZURE_API_VERSION env var", style="yellow") return 1 except Exception as e: console.print(f"L Error: {str(e)}", style="red") diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index 70dc7fce..48f05e2a 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -30,6 +30,7 @@ api-endpoint: api_base: "https://api.llama.com/v1" # Optional base URL for API endpoint (null for default API) api_key: "llama-api-key" # API key for API endpoint or compatible service (can use env var instead) model: "Llama-4-Maverick-17B-128E-Instruct-FP8" # Default model to use + azure_api_version: "2024-06-01" # API version needed for Azure OpenAI endpoints max_retries: 3 # Number of retries for API calls retry_delay: 1.0 # Initial delay between retries (seconds) diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 708f5692..5d7a0de1 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -22,6 +22,7 @@ # Try to import OpenAI, but handle case where it's not installed try: from openai import OpenAI + from openai import AzureOpenAI from openai.types.chat import ChatCompletion OPENAI_AVAILABLE = True except ImportError: @@ -35,6 +36,7 @@ def __init__(self, api_base: Optional[str] = None, api_key: Optional[str] = None, model_name: Optional[str] = None, + azure_api_version: Optional[str] = None, max_retries: Optional[int] = None, retry_delay: Optional[float] = None): """Initialize an LLM client that supports multiple providers @@ -45,6 +47,7 @@ def __init__(self, api_base: Override API base URL from config api_key: Override API key for API endpoint (only needed for 'api-endpoint' provider) model_name: Override model name from config + azure_api_version: Override azure api version from config. Needed for Azure OpenAI endpoints max_retries: Override max retries from config retry_delay: Override retry delay from config """ @@ -74,7 +77,16 @@ def __init__(self, if not self.api_key and not self.api_base: # Only require API key for official API raise ValueError("API key is required for API endpoint provider. Set in config or API_ENDPOINT_KEY env var.") - + + # Check for azure api version in environment variables + azure_api_version_key = os.environ.get('AZURE_API_VERSION') + print(f"AZURE_API_VERSION_KEY environment variable: {'Found' if azure_api_version_key else 'Not found'}") + + # Set API VERSION key with priority: CLI arg > env var > config + self.azure_api_version = azure_api_version or azure_api_version_key or api_endpoint_config.get("azure_api_version") + if self.azure_api_version: + print(f"API version source: {'Environment variable' if azure_api_version_key else 'Config file'}") + self.model = model_name or api_endpoint_config.get('model') self.max_retries = max_retries or api_endpoint_config.get('max_retries') self.retry_delay = retry_delay or api_endpoint_config.get('retry_delay') @@ -113,8 +125,16 @@ def _init_openai_client(self): if self.api_base: print(f"Using API base URL: {self.api_base}") client_kwargs['base_url'] = self.api_base - - self.openai_client = OpenAI(**client_kwargs) + + # Add Azure api version if provided (Needed for Azure OpenAI APIs) + print(f"Using API VERSION: {self.azure_api_version}") + if self.azure_api_version: + print(f"Using API base URL: {self.api_base}") + client_kwargs['api_version'] = self.azure_api_version + # OpenAI library differs for AzureOpenAI support + self.openai_client = AzureOpenAI(**client_kwargs) + else: + self.openai_client = OpenAI(**client_kwargs) def _check_vllm_server(self) -> tuple: """Check if the VLLM server is running and accessible""" @@ -351,6 +371,7 @@ async def _process_message_async(self, """Process a single message set asynchronously using the OpenAI API""" try: from openai import AsyncOpenAI + from openai import AsyncAzureOpenAI except ImportError: raise ImportError("The 'openai' package is required for this functionality. Please install it using 'pip install openai>=1.0.0'.") @@ -360,8 +381,14 @@ async def _process_message_async(self, client_kwargs['api_key'] = self.api_key if self.api_base: client_kwargs['base_url'] = self.api_base - - async_client = AsyncOpenAI(**client_kwargs) + + # Add Azure api version if provided (Needed for Azure OpenAI APIs) + if self.azure_api_version: + client_kwargs['api_version'] = self.azure_api_version + # OpenAI library differs for Azure OpenAI support + async_client = AsyncAzureOpenAI(**client_kwargs) + else: + async_client = AsyncOpenAI(**client_kwargs) for attempt in range(self.max_retries): try: diff --git a/synthetic_data_kit/utils/config.py b/synthetic_data_kit/utils/config.py index e17600f7..3bc2bf27 100644 --- a/synthetic_data_kit/utils/config.py +++ b/synthetic_data_kit/utils/config.py @@ -104,6 +104,7 @@ def get_openai_config(config: Dict[str, Any]) -> Dict[str, Any]: 'api_base': None, # None means use default API base URL 'api_key': None, # None means use environment variables 'model': 'gpt-4o', + 'azure_api_version': None, 'max_retries': 3, 'retry_delay': 1.0 }) diff --git a/tests/conftest.py b/tests/conftest.py index 25a384d9..044dc4f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -162,6 +162,7 @@ def create_api_config(provider="api-endpoint", api_key="mock-key", model="mock-m "api_base": "https://api.together.xyz/v1", "api_key": api_key, "model": model, + "azure_api_version": "2024-06-01", "max_retries": 3, "retry_delay": 1, }, @@ -232,8 +233,28 @@ def test_env(): @pytest.fixture def patch_config(config_factory): """Patch the config loader to return a mock configuration.""" + mock_config = config_factory.create_api_config() + with patch("synthetic_data_kit.utils.config.load_config") as mock_load_config: - mock_load_config.return_value = config_factory.create_api_config() + mock_load_config.return_value = mock_config + yield mock_load_config + +@pytest.fixture +def patch_llm_client_config(config_factory): + """Patch the config loader to return a mock configuration.""" + mock_config = config_factory.create_api_config() + + with patch("synthetic_data_kit.models.llm_client.load_config") as mock_load_config: + mock_load_config.return_value = mock_config + yield mock_load_config + +@pytest.fixture +def patch_cli_config(config_factory): + """Patch the config loader to return a mock configuration.""" + mock_config = config_factory.create_api_config() + + with patch("synthetic_data_kit.cli.load_config") as mock_load_config: + mock_load_config.return_value = mock_config yield mock_load_config diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 3f9b5d68..48a6a138 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -32,7 +32,7 @@ def test_system_check_command_vllm(patch_config): @pytest.mark.functional -def test_system_check_command_api_endpoint(patch_config, test_env): +def test_system_check_command_api_endpoint(patch_cli_config, test_env): """Test the system-check command with API endpoint provider.""" runner = CliRunner() @@ -42,12 +42,30 @@ def test_system_check_command_api_endpoint(patch_config, test_env): mock_client.models.list.return_value = ["mock-model"] mock_openai.return_value = mock_client + # Get the mock config from the fixture + config = patch_cli_config.return_value + config["api-endpoint"]["azure_api_version"] = None + result = runner.invoke(app, ["system-check", "--provider", "api-endpoint"]) - - # Just check exit code, not specific message since it varies assert result.exit_code == 0 mock_openai.assert_called_once() +@pytest.mark.functional +def test_system_check_command_azure_api_endpoint(patch_cli_config, test_env): + """Test the system-check command with API endpoint provider.""" + runner = CliRunner() + + # Mock Azure OpenAI client + with patch("openai.AzureOpenAI") as mock_azure_openai: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="Hello World!"))] + ) + mock_azure_openai.return_value = mock_client + + result = runner.invoke(app, ["system-check", "--provider", "api-endpoint"]) + assert result.exit_code == 0 + mock_azure_openai.assert_called_once() @pytest.mark.functional def test_ingest_command(patch_config): diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 21289a82..cc2d7919 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -39,12 +39,16 @@ def test_parse_qa_pairs_invalid_json(): @pytest.mark.unit -def test_llm_client_error_handling(patch_config, test_env): +def test_llm_client_error_handling(patch_llm_client_config, test_env): """Test error handling in LLM client.""" with patch("synthetic_data_kit.models.llm_client.OpenAI") as mock_openai: # Setup mock to raise an exception mock_openai.side_effect = Exception("API Error") + # Get the mock config from the fixture + config = patch_llm_client_config.return_value + config["api-endpoint"]["azure_api_version"] = None + # Should handle the exception gracefully with pytest.raises(Exception) as excinfo: LLMClient(provider="api-endpoint") diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py index d030f652..00ed4012 100644 --- a/tests/unit/test_llm_client.py +++ b/tests/unit/test_llm_client.py @@ -8,12 +8,34 @@ @pytest.mark.unit -def test_llm_client_initialization(patch_config, test_env): +def test_llm_client_initialization(patch_llm_client_config, test_env): """Test LLM client initialization with API endpoint provider.""" with patch("synthetic_data_kit.models.llm_client.OpenAI") as mock_openai: mock_client = MagicMock() mock_openai.return_value = mock_client + # Get the mock config from the fixture + config = patch_llm_client_config.return_value + config["api-endpoint"]["azure_api_version"] = None + + # Initialize client + client = LLMClient(provider="api-endpoint") + + # Check that the client was initialized correctly + assert client.provider == "api-endpoint" + assert client.api_base is not None + assert client.model is not None + # Check that OpenAI client was initialized + assert mock_openai.called + + +@pytest.mark.unit +def test_azure_llm_client_initialization(patch_llm_client_config, test_env): + """Test LLM client initialization with API endpoint provider.""" + with patch("synthetic_data_kit.models.llm_client.AzureOpenAI") as mock_openai: + mock_client = MagicMock() + mock_openai.return_value = mock_client + # Initialize client client = LLMClient(provider="api-endpoint") @@ -21,6 +43,7 @@ def test_llm_client_initialization(patch_config, test_env): assert client.provider == "api-endpoint" assert client.api_base is not None assert client.model is not None + assert client.azure_api_version is not None # Check that OpenAI client was initialized assert mock_openai.called @@ -46,7 +69,7 @@ def test_llm_client_vllm_initialization(patch_config, test_env): @pytest.mark.unit -def test_llm_client_chat_completion(patch_config, test_env): +def test_llm_client_chat_completion(patch_llm_client_config, test_env): """Test LLM client chat completion with API endpoint provider.""" with patch("synthetic_data_kit.models.llm_client.OpenAI") as mock_openai: # Create a proper mock chain for OpenAI client @@ -73,6 +96,10 @@ def test_llm_client_chat_completion(patch_config, test_env): # Connect the create function to return our mock response mock_create.return_value = mock_response + # Get the mock config from the fixture + config = patch_llm_client_config.return_value + config["api-endpoint"]["azure_api_version"] = None + # Initialize client client = LLMClient(provider="api-endpoint") @@ -90,6 +117,51 @@ def test_llm_client_chat_completion(patch_config, test_env): assert mock_create.called +@pytest.mark.unit +def test_azure_llm_client_chat_completion(patch_llm_client_config, test_env): + """Test LLM client chat completion with API endpoint provider.""" + with patch("synthetic_data_kit.models.llm_client.AzureOpenAI") as mock_openai: + # Create a proper mock chain for OpenAI client + mock_client = MagicMock() + mock_chat = MagicMock() + mock_completions = MagicMock() + mock_create = MagicMock() + + # Setup the nested mock structure + mock_openai.return_value = mock_client + mock_client.chat = mock_chat + mock_chat.completions = mock_completions + mock_completions.create = mock_create + + # Setup mock response + mock_response = MagicMock() + mock_choice = MagicMock() + mock_message = MagicMock() + + mock_message.content = "This is a test response" + mock_choice.message = mock_message + mock_response.choices = [mock_choice] + + # Connect the create function to return our mock response + mock_create.return_value = mock_response + + # Initialize client + client = LLMClient(provider="api-endpoint") + + # Test chat completion + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is synthetic data?"}, + ] + + response = client.chat_completion(messages, temperature=0.7) + + # Check that the response is correct + assert response == "This is a test response" + # Check that AzureOpenAI client was called + assert mock_create.called + + @pytest.mark.unit def test_llm_client_vllm_chat_completion(patch_config, test_env): """Test LLM client chat completion with vLLM provider."""