Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 38 additions & 42 deletions generators/python/core_utilities/shared/pydantic_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
from collections import defaultdict
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -209,6 +208,30 @@ def _is_string_type(type_: Type[Any]) -> bool:
return False


_sse_type_info_cache: Dict[int, Tuple[Optional[str], Optional[List[Type[Any]]]]] = {}


def _get_sse_type_info(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
key = id(type_)
info = _sse_type_info_cache.get(key)
if info is None:
info = _get_discriminator_and_variants(type_)
_sse_type_info_cache[key] = info
return info


def _validate_sse_data(type_: Type[T], data: Any) -> T:
"""Validate SSE wire data directly, bypassing convert_and_respect_annotation_metadata.

SSE data arrives as JSON from the wire and already uses wire-format keys,
so the TypedDict dealiasing pass is unnecessary and extremely expensive
for large discriminated unions.
"""
if IS_PYDANTIC_V2:
return _get_type_adapter(type_).validate_python(data) # type: ignore[no-any-return]
return pydantic.parse_obj_as(type_, data)


def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.
Expand All @@ -219,93 +242,66 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
The union describes the data content, not the SSE envelope.
-> Returns: json.loads(data) parsed into the type

Example: ChatStreamResponse with discriminator='type'
Input: ServerSentEvent(event="message", data='{"type": "content-delta", ...}', id="")
Output: ContentDeltaEvent (parsed from data, SSE envelope stripped)

2. Event-level discrimination: The discriminator (e.g., 'event') is at the SSE event level.
The union describes the full SSE event structure.
-> Returns: SSE envelope with 'data' field JSON-parsed only if the variant expects non-string

Example: JobStreamResponse with discriminator='event'
Input: ServerSentEvent(event="ERROR", data='{"code": "FAILED", ...}', id="123")
Output: JobStreamResponse_Error with data as ErrorData object

But for variants where data is str (like STATUS_UPDATE):
Input: ServerSentEvent(event="STATUS_UPDATE", data='{"status": "processing"}', id="1")
Output: JobStreamResponse_StatusUpdate with data as string (not parsed)

Args:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type

Returns:
The parsed object of type T

Note:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)
data_value = sse.data
discriminator, variants = _get_sse_type_info(type_)

if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry})

data_value = sse_event.get("data")
sse_fields = {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry}

# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
if discriminator in sse_fields:
# Event-level discrimination
disc_value = sse_fields.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)

if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
sse_fields["data"] = json.loads(data_value)
return _validate_sse_data(type_, sse_fields)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)

else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
# Data-level discrimination
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)


_type_adapter_cache: Dict[int, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
from collections import defaultdict
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -207,6 +206,30 @@ def _is_string_type(type_: Type[Any]) -> bool:
return False


_sse_type_info_cache: Dict[int, Tuple[Optional[str], Optional[List[Type[Any]]]]] = {}


def _get_sse_type_info(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
key = id(type_)
info = _sse_type_info_cache.get(key)
if info is None:
info = _get_discriminator_and_variants(type_)
_sse_type_info_cache[key] = info
return info


def _validate_sse_data(type_: Type[T], data: Any) -> T:
"""Validate SSE wire data directly, bypassing convert_and_respect_annotation_metadata.

SSE data arrives as JSON from the wire and already uses wire-format keys,
so the TypedDict dealiasing pass is unnecessary and extremely expensive
for large discriminated unions.
"""
if IS_PYDANTIC_V2:
return _get_type_adapter(type_).validate_python(data) # type: ignore[no-any-return]
return pydantic.parse_obj_as(type_, data)


def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.
Expand All @@ -225,73 +248,58 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type

Returns:
The parsed object of type T

Note:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)
data_value = sse.data
discriminator, variants = _get_sse_type_info(type_)

if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry})

data_value = sse_event.get("data")
sse_fields = {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry}

# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
if discriminator in sse_fields:
# Event-level discrimination
disc_value = sse_fields.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)

if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
sse_fields["data"] = json.loads(data_value)
return _validate_sse_data(type_, sse_fields)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)

else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
# Data-level discrimination
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)


_type_adapter_cache: Dict[int, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import logging
import warnings
from collections import defaultdict
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast

import pydantic
Expand Down Expand Up @@ -125,6 +124,28 @@ def _is_string_type(type_: Type[Any]) -> bool:
return False


_sse_type_info_cache: Dict[int, Tuple[Optional[str], Optional[List[Type[Any]]]]] = {}


def _get_sse_type_info(type_: Type[Any]) -> Tuple[Optional[str], Optional[List[Type[Any]]]]:
key = id(type_)
info = _sse_type_info_cache.get(key)
if info is None:
info = _get_discriminator_and_variants(type_)
_sse_type_info_cache[key] = info
return info


def _validate_sse_data(type_: Type[T], data: Any) -> T:
"""Validate SSE wire data directly, bypassing convert_and_respect_annotation_metadata.

SSE data arrives as JSON from the wire and already uses wire-format keys,
so the TypedDict dealiasing pass is unnecessary and extremely expensive
for large discriminated unions.
"""
return pydantic.v1.parse_obj_as(type_, data)


def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
"""
Parse a ServerSentEvent into the appropriate type.
Expand All @@ -143,73 +164,58 @@ def parse_sse_obj(sse: "ServerSentEvent", type_: Type[T]) -> T:
sse: The ServerSentEvent object to parse
type_: The target discriminated union type

Returns:
The parsed object of type T

Note:
This function is only available in SDK contexts where http_sse module exists.
"""
sse_event = asdict(sse)
discriminator, variants = _get_discriminator_and_variants(type_)
data_value = sse.data
discriminator, variants = _get_sse_type_info(type_)

if discriminator is None or variants is None:
# Not a discriminated union - parse the data field as JSON
data_value = sse_event.get("data")
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry})

data_value = sse_event.get("data")
sse_fields = {"event": sse.event, "data": data_value, "id": sse.id, "retry": sse.retry}

# Check if discriminator is at the top level (event-level discrimination)
if discriminator in sse_event:
# Case 2: Event-level discrimination
# Find the matching variant to check if 'data' field needs JSON parsing
disc_value = sse_event.get(discriminator)
if discriminator in sse_fields:
# Event-level discrimination
disc_value = sse_fields.get(discriminator)
matching_variant = _find_variant_by_discriminator(variants, discriminator, disc_value)

if matching_variant is not None:
# Check what type the variant expects for 'data'
data_type = _get_field_annotation_v1(matching_variant, "data")
if data_type is not None and not _is_string_type(data_type):
# Variant expects non-string data - parse JSON
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
new_object = dict(sse_event)
new_object["data"] = parsed_data
return parse_obj_as(type_, new_object)
sse_fields["data"] = json.loads(data_value)
return _validate_sse_data(type_, sse_fields)
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for event-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
# Either no matching variant, data is string type, or JSON parse failed
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)

else:
# Case 1: Data-level discrimination
# The discriminator is inside the data payload - extract and parse data only
# Data-level discrimination
if isinstance(data_value, str) and data_value:
try:
parsed_data = json.loads(data_value)
return parse_obj_as(type_, parsed_data)
return _validate_sse_data(type_, json.loads(data_value))
except json.JSONDecodeError as e:
_logger.warning(
"Failed to parse SSE data field as JSON for data-level discrimination: %s, data: %s",
e,
data_value[:100] if len(data_value) > 100 else data_value,
)
return parse_obj_as(type_, sse_event)
return _validate_sse_data(type_, sse_fields)


def parse_obj_as(type_: Type[T], object_: Any) -> T:
Expand Down
Loading