diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index e88dbb59..53e67b4b 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -99,14 +99,41 @@ def system_check( console.print("Install with: pip install openai>=1.0.0", style="yellow") return 1 - # Create client + # If no API key is provided but an api_base is configured (common for + # local Ollama or proxies), try a direct HTTP request first so the + # CLI doesn't require the OpenAI SDK or OPENAI_API_KEY. + if api_base and not api_key: + try: + messages = [{"role": "user", "content": "Hello"}] + payload = { + "model": model, + "messages": messages, + "temperature": 0.1, + } + resp = requests.post(f"{api_base}/chat/completions", json=payload, timeout=10) + if resp.status_code == 200: + console.print(f" API endpoint access confirmed via HTTP: {api_base}", style="green") + console.print(f"Default model: {model}", style="green") + # Try to parse response content if present + try: + j = resp.json() + console.print(f"Response from model: {j.get('choices',[{}])[0].get('message',{}).get('content')}", style="green") + except Exception: + pass + return 0 + # If not 200, fallthrough to SDK-based attempt for more detailed error + except requests.exceptions.RequestException: + # Fall back to attempting with the SDK for a clearer error message + pass + + # Create client (SDK path) client_kwargs = {} if api_key: client_kwargs['api_key'] = api_key if api_base: client_kwargs['base_url'] = api_base - - # Check API access + + # Check API access using the OpenAI SDK try: client = OpenAI(**client_kwargs) # Try a simple models request to check connectivity @@ -134,6 +161,65 @@ def system_check( except Exception as e: console.print(f"L Error: {str(e)}", style="red") return 1 + elif selected_provider == "ollama": + # Get Ollama config + from synthetic_data_kit.utils.config import get_ollama_config + ollama_config = get_ollama_config(ctx.config) + api_base = api_base or ollama_config.get("api_base") + model = ollama_config.get("model") + + with console.status(f"Checking Ollama server at {api_base}..."): + try: + # Check if Ollama server is running by hitting the tags endpoint + response = requests.get(f"{api_base}/api/tags", timeout=5) + if response.status_code == 200: + console.print(f" Ollama server is running at {api_base}", style="green") + tags_data = response.json() + if 'models' in tags_data and tags_data['models']: + console.print(f"Available models: {[m['name'] for m in tags_data['models']]}") + else: + console.print("No models found. Install a model with: ollama pull ", style="yellow") + + # Test generation with a simple prompt + console.print(f"Testing generation with model: {model}", style="blue") + test_payload = { + "model": model, + "prompt": "Hello", + "stream": False, + "options": { + "temperature": 0.1, + "num_predict": 50 + } + } + + try: + gen_response = requests.post(f"{api_base}/api/generate", json=test_payload, timeout=30) + if gen_response.status_code == 200: + gen_data = gen_response.json() + if 'response' in gen_data: + console.print(f"Response from model: {gen_data['response'][:100]}...", style="green") + console.print(" Ollama generation test successful", style="green") + return 0 + else: + console.print(f"Generation test failed with status: {gen_response.status_code}", style="yellow") + console.print(f"Response: {gen_response.text}", style="yellow") + except requests.exceptions.RequestException as e: + console.print(f"Generation test failed: {str(e)}", style="yellow") + + # Server is running but generation failed - still consider it partially successful + return 0 + else: + console.print(f"L Ollama server is not available at {api_base}", style="red") + console.print(f"Error: Server returned status code: {response.status_code}") + except requests.exceptions.RequestException as e: + console.print(f"L Ollama server is not available at {api_base}", style="red") + console.print(f"Error: {str(e)}") + + # Show instruction to start the server + console.print("\nTo start Ollama, run:", style="yellow") + console.print("ollama serve", style="bold blue") + console.print(f"Then install a model: ollama pull {model}", style="bold blue") + return 1 else: # Default to vLLM # Get vLLM server details @@ -338,6 +424,26 @@ def create( api_base = api_base or api_endpoint_config.get("api_base") model = model or api_endpoint_config.get("model") # No server check needed for API endpoint + elif provider == "ollama": + # Use Ollama config + from synthetic_data_kit.utils.config import get_ollama_config + ollama_config = get_ollama_config(ctx.config) + api_base = api_base or ollama_config.get("api_base") + model = model or ollama_config.get("model") + + # Check Ollama server availability + try: + response = requests.get(f"{api_base}/api/tags", timeout=2) + if response.status_code != 200: + console.print(f"❌ Error: Ollama server not available at {api_base}", style="red") + console.print("Please start the Ollama server with:", style="yellow") + console.print(f"ollama serve", style="bold blue") + return 1 + except requests.exceptions.RequestException: + console.print(f"❌ Error: Ollama server not available at {api_base}", style="red") + console.print("Please start the Ollama server with:", style="yellow") + console.print(f"ollama serve", style="bold blue") + return 1 else: # Use vLLM config vllm_config = get_vllm_config(ctx.config) @@ -498,6 +604,26 @@ def curate( api_base = api_base or api_endpoint_config.get("api_base") model = model or api_endpoint_config.get("model") # No server check needed for API endpoint + elif provider == "ollama": + # Use Ollama config + from synthetic_data_kit.utils.config import get_ollama_config + ollama_config = get_ollama_config(ctx.config) + api_base = api_base or ollama_config.get("api_base") + model = model or ollama_config.get("model") + + # Check Ollama server availability + try: + response = requests.get(f"{api_base}/api/tags", timeout=2) + if response.status_code != 200: + console.print(f"❌ Error: Ollama server not available at {api_base}", style="red") + console.print("Please start the Ollama server with:", style="yellow") + console.print(f"ollama serve", style="bold blue") + return 1 + except requests.exceptions.RequestException: + console.print(f"❌ Error: Ollama server not available at {api_base}", style="red") + console.print("Please start the Ollama server with:", style="yellow") + console.print(f"ollama serve", style="bold blue") + return 1 else: # Use vLLM config vllm_config = get_vllm_config(ctx.config) @@ -517,6 +643,11 @@ def curate( console.print("Please start the VLLM server with:", style="yellow") console.print(f"vllm serve {model}", style="bold blue") return 1 + except requests.exceptions.RequestException: + console.print(f"❌ Error: VLLM server not available at {api_base}", style="red") + console.print("Please start the VLLM server with:", style="yellow") + console.print(f"vllm serve {model}", style="bold blue") + return 1 try: # Check if input is a directory diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index 2690697d..1355dc7c 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -14,8 +14,8 @@ paths: # LLM Provider configuration llm: - # Provider selection: "vllm" or "api-endpoint" - provider: "api-endpoint" + # Provider selection: "vllm", "api-endpoint", or "ollama" + provider: "ollama" # VLLM server configuration vllm: @@ -33,6 +33,14 @@ api-endpoint: max_retries: 3 # Number of retries for API calls retry_delay: 1.0 # Initial delay between retries (seconds) +# Ollama configuration +ollama: + api_base: "http://localhost:11434" # Base URL for Ollama API + model: "phi4:latest" # Default model to use + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + sleep_time: 0.1 # Sleep time between batch requests + # Ingest configuration ingest: default_format: "txt" # Default output format for parsed files diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 4f964100..751cf810 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -13,7 +13,7 @@ import asyncio from pathlib import Path -from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider +from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_ollama_config, get_llm_provider # Set up logging logging.basicConfig(level=logging.INFO) @@ -55,33 +55,62 @@ def __init__(self, self.provider = provider or get_llm_provider(self.config) if self.provider == 'api-endpoint': - if not OPENAI_AVAILABLE: - raise ImportError("OpenAI package is not installed. Install with 'pip install openai>=1.0.0'") - # Load API endpoint configuration api_endpoint_config = get_openai_config(self.config) - + # Set parameters, with CLI overrides taking precedence self.api_base = api_base or api_endpoint_config.get('api_base') - - # Check for environment variables + + # Check for environment variables (support multiple names for compatibility) api_endpoint_key = os.environ.get('API_ENDPOINT_KEY') + openai_key = os.environ.get('OPENAI_API_KEY') print(f"API_ENDPOINT_KEY from environment: {'Found' if api_endpoint_key else 'Not found'}") - - # Set API key with priority: CLI arg > env var > config - self.api_key = api_key or api_endpoint_key or api_endpoint_config.get('api_key') - print(f"Using API key: {'From CLI' if api_key else 'From env var' if api_endpoint_key else 'From config' if api_endpoint_config.get('api_key') else 'None'}") - - if not self.api_key and not self.api_base: # Only require API key for official API - raise ValueError("API key is required for API endpoint provider. Set in config or API_ENDPOINT_KEY env var.") - + print(f"OPENAI_API_KEY from environment: {'Found' if openai_key else 'Not found'}") + + # Set API key with priority: CLI arg > env var (API_ENDPOINT_KEY) > OPENAI_API_KEY > config + self.api_key = api_key or api_endpoint_key or openai_key or api_endpoint_config.get('api_key') + print(f"Using API key: {'From CLI' if api_key else 'From env var (API_ENDPOINT_KEY)' if api_endpoint_key else 'From env var (OPENAI_API_KEY)' if openai_key else 'From config' if api_endpoint_config.get('api_key') else 'None'}") + + # Set other parameters self.model = model_name or api_endpoint_config.get('model') self.max_retries = max_retries or api_endpoint_config.get('max_retries') self.retry_delay = retry_delay or api_endpoint_config.get('retry_delay') - self.sleep_time = api_endpoint_config.get('sleep_time',0.5) + self.sleep_time = api_endpoint_config.get('sleep_time', 0.5) + + # Decide whether we can/should use the OpenAI SDK client. If the SDK is installed and + # we have an API key (either provided or via OPENAI_API_KEY), prefer the SDK. Otherwise + # fall back to a direct requests-based implementation which works with local servers + # such as Ollama or local OpenAI-compatible proxies that don't require keys. + self.use_openai_client = False + if OPENAI_AVAILABLE and (self.api_key or os.environ.get('OPENAI_API_KEY')): + try: + # Try initializing the SDK client; if it fails we'll fall back + self._init_openai_client() + self.use_openai_client = True + except Exception as e: + logger.warning(f"OpenAI SDK client initialization failed, falling back to HTTP requests: {e}") + self.use_openai_client = False + else: + # No SDK or no API key — fall back to requests for local endpoints (Ollama, etc.) + if not OPENAI_AVAILABLE: + logger.info("OpenAI SDK not available; using direct HTTP requests for API endpoint provider") + else: + logger.info("OpenAI SDK available but no API key found; using direct HTTP requests for API endpoint provider") + elif self.provider == 'ollama': + # Load Ollama configuration + ollama_config = get_ollama_config(self.config) + + # Set parameters, with CLI overrides taking precedence + self.api_base = api_base or ollama_config.get('api_base') + self.model = model_name or ollama_config.get('model') + self.max_retries = max_retries or ollama_config.get('max_retries') + self.retry_delay = retry_delay or ollama_config.get('retry_delay') + self.sleep_time = ollama_config.get('sleep_time', 0.1) - # Initialize OpenAI client - self._init_openai_client() + # Verify Ollama server is running + available, info = self._check_ollama_server() + if not available: + raise ConnectionError(f"Ollama server not available at {self.api_base}: {info}") else: # Default to vLLM # Load vLLM configuration vllm_config = get_vllm_config(self.config) @@ -128,6 +157,16 @@ def _check_vllm_server(self) -> tuple: except requests.exceptions.RequestException as e: return False, f"Server connection error: {str(e)}" + def _check_ollama_server(self) -> tuple: + """Check if the Ollama server is running and accessible""" + try: + response = requests.get(f"{self.api_base}/api/tags", timeout=5) + if response.status_code == 200: + return True, response.json() + return False, f"Server returned status code: {response.status_code}" + except Exception as e: + return False, f"Server connection error: {str(e)}" + def chat_completion(self, messages: List[Dict[str, str]], temperature: float = None, @@ -154,6 +193,8 @@ def chat_completion(self, if self.provider == 'api-endpoint': return self._openai_chat_completion(messages, temperature, max_tokens, top_p, verbose) + elif self.provider == 'ollama': + return self._ollama_chat_completion(messages, temperature, max_tokens, top_p, verbose) else: # Default to vLLM return self._vllm_chat_completion(messages, temperature, max_tokens, top_p, verbose) @@ -164,6 +205,12 @@ def _openai_chat_completion(self, top_p: float, verbose: bool) -> str: """Generate a chat completion using the OpenAI API or compatible APIs""" + # If we couldn't initialize the OpenAI SDK client (e.g., local Ollama server + # without an API key), fall back to a requests-based implementation that + # posts directly to the configured API base. + if not getattr(self, 'use_openai_client', True): + return self._api_endpoint_chat_completion(messages, temperature, max_tokens, top_p, verbose) + debug_mode = os.environ.get('SDK_DEBUG', 'false').lower() == 'true' if verbose: logger.info(f"Sending request to {self.provider} model {self.model}...") @@ -278,6 +325,149 @@ def _openai_chat_completion(self, raise Exception(f"Failed to get {self.provider} completion after {self.max_retries} attempts: {str(e)}") time.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + + def _api_endpoint_chat_completion(self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: int, + top_p: float, + verbose: bool) -> str: + """Fallback HTTP implementation for OpenAI-compatible API endpoints (including local Ollama).""" + data = { + "model": self.model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + } + + headers = {"Content-Type": "application/json"} + # If an API key is available, add an Authorization header (for proxies that expect it) + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + for attempt in range(self.max_retries): + try: + if verbose: + logger.info(f"Sending HTTP request to API endpoint {self.api_base} for model {self.model}...") + + response = requests.post( + f"{self.api_base}/chat/completions", + headers=headers, + data=json.dumps(data), + timeout=180, + ) + + if verbose: + logger.info(f"Received response with status code: {response.status_code}") + + response.raise_for_status() + resp_json = response.json() + + # Try standard OpenAI-compatible response shape + try: + return resp_json["choices"][0]["message"]["content"] + except Exception: + # Try Llama/Ollama-like shapes + if "completion_message" in resp_json: + comp = resp_json["completion_message"] + if isinstance(comp, dict) and "content" in comp: + content = comp["content"] + if isinstance(content, dict) and "text" in content: + return content["text"] + elif isinstance(content, str): + return content + + # If we couldn't extract, raise to trigger retry/debugging + raise ValueError("Could not extract content from API endpoint response") + + except (requests.exceptions.RequestException, KeyError, IndexError, ValueError) as e: + if verbose: + logger.error(f"API endpoint error (attempt {attempt+1}/{self.max_retries}): {e}") + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get API endpoint completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) + + def _ollama_chat_completion(self, + messages: List[Dict[str, str]], + temperature: float, + max_tokens: int, + top_p: float, + verbose: bool) -> str: + """Generate a chat completion using the Ollama API""" + # Convert OpenAI-style messages to Ollama format + # Ollama expects a single prompt string, so we need to format the messages + prompt = self._format_messages_for_ollama(messages) + + data = { + "model": self.model, + "prompt": prompt, + "options": { + "temperature": temperature, + "num_predict": max_tokens, + "top_p": top_p + }, + "stream": False + } + + for attempt in range(self.max_retries): + try: + if verbose: + logger.info(f"Sending request to Ollama model {self.model}...") + + response = requests.post( + f"{self.api_base}/api/generate", + headers={"Content-Type": "application/json"}, + data=json.dumps(data), + timeout=180 + ) + + if verbose: + logger.info(f"Received response with status code: {response.status_code}") + + response.raise_for_status() + response_json = response.json() + + # Ollama returns the response in the 'response' field + if 'response' in response_json: + return response_json['response'] + else: + raise ValueError("No 'response' field in Ollama response") + + except (requests.exceptions.RequestException, KeyError, ValueError) as e: + if verbose: + logger.error(f"Ollama API error (attempt {attempt+1}/{self.max_retries}): {e}") + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get Ollama completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) + except Exception as e: + if verbose: + logger.error(f"Ollama API error (attempt {attempt+1}/{self.max_retries}): {e}") + if attempt == self.max_retries - 1: + raise Exception(f"Failed to get Ollama completion after {self.max_retries} attempts: {str(e)}") + time.sleep(self.retry_delay * (attempt + 1)) + + def _format_messages_for_ollama(self, messages: List[Dict[str, str]]) -> str: + """Convert OpenAI-style messages to a single prompt string for Ollama""" + formatted_parts = [] + + for message in messages: + role = message.get('role', 'user') + content = message.get('content', '') + + if role == 'system': + formatted_parts.append(f"System: {content}") + elif role == 'user': + formatted_parts.append(f"User: {content}") + elif role == 'assistant': + formatted_parts.append(f"Assistant: {content}") + else: + formatted_parts.append(f"{role.title()}: {content}") + + # Add a final "Assistant:" to prompt for a response + formatted_parts.append("Assistant:") + + return "\n\n".join(formatted_parts) def _vllm_chat_completion(self, messages: List[Dict[str, str]], @@ -340,6 +530,8 @@ def batch_completion(self, if self.provider == 'api-endpoint': return self._openai_batch_completion(message_batches, temperature, max_tokens, top_p, batch_size, verbose) + elif self.provider == 'ollama': + return self._ollama_batch_completion(message_batches, temperature, max_tokens, top_p, batch_size, verbose) else: # Default to vLLM return self._vllm_batch_completion(message_batches, temperature, max_tokens, top_p, batch_size, verbose) @@ -523,8 +715,16 @@ async def process_batch(): # Process all messages in the batch concurrently return await asyncio.gather(*tasks) - # Run the async batch processing - batch_results = asyncio.run(process_batch()) + # If we're not using the OpenAI SDK client, run a synchronous requests-based + # batch here instead of the async SDK-based processing. + if not getattr(self, 'use_openai_client', True): + # Process each messages set synchronously + batch_results = [] + for messages in batch_chunk: + batch_results.append(self._api_endpoint_chat_completion(messages, temperature, max_tokens, top_p, verbose)) + else: + # Run the async batch processing using the SDK + batch_results = asyncio.run(process_batch()) results.extend(batch_results) # Small delay between batches to avoid rate limits @@ -533,6 +733,40 @@ async def process_batch(): return results + def _ollama_batch_completion(self, + message_batches: List[List[Dict[str, str]]], + temperature: float, + max_tokens: int, + top_p: float, + batch_size: int, + verbose: bool) -> List[str]: + """Process multiple message sets in batches using Ollama's API""" + results = [] + + # Process message batches in chunks to avoid overloading the server + for i in range(0, len(message_batches), batch_size): + batch_chunk = message_batches[i:i+batch_size] + if verbose: + logger.info(f"Processing Ollama batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests") + + try: + # Process each request sequentially for Ollama (it typically doesn't handle concurrent well) + batch_results = [] + for messages in batch_chunk: + result = self._ollama_chat_completion(messages, temperature, max_tokens, top_p, verbose) + batch_results.append(result) + + results.extend(batch_results) + + except Exception as e: + raise Exception(f"Failed to process Ollama batch: {str(e)}") + + # Small delay between batches + if i + batch_size < len(message_batches): + time.sleep(self.sleep_time) + + return results + def _vllm_batch_completion(self, message_batches: List[List[Dict[str, str]]], temperature: float, diff --git a/synthetic_data_kit/utils/config.py b/synthetic_data_kit/utils/config.py index e17600f7..f8f26654 100644 --- a/synthetic_data_kit/utils/config.py +++ b/synthetic_data_kit/utils/config.py @@ -79,13 +79,14 @@ def get_llm_provider(config: Dict[str, Any]) -> str: """Get the selected LLM provider Returns: - String with provider name: 'vllm' or 'api-endpoint' + String with provider name: 'vllm', 'api-endpoint', or 'ollama' """ llm_config = config.get('llm', {}) provider = llm_config.get('provider', 'vllm') print(f"get_llm_provider returning: {provider}") - if provider != 'api-endpoint' and 'llm' in config and 'provider' in config['llm'] and config['llm']['provider'] == 'api-endpoint': - print(f"WARNING: Config has 'api-endpoint' but returning '{provider}'") + if provider not in ['vllm', 'api-endpoint', 'ollama']: + print(f"WARNING: Unknown provider '{provider}', falling back to 'vllm'") + return 'vllm' return provider def get_vllm_config(config: Dict[str, Any]) -> Dict[str, Any]: @@ -108,6 +109,16 @@ def get_openai_config(config: Dict[str, Any]) -> Dict[str, Any]: 'retry_delay': 1.0 }) +def get_ollama_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Get Ollama configuration""" + return config.get('ollama', { + 'api_base': 'http://localhost:11434', + 'model': 'llama3.2', + 'max_retries': 3, + 'retry_delay': 1.0, + 'sleep_time': 0.1 + }) + def get_generation_config(config: Dict[str, Any]) -> Dict[str, Any]: """Get generation configuration""" return config.get('generation', { diff --git a/tests/integration/test_ollama_integration.py b/tests/integration/test_ollama_integration.py new file mode 100644 index 00000000..2bff747c --- /dev/null +++ b/tests/integration/test_ollama_integration.py @@ -0,0 +1,204 @@ +"""Integration tests for Ollama provider.""" + +import json +import tempfile +from unittest.mock import MagicMock, patch + +import pytest + +from synthetic_data_kit.models.llm_client import LLMClient + + +@pytest.mark.integration +def test_ollama_client_integration(patch_config, test_env): + """Test Ollama client integration with mocked server responses.""" + with patch("requests.get") as mock_get, patch("requests.post") as mock_post: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = { + "models": [ + { + "name": "llama3.2:latest", + "modified_at": "2024-01-01T00:00:00Z", + "size": 4684876794 + } + ] + } + mock_get.return_value = mock_check_response + + # Mock Ollama generation response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "model": "llama3.2", + "created_at": "2024-01-01T00:00:00Z", + "response": "Synthetic data is artificially generated data that mimics real-world data patterns.", + "done": True + } + mock_post.return_value = mock_response + + # Initialize Ollama client + client = LLMClient(provider="ollama") + + # Test single chat completion + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is synthetic data?"} + ] + + response = client.chat_completion(messages, temperature=0.7, max_tokens=100) + + # Verify response + assert isinstance(response, str) + assert len(response) > 0 + assert "synthetic data" in response.lower() + + # Verify API calls + assert mock_get.called # Server check + assert mock_post.called # Generation call + + # Verify the generation endpoint was called + post_call_args = mock_post.call_args + assert "/api/generate" in post_call_args[0][0] + + # Verify request payload structure + request_data = json.loads(post_call_args[1]['data']) + assert 'model' in request_data + assert 'prompt' in request_data + assert 'options' in request_data + assert request_data['stream'] is False + + +@pytest.mark.integration +def test_ollama_batch_processing(patch_config, test_env): + """Test Ollama client batch processing.""" + with patch("requests.get") as mock_get, patch("requests.post") as mock_post: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_check_response + + # Mock Ollama generation responses + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = [ + {"response": "Response to question 1 about synthetic data."}, + {"response": "Response to question 2 about machine learning."}, + {"response": "Response to question 3 about data privacy."} + ] + mock_post.return_value = mock_response + + # Initialize Ollama client + client = LLMClient(provider="ollama") + + # Test batch processing + message_batches = [ + [{"role": "user", "content": "What is synthetic data?"}], + [{"role": "user", "content": "How is machine learning used?"}], + [{"role": "user", "content": "Why is data privacy important?"}] + ] + + responses = client.batch_completion( + message_batches, + temperature=0.7, + max_tokens=100, + batch_size=2 + ) + + # Verify responses + assert len(responses) == 3 + for i, response in enumerate(responses): + assert isinstance(response, str) + assert len(response) > 0 + assert f"Response to question {i+1}" in response + + # Verify API calls + assert mock_get.called # Server check + assert mock_post.call_count == 3 # Three generation calls + + +@pytest.mark.integration +def test_ollama_message_formatting_edge_cases(patch_config, test_env): + """Test Ollama message formatting with various edge cases.""" + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_response + + client = LLMClient(provider="ollama") + + # Test empty messages + messages = [] + formatted = client._format_messages_for_ollama(messages) + assert formatted == "Assistant:" + + # Test single system message + messages = [{"role": "system", "content": "You are helpful."}] + formatted = client._format_messages_for_ollama(messages) + expected = "System: You are helpful.\n\nAssistant:" + assert formatted == expected + + # Test unknown role + messages = [{"role": "custom", "content": "Custom content"}] + formatted = client._format_messages_for_ollama(messages) + expected = "Custom: Custom content\n\nAssistant:" + assert formatted == expected + + # Test empty content + messages = [{"role": "user", "content": ""}] + formatted = client._format_messages_for_ollama(messages) + expected = "User: \n\nAssistant:" + assert formatted == expected + + +@pytest.mark.integration +def test_ollama_error_handling(patch_config, test_env): + """Test Ollama client error handling.""" + with patch("requests.get") as mock_get, patch("requests.post") as mock_post: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_check_response + + # Initialize Ollama client + client = LLMClient(provider="ollama") + + # Test HTTP error handling + mock_post.side_effect = Exception("Connection error") + + messages = [{"role": "user", "content": "Test message"}] + + with pytest.raises(Exception) as exc_info: + client.chat_completion(messages) + + assert "Failed to get Ollama completion" in str(exc_info.value) + + # Test malformed response + mock_post.side_effect = None + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"invalid": "response"} # Missing 'response' field + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + client.chat_completion(messages) + + assert "Failed to get Ollama completion" in str(exc_info.value) + + +@pytest.mark.integration +def test_ollama_server_unavailable(patch_config, test_env): + """Test Ollama client when server is unavailable.""" + with patch("requests.get") as mock_get: + # Mock server unavailable + mock_get.side_effect = Exception("Connection refused") + + # Should raise ConnectionError during initialization + with pytest.raises(ConnectionError) as exc_info: + LLMClient(provider="ollama") + + assert "Ollama server not available" in str(exc_info.value) \ No newline at end of file diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py index d030f652..cd93a9f1 100644 --- a/tests/unit/test_llm_client.py +++ b/tests/unit/test_llm_client.py @@ -123,3 +123,129 @@ def test_llm_client_vllm_chat_completion(patch_config, test_env): assert response == "This is a test response" # Check that vLLM API was called assert mock_post.called + + +@pytest.mark.unit +def test_llm_client_ollama_initialization(patch_config, test_env): + """Test LLM client initialization with Ollama provider.""" + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_response + + # Initialize client + client = LLMClient(provider="ollama") + + # Check that the client was initialized correctly + assert client.provider == "ollama" + assert client.api_base is not None + assert client.model is not None + # Check that Ollama server was checked + assert mock_get.called + + +@pytest.mark.unit +def test_llm_client_ollama_chat_completion(patch_config, test_env): + """Test LLM client chat completion with Ollama provider.""" + with patch("requests.get") as mock_get, patch("requests.post") as mock_post: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_check_response + + # Mock Ollama API response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": "This is a test response from Ollama" + } + mock_post.return_value = mock_response + + # Initialize client + client = LLMClient(provider="ollama") + + # Test chat completion + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is synthetic data?"}, + ] + + response = client.chat_completion(messages, temperature=0.7) + + # Check that the response is correct + assert response == "This is a test response from Ollama" + # Check that Ollama API was called + assert mock_post.called + + # Check that the correct endpoint was called + call_args = mock_post.call_args + assert "/api/generate" in call_args[0][0] + + +@pytest.mark.unit +def test_ollama_message_formatting(): + """Test Ollama message formatting helper function.""" + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_response + + client = LLMClient(provider="ollama") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is synthetic data?"}, + {"role": "assistant", "content": "Synthetic data is artificially generated data."}, + {"role": "user", "content": "Can you explain more?"} + ] + + formatted = client._format_messages_for_ollama(messages) + + expected = ("System: You are a helpful assistant.\n\n" + "User: What is synthetic data?\n\n" + "Assistant: Synthetic data is artificially generated data.\n\n" + "User: Can you explain more?\n\n" + "Assistant:") + + assert formatted == expected + + +@pytest.mark.unit +def test_llm_client_ollama_batch_completion(patch_config, test_env): + """Test LLM client batch completion with Ollama provider.""" + with patch("requests.get") as mock_get, patch("requests.post") as mock_post: + # Mock Ollama server check + mock_check_response = MagicMock() + mock_check_response.status_code = 200 + mock_check_response.json.return_value = {"models": [{"name": "llama3.2"}]} + mock_get.return_value = mock_check_response + + # Mock Ollama API responses + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = [ + {"response": "Response 1"}, + {"response": "Response 2"} + ] + mock_post.return_value = mock_response + + # Initialize client + client = LLMClient(provider="ollama") + + # Test batch completion + message_batches = [ + [{"role": "user", "content": "Hello 1"}], + [{"role": "user", "content": "Hello 2"}] + ] + + responses = client.batch_completion(message_batches, temperature=0.7, batch_size=1) + + # Check that the responses are correct + assert len(responses) == 2 + assert responses[0] == "Response 1" + assert responses[1] == "Response 2" + # Check that Ollama API was called twice + assert mock_post.call_count == 2