Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 99 additions & 4 deletions src/parseval/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
Callable,
Generator,
)
from collections import defaultdict
import random, logging, os
from collections import defaultdict, deque
import re, random, logging, os
from contextlib import contextmanager
import time, threading
import atexit
Expand Down Expand Up @@ -348,9 +348,98 @@ def _check() -> bool:

return records

def create_tables(self, *ddls: str) -> None:
"""Execute one or more DDL statements (CREATE TABLE …)."""
@staticmethod
def _topological_sort_ddls(ddls: tuple[str, ...]) -> list[str]:
"""Sort CREATE TABLE DDLs so that referenced tables are created first."""
_CREATE_RE = re.compile(
r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+)["\']?',
re.IGNORECASE,
)
_FK_REF_RE = re.compile(
r'REFERENCES\s+["\']?(\w+)["\']?',
re.IGNORECASE,
)

table_to_ddl: dict[str, str] = {}
deps: dict[str, set[str]] = {}
ordered_non_create: list[str] = []

for ddl in ddls:
stripped = ddl.strip()
if not stripped:
continue
m = _CREATE_RE.search(stripped)
if not m:
ordered_non_create.append(stripped)
continue
tbl = m.group(1)
table_to_ddl[tbl] = stripped
refs = set(_FK_REF_RE.findall(stripped))
refs.discard(tbl)
deps[tbl] = refs

for tbl, refs in list(deps.items()):
for ref in refs:
if ref not in deps:
deps[ref] = set()

in_degree: dict[str, int] = {t: 0 for t in deps}
for tbl, refs in deps.items():
for ref in refs:
in_degree[ref] = in_degree.get(ref, 0)
in_degree[tbl] += 1

queue = deque(t for t, d in in_degree.items() if d == 0)
sorted_tables: list[str] = []
while queue:
t = queue.popleft()
sorted_tables.append(t)
for tbl, refs in deps.items():
if t in refs:
in_degree[tbl] -= 1
if in_degree[tbl] == 0:
queue.append(tbl)

for tbl in deps:
if tbl not in sorted_tables:
sorted_tables.append(tbl)

result = [table_to_ddl[t] for t in sorted_tables if t in table_to_ddl]
result.extend(ordered_non_create)
return result

@staticmethod
def _strip_fk_constraints(ddl: str) -> str:
"""Remove FOREIGN KEY clauses from a CREATE TABLE DDL.

The internal catalog already tracks FK relationships, so they are
not needed in the physical database and can cause errors when the
referenced column lacks a UNIQUE/PK constraint.
"""
ddl = re.sub(
r',?\s*FOREIGN\s+KEY\s*\([^)]*\)\s*REFERENCES\s+[^(]*\([^)]*\)',
'',
ddl,
flags=re.IGNORECASE,
)
ddl = re.sub(r',\s*\)', ')', ddl)
return ddl

def create_tables(self, *ddls: str) -> None:
"""Execute one or more DDL statements (CREATE TABLE …), topologically sorted by FK deps."""
sorted_ddls = self._topological_sort_ddls(ddls)
for ddl in sorted_ddls:
ddl = self._strip_fk_constraints(ddl)
# Convert fixed-length char types to varchar to avoid Postgres truncation/padding errors.
ddl = re.sub(r'\bcharacter\s*\(\s*\d+\s*\)', 'character varying', ddl, flags=re.IGNORECASE)
ddl = re.sub(r'\bchar\s*\(\s*\d+\s*\)', 'character varying', ddl, flags=re.IGNORECASE)
# Use IF NOT EXISTS to be idempotent — safe when zombie threads
# from timed-out runs are concurrently using the same database.
ddl = re.sub(
r'CREATE\s+TABLE\b',
'CREATE TABLE IF NOT EXISTS',
ddl, count=1, flags=re.IGNORECASE,
)
self.execute(ddl, fetch=None)
self._invalidate_metadata()

Expand Down Expand Up @@ -608,6 +697,12 @@ def _create_engine(
)
else:
connect_args = {"connect_timeout": connect_timeout}
# FIBEN/Postgres: schema lives outside `public`. Set search_path via env var
# (e.g. PARSEVAL_PG_SEARCH_PATH=fiben,public) so qualified-or-not lookups resolve.
if dialect == "postgres":
search_path = os.environ.get("PARSEVAL_PG_SEARCH_PATH")
if search_path:
connect_args["options"] = f"-c search_path={search_path}"
return create_engine(
conn_url,
pool_size=pool_size,
Expand Down
26 changes: 25 additions & 1 deletion src/parseval/disprover.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import re
import tempfile
import threading
import time
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
dialect: str = "sqlite",
config: DisproverConfig | None = None,
existing_dbs: Optional[Sequence[Sequence[object]]] = None,
schema_db_id: Optional[str] = None,
) -> None:
self.q1 = q1
self.q2 = q2
Expand All @@ -66,6 +68,11 @@ def __init__(
db_id="default",
)

# Original db_id (before any run_id suffix) used to strip schema qualifiers
# from queries, e.g. FIBEN."TABLE" -> "TABLE". Falls back to config.db_id
# when callers don't separate schema from per-run identifiers.
self._schema_db_id = schema_db_id or self.config.db_id

self.stop_event = threading.Event()
self._lock = threading.Lock()

Expand Down Expand Up @@ -177,7 +184,18 @@ def _check_syntax(self) -> RunResult | None:
)

pair = self._execute_pair(syntax_context)
if pair.q1.error_msg or pair.q2.error_msg:

