Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion documentation/OPTIONS.es.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Donde `foo` es tu entorno de configuración; o simplemente usa `config/config.js

### `--musubi_blocks_to_swap`

- **Qué**: Intercambio de bloques Musubi para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image y HunyuanVideo: mantiene los últimos N bloques del transformer en CPU y transmite pesos por bloque durante el forward.
- **Qué**: Intercambio de bloques Musubi para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo y Krea 2: mantiene los últimos N bloques del transformer en CPU y transmite pesos por bloque durante el forward.
- **Predeterminado**: `0` (desactivado)
- **Notas**: Offload de pesos estilo Musubi; reduce VRAM con coste en rendimiento y se omite cuando los gradientes están habilitados.

Expand Down
2 changes: 1 addition & 1 deletion documentation/OPTIONS.hi.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json

### `--musubi_blocks_to_swap`

- **What**: LongCat‑Video, Wan, LTXVideo, Kandinsky5‑Video, Qwen‑Image, Flux, Flux.2, zlab i1, Cosmos2Image, और HunyuanVideo के लिए Musubi block swap — आख़िरी N transformer blocks को CPU पर रखें और forward के दौरान प्रति block weights stream करें।
- **What**: LongCat‑Video, Wan, LTXVideo, Kandinsky5‑Video, Qwen‑Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo, और Krea 2 के लिए Musubi block swap — आख़िरी N transformer blocks को CPU पर रखें और forward के दौरान प्रति block weights stream करें।
- **Default**: `0` (disabled)
- **Notes**: Musubi‑style weight offload; throughput लागत पर VRAM कम करता है और gradients सक्षम होने पर skip हो जाता है।

Expand Down
2 changes: 1 addition & 1 deletion documentation/OPTIONS.ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json

### `--musubi_blocks_to_swap`

- **内容**: LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo の Musubi ブロックスワップ。最後の N 個の Transformer ブロックを CPU に置き、forward 中にブロック単位で重みをストリーミングします。
- **内容**: LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo、Krea 2 の Musubi ブロックスワップ。最後の N 個の Transformer ブロックを CPU に置き、forward 中にブロック単位で重みをストリーミングします。
- **既定**: `0`(無効)
- **注記**: Musubi 方式の重みオフロードで、スループット低下と引き換えに VRAM を削減します。勾配が有効な場合はスキップされます。
- **注記**: Musubi 方式の重みオフロードで、スループット低下と引き換えに VRAM を削減します。勾配が有効な場合はスキップされます。
Expand Down
2 changes: 1 addition & 1 deletion documentation/OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Where `foo` is your config environment - or just use `config/config.json` if you

### `--musubi_blocks_to_swap`

- **What**: Musubi block swap for LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, and HunyuanVideo — keep the last N transformer blocks on CPU and stream weights per block during forward.
- **What**: Musubi block swap for LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo, and Krea 2 — keep the last N transformer blocks on CPU and stream weights per block during forward.
- **Default**: `0` (disabled)
- **Notes**: Musubi-style weight offload; reduces VRAM at a throughput cost and is skipped when gradients are enabled.

Expand Down
2 changes: 1 addition & 1 deletion documentation/OPTIONS.pt-BR.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Onde `foo` e seu ambiente de config — ou use `config/config.json` se nao estiv

### `--musubi_blocks_to_swap`

- **O que**: Musubi block swap para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image e HunyuanVideo — mantem os ultimos N blocos transformer na CPU e faz streaming de pesos por bloco durante o forward.
- **O que**: Musubi block swap para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo e Krea 2 — mantem os ultimos N blocos transformer na CPU e faz streaming de pesos por bloco durante o forward.
- **Padrao**: `0` (desabilitado)
- **Notas**: Offload de pesos estilo Musubi; reduz VRAM com custo de throughput e e ignorado quando gradientes estao habilitados.

Expand Down
2 changes: 1 addition & 1 deletion documentation/OPTIONS.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json

### `--musubi_blocks_to_swap`

- **内容**:为 LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo 提供 Musubi 块交换。将最后 N 个 Transformer 块保留在 CPU,并在前向中按块流式加载权重。
- **内容**:为 LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo、Krea 2 提供 Musubi 块交换。将最后 N 个 Transformer 块保留在 CPU,并在前向中按块流式加载权重。
- **默认**:`0`(禁用)
- **说明**:Musubi 风格权重卸载,会降低吞吐以换取显存节省,且在启用梯度时会跳过。
- **说明**:Musubi 风格权重卸载,通过降低吞吐换取显存节省;启用梯度时会跳过。
Expand Down
61 changes: 60 additions & 1 deletion simpletuner/helpers/models/krea2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from simpletuner.helpers.models.krea2.transformer import Krea2Transformer2DModel
from simpletuner.helpers.models.registry import ModelRegistry
from simpletuner.helpers.models.tae.types import VideoTAESpec
from simpletuner.helpers.musubi_block_swap import apply_musubi_pretrained_defaults
from simpletuner.helpers.training.state_tracker import StateTracker

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,6 +64,26 @@ class Krea2(ImageModelFoundation):
DEFAULT_LORA_TARGET = ["to_k", "to_q", "to_v", "to_out.0"]
FUSED_LORA_TARGET = ["to_qkv", "to_out.0"]

def tread_init(self):
from simpletuner.helpers.training.tread import TREADRouter

tread_cfg = getattr(self.config, "tread_config", None)
if not isinstance(tread_cfg, dict) or tread_cfg == {} or tread_cfg.get("routes") is None:
logger.error("TREAD training requires you to configure the routes in the TREAD config")
import sys

sys.exit(1)

self.unwrap_model(model=self.model).set_router(
TREADRouter(
seed=getattr(self.config, "seed", None) or 42,
device=self.accelerator.device,
),
tread_cfg["routes"],
)

logger.info("TREAD training is enabled for Krea 2")

@classmethod
def max_swappable_blocks(cls, config=None) -> Optional[int]:
return 27
Expand All @@ -78,6 +99,24 @@ def _uses_reference_latents(self) -> bool:
def supports_conditioning_dataset(self) -> bool:
return True

def supports_crepa_self_flow(self) -> bool:
return True

def _prepare_crepa_self_flow_batch(self, batch: dict, state: dict) -> dict:
return self._prepare_image_crepa_self_flow_batch(batch, state, patch_size=self._patch_size())

def _select_crepa_hidden_states(self, prepared_batch: dict, hidden_states_buffer):
if hidden_states_buffer is None:
return None
crepa = getattr(self, "crepa_regularizer", None)
capture_layer = prepared_batch.get(
"crepa_capture_block_index",
getattr(crepa, "block_index", None),
)
if capture_layer is None:
return None
return hidden_states_buffer.get(f"layer_{int(capture_layer)}")

def requires_conditioning_dataset(self) -> bool:
return self._uses_reference_latents()

Expand Down Expand Up @@ -419,6 +458,11 @@ def _prepare_model_predict_timesteps(self, raw_timesteps, batch_size: int) -> to
raise ValueError(f"Krea 2 expected scalar or 1D timesteps, got shape {tuple(timesteps.shape)}.")
return timesteps / 1000.0

def _apply_krea2_experimental_timestep_kwargs(self, call_kwargs: dict, prepared_batch: dict) -> None:
if getattr(self.config, "twinflow_enabled", False):
call_kwargs["timestep_sign"] = prepared_batch.get("twinflow_time_sign")
self._apply_flowmap_r_timestep_kwargs(call_kwargs, prepared_batch)

def _prepare_reference_latents(self, prepared_batch: dict, batch_size: int, channels: int, height: int, width: int):
reference_latents = prepared_batch.get("conditioning_latents")
if isinstance(reference_latents, list):
Expand Down Expand Up @@ -477,20 +521,35 @@ def model_predict(self, prepared_batch):
timesteps = self._prepare_model_predict_timesteps(prepared_batch["timesteps"], batch_size)
position_ids = self._position_ids_for_grids(prompt_embeds.shape[1], grids, self.accelerator.device)

hidden_states_buffer = self._new_hidden_state_buffer()
call_kwargs = {}
if hidden_states_buffer is not None:
call_kwargs["hidden_states_buffer"] = hidden_states_buffer
self._apply_krea2_experimental_timestep_kwargs(call_kwargs, prepared_batch)

noise_pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=prompt_embeds,
timestep=timesteps,
position_ids=position_ids,
encoder_attention_mask=prompt_embeds_mask,
**call_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, :target_token_count]
noise_pred = self._unpack_latents(noise_pred, latent_height, latent_width)

if target_ndim == 5:
noise_pred = noise_pred.unsqueeze(2)
return {"model_prediction": noise_pred}
return {
"model_prediction": noise_pred,
"crepa_hidden_states": self._select_crepa_hidden_states(prepared_batch, hidden_states_buffer),
"hidden_states_buffer": hidden_states_buffer,
}

def pretrained_load_args(self, pretrained_load_args: dict) -> dict:
args = super().pretrained_load_args(pretrained_load_args)
return apply_musubi_pretrained_defaults(self.config, args)


ModelRegistry.register("krea2", Krea2)
51 changes: 50 additions & 1 deletion simpletuner/helpers/models/krea2/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@
"""


@torch.amp.autocast(
"mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu",
dtype=torch.float32,
)
def optimized_scale(positive_flat: torch.Tensor, negative_flat: torch.Tensor):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
return dot_product / squared_norm


# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
Expand Down Expand Up @@ -578,6 +588,13 @@ def __call__(
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
attention_kwargs: dict[str, Any] | None = None,
max_sequence_length: int = 512,
skip_guidance_layers: list[int] | None = None,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_scale: float = 2.8,
use_cfg_zero_star: bool = True,
zero_steps: int = 1,
use_zero_init: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -796,7 +813,39 @@ def __call__(
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, :target_image_seq_len]
noise_pred = noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
pos_noise_pred = noise_pred
if use_cfg_zero_star:
pos_flat = pos_noise_pred.view(pos_noise_pred.shape[0], -1)
neg_flat = neg_noise_pred.view(neg_noise_pred.shape[0], -1)
alpha = optimized_scale(pos_flat, neg_flat).view(
pos_noise_pred.shape[0], *([1] * (pos_noise_pred.dim() - 1))
)
if i <= zero_steps and use_zero_init:
noise_pred = pos_noise_pred * 0.0
else:
noise_pred = neg_noise_pred * alpha + guidance_scale * (pos_noise_pred - neg_noise_pred * alpha)
else:
noise_pred = pos_noise_pred + guidance_scale * (pos_noise_pred - neg_noise_pred)

should_skip_layers = (
skip_guidance_layers is not None
and len(skip_guidance_layers) > 0
and i > num_inference_steps * skip_layer_guidance_start
and i < num_inference_steps * skip_layer_guidance_stop
)
if should_skip_layers:
skip_noise_pred = self.transformer(
hidden_states=transformer_latents,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
position_ids=position_ids,
encoder_attention_mask=prompt_embeds_mask,
attention_kwargs=self.attention_kwargs,
skip_layers=skip_guidance_layers,
return_dict=False,
)[0]
skip_noise_pred = skip_noise_pred[:, :target_image_seq_len]
noise_pred = noise_pred + (pos_noise_pred - skip_noise_pred) * skip_layer_guidance_scale

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down
Loading
Loading