Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 93 additions & 18 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,30 @@ def build_model_from_config(cfg, model_name, use_vllm=False):
return model


def build_dataset_from_config(cfg, dataset_name):
def build_dataset_from_config(cfg, dataset_name, *, strict=False, extra_kwargs=None):
import inspect

import vlmeval.dataset
config = cp.deepcopy(cfg[dataset_name])
if config == {}:
return supported_video_datasets[dataset_name]()
assert 'class' in config
if dataset_name not in supported_video_datasets:
raise ValueError(f'Empty dataset config {dataset_name} is not a supported video dataset shortcut')
return supported_video_datasets[dataset_name](**(extra_kwargs or {}))
if 'class' not in config:
raise ValueError(f'`class` must be set for dataset config {dataset_name}')
cls_name = config.pop('class')
if extra_kwargs:
for k, v in extra_kwargs.items():
config.setdefault(k, v)
if hasattr(vlmeval.dataset, cls_name):
cls = getattr(vlmeval.dataset, cls_name)
sig = inspect.signature(cls.__init__)
unknown_params = sorted(k for k in config if k not in sig.parameters)
if strict and unknown_params:
unknown = ', '.join(unknown_params)
raise ValueError(f'Unsupported parameter(s) for dataset class {cls_name}: {unknown}')
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
if cls.MODALITY == 'VIDEO':
if getattr(cls, 'MODALITY', None) == 'VIDEO':
if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0:
raise ValueError('fps and nframe should not be set at the same time')
if valid_params.get('fps', 0) <= 0 and valid_params.get('nframe', 0) <= 0:
Expand All @@ -199,6 +209,65 @@ def apply_supported_vlm_cli_overrides(args):
supported_VLM[k] = v


def load_data_config(data_config):
if data_config is None:
return {}

raw = data_config.strip()
if raw == '':
return {}

try:
config = json.loads(raw)
except json.JSONDecodeError as e:
raise ValueError('--data-config must be a valid JSON dict string') from e

if not isinstance(config, dict):
raise ValueError('--data-config must be a JSON dict')
for name, value in config.items():
if not isinstance(name, str):
raise ValueError('--data-config keys must be strings')
if not isinstance(value, dict):
raise ValueError(f'--data-config value for {name} must be a JSON dict')
if 'class' in value and not isinstance(value['class'], str):
raise ValueError(f'--data-config class for {name} must be a string')
if 'dataset' in value and not isinstance(value['dataset'], str):
raise ValueError(f'--data-config dataset for {name} must be a string')
return config


def get_data_config_dataset_name(dataset_name, data_config):
if dataset_name in data_config:
return data_config[dataset_name].get('dataset', dataset_name)
return dataset_name


def get_judge_dataset_name(dataset_name, data_config):
base_name = get_data_config_dataset_name(dataset_name, data_config)
if base_name == dataset_name:
return dataset_name
return f'{dataset_name} {base_name} {base_name.replace("-", "_")}'


def get_dataset_build_kwargs(dataset_name, model_name, data_config):
dataset_kwargs = {}
base_name = get_data_config_dataset_name(dataset_name, data_config)
if base_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
dataset_kwargs['model'] = model_name
return dataset_kwargs


def build_dataset_from_cli(dataset_name, data_config, dataset_kwargs):
if dataset_name in data_config:
return build_dataset_from_config(
data_config,
dataset_name,
strict=True,
extra_kwargs=dataset_kwargs,
)
return build_dataset(dataset_name, **dataset_kwargs)


