diff --git a/.gitignore b/.gitignore index 3e6aee68..64808557 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,9 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* + +# Dataset folders and outputs +/rawfiles +/outputs +/pre_encoded diff --git a/TRAINING_FOR_NOOBS.md b/TRAINING_FOR_NOOBS.md new file mode 100644 index 00000000..71a4b379 --- /dev/null +++ b/TRAINING_FOR_NOOBS.md @@ -0,0 +1,123 @@ +# How to train your own finetune of Stable Audio Open 1.0 or Stable Audio Open Small +I did all of this on a ROG Z13 ACRNM with a mobile 4070 GPU with 8 GiB of VRAM, or an Eluktronics Mech-17 GP2 with a mobile 4090 GPU with 16 GiB of VRAM. I did everything with Python 3.11.9 and CUDA 12.9 on Windows 11 Pro. Ymmv. + +## Raw files for your training dataset +* `/rawfiles` is what the example scripts and config files will use for finding raw audio files to pre-encode and use for training the model. You can change this and use another directory if you wish. +* Everything will automatically be converted to a 44100 Hz sampling rate using `torchaudio`. You can perform your own SRC (sample rate conversion) ahead of time if you wish. +* The SAO-small model can handle up to about 11.89 seconds of audio. If your files are longer than this, they will be implicitly truncated. If your files are shorter than this, they will be padded to fit. +* The full SAO-1.0 model can handle up to about 47.55 seconds of audio. If your files are longer than this, they will be implicitly truncated. If your files are shorter than this, they will be padded to fit. +* A good dataset will have at least ~ 5000 files. Less than this, and the model is likely to overfit even with a low number of training steps. More examples to show to the model = the better the results. +* A good dataset will have lots of different sounds. You can intentionally overfit by using only a few sounds, but this will make the model pretty much useless at doing anything other than spitting out the few sounds you fed it, with very little "variation" or "creativity". +* You can feed the model a ton of examples of something specific, like 808 bass one-shots, set all of the conditioning prompts for all of the files in your dataset to "808 bass one-shot" instead of using more complex logic in your custom metadata function(s) to create a distinct prompt for each audio file, and the model will learn many examples of what an "808 bass one-shot" is. +* It's better to feed the model "too many examples" and then refine things later, than it is to feed the model a tiny dataset and end up overfitting. + +I strongly recommend using a lossless file format, such as `.WAV` or `.AIFF`, because lossy formats like `.MP3` change the audio in ways that can be disastrous for the audio quality. + +Case in point: + +image + +### Sample size, sample rate, latent length, audio file duration in seconds, etc. +* These models use a 44100 Hz `sample_rate`. Do not mess with this setting! +* `sample_size` is a confusing misnomer. It would be more accurate if it had been named `audio_input_length_in_samples` or `segment_size_in_samples` or something similar, because that's what it actually is. `sample_size` / 44100 = how long the audio inputs and outputs will be in seconds. +* "Latent" = what your audio gets turned into before the model trains on the data. +* `downsampling_ratio` = the ratio by which your raw audio inputs will be downsampled when they are converted into latents during pre-encoding. This value is 2048 for both SAO-1.0 and SAO-small. Do not mess with this setting! + +#### Handy math examples +Latent length of 64 * downsampling ratio of 2048 = 131072 `sample_size` + +131072 `sample_size` / `sample_rate` of 44100 = 2.97 seconds + +"Segment size" = `sample_size` = latent length * `downsampling_ratio`. + +Ergo the SAO-1.0 `model.ckpt` uses a `sample_size` of 2097152, since the model was trained using a latent length of 1024 and a `downsampling_ratio` of 2048. + +Thus we arrive at 2097152 samples / 44100 samples per second = 47.55 seconds of audio. + +Same values and math for SAO-small, but with a latent length of 256 instead of 1024, ergo 11.89 seconds of audio. + +All pre-encoded latents derived from raw audio files will be silence-padded using a mask in order to fit the appropriate latent length for a given model during pre-encoding. + +Ergo there is no point in using `latent_crop_length` when pre-encoding raw audio files which have a length in samples which is less than the model's native segment size, e.g., 2097152 with SAO-1.0 and 524288 with SAO-small. + +`latent_crop_length` can be used to set the pre-encoded latent sizes to a consistent size. To give a practical example, you could pre-encode your data with a `sample_size` of 2097152 (SAO 1.0 length), then have two separate pre-encoded dataset configs with different `latent_crop_length` (1024 for SAO 1.0, 256 for SAO Small), both reading from the same pre-encoded directory. + +## Pre-encode the latents based on your raw audio files +Technically this is optional, but there is no reason not to pre-encode the latents. +* `pe_dataset_config.json` + * "PE" stands for "pre-encoding". This file contains instructions for the `DataLoader` to read your raw audio files, such as `.WAV` or `.AIFF` files, which will be used to pre-encode the latents. +* `paths_md_pre_encode.py` + * "PE" stands for "pre-encoding". This script provides the model with conditioning parameters during training. The only conditioning you need to handle is the `prompt`. The other one or two conditioning parameters, `seconds_start` and `seconds_total`, will be determined by other settings. +* `pre-encode.bat` + * You can skip over this file and run the `pre_encode.py` command manually with your own settings if you wish. + * If you run out of memory, lower the `batch_size`. If you still get OOM errors even with the minimum `batch_size` of 1, you probably need to buy a GPU with more VRAM, or you need to run this stuff on a remote hosting platform such as AWS EC2, RunPod, or Google Colab. + * The example settings should work fine on a GPU with 8 GiB of VRAM. + +## Configure your dataset +* `dataset_config.json` +* `paths_md.py` + +## Configure the model +* `/sao_small/base_model_config.json` + * Use this model config for the SAO-small model. +* `model_config.json` + * Use this model config for the full SAO-1.0 model. + +## Train +`train.bat` + +You can skip over this file and run the `training.py` command manually with your own settings if you wish. + +If you run out of memory, lower the `batch_size`. If you still get OOM errors even with the minimum `batch_size` of 1, you probably need to buy a GPU with more VRAM, or you need to run this stuff on a remote hosting platform such as AWS EC2, RunPod, or Google Colab. + +The example settings for SAO-small should work fine on a GPU with at least 8 GiB of VRAM. + +If you want to train the full-sized SAO-1.0 model, you will need at least 24 GiB of VRAM. + +### When is training done? +Whenever you feel like stopping. + +Listen to the demos and decide when it sounds like the model has learned enough about your dataset, then kill the training process with Ctrl-C. + +If you have a small number of files in the dataset, like only 100 .WAV files, then you will probably start to overfit after about 2000 steps. "Overfitting" means that the model is getting to a point where it will basically be hyper-optimized for recreating the exact audio you used for training whenever you generate new outputs during inference using the same or similer prompts that you used during training. + +## WARNING: DO NOT MODIFY OR USE THESE FILES +Do not mess with these files: +* `/sao_small/model_config.json` + * This is the config for the ARC post-trained `model.ckpt` of SAO-small, which you should not attempt to train. +* `/vae_model_config.json` + * This is the config for the VAE model (the auto-encoder), which you should not mess with unless you know exactly what you are doing and why you are doing it. + +# Terminology for noobs +* Epoch = one pass over all files in the dataset. If you have 1280 files in the dataset, it will take 1 Epoch to "show all 1280 files to the model". +* Batch = one chunk of the dataset. If your Batch Size is 32, it will take 1280 / 32 = 40 Steps to complete 1 Epoch. +* Step = one iteration of the training process, in which 1 Batch of latents (derived from your training dataset files) will be "shown to the model so it can learn from them". If you have 1280 files in your dataset, and you use a Batch Size of 8, it will take 1280 / 8 = 160 Steps to complete 1 Epoch. +* Gradient Accumulation = increases the effective Batch Size when your hardware can't handle a larger Batch Size. Effective Batch Size = Gradient Accumulation * Batch Size. Gradient Accumulation 4 * Batch Size 8 = Effective Batch Size of 32. Instead of actually "showing 32 latents or files to the model", you end up "showing 8 latents or files to the model" 4 times. This results in lower VRAM usage, but longer training times. It's usually better to just use the largest Batch Size you can without running out of VRAM, and not using Gradient Accumulation unless you have no other option. +* Learning Rate = how much the model learns from each Batch (in each Step), as a function of time. The Learning Rate could be constant, or it could change over time. This is usually expressed as a value between 0 and 1, with 0 meaning "learn nothing" and 1 meaning "study what you are exposed to in 100% depth, and let this experience influence you to the utmost". + * Learning rate too low = takes longer to train, model seems to not have learned anything (underfits). + * Learning rate too high = takes far less time to train than you probably expected, and the model probably overfits within 1 Epoch. + * Small dataset = try a larger value for Learning Rate, such as 1e-2 (0.01). Not many examples to learn from, but you learn a lot from each example. + * Large dataset = try a smaller value for Learning Rate, such as 5e-4 (0.0005). Learn just a bit from each example, but have a lot of examples. +* Weight decay = how much the Learning Rate decreases (decays) over time during training. You might have a Learning Rate of 0.01, and a Decay of 0.001, meaning that after each Step the Learning Rate decreases by 0.001: 0.01, 0.099, 0.098, 0.096 ... 0.003, 0.002, 0.001, done! +* Learning Rate Optimizers = algorithms for optimizing the Learning Rate. Usually combined with a Learning Rate Scheduler. A typical choice is `AdamW`. `AdamW8bit` is a viable option for saving VRAM. Many other options exist, like Lion and Prodigy, but you should stick with `AdamW` or another `Adam` derivative unless you want to navigate unexplored territory and perform experiments. +* Learning Rate Schedulers = algorithms for changing the Learning Rate over time. https://machinelearningmastery.com/a-gentle-introduction-to-learning-rate-schedulers/ NOT TO BE CONFUSED WITH NOISE SCHEDULERS, AKA SAMPLERS! Stable Audio Open uses a custom `InverseLR` Scheduler. Another good option is `CosineAnnealing`. +* Noise Schedulers, aka Samplers = algorithms for adding or removing noise: https://huggingface.co/docs/diffusers/en/api/schedulers/overview https://civitai.com/articles/7484/understanding-stable-diffusion-samplers-beyond-image-comparisons + +## What values should I use? +* You should train for at least 1 Epoch, or else the model won't "see" all of your dataset. + * Too many Epochs (and similarly, too many total Steps) = the model is likely to overfit. Imagine someone going to "normal" school up to the 4th grade, and then being sent to a specialized school where they only learned about how to play modern jazz trumpet: they'd probably not be very good at many "normal" tasks, while excelling at modern jazz trumpet, and they'd be likely to interpret everything they experienced after graduation in the context of modern jazz trumpet. + * Too few Epochs (and similarly, too few total Steps) = the model is likely to underfit. Imagine someone going to "normal" school up to the 4th grade, and then being sent to a specialized school where they only learned about how to play modern jazz trumpet, but then you pull them out of school after one week: they'd probably not suffer from "forgetting" everything from "normal" school, but they'd also have learned so little about modern jazz trumpet that they might not be much better than their peers who never studied modern jazz trumpet. +* You should use the largest Batch Size you can fit into VRAM, as a general rule. + * Try to not use extremely small Batch Size values, such as 1, because the model is more likely to learn well from Batch Sizes of about 8 to 32. + * Try to not use extremely large Batch Size values, such as 512, because the model is more likely to learn well from Batch Sizes of about 8 to 32. + * Try to use only Batch Size, and to not use Gradient Accumulation, whenever feasible. +* Some Optimizer + Scheduler combinations can figure out the appropriate Learning Rate for you. Even better: some Optimizer + Scheduler combinations can figure out the appropriate Learning Rate and the best way to adjust the Learning Rate over time, so you don't have a constant Learning Rate. + +### I NEED SPECIFIC MAGICAL NUMBERS!!! +Training an AI/ML model is as much of an art as it is a science. Each scenario is unique. You will have to experiment in order to figure out whether training SAO-small on 500 drum one-shots for 2 Epochs with a Batch Size of 8 and a Learning Rate of 5e-3 (0.005) produces better results than training SAO-small on the same 500 drum one-shots for 20 Epochs with a Batch Size of 32 and a Learning Rate of 1e-5 (0.00001). + +# HELP!!! +* Static-y whine or drone = you probably used an unwrapped model instead of a wrapped one, or vice versa; or you used `--pretrained-ckpt-path` instead of `--ckpt-path`, or vice versa. +* If you need a pre-compiled wheel for `flash-attention`, I gotchu fam: https://github.com/sskalnik/flash_attn_wheels +* `RuntimeError: Given groups=1, weight of size [128, 2, 7], expected input[1, 64, 645] to have 2 channels, but got 64 channels instead` = you need to make sure `pre_encoded` is set to `True` in the model config JSON file you're using for training. +* `UserWarning: At least one mel filterbank has all zero values. The value for n_mels (128) may be set too high. Or, the value for n_freqs (513) may be set too low.` = You can ignore this. diff --git a/acid_dataset_config.json b/acid_dataset_config.json new file mode 100644 index 00000000..9eabf5df --- /dev/null +++ b/acid_dataset_config.json @@ -0,0 +1,8 @@ +{ + "dataset_type": "pre_encoded", + "datasets": [{ + "id": "audio_pre_encoded", + "path": "pre_encoded", + "custom_metadata_module": "acid_paths_md.py" + }] +} diff --git a/acid_paths_md.py b/acid_paths_md.py new file mode 100644 index 00000000..9a986587 --- /dev/null +++ b/acid_paths_md.py @@ -0,0 +1,58 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + #cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Get parent directory name (without the full path) + dir_name = os.path.dirname(info["relpath"]) + prompt = os.path.split(dir_name)[1] + + # Use the filename instead of parent directory if the filename has relevant info + if 'BPM' in prompt: + prompt = file_name_without_extension + + # Translate X beats of Y notes per bar from XbYn to "normal" time signature notation + # 4b4n = 4/4 + # 3b16n = 3/16 + # 69b420n = 69/420 + prompt = re.sub(r'(\d+)b(\d+)n', r'\1/\2', prompt) + + # Instrument123 => Instrument + # Acid1 => Acid + prompt = re.sub(r'^(\w+)\d+', r'\1', prompt, count=1) + + # Acid DistSplinterFat 120BPM 4/4 4bars + # Acid Distorted 120BPM 4/4 4bars + prompt = re.sub(r'Dist\w+ ', r' Distorted ', prompt, count=1) + + # Acid Distorted 120BPM 4/4 4bars + # Acid Distorted 120 BPM 4/4 4bars + prompt = re.sub(r'(\d+)BPM', r'\1 BPM', prompt, count=1) + + # Am = A minor + # G#m = G# minor + prompt = re.sub(r'( [ABCDEFG][#♭♮♯]?)m ', r'\1 minor ', prompt) + # AMajor = A Major + # F#Phrygian = F# Phrygian + prompt = re.sub(r'(^\w+ [ABCDEFG][#♭♮♯]?)([a-zA-Z]+)', r'\1 \2', prompt, count=1) + # TODO: obviate this hack + prompt = re.sub(r' D istorted ', r' Distorted ', prompt, count=1) + + # 4bars = 4 bars + prompt = re.sub(r'(\d+)bars', r'\1 bars', prompt) + + # Remove (1), (2), etc. + prompt = re.sub(r'\(\d+\)$', r'', prompt) + + # Sanity check + print(f'{info["relpath"]} => {prompt}') + + return {"prompt": prompt} diff --git a/acid_paths_md_pre_encode.py b/acid_paths_md_pre_encode.py new file mode 100644 index 00000000..9a986587 --- /dev/null +++ b/acid_paths_md_pre_encode.py @@ -0,0 +1,58 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + #cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Get parent directory name (without the full path) + dir_name = os.path.dirname(info["relpath"]) + prompt = os.path.split(dir_name)[1] + + # Use the filename instead of parent directory if the filename has relevant info + if 'BPM' in prompt: + prompt = file_name_without_extension + + # Translate X beats of Y notes per bar from XbYn to "normal" time signature notation + # 4b4n = 4/4 + # 3b16n = 3/16 + # 69b420n = 69/420 + prompt = re.sub(r'(\d+)b(\d+)n', r'\1/\2', prompt) + + # Instrument123 => Instrument + # Acid1 => Acid + prompt = re.sub(r'^(\w+)\d+', r'\1', prompt, count=1) + + # Acid DistSplinterFat 120BPM 4/4 4bars + # Acid Distorted 120BPM 4/4 4bars + prompt = re.sub(r'Dist\w+ ', r' Distorted ', prompt, count=1) + + # Acid Distorted 120BPM 4/4 4bars + # Acid Distorted 120 BPM 4/4 4bars + prompt = re.sub(r'(\d+)BPM', r'\1 BPM', prompt, count=1) + + # Am = A minor + # G#m = G# minor + prompt = re.sub(r'( [ABCDEFG][#♭♮♯]?)m ', r'\1 minor ', prompt) + # AMajor = A Major + # F#Phrygian = F# Phrygian + prompt = re.sub(r'(^\w+ [ABCDEFG][#♭♮♯]?)([a-zA-Z]+)', r'\1 \2', prompt, count=1) + # TODO: obviate this hack + prompt = re.sub(r' D istorted ', r' Distorted ', prompt, count=1) + + # 4bars = 4 bars + prompt = re.sub(r'(\d+)bars', r'\1 bars', prompt) + + # Remove (1), (2), etc. + prompt = re.sub(r'\(\d+\)$', r'', prompt) + + # Sanity check + print(f'{info["relpath"]} => {prompt}') + + return {"prompt": prompt} diff --git a/acid_pe_dataset_config.json b/acid_pe_dataset_config.json new file mode 100644 index 00000000..ecfa420f --- /dev/null +++ b/acid_pe_dataset_config.json @@ -0,0 +1,11 @@ +{ + "dataset_type": "audio_dir", + "datasets": [{ + "id": "audio", + "path": "./rawfiles", + "custom_metadata_module": "./acid_paths_md_pre_encode.py", + "drop_last": false + }], + "drop_last": false, + "random_crop": false +} diff --git a/dataset_config.json b/dataset_config.json new file mode 100644 index 00000000..e402ec9e --- /dev/null +++ b/dataset_config.json @@ -0,0 +1,8 @@ +{ + "dataset_type": "pre_encoded", + "datasets": [{ + "id": "audio_pre_encoded", + "path": "pre_encoded", + "custom_metadata_module": "paths_md.py" + }] +} diff --git a/model_config.json b/model_config.json new file mode 100644 index 00000000..a8c3b090 --- /dev/null +++ b/model_config.json @@ -0,0 +1,123 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 131072, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "t5-base", + "max_length": 128 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "pre_encoded": true, + "use_ema": true, + "log_loss_info": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 50, + "demo_steps": 50, + "num_demos": 1, + "demo_cond": [ + {"prompt": "Amen break 174 BPM", "seconds_start": 0, "seconds_total": 16} + ], + "demo_cfg_scales": [1, 2, 4, 8] + } + } +} diff --git a/paths_md.py b/paths_md.py new file mode 100644 index 00000000..5ead4a86 --- /dev/null +++ b/paths_md.py @@ -0,0 +1,17 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Sanity check + print(f'{info["relpath"]} => {cleaned_file_name}') + + return {"prompt": cleaned_file_name} diff --git a/paths_md_pre_encode.py b/paths_md_pre_encode.py new file mode 100644 index 00000000..5ead4a86 --- /dev/null +++ b/paths_md_pre_encode.py @@ -0,0 +1,17 @@ +import os +import re + + +def get_custom_metadata(info, audio): + # Get filename without extension + file_name = os.path.basename(info["relpath"]) + file_name_without_extension = os.path.splitext(file_name)[0] + + # Replace non-alphanumeric characters with spaces, and remove leading/trailing spaces + cleaned_file_name = re.sub('[^0-9a-zA-Z]+', ' ', file_name_without_extension).strip() + #cleaned_file_name = re.match('', cleaned_file_name).groups()[0] + + # Sanity check + print(f'{info["relpath"]} => {cleaned_file_name}') + + return {"prompt": cleaned_file_name} diff --git a/pe_dataset_config.json b/pe_dataset_config.json new file mode 100644 index 00000000..9e6194ae --- /dev/null +++ b/pe_dataset_config.json @@ -0,0 +1,11 @@ +{ + "dataset_type": "audio_dir", + "datasets": [{ + "id": "audio", + "path": "./rawfiles", + "custom_metadata_module": "./paths_md_pre_encode.py", + "drop_last": false + }], + "drop_last": false, + "random_crop": false +} diff --git a/pre-encode.bat b/pre-encode.bat new file mode 100644 index 00000000..ef843f7b --- /dev/null +++ b/pre-encode.bat @@ -0,0 +1,9 @@ +python ./pre_encode.py ^ + --ckpt-path ./vae_model.ckpt ^ + --model-config ./vae_model_config.json ^ + --batch-size 4 ^ + --dataset-config am_pe_dataset_config.json ^ + --output-path ./pre_encoded ^ + --model-half ^ + --sample-size 131072 ^ + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..a63a05bd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +aiohttp==3.12.15 +altair==5.5.0 +# Required for AdamW8bit optimizer: +bitsandbytes==0.47.0 +Brotli==1.1.0 +# Only needed if compiling source code locally: +# cmake==4.1.0 +# Resolves pickling issues on Windows 11, possibly other platforms as well: +dill==0.4.0 +# If on Windows, use this for Flash Attention 2.8.0.post2 with CUDA 12.9 and Pytorch 2.8.0 with Python 3.11 on Windows 11: +# Required as of 2025 Sept 20 because no pre-compiled official wheels are available. +# flash-attn @ https://github.com/sskalnik/flash_attn_wheels/blob/main/flash_attn-2.8.0.post2%2Bcu129torch2.8.0cxx11abiTRUE-cp311-cp311-win_amd64.whl +# Only required for development: +# pipdeptree==2.28.0 +pytorch_optimizer==3.1.2 +stable-audio-tools @ https://github.com/sskalnik/stable-audio-tools.git +# Required for StatefulDataLoader: +torchdata==0.11.0 +typing-inspection==0.4.1 +wheel==0.45.1 diff --git a/sao_small/acid_v3_base_model_config.json b/sao_small/acid_v3_base_model_config.json new file mode 100644 index 00000000..beb6a77b --- /dev/null +++ b/sao_small/acid_v3_base_model_config.json @@ -0,0 +1,134 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 354304, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": false, + "model_half": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "google/t5gemma-b-b-ul2", + "max_length": 128 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 256 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_total"], + "global_cond_ids": ["seconds_total"], + "diffusion_objective": "rectified_flow", + "distribution_shift_options": { + "min_length": 256, + "max_length": 4096 + }, + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 16, + "num_heads": 8, + "cond_token_dim": 768, + "global_cond_dim": 768, + "transformer_type": "continuous_transformer", + "attn_kwargs": { + "qk_norm": "ln" + } + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "pre_encoded": true, + "timestep_sampler": "trunc_logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW8bit", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01, + "percentile_clipping": 100, + "block_wise": true + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.995 + } + } + } + }, + "demo": { + "demo_every": 768, + "demo_steps": 100, + "num_demos": 8, + "demo_cond": [ + {"prompt": "acid lead in F# minor, extremely fast tempo of 200 BPM, 3/4 time signature, 4-bar loop, analog distortion", "seconds_total": 8}, + {"prompt": "drum breaks 174 BPM", "seconds_total": 6}, + {"prompt": "A short, beautiful piano riff in C minor", "seconds_total": 6}, + {"prompt": "Tight Snare Drum", "seconds_total": 1}, + {"prompt": "Glitchy bass design, I used Serum for this", "seconds_total": 4}, + {"prompt": "Synth pluck arp with reverb and delay, 128 BPM", "seconds_total": 6}, + {"prompt": "Acid A minor 120 BPM 4/4 4 bar loop", "seconds_total": 8}, + {"prompt": "Electronic, with a synthesized, futuristic tone. It has a steady, rhythmic pattern and a slightly retro, 1980s-inspired sound. The bass is prominent, giving the track a pulsing, driving feel. The music sets a mood that is upbeat, energetic, and slightly playful. It suggests a setting that could be related to technology, gaming, or a lighthearted, futuristic scenario.", "seconds_total": 8} + ], + "demo_cfg_scales": [0.5, 1, 1.5, 8] + } + } +} diff --git a/sao_small/base_model_config.json b/sao_small/base_model_config.json new file mode 100644 index 00000000..d8938759 --- /dev/null +++ b/sao_small/base_model_config.json @@ -0,0 +1,132 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 131072, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": false, + "model_half": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "t5-base", + "max_length": 64 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 256 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_total"], + "global_cond_ids": ["seconds_total"], + "diffusion_objective": "rectified_flow", + "distribution_shift_options": { + "min_length": 256, + "max_length": 4096 + }, + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 16, + "num_heads": 8, + "cond_token_dim": 768, + "global_cond_dim": 768, + "transformer_type": "continuous_transformer", + "attn_kwargs": { + "qk_norm": "ln" + } + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "pre_encoded": true, + "timestep_sampler": "trunc_logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 2e-4, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.01, + "foreach": true + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.995 + } + } + } + }, + "demo": { + "demo_every": 100, + "demo_steps": 50, + "num_demos": 7, + "demo_cond": [ + {"prompt": "Amen break 174 BPM", "seconds_total": 6}, + {"prompt": "drum breaks 174 BPM", "seconds_total": 6}, + {"prompt": "A short, beautiful piano riff in C minor", "seconds_total": 6}, + {"prompt": "Tight Snare Drum", "seconds_total": 1}, + {"prompt": "Glitchy bass design", "seconds_total": 4}, + {"prompt": "Glitchy bass design, I used Serum for this", "seconds_total": 4}, + {"prompt": "Synth pluck arp with reverb and delay, 128 BPM", "seconds_total": 6} + ], + "demo_cfg_scales": [1, 2, 4, 8] + } + } +} diff --git a/sao_small/model_config.json b/sao_small/model_config.json new file mode 100644 index 00000000..34dfe14f --- /dev/null +++ b/sao_small/model_config.json @@ -0,0 +1,164 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 131072, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": false, + "model_half": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "t5-base", + "max_length": 64 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 256 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_total"], + "global_cond_ids": ["seconds_total"], + "diffusion_objective": "rf_denoiser", + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 16, + "num_heads": 8, + "cond_token_dim": 768, + "global_cond_dim": 768, + "transformer_type": "continuous_transformer", + "attn_kwargs": { + "qk_norm": "ln" + } + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "pre_encoded": true, + "log_loss_info": false, + "timestep_sampler": "trunc_logit_normal", + "clip_grad_norm": 1.0, + "cfg_dropout_prob": 0.0, + "arc": { + "ode_warmup": { + "warmup_steps": 0, + "refresh_rate": 10, + "sampling_steps": 25, + "cfg": 4 + }, + "noise_dist": { + "generator": "logsnr_uniform", + "discriminator": "logit_normal" + }, + "use_model_as_discriminator": true, + "discriminator_base_ckpt": "/path/to/base/rf/model", + "discriminator": { + "type": "convnext", + "dit_hidden_layer": 12, + "weights": { + "generator": 1.0, + "discriminator": 1.0 + }, + "loss_type": "relativistic", + "config": { + "channels": 128, + "strides": [1, 2, 2], + "c_mults": [1, 2, 4], + "num_blocks": [2, 2, 2] + }, + "contrastive": true, + "include_grad_penalties": true + } + }, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.01, + "foreach": true + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-6, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 0.01, + "foreach": true + } + } + } + }, + "demo": { + "demo_every": 100, + "demo_steps": 8, + "num_demos": 6, + "demo_cond": [ + {"prompt": "drum breaks 174 BPM", "seconds_total": 6}, + {"prompt": "A short, beautiful piano riff in C minor", "seconds_total": 6}, + {"prompt": "Tight Snare Drum", "seconds_total": 1}, + {"prompt": "Glitchy bass design", "seconds_total": 4}, + {"prompt": "Synth pluck arp with reverb and delay, 128 BPM", "seconds_total": 6}, + {"prompt": "Birds singing in the forest", "seconds_total": 10} + ], + "demo_cfg_scales": [1] + } + } +} diff --git a/setup.py b/setup.py index f96f3bc1..7974a350 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ 'local-attention==1.8.6', 'pandas==2.0.2', 'prefigure==0.0.9', - 'pytorch_lightning==2.1.0', + 'pytorch_lightning==2.4.0', 'PyWavelets==1.4.1', 'safetensors', 'sentencepiece==0.1.99', diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 7543ac17..7e02b6d6 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -1,3 +1,4 @@ +import dill import importlib import numpy as np import io @@ -19,7 +20,10 @@ from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm -AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") +from torchdata.stateful_dataloader import StatefulDataLoader + + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus", "aiff", "aif") # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py @@ -94,7 +98,7 @@ def keyword_scandir( def get_audio_filenames( paths: list, # directories in which to search keywords=None, - exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus', '.aif', '.aiff'] ): "recursively get a list of audio filenames" filenames = [] @@ -178,7 +182,7 @@ def __init__( self.root_paths.append(config.path) self.filenames.extend(get_audio_filenames(config.path, keywords)) if config.custom_metadata_fn is not None: - self.custom_metadata_fns[config.path] = config.custom_metadata_fn + self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn) print(f'Found {len(self.filenames)} files') @@ -238,8 +242,8 @@ def __getitem__(self, idx): for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in audio_filename: - custom_metadata_fn = self.custom_metadata_fns[custom_md_path] - custom_metadata = custom_metadata_fn(info, audio) + custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path]) + custom_metadata = custom_metadata_fn_deserialized(info, audio) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: @@ -282,7 +286,7 @@ def __init__( for config in configs: self.filenames.extend(get_latent_filenames(config.path, [latent_extension])) if config.custom_metadata_fn is not None: - self.custom_metadata_fns[config.path] = config.custom_metadata_fn + self.custom_metadata_fns[config.path] = dill.dumps(config.custom_metadata_fn) self.latent_crop_length = latent_crop_length self.random_crop = random_crop @@ -339,8 +343,9 @@ def __getitem__(self, idx): for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in latent_filename: - custom_metadata_fn = self.custom_metadata_fns[custom_md_path] - custom_metadata = custom_metadata_fn(info, None) + + custom_metadata_fn_deserialized = dill.loads(self.custom_metadata_fns[custom_md_path]) + custom_metadata = custom_metadata_fn_deserialized(info, None) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: @@ -849,8 +854,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl force_channels=force_channels ) - return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, - num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn) + # https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader + g = torch.Generator() + g.manual_seed(0) + + #return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, + # num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) + return StatefulDataLoader(train_set, batch_size, shuffle=shuffle, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) elif dataset_type == "pre_encoded": @@ -899,8 +910,14 @@ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sampl latent_extension=latent_extension ) - return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, - num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn) + # https://docs.pytorch.org/docs/stable/notes/randomness.html#dataloader + g = torch.Generator() + g.manual_seed(0) + + #return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, + # num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) + return StatefulDataLoader(train_set, batch_size, shuffle=shuffle, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=dataset_config.get("drop_last", True), collate_fn=collation_fn, generator=g) elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility wds_configs = [] diff --git a/stable_audio_tools/models/conditioners.py b/stable_audio_tools/models/conditioners.py index 64f6e3a8..48d59da2 100644 --- a/stable_audio_tools/models/conditioners.py +++ b/stable_audio_tools/models/conditioners.py @@ -287,7 +287,7 @@ class T5Conditioner(Conditioner): T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] + "google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl", "google/t5gemma-b-b-ul2"] T5_MODEL_DIMS = { "t5-small": 512, @@ -304,6 +304,7 @@ class T5Conditioner(Conditioner): "google/flan-t5-11b": 1024, "google/flan-t5-xl": 2048, "google/flan-t5-xxl": 4096, + "google/t5gemma-b-b-ul2": 768 } def __init__( @@ -317,7 +318,7 @@ def __init__( assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) - from transformers import T5EncoderModel, AutoTokenizer + from transformers import T5EncoderModel, T5GemmaEncoderModel, AutoTokenizer self.max_length = max_length self.enable_grad = enable_grad @@ -331,7 +332,11 @@ def __init__( # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) - model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + if 'gemma' in t5_model_name: + #T5GemmaEncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"] + model = T5GemmaEncoderModel.from_pretrained(t5_model_name, is_encoder_decoder=False, torch_dtype=torch.float16).train(enable_grad).requires_grad_(enable_grad) + else: + model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) finally: logging.disable(previous_level) @@ -359,7 +364,7 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t self.model.eval() - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + with torch.amp.autocast('cuda', dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): embeddings = self.model( input_ids=input_ids, attention_mask=attention_mask )["last_hidden_state"] @@ -758,4 +763,4 @@ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.An else: raise ValueError(f"Unknown conditioner type: {conditioner_type}") - return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys) \ No newline at end of file + return MultiConditioner(conditioners, default_keys=default_keys, pre_encoded_keys=pre_encoded_keys) diff --git a/stable_audio_tools/training/diffusion.py b/stable_audio_tools/training/diffusion.py index 4583b3e4..1d4eb66c 100644 --- a/stable_audio_tools/training/diffusion.py +++ b/stable_audio_tools/training/diffusion.py @@ -10,6 +10,7 @@ from einops import rearrange from safetensors.torch import save_file from torch import optim +import bitsandbytes as bnb from torch.nn import functional as F from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -80,7 +81,9 @@ def __init__( self.pre_encoded = pre_encoded def configure_optimizers(self): - return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + #return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + #return bnb.optim.Adam(model.parameters(), lr=self.lr, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5) [*self.diffusion.parameters()] + return bnb.optim.AdamW8bit([*self.diffusion.parameters()], lr=self.lr) def training_step(self, batch, batch_idx): reals = batch[0] @@ -119,7 +122,7 @@ def training_step(self, batch, batch_idx): noised_inputs = diffusion_input * alphas + noise * sigmas targets = noise * alphas - diffusion_input * sigmas - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): v = self.diffusion(noised_inputs, t) loss_info.update({ @@ -184,7 +187,7 @@ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) try: - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) if module.diffusion.pretransform is not None: @@ -365,7 +368,7 @@ def training_step(self, batch, batch_idx): self.diffusion.pretransform.to(self.device) if not self.pre_encoded: - with torch.cuda.amp.autocast() and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) diffusion_input = self.diffusion.pretransform.encode(diffusion_input) @@ -501,7 +504,7 @@ def validation_step(self, batch, batch_idx): diffusion_input = reals - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): conditioning = self.diffusion.conditioner(metadata, self.device) # TODO: decide what to do with padding masks during validation @@ -517,7 +520,7 @@ def validation_step(self, batch, batch_idx): self.diffusion.pretransform.to(self.device) if not self.pre_encoded: - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) diffusion_input = self.diffusion.pretransform.encode(diffusion_input) @@ -556,7 +559,7 @@ def validation_step(self, batch, batch_idx): # if use_padding_mask: # extra_args["mask"] = padding_masks - with torch.cuda.amp.autocast() and torch.no_grad(): + with torch.amp.autocast('cuda') and torch.no_grad(): output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0, **extra_args) val_loss = F.mse_loss(output, targets) @@ -654,7 +657,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp try: print("Getting conditioning") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): conditioning = module.diffusion.conditioner(demo_cond, module.device) cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) @@ -698,7 +701,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp print(f"Generating demo for cfg scale {cfg_scale}") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): model = module.diffusion_ema.ema_model if module.diffusion_ema is not None else module.diffusion.model if module.diffusion_objective == "v": @@ -879,7 +882,7 @@ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outp model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model print(f"Generating demo for cfg scale {cfg_scale}") - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): if module.diffusion_objective == "v": fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, dist_shift=module.diffusion.dist_shift, batch_cfg=True) elif module.diffusion_objective == "rectified_flow": @@ -988,7 +991,7 @@ def __init__( self.losses = MultiLoss(loss_modules) def configure_optimizers(self): - return optim.Adam([*self.diffae.parameters()], lr=self.lr) + return bnb.optim.Adam8bit([*self.diffae.parameters()], lr=self.lr) def training_step(self, batch, batch_idx): reals = batch[0] @@ -1034,7 +1037,7 @@ def training_step(self, batch, batch_idx): noised_reals = reals * alphas + noise * sigmas targets = noise * alphas - reals * sigmas - with torch.cuda.amp.autocast(): + with torch.amp.autocast('cuda'): v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) loss_info.update({ @@ -1114,7 +1117,7 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe demo_reals = demo_reals.to(module.device) - with torch.no_grad() and torch.cuda.amp.autocast(): + with torch.no_grad() and torch.amp.autocast('cuda'): latents = module.diffae_ema.ema_model.encode(encoder_input).float() fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) @@ -1147,7 +1150,7 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe audio_spectrogram_image(reals_fakes)) if module.diffae_ema.ema_model.pretransform is not None: - with torch.no_grad() and torch.cuda.amp.autocast(): + with torch.no_grad() and torch.amp.autocast('cuda'): initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') @@ -1163,4 +1166,4 @@ def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrappe tokens_spectrogram_image(initial_latents)) log_image( trainer.logger, "first_stage_melspec_left", - audio_spectrogram_image(first_stage_fakes)) \ No newline at end of file + audio_spectrogram_image(first_stage_fakes)) diff --git a/stable_audio_tools/training/utils.py b/stable_audio_tools/training/utils.py index 8b69cd8c..74b698c6 100644 --- a/stable_audio_tools/training/utils.py +++ b/stable_audio_tools/training/utils.py @@ -1,10 +1,17 @@ from pytorch_lightning.loggers import WandbLogger, CometLogger from ..interface.aeiou import pca_point_cloud +from pytorch_optimizer.lr_scheduler.chebyshev import ( + get_chebyshev_perm_steps, + get_chebyshev_permutation, + get_chebyshev_schedule + ) import wandb import torch import os +import bitsandbytes as bnb + def get_rank(): """Get rank of current process.""" @@ -73,6 +80,8 @@ def create_optimizer_from_config(optimizer_config, parameters): if optimizer_type == "FusedAdam": from deepspeed.ops.adam import FusedAdam optimizer = FusedAdam(parameters, **optimizer_config["config"]) + elif optimizer_type == "AdamW8bit": + optimizer = bnb.optim.AdamW8bit(parameters, **optimizer_config["config"]) else: optimizer_fn = getattr(torch.optim, optimizer_type) optimizer = optimizer_fn(parameters, **optimizer_config["config"]) @@ -90,6 +99,8 @@ def create_scheduler_from_config(scheduler_config, optimizer): """ if scheduler_config["type"] == "InverseLR": scheduler_fn = InverseLR + elif scheduler_config["type"] == "Chebyshev": + scheduler_fn = get_chebyshev_schedule(optimizer=optimizer, num_epochs=scheduler_config["num_epochs"], is_warmup=scheduler_config["is_warmup"], last_epoch=scheduler_config["last_epoch"]) else: scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) @@ -126,4 +137,4 @@ def log_point_cloud(logger, key, tokens, caption=None): logger.experiment.log({key: point_cloud}) elif isinstance(logger, CometLogger): point_cloud = pca_point_cloud(tokens, rgb_float=True, output_type="points") - #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) \ No newline at end of file + #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) diff --git a/test6.ini b/test6.ini new file mode 100644 index 00000000..a6962b85 --- /dev/null +++ b/test6.ini @@ -0,0 +1,70 @@ + +#name of the run +name = stable_audio_tools + +# name of the project +project = test6 + +# the batch size +batch_size = 4 + +# If `true`, attempts to resume training from latest checkpoint. +# In this case, each run must have unique config filename. +recover = false + +# Save top K model checkpoints during training. +save_top_k = -1 + +# number of nodes to use for training +num_nodes = 1 + +# Multi-GPU strategy for PyTorch Lightning +strategy = "auto" + +# Precision to use for training +precision = "16-mixed" + +# number of CPU workers for the DataLoader +num_workers = 6 + +# the random seed +seed = 42 + +# Batches for gradient accumulation +accum_batches = 1 + +# Number of steps between checkpoints +checkpoint_every = 1000 + +# Number of steps between validation runs +val_every = -1 + +# trainer checkpoint file to restart training from +ckpt_path = '' + +# model checkpoint file to start a new training run from +pretrained_ckpt_path = '' + +# Checkpoint path for the pretransform model if needed +pretransform_ckpt_path = '' + +# configuration model specifying model hyperparameters +model_config = '' + +# configuration for datasets +dataset_config = '' + +# configuration for validation datasets +val_dataset_config = '' + +# directory to save the checkpoints in +save_dir = 'output' + +# gradient_clip_val passed into PyTorch Lightning Trainer +gradient_clip_val = 0.0 + +# remove the weight norm from the pretransform model +remove_pretransform_weight_norm = '' + +# Logger type to use +logger = 'wandb' diff --git a/train.bat b/train.bat new file mode 100644 index 00000000..f88618b5 --- /dev/null +++ b/train.bat @@ -0,0 +1,12 @@ +python train.py ^ + --name saos1 ^ + --pretrained-ckpt-path .\sao_small\base_model.ckpt ^ + --model-config .\sao_small\base_model_config.json ^ + --batch-size 2 ^ + --num-workers 4 ^ + --seed 1937401721 ^ + --checkpoint-every 1000 ^ + --dataset-config dataset_config.json ^ + --save-dir outputs ^ + --precision 16-mixed ^ + diff --git a/vae_model_config.json b/vae_model_config.json new file mode 100644 index 00000000..81ccbfd5 --- /dev/null +++ b/vae_model_config.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 5e-5, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 25.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-5 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file