Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 128 additions & 91 deletions every_eval_ever/converters/helm/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Union, cast

_HELM_IMPORT_ERROR: Exception | None = None
try:
Expand All @@ -29,14 +29,18 @@
Exception
) as ex: # pragma: no cover - exercised only when optional deps missing
_HELM_IMPORT_ERROR = ex
DaciteConfig = from_dict = None # type: ignore[assignment]
PerInstanceStats = AdapterSpec = RequestState = ScenarioState = Stat = (
RunSpec
) = Any # type: ignore[assignment]
get_model_deployment = register_builtin_configs_from_helm_package = (
from_json
) = None # type: ignore[assignment]
ModelDeploymentNotFoundError = Exception # type: ignore[assignment]
DaciteConfig = cast(Any, None)
from_dict = cast(Any, None)
PerInstanceStats = cast(Any, None)
AdapterSpec = cast(Any, None)
RequestState = cast(Any, None)
ScenarioState = cast(Any, None)
Stat = cast(Any, None)
RunSpec = cast(Any, None)
get_model_deployment = cast(Any, None)
register_builtin_configs_from_helm_package = cast(Any, None)
from_json = cast(Any, None)
ModelDeploymentNotFoundError = cast(Any, Exception)

from every_eval_ever.converters import SCHEMA_VERSION
from every_eval_ever.converters.common.adapter import (
Expand All @@ -45,8 +49,11 @@
SupportedLibrary,
)
from every_eval_ever.converters.common.utils import sha256_file
from every_eval_ever.converters.helm.metrics import is_core_metric
from every_eval_ever.converters.helm.instance_level_adapter import (
HELMInstanceLevelDataAdapter,
_evaluation_result_id,
_score_from_stat,
)
from every_eval_ever.converters.helm.utils import extract_reasoning
from every_eval_ever.eval_types import (
Expand All @@ -67,7 +74,6 @@
SourceType,
Uncertainty,
)
from every_eval_ever.instance_level_types import InstanceLevelEvaluationLog


def _require_helm_dependencies() -> None:
Expand Down Expand Up @@ -106,6 +112,7 @@ def metadata(self) -> AdapterMetadata:
return AdapterMetadata(
name='HELMAdapter',
version='0.0.1',
supported_library_versions=['helm'],
description='HELM adapter with dynamic metrics and unified JSONL instance logging',
)

Expand All @@ -129,7 +136,8 @@ def _split_model_id(self, model_id: str | None) -> tuple[str, str]:
if not model_id:
return ('unknown', 'unknown')
if '/' in model_id:
return tuple(model_id.split('/', 1))
developer, name = model_id.split('/', 1)
return (developer, name)
return ('unknown', model_id)

def _extract_model_info(self, adapter_spec: AdapterSpec) -> ModelInfo:
Expand Down Expand Up @@ -191,6 +199,7 @@ def _load_file_if_exists(self, dir_path, file_name) -> Any:
return None

def _load_evaluation_run_logfiles(self, dir_path) -> Dict:
"""Load the HELM files needed for aggregate and detail conversion."""
scenario_state_dict = self._load_file_if_exists(
dir_path, self.SCENARIO_STATE_FILE
)
Expand All @@ -210,14 +219,18 @@ def _load_evaluation_run_logfiles(self, dir_path) -> Dict:
}

def transform_from_directory(
self, dir_path: str, output_path: str, metadata_args: Dict[str, Any]
):
self,
dir_path: str | Path,
metadata_args: Dict[str, Any] | None = None,
output_path: str | None = None,
) -> List[EvaluationLog]:
"""
Transforms HELM results into one aggregate EvaluationLog and one
instance-level JSONL file containing all samples.
"""
# all_instance_logs: List[InstanceLevelEvaluationLog] = []
aggregate_logs: List[EvaluationLog] = []
metadata_args = metadata_args or {}
dir_path = str(dir_path)

file_uuids = metadata_args.get('file_uuids')