# Only treat errors as syntax failures if they are NOT about missing
# tables/relations — those are expected when the syntax-check DB is empty.
def _is_syntax_error(err: str) -> bool:
if not err:
return False
lower = err.lower()
if "does not exist" in lower or "no such table" in lower:
return False
return True

if _is_syntax_error(pair.q1.error_msg) or _is_syntax_error(pair.q2.error_msg):
return self._build_result(
state="SYN",
q1_result=pair.q1,
Expand All @@ -189,6 +207,12 @@ def _check_syntax(self) -> RunResult | None:

def _execute_query(self, query: str, context: DatabaseContext) -> ExecutionResult:
database_name = self._normalize_database_name(context.database)
# Strip schema qualifiers matching the schema db_id (e.g. FIBEN."T" -> "T")
# but leave alias.column references (e.g. alias."col") intact. Needed when
# the per-run database has a different name than the schema namespace.
if self._schema_db_id:
schema_name = re.escape(self._schema_db_id)
query = re.sub(rf'\b{schema_name}\.\s*', '', query, flags=re.IGNORECASE)
started_at = time.monotonic()
with DBManager().get_connection(
host_or_path=context.host_or_path,
Expand Down
64 changes: 49 additions & 15 deletions src/parseval/faker/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,23 @@ def step(self, value: ValueType, direction: int) -> ValueType:
class IntGenerator(ValueGenerator[int]):

def satisfying_value(self, op, value):
if op == "BETWEEN":
lo, hi = value
if isinstance(lo, str) or isinstance(hi, str):
return lo
return random.choice([lo + 1, hi - 1])
if op == "IN":
return random.choice(value) if isinstance(value, list) and value else value
# Defensive int() for scalar ops — symbolic aggregate aliases
# (e.g. 'count_HASACCOUNTNUMBER_0') can leak in from speculative tracing.
try:
value = int(value)
except (ValueError, TypeError):
return self._default()
if op == "EQ":
return int(value)
return value
if op == "NEQ":
return int(value) + 1
return value + 1
if op == "GT":
return value + 1
if op == "GTE":
Expand All @@ -703,13 +716,6 @@ def satisfying_value(self, op, value):
return value - 1
if op == "LTE":
return value
if op == "BETWEEN":
lo, hi = value
if isinstance(lo, str) or isinstance(hi, str):
return lo
return random.choice([lo + 1, hi - 1])
if op == "IN":
return random.choice(value) if isinstance(value, list) and value else value
return value

def negating_value(self, op, value):
Expand All @@ -718,10 +724,16 @@ def negating_value(self, op, value):
return self.satisfying_value(neg_op, value)
if op == "BETWEEN":
lo, _ = value
return int(lo) - 1
try:
return int(float(lo)) - 1
except (ValueError, TypeError):
return self._default()
if op == "IN":
first = value[0] if value else None
return (int(first) + 1) if first is not None else None
try:
return (int(float(first)) + 1) if first is not None else None
except (ValueError, TypeError):
return self._default()
return self.generate()

def propagate_constraints(self):
Expand Down Expand Up @@ -866,17 +878,33 @@ def propagate_constraints(self):
elif isinstance(c, exp.GT):
self.length = lv + 1

def _max_length_from_datatype(self) -> Optional[int]:
"""Extract max length from the column datatype (e.g. character(1), varchar(50))."""
dt = self.pool.domain.datatype
if dt and dt.expressions:
try:
return int(dt.expressions[0].this.this)
except (ValueError, TypeError, AttributeError):
pass
return None

def generate(self, skips) -> str:
self.propagate_constraints()
alphabet = string.ascii_letters + string.digits + " "
max_len = self._max_length_from_datatype()
if self._in_values:
candidates = [v for v in self._in_values if skips is None or v not in skips]
if candidates:
return random.choice(candidates)
for _ in range(_MAX_UNIQUE_ATTEMPTS):
if self.fixed_value is not None and self.validate(self.fixed_value, skips):
return self.fixed_value
length = self.length if self.length is not None else random.randint(5, 15)
return self.fixed_value[:max_len] if max_len else self.fixed_value
if self.length is not None:
length = self.length
elif max_len is not None:
length = max_len
else:
length = random.randint(5, 15)
pattern = self.pattern or ("_" * length)
result = ""
for ch in pattern:
Expand All @@ -888,6 +916,8 @@ def generate(self, skips) -> str:
result += random.choice(alphabet)
else:
result += ch
if max_len is not None:
result = result[:max_len]
if self.validate(result, skips):
return result
raise RuntimeError("StringGenerator could not produce a valid value")
Expand Down Expand Up @@ -1374,8 +1404,12 @@ def get_or_create_pool(
alias = f"{table}.{column}"
if alias in self._pools:
return self._pools[alias]
table = normalize_name(table, is_table=True).name
column = normalize_name(column).name
_tbl = normalize_name(table, is_table=True)
_col = normalize_name(column)
if _tbl is None or _col is None:
raise KeyError(f"Cannot normalize table={table!r} column={column!r}")
table = _tbl.name
column = _col.name
qualified_name = f"{table}.{column}"
domain = self._domains.get(qualified_name)
if not domain:
Expand Down
5 changes: 4 additions & 1 deletion src/parseval/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def to_concrete(value, datatype=None):
return 1
return int(value)
elif datatype.is_type(*DataType.REAL_TYPES):
return float(value)
try:
return float(value)
except (ValueError, TypeError):
return 0.0
elif datatype.is_type(DataType.Type.BOOLEAN):
return bool(value)
elif datatype.is_type(*DataType.TEMPORAL_TYPES):
Expand Down
Loading