diff --git a/src/lighteval/main_tasks.py b/src/lighteval/main_tasks.py index fb71b1b77..b4b6e49f5 100644 --- a/src/lighteval/main_tasks.py +++ b/src/lighteval/main_tasks.py @@ -107,3 +107,43 @@ def dump( modules_data = registry.get_tasks_dump() print(json.dumps(modules_data, indent=2, default=str)) + + +@app.command() +def lint( + custom_task_file: Annotated[str, Argument(help="Path to the custom task Python file to lint")], +): + """Statically validate custom tasks for structural and module export errors""" + import importlib.util + import os + import sys + + from rich import print + + from lighteval.tasks.linter import validate_task_module + + custom_tasks_path = os.path.abspath(custom_task_file) + + if not os.path.exists(custom_tasks_path): + print(f"[red]Error: Custom task file not found at {custom_tasks_path}[/red]") + sys.exit(1) + + print(f"Linting custom tasks file: [bold]{custom_tasks_path}[/bold]") + + try: + spec = importlib.util.spec_from_file_location("custom_task_module", custom_tasks_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except Exception as e: + print(f"[red]Failed to import module: {e}[/red]") + sys.exit(1) + + errors = validate_task_module(module) + + if errors: + print(f"\n[red]Found {len(errors)} structural error(s) in custom tasks:[/red]") + for err in errors: + print(f" - {err}") + sys.exit(1) + else: + print("\n[green]Linting passed! Custom tasks are structurally valid and correctly exported.[/green]") diff --git a/src/lighteval/tasks/linter.py b/src/lighteval/tasks/linter.py new file mode 100644 index 000000000..ea3b04bee --- /dev/null +++ b/src/lighteval/tasks/linter.py @@ -0,0 +1,98 @@ +import inspect +import logging +import types +from typing import List + +from lighteval.tasks.lighteval_task import LightevalTaskConfig + + +logger = logging.getLogger(__name__) + + +def _validate_base_types(config: LightevalTaskConfig) -> List[str]: + errors = [] + if not isinstance(config.name, str) or not config.name: + errors.append(f"Task name must be a non-empty string. Got: {config.name}") + if not isinstance(config.hf_repo, str) or not config.hf_repo: + errors.append(f"Hugging Face repo (hf_repo) must be a non-empty string. Got: {config.hf_repo}") + if not config.metrics or len(config.metrics) == 0: + errors.append("Task config must define at least one metric in the 'metrics' list.") + return errors + + +def _validate_splits(config: LightevalTaskConfig) -> List[str]: + errors = [] + avail_splits = set(config.hf_avail_splits) if config.hf_avail_splits else set() + for eval_split in config.evaluation_splits: + if eval_split not in avail_splits: + errors.append(f"evaluation_split '{eval_split}' is not declared in hf_avail_splits {list(avail_splits)}.") + if config.few_shots_split is not None: + if config.few_shots_split not in avail_splits: + errors.append( + f"few_shots_split '{config.few_shots_split}' is not declared in hf_avail_splits {list(avail_splits)}." + ) + return errors + + +def _validate_prompt_function(config: LightevalTaskConfig) -> List[str]: + errors = [] + if not callable(config.prompt_function): + errors.append(f"prompt_function must be a callable (function). Got: {type(config.prompt_function)}") + return errors + + try: + sig = inspect.signature(config.prompt_function) + if len(sig.parameters) < 1: + errors.append("prompt_function must accept at least one parameter (the dataset row dict).") + if sig.return_annotation is not inspect.Signature.empty: + if getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) not in [ + "Doc", + "lighteval.tasks.requests.Doc", + ]: + logger.warning(f"prompt_function return annotation is {sig.return_annotation}, expected 'Doc'.") + except ValueError: + pass + return errors + + +def validate_task_config(config: LightevalTaskConfig) -> List[str]: + """ + Performs strict pure-Python static validation on a LightevalTaskConfig. + Returns a list of error strings. If the list is empty, the config is structurally valid. + """ + errors: List[str] = [] + errors.extend(_validate_base_types(config)) + errors.extend(_validate_splits(config)) + errors.extend(_validate_prompt_function(config)) + return errors + + +def validate_task_module(module: types.ModuleType) -> List[str]: + """ + Validates the module boundary to ensure it correctly exports tasks for the registry. + This prevents silent failures where a config is written but not correctly exported. + """ + errors: List[str] = [] + + if not hasattr(module, "TASKS_TABLE"): + errors.append("Module is missing the required 'TASKS_TABLE' export list.") + return errors + + tasks_table = getattr(module, "TASKS_TABLE") + if not isinstance(tasks_table, list): + errors.append(f"'TASKS_TABLE' must be a list. Got: {type(tasks_table)}") + return errors + + if len(tasks_table) == 0: + errors.append("'TASKS_TABLE' is empty. At least one task must be exported.") + + for idx, task_config in enumerate(tasks_table): + if not isinstance(task_config, LightevalTaskConfig): + errors.append(f"Item at index {idx} in 'TASKS_TABLE' is not a LightevalTaskConfig object.") + continue + + config_errors = validate_task_config(task_config) + for err in config_errors: + errors.append(f"[Task: {getattr(task_config, 'name', 'Unknown')}] {err}") + + return errors diff --git a/tests/tasks/test_linter.py b/tests/tasks/test_linter.py new file mode 100644 index 000000000..db5e24fb0 --- /dev/null +++ b/tests/tasks/test_linter.py @@ -0,0 +1,82 @@ +import types + +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.linter import validate_task_module +from lighteval.tasks.requests import Doc + + +def test_linter_catches_missing_tasks_table(): + # Simulate a user who wrote the config but forgot the TASKS_TABLE export + mock_module = types.ModuleType("mock_custom_task") + + errors = validate_task_module(mock_module) + assert len(errors) == 1 + assert "missing the required 'TASKS_TABLE'" in errors[0] + + +def test_linter_catches_invalid_tasks_table_type(): + # Simulate a user who incorrectly defined TASKS_TABLE as a dict instead of a list + mock_module = types.ModuleType("mock_custom_task") + mock_module.TASKS_TABLE = {"my_task_name": "config_object"} + + errors = validate_task_module(mock_module) + assert len(errors) == 1 + assert "'TASKS_TABLE' must be a list" in errors[0] + + +def test_linter_catches_invalid_config_inside_table(): + # Simulate a user who exported the list, but it contains a string instead of a LightevalTaskConfig + mock_module = types.ModuleType("mock_custom_task") + mock_module.TASKS_TABLE = ["This is not a LightevalTaskConfig"] + + errors = validate_task_module(mock_module) + assert len(errors) == 1 + assert "is not a LightevalTaskConfig object" in errors[0] + + +def test_linter_validates_perfect_module_export(): + # Simulate a flawless, production-ready module export + def mock_prompt_fn(line: dict) -> Doc: + return Doc(task_name="test", query="q", choices=[], instruction="", target_for_fewshot_context="a") + + perfect_config = LightevalTaskConfig( + name="test_task", + hf_repo="huggingface/mock_repo", + hf_subset="default", + metrics=[Metrics.exact_match], + prompt_function=mock_prompt_fn, + hf_avail_splits=["train", "validation", "test"], + evaluation_splits=["test"], + few_shots_split="train", + ) + + mock_module = types.ModuleType("mock_custom_task") + mock_module.TASKS_TABLE = [perfect_config] + + errors = validate_task_module(mock_module) + # The linter should find zero structural or namespace errors + assert len(errors) == 0 + + +def test_linter_catches_split_mismatch_inside_module(): + # Simulate a user requesting a split that doesn't exist on the Hugging Face repo + def mock_prompt_fn(line: dict) -> Doc: + return Doc(task_name="test", query="q", choices=[], instruction="", target_for_fewshot_context="a") + + broken_config = LightevalTaskConfig( + name="broken_task", + hf_repo="huggingface/mock_repo", + hf_subset="default", + metrics=[Metrics.exact_match], + prompt_function=mock_prompt_fn, + hf_avail_splits=["train", "test"], + evaluation_splits=["validation"], # "validation" is not in hf_avail_splits! + ) + + mock_module = types.ModuleType("mock_custom_task") + mock_module.TASKS_TABLE = [broken_config] + + errors = validate_task_module(mock_module) + assert len(errors) == 1 + assert "evaluation_split 'validation' is not declared in hf_avail_splits" in errors[0]