Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions synthetic_data_kit/models/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import asyncio
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider

Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self,
self.max_retries = max_retries or api_endpoint_config.get('max_retries')
self.retry_delay = retry_delay or api_endpoint_config.get('retry_delay')
self.sleep_time = api_endpoint_config.get('sleep_time',0.5)

# Initialize OpenAI client
self._init_openai_client()
else: # Default to vLLM
Expand All @@ -92,6 +93,9 @@ def __init__(self,
self.max_retries = max_retries or vllm_config.get('max_retries')
self.retry_delay = retry_delay or vllm_config.get('retry_delay')
self.sleep_time = vllm_config.get('sleep_time',0.1)
self.max_concurrency = int(
os.environ.get("SDK_VLLM_MAX_CONCURRENCY", vllm_config.get("max_concurrency", 8))
)

# No client to initialize for vLLM as we use requests directly
# Verify server is running
Expand Down Expand Up @@ -533,6 +537,22 @@ async def process_batch():

return results

def _vllm_post_once(self, request_data, verbose: bool) -> str:
"""One /chat/completions call; returns the content string or raises."""
if verbose:
logger.info(f"Sending batch request to vLLM model {self.model}...")
response = requests.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
data=json.dumps(request_data),
timeout=180,
)
if verbose:
logger.info(f"Received response with status code: {response.status_code}")
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]

def _vllm_batch_completion(self,
message_batches: List[List[Dict[str, str]]],
temperature: float,
Expand Down Expand Up @@ -560,33 +580,27 @@ def _vllm_batch_completion(self,
"top_p": top_p
})

# Run requests concurrently while preserving order and isolating errors
max_workers = max(1, min(self.max_concurrency, len(batch_requests)))
batch_results = [""] * len(batch_requests)
try:
# For now, we run these in parallel with multiple requests
batch_results = []
for request_data in batch_requests:
# Only print if verbose mode is enabled
if verbose:
logger.info(f"Sending batch request to vLLM model {self.model}...")

response = requests.post(
f"{self.api_base}/chat/completions",
headers={"Content-Type": "application/json"},
data=json.dumps(request_data),
timeout=180 # Increased timeout for batch processing
)

if verbose:
logger.info(f"Received response with status code: {response.status_code}")

response.raise_for_status()
content = response.json()["choices"][0]["message"]["content"]
batch_results.append(content)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index = {
executor.submit(self._vllm_post_once, req, verbose): idx
for idx, req in enumerate(batch_requests)
}
for future in as_completed(future_to_index):
idx = future_to_index[future]
try:
batch_results[idx] = future.result()
except Exception as e:
# Don't fail the whole batch; record error in-place
batch_results[idx] = f"ERROR: {e}"
results.extend(batch_results)
except (requests.exceptions.RequestException, KeyError, IndexError) as e:
except Exception as e:
# If the executor itself failed for unexpected reasons
raise Exception(f"Failed to process vLLM batch: {str(e)}")

# Small delay between batches
if i + batch_size < len(message_batches):
time.sleep(self.sleep_time)
Expand Down
191 changes: 191 additions & 0 deletions tests/unit/test_llm_client_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import time
import json
import pytest
import requests

from synthetic_data_kit.models.llm_client import LLMClient

# ---------- Live server detection helpers ----------

def _live_base():
return os.getenv("LIVE_VLLM_BASE", "http://localhost:8000/v1")

def _vllm_is_live(base: str, timeout=1.0) -> bool:
try:
r = requests.get(f"{base}/models", timeout=timeout)
return r.status_code == 200
except Exception:
return False

@pytest.fixture(scope="session")
def live_info():
base = _live_base()
return {"base": base, "up": _vllm_is_live(base)}

# ---------- Live path (no mocks) ----------

@pytest.mark.integration
def test_live_vllm_smoke(live_info):
"""
If a real vLLM server is reachable, run a real call.
Otherwise skip this test automatically.
"""
if not live_info["up"]:
pytest.skip(f"No live vLLM at {_live_base()} — skipping live smoke.")

client = LLMClient(
provider="vllm",
api_base=live_info["base"],
model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # must match the server's model
max_retries=2,
retry_delay=0.2,
)
out = client.chat_completion(
messages=[{"role": "user", "content": "Say 'pong' and nothing else."}],
temperature=0.0,
max_tokens=16,
)
assert isinstance(out, str)
assert "pong" in out.lower()

# ---------- Mock path (runs when live is NOT available) ----------

