Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,6 @@ cython_debug/

*.ckpt
*.wav
wandb/*
wandb/*
# macOS
.DS_Store
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ The following properties are defined in the top level of the model configuration
## Dataset config
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)

## S3-compatible storage (Backblaze B2)
The S3 dataset loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at any S3-compatible host without changing the dataset config.

Example for [Backblaze B2](https://www.backblaze.com/cloud-storage):
```bash
export AWS_ENDPOINT_URL=https://s3.us-west-004.backblazeb2.com
export AWS_ACCESS_KEY_ID=<B2 application key ID>
export AWS_SECRET_ACCESS_KEY=<B2 application key>
```

When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3 — existing setups are unaffected.

# Todo
- [ ] Add troubleshooting section
- [ ] Add contribution guidelines
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
install_requires=[
'alias-free-torch==0.0.6',
'auraloss==0.4.0',
'boto3',
'descript-audio-codec==1.0.0',
'einops',
'einops-exts',
Expand Down
120 changes: 74 additions & 46 deletions stable_audio_tools/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import os
import posixpath
import random
import re
import subprocess
import time
import torch
import torchaudio
Expand Down Expand Up @@ -359,51 +357,71 @@ def __getitem__(self, idx):

# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py

def _get_s3_client(profile=None):
"""
Build a boto3 S3 client. Honors AWS_ENDPOINT_URL when set so the same
code path works against any S3-compatible endpoint (AWS S3 by default;
set AWS_ENDPOINT_URL to a Backblaze B2 endpoint to point it at B2).
When the env var is unset, behavior matches the default AWS client.
"""
import boto3 # local import so boto3 is only required when S3 is used

endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or None
session = boto3.Session(profile_name=profile) if profile else boto3.Session()
return session.client("s3", endpoint_url=endpoint_url)
Comment thread
goanpeca marked this conversation as resolved.


def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
"""
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
Returns a list of S3 keys (relative to ``dataset_path``) for objects in a
given S3 bucket and directory path. Uses boto3 directly so it works
against any S3-compatible endpoint when ``AWS_ENDPOINT_URL`` is set.
"""
# Ensure dataset_path ends with a trailing slash
if dataset_path != '' and not dataset_path.endswith('/'):
dataset_path += '/'
# Use posixpath to construct the S3 URL path
# Use posixpath to construct the S3 URL path (e.g. "s3://bucket/prefix/")
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
# Construct the `aws s3 ls` command
cmd = ['aws', 's3', 'ls', bucket_path]

if profile is not None:
cmd.extend(['--profile', profile])
# Parse "s3://bucket/prefix/..." into bucket + prefix.
if not bucket_path.startswith("s3://"):
raise ValueError(
f"get_s3_contents expected an s3:// URL, got: {bucket_path!r}"
)
without_scheme = bucket_path[len("s3://"):]
bucket, _, prefix = without_scheme.partition("/")

s3 = _get_s3_client(profile=profile)
paginator = s3.get_paginator("list_objects_v2")
list_kwargs = {"Bucket": bucket, "Prefix": prefix}
if not recursive:
list_kwargs["Delimiter"] = "/"

keys = []
for page in paginator.paginate(**list_kwargs):
for obj in page.get("Contents", []) or []:
key = obj.get("Key", "")
if not key or key.endswith("/"):
continue
keys.append(key)

if recursive:
# Add the --recursive flag if requested
cmd.append('--recursive')

# Run the `aws s3 ls` command and capture the output
run_ls = subprocess.run(cmd, capture_output=True, check=True)
# Split the output into lines and strip whitespace from each line
contents = run_ls.stdout.decode('utf-8').split('\n')
contents = [x.strip() for x in contents if x]
# Remove the timestamp from lines that begin with a timestamp
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
# Construct a full S3 path for each file in the contents list
contents = [posixpath.join(s3_url_prefix or '', x)
for x in contents if not x.endswith('/')]
# Apply the filter, if specified
if filter:
contents = [x for x in contents if filter in x]
# Remove redundant directory names in the S3 URL
if recursive:
# Get the main directory name from the S3 URL
main_dir = "/".join(bucket_path.split('/')[3:])
# Remove the redundant directory names from each file path
contents = [x.replace(f'{main_dir}', '').replace(
'//', '/') for x in contents]
# Print debugging information, if requested
keys = [k for k in keys if filter in k]

# Match the legacy `aws s3 ls` output shape: paths relative to dataset_path.
# The legacy CLI emitted basenames in non-recursive mode and full keys
# (which it then stripped) in recursive mode; both paths ended up
# relative to dataset_path. boto3 always returns full keys, so strip
# the prefix unconditionally.
if prefix:
keys = [k[len(prefix):] if k.startswith(prefix) else k for k in keys]
keys = [k.lstrip('/') for k in keys]

Comment thread
goanpeca marked this conversation as resolved.
if debug:
print("contents = \n", contents)
# Return the list of S3 paths to files
return contents
print("contents = \n", keys)

return keys


def get_all_s3_urls(
Expand Down Expand Up @@ -436,22 +454,32 @@ def get_all_s3_urls(
profile = profiles.get(name, None)
tar_list = get_s3_contents(
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
# Build a boto3 client once per (name, subset) for presigning.
s3_client = _get_s3_client(profile=profile)
for tar in tar_list:
# Escape spaces and parentheses in the tar filename for use in the shell command
tar = tar.replace(" ", "\ ").replace(
"(", "\(").replace(")", "\)")
# Construct the S3 path to the current tar file
s3_path = posixpath.join(name, subset, tar) + " -"
# Construct the AWS CLI command to download the current tar file
# Construct the full s3:// URL for the current tar file.
if s3_url_prefix is None:
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
full_s3_url = posixpath.join(name, subset, tar)
else:
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
if profiles.get(name):
request_str += f" --profile {profiles.get(name)}"
full_s3_url = posixpath.join(s3_url_prefix, name, subset, tar)

if not full_s3_url.startswith("s3://"):
raise ValueError(
f"get_all_s3_urls expected an s3:// URL, got: {full_s3_url!r}"
)
without_scheme = full_s3_url[len("s3://"):]
bucket, _, key = without_scheme.partition("/")

# Short-lived (1h) presigned GET URL works against AWS and any
# S3-compatible endpoint when AWS_ENDPOINT_URL is set.
presigned = s3_client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
ExpiresIn=3600,
)
request_str = f'pipe:curl -fsSL "{presigned}"'
if debug:
print("request_str = ", request_str)
# Add the constructed URL to the list of URLs
urls.append(request_str)
return urls

Expand Down
156 changes: 156 additions & 0 deletions tests/test_dataset_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
from unittest import mock

import pytest

from stable_audio_tools.data import dataset as ds


def _fake_paginator(pages):
"Paginator-like mock; records last paginate(**kwargs) on .last_kwargs."
pag = mock.MagicMock()
pag.last_kwargs = {}

def paginate(**kwargs):
pag.last_kwargs = kwargs
return iter(pages)

pag.paginate.side_effect = paginate
return pag


def _fake_client(pages=None, presigned_url="https://example.com/signed"):
client = mock.MagicMock()
client.get_paginator.return_value = _fake_paginator(pages or [])
client.generate_presigned_url.return_value = presigned_url
return client


def test_get_s3_client_uses_aws_endpoint_url_env():
fake_boto3 = mock.MagicMock()
fake_session = mock.MagicMock()
fake_boto3.Session.return_value = fake_session

with mock.patch.dict(os.environ, {"AWS_ENDPOINT_URL": "https://s3.us-west-004.backblazeb2.com"}, clear=False):
with mock.patch.dict("sys.modules", {"boto3": fake_boto3}):
ds._get_s3_client()

fake_session.client.assert_called_once_with(
"s3", endpoint_url="https://s3.us-west-004.backblazeb2.com"
)


def test_get_s3_client_default_when_env_unset():
fake_boto3 = mock.MagicMock()
fake_session = mock.MagicMock()
fake_boto3.Session.return_value = fake_session

env = {k: v for k, v in os.environ.items() if k != "AWS_ENDPOINT_URL"}
with mock.patch.dict(os.environ, env, clear=True):
with mock.patch.dict("sys.modules", {"boto3": fake_boto3}):
ds._get_s3_client()

# endpoint_url=None preserves boto3's default (AWS) behavior.
fake_session.client.assert_called_once_with("s3", endpoint_url=None)


def test_get_s3_client_uses_profile_when_given():
fake_boto3 = mock.MagicMock()
fake_session = mock.MagicMock()
fake_boto3.Session.return_value = fake_session

with mock.patch.dict("sys.modules", {"boto3": fake_boto3}):
ds._get_s3_client(profile="myprofile")

fake_boto3.Session.assert_called_once_with(profile_name="myprofile")


def test_get_s3_contents_returns_keys_relative_to_prefix():
pages = [
{"Contents": [
{"Key": "prefix/a.tar"},
{"Key": "prefix/sub/b.tar"},
{"Key": "prefix/"}, # directory marker -> skipped
]},
]
client = _fake_client(pages=pages)

with mock.patch.object(ds, "_get_s3_client", return_value=client):
keys = ds.get_s3_contents("s3://bucket/prefix/", recursive=True)

client.get_paginator.assert_called_once_with("list_objects_v2")
pag = client.get_paginator.return_value
assert pag.last_kwargs == {"Bucket": "bucket", "Prefix": "prefix/"}
# Recursive mode strips the bucket-level prefix from each key.
assert keys == ["a.tar", "sub/b.tar"]


def test_get_s3_contents_non_recursive_adds_delimiter():
client = _fake_client(pages=[{"Contents": []}])

with mock.patch.object(ds, "_get_s3_client", return_value=client):
ds.get_s3_contents("s3://bucket/prefix/", recursive=False)

pag = client.get_paginator.return_value
assert pag.last_kwargs == {
"Bucket": "bucket",
"Prefix": "prefix/",
"Delimiter": "/",
}


def test_get_s3_contents_non_recursive_strips_prefix_from_keys():
pages = [{"Contents": [
{"Key": "prefix/a.tar"},
{"Key": "prefix/b.tar"},
]}]
client = _fake_client(pages=pages)

with mock.patch.object(ds, "_get_s3_client", return_value=client):
keys = ds.get_s3_contents("s3://bucket/prefix/", recursive=False)

# Keys must be relative to dataset_path in BOTH recursive and non-recursive
# modes (matches the legacy `aws s3 ls` output shape; without this strip,
# `get_all_s3_urls(..., recursive=False)` joins the prefix twice).
assert keys == ["a.tar", "b.tar"]


def test_get_s3_contents_applies_filter():
pages = [{"Contents": [
{"Key": "prefix/a.tar"},
{"Key": "prefix/b.txt"},
{"Key": "prefix/c.tar"},
]}]
client = _fake_client(pages=pages)

with mock.patch.object(ds, "_get_s3_client", return_value=client):
keys = ds.get_s3_contents("s3://bucket/prefix/", filter="tar", recursive=True)

assert keys == ["a.tar", "c.tar"]


def test_get_s3_contents_rejects_non_s3_url():
with pytest.raises(ValueError):
ds.get_s3_contents("not-an-s3-url/")


def test_get_all_s3_urls_emits_pipe_curl_with_presigned_url():
pages = [{"Contents": [{"Key": "name/train/shard-000.tar"}]}]
fake_url = "https://signed.example.com/shard-000.tar?X-Amz-Signature=abc"
client = _fake_client(pages=pages, presigned_url=fake_url)

with mock.patch.object(ds, "_get_s3_client", return_value=client):
urls = ds.get_all_s3_urls(
names=["name"],
subsets=["train"],
s3_url_prefix="s3://bucket",
recursive=True,
filter_str="tar",
)

assert urls == [f'pipe:curl -fsSL "{fake_url}"']
client.generate_presigned_url.assert_called_with(
"get_object",
Params={"Bucket": "bucket", "Key": "name/train/shard-000.tar"},
ExpiresIn=3600,
)