def build_model_from_base_url(args):
"""Build LMDeployAPI model kwargs from command-line arguments.

Expand Down Expand Up @@ -379,6 +448,10 @@ def parse_args():
or you can check the output of the command `vlmutil dlist all` in the terminal.
To find all supported video dataset default settings, please refer to the \
`vlmeval/dataset/video_dataset_config.py` file.
You can also pass --data-config to define custom dataset names used by --data:
--data Video-MME-custom \
--data-config '{"Video-MME-custom": {"class": "VideoMME", "dataset": "Video-MME", "nframe": 16}}'
The value of --data-config must be a JSON dict string so evaluation parameters are recorded in argv.

--config:
Launch the evaluation by specifying the path to the config json file. Sample Json Content:
Expand Down Expand Up @@ -446,6 +519,8 @@ def parse_args():
parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets')
parser.add_argument('--model', type=str, nargs='+', help='Names of Models')
parser.add_argument('--config', type=str, help='Path to the Config Json File')
parser.add_argument('--data-config', type=str, default=None,
help='Custom dataset configs as a JSON dict string. Keys must match names passed to --data.')

# Work Dir & Mode
parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory')
Expand Down Expand Up @@ -520,6 +595,10 @@ def parse_args():
help='Debug mode: run evaluation in main process')

args = parser.parse_args()
try:
args.data_config = load_data_config(args.data_config)
except ValueError as e:
parser.error(str(e))
if args.ignore:
logger.warning('[Deprecated] the `--ignore` flag is deprecated since it is '
'the default behavior, use `--keep-failed` to disable it.')
Expand All @@ -531,6 +610,7 @@ def run_local_mode(args):
use_config, cfg = False, None
if args.config is not None:
assert args.data is None and args.model is None, '--data and --model should not be set when using --config'
assert not args.data_config, '--data-config should not be set when using --config'
use_config, cfg = True, load(args.config)
args.model = list(cfg['model'].keys())
args.data = list(cfg['data'].keys())
Expand Down Expand Up @@ -649,17 +729,15 @@ def run_local_mode(args):
)
continue
else:
dataset_kwargs = {}
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
dataset_kwargs['model'] = model_name
dataset_kwargs = get_dataset_build_kwargs(dataset_name, model_name, args.data_config)

# If distributed, first build the dataset on the main process for doing preparation works
if WORLD_SIZE > 1:
if RANK == 0:
dataset = build_dataset(dataset_name, **dataset_kwargs)
dataset = build_dataset_from_cli(dataset_name, args.data_config, dataset_kwargs)
dist.barrier()

dataset = build_dataset(dataset_name, **dataset_kwargs)
dataset = build_dataset_from_cli(dataset_name, args.data_config, dataset_kwargs)
if dataset is None:
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
if RANK == 0:
Expand All @@ -672,7 +750,8 @@ def run_local_mode(args):
)
continue

judge_kwargs = get_judge_kwargs(dataset_name, dataset.TYPE, args)
judge_dataset_name = get_judge_dataset_name(dataset_name, args.data_config)
judge_kwargs = get_judge_kwargs(judge_dataset_name, dataset.TYPE, args)
judge_model = judge_kwargs.get('model', '')

if RANK == 0:
Expand Down Expand Up @@ -1036,13 +1115,8 @@ def run_api_mode(args):
logger.info(f'-------------------- {ds_name} --------------------')

try:
dataset_kwargs = {}
if ds_name in [
'MMLongBench_DOC', 'DUDE', 'DUDE_MINI',
'SLIDEVQA', 'SLIDEVQA_MINI',
]:
dataset_kwargs['model'] = model_name
dataset = build_dataset(ds_name, **dataset_kwargs)
dataset_kwargs = get_dataset_build_kwargs(ds_name, model_name, args.data_config)
dataset = build_dataset_from_cli(ds_name, args.data_config, dataset_kwargs)

if dataset is None:
logger.error(f'Dataset {ds_name} is not valid, will be skipped.')
Expand All @@ -1060,7 +1134,8 @@ def run_api_mode(args):
logger.info(f'{ds_name} requires special handling, skipped in pipeline.')
continue

judge_kwargs = get_judge_kwargs(ds_name, dataset.TYPE, args)
judge_dataset_name = get_judge_dataset_name(ds_name, args.data_config)
judge_kwargs = get_judge_kwargs(judge_dataset_name, dataset.TYPE, args)
judge_model = judge_kwargs.get('model', '')
logger.info(f'Judge kwargs: {judge_kwargs}')

Expand Down
Loading