diff --git a/train.py b/train.py index 6a2d2b03..f87a5b29 100644 --- a/train.py +++ b/train.py @@ -103,10 +103,27 @@ def main(): save_model_config_callback = ModelConfigEmbedderCallback(model_config) if args.val_dataset_config: - demo_callback = create_demo_callback_from_config(model_config, demo_dl=val_dl) + demo_dl = create_dataloader_from_config( + val_dataset_config, + batch_size=args.batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=model_config["sample_size"], + audio_channels=model_config.get("audio_channels", 2), + shuffle=False + ) else: - demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) - + demo_dl = create_dataloader_from_config( + dataset_config, + batch_size=args.batch_size, + num_workers=args.num_workers, + sample_rate=model_config["sample_rate"], + sample_size=model_config["sample_size"], + audio_channels=model_config.get("audio_channels", 2), + shuffle=False + ) + demo_callback = create_demo_callback_from_config(model_config, demo_dl=train_dl) + #Combine args and config dicts args_dict = vars(args) args_dict.update({"model_config": model_config})