Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions hindsight-api-slim/hindsight_api/engine/providers/gemini_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,16 +258,13 @@ async def call(
else:
gemini_contents.append(genai_types.Content(role="user", parts=[genai_types.Part(text=content)]))

# Add the JSON schema as a textual hint in the system_instruction (matching
# the normal uncached path). Structured output is still enforced via
# response_schema regardless; this is just guidance text.
if response_format is not None and hasattr(response_format, "model_json_schema"):
def _system_instruction_with_schema() -> str:
schema = response_format.model_json_schema()
schema_msg = f"\n\nYou must respond with valid JSON matching this schema:\n{json.dumps(schema, indent=2, ensure_ascii=False)}"
if system_instruction:
system_instruction += schema_msg
else:
system_instruction = schema_msg
schema_msg = (
f"\n\nYou must respond with valid JSON matching this schema:\n"
f"{json.dumps(schema, indent=2, ensure_ascii=False)}"
)
return (system_instruction + schema_msg) if system_instruction else schema_msg

# Apply safety settings: context var (per-request bank override) takes precedence over instance default
effective_safety_settings = _safety_settings_ctx.get()
Expand All @@ -287,9 +284,15 @@ def _build_generation_config(use_cache: bool) -> "genai_types.GenerateContentCon
self._apply_service_tier(config_kwargs)
if use_cache:
config_kwargs["cached_content"] = cached_prefix
elif (
use_schema_prompt_fallback
and response_format is not None
and hasattr(response_format, "model_json_schema")
):
config_kwargs["system_instruction"] = _system_instruction_with_schema()
elif system_instruction:
config_kwargs["system_instruction"] = system_instruction
if response_format is not None:
if response_format is not None and not use_schema_prompt_fallback:
config_kwargs["response_mime_type"] = "application/json"
config_kwargs["response_schema"] = response_format
if temperature is not None:
Expand All @@ -307,6 +310,7 @@ def _build_generation_config(use_cache: bool) -> "genai_types.GenerateContentCon
return genai_types.GenerateContentConfig(**config_kwargs) if config_kwargs else None

cache_active = using_cache
use_schema_prompt_fallback = False
generation_config = _build_generation_config(cache_active)

last_exception = None
Expand Down Expand Up @@ -430,6 +434,19 @@ def _build_generation_config(use_cache: bool) -> "genai_types.GenerateContentCon

except json.JSONDecodeError as e:
last_exception = e
if (
attempt < max_retries
and response_format is not None
and hasattr(response_format, "model_json_schema")
and not cache_active
and not use_schema_prompt_fallback
):
logger.warning("Gemini returned invalid JSON, retrying with prompt-side schema guidance...")
cache_active = False
use_schema_prompt_fallback = True
generation_config = _build_generation_config(cache_active)
await asyncio.sleep(min(initial_backoff * (2**attempt), max_backoff))
continue
if attempt < max_retries:
logger.warning("Gemini returned invalid JSON, retrying...")
backoff = min(initial_backoff * (2**attempt), max_backoff)
Expand Down
140 changes: 140 additions & 0 deletions hindsight-api-slim/tests/test_llm_extra_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,146 @@ async def test_gemini_extra_body_service_tier_takes_precedence():
assert provider._extra_body["http_options"]["extra_body"]["service_tier"] == "standard"


@pytest.mark.asyncio
async def test_gemini_structured_call_uses_native_schema_without_prompt_duplicate():
"""Structured Gemini calls send schema through response_schema only."""
from pydantic import BaseModel

class StructuredAnswer(BaseModel):
answer: str

provider = _make_gemini_provider()
response = _fake_gemini_response()
response.text = '{"answer": "ok"}'
provider._client.aio.models.generate_content = AsyncMock(return_value=response)

result = await provider.call(
messages=[
{"role": "system", "content": "Return concise JSON."},
{"role": "user", "content": "hello"},
],
response_format=StructuredAnswer,
scope="test",
)

config_arg = provider._client.aio.models.generate_content.call_args.kwargs.get("config")
assert result.answer == "ok"
assert config_arg.response_mime_type == "application/json"
assert config_arg.response_schema is StructuredAnswer
assert config_arg.system_instruction == "Return concise JSON."
assert "valid JSON matching this schema" not in config_arg.system_instruction


@pytest.mark.asyncio
async def test_gemini_cached_structured_call_keeps_native_schema():
"""Cached Gemini calls still send response_schema per request."""
from pydantic import BaseModel

class StructuredAnswer(BaseModel):
answer: str

provider = _make_gemini_provider()
response = _fake_gemini_response()
response.text = '{"answer": "ok"}'
provider._client.aio.models.generate_content = AsyncMock(return_value=response)

result = await provider.call(
messages=[
{"role": "system", "content": "Return concise JSON."},
{"role": "user", "content": "hello"},
],
response_format=StructuredAnswer,
cached_prefix="cachedContents/test",
scope="test",
)

config_arg = provider._client.aio.models.generate_content.call_args.kwargs.get("config")
assert result.answer == "ok"
assert config_arg.cached_content == "cachedContents/test"
assert config_arg.system_instruction is None
assert config_arg.response_mime_type == "application/json"
assert config_arg.response_schema is StructuredAnswer


@pytest.mark.asyncio
async def test_gemini_structured_parse_failure_falls_back_to_prompt_schema():
"""Malformed native-schema output gets one prompt-schema compatibility retry."""
from pydantic import BaseModel

class StructuredAnswer(BaseModel):
answer: str

provider = _make_gemini_provider()
invalid = _fake_gemini_response()
invalid.text = "not json"
valid = _fake_gemini_response()
valid.text = '{"answer": "ok"}'
provider._client.aio.models.generate_content = AsyncMock(side_effect=[invalid, valid])

result = await provider.call(
messages=[
{"role": "system", "content": "Return concise JSON."},
{"role": "user", "content": "hello"},
],
response_format=StructuredAnswer,
scope="test",
max_retries=1,
initial_backoff=0,
max_backoff=0,
)

first_config = provider._client.aio.models.generate_content.call_args_list[0].kwargs["config"]
fallback_config = provider._client.aio.models.generate_content.call_args_list[1].kwargs["config"]

assert result.answer == "ok"
assert first_config.response_schema is StructuredAnswer
assert first_config.system_instruction == "Return concise JSON."
assert fallback_config.response_schema is None
assert fallback_config.response_mime_type is None
assert fallback_config.system_instruction.startswith("Return concise JSON.")
assert "valid JSON matching this schema" in fallback_config.system_instruction
assert '"answer"' in fallback_config.system_instruction


@pytest.mark.asyncio
async def test_gemini_cached_parse_retry_keeps_cached_native_schema():
"""Cached structured retries keep cache context instead of switching prompts."""
from pydantic import BaseModel

class StructuredAnswer(BaseModel):
answer: str

provider = _make_gemini_provider()
invalid = _fake_gemini_response()
invalid.text = "not json"
valid = _fake_gemini_response()
valid.text = '{"answer": "ok"}'
provider._client.aio.models.generate_content = AsyncMock(side_effect=[invalid, valid])

result = await provider.call(
messages=[
{"role": "system", "content": "Return concise JSON."},
{"role": "user", "content": "hello"},
],
response_format=StructuredAnswer,
cached_prefix="cachedContents/test",
scope="test",
max_retries=1,
initial_backoff=0,
max_backoff=0,
)

first_config = provider._client.aio.models.generate_content.call_args_list[0].kwargs["config"]
retry_config = provider._client.aio.models.generate_content.call_args_list[1].kwargs["config"]

assert result.answer == "ok"
assert first_config.cached_content == "cachedContents/test"
assert retry_config.cached_content == "cachedContents/test"
assert retry_config.response_schema is StructuredAnswer
assert retry_config.response_mime_type == "application/json"
assert retry_config.system_instruction is None


# ─── LiteLLM ──────────────────────────────────────────────────────────────────


Expand Down