diff --git a/documentation/OPTIONS.es.md b/documentation/OPTIONS.es.md index 59e332411..06827b17b 100644 --- a/documentation/OPTIONS.es.md +++ b/documentation/OPTIONS.es.md @@ -129,7 +129,7 @@ Donde `foo` es tu entorno de configuración; o simplemente usa `config/config.js ### `--musubi_blocks_to_swap` -- **Qué**: Intercambio de bloques Musubi para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image y HunyuanVideo: mantiene los últimos N bloques del transformer en CPU y transmite pesos por bloque durante el forward. +- **Qué**: Intercambio de bloques Musubi para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo y Krea 2: mantiene los últimos N bloques del transformer en CPU y transmite pesos por bloque durante el forward. - **Predeterminado**: `0` (desactivado) - **Notas**: Offload de pesos estilo Musubi; reduce VRAM con coste en rendimiento y se omite cuando los gradientes están habilitados. diff --git a/documentation/OPTIONS.hi.md b/documentation/OPTIONS.hi.md index 3c7da2e07..cf48d21e7 100644 --- a/documentation/OPTIONS.hi.md +++ b/documentation/OPTIONS.hi.md @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json ### `--musubi_blocks_to_swap` -- **What**: LongCat‑Video, Wan, LTXVideo, Kandinsky5‑Video, Qwen‑Image, Flux, Flux.2, zlab i1, Cosmos2Image, और HunyuanVideo के लिए Musubi block swap — आख़िरी N transformer blocks को CPU पर रखें और forward के दौरान प्रति block weights stream करें। +- **What**: LongCat‑Video, Wan, LTXVideo, Kandinsky5‑Video, Qwen‑Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo, और Krea 2 के लिए Musubi block swap — आख़िरी N transformer blocks को CPU पर रखें और forward के दौरान प्रति block weights stream करें। - **Default**: `0` (disabled) - **Notes**: Musubi‑style weight offload; throughput लागत पर VRAM कम करता है और gradients सक्षम होने पर skip हो जाता है। diff --git a/documentation/OPTIONS.ja.md b/documentation/OPTIONS.ja.md index 8974e04f8..f300fb469 100644 --- a/documentation/OPTIONS.ja.md +++ b/documentation/OPTIONS.ja.md @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json ### `--musubi_blocks_to_swap` -- **内容**: LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo の Musubi ブロックスワップ。最後の N 個の Transformer ブロックを CPU に置き、forward 中にブロック単位で重みをストリーミングします。 +- **内容**: LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo、Krea 2 の Musubi ブロックスワップ。最後の N 個の Transformer ブロックを CPU に置き、forward 中にブロック単位で重みをストリーミングします。 - **既定**: `0`(無効) - **注記**: Musubi 方式の重みオフロードで、スループット低下と引き換えに VRAM を削減します。勾配が有効な場合はスキップされます。 - **注記**: Musubi 方式の重みオフロードで、スループット低下と引き換えに VRAM を削減します。勾配が有効な場合はスキップされます。 diff --git a/documentation/OPTIONS.md b/documentation/OPTIONS.md index 4c1a6c700..f33e75755 100644 --- a/documentation/OPTIONS.md +++ b/documentation/OPTIONS.md @@ -129,7 +129,7 @@ Where `foo` is your config environment - or just use `config/config.json` if you ### `--musubi_blocks_to_swap` -- **What**: Musubi block swap for LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, and HunyuanVideo — keep the last N transformer blocks on CPU and stream weights per block during forward. +- **What**: Musubi block swap for LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo, and Krea 2 — keep the last N transformer blocks on CPU and stream weights per block during forward. - **Default**: `0` (disabled) - **Notes**: Musubi-style weight offload; reduces VRAM at a throughput cost and is skipped when gradients are enabled. diff --git a/documentation/OPTIONS.pt-BR.md b/documentation/OPTIONS.pt-BR.md index f9827f1f7..b2f1dd369 100644 --- a/documentation/OPTIONS.pt-BR.md +++ b/documentation/OPTIONS.pt-BR.md @@ -129,7 +129,7 @@ Onde `foo` e seu ambiente de config — ou use `config/config.json` se nao estiv ### `--musubi_blocks_to_swap` -- **O que**: Musubi block swap para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image e HunyuanVideo — mantem os ultimos N blocos transformer na CPU e faz streaming de pesos por bloco durante o forward. +- **O que**: Musubi block swap para LongCat-Video, Wan, LTXVideo, Kandinsky5-Video, Qwen-Image, Flux, Flux.2, zlab i1, Cosmos2Image, HunyuanVideo e Krea 2 — mantem os ultimos N blocos transformer na CPU e faz streaming de pesos por bloco durante o forward. - **Padrao**: `0` (desabilitado) - **Notas**: Offload de pesos estilo Musubi; reduz VRAM com custo de throughput e e ignorado quando gradientes estao habilitados. diff --git a/documentation/OPTIONS.zh.md b/documentation/OPTIONS.zh.md index 71f29240b..c4b89744e 100644 --- a/documentation/OPTIONS.zh.md +++ b/documentation/OPTIONS.zh.md @@ -129,7 +129,7 @@ simpletuner configure config/foo/config.json ### `--musubi_blocks_to_swap` -- **内容**:为 LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo 提供 Musubi 块交换。将最后 N 个 Transformer 块保留在 CPU,并在前向中按块流式加载权重。 +- **内容**:为 LongCat-Video、Wan、LTXVideo、Kandinsky5-Video、Qwen-Image、Flux、Flux.2、zlab i1、Cosmos2Image、HunyuanVideo、Krea 2 提供 Musubi 块交换。将最后 N 个 Transformer 块保留在 CPU,并在前向中按块流式加载权重。 - **默认**:`0`(禁用) - **说明**:Musubi 风格权重卸载,会降低吞吐以换取显存节省,且在启用梯度时会跳过。 - **说明**:Musubi 风格权重卸载,通过降低吞吐换取显存节省;启用梯度时会跳过。 diff --git a/simpletuner/helpers/models/krea2/model.py b/simpletuner/helpers/models/krea2/model.py index e1e92718a..4d4e803b9 100644 --- a/simpletuner/helpers/models/krea2/model.py +++ b/simpletuner/helpers/models/krea2/model.py @@ -18,6 +18,7 @@ from simpletuner.helpers.models.krea2.transformer import Krea2Transformer2DModel from simpletuner.helpers.models.registry import ModelRegistry from simpletuner.helpers.models.tae.types import VideoTAESpec +from simpletuner.helpers.musubi_block_swap import apply_musubi_pretrained_defaults from simpletuner.helpers.training.state_tracker import StateTracker logger = logging.getLogger(__name__) @@ -63,6 +64,26 @@ class Krea2(ImageModelFoundation): DEFAULT_LORA_TARGET = ["to_k", "to_q", "to_v", "to_out.0"] FUSED_LORA_TARGET = ["to_qkv", "to_out.0"] + def tread_init(self): + from simpletuner.helpers.training.tread import TREADRouter + + tread_cfg = getattr(self.config, "tread_config", None) + if not isinstance(tread_cfg, dict) or tread_cfg == {} or tread_cfg.get("routes") is None: + logger.error("TREAD training requires you to configure the routes in the TREAD config") + import sys + + sys.exit(1) + + self.unwrap_model(model=self.model).set_router( + TREADRouter( + seed=getattr(self.config, "seed", None) or 42, + device=self.accelerator.device, + ), + tread_cfg["routes"], + ) + + logger.info("TREAD training is enabled for Krea 2") + @classmethod def max_swappable_blocks(cls, config=None) -> Optional[int]: return 27 @@ -78,6 +99,24 @@ def _uses_reference_latents(self) -> bool: def supports_conditioning_dataset(self) -> bool: return True + def supports_crepa_self_flow(self) -> bool: + return True + + def _prepare_crepa_self_flow_batch(self, batch: dict, state: dict) -> dict: + return self._prepare_image_crepa_self_flow_batch(batch, state, patch_size=self._patch_size()) + + def _select_crepa_hidden_states(self, prepared_batch: dict, hidden_states_buffer): + if hidden_states_buffer is None: + return None + crepa = getattr(self, "crepa_regularizer", None) + capture_layer = prepared_batch.get( + "crepa_capture_block_index", + getattr(crepa, "block_index", None), + ) + if capture_layer is None: + return None + return hidden_states_buffer.get(f"layer_{int(capture_layer)}") + def requires_conditioning_dataset(self) -> bool: return self._uses_reference_latents() @@ -419,6 +458,11 @@ def _prepare_model_predict_timesteps(self, raw_timesteps, batch_size: int) -> to raise ValueError(f"Krea 2 expected scalar or 1D timesteps, got shape {tuple(timesteps.shape)}.") return timesteps / 1000.0 + def _apply_krea2_experimental_timestep_kwargs(self, call_kwargs: dict, prepared_batch: dict) -> None: + if getattr(self.config, "twinflow_enabled", False): + call_kwargs["timestep_sign"] = prepared_batch.get("twinflow_time_sign") + self._apply_flowmap_r_timestep_kwargs(call_kwargs, prepared_batch) + def _prepare_reference_latents(self, prepared_batch: dict, batch_size: int, channels: int, height: int, width: int): reference_latents = prepared_batch.get("conditioning_latents") if isinstance(reference_latents, list): @@ -477,12 +521,19 @@ def model_predict(self, prepared_batch): timesteps = self._prepare_model_predict_timesteps(prepared_batch["timesteps"], batch_size) position_ids = self._position_ids_for_grids(prompt_embeds.shape[1], grids, self.accelerator.device) + hidden_states_buffer = self._new_hidden_state_buffer() + call_kwargs = {} + if hidden_states_buffer is not None: + call_kwargs["hidden_states_buffer"] = hidden_states_buffer + self._apply_krea2_experimental_timestep_kwargs(call_kwargs, prepared_batch) + noise_pred = self.model( hidden_states=hidden_states, encoder_hidden_states=prompt_embeds, timestep=timesteps, position_ids=position_ids, encoder_attention_mask=prompt_embeds_mask, + **call_kwargs, return_dict=False, )[0] noise_pred = noise_pred[:, :target_token_count] @@ -490,7 +541,15 @@ def model_predict(self, prepared_batch): if target_ndim == 5: noise_pred = noise_pred.unsqueeze(2) - return {"model_prediction": noise_pred} + return { + "model_prediction": noise_pred, + "crepa_hidden_states": self._select_crepa_hidden_states(prepared_batch, hidden_states_buffer), + "hidden_states_buffer": hidden_states_buffer, + } + + def pretrained_load_args(self, pretrained_load_args: dict) -> dict: + args = super().pretrained_load_args(pretrained_load_args) + return apply_musubi_pretrained_defaults(self.config, args) ModelRegistry.register("krea2", Krea2) diff --git a/simpletuner/helpers/models/krea2/pipeline.py b/simpletuner/helpers/models/krea2/pipeline.py index a2b53292d..4832ad05c 100644 --- a/simpletuner/helpers/models/krea2/pipeline.py +++ b/simpletuner/helpers/models/krea2/pipeline.py @@ -59,6 +59,16 @@ """ +@torch.amp.autocast( + "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.float32, +) +def optimized_scale(positive_flat: torch.Tensor, negative_flat: torch.Tensor): + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + return dot_product / squared_norm + + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -578,6 +588,13 @@ def __call__( callback_on_step_end_tensor_inputs: list[str] = ["latents"], attention_kwargs: dict[str, Any] | None = None, max_sequence_length: int = 512, + skip_guidance_layers: list[int] | None = None, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_scale: float = 2.8, + use_cfg_zero_star: bool = True, + zero_steps: int = 1, + use_zero_init: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -796,7 +813,39 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, :target_image_seq_len] - noise_pred = noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + pos_noise_pred = noise_pred + if use_cfg_zero_star: + pos_flat = pos_noise_pred.view(pos_noise_pred.shape[0], -1) + neg_flat = neg_noise_pred.view(neg_noise_pred.shape[0], -1) + alpha = optimized_scale(pos_flat, neg_flat).view( + pos_noise_pred.shape[0], *([1] * (pos_noise_pred.dim() - 1)) + ) + if i <= zero_steps and use_zero_init: + noise_pred = pos_noise_pred * 0.0 + else: + noise_pred = neg_noise_pred * alpha + guidance_scale * (pos_noise_pred - neg_noise_pred * alpha) + else: + noise_pred = pos_noise_pred + guidance_scale * (pos_noise_pred - neg_noise_pred) + + should_skip_layers = ( + skip_guidance_layers is not None + and len(skip_guidance_layers) > 0 + and i > num_inference_steps * skip_layer_guidance_start + and i < num_inference_steps * skip_layer_guidance_stop + ) + if should_skip_layers: + skip_noise_pred = self.transformer( + hidden_states=transformer_latents, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + position_ids=position_ids, + encoder_attention_mask=prompt_embeds_mask, + attention_kwargs=self.attention_kwargs, + skip_layers=skip_guidance_layers, + return_dict=False, + )[0] + skip_noise_pred = skip_noise_pred[:, :target_image_seq_len] + noise_pred = noise_pred + (pos_noise_pred - skip_noise_pred) * skip_layer_guidance_scale # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/simpletuner/helpers/models/krea2/transformer.py b/simpletuner/helpers/models/krea2/transformer.py index 1a29303d8..0d77cdbce 100644 --- a/simpletuner/helpers/models/krea2/transformer.py +++ b/simpletuner/helpers/models/krea2/transformer.py @@ -29,6 +29,17 @@ from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import apply_lora_scale, logging +from simpletuner.helpers.models.flowmap import ( + blend_flowmap_embeddings, + clone_flowmap_embedder, + prepare_flowmap_delta_timestep, + register_flowmap_config, + set_flowmap_gate, + validate_flowmap_deltatime_type, +) +from simpletuner.helpers.musubi_block_swap import MusubiBlockSwapManager +from simpletuner.helpers.training.tread import TREADRouter + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -36,6 +47,15 @@ def _krea2_rope_freqs_dtype(device: torch.device) -> torch.dtype: return torch.float32 if device.type in {"mps", "neuron", "npu"} else torch.float64 +def _store_hidden_state(buffer, key: str, hidden_states: torch.Tensor, image_tokens_start: int | None = None) -> None: + if buffer is None: + return + capture = hidden_states + if image_tokens_start is not None: + capture = capture[:, image_tokens_start:] + buffer[key] = capture + + class Krea2RMSNorm(nn.Module): """RMSNorm with a zero-centered scale: the effective multiplier is `1 + weight`, matching the Krea 2 checkpoint format. The activations are upcast so the normalization runs in float32; the scale weight is kept in float32 by the @@ -281,18 +301,75 @@ class Krea2TimestepEmbedding(nn.Module): Keeps the sequence dimension at size 1 so the per-block modulations broadcast over tokens. """ - def __init__(self, embed_dim: int, hidden_size: int) -> None: + def __init__(self, embed_dim: int, hidden_size: int, enable_time_sign_embed: bool = False) -> None: super().__init__() self.embed_dim = embed_dim self.linear_1 = nn.Linear(embed_dim, hidden_size, bias=True) self.linear_2 = nn.Linear(hidden_size, hidden_size, bias=True) - - def forward(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + self.delta_timestep_embedder: nn.Module | None = None + self.flowmap_deltatime_type: str | None = None + self.register_buffer("flowmap_delta_emb_gate", torch.tensor([0.25], dtype=torch.float32), persistent=False) + self.time_sign_embed: nn.Embedding | None = None + if enable_time_sign_embed: + self.time_sign_embed = nn.Embedding(2, hidden_size) + nn.init.zeros_(self.time_sign_embed.weight) + + def _embed_timestep(self, timestep: torch.Tensor, dtype: torch.dtype, embedder: nn.Module | None = None) -> torch.Tensor: + embedder = embedder or self half = self.embed_dim // 2 freqs = torch.exp(-math.log(1e4) * torch.arange(half, dtype=torch.float32, device=timestep.device) / half) args = (timestep.float() * 1e3)[:, None, None] * freqs emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) - return self.linear_2(F.gelu(self.linear_1(emb), approximate="tanh")) + return embedder.linear_2(F.gelu(embedder.linear_1(emb), approximate="tanh")) + + def enable_flowmap_time_conditioning(self, gate_value: float = 0.25, deltatime_type: str = "r") -> None: + self.flowmap_deltatime_type = validate_flowmap_deltatime_type(deltatime_type, model_name="Krea 2") + if self.delta_timestep_embedder is None: + self.delta_timestep_embedder = clone_flowmap_embedder(self) + set_flowmap_gate(self, gate_value) + + def forward( + self, + timestep: torch.Tensor, + dtype: torch.dtype, + timestep_sign: torch.Tensor | None = None, + r_timestep: torch.Tensor | None = None, + ) -> torch.Tensor: + temb = self._embed_timestep(timestep, dtype) + if r_timestep is not None: + if self.delta_timestep_embedder is None or self.flowmap_deltatime_type is None: + raise ValueError( + "Krea 2 FlowMap conditioning requires `enable_flowmap_time_conditioning()` before training." + ) + delta_timestep = prepare_flowmap_delta_timestep( + timestep, + r_timestep, + self.flowmap_deltatime_type, + model_name="Krea 2", + ) + delta_emb = self._embed_timestep(delta_timestep, dtype, embedder=self.delta_timestep_embedder) + temb = blend_flowmap_embeddings(temb, delta_emb, self.flowmap_delta_emb_gate) + if timestep_sign is not None: + if self.time_sign_embed is None: + raise ValueError( + "timestep_sign was provided but the model was loaded without `enable_time_sign_embed=True`. " + "Enable TwinFlow (or load a TwinFlow-compatible checkpoint) to use signed-timestep conditioning." + ) + sign_tensor = timestep_sign.to(device=temb.device) + if sign_tensor.ndim == 0: + sign_tensor = sign_tensor.expand(temb.shape[0]) + elif sign_tensor.ndim == 1: + if sign_tensor.shape[0] == 1: + sign_tensor = sign_tensor.expand(temb.shape[0]) + elif sign_tensor.shape[0] != temb.shape[0]: + raise ValueError( + f"Krea 2 timestep_sign expected 1 or {temb.shape[0]} batch values, got {sign_tensor.shape[0]}." + ) + else: + raise ValueError(f"Krea 2 timestep_sign expected scalar or 1D batch tensor, got {tuple(sign_tensor.shape)}.") + sign_idx = (sign_tensor.view(-1) < 0).long().to(device=temb.device) + temb = temb + self.time_sign_embed(sign_idx).to(dtype=temb.dtype, device=temb.device)[:, None, :] + return temb class Krea2TextProjection(nn.Module): @@ -428,6 +505,11 @@ def __init__( axes_dims_rope: tuple[int, int, int] = (32, 48, 48), rope_theta: float = 1000.0, norm_eps: float = 1e-5, + enable_time_sign_embed: bool = False, + musubi_blocks_to_swap: int = 0, + musubi_block_swap_device: str = "cpu", + gate_value: float | None = None, + deltatime_type: str | None = None, ) -> None: super().__init__() @@ -441,7 +523,14 @@ def __init__( self.gradient_checkpointing = False self.img_in = nn.Linear(in_channels, hidden_size, bias=True) - self.time_embed = Krea2TimestepEmbedding(timestep_embed_dim, hidden_size) + self.time_embed = Krea2TimestepEmbedding( + timestep_embed_dim, hidden_size, enable_time_sign_embed=enable_time_sign_embed + ) + if deltatime_type is not None: + self.enable_flowmap_time_conditioning( + gate_value=0.25 if gate_value is None else float(gate_value), + deltatime_type=deltatime_type, + ) self.time_mod_proj = nn.Linear(hidden_size, 6 * hidden_size, bias=True) self.text_fusion = Krea2TextFusion( num_text_layers=num_text_layers, @@ -470,6 +559,22 @@ def __init__( ) self.final_layer = Krea2FinalLayer(hidden_size, out_channels=in_channels, eps=norm_eps) + self._musubi_block_swap = MusubiBlockSwapManager.build( + depth=len(self.transformer_blocks), + blocks_to_swap=musubi_blocks_to_swap, + swap_device=musubi_block_swap_device, + logger=logger, + ) + self._tread_router: TREADRouter | None = None + self._tread_routes: list[dict[str, Any]] | None = None + + def set_router(self, router: TREADRouter, routes: list[dict[str, Any]] | None = None) -> None: + self._tread_router = router + self._tread_routes = routes + + def enable_flowmap_time_conditioning(self, gate_value: float = 0.25, deltatime_type: str = "r") -> None: + self.time_embed.enable_flowmap_time_conditioning(gate_value=gate_value, deltatime_type=deltatime_type) + register_flowmap_config(self, gate_value, deltatime_type) def fuse_qkv_projections(self, preferred_backend: str | None = None) -> None: del preferred_backend @@ -491,6 +596,10 @@ def forward( position_ids: torch.Tensor, encoder_attention_mask: torch.Tensor | None = None, attention_kwargs: dict[str, Any] | None = None, + hidden_states_buffer: dict | None = None, + timestep_sign: torch.Tensor | None = None, + r_timestep: torch.Tensor | None = None, + skip_layers: list[int] | None = None, return_dict: bool = True, ) -> Transformer2DModelOutput | tuple[torch.Tensor]: r""" @@ -524,7 +633,7 @@ def forward( batch_size, image_seq_len, _ = hidden_states.shape text_seq_len = encoder_hidden_states.shape[1] - temb = self.time_embed(timestep, dtype=hidden_states.dtype) + temb = self.time_embed(timestep, dtype=hidden_states.dtype, timestep_sign=timestep_sign, r_timestep=r_timestep) temb_mod = self.time_mod_proj(F.gelu(temb, approximate="tanh")) text_attention_mask = None @@ -545,13 +654,89 @@ def forward( image_rotary_emb = self.rotary_emb(position_ids) - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - block, hidden_states, temb_mod, image_rotary_emb, attention_mask + routes = self._tread_routes or [] + router = self._tread_router + use_routing = self.training and len(routes) > 0 and torch.is_grad_enabled() + if use_routing and router is None: + raise ValueError("TREAD routing requested but no router has been configured. Call set_router before training.") + if routes: + total_layers = len(self.transformer_blocks) + + def _to_pos(idx): + return idx if idx >= 0 else total_layers + idx + + routes = [ + { + **route, + "start_layer_idx": _to_pos(route["start_layer_idx"]), + "end_layer_idx": _to_pos(route["end_layer_idx"]), + } + for route in routes + ] + + route_ptr = 0 + routing_now = False + tread_mask_info = None + saved_tokens = None + saved_attention_mask = None + saved_rotary_cos = None + saved_rotary_sin = None + + skip_set = set(skip_layers) if skip_layers is not None else set() + combined_blocks = list(self.transformer_blocks) + musubi_manager = self._musubi_block_swap + musubi_offload_active = False + if musubi_manager is not None: + musubi_offload_active = musubi_manager.activate(combined_blocks, hidden_states.device, torch.is_grad_enabled()) + + for idx, block in enumerate(self.transformer_blocks): + if musubi_offload_active and musubi_manager.is_managed_block(idx): + musubi_manager.stream_in(block, hidden_states.device) + + if use_routing and route_ptr < len(routes) and idx == routes[route_ptr]["start_layer_idx"]: + mask_ratio = routes[route_ptr]["selection_ratio"] + force_keep = torch.zeros( + (batch_size, hidden_states.shape[1]), + dtype=torch.bool, + device=hidden_states.device, ) + force_keep[:, :text_seq_len] = True + tread_mask_info = router.get_mask(hidden_states, mask_ratio=mask_ratio, force_keep=force_keep) + saved_tokens = hidden_states.clone() + saved_attention_mask = attention_mask + saved_rotary_cos, saved_rotary_sin = image_rotary_emb + hidden_states = router.start_route(hidden_states, tread_mask_info) + image_rotary_emb = ( + router.start_route(saved_rotary_cos.unsqueeze(0).expand(batch_size, -1, -1), tread_mask_info)[0], + router.start_route(saved_rotary_sin.unsqueeze(0).expand(batch_size, -1, -1), tread_mask_info)[0], + ) + attention_mask = None + routing_now = True + + if torch.is_grad_enabled() and self.gradient_checkpointing: + if idx in skip_set: + block_output = hidden_states + else: + block_output = self._gradient_checkpointing_func( + block, hidden_states, temb_mod, image_rotary_emb, attention_mask + ) else: - hidden_states = block(hidden_states, temb_mod, image_rotary_emb, attention_mask) + block_output = ( + hidden_states if idx in skip_set else block(hidden_states, temb_mod, image_rotary_emb, attention_mask) + ) + hidden_states = block_output + + if routing_now and route_ptr < len(routes) and idx == routes[route_ptr]["end_layer_idx"]: + hidden_states = router.end_route(hidden_states, tread_mask_info, original_x=saved_tokens) + image_rotary_emb = (saved_rotary_cos, saved_rotary_sin) + attention_mask = saved_attention_mask + routing_now = False + route_ptr += 1 + + _store_hidden_state(hidden_states_buffer, f"layer_{idx}", hidden_states, image_tokens_start=text_seq_len) + + if musubi_offload_active and musubi_manager.is_managed_block(idx): + musubi_manager.stream_out(block) hidden_states = hidden_states[:, text_seq_len:] output = self.final_layer(hidden_states, temb) diff --git a/tests/test_krea2_model.py b/tests/test_krea2_model.py index 11eadf486..64b0287f1 100644 --- a/tests/test_krea2_model.py +++ b/tests/test_krea2_model.py @@ -8,6 +8,7 @@ from simpletuner.helpers.models.krea2 import Krea2, Krea2LoraLoaderMixin, Krea2Pipeline, Krea2Transformer2DModel from simpletuner.helpers.models.krea2.transformer import Krea2Attention, _krea2_rope_freqs_dtype from simpletuner.helpers.models.registry import ModelRegistry +from simpletuner.helpers.training.tread import TREADRouter class FakeKrea2Transformer: @@ -81,6 +82,12 @@ def test_pipeline_accepts_reference_image_for_validation(self): parameters = inspect.signature(Krea2Pipeline.__call__).parameters self.assertIn("reference_image", parameters) + def test_pipeline_accepts_cfg_zero_star_and_skip_layer_guidance(self): + parameters = inspect.signature(Krea2Pipeline.__call__).parameters + self.assertIn("use_cfg_zero_star", parameters) + self.assertIn("skip_guidance_layers", parameters) + self.assertIn("skip_layer_guidance_scale", parameters) + def test_reference_latents_enable_reference_dataset_hooks(self): model = Krea2.__new__(Krea2) model.config = SimpleNamespace(krea2_reference_latents=True) @@ -136,6 +143,162 @@ def test_krea2_attention_fused_projection_matches_unfused_path(self): self.assertFalse(hasattr(attention, "to_qkv")) self.assertFalse(attention.fused_projections) + def test_model_supports_crepa_self_flow(self): + model = Krea2.__new__(Krea2) + + self.assertTrue(model.supports_crepa_self_flow()) + + def test_pretrained_load_args_enable_twinflow_and_musubi(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace( + twinflow_enabled=True, + musubi_blocks_to_swap=3, + musubi_block_swap_device="cpu", + ) + + args = model.pretrained_load_args({}) + + self.assertTrue(args["enable_time_sign_embed"]) + self.assertEqual(args["musubi_blocks_to_swap"], 3) + self.assertEqual(args["musubi_block_swap_device"], "cpu") + + def test_transformer_captures_hidden_states_and_accepts_skip_layers(self): + transformer = Krea2Transformer2DModel( + in_channels=4, + num_layers=2, + attention_head_dim=6, + num_attention_heads=1, + num_key_value_heads=1, + intermediate_size=8, + timestep_embed_dim=8, + text_hidden_dim=6, + num_text_layers=2, + text_num_attention_heads=1, + text_num_key_value_heads=1, + text_intermediate_size=8, + num_layerwise_text_blocks=1, + num_refiner_text_blocks=1, + axes_dims_rope=(2, 2, 2), + ) + hidden_states_buffer = {} + + output = transformer( + hidden_states=torch.randn(1, 4, 4), + encoder_hidden_states=torch.randn(1, 3, 2, 6), + timestep=torch.tensor([0.5]), + position_ids=torch.zeros(7, 3, dtype=torch.long), + encoder_attention_mask=torch.ones(1, 3, dtype=torch.int64), + skip_layers=[1], + hidden_states_buffer=hidden_states_buffer, + return_dict=False, + ) + + self.assertEqual(tuple(output[0].shape), (1, 4, 4)) + self.assertIn("layer_0", hidden_states_buffer) + self.assertIn("layer_1", hidden_states_buffer) + self.assertEqual(tuple(hidden_states_buffer["layer_0"].shape[:2]), (1, 4)) + + def test_transformer_accepts_tread_routing(self): + transformer = Krea2Transformer2DModel( + in_channels=4, + num_layers=2, + attention_head_dim=6, + num_attention_heads=1, + num_key_value_heads=1, + intermediate_size=8, + timestep_embed_dim=8, + text_hidden_dim=6, + num_text_layers=2, + text_num_attention_heads=1, + text_num_key_value_heads=1, + text_intermediate_size=8, + num_layerwise_text_blocks=1, + num_refiner_text_blocks=1, + axes_dims_rope=(2, 2, 2), + ) + transformer.train() + transformer.set_router( + TREADRouter(seed=1, device=torch.device("cpu")), + [{"start_layer_idx": 0, "end_layer_idx": 0, "selection_ratio": 0.5}], + ) + + with torch.enable_grad(): + output = transformer( + hidden_states=torch.randn(1, 4, 4), + encoder_hidden_states=torch.randn(1, 3, 2, 6), + timestep=torch.tensor([0.5]), + position_ids=torch.zeros(7, 3, dtype=torch.long), + encoder_attention_mask=torch.ones(1, 3, dtype=torch.int64), + return_dict=False, + ) + + self.assertEqual(tuple(output[0].shape), (1, 4, 4)) + + def test_transformer_supports_twinflow_and_anyflow_timestep_conditioning(self): + transformer = Krea2Transformer2DModel( + in_channels=4, + num_layers=1, + attention_head_dim=6, + num_attention_heads=1, + num_key_value_heads=1, + intermediate_size=8, + timestep_embed_dim=8, + text_hidden_dim=6, + num_text_layers=2, + text_num_attention_heads=1, + text_num_key_value_heads=1, + text_intermediate_size=8, + num_layerwise_text_blocks=1, + num_refiner_text_blocks=1, + axes_dims_rope=(2, 2, 2), + enable_time_sign_embed=True, + ) + transformer.enable_flowmap_time_conditioning(gate_value=0.25, deltatime_type="r") + kwargs = { + "hidden_states": torch.randn(1, 4, 4), + "encoder_hidden_states": torch.randn(1, 3, 2, 6), + "timestep": torch.tensor([0.5]), + "position_ids": torch.zeros(7, 3, dtype=torch.long), + "encoder_attention_mask": torch.ones(1, 3, dtype=torch.int64), + "return_dict": False, + } + + output = transformer(**kwargs, timestep_sign=torch.ones(1), r_timestep=torch.zeros(1)) + + self.assertEqual(tuple(output[0].shape), (1, 4, 4)) + + def test_transformer_requires_initialization_for_timestep_conditioning(self): + transformer = Krea2Transformer2DModel( + in_channels=4, + num_layers=1, + attention_head_dim=6, + num_attention_heads=1, + num_key_value_heads=1, + intermediate_size=8, + timestep_embed_dim=8, + text_hidden_dim=6, + num_text_layers=2, + text_num_attention_heads=1, + text_num_key_value_heads=1, + text_intermediate_size=8, + num_layerwise_text_blocks=1, + num_refiner_text_blocks=1, + axes_dims_rope=(2, 2, 2), + ) + kwargs = { + "hidden_states": torch.randn(1, 4, 4), + "encoder_hidden_states": torch.randn(1, 3, 2, 6), + "timestep": torch.tensor([0.5]), + "position_ids": torch.zeros(7, 3, dtype=torch.long), + "encoder_attention_mask": torch.ones(1, 3, dtype=torch.int64), + "return_dict": False, + } + + with self.assertRaisesRegex(ValueError, "enable_time_sign_embed"): + transformer(**kwargs, timestep_sign=torch.ones(1)) + with self.assertRaisesRegex(ValueError, "FlowMap"): + transformer(**kwargs, r_timestep=torch.zeros(1)) + def test_vae_encode_hooks_use_qwen_image_vae_rank_and_normalization(self): model = Krea2.__new__(Krea2) vae = SimpleNamespace(config=SimpleNamespace(latents_mean=[1.0, 2.0], latents_std=[2.0, 4.0], z_dim=2))