diff --git a/documentation/OPTIONS.es.md b/documentation/OPTIONS.es.md index 8af20da7c..59e332411 100644 --- a/documentation/OPTIONS.es.md +++ b/documentation/OPTIONS.es.md @@ -891,6 +891,13 @@ Diferentes modelos esperan diferentes datos de conditioning: - **Nota**: Al usar `combined`, no puedes definir `captions` separados en datasets de condicionamiento; se usan las captions del dataset fuente. - **Ver también**: [DATALOADER.md](DATALOADER.md#conditioning_data) para configurar múltiples datasets de condicionamiento. +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **Qué**: Habilita entrenamiento de Krea 2 con reference dataset. +- **Por qué**: Cuando está habilitado, Krea 2 usa la imagen de condicionamiento emparejada al cachear los prompt embeddings de Qwen3VL, y añade los latentes VAE limpios de esa imagen al flujo de tokens del transformer durante el entrenamiento. +- **Dataset**: Configura el dataset principal de imágenes con `conditioning_data` apuntando a un dataset de condicionamiento emparejado. Los nombres de archivo deben coincidir entre las imágenes target y reference. +- **Alcance**: Esta es una opción del lado del modelo Krea 2. No genera datasets de condicionamiento; usa los ajustes normales de reference dataset del dataloader. + ### Opciones de condicionamiento de LTX-2 Estos son ajustes avanzados opcionales para entrenamiento LTX-2. Úsalos en archivos JSON/TOML con los nombres indicados, o mediante los flags CLI equivalentes, por ejemplo `--ltx2_first_frame_conditioning_probability`. diff --git a/documentation/OPTIONS.hi.md b/documentation/OPTIONS.hi.md index 22c643811..3c7da2e07 100644 --- a/documentation/OPTIONS.hi.md +++ b/documentation/OPTIONS.hi.md @@ -889,6 +889,13 @@ Different models different conditioning data expect करते हैं: - **Note**: `combined` उपयोग करने पर आप conditioning datasets पर अलग `captions` परिभाषित नहीं कर सकते; source dataset के captions ही उपयोग होते हैं। - **See also**: multiple conditioning datasets कॉन्फ़िगर करने के लिए [DATALOADER.md](DATALOADER.md#conditioning_data) देखें। +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **What**: Krea 2 reference dataset training enable करें। +- **Why**: Enable होने पर Krea 2 Qwen3VL prompt embeddings cache करते समय paired conditioning image use करता है, और training के दौरान उस conditioning image के clean VAE latents को transformer token stream में append करता है। +- **Dataset setup**: Main image dataset में `conditioning_data` को paired conditioning dataset की ओर point करें। Target और reference images के filenames match होने चाहिए। +- **Scope**: यह Krea 2 model-side option है। यह conditioning datasets generate नहीं करता; उसके लिए normal dataloader reference-dataset settings use करें। + ### LTX-2 conditioning options ये LTX-2 training के optional advanced settings हैं। इन्हें JSON/TOML config files में नीचे दिए गए names से सेट करें, या matching CLI flags जैसे `--ltx2_first_frame_conditioning_probability` से पास करें। diff --git a/documentation/OPTIONS.ja.md b/documentation/OPTIONS.ja.md index b3431f3f2..8974e04f8 100644 --- a/documentation/OPTIONS.ja.md +++ b/documentation/OPTIONS.ja.md @@ -891,6 +891,13 @@ Flux Kontext の検証もこのコンディショニングベースの経路を - **注記**: `combined` を使う場合、条件データセットに個別の `captions` は定義できません。ソースデータセットのキャプションが使用されます。 - **参照**: 複数条件データセットの設定は [DATALOADER.md](DATALOADER.md#conditioning_data) を参照してください。 +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **内容**: Krea 2 の reference dataset training を有効化します。 +- **理由**: 有効にすると、Krea 2 は Qwen3VL prompt embeddings をキャッシュするときにペアの conditioning image を使い、training 時にはその conditioning image の clean VAE latents を transformer token sequence に追加します。 +- **データセット設定**: main image dataset の `conditioning_data` をペアの conditioning dataset に向けます。target image と reference image のファイル名は一致している必要があります。 +- **範囲**: これは Krea 2 の model-side option です。conditioning datasets は生成しません。通常の dataloader reference-dataset 設定を使ってください。 + ### LTX-2 conditioning options これは LTX-2 training 用の任意の advanced setting です。下記の名前で JSON/TOML config に書くか、`--ltx2_first_frame_conditioning_probability` のような対応する CLI flag で指定します。 diff --git a/documentation/OPTIONS.md b/documentation/OPTIONS.md index e8644ed2a..4c1a6c700 100644 --- a/documentation/OPTIONS.md +++ b/documentation/OPTIONS.md @@ -896,6 +896,13 @@ Different models expect different conditioning data: - **Note**: When using `combined`, you cannot define separate `captions` on conditioning datasets; the source dataset's captions are used instead. - **See also**: [DATALOADER.md](DATALOADER.md#conditioning_data) for configuring multiple conditioning datasets. +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **What**: Enable Krea 2 reference-dataset training. +- **Why**: When enabled, Krea 2 uses the paired conditioning image while caching Qwen3VL prompt embeddings, and appends the clean VAE latents from that conditioning image to the transformer token stream during training. +- **Dataset setup**: Configure the main image dataset with `conditioning_data` pointing at a paired conditioning dataset. Filenames must match between target and reference images. +- **Scope**: This is a Krea 2 model-side option. It does not generate conditioning datasets; use the normal dataloader reference-dataset settings for that. + ### LTX-2 conditioning options These are optional advanced settings for LTX-2 training. Set them in JSON/TOML config files with the names below, or pass the matching CLI flags such as `--ltx2_first_frame_conditioning_probability`. diff --git a/documentation/OPTIONS.pt-BR.md b/documentation/OPTIONS.pt-BR.md index 51999ba1d..f9827f1f7 100644 --- a/documentation/OPTIONS.pt-BR.md +++ b/documentation/OPTIONS.pt-BR.md @@ -887,6 +887,13 @@ Diferentes modelos esperam diferentes dados de conditioning: - **Nota**: Ao usar `combined`, voce nao pode definir `captions` separadas nos datasets de condicionamento; as captions do dataset de origem sao usadas. - **Veja tambem**: [DATALOADER.md](DATALOADER.md#conditioning_data) para configurar multiplos datasets de condicionamento. +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **O que**: Habilita treino Krea 2 com reference dataset. +- **Por que**: Quando habilitado, Krea 2 usa a imagem de condicionamento pareada ao cachear os prompt embeddings Qwen3VL, e adiciona os latentes VAE limpos dessa imagem ao fluxo de tokens do transformer durante o treino. +- **Dataset**: Configure o dataset principal de imagens com `conditioning_data` apontando para um dataset de conditioning pareado. Os nomes dos arquivos devem coincidir entre imagens target e reference. +- **Escopo**: Esta e uma opcao do lado do modelo Krea 2. Ela nao gera conditioning datasets; use as configuracoes normais de reference dataset do dataloader. + ### Opcoes de condicionamento do LTX-2 Estas sao configuracoes avancadas opcionais para treino LTX-2. Defina-as em JSON/TOML com os nomes abaixo, ou passe os flags CLI correspondentes, como `--ltx2_first_frame_conditioning_probability`. diff --git a/documentation/OPTIONS.zh.md b/documentation/OPTIONS.zh.md index 33910033c..71f29240b 100644 --- a/documentation/OPTIONS.zh.md +++ b/documentation/OPTIONS.zh.md @@ -893,6 +893,13 @@ Flux Kontext 的验证也始终走这条基于条件的路径。使用 `--eval_d - **说明**:使用 `combined` 时不能在条件数据集中定义单独的 `captions`,会使用源数据集的字幕。 - **另见**:[DATALOADER.md](DATALOADER.md#conditioning_data) 获取多条件数据集配置说明。 +### `--krea2_reference_latents` {#--krea2_reference_latents} + +- **内容**:启用 Krea 2 reference dataset 训练。 +- **原因**:启用后,Krea 2 在缓存 Qwen3VL prompt embeddings 时使用配对的条件图像,并在训练时把该条件图像的 clean VAE latents 追加到 transformer token 序列中。 +- **数据集设置**:主 image dataset 的 `conditioning_data` 应指向配对的 conditioning dataset。目标图像和参考图像的文件名必须匹配。 +- **范围**:这是 Krea 2 的模型侧选项。它不会生成 conditioning datasets;请使用常规 dataloader reference-dataset 设置。 + ### LTX-2 条件选项 这些是 LTX-2 训练的可选高级设置。可以在 JSON/TOML 配置中使用下面的名称,也可以使用对应的 CLI flag,例如 `--ltx2_first_frame_conditioning_probability`。 diff --git a/documentation/QUICKSTART.es.md b/documentation/QUICKSTART.es.md index 9747c7a4b..4df1367cc 100644 --- a/documentation/QUICKSTART.es.md +++ b/documentation/QUICKSTART.es.md @@ -18,6 +18,7 @@ Para la matriz de funciones completa y más precisa, consulta el [README princip | Flux.2 | 32B | ✓ | ✓ | ✓* | int8/fp8/nf4 opcional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓ | ✓* | int8/fp8/nf4 opcional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✗ | ✓* | int8 opcional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✗ | ✓* | int8 opcional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](quickstart/KREA2.es.md) | | Boogu-Image 0.1 | - | ✓ | ✓ | ✓* | fp8 opcional | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](quickstart/BOOGU_IMAGE.es.md) | | zlab i1 | 3B | ✓ | ✓ | ✓ | int8 opcional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](quickstart/ZLAB_i1.es.md) | | Ideogram 4 | 9B | ✓ | ✓ | ✓* | fp8 predeterminado, nf4 opcional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](quickstart/IDEOGRAM4.es.md) | diff --git a/documentation/QUICKSTART.hi.md b/documentation/QUICKSTART.hi.md index 3d4bf1fd5..2398850f5 100644 --- a/documentation/QUICKSTART.hi.md +++ b/documentation/QUICKSTART.hi.md @@ -18,6 +18,7 @@ | Flux.2 | 32B | ✓ | ✓ | ✓* | int8/fp8/nf4 वैकल्पिक | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓ | ✓* | int8/fp8/nf4 वैकल्पिक | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✗ | ✓* | int8 वैकल्पिक | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✗ | ✓* | int8 वैकल्पिक | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](quickstart/KREA2.hi.md) | | Boogu-Image 0.1 | - | ✓ | ✓ | ✓* | fp8 वैकल्पिक | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](quickstart/BOOGU_IMAGE.hi.md) | | zlab i1 | 3B | ✓ | ✓ | ✓ | int8 वैकल्पिक | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](quickstart/ZLAB_i1.hi.md) | | Ideogram 4 | 9B | ✓ | ✓ | ✓* | fp8 डिफ़ॉल्ट, nf4 वैकल्पिक | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](quickstart/IDEOGRAM4.hi.md) | diff --git a/documentation/QUICKSTART.ja.md b/documentation/QUICKSTART.ja.md index ceb0d14cd..d910c33fd 100644 --- a/documentation/QUICKSTART.ja.md +++ b/documentation/QUICKSTART.ja.md @@ -18,6 +18,7 @@ | Flux.2 | 32B | ✓ | ✓ | ✓* | int8/fp8/nf4 オプション | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓ | ✓* | int8/fp8/nf4 オプション | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✗ | ✓* | int8 オプション | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✗ | ✓* | int8 オプション | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](quickstart/KREA2.ja.md) | | Boogu-Image 0.1 | - | ✓ | ✓ | ✓* | fp8 オプション | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](quickstart/BOOGU_IMAGE.ja.md) | | zlab i1 | 3B | ✓ | ✓ | ✓ | int8 オプション | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](quickstart/ZLAB_i1.ja.md) | | Ideogram 4 | 9B | ✓ | ✓ | ✓* | fp8 デフォルト、nf4 オプション | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](quickstart/IDEOGRAM4.ja.md) | diff --git a/documentation/QUICKSTART.md b/documentation/QUICKSTART.md index 8974c6e9f..b586e66e7 100644 --- a/documentation/QUICKSTART.md +++ b/documentation/QUICKSTART.md @@ -18,6 +18,7 @@ For the complete and most accurate feature matrix, refer to the [main README](ht | Flux.2 | 32B | ✓ | ✓* | int8/fp8/nf4 optional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](/documentation/quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓* | int8/fp8/nf4 optional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](/documentation/quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✓* | int8 optional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](/documentation/quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✓* | int8 optional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](/documentation/quickstart/KREA2.md) | | Boogu-Image 0.1 | - | ✓ | ✓* | fp8 optional | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](/documentation/quickstart/BOOGU_IMAGE.md) | | zlab i1 | 3B | ✓ | ✓ | int8 optional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](/documentation/quickstart/ZLAB_i1.md) | | Ideogram 4 | 9B | ✓ | ✓* | fp8 default, nf4 optional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](/documentation/quickstart/IDEOGRAM4.md) | diff --git a/documentation/QUICKSTART.pt-BR.md b/documentation/QUICKSTART.pt-BR.md index 655f34eb7..2fed8aca0 100644 --- a/documentation/QUICKSTART.pt-BR.md +++ b/documentation/QUICKSTART.pt-BR.md @@ -18,6 +18,7 @@ Para a matriz completa e mais precisa de recursos, consulte o [README principal] | Flux.2 | 32B | ✓ | ✓ | ✓* | int8/fp8/nf4 opcional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓ | ✓* | int8/fp8/nf4 opcional | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✗ | ✓* | int8 opcional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✗ | ✓* | int8 opcional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](quickstart/KREA2.pt-BR.md) | | Boogu-Image 0.1 | - | ✓ | ✓ | ✓* | fp8 opcional | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](quickstart/BOOGU_IMAGE.pt-BR.md) | | zlab i1 | 3B | ✓ | ✓ | ✓ | int8 opcional | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](quickstart/ZLAB_i1.pt-BR.md) | | Ideogram 4 | 9B | ✓ | ✓ | ✓* | fp8 padrão, nf4 opcional | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](quickstart/IDEOGRAM4.pt-BR.md) | diff --git a/documentation/QUICKSTART.zh.md b/documentation/QUICKSTART.zh.md index a4d278611..655181fba 100644 --- a/documentation/QUICKSTART.zh.md +++ b/documentation/QUICKSTART.zh.md @@ -18,6 +18,7 @@ | Flux.2 | 32B | ✓ | ✓ | ✓* | int8/fp8/nf4 可选 | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ opt | ✗ | ✓ | [FLUX2.md](quickstart/FLUX2.md) | | Flux Kontext | 8B–12B | ✓ | ✓ | ✓* | int8/fp8/nf4 可选 | bf16 | ✓+ | ✓ | ✓ | ✓ | ✓ | ✓ req | ✓ | ✓ | [FLUX_KONTEXT.md](quickstart/FLUX_KONTEXT.md) | | Z-Image Turbo | 6B | ✓ | ✗ | ✓* | int8 可选 | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZIMAGE.md](quickstart/ZIMAGE.md) | +| Krea2 | - | ✓ | ✗ | ✓* | int8 可选 | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✓ opt | ✗ | ✓ | [KREA2.md](quickstart/KREA2.zh.md) | | Boogu-Image 0.1 | - | ✓ | ✓ | ✓* | fp8 可选 | bf16 | ✓ | ✓ | ✗ | ✗ | ✗ | ✓ edit | ✗ | ✓ | [BOOGU_IMAGE.md](quickstart/BOOGU_IMAGE.zh.md) | | zlab i1 | 3B | ✓ | ✓ | ✓ | int8 可选 | bf16 | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ | ✓ | [ZLAB_i1.md](quickstart/ZLAB_i1.zh.md) | | Ideogram 4 | 9B | ✓ | ✓ | ✓* | fp8 默认,nf4 可选 | bf16 | ✓+ | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✓ | [IDEOGRAM4.md](quickstart/IDEOGRAM4.zh.md) | diff --git a/documentation/quickstart/KREA2.es.md b/documentation/quickstart/KREA2.es.md new file mode 100644 index 000000000..eca99b06b --- /dev/null +++ b/documentation/quickstart/KREA2.es.md @@ -0,0 +1,176 @@ +# Guía rápida de Krea2 + +Esta guía cubre el entrenamiento LoRA de Krea2 en SimpleTuner. Krea2 es un transformer de imágenes grande con flow matching, acondicionamiento de texto estilo Qwen y el VAE de Qwen Image. Funciona mejor en GPUs NVIDIA con mucha memoria. + +El ejemplo inicial está en: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## Punto de partida recomendado + +Para la primera ejecución, usa la configuración de ejemplo y mantén el modelo conservador: + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "train_batch_size": 1, + "base_model_precision": "no_change" +} +``` + +Krea2 es nativo de 1024px, pero 512px y 768px son útiles para iterar rápido y revisar datasets. Usa un dataloader de 1024px cuando la ejecución ya sea estable. + +## Notas de hardware + +Krea2 puede entrenar en bf16 en una H100 de 80GB con batch 1 a 1024px. En nuestras pruebas, batches más grandes caben sin compile, pero compile agrega suficiente memoria de grafo/cudagraph como para provocar OOM en muchas configuraciones grandes. + +TorchAO int8 weight-only reduce mucho la VRAM, pero no fue más rápido que bf16 en la ruta de entrenamiento probada. Úsalo cuando la capacidad de memoria sea más importante que el tiempo por paso. + +Recomendaciones: + +- Usa `bf16` cuando el modelo quepa. +- Usa `int8-torchao` cuando necesites margen de memoria. +- Mantén `gradient_checkpointing=true`. +- Mantén `fuse_qkv_projections=true`. +- Usa `dynamo_backend=inductor`, `dynamo_mode=reduce-overhead` y `dynamo_use_regional_compilation=true` solo después de confirmar que el batch/resolución cabe. + +## Valores de configuración clave + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +Para TorchAO int8: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +Para compile reduce-overhead: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +## Entrenamiento con imagen de referencia + +Krea2 soporta acondicionamiento opcional con latentes de referencia para datasets de edición. Actívalo cuando tu dataloader proporcione imágenes de referencia emparejadas o latentes de referencia en caché: + +```json +{ + "krea2_reference_latents": true +} +``` + +Los latentes de referencia deben coincidir con la forma de los latentes objetivo. + +## Configuración del dataloader + +Krea2 usa la estructura general de dataloader de imagen de otros modelos transformer. La resolución real de entrenamiento viene del JSON del dataloader, no solo de `resolution` en el config principal. Para entrenar a 1024px, asegúrate de que `resolution`, `maximum_image_size` y `target_downsample_size` también sean 1024 en el dataloader. + +Un dataset de 512px es útil para pruebas rápidas, revisar captions y detectar crops rotos. El run final suele necesitar 1024px para dar una señal real de calidad. + +Para datasets locales, usa `type: local`, define `instance_data_dir`, y elige una estrategia de caption. Para un sujeto pequeño, `caption_strategy=instanceprompt` suele ser suficiente al inicio. Para estilos, filenames o captions completos suelen funcionar mejor. + +## Validación + +La validación de Krea2 es costosa, así que empieza con pocos prompts. Un prompt único puede ocultar overfitting o memorización; cuando el run sea estable, añade una pequeña prompt library. + +Ejemplo: + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +## Notas de cuantización + +`int8-torchao` almacena los pesos base del transformer en int8 y entrena pesos LoRA bf16 encima. En H100 redujo mucho la VRAM, pero fue más lento que bf16 en esta ruta de entrenamiento. Es una opción de capacidad, no una garantía de throughput. + +## Resultados de benchmark + +Estas mediciones se tomaron en una NVIDIA H100 de 80GB usando el trainer real de SimpleTuner, Krea2 LoRA, QKV fusionado, gradient checkpointing y un dataset pequeño de Domokun. La VRAM se muestreó externamente con `nvidia-smi`. Tómalas solo como guía comparativa; versiones distintas de PyTorch, CUDA, drivers, datasets, rank LoRA, optimizadores, backends de atención y GPUs pueden cambiar los resultados. + +### QKV fusionado + checkpointing, compile desactivado + +| Precisión | Resolución | Batch | s/paso estable | Pico VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### QKV fusionado + checkpointing + compile reduce-overhead + +| Precisión | Resolución | Batch | Estado | s/paso estable | Pico VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +## Guía práctica + +- Para iterar más rápido en una H100, usa bf16, QKV fusionado, checkpointing, compile activado y batch 1. +- Para batches efectivos más grandes, usa bf16 sin compile y sube `train_batch_size` hasta que la VRAM sea el límite. +- Para ejecuciones con poca memoria, usa `int8-torchao`; espera menos VRAM pero pasos más lentos. +- Compile ayuda en batch 1, pero puede consumir suficiente VRAM para hacer fallar batches mayores. + +## Problemas comunes + +- Si esperabas 1024px pero el log muestra 512px, revisa el dataloader JSON. +- Si compile OOM pero el run sin compile entra, baja batch size o desactiva compile. +- Si int8 usa menos VRAM pero es más lento, coincide con nuestras pruebas H100. +- Si la imagen de referencia no influye en validación, confirma que `krea2_reference_latents=true` y que el dataset de validación usa pares de referencia. +- Si overfitea rápido, baja learning rate, reduce pasos o amplía la variedad del dataset. diff --git a/documentation/quickstart/KREA2.hi.md b/documentation/quickstart/KREA2.hi.md new file mode 100644 index 000000000..9ab81effb --- /dev/null +++ b/documentation/quickstart/KREA2.hi.md @@ -0,0 +1,174 @@ +# Krea2 क्विकस्टार्ट + +यह गाइड SimpleTuner में Krea2 LoRA training के लिए है। Krea2 एक बड़ा flow-matching image transformer है जो Qwen-style text conditioning और Qwen Image VAE का उपयोग करता है। यह high-memory NVIDIA GPUs पर सबसे व्यावहारिक है। + +शुरुआती example यहां है: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## अनुशंसित शुरुआत + +पहली run के लिए example config से शुरू करें और settings conservative रखें: + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "train_batch_size": 1, + "base_model_precision": "no_change" +} +``` + +Krea2 1024px-native image model है, लेकिन 512px और 768px fast iteration और dataset checks के लिए उपयोगी हैं। Run stable होने के बाद 1024px dataloader इस्तेमाल करें। + +## Hardware Notes + +हमारे परीक्षण में Krea2 80GB H100 पर bf16, 1024px, batch 1 में train हो सका। Compile off होने पर बड़े batches भी fit हुए, लेकिन compile graph/cudagraph memory बढ़ाता है और कई बड़े batch settings OOM हो जाते हैं। + +TorchAO int8 weight-only VRAM को काफी कम करता है, लेकिन tested SimpleTuner training path में यह bf16 से तेज नहीं था। इसे तब उपयोग करें जब memory capacity step time से अधिक महत्वपूर्ण हो। + +Recommendations: + +- Model fit हो तो `bf16` उपयोग करें। +- Memory headroom चाहिए तो `int8-torchao` उपयोग करें। +- `gradient_checkpointing=true` रखें। +- `fuse_qkv_projections=true` रखें। +- `dynamo_backend=inductor`, `dynamo_mode=reduce-overhead`, और `dynamo_use_regional_compilation=true` केवल batch/resolution fit होने की पुष्टि के बाद enable करें। + +## मुख्य Config Values + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +TorchAO int8 के लिए: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +Reduce-overhead compile के लिए: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +## Reference Image Training + +Krea2 edit-style datasets के लिए optional reference-latent conditioning support करता है। जब dataloader paired reference images या cached reference latents देता हो, इसे enable करें: + +```json +{ + "krea2_reference_latents": true +} +``` + +Reference latents का shape target latents से match होना चाहिए। + +## Dataloader Configuration + +Krea2 वही general image dataloader structure उपयोग करता है जो दूसरे image transformer models करते हैं। वास्तविक training resolution dataloader JSON से आती है, केवल main config के `resolution` से नहीं। 1024px training के लिए dataloader में `resolution`, `maximum_image_size`, और `target_downsample_size` भी 1024 होने चाहिए। + +512px datasets fast tests, captions जांचने और crop समस्याएं पकड़ने के लिए उपयोगी हैं। Final quality signal के लिए 1024px run अधिक उपयोगी होता है। + +Local datasets के लिए `type: local`, `instance_data_dir`, और caption strategy set करें। छोटे subject LoRA के लिए `caption_strategy=instanceprompt` अच्छी शुरुआत है। Style LoRA में filenames या full captions बेहतर हो सकते हैं। + +## Validation + +Krea2 validation महंगी है, इसलिए tuning के समय कम prompts रखें। एक prompt overfit या memorisation छिपा सकता है। Run stable होने के बाद छोटी prompt library जोड़ें। + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +## Quantisation Notes + +`int8-torchao` transformer base weights को int8 में store करता है और ऊपर bf16 LoRA weights train करता है। H100 पर इससे VRAM काफी कम हुई, लेकिन tested training path में यह bf16 से धीमा था। इसे speed guarantee नहीं, capacity option समझें। + +## Benchmark Results + +ये measurements single NVIDIA H100 80GB पर लिए गए: real SimpleTuner trainer, Krea2 LoRA, fused QKV projections, gradient checkpointing, और छोटा Domokun dataset। VRAM को `nvidia-smi` से externally sample किया गया। इन्हें केवल comparative guidance मानें; PyTorch, CUDA, driver, dataset, LoRA rank, optimizer, attention backend और GPU बदलने से values बदल सकती हैं। + +### Fused QKV + Checkpointing, Compile Off + +| Precision | Resolution | Batch | Steady s/step | Peak VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### Fused QKV + Checkpointing + Reduce-Overhead Compile + +| Precision | Resolution | Batch | Status | Steady s/step | Peak VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +## Practical Guidance + +- H100 single-GPU fast iteration के लिए bf16, fused QKV, checkpointing, compile on, batch 1 उपयोग करें। +- बड़े effective batches के लिए uncompiled bf16 बेहतर है; `train_batch_size` को VRAM limit तक बढ़ाएं। +- Memory-constrained runs के लिए `int8-torchao` उपयोग करें; VRAM कम होगी, लेकिन steps धीमे हो सकते हैं। +- Compile batch 1 में उपयोगी है, लेकिन VRAM बहुत बढ़ाकर बड़े batches fail कर सकता है। + +## Common Issues + +- यदि आपने 1024px अपेक्षित किया लेकिन log 512px दिखाता है, dataloader JSON जांचें। +- यदि compile OOM करता है लेकिन uncompiled run fit होता है, batch size घटाएं या compile बंद करें। +- यदि int8 कम VRAM इस्तेमाल करता है लेकिन धीमा है, यह हमारे H100 tests से मेल खाता है। +- यदि reference image validation को प्रभावित नहीं कर रही है, `krea2_reference_latents=true` और paired validation dataset जांचें। +- यदि model जल्दी overfit करता है, learning rate घटाएं, steps घटाएं, या dataset variety बढ़ाएं। diff --git a/documentation/quickstart/KREA2.ja.md b/documentation/quickstart/KREA2.ja.md new file mode 100644 index 000000000..fc161186e --- /dev/null +++ b/documentation/quickstart/KREA2.ja.md @@ -0,0 +1,174 @@ +# Krea2 クイックスタート + +このガイドでは、SimpleTuner で Krea2 の LoRA をトレーニングする方法を扱います。Krea2 は Qwen 系のテキスト条件付けと Qwen Image VAE を使う、大規模な flow matching 画像 transformer です。高メモリの NVIDIA GPU で使うのが現実的です。 + +スターター例: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## 推奨スタート設定 + +最初の実行では、例の設定をベースにして保守的に始めてください。 + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "train_batch_size": 1, + "base_model_precision": "no_change" +} +``` + +Krea2 は 1024px ネイティブの画像モデルですが、512px や 768px は高速な確認に便利です。実行が安定してから 1024px の dataloader に移行してください。 + +## ハードウェアメモ + +Krea2 は 80GB H100 で、1024px・batch 1 の bf16 トレーニングが可能でした。compile なしなら大きめの batch も入りましたが、compile は graph/cudagraph 用のメモリを大きく増やすため、多くの大きな batch 設定で OOM になります。 + +TorchAO int8 weight-only は VRAM を大きく下げますが、今回テストした SimpleTuner のトレーニング経路では bf16 より高速ではありませんでした。速度よりメモリ余裕が重要な場合に使ってください。 + +推奨: + +- 入るなら `bf16` を使う。 +- メモリが足りない場合は `int8-torchao` を使う。 +- `gradient_checkpointing=true` を維持する。 +- `fuse_qkv_projections=true` を維持する。 +- `dynamo_backend=inductor`、`dynamo_mode=reduce-overhead`、`dynamo_use_regional_compilation=true` は、batch/resolution が入ることを確認してから使う。 + +## 主要設定 + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +TorchAO int8: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +reduce-overhead compile: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +## 参照画像トレーニング + +Krea2 は編集系 dataset 向けに、任意の参照 latent 条件付けをサポートします。ペアの参照画像またはキャッシュ済み参照 latent を dataloader が提供する場合に有効化します。 + +```json +{ + "krea2_reference_latents": true +} +``` + +参照 latent はターゲット latent と同じ shape である必要があります。 + +## Dataloader 設定 + +Krea2 は他の画像 transformer と同じ基本 dataloader 形式を使います。実際の学習解像度はトップレベルの `resolution` だけではなく、dataloader JSON の `resolution`、`maximum_image_size`、`target_downsample_size` で決まります。1024px で学習する場合は、dataloader 側も 1024 にしてください。 + +512px dataset は、caption、crop、learning rate の問題を素早く見つけるのに便利です。最終品質を確認するには 1024px の run がより信頼できます。 + +ローカル dataset では `type: local`、`instance_data_dir`、caption strategy を設定します。小さな subject LoRA は `caption_strategy=instanceprompt` から始められます。style LoRA では filenames や通常 caption の方が向くことがあります。 + +## 検証 + +Krea2 の validation は重いので、調整中は prompt を少なくしてください。1つの prompt だけでは overfit や暗記を見逃すことがあります。安定したら小さな prompt library を追加します。 + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +## 量子化メモ + +`int8-torchao` は transformer のベース重みを int8 に保存し、その上で bf16 LoRA 重みを学習します。H100 では VRAM を大きく削減しましたが、この training path では bf16 より高速ではありませんでした。速度ではなく容量のための選択肢として考えてください。 + +## ベンチマーク結果 + +以下は、単一の NVIDIA H100 80GB、SimpleTuner の実トレーナー、Krea2 LoRA、QKV fusion、gradient checkpointing、小さな Domokun dataset で測定した結果です。VRAM は `nvidia-smi` で外部サンプリングしました。比較用の目安として扱ってください。PyTorch、CUDA、driver、dataset、LoRA rank、optimizer、attention backend、GPU が変わると結果も変わります。 + +### QKV fusion + checkpointing、compile オフ + +| 精度 | 解像度 | Batch | 安定時 s/step | Peak VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### QKV fusion + checkpointing + reduce-overhead compile + +| 精度 | 解像度 | Batch | 状態 | 安定時 s/step | Peak VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +## 実用上の指針 + +- H100 単一 GPU で最速に試すなら、bf16、QKV fusion、checkpointing、compile オン、batch 1。 +- 大きな effective batch が必要なら、compile なしの bf16 で `train_batch_size` を上げる。 +- メモリ制約が強い場合は `int8-torchao` を使う。ただし step は遅くなる可能性がある。 +- compile は batch 1 では有効ですが、VRAM を大きく増やし、大きな batch を失敗させることがあります。 + +## よくある問題 + +- 1024px のつもりで log が 512px の場合は、dataloader JSON を確認してください。 +- compile で OOM し、compile なしで入る場合は、batch size を下げるか compile を無効にしてください。 +- int8 が低 VRAM でも遅い場合、それは今回の H100 測定と一致します。 +- 参照画像が validation に効かない場合は、`krea2_reference_latents=true` と paired reference dataset を確認してください。 +- すぐ overfit する場合は、learning rate、step 数、dataset の多様性を見直してください。 diff --git a/documentation/quickstart/KREA2.md b/documentation/quickstart/KREA2.md new file mode 100644 index 000000000..572be456b --- /dev/null +++ b/documentation/quickstart/KREA2.md @@ -0,0 +1,304 @@ +# Krea2 Quickstart + +In this example, we'll be training a Krea2 LoRA. + +Krea2 is a large flow-matching image transformer using Qwen-style text conditioning and the Qwen Image VAE. In SimpleTuner, the default target is PEFT LoRA training on the transformer. The text encoder and VAE are used for caching and validation, then moved out of the way before the training loop. + +The starter example lives here: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## Hardware requirements + +Krea2 is much heavier than SDXL-style UNets. It is happiest on a high-memory CUDA GPU, and the 1024px configuration should be treated as an H100/A100-80G/L40S-class workload unless you are using quantisation or offload. + +On a single H100 80GB, the practical starting points are: + +- **bf16, 512px, batch 1** for fast smoke tests and dataset checks +- **bf16, 1024px, batch 1** for a realistic full-resolution run +- **int8-torchao, 1024px, batch 1-4** when VRAM headroom is more important than step speed +- **compile reduce-overhead** only after the uncompiled run is stable, because compile can add a large amount of VRAM + +You will need: + +- **the realistic minimum**: an NVIDIA GPU with at least 24GB VRAM for reduced-resolution or quantised experiments +- **recommended**: 48GB or more for comfortable 512px work +- **ideal**: 80GB H100/A100-class cards for 1024px and compile experiments + +Apple GPUs are not currently a recommended target for Krea2 training. + +## Prerequisites + +Make sure Python is installed. SimpleTuner supports Python 3.10 through 3.13. + +```bash +python --version +``` + +If Python 3.13 is not installed on Ubuntu: + +```bash +apt -y install python3.13 python3.13-venv +``` + +### Container image dependencies + +For Vast, RunPod, TensorDock, and similar CUDA images, install CUDA toolkit headers when you need packages that compile CUDA extensions: + +```bash +apt -y install nvidia-cuda-toolkit +``` + +CUDA 13 users should use an image whose CUDA runtime, CUDA headers, NVRTC, and PyTorch wheel are from the same generation. Transformer-style training paths can become very sensitive to mismatched CUDA packages. + +## Installation + +Install SimpleTuner via pip: + +```bash +pip install 'simpletuner[cuda]' + +# CUDA 13 / Blackwell users +pip install 'simpletuner[cuda13]' --extra-index-url https://download.pytorch.org/whl/cu130 +``` + +For a development checkout, follow the [installation documentation](../INSTALL.md). + +## Setting up the environment + +### Web interface method + +The SimpleTuner WebUI can create and edit the config for you: + +```bash +simpletuner server +``` + +The server listens on port 8001 by default. + +### Manual / command-line method + +Copy the example config and edit it: + +```bash +cp simpletuner/examples/krea2.peft-lora/config.json config/config.json +``` + +The important values are: + +- `model_type` - set this to `lora`. +- `model_family` - set this to `krea2`. +- `model_flavour` - set this to `raw`. +- `pretrained_model_name_or_path` - set this to `krea/Krea-2-Raw`. +- `mixed_precision` - keep this at `bf16` on modern NVIDIA GPUs. +- `gradient_checkpointing` - keep this enabled unless you are deliberately measuring memory. +- `fuse_qkv_projections` - keep this enabled. Krea2 supports permanent QKV fusion for the attention projections, and the LoRA target changes to the fused projection. +- `train_batch_size` - start at 1. Increase after the run is stable. +- `resolution` - the top-level value is less important than the dataloader's own `resolution`; make sure the dataloader is actually set to the resolution you intend to test. +- `validation_resolution` - use `1024x1024` for full-resolution validation. +- `base_model_precision` - use `no_change` for bf16 or `int8-torchao` for TorchAO int8 weight-only training. +- `quantize_via` - use `cpu` for TorchAO int8 when startup GPU memory is tight. + +A conservative bf16 starting point: + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +For TorchAO int8: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +For reduce-overhead compile: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +Compile should be considered a batch-size-1 performance option first. It can make larger batches OOM even when the same batch fits without compile. + +## Dataloader configuration + +Krea2 uses the same general image dataloader structure as the other image transformer models. The example config uses a small Domokun dataset: + +```json +[ + { + "id": "dreambooth-1024", + "type": "huggingface", + "dataset_name": "RareConcepts/Domokun", + "crop": true, + "crop_style": "random", + "crop_aspect": "square", + "minimum_image_size": 128, + "maximum_image_size": 1024, + "target_downsample_size": 1024, + "resolution": 1024, + "resolution_type": "pixel", + "metadata_backend": "huggingface", + "caption_strategy": "instanceprompt", + "instance_prompt": "the name of your subject goes here", + "cache_dir_vae": "cache/vae/krea2/dreambooth-1024" + }, + { + "id": "alt-embed-cache", + "dataset_type": "text_embeds", + "default": true, + "type": "local", + "cache_dir": "cache/text/krea2" + } +] +``` + +For your own dataset, switch `type` to `local`, set `instance_data_dir`, and choose a caption strategy. Subject LoRAs commonly start with `caption_strategy=instanceprompt`; style LoRAs usually do better with captions or filenames. + +512px and 1024px datasets can both be useful. 512px runs are much faster and are good for catching bad captions, broken crops, or learning-rate mistakes. 1024px runs are the better signal for final quality. + +## Validation prompts + +Krea2 validation is expensive enough that it is worth keeping the prompt set small while tuning. Start with one or two prompts that clearly show whether the subject or style is being learned. Once the run is stable, add a prompt library. + +Example: + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +Krea2 can overfit small datasets quickly. Do not rely on a single validation prompt; a small prompt library makes collapse or prompt memorisation much easier to spot. + +## Reference image training + +Krea2 supports optional reference-latent conditioning for edit-style datasets: + +```json +{ + "krea2_reference_latents": true +} +``` + +This mode expects paired reference data. The reference latents must match the target latent shape. It is intended for Qwen Edit-style paired conditioning, not for generic SDEdit/img2img training. + +Use this only when the dataset is built around paired examples. For ordinary subject/style LoRA training, leave it disabled. + +## Quantisation notes + +`int8-torchao` stores the base transformer weights in int8 and trains bf16 LoRA weights on top. On H100 it reduced peak VRAM substantially, but the tested path was slower than bf16 at the same resolution and batch size. + +That tradeoff is still useful: + +- use bf16 when the run fits and speed matters +- use int8 when the run otherwise does not fit +- expect int8 startup to take longer because the model is quantised before training +- remeasure if you change PyTorch, TorchAO, CUDA, or the attention backend + +## Performance notes + +The following results were measured on a single NVIDIA H100 80GB using the real SimpleTuner trainer, Krea2 LoRA, fused QKV projections, gradient checkpointing, and a small Domokun dataset. VRAM was sampled externally with `nvidia-smi`. + +These numbers are not hardware guarantees. Treat them as comparative data showing how this recipe behaved on one H100 system. Different drivers, CUDA builds, PyTorch builds, dataloaders, optimizers, LoRA ranks, and attention backends can move the numbers. + +### Fused QKV + checkpointing, compile off + +| Precision | Resolution | Batch | Steady s/step | Peak VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### Fused QKV + checkpointing + reduce-overhead compile + +| Precision | Resolution | Batch | Status | Steady s/step | Peak VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +The useful conclusion is straightforward: on this H100, compile was a strong batch-size-1 speedup, but it increased VRAM enough to make most larger batches fail. Uncompiled bf16 was the best general-purpose choice when the model fit. Int8 was the memory-saving choice, not the speed choice. + +## Executing the training run + +From the SimpleTuner directory: + +```bash +simpletuner train +``` + +or, from a development checkout: + +```bash +.venv/bin/python -m simpletuner.cli train +``` + +Training begins by caching text embeddings and VAE latents. If you change captions, image resolution, crop settings, or reference-image settings, clear the relevant cache or use a new cache directory. + +## Troubleshooting + +### The run is using 512px even though `resolution` says 1024 + +The dataloader has its own `resolution`, `maximum_image_size`, and `target_downsample_size`. Those values decide the actual training image size. Update the dataloader JSON, not just the top-level model config. + +### Compile OOMs but uncompiled training fits + +This is expected for Krea2 on 80GB cards. Compile reduce-overhead can reserve much more memory for graphs/cudagraphs. Lower the batch size or disable compile. + +### Int8 uses less VRAM but trains slower + +That matched the H100 measurements above. TorchAO int8 is useful for capacity; it is not automatically a throughput improvement for this training path. + +### The validation image is not following the reference image + +Confirm that `krea2_reference_latents=true` is enabled and that the validation dataset is using paired reference data. Plain validation prompts do not exercise the reference-latent path. + +### The model overfits quickly + +Use fewer steps, lower the learning rate, add more prompts for validation, or increase dataset variety. Krea2 is large enough to memorise small subject datasets quickly. diff --git a/documentation/quickstart/KREA2.pt-BR.md b/documentation/quickstart/KREA2.pt-BR.md new file mode 100644 index 000000000..facff1408 --- /dev/null +++ b/documentation/quickstart/KREA2.pt-BR.md @@ -0,0 +1,174 @@ +# Guia Rápido do Krea2 + +Este guia cobre treinamento LoRA do Krea2 no SimpleTuner. Krea2 é um transformer grande de imagem com flow matching, condicionamento de texto estilo Qwen e VAE do Qwen Image. Ele é mais confortável em GPUs NVIDIA com muita memória. + +O exemplo inicial está em: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## Ponto de Partida Recomendado + +Para a primeira execução, use a configuração de exemplo e mantenha o modelo conservador: + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "train_batch_size": 1, + "base_model_precision": "no_change" +} +``` + +Krea2 é nativo em 1024px, mas 512px e 768px são úteis para iteração rápida e checagem de dataset. Use um dataloader de 1024px depois que a execução estiver estável. + +## Notas de Hardware + +Krea2 pode treinar em bf16 em uma H100 de 80GB com batch 1 em 1024px. Batches maiores couberam sem compile nos nossos testes, mas compile adiciona memória de grafo/cudagraph suficiente para causar OOM em muitas configurações maiores. + +TorchAO int8 weight-only reduz bastante a VRAM, mas não foi mais rápido que bf16 no caminho de treinamento testado. Use quando memória for mais importante que tempo por passo. + +Recomendações: + +- Use `bf16` quando couber. +- Use `int8-torchao` quando precisar de folga de memória. +- Mantenha `gradient_checkpointing=true`. +- Mantenha `fuse_qkv_projections=true`. +- Use `dynamo_backend=inductor`, `dynamo_mode=reduce-overhead` e `dynamo_use_regional_compilation=true` apenas depois de confirmar que batch/resolução cabem. + +## Valores Principais de Configuração + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +Para TorchAO int8: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +Para compile reduce-overhead: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +## Treinamento com Imagem de Referência + +Krea2 suporta condicionamento opcional por latentes de referência para datasets de edição. Ative quando o dataloader fornecer imagens de referência pareadas ou latentes de referência em cache: + +```json +{ + "krea2_reference_latents": true +} +``` + +Os latentes de referência devem ter a mesma forma dos latentes alvo. + +## Configuração do Dataloader + +Krea2 usa a estrutura geral de dataloader de imagem dos outros modelos transformer. A resolução real de treino vem do JSON do dataloader, não apenas de `resolution` no config principal. Para treinar em 1024px, confirme que `resolution`, `maximum_image_size` e `target_downsample_size` também são 1024 no dataloader. + +Datasets em 512px são úteis para testes rápidos, checar captions e encontrar crops ruins. Para sinal de qualidade final, 1024px costuma ser mais representativo. + +Para datasets locais, use `type: local`, defina `instance_data_dir` e escolha uma estratégia de caption. Para subject LoRA pequeno, `caption_strategy=instanceprompt` é um bom começo. Para estilos, filenames ou captions completas tendem a funcionar melhor. + +## Validação + +Validação do Krea2 é cara, então comece com poucos prompts. Um único prompt pode esconder overfit ou memorização. Depois que o run estiver estável, adicione uma pequena prompt library. + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +## Notas de Quantização + +`int8-torchao` armazena os pesos base do transformer em int8 e treina pesos LoRA bf16 por cima. Na H100 reduziu bastante a VRAM, mas foi mais lento que bf16 neste caminho de treinamento. Pense nele como uma opção de capacidade, não como garantia de throughput. + +## Resultados de Benchmark + +As medições abaixo foram feitas em uma NVIDIA H100 80GB usando o trainer real do SimpleTuner, Krea2 LoRA, QKV fusionado, gradient checkpointing e um dataset pequeno Domokun. A VRAM foi amostrada externamente com `nvidia-smi`. Use estes valores apenas como comparação; versões diferentes de PyTorch, CUDA, driver, dataset, rank LoRA, otimizador, backend de atenção e GPU podem mudar os resultados. + +### QKV Fusionado + Checkpointing, Compile Desligado + +| Precisão | Resolução | Batch | s/passo estável | Pico VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### QKV Fusionado + Checkpointing + Compile Reduce-Overhead + +| Precisão | Resolução | Batch | Estado | s/passo estável | Pico VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +## Orientação Prática + +- Para iterar mais rápido em uma H100, use bf16, QKV fusionado, checkpointing, compile ligado e batch 1. +- Para batches efetivos maiores, prefira bf16 sem compile e aumente `train_batch_size` até a VRAM virar o limite. +- Para execuções com pouca memória, use `int8-torchao`; espere menos VRAM, mas passos mais lentos. +- Compile ajuda em batch 1, mas pode consumir VRAM suficiente para fazer batches maiores falharem. + +## Problemas Comuns + +- Se você esperava 1024px mas o log mostra 512px, revise o JSON do dataloader. +- Se compile causa OOM mas o run sem compile cabe, reduza batch size ou desligue compile. +- Se int8 usa menos VRAM mas é mais lento, isso corresponde aos nossos testes H100. +- Se a imagem de referência não afeta a validação, confirme `krea2_reference_latents=true` e dataset pareado. +- Se overfitar rápido, reduza learning rate, reduza steps ou aumente a variedade do dataset. diff --git a/documentation/quickstart/KREA2.zh.md b/documentation/quickstart/KREA2.zh.md new file mode 100644 index 000000000..29dc1077d --- /dev/null +++ b/documentation/quickstart/KREA2.zh.md @@ -0,0 +1,174 @@ +# Krea2 快速开始 + +本指南介绍在 SimpleTuner 中训练 Krea2 LoRA。Krea2 是大型 flow-matching 图像 transformer,使用 Qwen 风格文本条件和 Qwen Image VAE。它更适合高显存 NVIDIA GPU。 + +起始示例位于: + +```bash +simpletuner/examples/krea2.peft-lora/config.json +``` + +## 推荐起点 + +第一次运行时,建议从示例配置开始,并保持保守设置: + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "train_batch_size": 1, + "base_model_precision": "no_change" +} +``` + +Krea2 是 1024px 原生图像模型,但 512px 和 768px 适合快速迭代和检查数据集。运行稳定后再切换到 1024px dataloader。 + +## 硬件说明 + +在我们的测试中,Krea2 可以在 80GB H100 上以 bf16、1024px、batch 1 训练。未启用 compile 时更大的 batch 也能放下,但 compile 会增加 graph/cudagraph 内存,很多大 batch 设置会 OOM。 + +TorchAO int8 weight-only 可以显著降低 VRAM,但在测试的 SimpleTuner 训练路径中并不比 bf16 更快。内存容量比速度更重要时再使用。 + +推荐: + +- 能放下时使用 `bf16`。 +- 需要显存余量时使用 `int8-torchao`。 +- 保持 `gradient_checkpointing=true`。 +- 保持 `fuse_qkv_projections=true`。 +- 只有在确认 batch/resolution 能放下后,才启用 `dynamo_backend=inductor`、`dynamo_mode=reduce-overhead` 和 `dynamo_use_regional_compilation=true`。 + +## 关键配置值 + +```json +{ + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "base_model_precision": "no_change", + "mixed_precision": "bf16", + "gradient_checkpointing": true, + "fuse_qkv_projections": true, + "optimizer": "optimi-lion", + "learning_rate": 1e-4, + "lora_rank": 64, + "train_batch_size": 1, + "resolution": 1024, + "validation_resolution": "1024x1024" +} +``` + +TorchAO int8: + +```json +{ + "base_model_precision": "int8-torchao", + "quantize_via": "cpu" +} +``` + +reduce-overhead compile: + +```json +{ + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true +} +``` + +## 参考图像训练 + +Krea2 支持面向编辑数据集的可选参考 latent 条件。若 dataloader 提供成对参考图像或缓存的参考 latents,可启用: + +```json +{ + "krea2_reference_latents": true +} +``` + +参考 latents 必须与目标 latents 的 shape 匹配。 + +## Dataloader 配置 + +Krea2 使用与其他图像 transformer 类似的 dataloader 结构。真实训练分辨率由 dataloader JSON 决定,而不只是主配置里的 `resolution`。如果要训练 1024px,请确保 dataloader 中的 `resolution`、`maximum_image_size` 和 `target_downsample_size` 也是 1024。 + +512px 数据集适合快速测试、检查 caption 和发现裁剪问题。最终质量判断通常需要 1024px run。 + +本地数据集使用 `type: local`,设置 `instance_data_dir`,并选择 caption strategy。小型 subject LoRA 可以从 `caption_strategy=instanceprompt` 开始;风格 LoRA 通常更适合 filenames 或完整 captions。 + +## 验证 + +Krea2 验证成本较高,调参时先使用少量 prompts。单个 prompt 可能掩盖 overfit 或记忆问题;运行稳定后再加入小型 prompt library。 + +```json +{ + "validation_prompt": "a studio portrait of , soft directional light, detailed fabric texture", + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_guidance": 4.5, + "validation_resolution": "1024x1024" +} +``` + +## 量化说明 + +`int8-torchao` 将 transformer 基础权重以 int8 存储,并在其上训练 bf16 LoRA 权重。在 H100 上它显著降低了 VRAM,但在该训练路径中比 bf16 慢。它主要是容量选项,不是吞吐保证。 + +## Benchmark 结果 + +以下数据来自单张 NVIDIA H100 80GB,使用 SimpleTuner 真实 trainer、Krea2 LoRA、QKV fusion、gradient checkpointing 和小型 Domokun 数据集。VRAM 通过 `nvidia-smi` 外部采样。请只将这些数值作为比较参考;不同 PyTorch、CUDA、驱动、数据集、LoRA rank、优化器、attention backend 和 GPU 都可能改变结果。 + +### QKV Fusion + Checkpointing,关闭 Compile + +| 精度 | 分辨率 | Batch | 稳定 s/step | 峰值 VRAM | +| --- | ---: | ---: | ---: | ---: | +| bf16 | 512 | 1 | 0.353 | 31.10 GiB | +| bf16 | 512 | 4 | 1.230 | 39.31 GiB | +| bf16 | 512 | 8 | 2.430 | 50.32 GiB | +| bf16 | 1024 | 1 | 0.990 | 33.28 GiB | +| bf16 | 1024 | 4 | 3.850 | 48.35 GiB | +| bf16 | 1024 | 8 | 7.690 | 67.88 GiB | +| int8-torchao | 512 | 1 | 0.535 | 18.10 GiB | +| int8-torchao | 512 | 4 | 1.690 | 27.46 GiB | +| int8-torchao | 512 | 8 | 3.220 | 40.52 GiB | +| int8-torchao | 1024 | 1 | 1.330 | 20.35 GiB | +| int8-torchao | 1024 | 4 | 4.850 | 36.99 GiB | +| int8-torchao | 1024 | 8 | 9.520 | 58.84 GiB | + +### QKV Fusion + Checkpointing + Reduce-Overhead Compile + +| 精度 | 分辨率 | Batch | 状态 | 稳定 s/step | 峰值 VRAM | +| --- | ---: | ---: | --- | ---: | ---: | +| bf16 | 512 | 1 | ok | 0.260 | 41.20 GiB | +| bf16 | 512 | 4 | OOM | - | 79.07 GiB | +| bf16 | 512 | 8 | OOM | - | 79.10 GiB | +| bf16 | 1024 | 1 | ok | 0.704 | 63.71 GiB | +| bf16 | 1024 | 4 | OOM | - | 79.11 GiB | +| bf16 | 1024 | 8 | OOM | - | 78.40 GiB | +| int8-torchao | 512 | 1 | ok | 0.410 | 30.93 GiB | +| int8-torchao | 512 | 4 | ok | 1.300 | 78.60 GiB | +| int8-torchao | 512 | 8 | OOM | - | 79.12 GiB | +| int8-torchao | 1024 | 1 | ok | 0.990 | 58.68 GiB | +| int8-torchao | 1024 | 4 | OOM | - | 78.92 GiB | +| int8-torchao | 1024 | 8 | OOM | - | 78.09 GiB | + +## 实用建议 + +- 在单张 H100 上最快迭代:bf16、QKV fusion、checkpointing、开启 compile、batch 1。 +- 若需要更大的有效 batch,优先使用未 compile 的 bf16,并逐步提高 `train_batch_size` 直到 VRAM 成为限制。 +- 内存受限时使用 `int8-torchao`;VRAM 更低,但 step 更慢。 +- compile 对 batch 1 有用,但可能消耗大量 VRAM,导致更大 batch 失败。 + +## 常见问题 + +- 如果期望 1024px 但日志显示 512px,请检查 dataloader JSON。 +- 如果 compile OOM 但非 compile 能运行,请降低 batch size 或关闭 compile。 +- 如果 int8 显存更低但更慢,这与我们的 H100 测试一致。 +- 如果参考图像没有影响验证结果,请确认 `krea2_reference_latents=true` 且验证 dataset 使用成对参考数据。 +- 如果很快 overfit,请降低 learning rate、减少 steps 或增加数据集多样性。 diff --git a/documentation/quickstart/index.es.md b/documentation/quickstart/index.es.md index 4bd7c542e..6052c7ccf 100644 --- a/documentation/quickstart/index.es.md +++ b/documentation/quickstart/index.es.md @@ -18,6 +18,7 @@ Guías paso a paso para entrenar cada arquitectura de modelo compatible. | **Lumina2** | 2B | [Guía de Lumina2](LUMINA2.md) | | **HiDream** | 17B MoE | [Guía de HiDream](HIDREAM.md) | | **Z-Image** | - | [Guía de Z-Image](ZIMAGE.md) | +| **Krea2** | - | [Guía de Krea2](KREA2.es.md) | | **Boogu-Image** | - | [Guía de Boogu-Image](BOOGU_IMAGE.es.md) | | **zlab i1** | 3B | [Guía de zlab i1](ZLAB_i1.es.md) | | **Ideogram 4** | 9B | [Guía de Ideogram 4](IDEOGRAM4.es.md) | diff --git a/documentation/quickstart/index.hi.md b/documentation/quickstart/index.hi.md index 941da170d..d55f311f4 100644 --- a/documentation/quickstart/index.hi.md +++ b/documentation/quickstart/index.hi.md @@ -18,6 +18,7 @@ | **Lumina2** | 2B | [Lumina2 गाइड](LUMINA2.md) | | **HiDream** | 17B MoE | [HiDream गाइड](HIDREAM.md) | | **Z-Image** | - | [Z-Image गाइड](ZIMAGE.md) | +| **Krea2** | - | [Krea2 गाइड](KREA2.hi.md) | | **Boogu-Image** | - | [Boogu-Image गाइड](BOOGU_IMAGE.hi.md) | | **zlab i1** | 3B | [zlab i1 गाइड](ZLAB_i1.hi.md) | | **Ideogram 4** | 9B | [Ideogram 4 गाइड](IDEOGRAM4.hi.md) | diff --git a/documentation/quickstart/index.ja.md b/documentation/quickstart/index.ja.md index 25cdbc045..29b423342 100644 --- a/documentation/quickstart/index.ja.md +++ b/documentation/quickstart/index.ja.md @@ -18,6 +18,7 @@ | **Lumina2** | 2B | [Lumina2 ガイド](LUMINA2.md) | | **HiDream** | 17B MoE | [HiDream ガイド](HIDREAM.md) | | **Z-Image** | - | [Z-Image ガイド](ZIMAGE.md) | +| **Krea2** | - | [Krea2 ガイド](KREA2.ja.md) | | **Boogu-Image** | - | [Boogu-Image ガイド](BOOGU_IMAGE.ja.md) | | **zlab i1** | 3B | [zlab i1 ガイド](ZLAB_i1.ja.md) | | **Ideogram 4** | 9B | [Ideogram 4 ガイド](IDEOGRAM4.ja.md) | diff --git a/documentation/quickstart/index.md b/documentation/quickstart/index.md index f7678b47a..ecc897fda 100644 --- a/documentation/quickstart/index.md +++ b/documentation/quickstart/index.md @@ -18,6 +18,7 @@ Step-by-step guides for training each supported model architecture. | **Lumina2** | 2B | [Lumina2 Guide](LUMINA2.md) | | **HiDream** | 17B MoE | [HiDream Guide](HIDREAM.md) | | **Z-Image** | - | [Z-Image Guide](ZIMAGE.md) | +| **Krea2** | - | [Krea2 Guide](KREA2.md) | | **Boogu-Image** | - | [Boogu-Image Guide](BOOGU_IMAGE.md) | | **zlab i1** | 3B | [zlab i1 Guide](ZLAB_i1.md) | | **Ideogram 4** | 9B | [Ideogram 4 Guide](IDEOGRAM4.md) | diff --git a/documentation/quickstart/index.pt-BR.md b/documentation/quickstart/index.pt-BR.md index 3e0d802ce..59d8dfe4e 100644 --- a/documentation/quickstart/index.pt-BR.md +++ b/documentation/quickstart/index.pt-BR.md @@ -18,6 +18,7 @@ Guias passo a passo para treinar cada arquitetura de modelo suportada. | **Lumina2** | 2B | [Guia Lumina2](LUMINA2.md) | | **HiDream** | 17B MoE | [Guia HiDream](HIDREAM.md) | | **Z-Image** | - | [Guia Z-Image](ZIMAGE.md) | +| **Krea2** | - | [Guia Krea2](KREA2.pt-BR.md) | | **Boogu-Image** | - | [Guia Boogu-Image](BOOGU_IMAGE.pt-BR.md) | | **zlab i1** | 3B | [Guia zlab i1](ZLAB_i1.pt-BR.md) | | **Ideogram 4** | 9B | [Guia Ideogram 4](IDEOGRAM4.pt-BR.md) | diff --git a/documentation/quickstart/index.zh.md b/documentation/quickstart/index.zh.md index 1bff22fbd..3bb9cdef5 100644 --- a/documentation/quickstart/index.zh.md +++ b/documentation/quickstart/index.zh.md @@ -18,6 +18,7 @@ | **Lumina2** | 2B | [Lumina2 指南](LUMINA2.md) | | **HiDream** | 17B MoE | [HiDream 指南](HIDREAM.md) | | **Z-Image** | - | [Z-Image 指南](ZIMAGE.md) | +| **Krea2** | - | [Krea2 指南](KREA2.zh.md) | | **Boogu-Image** | - | [Boogu-Image 指南](BOOGU_IMAGE.zh.md) | | **zlab i1** | 3B | [zlab i1 指南](ZLAB_i1.zh.md) | | **Ideogram 4** | 9B | [Ideogram 4 指南](IDEOGRAM4.zh.md) | diff --git a/mkdocs.yml b/mkdocs.yml index 7dc98f145..87d0f6cf4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -253,6 +253,7 @@ nav: - Cosmos2: quickstart/COSMOS2IMAGE.md - OmniGen: quickstart/OMNIGEN.md - Z-Image: quickstart/ZIMAGE.md + - Krea2: quickstart/KREA2.md - Boogu-Image: quickstart/BOOGU_IMAGE.md - zlab i1: quickstart/ZLAB_i1.md - Ideogram 4: quickstart/IDEOGRAM4.md diff --git a/simpletuner/examples/krea2.peft-lora/config.json b/simpletuner/examples/krea2.peft-lora/config.json new file mode 100644 index 000000000..85f564511 --- /dev/null +++ b/simpletuner/examples/krea2.peft-lora/config.json @@ -0,0 +1,58 @@ +{ + "base_model_precision": "int8-torchao", + "caption_dropout_probability": 0.0, + "checkpoint_step_interval": 50, + "checkpoints_total_limit": 5, + "compress_disk_cache": false, + "data_backend_config": "config/examples/multidatabackend-small-dreambooth-512px.json", + "disable_bucket_pruning": true, + "attention_mechanism": "diffusers", + "dynamo_backend": "inductor", + "dynamo_mode": "reduce-overhead", + "dynamo_use_regional_compilation": true, + "fuse_qkv_projections": true, + "gradient_checkpointing": true, + "hub_model_id": "simpletuner-example-krea2-peft-lora", + "ignore_final_epochs": true, + "learning_rate": 1e-4, + "lora_rank": 64, + "lora_type": "standard", + "lr_scheduler": "constant_with_warmup", + "lr_warmup_steps": 100, + "max_grad_norm": 0.01, + "max_train_steps": 1000, + "minimum_image_size": 0, + "mixed_precision": "bf16", + "model_family": "krea2", + "model_flavour": "raw", + "model_type": "lora", + "num_eval_images": 25, + "num_train_epochs": 0, + "optimizer": "optimi-lion", + "output_dir": "output/examples/krea2.peft-lora-compile", + "pretrained_model_name_or_path": "krea/Krea-2-Raw", + "push_checkpoints_to_hub": false, + "push_to_hub": false, + "quantize_via": "cpu", + "report_to": "none", + "resolution": 1024, + "resolution_type": "pixel_area", + "resume_from_checkpoint": null, + "seed": 42, + "skip_file_discovery": false, + "tracker_project_name": "lora-training", + "tracker_run_name": "krea2-domokun-example", + "train_batch_size": 1, + "use_ema": false, + "vae_batch_size": 1, + "validation_adapter_strength": 1.0, + "validation_guidance": 4.5, + "validation_guidance_rescale": 0.0, + "validation_negative_prompt": "ugly, cropped, blurry, low-quality, mediocre average", + "validation_num_inference_steps": 28, + "validation_prompt": "🟫 is holding a sign that says hello world from krea2", + "validation_prompt_library": false, + "validation_resolution": "1024x1024", + "validation_seed": 42, + "validation_steps": 50 +} diff --git a/simpletuner/helpers/models/common.py b/simpletuner/helpers/models/common.py index 7303a4e5e..308feb63a 100644 --- a/simpletuner/helpers/models/common.py +++ b/simpletuner/helpers/models/common.py @@ -105,6 +105,7 @@ def _is_hf_repo_id(path: str) -> bool: "z_image_omni", "zlab_i1", "ideogram", + "krea2", ] upstream_config_sources = { "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", @@ -121,6 +122,7 @@ def _is_hf_repo_id(path: str) -> bool: "wan": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", "hunyuanvideo": "tencent/HunyuanVideo-1.5", "ideogram": "ideogram-ai/ideogram-4-fp8", + "krea2": "krea/Krea-2-Raw", } diff --git a/simpletuner/helpers/models/krea2/__init__.py b/simpletuner/helpers/models/krea2/__init__.py new file mode 100644 index 000000000..7c0a65ceb --- /dev/null +++ b/simpletuner/helpers/models/krea2/__init__.py @@ -0,0 +1,6 @@ +from simpletuner.helpers.models.krea2.lora_pipeline import Krea2LoraLoaderMixin +from simpletuner.helpers.models.krea2.model import Krea2 +from simpletuner.helpers.models.krea2.pipeline import Krea2Pipeline +from simpletuner.helpers.models.krea2.transformer import Krea2Transformer2DModel + +__all__ = ["Krea2", "Krea2LoraLoaderMixin", "Krea2Pipeline", "Krea2Transformer2DModel"] diff --git a/simpletuner/helpers/models/krea2/lora_pipeline.py b/simpletuner/helpers/models/krea2/lora_pipeline.py new file mode 100644 index 000000000..c8adbce86 --- /dev/null +++ b/simpletuner/helpers/models/krea2/lora_pipeline.py @@ -0,0 +1,231 @@ +"""Krea 2 LoRA loader vendored from huggingface/diffusers#14046.""" + +import os +from typing import Callable + +import torch +from diffusers.loaders.lora_base import LoraBaseMixin, _fetch_state_dict +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_peft_available, + is_peft_version, + is_torch_version, + is_transformers_available, + is_transformers_version, + logging, +) +from huggingface_hub.utils import validate_hf_hub_args + +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + +TRANSFORMER_NAME = "transformer" + +logger = logging.get_logger(__name__) + + +class Krea2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Krea2Transformer2DModel`]. Specific to [`Krea2Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Krea2Transformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) diff --git a/simpletuner/helpers/models/krea2/model.py b/simpletuner/helpers/models/krea2/model.py new file mode 100644 index 000000000..e1e92718a --- /dev/null +++ b/simpletuner/helpers/models/krea2/model.py @@ -0,0 +1,496 @@ +import logging +from typing import Optional + +import numpy as np +import torch +from diffusers import AutoencoderKLQwenImage +from PIL import Image +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLModel + +from simpletuner.helpers.models.common import ( + ImageModelFoundation, + ModelTypes, + PipelineTypes, + PredictionTypes, + TextEmbedCacheKey, +) +from simpletuner.helpers.models.krea2.pipeline import Krea2Pipeline +from simpletuner.helpers.models.krea2.transformer import Krea2Transformer2DModel +from simpletuner.helpers.models.registry import ModelRegistry +from simpletuner.helpers.models.tae.types import VideoTAESpec +from simpletuner.helpers.training.state_tracker import StateTracker + +logger = logging.getLogger(__name__) + + +class Krea2(ImageModelFoundation): + SUPPORTS_MUON_CLIP = True + NAME = "Krea 2" + MODEL_DESCRIPTION = "Krea 2 flow-matching transformer" + ENABLED_IN_WIZARD = True + PREDICTION_TYPE = PredictionTypes.FLOW_MATCHING + MODEL_TYPE = ModelTypes.TRANSFORMER + USES_DYNAMIC_SHIFT = True + AUTOENCODER_CLASS = AutoencoderKLQwenImage + AUTOENCODER_SCALING_FACTOR = 1.0 + LATENT_CHANNEL_COUNT = 16 + VALIDATION_PREVIEW_SPEC = VideoTAESpec(filename="taew2_1.pth", description="Wan 2.1 VAE compatible") + + MODEL_CLASS = Krea2Transformer2DModel + MODEL_SUBFOLDER = "transformer" + PIPELINE_CLASSES = { + PipelineTypes.TEXT2IMG: Krea2Pipeline, + } + DEFAULT_MODEL_FLAVOUR = "raw" + HUGGINGFACE_PATHS = { + "raw": "krea/Krea-2-Raw", + "turbo": "krea/Krea-2-Turbo", + } + MODEL_LICENSE = "other" + + TEXT_ENCODER_CONFIGURATION = { + "text_encoder": { + "name": "Qwen3VL", + "tokenizer": AutoTokenizer, + "tokenizer_subfolder": "tokenizer", + "model": Qwen3VLModel, + "subfolder": "text_encoder", + }, + } + PROCESSOR_CLASS = AutoProcessor + PROCESSOR_PATH = "Qwen/Qwen3-VL-4B-Instruct" + PROCESSOR_SUBFOLDER = None + DEFAULT_LORA_TARGET = ["to_k", "to_q", "to_v", "to_out.0"] + FUSED_LORA_TARGET = ["to_qkv", "to_out.0"] + + @classmethod + def max_swappable_blocks(cls, config=None) -> Optional[int]: + return 27 + + def __init__(self, config, accelerator): + super().__init__(config, accelerator) + self.processor = None + self.vae_scale_factor = 8 + + def _uses_reference_latents(self) -> bool: + return bool(getattr(self.config, "krea2_reference_latents", False)) + + def supports_conditioning_dataset(self) -> bool: + return True + + def requires_conditioning_dataset(self) -> bool: + return self._uses_reference_latents() + + def requires_conditioning_validation_inputs(self) -> bool: + return self._uses_reference_latents() + + def requires_validation_edit_captions(self) -> bool: + return self._uses_reference_latents() + + def should_precompute_validation_negative_prompt(self) -> bool: + return not self._uses_reference_latents() + + def text_embed_cache_key(self) -> TextEmbedCacheKey: + if self._uses_reference_latents(): + return TextEmbedCacheKey.DATASET_AND_FILENAME + return super().text_embed_cache_key() + + def requires_text_embed_image_context(self) -> bool: + return self._uses_reference_latents() + + def requires_conditioning_latents(self) -> bool: + return self._uses_reference_latents() + + def update_pipeline_call_kwargs(self, pipeline_kwargs): + if self._uses_reference_latents() and "image" in pipeline_kwargs and "reference_image" not in pipeline_kwargs: + pipeline_kwargs["reference_image"] = pipeline_kwargs.pop("image") + return pipeline_kwargs + + def get_lora_target_layers(self): + if getattr(self.config, "fuse_qkv_projections", False): + return self.FUSED_LORA_TARGET + return super().get_lora_target_layers() + + def pre_vae_encode_transform_sample(self, sample): + if sample.dim() == 4: + sample = sample.unsqueeze(2) + return sample + + def post_vae_encode_transform_sample(self, sample): + vae = self.get_vae() + if vae is None: + raise ValueError("Cannot normalize Krea 2 latents without a loaded VAE.") + + sample_latents = sample.latent_dist.sample() + if sample_latents.dim() == 5: + sample_latents = sample_latents.squeeze(2) + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1) + .to(sample_latents.device, sample_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1).to( + sample_latents.device, sample_latents.dtype + ) + return (sample_latents - latents_mean) * latents_std + + def _load_processor_for_pipeline(self): + if self.processor is not None: + return self.processor + + processor_path = getattr(self.config, "processor_pretrained_model_name_or_path", None) or self.PROCESSOR_PATH + processor_subfolder = getattr(self.config, "processor_subfolder", self.PROCESSOR_SUBFOLDER) + processor_revision = getattr(self.config, "processor_revision", getattr(self.config, "revision", None)) + + processor_kwargs = {"pretrained_model_name_or_path": processor_path} + if processor_subfolder: + processor_kwargs["subfolder"] = processor_subfolder + if processor_revision is not None: + processor_kwargs["revision"] = processor_revision + if getattr(self.config, "local_files_only", False): + processor_kwargs["local_files_only"] = True + + self.processor = self.PROCESSOR_CLASS.from_pretrained(**processor_kwargs) + return self.processor + + def _encode_prompts(self, prompts: list, is_negative_prompt: bool = False): + if self.text_encoders is None or len(self.text_encoders) == 0: + self.load_text_encoder() + + text_encoder = self.text_encoders[0] + if text_encoder.device != self.accelerator.device: + text_encoder.to(self.accelerator.device) + + pipeline = self.get_pipeline(PipelineTypes.TEXT2IMG, load_base_model=False) + prompt_contexts = getattr(self, "_current_prompt_contexts", None) + encode_kwargs = { + "device": self.accelerator.device, + "num_images_per_prompt": 1, + } + if self.requires_text_embed_image_context(): + if not prompt_contexts or len(prompt_contexts) != len(prompts): + raise ValueError("Krea 2 reference text encoding requires image context for each caption.") + reference_images = self._prepare_prompt_image_batch(prompt_contexts, len(prompts)) + if reference_images is None: + raise ValueError("Failed to resolve reference images for Krea 2 text encoding.") + encode_kwargs["images"] = reference_images + encode_kwargs["processor"] = self._load_processor_for_pipeline() + + return pipeline.encode_prompt(prompts, **encode_kwargs) + + def _prepare_prompt_image_batch(self, prompt_contexts, batch_size: int): + if not prompt_contexts or len(prompt_contexts) != batch_size: + return None + images = [] + for idx, context in enumerate(prompt_contexts): + extracted = self._extract_prompt_image_from_context(context) + if extracted is None: + logger.warning("Failed to extract Krea 2 reference image tensor from context %s: %s", idx, context) + return None + if isinstance(extracted, list): + if len(extracted) != 1: + raise ValueError("Krea 2 reference text encoding expects exactly one reference image per caption.") + extracted = extracted[0] + images.append(self._tensor_to_pil(extracted)) + return images + + def _extract_prompt_image_from_context(self, context: dict): + if not isinstance(context, dict): + return None + tensor = self._coerce_prompt_tensor(context.get("conditioning_pixel_values")) + if tensor is not None: + return tensor + return self._load_prompt_image_from_backend(context) + + def _coerce_prompt_tensor(self, tensor): + if tensor is None: + return None + if isinstance(tensor, Image.Image): + return tensor + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + if not torch.is_tensor(tensor): + return None + if tensor.dim() == 4 and tensor.size(0) == 1: + tensor = tensor.squeeze(0) + if tensor.dim() != 3: + return None + return tensor.to(device=self.accelerator.device, dtype=self.config.weight_dtype) + + def _load_prompt_image_from_backend(self, context: dict): + image_paths = context.get("image_paths") + data_backend_ids = context.get("data_backend_ids") + if isinstance(image_paths, (list, tuple)) and image_paths: + image_path = image_paths[0] + if isinstance(data_backend_ids, (list, tuple)) and data_backend_ids: + data_backend_id = data_backend_ids[0] + else: + data_backend_id = context.get("data_backend_id") + else: + image_path = context.get("image_path") + data_backend_id = context.get("data_backend_id") + if not image_path or not data_backend_id: + return None + backend_entry = StateTracker.get_data_backend(data_backend_id) + if backend_entry is None: + return None + data_backend = backend_entry.get("data_backend") + if data_backend is None: + return None + image = data_backend.read_image(image_path) + return self._convert_image_to_tensor(image) + + def _convert_image_to_tensor(self, image): + if isinstance(image, Image.Image): + array = np.array(image.convert("RGB"), copy=True) + tensor = torch.from_numpy(array) + elif isinstance(image, np.ndarray): + array = image[0] if image.ndim == 4 else image + if array.ndim == 3 and array.shape[2] == 4: + array = array[:, :, :3] + tensor = torch.from_numpy(array) + elif torch.is_tensor(image): + tensor = image.clone().detach() + else: + return None + if tensor.dim() == 3 and tensor.shape[0] not in (1, 3): + tensor = tensor.permute(2, 0, 1) + elif tensor.dim() == 4 and tensor.size(0) == 1: + tensor = tensor.squeeze(0) + tensor = tensor.to(dtype=torch.float32) + if tensor.max() > 1.0 or tensor.min() < 0.0: + tensor = tensor / 255.0 + return tensor.clamp_(0.0, 1.0).to(device=self.accelerator.device, dtype=self.config.weight_dtype) + + def _tensor_to_pil(self, tensor: torch.Tensor | np.ndarray | Image.Image): + if isinstance(tensor, Image.Image): + return tensor + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + if not torch.is_tensor(tensor): + raise ValueError(f"Unsupported Krea 2 reference image type: {type(tensor)}") + converted = tensor.detach().float().cpu() + if converted.dim() == 4 and converted.size(0) == 1: + converted = converted.squeeze(0) + if converted.dim() != 3: + raise ValueError(f"Expected Krea 2 reference tensor with shape (C, H, W); received {tuple(converted.shape)}.") + if converted.max().item() > 1.0 or converted.min().item() < 0.0: + converted = (converted + 1.0) / 2.0 + converted = converted.clamp_(0.0, 1.0) + array = (converted.permute(1, 2, 0).numpy() * 255.0).round().astype(np.uint8) + if array.shape[2] == 1: + array = np.repeat(array, 3, axis=2) + return Image.fromarray(array) + + def _format_text_embedding(self, text_embedding: torch.Tensor): + prompt_embeds, prompt_embeds_mask = text_embedding + return { + "prompt_embeds": prompt_embeds, + "attention_masks": prompt_embeds_mask, + } + + def convert_text_embed_for_pipeline(self, text_embedding: dict) -> dict: + attention_mask = text_embedding.get("attention_masks", None) + if attention_mask is not None and attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + prompt_embeds = text_embedding["prompt_embeds"] + if prompt_embeds.dim() == 3: + prompt_embeds = prompt_embeds.unsqueeze(0) + return { + "prompt_embeds": prompt_embeds, + "prompt_embeds_mask": attention_mask.to(dtype=torch.int64) if attention_mask is not None else None, + } + + def convert_negative_text_embed_for_pipeline(self, text_embedding: dict) -> dict: + attention_mask = text_embedding.get("attention_masks", None) + if attention_mask is not None and attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + prompt_embeds = text_embedding["prompt_embeds"] + if prompt_embeds.dim() == 3: + prompt_embeds = prompt_embeds.unsqueeze(0) + return { + "negative_prompt_embeds": prompt_embeds, + "negative_prompt_embeds_mask": attention_mask.to(dtype=torch.int64) if attention_mask is not None else None, + } + + def collate_prompt_embeds(self, text_encoder_output: list) -> dict: + if not text_encoder_output: + return {} + embeds = [] + masks = [] + for entry in text_encoder_output: + embed = entry["prompt_embeds"] + mask = entry.get("attention_masks") + if embed.dim() == 3: + embed = embed.unsqueeze(0) + embeds.append(embed) + if mask is not None: + if mask.dim() == 1: + mask = mask.unsqueeze(0) + masks.append(mask) + + max_seq_len = max(embed.shape[1] for embed in embeds) + padded_embeds = [] + padded_masks = [] + for idx, embed in enumerate(embeds): + if embed.shape[1] < max_seq_len: + pad_len = max_seq_len - embed.shape[1] + pad_shape = (embed.shape[0], pad_len, embed.shape[2], embed.shape[3]) + embed = torch.cat( + [embed, torch.zeros(pad_shape, dtype=embed.dtype, device=embed.device)], + dim=1, + ) + padded_embeds.append(embed) + if masks: + mask = masks[idx] + if mask.shape[1] < max_seq_len: + pad_len = max_seq_len - mask.shape[1] + mask = torch.cat( + [mask, torch.zeros((mask.shape[0], pad_len), dtype=mask.dtype, device=mask.device)], + dim=1, + ) + padded_masks.append(mask) + + return { + "prompt_embeds": torch.cat(padded_embeds, dim=0), + "attention_masks": torch.cat(padded_masks, dim=0) if padded_masks else None, + } + + def _patch_size(self) -> int: + transformer = self.unwrap_model(self.model) if getattr(self, "model", None) is not None else None + config = getattr(transformer, "config", None) + return int(max(getattr(config, "patch_size", 2), 1)) + + def _pack_latents(self, latents: torch.Tensor) -> tuple[torch.Tensor, int, int]: + patch_size = self._patch_size() + batch_size, channels, height, width = latents.shape + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Krea 2 latent dimensions must be divisible by patch_size={patch_size}: {height}x{width}.") + packed = latents.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size) + packed = packed.permute(0, 2, 4, 1, 3, 5) + packed = packed.reshape( + batch_size, (height // patch_size) * (width // patch_size), channels * patch_size * patch_size + ) + return packed, height // patch_size, width // patch_size + + def _unpack_latents(self, latents: torch.Tensor, latent_height: int, latent_width: int) -> torch.Tensor: + patch_size = self._patch_size() + batch_size, _, channels = latents.shape + unpacked = latents.view( + batch_size, + latent_height // patch_size, + latent_width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + unpacked = unpacked.permute(0, 3, 1, 4, 2, 5) + return unpacked.reshape(batch_size, channels // (patch_size * patch_size), latent_height, latent_width) + + @staticmethod + def _position_ids_for_grids(text_seq_len: int, grids: list[tuple[int, int]], device: torch.device): + text_ids = torch.zeros(text_seq_len, 3, device=device) + image_ids = [] + for grid_height, grid_width in grids: + ids = torch.zeros(grid_height, grid_width, 3, device=device) + ids[..., 1] = torch.arange(grid_height, device=device)[:, None] + ids[..., 2] = torch.arange(grid_width, device=device)[None, :] + image_ids.append(ids.reshape(grid_height * grid_width, 3)) + if image_ids: + return torch.cat([text_ids, *image_ids], dim=0) + return text_ids + + def _prepare_model_predict_timesteps(self, raw_timesteps, batch_size: int) -> torch.Tensor: + if not torch.is_tensor(raw_timesteps): + timesteps = torch.tensor(raw_timesteps, device=self.accelerator.device, dtype=torch.float32) + else: + timesteps = raw_timesteps.to(device=self.accelerator.device, dtype=torch.float32) + if timesteps.ndim == 0: + timesteps = timesteps.expand(batch_size) + elif timesteps.ndim == 1: + if timesteps.shape[0] == 1: + timesteps = timesteps.expand(batch_size) + elif timesteps.shape[0] != batch_size: + raise ValueError( + f"Krea 2 expected 1 timestep or {batch_size} per-batch timesteps, got {timesteps.shape[0]}." + ) + else: + raise ValueError(f"Krea 2 expected scalar or 1D timesteps, got shape {tuple(timesteps.shape)}.") + return timesteps / 1000.0 + + def _prepare_reference_latents(self, prepared_batch: dict, batch_size: int, channels: int, height: int, width: int): + reference_latents = prepared_batch.get("conditioning_latents") + if isinstance(reference_latents, list): + reference_latents = reference_latents[0] if reference_latents else None + if reference_latents is None: + raise ValueError("Krea 2 reference-latent training requires conditioning_latents in the batch.") + if reference_latents.dim() == 5: + if reference_latents.shape[2] != 1: + raise ValueError(f"Krea 2 reference latents must have a single frame, got {tuple(reference_latents.shape)}.") + reference_latents = reference_latents.squeeze(2) + if reference_latents.shape != (batch_size, channels, height, width): + raise ValueError( + "Krea 2 reference latents must match target latent shape. " + f"Got reference {tuple(reference_latents.shape)} vs target {(batch_size, channels, height, width)}." + ) + return reference_latents.to(device=self.accelerator.device, dtype=self.config.weight_dtype) + + def model_predict(self, prepared_batch): + latent_model_input = prepared_batch["noisy_latents"] + target_latents = prepared_batch["latents"] + target_ndim = target_latents.dim() + + if latent_model_input.dim() == 5: + if latent_model_input.shape[2] != 1: + raise ValueError( + f"Krea 2 image training expects a single latent frame, got {tuple(latent_model_input.shape)}." + ) + latent_model_input = latent_model_input.squeeze(2) + batch_size, channels, latent_height, latent_width = latent_model_input.shape + + hidden_states, grid_height, grid_width = self._pack_latents( + latent_model_input.to(device=self.accelerator.device, dtype=self.config.weight_dtype) + ) + target_token_count = hidden_states.shape[1] + grids = [(grid_height, grid_width)] + + if self._uses_reference_latents(): + reference_latents = self._prepare_reference_latents( + prepared_batch, + batch_size=batch_size, + channels=channels, + height=latent_height, + width=latent_width, + ) + packed_reference, ref_grid_height, ref_grid_width = self._pack_latents(reference_latents) + hidden_states = torch.cat([hidden_states, packed_reference], dim=1) + grids.append((ref_grid_height, ref_grid_width)) + + prompt_embeds = prepared_batch["prompt_embeds"].to(device=self.accelerator.device, dtype=self.config.weight_dtype) + prompt_embeds_mask = prepared_batch.get("encoder_attention_mask") + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.to(self.accelerator.device, dtype=torch.int64) + if prompt_embeds_mask.dim() == 3 and prompt_embeds_mask.size(1) == 1: + prompt_embeds_mask = prompt_embeds_mask.squeeze(1) + + timesteps = self._prepare_model_predict_timesteps(prepared_batch["timesteps"], batch_size) + position_ids = self._position_ids_for_grids(prompt_embeds.shape[1], grids, self.accelerator.device) + + noise_pred = self.model( + hidden_states=hidden_states, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + position_ids=position_ids, + encoder_attention_mask=prompt_embeds_mask, + return_dict=False, + )[0] + noise_pred = noise_pred[:, :target_token_count] + noise_pred = self._unpack_latents(noise_pred, latent_height, latent_width) + + if target_ndim == 5: + noise_pred = noise_pred.unsqueeze(2) + return {"model_prediction": noise_pred} + + +ModelRegistry.register("krea2", Krea2) diff --git a/simpletuner/helpers/models/krea2/pipeline.py b/simpletuner/helpers/models/krea2/pipeline.py new file mode 100644 index 000000000..a2b53292d --- /dev/null +++ b/simpletuner/helpers/models/krea2/pipeline.py @@ -0,0 +1,852 @@ +# Vendored from huggingface/diffusers#14046 at 3993de59e37344d92aa24ec25bdc39413157b744. +# Copyright 2026 Krea AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKLQwenImage +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image +from transformers import AutoTokenizer, Qwen3VLModel + +from simpletuner.helpers.models.krea2.lora_pipeline import Krea2LoraLoaderMixin +from simpletuner.helpers.models.krea2.pipeline_output import Krea2PipelineOutput +from simpletuner.helpers.models.krea2.transformer import Krea2Transformer2DModel + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Krea2Pipeline + + >>> # Load from a local directory produced by the Krea 2 conversion (no hub repo yet). + >>> pipe = Krea2Pipeline.from_pretrained("path/to/krea2-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "a fox in the snow" + >>> # Base (midtrain) checkpoint defaults. For the few-step distilled (TDM) checkpoint use + >>> # `num_inference_steps=8, guidance_scale=0.0` instead. + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.5).images[0] + >>> image.save("krea2.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + if hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + if hasattr(encoder_output, "latents"): + return encoder_output.latents + raise AttributeError("Could not access latents of provided encoder_output") + + +class Krea2Pipeline(DiffusionPipeline, Krea2LoraLoaderMixin): + r""" + The Krea 2 pipeline for text-to-image generation. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Euler flow-matching scheduler. The Krea 2 sigma schedule is the resolution-aware exponential time shift, so + the scheduler config is expected to set `use_dynamic_shifting=True` together with the Krea 2 shift + parameters (`base_shift=0.5`, `max_shift=1.15`, `base_image_seq_len=256`, `max_image_seq_len=6400`). + vae ([`AutoencoderKLQwenImage`]): + The Qwen-Image variational auto-encoder (f8, 16 latent channels) used to decode latents to images. + text_encoder ([`~transformers.PreTrainedModel`]): + A Qwen3-VL model (e.g. `Qwen3VLModel` of `Qwen/Qwen3-VL-4B-Instruct`). The pipeline consumes a stack of + hidden states tapped from several decoder layers rather than the last hidden state. + tokenizer ([`~transformers.AutoTokenizer`]): + The tokenizer paired with the text encoder. + transformer ([`Krea2Transformer2DModel`]): + The Krea 2 single-stream MMDiT that predicts the flow-matching velocity. + text_encoder_select_layers (`tuple[int, ...]`, *optional*): + Indices into the text encoder's `hidden_states` tuple (0 is the embedding output) whose states are stacked + per token as the transformer's text conditioning. Must have `transformer.config.num_text_layers` entries. + is_distilled (`bool`, *optional*, defaults to `False`): + Whether the transformer is the few-step distilled (TDM/turbo) checkpoint. When `True` a fixed timestep + shift `mu=1.15` is used; otherwise `mu` is computed from the image resolution. + patch_size (`int`, *optional*, defaults to 2): + Side length of the square patches the latents are packed into before entering the transformer. The + effective pixel-to-token downsampling factor is `vae_scale_factor * patch_size`. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen3VLModel, + tokenizer: AutoTokenizer, + transformer: Krea2Transformer2DModel, + text_encoder_select_layers: tuple[int, ...] | list[int] | None = None, + is_distilled: bool = False, + patch_size: int = 2, + ): + super().__init__() + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + # Indices into the text encoder's `hidden_states` tuple (0 is the embedding output) whose states are stacked + # per token and fed to the transformer's text fusion stage. `None` selects the Krea 2 (Qwen3-VL-4B) taps. + if text_encoder_select_layers is None: + text_encoder_select_layers = (2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35) + self.register_to_config(text_encoder_select_layers=tuple(text_encoder_select_layers)) + self.text_encoder_select_layers = tuple(text_encoder_select_layers) + # The few-step distilled (TDM/turbo) checkpoint uses a fixed timestep-shift `mu=1.15`; the base (midtrain) + # checkpoint computes `mu` from the image resolution. Encoded here so each checkpoint carries the right schedule. + self.register_to_config(is_distilled=is_distilled) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # Latents are packed into `patch_size`-square patches before entering the transformer, so the effective + # pixel-to-token downsampling factor is vae_scale_factor * patch_size. + self.register_to_config(patch_size=patch_size) + self.patch_size = patch_size + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * self.patch_size) + + # Text conditioning uses the Qwen-Image chat template, tokenized as a fixed-length block: the prompt is padded + # to a fixed length first and the assistant suffix is appended after the padding (matching how the model was + # sampled at training time). The first `prompt_template_encode_start_idx` (system prefix) tokens are dropped + # from the encoder outputs. + self.prompt_template_encode_prefix = ( + "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + ) + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.prompt_template_encode_num_suffix_tokens = 5 + + def get_text_hidden_states( + self, + prompt: str | list[str], + max_sequence_length: int = 512, + device: torch.device | None = None, + images: Image.Image | list[Image.Image] | None = None, + processor: Any | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Tokenize `prompt` into the fixed-length Krea 2 layout and tap the selected encoder hidden states. + + Returns a `(hidden_states, attention_mask)` tuple of shapes `(batch_size, text_seq_len, num_text_layers, + text_hidden_dim)` and `(batch_size, text_seq_len)` (bool). + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + prefix_idx = self.prompt_template_encode_start_idx + + if images is not None: + if processor is None: + raise ValueError("Krea 2 image-context text encoding requires a Qwen3VL processor.") + images = [images] if isinstance(images, Image.Image) else images + if len(images) != len(prompt): + raise ValueError(f"Expected {len(prompt)} reference images for Krea 2 prompt encoding, got {len(images)}.") + + text = [ + self.prompt_template_encode_prefix + + "<|vision_start|><|image_pad|><|vision_end|>\n" + + entry + + self.prompt_template_encode_suffix + for entry in prompt + ] + text_tokens = processor(text=text, images=images, padding=True, return_tensors="pt").to(device) + outputs = self.text_encoder( + **text_tokens, + output_hidden_states=True, + ) + hidden_states = torch.stack([outputs.hidden_states[i] for i in self.text_encoder_select_layers], dim=2) + attention_mask = text_tokens.attention_mask.bool() + return hidden_states[:, prefix_idx:], attention_mask[:, prefix_idx:] + + text = [self.prompt_template_encode_prefix + e for e in prompt] + text_tokens = self.tokenizer( + text, + truncation=True, + padding="max_length", + max_length=max_sequence_length + prefix_idx - self.prompt_template_encode_num_suffix_tokens, + return_tensors="pt", + ).to(device) + suffix_tokens = self.tokenizer([self.prompt_template_encode_suffix] * len(text), return_tensors="pt").to(device) + + input_ids = torch.cat([text_tokens.input_ids, suffix_tokens.input_ids], dim=1) + attention_mask = torch.cat([text_tokens.attention_mask, suffix_tokens.attention_mask], dim=1).bool() + + # Krea 2 pads in the middle of the template (`[prefix | prompt | PAD | suffix]`), so the suffix tokens sit + # downstream of the padding. The text features must use positions that count only real tokens (padding does + # not consume a position) to match how the model was trained; otherwise the suffix gets a shifted mRoPE phase. + # `Qwen3VLModel`'s default raw-index positions would place the suffix at ~max_length instead. Build the + # cumulative-valid-token positions explicitly and broadcast across the 3 mRoPE axes (T/H/W are equal for text). + position_ids = (attention_mask.long().cumsum(dim=-1) - 1).clamp(min=0) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + ) + hidden_states = torch.stack([outputs.hidden_states[i] for i in self.text_encoder_select_layers], dim=2) + + hidden_states = hidden_states[:, prefix_idx:] + attention_mask = attention_mask[:, prefix_idx:] + return hidden_states, attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 512, + images: Image.Image | list[Image.Image] | None = None, + processor: Any | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings of shape `(batch_size, text_seq_len, num_text_layers, text_hidden_dim)`. + Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will + be generated from `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated boolean mask marking valid text tokens, of shape `(batch_size, text_seq_len)`. Required + when `prompt_embeds` is passed. + max_sequence_length (`int`, defaults to 512): + Fixed text sequence length consumed by the transformer; prompts are padded or truncated to it. + """ + device = device or self._execution_device + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self.get_text_hidden_states( + prompt, + max_sequence_length, + device, + images=images, + processor=processor, + ) + + batch_size, seq_len, num_text_layers, dim = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, num_text_layers, dim) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + multiple = self.vae_scale_factor * self.patch_size + if height % multiple != 0 or width % multiple != 0: + raise ValueError(f"`height` and `width` must be divisible by {multiple} but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length <= 0: + raise ValueError(f"`max_sequence_length` must be a positive integer but is {max_sequence_length}") + + def _pack_latents(self, latents, batch_size, num_channels_latents, height, width): + p = self.patch_size + latents = latents.view(batch_size, num_channels_latents, height // p, p, width // p, p) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // p) * (width // p), num_channels_latents * p * p) + + return latents + + def _unpack_latents(self, latents, height, width): + batch_size, _, channels = latents.shape + p = self.patch_size + + # The VAE applies `vae_scale_factor`x compression, and latents are packed into `p`-square patches, so latent + # height and width must be divisible by `p`. + height = p * (int(height) // (self.vae_scale_factor * p)) + width = p * (int(width) // (self.vae_scale_factor * p)) + + latents = latents.view(batch_size, height // p, width // p, channels // (p * p), p, p) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (p * p), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator | list[torch.Generator] | None = None): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + return (image_latents - latents_mean) / latents_std + + def prepare_reference_latents( + self, + reference_image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ): + if reference_image is None: + return None, None + + reference_image = self.image_processor.preprocess(reference_image, height=height, width=width) + reference_image = reference_image.to(device=device, dtype=dtype) + reference_latents = self._encode_vae_image(reference_image, generator=generator) + if batch_size > reference_latents.shape[0] and batch_size % reference_latents.shape[0] == 0: + reference_latents = torch.cat([reference_latents] * (batch_size // reference_latents.shape[0]), dim=0) + elif batch_size > reference_latents.shape[0] and batch_size % reference_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `reference_image` of batch size {reference_latents.shape[0]} to {batch_size} prompts." + ) + reference_latents = reference_latents.squeeze(2) + latent_height, latent_width = reference_latents.shape[2:] + packed_reference = self._pack_latents( + reference_latents, + batch_size, + num_channels_latents, + latent_height, + latent_width, + ) + return packed_reference, (latent_height // self.patch_size, latent_width // self.patch_size) + + @staticmethod + def prepare_position_ids(text_seq_len: int, grid_height: int, grid_width: int, device: torch.device): + """Build the `(text_seq_len + grid_height * grid_width, 3)` rotary coordinates for the combined sequence: + text tokens sit at the origin, image tokens carry their `(0, h, w)` latent-grid coordinates.""" + text_ids = torch.zeros(text_seq_len, 3, device=device) + image_ids = torch.zeros(grid_height, grid_width, 3, device=device) + image_ids[..., 1] = torch.arange(grid_height, device=device)[:, None] + image_ids[..., 2] = torch.arange(grid_width, device=device)[None, :] + image_ids = image_ids.reshape(grid_height * grid_width, 3) + return torch.cat([text_ids, image_ids], dim=0) + + @staticmethod + def prepare_position_ids_for_grids(text_seq_len: int, grids: list[tuple[int, int]], device: torch.device): + text_ids = torch.zeros(text_seq_len, 3, device=device) + image_ids = [] + for grid_height, grid_width in grids: + ids = torch.zeros(grid_height, grid_width, 3, device=device) + ids[..., 1] = torch.arange(grid_height, device=device)[:, None] + ids[..., 2] = torch.arange(grid_width, device=device)[None, :] + image_ids.append(ids.reshape(grid_height * grid_width, 3)) + return torch.cat([text_ids, *image_ids], dim=0) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + shape = (batch_size, num_channels_latents, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, latent_height, latent_width) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + reference_image: Any | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + attention_kwargs: dict[str, Any] | None = None, + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when `guidance_scale <= 0`; defaults + to an empty prompt when guidance is enabled. + height (`int`, defaults to 1024): + The height in pixels of the generated image. Rounded up to a multiple of 16 if needed. + width (`int`, defaults to 1024): + The width in pixels of the generated image. Rounded up to a multiple of 16 if needed. + num_inference_steps (`int`, defaults to 28): + The number of denoising steps. Use 28 for the base (midtrain) checkpoint and 8 for the few-step + distilled (TDM) checkpoint. + sigmas (`list[float]`, *optional*): + Custom sigmas for the scheduler. If not defined, the default `linspace(1.0, 1/num_inference_steps, + num_inference_steps)` grid is used (the resolution-aware shift is applied inside the scheduler). + guidance_scale (`float`, defaults to 4.5): + Classifier-free guidance scale, following the Krea 2 convention: the velocity is computed as `cond + + guidance_scale * (cond - uncond)` and guidance is enabled whenever `guidance_scale > 0` (this equals + the usual CFG formulation with scale `1 + guidance_scale`). Set to `0.0` to disable (e.g. for the TDM + checkpoint). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents in packed form `(batch_size, image_seq_len, in_channels)`, sampled from a + Gaussian distribution, to be used as inputs for image generation. + reference_image (`PIL.Image.Image`, list of images, or tensor, *optional*): + Clean reference image to append as packed VAE latent tokens during denoising. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings of shape `(batch_size, text_seq_len, num_text_layers, text_hidden_dim)`. + If not provided, embeddings are generated from `prompt`. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Boolean mask for `prompt_embeds`; required when `prompt_embeds` is passed. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings; same layout as `prompt_embeds`. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Boolean mask for `negative_prompt_embeds`; required when `negative_prompt_embeds` is passed. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"`, `"np"`, `"pt"` or `"latent"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.krea2.Krea2PipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step with `callback_on_step_end(self, step, + timestep, callback_kwargs)`. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. Must be a subset of + `._callback_tensor_inputs`. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + max_sequence_length (`int`, defaults to 512): + Fixed text sequence length consumed by the transformer; prompts are padded or truncated to it. + + Examples: + + Returns: + [`~pipelines.krea2.Krea2PipelineOutput`] or `tuple`: [`~pipelines.krea2.Krea2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`, whose first element is a list with the generated images. + """ + multiple = self.vae_scale_factor * self.patch_size + if height % multiple != 0 or width % multiple != 0: + rounded_height = ((height + multiple - 1) // multiple) * multiple + rounded_width = ((width + multiple - 1) // multiple) * multiple + logger.warning( + f"`height` and `width` must be multiples of {multiple}; rounding up from {height}x{width} to" + f" {rounded_height}x{rounded_width}." + ) + height, width = rounded_height, rounded_width + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode the prompts + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + max_sequence_length=max_sequence_length, + ) + if self.do_classifier_free_guidance: + if negative_prompt is None and negative_prompt_embeds is None: + negative_prompt = "" + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latents and position ids + num_channels_latents = self.transformer.config.in_channels // (self.patch_size**2) + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + grid_height = height // (self.vae_scale_factor * self.patch_size) + grid_width = width // (self.vae_scale_factor * self.patch_size) + target_image_seq_len = latents.shape[1] + reference_latents, reference_grid = self.prepare_reference_latents( + reference_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + grids = [(grid_height, grid_width)] + if reference_grid is not None: + grids.append(reference_grid) + position_ids = self.prepare_position_ids_for_grids(prompt_embeds.shape[1], grids, device) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = target_image_seq_len + if self.config.is_distilled: + mu = 1.15 + else: + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 6400), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = (t / self.scheduler.config.num_train_timesteps).expand(latents.shape[0]).to(latents.dtype) + transformer_latents = ( + torch.cat([latents, reference_latents], dim=1) if reference_latents is not None else latents + ) + + noise_pred = self.transformer( + hidden_states=transformer_latents, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + position_ids=position_ids, + encoder_attention_mask=prompt_embeds_mask, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, :target_image_seq_len] + + if self.do_classifier_free_guidance: + neg_noise_pred = self.transformer( + hidden_states=transformer_latents, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + position_ids=position_ids, + encoder_attention_mask=negative_prompt_embeds_mask, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, :target_image_seq_len] + noise_pred = noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # 7. Decode latents + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Krea2PipelineOutput(images=image) diff --git a/simpletuner/helpers/models/krea2/pipeline_output.py b/simpletuner/helpers/models/krea2/pipeline_output.py new file mode 100644 index 000000000..db36bc458 --- /dev/null +++ b/simpletuner/helpers/models/krea2/pipeline_output.py @@ -0,0 +1,34 @@ +# Vendored from huggingface/diffusers#14046 at 3993de59e37344d92aa24ec25bdc39413157b744. +# Copyright 2026 Krea AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image +from diffusers.utils import BaseOutput + + +@dataclass +class Krea2PipelineOutput(BaseOutput): + """ + Output class for the Krea 2 pipeline. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`): + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/simpletuner/helpers/models/krea2/transformer.py b/simpletuner/helpers/models/krea2/transformer.py new file mode 100644 index 000000000..1a29303d8 --- /dev/null +++ b/simpletuner/helpers/models/krea2/transformer.py @@ -0,0 +1,561 @@ +# Vendored from huggingface/diffusers#14046 at 3993de59e37344d92aa24ec25bdc39413157b744. +# Copyright 2026 Krea AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.embeddings import apply_rotary_emb, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import apply_lora_scale, logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _krea2_rope_freqs_dtype(device: torch.device) -> torch.dtype: + return torch.float32 if device.type in {"mps", "neuron", "npu"} else torch.float64 + + +class Krea2RMSNorm(nn.Module): + """RMSNorm with a zero-centered scale: the effective multiplier is `1 + weight`, matching the Krea 2 checkpoint + format. The activations are upcast so the normalization runs in float32; the scale weight is kept in float32 by the + model's `_keep_in_fp32_modules`.""" + + def __init__(self, dim: int, eps: float = 1e-5) -> None: + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + dtype = hidden_states.dtype + hidden_states = F.rms_norm(hidden_states.float(), (self.dim,), weight=self.weight + 1.0, eps=self.eps) + return hidden_states.to(dtype) + + +class Krea2AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "Krea2Attention", + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + if getattr(attn, "fused_projections", False): + q_dim = attn.head_dim * attn.num_heads + kv_dim = attn.head_dim * attn.num_kv_heads + query, key, value = attn.to_qkv(hidden_states).split((q_dim, kv_dim, kv_dim), dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + query = query.unflatten(-1, (attn.num_heads, attn.head_dim)) + key = key.unflatten(-1, (attn.num_kv_heads, attn.head_dim)) + value = value.unflatten(-1, (attn.num_kv_heads, attn.head_dim)) + gate = attn.to_gate(hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + enable_gqa=attn.num_heads != attn.num_kv_heads, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states * torch.sigmoid(gate) + return attn.to_out[0](hidden_states) + + +class Krea2Attention(nn.Module, AttentionModuleMixin): + """Self-attention with grouped-query projections, q/k RMSNorm, rotary embeddings and a sigmoid output gate.""" + + _default_processor_cls = Krea2AttnProcessor + _available_processors = [Krea2AttnProcessor] + + def __init__( + self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None, eps: float = 1e-5, processor=None + ) -> None: + super().__init__() + if hidden_size % num_heads != 0: + raise ValueError(f"hidden_size={hidden_size} must be divisible by num_heads={num_heads}") + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.head_dim = hidden_size // num_heads + self.use_bias = False + + self.to_q = nn.Linear(hidden_size, self.head_dim * self.num_heads, bias=False) + self.to_k = nn.Linear(hidden_size, self.head_dim * self.num_kv_heads, bias=False) + self.to_v = nn.Linear(hidden_size, self.head_dim * self.num_kv_heads, bias=False) + self.to_gate = nn.Linear(hidden_size, hidden_size, bias=False) + self.norm_q = Krea2RMSNorm(self.head_dim, eps=eps) + self.norm_k = Krea2RMSNorm(self.head_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(hidden_size, hidden_size, bias=False), nn.Dropout(0.0)]) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + self.fused_projections = False + + @torch.no_grad() + def fuse_projections(self) -> None: + if self.fused_projections: + return + + device = self.to_q.weight.device + dtype = self.to_q.weight.dtype + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + self.to_qkv = nn.Linear( + concatenated_weights.shape[1], + concatenated_weights.shape[0], + bias=False, + device=device, + dtype=dtype, + ) + self.to_qkv.weight.copy_(concatenated_weights) + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self) -> None: + if not self.fused_projections: + return + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k in kwargs if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class Krea2SwiGLU(nn.Module): + """SwiGLU feed-forward network.""" + + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.gate = nn.Linear(dim, hidden_dim, bias=False) + self.up = nn.Linear(dim, hidden_dim, bias=False) + self.down = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down(F.silu(self.gate(hidden_states)) * self.up(hidden_states)) + + +class Krea2TextFusionBlock(nn.Module): + """Pre-norm transformer block (no rotary embeddings, no time modulation) used by the text fusion stage.""" + + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, intermediate_size: int, eps: float) -> None: + super().__init__() + self.norm1 = Krea2RMSNorm(dim, eps=eps) + self.norm2 = Krea2RMSNorm(dim, eps=eps) + self.attn = Krea2Attention(dim, num_heads, num_kv_heads, eps=eps) + self.ff = Krea2SwiGLU(dim, intermediate_size) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = hidden_states + self.attn(self.norm1(hidden_states), attention_mask=attention_mask) + hidden_states = hidden_states + self.ff(self.norm2(hidden_states)) + return hidden_states + + +class Krea2TextFusion(nn.Module): + """Fuses the stack of tapped text-encoder hidden states into a single sequence of text features. + + Two `layerwise_blocks` attend across the `num_text_layers` axis independently for every token, a linear `projector` + collapses that axis, and two `refiner_blocks` attend across the token sequence. + """ + + def __init__( + self, + num_text_layers: int, + dim: int, + num_heads: int, + num_kv_heads: int, + intermediate_size: int, + num_layerwise_blocks: int, + num_refiner_blocks: int, + eps: float, + ) -> None: + super().__init__() + self.layerwise_blocks = nn.ModuleList( + [Krea2TextFusionBlock(dim, num_heads, num_kv_heads, intermediate_size, eps) for _ in range(num_layerwise_blocks)] + ) + self.projector = nn.Linear(num_text_layers, 1, bias=False) + self.refiner_blocks = nn.ModuleList( + [Krea2TextFusionBlock(dim, num_heads, num_kv_heads, intermediate_size, eps) for _ in range(num_refiner_blocks)] + ) + + def forward(self, encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + batch_size, seq_len, num_text_layers, dim = encoder_hidden_states.shape + + hidden_states = encoder_hidden_states.reshape(batch_size * seq_len, num_text_layers, dim) + for block in self.layerwise_blocks: + hidden_states = block(hidden_states.contiguous()) + + hidden_states = hidden_states.reshape(batch_size, seq_len, num_text_layers, dim).permute(0, 1, 3, 2) + hidden_states = self.projector(hidden_states).squeeze(-1) + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + return hidden_states + + +class Krea2TransformerBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, num_kv_heads: int, norm_eps: float) -> None: + super().__init__() + self.scale_shift_table = nn.Parameter(torch.zeros(6, hidden_size)) + self.norm1 = Krea2RMSNorm(hidden_size, eps=norm_eps) + self.norm2 = Krea2RMSNorm(hidden_size, eps=norm_eps) + self.attn = Krea2Attention(hidden_size, num_heads, num_kv_heads, eps=norm_eps) + self.ff = Krea2SwiGLU(hidden_size, intermediate_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + # temb: (B, 1, 6 * hidden_size), shared across all blocks; each block only learns an additive table. + modulation = temb.unflatten(-1, (6, -1)) + self.scale_shift_table + prescale, preshift, pregate, postscale, postshift, postgate = modulation.unbind(-2) + + attn_out = self.attn( + (1.0 + prescale) * self.norm1(hidden_states) + preshift, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + pregate * attn_out + ff_out = self.ff((1.0 + postscale) * self.norm2(hidden_states) + postshift) + hidden_states = hidden_states + postgate * ff_out + return hidden_states + + +class Krea2TimestepEmbedding(nn.Module): + """Sinusoidal flow-time embedding (cos-first, input scaled by 1000) followed by a two-layer MLP. + + Keeps the sequence dimension at size 1 so the per-block modulations broadcast over tokens. + """ + + def __init__(self, embed_dim: int, hidden_size: int) -> None: + super().__init__() + self.embed_dim = embed_dim + self.linear_1 = nn.Linear(embed_dim, hidden_size, bias=True) + self.linear_2 = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + half = self.embed_dim // 2 + freqs = torch.exp(-math.log(1e4) * torch.arange(half, dtype=torch.float32, device=timestep.device) / half) + args = (timestep.float() * 1e3)[:, None, None] * freqs + emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + return self.linear_2(F.gelu(self.linear_1(emb), approximate="tanh")) + + +class Krea2TextProjection(nn.Module): + """Projects the fused text features into the transformer width.""" + + def __init__(self, text_dim: int, hidden_size: int, eps: float) -> None: + super().__init__() + self.norm = Krea2RMSNorm(text_dim, eps=eps) + self.linear_1 = nn.Linear(text_dim, hidden_size, bias=True) + self.linear_2 = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(self.norm(hidden_states)) + return self.linear_2(F.gelu(hidden_states, approximate="tanh")) + + +class Krea2FinalLayer(nn.Module): + """Final adaptive RMSNorm and output projection. Kept as one module (and in `_no_split_modules`) so the learned + modulation table, norm and projection stay co-located under device-mapped inference.""" + + def __init__(self, hidden_size: int, out_channels: int, eps: float) -> None: + super().__init__() + self.scale_shift_table = nn.Parameter(torch.zeros(2, hidden_size)) + self.norm = Krea2RMSNorm(hidden_size, eps=eps) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + modulation = temb + self.scale_shift_table + scale, shift = modulation.chunk(2, dim=1) + hidden_states = (1.0 + scale) * self.norm(hidden_states) + shift + return self.linear(hidden_states) + + +# Copied from diffusers.models.transformers.transformer_flux.FluxPosEmbed with FluxPosEmbed->Krea2RotaryPosEmbed +class Krea2RotaryPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: list[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + freqs_dtype = _krea2_rope_freqs_dtype(ids.device) + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class Krea2Transformer2DModel(ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin): + r""" + The single-stream MMDiT flow-matching backbone used by the Krea 2 pipeline. + + Text conditioning enters as a stack of hidden states tapped from several layers of a multimodal text encoder. A + small text-fusion transformer collapses the layer axis and refines the token sequence; the result is concatenated + with the patchified image latents into a single `[text, image]` sequence processed by the transformer blocks. The + timestep conditions every block through one shared modulation vector plus per-block learned tables. + + Args: + in_channels (`int`, defaults to 64): + Latent channel count after patchification (`vae_channels * patch_size ** 2`). + num_layers (`int`, defaults to 28): + Number of transformer blocks. + attention_head_dim (`int`, defaults to 128): + Dimension of each attention head; the total hidden size is `attention_head_dim * num_attention_heads`. + num_attention_heads (`int`, defaults to 48): + Number of query heads. + num_key_value_heads (`int`, defaults to 12): + Number of key/value heads for grouped-query attention. + intermediate_size (`int`, defaults to 16384): + Feed-forward hidden size of the SwiGLU MLP inside each block. + timestep_embed_dim (`int`, defaults to 256): + Width of the sinusoidal timestep embedding before its MLP. + text_hidden_dim (`int`, defaults to 2560): + Hidden size of the text encoder whose hidden states are consumed. + num_text_layers (`int`, defaults to 12): + Number of tapped text-encoder hidden states stacked per token. + text_num_attention_heads (`int`, defaults to 20): + Number of query heads in the text fusion blocks. + text_num_key_value_heads (`int`, defaults to 20): + Number of key/value heads in the text fusion blocks. + text_intermediate_size (`int`, defaults to 6912): + Feed-forward hidden size of the SwiGLU MLP inside the text fusion blocks. + num_layerwise_text_blocks (`int`, defaults to 2): + Number of text fusion blocks applied across the tapped-layer axis (per token). + num_refiner_text_blocks (`int`, defaults to 2): + Number of text fusion blocks applied across the token sequence. + axes_dims_rope (`tuple[int, int, int]`, defaults to `(32, 48, 48)`): + Head-dim split across the (t, h, w) rotary position axes. + rope_theta (`float`, defaults to 1000.0): + Base used by the rotary position embedding. + norm_eps (`float`, defaults to 1e-5): + Epsilon used by all RMSNorm modules. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Krea2TransformerBlock", "Krea2TextFusionBlock", "Krea2FinalLayer"] + _repeated_blocks = ["Krea2TransformerBlock"] + _keep_in_fp32_modules = ["norm", "norm1", "norm2", "norm_q", "norm_k"] + _skip_layerwise_casting_patterns = ["time_embed", "norm"] + + @register_to_config + def __init__( + self, + in_channels: int = 64, + num_layers: int = 28, + attention_head_dim: int = 128, + num_attention_heads: int = 48, + num_key_value_heads: int = 12, + intermediate_size: int = 16384, + timestep_embed_dim: int = 256, + text_hidden_dim: int = 2560, + num_text_layers: int = 12, + text_num_attention_heads: int = 20, + text_num_key_value_heads: int = 20, + text_intermediate_size: int = 6912, + num_layerwise_text_blocks: int = 2, + num_refiner_text_blocks: int = 2, + axes_dims_rope: tuple[int, int, int] = (32, 48, 48), + rope_theta: float = 1000.0, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + + hidden_size = attention_head_dim * num_attention_heads + if sum(axes_dims_rope) != attention_head_dim: + raise ValueError(f"sum(axes_dims_rope)={sum(axes_dims_rope)} must equal attention_head_dim={attention_head_dim}") + + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.gradient_checkpointing = False + + self.img_in = nn.Linear(in_channels, hidden_size, bias=True) + self.time_embed = Krea2TimestepEmbedding(timestep_embed_dim, hidden_size) + self.time_mod_proj = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + self.text_fusion = Krea2TextFusion( + num_text_layers=num_text_layers, + dim=text_hidden_dim, + num_heads=text_num_attention_heads, + num_kv_heads=text_num_key_value_heads, + intermediate_size=text_intermediate_size, + num_layerwise_blocks=num_layerwise_text_blocks, + num_refiner_blocks=num_refiner_text_blocks, + eps=norm_eps, + ) + self.txt_in = Krea2TextProjection(text_hidden_dim, hidden_size, eps=norm_eps) + self.rotary_emb = Krea2RotaryPosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) + + self.transformer_blocks = nn.ModuleList( + [ + Krea2TransformerBlock( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + self.final_layer = Krea2FinalLayer(hidden_size, out_channels=in_channels, eps=norm_eps) + + def fuse_qkv_projections(self, preferred_backend: str | None = None) -> None: + del preferred_backend + for module in self.modules(): + if isinstance(module, Krea2Attention): + module.fuse_projections() + + def unfuse_qkv_projections(self) -> None: + for module in self.modules(): + if isinstance(module, Krea2Attention): + module.unfuse_projections() + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + position_ids: torch.Tensor, + encoder_attention_mask: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> Transformer2DModelOutput | tuple[torch.Tensor]: + r""" + Predict the flow-matching velocity for the image tokens. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_seq_len, in_channels)`): + Packed (patchified) noisy image latents. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_seq_len, num_text_layers, text_hidden_dim)`): + Stack of tapped text-encoder hidden states per token. + timestep (`torch.Tensor` of shape `(batch_size,)`): + Flow-matching time in `[0, 1]` (1 is pure noise, 0 is clean data). + position_ids (`torch.Tensor` of shape `(text_seq_len + image_seq_len, 3)`): + `(t, h, w)` rotary coordinates for the combined sequence. Text rows are all-zero; image rows hold the + latent-grid coordinates. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, text_seq_len)`, *optional*): + Boolean mask marking valid text tokens. Pass `None` when every text token is valid. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that, when it contains a `scale` entry, sets the LoRA scale applied to this + transformer's adapters for the duration of the forward pass. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`] instead of a plain tuple. + + Returns: + [`~models.modeling_outputs.Transformer2DModelOutput`] or a `tuple` whose first element is the velocity + tensor of shape `(batch_size, image_seq_len, in_channels)`. + """ + if position_ids.ndim != 2 or position_ids.shape[-1] != 3: + raise ValueError(f"`position_ids` must have shape (sequence_length, 3), got {tuple(position_ids.shape)}.") + + batch_size, image_seq_len, _ = hidden_states.shape + text_seq_len = encoder_hidden_states.shape[1] + + temb = self.time_embed(timestep, dtype=hidden_states.dtype) + temb_mod = self.time_mod_proj(F.gelu(temb, approximate="tanh")) + + text_attention_mask = None + attention_mask = None + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask.bool() + # Key-padding masks of shape (B, 1, 1, L): padded text tokens are excluded as attention keys everywhere; + # their own (garbage) lanes are never read back and are dropped at the output slice. + text_attention_mask = encoder_attention_mask[:, None, None, :] + image_mask = encoder_attention_mask.new_ones((batch_size, image_seq_len)) + attention_mask = torch.cat([encoder_attention_mask, image_mask], dim=1)[:, None, None, :] + + encoder_hidden_states = self.text_fusion(encoder_hidden_states, attention_mask=text_attention_mask) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + hidden_states = self.img_in(hidden_states) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + image_rotary_emb = self.rotary_emb(position_ids) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, temb_mod, image_rotary_emb, attention_mask + ) + else: + hidden_states = block(hidden_states, temb_mod, image_rotary_emb, attention_mask) + + hidden_states = hidden_states[:, text_seq_len:] + output = self.final_layer(hidden_states, temb) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/simpletuner/helpers/models/model_metadata.json b/simpletuner/helpers/models/model_metadata.json index 95c9efcd5..a5fe1a97e 100644 --- a/simpletuner/helpers/models/model_metadata.json +++ b/simpletuner/helpers/models/model_metadata.json @@ -43,6 +43,16 @@ "klein-9b" ] }, + "krea2": { + "class_name": "Krea2", + "module_path": "simpletuner.helpers.models.krea2.model", + "name": "Krea 2", + "prediction_type": "flow_matching", + "flavour_choices": [ + "raw", + "turbo" + ] + }, "ltxvideo2": { "class_name": "LTXVideo2", "module_path": "simpletuner.helpers.models.ltxvideo2.model", diff --git a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/logging_fields.py b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/logging_fields.py index 7997ed5f3..74bcdb44f 100644 --- a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/logging_fields.py +++ b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/logging_fields.py @@ -660,6 +660,24 @@ def register_logging_fields(registry: "FieldRegistry") -> None: ) ) + registry._add_field( + ConfigField( + name="krea2_reference_latents", + arg_name="--krea2_reference_latents", + ui_label="Krea 2 Reference Latents", + field_type=FieldType.CHECKBOX, + tab="model", + section="model_specific", + model_specific=["krea2"], + default_value=False, + help_text="Enable Krea 2 reference-dataset training with image-context prompt embeds and clean reference latents.", + tooltip="When enabled, Krea 2 requires paired conditioning data, encodes prompts with the reference image through Qwen3VL, and appends the clean reference latents to the transformer input.", + importance=ImportanceLevel.ADVANCED, + order=36, + documentation="OPTIONS.md#--krea2_reference_latents", + ) + ) + # Offload Parameter Path registry._add_field( ConfigField( diff --git a/tests/test_krea2_model.py b/tests/test_krea2_model.py new file mode 100644 index 000000000..11eadf486 --- /dev/null +++ b/tests/test_krea2_model.py @@ -0,0 +1,210 @@ +import inspect +import unittest +from types import SimpleNamespace + +import torch + +from simpletuner.helpers.models.common import TextEmbedCacheKey +from simpletuner.helpers.models.krea2 import Krea2, Krea2LoraLoaderMixin, Krea2Pipeline, Krea2Transformer2DModel +from simpletuner.helpers.models.krea2.transformer import Krea2Attention, _krea2_rope_freqs_dtype +from simpletuner.helpers.models.registry import ModelRegistry + + +class FakeKrea2Transformer: + def __init__(self): + self.config = SimpleNamespace(patch_size=2) + self.last_call = None + + def __call__(self, **kwargs): + self.last_call = kwargs + return (torch.zeros_like(kwargs["hidden_states"]),) + + +class FakeLatentDistribution: + def __init__(self, latents): + self.latents = latents + + def sample(self): + return self.latents + + +class Krea2VendoredModelTests(unittest.TestCase): + def test_model_registry_resolves_krea2(self): + registry_entry = ModelRegistry.get("krea2") + model_class = registry_entry.get_real_class() if hasattr(registry_entry, "get_real_class") else registry_entry + self.assertIs(model_class, Krea2) + + def test_krea2_reference_latents_config_field_is_parseable(self): + from simpletuner.helpers.configuration.cmd_args import get_argument_parser + + parser = get_argument_parser() + args = parser.parse_args( + [ + "--model_family", + "krea2", + "--output_dir", + "/tmp/simpletuner-test", + "--model_type", + "lora", + "--optimizer", + "adamw_bf16", + "--data_backend_config", + "/tmp/backend.json", + "--krea2_reference_latents", + "true", + ] + ) + + self.assertTrue(args.krea2_reference_latents) + + def test_model_components_are_local_simpletuner_classes(self): + self.assertEqual(Krea2Pipeline.__module__, "simpletuner.helpers.models.krea2.pipeline") + self.assertEqual(Krea2Transformer2DModel.__module__, "simpletuner.helpers.models.krea2.transformer") + self.assertEqual(Krea2LoraLoaderMixin.__module__, "simpletuner.helpers.models.krea2.lora_pipeline") + self.assertEqual(Krea2.PROCESSOR_PATH, "Qwen/Qwen3-VL-4B-Instruct") + self.assertIsNone(Krea2.PROCESSOR_SUBFOLDER) + + def test_rope_frequency_dtype_matches_pypi_diffusers_device_support(self): + self.assertEqual(_krea2_rope_freqs_dtype(torch.device("cpu")), torch.float64) + self.assertEqual(_krea2_rope_freqs_dtype(torch.device("cuda")), torch.float64) + self.assertEqual(_krea2_rope_freqs_dtype(torch.device("mps")), torch.float32) + + def test_lora_loader_targets_transformer(self): + self.assertEqual(Krea2LoraLoaderMixin._lora_loadable_modules, ["transformer"]) + self.assertEqual(Krea2LoraLoaderMixin.transformer_name, "transformer") + + def test_lora_save_accepts_transformer_metadata(self): + parameters = inspect.signature(Krea2LoraLoaderMixin.save_lora_weights).parameters + self.assertIn("transformer_lora_adapter_metadata", parameters) + + def test_pipeline_accepts_reference_image_for_validation(self): + parameters = inspect.signature(Krea2Pipeline.__call__).parameters + self.assertIn("reference_image", parameters) + + def test_reference_latents_enable_reference_dataset_hooks(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace(krea2_reference_latents=True) + + self.assertTrue(model.supports_conditioning_dataset()) + self.assertTrue(model.requires_conditioning_dataset()) + self.assertTrue(model.requires_conditioning_validation_inputs()) + self.assertTrue(model.requires_validation_edit_captions()) + self.assertTrue(model.requires_text_embed_image_context()) + self.assertTrue(model.requires_conditioning_latents()) + self.assertFalse(model.should_precompute_validation_negative_prompt()) + self.assertEqual(model.text_embed_cache_key(), TextEmbedCacheKey.DATASET_AND_FILENAME) + + def test_reference_latents_maps_validation_image_to_reference_image(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace(krea2_reference_latents=True) + + kwargs = model.update_pipeline_call_kwargs({"image": "reference"}) + + self.assertEqual(kwargs, {"reference_image": "reference"}) + + def test_reference_latents_disabled_keeps_text_to_image_hooks(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace(krea2_reference_latents=False, control=False, controlnet=False) + + self.assertTrue(model.supports_conditioning_dataset()) + self.assertFalse(model.requires_conditioning_dataset()) + self.assertFalse(model.requires_conditioning_validation_inputs()) + self.assertFalse(model.requires_validation_edit_captions()) + self.assertFalse(model.requires_text_embed_image_context()) + self.assertFalse(model.requires_conditioning_latents()) + self.assertTrue(model.should_precompute_validation_negative_prompt()) + self.assertEqual(model.text_embed_cache_key(), TextEmbedCacheKey.CAPTION) + + def test_fused_qkv_lora_targets_use_fused_projection(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace(fuse_qkv_projections=True) + + self.assertEqual(model.get_lora_target_layers(), ["to_qkv", "to_out.0"]) + + def test_krea2_attention_fused_projection_matches_unfused_path(self): + attention = Krea2Attention(hidden_size=8, num_heads=2, num_kv_heads=1) + hidden_states = torch.randn(2, 5, 8) + + unfused = attention(hidden_states) + attention.fuse_projections() + fused = attention(hidden_states) + + self.assertTrue(torch.allclose(fused, unfused, atol=1e-6, rtol=1e-6)) + self.assertTrue(hasattr(attention, "to_qkv")) + + attention.unfuse_projections() + self.assertFalse(hasattr(attention, "to_qkv")) + self.assertFalse(attention.fused_projections) + + def test_vae_encode_hooks_use_qwen_image_vae_rank_and_normalization(self): + model = Krea2.__new__(Krea2) + vae = SimpleNamespace(config=SimpleNamespace(latents_mean=[1.0, 2.0], latents_std=[2.0, 4.0], z_dim=2)) + model.get_vae = lambda: vae + + image_batch = torch.zeros(1, 3, 64, 64) + preprocessed = model.pre_vae_encode_transform_sample(image_batch) + self.assertEqual(tuple(preprocessed.shape), (1, 3, 1, 64, 64)) + + vae_output = SimpleNamespace(latent_dist=FakeLatentDistribution(torch.ones(1, 2, 1, 8, 8) * 5.0)) + latents = model.post_vae_encode_transform_sample(vae_output) + + self.assertEqual(tuple(latents.shape), (1, 2, 8, 8)) + self.assertTrue(torch.allclose(latents[:, 0], torch.full((1, 8, 8), 2.0))) + self.assertTrue(torch.allclose(latents[:, 1], torch.full((1, 8, 8), 0.75))) + + def test_transformer_accepts_cached_int64_encoder_attention_mask(self): + transformer = Krea2Transformer2DModel( + in_channels=4, + num_layers=1, + attention_head_dim=6, + num_attention_heads=1, + num_key_value_heads=1, + intermediate_size=8, + timestep_embed_dim=8, + text_hidden_dim=6, + num_text_layers=2, + text_num_attention_heads=1, + text_num_key_value_heads=1, + text_intermediate_size=8, + num_layerwise_text_blocks=1, + num_refiner_text_blocks=1, + axes_dims_rope=(2, 2, 2), + ) + + output = transformer( + hidden_states=torch.randn(1, 4, 4), + encoder_hidden_states=torch.randn(1, 3, 2, 6), + timestep=torch.tensor([0.5]), + position_ids=torch.zeros(7, 3, dtype=torch.long), + encoder_attention_mask=torch.ones(1, 3, dtype=torch.int64), + return_dict=False, + ) + + self.assertEqual(tuple(output[0].shape), (1, 4, 4)) + + def test_model_predict_appends_reference_latent_tokens_and_crops_output(self): + model = Krea2.__new__(Krea2) + model.config = SimpleNamespace(krea2_reference_latents=True, weight_dtype=torch.float32) + model.accelerator = SimpleNamespace(device=torch.device("cpu")) + model.model = FakeKrea2Transformer() + model.unwrap_model = lambda model: model + + batch = { + "noisy_latents": torch.ones(1, 16, 4, 4), + "latents": torch.zeros(1, 16, 4, 4), + "conditioning_latents": torch.full((1, 16, 4, 4), 2.0), + "timesteps": torch.tensor([500.0]), + "prompt_embeds": torch.zeros(1, 3, 2, 4), + "encoder_attention_mask": torch.ones(1, 3, dtype=torch.int64), + } + + result = model.model_predict(batch) + + self.assertEqual(tuple(result["model_prediction"].shape), (1, 16, 4, 4)) + self.assertEqual(tuple(model.model.last_call["hidden_states"].shape), (1, 8, 64)) + self.assertEqual(tuple(model.model.last_call["position_ids"].shape), (11, 3)) + self.assertTrue(torch.equal(model.model.last_call["timestep"], torch.tensor([0.5]))) + + +if __name__ == "__main__": + unittest.main()