Expand Down Expand Up @@ -260,11 +273,6 @@ def transform_from_directory(
aggregate_logs.append(agg)
converted_idx += 1

# # Write all consolidated instance logs to JSONL
# with open(output_path, 'w', encoding='utf-8') as f:
# for log in all_instance_logs:
# f.write(json.dumps(log.model_dump(), ensure_ascii=False) + '\n')

return aggregate_logs

def _extract_generation_args(
Expand Down Expand Up @@ -318,6 +326,7 @@ def _extract_evaluation_time(
def _extract_dataset_name(
self, run_spec_name: str, scenario_name: str | None
) -> str:
"""Prefer scenario metadata, falling back to HELM run-spec names."""
if scenario_name:
return scenario_name

Expand All @@ -332,20 +341,16 @@ def _extract_dataset_name(

return run_spec_name.split(':')[0]

def _extract_metric_names(self, run_spec: RunSpec) -> List[str]:
metric_names = []
for metric_spec in run_spec.metric_specs:
names = metric_spec.args.get('names')
if names:
metric_names.extend(names)
else:
metric_names.append(metric_spec.class_name.split('.')[-1])

return metric_names

def _transform_single(
self, raw_data: Dict, metadata_args: Dict[str, Any]
) -> Tuple[EvaluationLog, List[InstanceLevelEvaluationLog]]:
) -> EvaluationLog:
"""Convert one HELM run into aggregate JSON plus detail JSONL.

The aggregate ``evaluation_result_id`` values are generated from
core metrics in ``stats.json`` with the same helper used by the
instance converter so every metric-specific detail row can join
back to an aggregate result.
"""
run_spec = from_dict(data_class=RunSpec, data=raw_data['run_spec_dict'])
# cast=[str] coerces int instance IDs to str; newer HELM versions
# (e.g. long-context suite) store instance.id as int in the JSON.
Expand Down Expand Up @@ -402,80 +407,112 @@ def _transform_single(

evaluation_id = f'{source_data.dataset_name}/{model_info.id.replace("/", "_")}/{evaluation_timestamp}'

metric_names = self._extract_metric_names(run_spec)

# Build aggregate results from core HELM stats themselves, not
# only from run_spec.metric_specs. The instance-level converter emits
# one row per core per-instance stat, so aggregate IDs must cover
# the same core namespace for detailed rows to be joinable.
# TODO: Consider promoting bookkeeping telemetry into structured
# fields such as token_usage, performance, metadata, or
# additional_details in a separate follow-up.
evaluation_results: List[EvaluationResult] = []
seen_evaluation_result_ids: set[str] = set()

for stat in stats_raw:
# The ID helper mirrors the instance-level converter. This is the
# key invariant: detail rows should never introduce metric IDs that
# are absent from aggregate evaluation_results.
metric_name = getattr(getattr(stat, 'name', None), 'name', None)
if not is_core_metric(metric_name):
continue
score = _score_from_stat(stat)
if metric_name is None or score is None:
continue

stat_count = getattr(stat, 'count', None)

evaluation_result_id = _evaluation_result_id(
metric_name,
getattr(stat.name, 'split', None),
getattr(stat.name, 'perturbation', None),
)
if evaluation_result_id is None:
continue
if evaluation_result_id in seen_evaluation_result_ids:
continue
seen_evaluation_result_ids.add(evaluation_result_id)

for metric_name in set(metric_names):
metric_config = MetricConfig(
evaluation_description=metric_name,
lower_is_better=False, # TODO schema.json check
score_type=ScoreType.continuous,
min_score=0,
max_score=1,
max_score=1.0,
)

matching_stats = [
s
for s in stats_raw
if s.name.name == metric_name and not s.name.perturbation
]

for stat in matching_stats:
evaluation_name = (
f'{metric_name} on {source_data.dataset_name}'
if not stat.name.split
else f'{metric_name} {stat.name.split} on {source_data.dataset_name}'
)
split = getattr(stat.name, 'split', None)
perturbation = getattr(stat.name, 'perturbation', None)
name_parts = [metric_name]
if split:
name_parts.append(str(split))
if perturbation:
name_parts.append(str(perturbation))
evaluation_name = (
f'{" ".join(name_parts)} on {source_data.dataset_name}'
)

evaluation_results.append(
EvaluationResult(
evaluation_name=evaluation_name,
source_data=source_data,
evaluation_timestamp=evaluation_timestamp,
metric_config=metric_config,
score_details=ScoreDetails(
score=stat.mean
or (stat.sum / stat.count if stat.count else 0.0),
uncertainty=Uncertainty(
standard_deviation=stat.stddev,
num_samples=adapter_spec.max_eval_instances
or len(request_states),
evaluation_results.append(
EvaluationResult(
evaluation_result_id=evaluation_result_id,
evaluation_name=evaluation_name,
source_data=source_data,
evaluation_timestamp=evaluation_timestamp,
metric_config=metric_config,
score_details=ScoreDetails(
score=score,
uncertainty=Uncertainty(
standard_deviation=getattr(stat, 'stddev', None),
# Split-specific HELM stats may cover fewer
# examples than the full run, so use the stat's
# own count when it is available.
num_samples=(
stat_count
if stat_count is not None
else adapter_spec.max_eval_instances
or len(request_states)
),
details={
'count': str(stat.count),
'split': str(stat.name.split)
if stat.name.split
else '',
'perturbation': str(stat.name.perturbation)
if stat.name.perturbation
else '',
},
),
generation_config=GenerationConfig(
generation_args=self._extract_generation_args(
adapter_spec=adapter_spec,
request_state=request_states[0],
),
additional_details={
'stop_sequences': json.dumps(
request_states[0].request.stop_sequences
)
if request_states[0].request.stop_sequences
else '[]',
'presence_penalty': str(
request_states[0].request.presence_penalty
),
'frequency_penalty': str(
request_states[0].request.frequency_penalty
),
'num_completions': str(
request_states[0].request.num_completions
),
},
details={
'count': str(getattr(stat, 'count', '')),
'split': str(split) if split else '',
'perturbation': str(perturbation)
if perturbation
else '',
},
),
generation_config=GenerationConfig(
generation_args=self._extract_generation_args(
adapter_spec=adapter_spec,
request_state=request_states[0],
),
)
additional_details={
'stop_sequences': json.dumps(
request_states[0].request.stop_sequences
)
if request_states[0].request.stop_sequences
else '[]',
'presence_penalty': str(
request_states[0].request.presence_penalty
),
'frequency_penalty': str(
request_states[0].request.frequency_penalty
),
'num_completions': str(
request_states[0].request.num_completions
),
},
),
)
)

if request_states:
parent_eval_output_dir = metadata_args.get('parent_eval_output_dir')
Expand Down
Loading