Skip to content

Commit f833eb7

Browse files
committed
Split rule regex into predicate/component
1 parent 909a240 commit f833eb7

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

pyreason/scripts/utils/rule_parser.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
1111
from pyreason.scripts.threshold.threshold import Threshold
1212

13-
_IDENTIFIER_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_.\-]*$')
13+
_PREDICATE_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_.\-]*$')
14+
_COMPONENT_RE = re.compile(r'^[a-zA-Z0-9_][a-zA-Z0-9_.@\-]*$')
1415

1516

1617
def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, dict], infer_edges: bool = False, set_static: bool = False, weights: Union[None, np.ndarray] = None) -> rule.Rule:
@@ -482,19 +483,17 @@ def _parse_head_arguments(head_args_str):
482483

483484

484485
def _validate_predicate_name(pred, context):
485-
"""Validate that a predicate name matches ^[a-zA-Z_][a-zA-Z0-9_]*$."""
486-
if not _IDENTIFIER_RE.match(pred):
486+
"""Validate that a predicate name matches ^[a-zA-Z_][a-zA-Z0-9_.\\-]*$."""
487+
if not _PREDICATE_RE.match(pred):
487488
if pred and pred[0].isdigit():
488489
raise ValueError(f"{context} predicate name '{pred}' cannot start with a digit")
489490
raise ValueError(f"{context} predicate name '{pred}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_.\\-]*")
490491

491492

492493
def _validate_component_name(var, context):
493-
"""Validate that a variable name matches ^[a-zA-Z_][a-zA-Z0-9_]*$."""
494-
if not _IDENTIFIER_RE.match(var):
495-
if var and var[0].isdigit():
496-
raise ValueError(f"{context} component name '{var}' cannot start with a digit")
497-
raise ValueError(f"{context} component name '{var}' contains invalid characters. Must match [a-zA-Z_][a-zA-Z0-9_.\\-]*")
494+
"""Validate that a component name matches ^[a-zA-Z0-9_][a-zA-Z0-9_.@\\-]*$."""
495+
if not _COMPONENT_RE.match(var):
496+
raise ValueError(f"{context} component name '{var}' contains invalid characters. Must match [a-zA-Z0-9_][a-zA-Z0-9_.@\\-]*")
498497

499498

500499
def _str_bound_to_bound(str_bound):

0 commit comments

Comments
 (0)