Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ltxv_trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,22 @@ class CheckpointsConfig(ConfigBaseModel):
)


class HubConfig(ConfigBaseModel):
"""Configuration for Hugging Face Hub integration"""

push_to_hub: bool = Field(
default=False,
description="Whether to push the model weights to the Hugging Face Hub"
)
hub_model_id: str = Field(
default=None,
description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
)
hub_token: str = Field(
default=None,
description="Hugging Face token. If None, will use the token from the Hugging Face CLI"
)

class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""

Expand All @@ -259,6 +275,7 @@ class LtxvTrainerConfig(ConfigBaseModel):
data: DataConfig = Field(default_factory=DataConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)

# General configuration
Expand Down
10 changes: 10 additions & 0 deletions src/ltxv_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from accelerate.utils import set_seed
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
from huggingface_hub import create_repo, upload_folder
from loguru import logger
from peft import LoraConfig, get_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -286,6 +287,15 @@ def train( # noqa: PLR0912, PLR0915
if self._accelerator.is_main_process:
saved_path = self._save_checkpoint()

# Upload artifacts to hub if enabled
if cfg.hub.push_to_hub:
repo_id = cfg.hub.hub_model_id or Path(cfg.output_dir).name
repo_id = create_repo(token=cfg.hub.hub_token, repo_id=repo_id, exist_ok=True)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the base model parameter here.
You can either point to the general LTXV repo or go finegrain amd allow the user to set a specific version. Either works for me.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a model card now :), wdyt?

upload_folder(
repo_id=repo_id,
folder_path=Path(self._config.output_dir),
)

# Log the training statistics
self._log_training_stats(stats)

Expand Down