Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/fory_compiler/ir/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __repr__(self) -> str:

@dataclass
class NamedType:
"""A reference to a user-defined type (message or enum)."""
"""A reference to a user-defined type (message, enum or union)."""

name: str
location: Optional[SourceLocation] = None
Expand Down
47 changes: 34 additions & 13 deletions compiler/fory_compiler/ir/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@
from fory_compiler.ir.type_id import compute_registered_type_id

INVALID_MAP_KEY_KINDS = {
PrimitiveKind.ANY,
PrimitiveKind.BYTES,
PrimitiveKind.FLOAT16,
PrimitiveKind.BFLOAT16,
PrimitiveKind.FLOAT32,
PrimitiveKind.FLOAT64,
PrimitiveKind.DECIMAL,
}
INVALID_MAP_KEY_MESSAGE = (
"map keys do not support binary, float, decimal, list, map, or array types"
)
INVALID_MAP_KEY_MESSAGE = "map keys do not support any, binary, float, decimal, message, union, list, map, or array types"


@dataclass
Expand Down Expand Up @@ -547,12 +546,30 @@ def check_message_fields(
check_field(f, None)

def _check_collection_type_rules(self) -> None:
def invalid_map_key(field_type: FieldType) -> bool:
def invalid_map_key(
field_type: FieldType,
enclosing_messages: Optional[List[Message]],
) -> bool:
if isinstance(field_type, PrimitiveType):
return field_type.kind in INVALID_MAP_KEY_KINDS
if isinstance(field_type, NamedType):
if enclosing_messages is not None:
resolved = self._resolve_named_type(
field_type.name, enclosing_messages
)
else:
resolved = self._find_top_level_type(field_type.name)
return isinstance(
resolved, (Message, Union)
) # message and union cannot be used as map key types.
return isinstance(field_type, (ListType, ArrayType, MapType))

def check_type(field_type: FieldType, field: Field, in_map_key: bool = False):
def check_type(
field_type: FieldType,
field: Field,
enclosing_messages: Optional[List[Message]] = None,
in_map_key: bool = False,
):
if isinstance(field_type, ArrayType):
if in_map_key:
self._error(INVALID_MAP_KEY_MESSAGE, field.location)
Expand Down Expand Up @@ -580,26 +597,30 @@ def check_type(field_type: FieldType, field: Field, in_map_key: bool = False):
if in_map_key:
self._error(INVALID_MAP_KEY_MESSAGE, field.location)
return
check_type(field_type.element_type, field)
check_type(field_type.element_type, field, enclosing_messages)
elif isinstance(field_type, MapType):
if in_map_key:
self._error(INVALID_MAP_KEY_MESSAGE, field.location)
return
key_type = field_type.key_type
if invalid_map_key(key_type):
if invalid_map_key(key_type, enclosing_messages):
self._error(INVALID_MAP_KEY_MESSAGE, field.location)
else:
check_type(key_type, field, in_map_key=True)
check_type(field_type.value_type, field)
check_type(key_type, field, enclosing_messages, in_map_key=True)
check_type(field_type.value_type, field, enclosing_messages)

def check_message_fields(message: Message) -> None:
def check_message_fields(
message: Message,
enclosing_messages: Optional[List[Message]] = None,
) -> None:
lineage = (enclosing_messages or []) + [message]
for f in message.fields:
check_type(f.field_type, f)
check_type(f.field_type, f, lineage)
for nested_msg in message.nested_messages:
check_message_fields(nested_msg)
check_message_fields(nested_msg, lineage)
for nested_union in message.nested_unions:
for f in nested_union.fields:
check_type(f.field_type, f)
check_type(f.field_type, f, lineage)

for message in self.schema.messages:
check_message_fields(message)
Expand Down
2 changes: 1 addition & 1 deletion compiler/fory_compiler/tests/test_weak_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_list_and_map_ref_options_preserve_explicit_opt_out():
message Holder {
list<ref Foo> foos = 1;
list<ref(weak=true, thread_safe=false) Bar> bars = 2;
map<Foo, ref(weak=true, thread_safe=false) Bar> bar_map = 3;
map<string, ref(weak=true, thread_safe=false) Bar> bar_map = 3;
}
"""
schema = parse_schema(source)
Expand Down
129 changes: 128 additions & 1 deletion compiler/fory_compiler/tests/test_xlang_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_array_rejects_optional_or_ref_elements_at_parse_time():
@pytest.mark.parametrize(
"key_type",
[
"any",
"bytes",
"float16",
"bfloat16",
Expand All @@ -188,12 +189,138 @@ def test_map_rejects_non_portable_key_types(key_type):

assert not ok
assert any(
"map keys do not support binary, float, decimal, list, map, or array types"
"map keys do not support any, binary, float, decimal, message, union, list, map, or array types"
in err.message
for err in validator.errors
)


@pytest.mark.parametrize(
"source",
[
"""
message Key {
string id = 1;
}

message InvalidMap {
map<Key, string> values = 1;
}
""",
"""
union Choice {
string text = 1;
}

message InvalidMap {
map<Choice, string> values = 1;
}
""",
"""
message Key {
string id = 1;
}

union InvalidUnion {
map<Key, string> values = 1;
}
""",
"""
message Outer {
message Key {
string id = 1;
}

map<Key, string> values = 1;
}
""",
"""
message Outer {
message Key {
string id = 1;
}
}

message InvalidMap {
map<Outer.Key, string> values = 1;
}
""",
"""
message Outer {
union Choice {
string text = 1;
}

map<Choice, string> values = 1;
}
""",
"""
message Outer {
union Choice {
string text = 1;
}
}

message InvalidMap {
map<Outer.Choice, string> values = 1;
}
""",
],
)
def test_map_rejects_message_and_union_key_types(source):
_schema, validator, ok = validate_schema(source)

assert not ok
assert any(
"map keys do not support any, binary, float, decimal, message, union, list, map, or array types"
in err.message
for err in validator.errors
)


@pytest.mark.parametrize(
"source",
[
"""
enum Status {
UNKNOWN = 0;
READY = 1;
}

message Holder {
map<Status, string> values = 1;
}
""",
"""
message Outer {
enum Status {
UNKNOWN = 0;
READY = 1;
}

map<Status, string> values = 1;
}
""",
"""
message Outer {
enum Status {
UNKNOWN = 0;
READY = 1;
}
}

message Holder {
map<Outer.Status, string> values = 1;
}
""",
],
)
def test_map_accepts_enum_key_types(source):
_schema, validator, ok = validate_schema(source)

assert ok, validator.errors


def test_proto_repeated_fields_remain_list_type():
schema = ProtoFrontend().parse(
"""
Expand Down
6 changes: 3 additions & 3 deletions docs/compiler/schema-idl.md
Original file line number Diff line number Diff line change
Expand Up @@ -1497,9 +1497,9 @@ message Config {
- Temporal scalar types (`date`, `timestamp`, `duration`)
- Enums

Map keys do not support binary `bytes`, floating-point types, `decimal`, `list<T>`, `array<T>`,
or nested `map<K, V>` types. Put those types in map values or wrap them in a message with a
portable scalar or enum key.
Map keys do not support `any`, binary `bytes`, floating-point types, `decimal`, message types,
union types, `list<T>`, `array<T>`, or nested `map<K, V>` types. Put those types in map values or
wrap them in a message with a portable scalar or enum key.

### Type Compatibility Matrix

Expand Down
Loading