Skip to content

Commit 4c0511d

Browse files
authored
Merge branch 'main' into add-rule-parser-validation
2 parents b3cf9b0 + 377940c commit 4c0511d

7 files changed

Lines changed: 98 additions & 15 deletions

File tree

pyreason/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
print('PyReason initialized!')
3838
print()
3939

40-
# Update cache status
41-
cache_status['initialized'] = True
42-
with open(cache_status_path, 'w') as file:
43-
yaml.dump(cache_status, file)
40+
# Update cache status (skip under test runners to keep repo file clean)
41+
import sys
42+
if 'pytest' not in sys.modules and 'unittest' not in sys.modules:
43+
cache_status['initialized'] = True
44+
with open(cache_status_path, 'w') as file:
45+
yaml.dump(cache_status, file)

pyreason/scripts/interpretation/interpretation.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,8 +1973,22 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node
19731973
@numba.njit(cache=True)
19741974
def float_to_str(value):
19751975
number = int(value)
1976-
decimal = int(value % 1 * 1000)
1977-
float_str = f'{number}.{decimal}'
1976+
decimal = int(round(abs(value) % 1 * 1000))
1977+
1978+
# Manual zero-padding (numba may not support :03d in f-strings)
1979+
if decimal < 10:
1980+
decimal_str = f'00{decimal}'
1981+
elif decimal < 100:
1982+
decimal_str = f'0{decimal}'
1983+
else:
1984+
decimal_str = f'{decimal}'
1985+
1986+
# Handle negative values where int() truncates to 0 (e.g., -0.123)
1987+
if value < 0 and number == 0:
1988+
float_str = f'-{number}.{decimal_str}'
1989+
else:
1990+
float_str = f'{number}.{decimal_str}'
1991+
19781992
return float_str
19791993

19801994

pyreason/scripts/interpretation/interpretation_fp.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,8 +2091,22 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node
20912091
@numba.njit(cache=True)
20922092
def float_to_str(value):
20932093
number = int(value)
2094-
decimal = int(value % 1 * 1000)
2095-
float_str = f'{number}.{decimal}'
2094+
decimal = int(round(abs(value) % 1 * 1000))
2095+
2096+
# Manual zero-padding (numba may not support :03d in f-strings)
2097+
if decimal < 10:
2098+
decimal_str = f'00{decimal}'
2099+
elif decimal < 100:
2100+
decimal_str = f'0{decimal}'
2101+
else:
2102+
decimal_str = f'{decimal}'
2103+
2104+
# Handle negative values where int() truncates to 0 (e.g., -0.123)
2105+
if value < 0 and number == 0:
2106+
float_str = f'-{number}.{decimal_str}'
2107+
else:
2108+
float_str = f'{number}.{decimal_str}'
2109+
20962110
return float_str
20972111

20982112

pyreason/scripts/interpretation/interpretation_parallel.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,8 +1973,22 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node
19731973
@numba.njit(cache=True)
19741974
def float_to_str(value):
19751975
number = int(value)
1976-
decimal = int(value % 1 * 1000)
1977-
float_str = f'{number}.{decimal}'
1976+
decimal = int(round(abs(value) % 1 * 1000))
1977+
1978+
# Manual zero-padding (numba may not support :03d in f-strings)
1979+
if decimal < 10:
1980+
decimal_str = f'00{decimal}'
1981+
elif decimal < 100:
1982+
decimal_str = f'0{decimal}'
1983+
else:
1984+
decimal_str = f'{decimal}'
1985+
1986+
# Handle negative values where int() truncates to 0 (e.g., -0.123)
1987+
if value < 0 and number == 0:
1988+
float_str = f'-{number}.{decimal_str}'
1989+
else:
1990+
float_str = f'{number}.{decimal_str}'
1991+
19781992
return float_str
19791993

19801994

pyreason/scripts/interval/interval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def intersection(self, interval):
6464
lower = max(self.lower, interval.lower)
6565
upper = min(self.upper, interval.upper)
6666
if lower > upper:
67-
lower = np.float32(0)
68-
upper = np.float32(1)
67+
lower = np.float64(0)
68+
upper = np.float64(1)
6969
return Interval(lower, upper, False, self.lower, self.upper)
7070

7171
def to_str(self):

pyreason/scripts/numba_wrapper/numba_types/interval_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def impl(self, interval):
5858
lower = max(self.lower, interval.lower)
5959
upper = min(self.upper, interval.upper)
6060
if lower > upper:
61-
lower = np.float32(0)
62-
upper = np.float32(1)
61+
lower = np.float64(0)
62+
upper = np.float64(1)
6363
return Interval(lower, upper, False, self.prev_lower, self.prev_upper)
6464

6565
return impl

tests/unit/disable_jit/interpretations/test_interpretation_common.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,50 @@ def foo(ann, wts):
579579

580580
def test_float_to_str_and_str_to_int():
581581
assert float_to_str(12.345) == "12.345"
582-
assert float_to_str(3.0) == "3.0"
582+
assert float_to_str(3.0) == "3.000"
583583
assert str_to_int("123") == 123
584584
assert str_to_int("-45") == -45
585585

586586

587+
@pytest.mark.parametrize(
588+
"value,expected",
589+
[
590+
(-3.456, "-3.456"),
591+
(-0.123, "-0.123"),
592+
(-1.0, "-1.000"),
593+
],
594+
)
595+
def test_float_to_str_negative_values(value, expected):
596+
"""BUG-102: negative floats must preserve correct decimal digits."""
597+
assert float_to_str(value) == expected
598+
599+
600+
@pytest.mark.parametrize(
601+
"value,expected",
602+
[
603+
(3.001, "3.001"),
604+
(0.009, "0.009"),
605+
(5.050, "5.050"),
606+
(0.0, "0.000"),
607+
],
608+
)
609+
def test_float_to_str_zero_padding(value, expected):
610+
"""BUG-103: leading zeros in fractional part must be preserved."""
611+
assert float_to_str(value) == expected
612+
613+
614+
@pytest.mark.parametrize(
615+
"value,expected",
616+
[
617+
(-0.001, "-0.001"),
618+
(-3.010, "-3.010"),
619+
],
620+
)
621+
def test_float_to_str_negative_with_zero_padding(value, expected):
622+
"""BUG-102 + BUG-103 combined: negative values with leading-zero decimals."""
623+
assert float_to_str(value) == expected
624+
625+
587626
@pytest.mark.parametrize(
588627
"s,expected",
589628
[("3.14", 3.14), ("42", 42.0), ("-2.5", -2.5)],

0 commit comments

Comments
 (0)