@pytest.fixture
def patch_config(monkeypatch):
"""
Force config -> provider=vllm and point to the default base.
"""
def fake_load_config(_path=None):
return {
"llm": {"provider": "vllm"},
"vllm": {
"api_base": _live_base(),
"model": "fake-model",
"max_retries": 2,
"retry_delay": 0.05,
"sleep_time": 0.01,
"max_concurrency": 4,
},
"generation": {"temperature": 0.0, "max_tokens": 32, "top_p": 0.95, "batch_size": 8},
}
monkeypatch.setattr("synthetic_data_kit.models.llm_client.load_config", fake_load_config)
return fake_load_config

@pytest.fixture
def test_env(monkeypatch, live_info):
"""
If live vLLM is DOWN, we stub _check_vllm_server and requests.post.
If live is UP, we do nothing (so the live smoke test can run).
"""
if live_info["up"]:
# Live server exists; let the live test handle it.
return

# --- stub health check to "available" ---
def fake_check_vllm_server(self):
return True, {"models": [{"id": "fake-model"}]}
monkeypatch.setattr("synthetic_data_kit.models.llm_client.LLMClient._check_vllm_server", fake_check_vllm_server)

# --- stub POST /chat/completions to return deterministic JSON ---
def fake_post(url, headers=None, data=None, timeout=None):
body = json.loads(data or "{}")
# Simulate a small stagger so futures may complete out-of-order
time.sleep(0.01 if "slow" not in (headers or {}) else 0.02)

class Resp:
status_code = 200
def raise_for_status(self): pass
def json(self):
# Echo back the last user message content
msgs = body.get("messages", [])
content = next((m["content"] for m in reversed(msgs) if m.get("role") == "user"), "")
return {"choices": [{"message": {"content": f"echo:{content}"}}]}
return Resp()

monkeypatch.setattr("synthetic_data_kit.models.llm_client.requests.post", fake_post)


# --- Tests that run under mocks when live server is not available ---

@pytest.mark.unit
def test_vllm_batch_completion_preserves_order(patch_config, test_env, live_info):
if live_info["up"]:
pytest.skip("Live vLLM detected; order test is only for the mock path.")

client = LLMClient(provider="vllm") # config fixture supplies base/model
batches = [
[{"role": "user", "content": f"u{i}"}]
for i in range(5)
]
out = client._vllm_batch_completion(
message_batches=batches,
temperature=0.0,
max_tokens=16,
top_p=0.95,
batch_size=5,
verbose=False,
)
assert out == [f"echo:u{i}" for i in range(5)]

@pytest.mark.unit
def test_vllm_batch_completion_isolates_errors(monkeypatch, patch_config, test_env, live_info):
if live_info["up"]:
pytest.skip("Live vLLM detected; error-isolation test is only for the mock path.")

# Patch one of the POSTs to raise
real_post = requests.post

def flaky_post(url, headers=None, data=None, timeout=None):
body = json.loads(data or "{}")
msgs = body.get("messages", [])
content = next((m["content"] for m in reversed(msgs) if m.get("role") == "user"), "")
if content == "boom":
raise requests.exceptions.ConnectionError("simulated failure")
return real_post(url, headers=headers, data=data, timeout=timeout)

monkeypatch.setattr("synthetic_data_kit.models.llm_client.requests.post", flaky_post)

client = LLMClient(provider="vllm")
batches = [
[{"role": "user", "content": "ok1"}],
[{"role": "user", "content": "boom"}], # will fail
[{"role": "user", "content": "ok2"}],
]
out = client._vllm_batch_completion(
message_batches=batches,
temperature=0.0,
max_tokens=16,
top_p=0.95,
batch_size=3,
verbose=False,
)
assert out[0].startswith("echo:ok1")
assert out[1].startswith("ERROR:")
assert out[2].startswith("echo:ok2")

@pytest.mark.unit
def test_vllm_batch_respects_max_concurrency(monkeypatch, patch_config, test_env, live_info):
if live_info["up"]:
pytest.skip("Live vLLM detected; concurrency test is only for the mock path.")

monkeypatch.setenv("SDK_VLLM_MAX_CONCURRENCY", "2")
client = LLMClient(provider="vllm")
assert client.max_concurrency == 2

@pytest.mark.unit
def test_vllm_post_once_happy_path(patch_config, test_env, live_info):
if live_info["up"]:
pytest.skip("Live vLLM detected; post-once test is only for the mock path.")

client = LLMClient(provider="vllm")
payload = {
"model": "fake-model",
"messages": [{"role": "user", "content": "xyz"}],
"temperature": 0.0,
"max_tokens": 8,
"top_p": 0.95,
}
out = client._vllm_post_once(payload, verbose=False)
assert out == "echo:xyz"