Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/lighteval/main_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,43 @@ def dump(
modules_data = registry.get_tasks_dump()

print(json.dumps(modules_data, indent=2, default=str))


@app.command()
def lint(
custom_task_file: Annotated[str, Argument(help="Path to the custom task Python file to lint")],
):
"""Statically validate custom tasks for structural and module export errors"""
import importlib.util
import os
import sys

from rich import print

from lighteval.tasks.linter import validate_task_module

custom_tasks_path = os.path.abspath(custom_task_file)

if not os.path.exists(custom_tasks_path):
print(f"[red]Error: Custom task file not found at {custom_tasks_path}[/red]")
sys.exit(1)

print(f"Linting custom tasks file: [bold]{custom_tasks_path}[/bold]")

try:
spec = importlib.util.spec_from_file_location("custom_task_module", custom_tasks_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except Exception as e:
print(f"[red]Failed to import module: {e}[/red]")
sys.exit(1)

errors = validate_task_module(module)

if errors:
print(f"\n[red]Found {len(errors)} structural error(s) in custom tasks:[/red]")
for err in errors:
print(f" - {err}")
sys.exit(1)
else:
print("\n[green]Linting passed! Custom tasks are structurally valid and correctly exported.[/green]")
98 changes: 98 additions & 0 deletions src/lighteval/tasks/linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import inspect
import logging
import types
from typing import List

from lighteval.tasks.lighteval_task import LightevalTaskConfig


logger = logging.getLogger(__name__)


def _validate_base_types(config: LightevalTaskConfig) -> List[str]:
errors = []
if not isinstance(config.name, str) or not config.name:
errors.append(f"Task name must be a non-empty string. Got: {config.name}")
if not isinstance(config.hf_repo, str) or not config.hf_repo:
errors.append(f"Hugging Face repo (hf_repo) must be a non-empty string. Got: {config.hf_repo}")
if not config.metrics or len(config.metrics) == 0:
errors.append("Task config must define at least one metric in the 'metrics' list.")
return errors


def _validate_splits(config: LightevalTaskConfig) -> List[str]:
errors = []
avail_splits = set(config.hf_avail_splits) if config.hf_avail_splits else set()
for eval_split in config.evaluation_splits:
if eval_split not in avail_splits:
errors.append(f"evaluation_split '{eval_split}' is not declared in hf_avail_splits {list(avail_splits)}.")
if config.few_shots_split is not None:
if config.few_shots_split not in avail_splits:
errors.append(
f"few_shots_split '{config.few_shots_split}' is not declared in hf_avail_splits {list(avail_splits)}."
)
return errors


def _validate_prompt_function(config: LightevalTaskConfig) -> List[str]:
errors = []
if not callable(config.prompt_function):
errors.append(f"prompt_function must be a callable (function). Got: {type(config.prompt_function)}")
return errors

try:
sig = inspect.signature(config.prompt_function)
if len(sig.parameters) < 1:
errors.append("prompt_function must accept at least one parameter (the dataset row dict).")
if sig.return_annotation is not inspect.Signature.empty:
if getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) not in [
"Doc",
"lighteval.tasks.requests.Doc",
]:
logger.warning(f"prompt_function return annotation is {sig.return_annotation}, expected 'Doc'.")
except ValueError:
pass
return errors


def validate_task_config(config: LightevalTaskConfig) -> List[str]:
"""
Performs strict pure-Python static validation on a LightevalTaskConfig.
Returns a list of error strings. If the list is empty, the config is structurally valid.
"""
errors: List[str] = []
errors.extend(_validate_base_types(config))
errors.extend(_validate_splits(config))
errors.extend(_validate_prompt_function(config))
return errors


def validate_task_module(module: types.ModuleType) -> List[str]:
"""
Validates the module boundary to ensure it correctly exports tasks for the registry.
This prevents silent failures where a config is written but not correctly exported.
"""
errors: List[str] = []

if not hasattr(module, "TASKS_TABLE"):
errors.append("Module is missing the required 'TASKS_TABLE' export list.")
return errors

tasks_table = getattr(module, "TASKS_TABLE")
if not isinstance(tasks_table, list):
errors.append(f"'TASKS_TABLE' must be a list. Got: {type(tasks_table)}")
return errors

if len(tasks_table) == 0:
errors.append("'TASKS_TABLE' is empty. At least one task must be exported.")

for idx, task_config in enumerate(tasks_table):
if not isinstance(task_config, LightevalTaskConfig):
errors.append(f"Item at index {idx} in 'TASKS_TABLE' is not a LightevalTaskConfig object.")
continue

config_errors = validate_task_config(task_config)
for err in config_errors:
errors.append(f"[Task: {getattr(task_config, 'name', 'Unknown')}] {err}")

return errors
82 changes: 82 additions & 0 deletions tests/tasks/test_linter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import types

from lighteval.metrics.metrics import Metrics
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.linter import validate_task_module
from lighteval.tasks.requests import Doc


def test_linter_catches_missing_tasks_table():
# Simulate a user who wrote the config but forgot the TASKS_TABLE export
mock_module = types.ModuleType("mock_custom_task")

errors = validate_task_module(mock_module)
assert len(errors) == 1
assert "missing the required 'TASKS_TABLE'" in errors[0]


def test_linter_catches_invalid_tasks_table_type():
# Simulate a user who incorrectly defined TASKS_TABLE as a dict instead of a list
mock_module = types.ModuleType("mock_custom_task")
mock_module.TASKS_TABLE = {"my_task_name": "config_object"}

errors = validate_task_module(mock_module)
assert len(errors) == 1
assert "'TASKS_TABLE' must be a list" in errors[0]


def test_linter_catches_invalid_config_inside_table():
# Simulate a user who exported the list, but it contains a string instead of a LightevalTaskConfig
mock_module = types.ModuleType("mock_custom_task")
mock_module.TASKS_TABLE = ["This is not a LightevalTaskConfig"]

errors = validate_task_module(mock_module)
assert len(errors) == 1
assert "is not a LightevalTaskConfig object" in errors[0]


def test_linter_validates_perfect_module_export():
# Simulate a flawless, production-ready module export
def mock_prompt_fn(line: dict) -> Doc:
return Doc(task_name="test", query="q", choices=[], instruction="", target_for_fewshot_context="a")

perfect_config = LightevalTaskConfig(
name="test_task",
hf_repo="huggingface/mock_repo",
hf_subset="default",
metrics=[Metrics.exact_match],
prompt_function=mock_prompt_fn,
hf_avail_splits=["train", "validation", "test"],
evaluation_splits=["test"],
few_shots_split="train",
)

mock_module = types.ModuleType("mock_custom_task")
mock_module.TASKS_TABLE = [perfect_config]

errors = validate_task_module(mock_module)
# The linter should find zero structural or namespace errors
assert len(errors) == 0


def test_linter_catches_split_mismatch_inside_module():
# Simulate a user requesting a split that doesn't exist on the Hugging Face repo
def mock_prompt_fn(line: dict) -> Doc:
return Doc(task_name="test", query="q", choices=[], instruction="", target_for_fewshot_context="a")

broken_config = LightevalTaskConfig(
name="broken_task",
hf_repo="huggingface/mock_repo",
hf_subset="default",
metrics=[Metrics.exact_match],
prompt_function=mock_prompt_fn,
hf_avail_splits=["train", "test"],
evaluation_splits=["validation"], # "validation" is not in hf_avail_splits!
)

mock_module = types.ModuleType("mock_custom_task")
mock_module.TASKS_TABLE = [broken_config]

errors = validate_task_module(mock_module)
assert len(errors) == 1
assert "evaluation_split 'validation' is not declared in hf_avail_splits" in errors[0]