From 6671c21102ab504f99987cd483b00d341e1c763b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Pe=C3=B1a-Castellanos?= Date: Mon, 4 May 2026 18:30:33 -0500 Subject: [PATCH] Replace 'aws s3 ls' shell-out in dataset.py with boto3 --- .gitignore | 4 +- README.md | 12 +++ setup.py | 1 + stable_audio_tools/data/dataset.py | 120 +++++++++++++--------- tests/test_dataset_s3.py | 156 +++++++++++++++++++++++++++++ 5 files changed, 246 insertions(+), 47 deletions(-) create mode 100644 tests/test_dataset_s3.py diff --git a/.gitignore b/.gitignore index 3e6aee68..300055d7 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,6 @@ cython_debug/ *.ckpt *.wav -wandb/* \ No newline at end of file +wandb/* +# macOS +.DS_Store diff --git a/README.md b/README.md index a0cca502..00bb8ba6 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,18 @@ The following properties are defined in the top level of the model configuration ## Dataset config `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md) +## S3-compatible storage (Backblaze B2) +The S3 dataset loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at any S3-compatible host without changing the dataset config. + +Example for [Backblaze B2](https://www.backblaze.com/cloud-storage): +```bash +export AWS_ENDPOINT_URL=https://s3.us-west-004.backblazeb2.com +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +``` + +When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3 — existing setups are unaffected. + # Todo - [ ] Add troubleshooting section - [ ] Add contribution guidelines diff --git a/setup.py b/setup.py index f96f3bc1..f8d2ecf0 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ install_requires=[ 'alias-free-torch==0.0.6', 'auraloss==0.4.0', + 'boto3', 'descript-audio-codec==1.0.0', 'einops', 'einops-exts', diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 7543ac17..a42f0f38 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -5,8 +5,6 @@ import os import posixpath import random -import re -import subprocess import time import torch import torchaudio @@ -359,51 +357,71 @@ def __getitem__(self, idx): # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py +def _get_s3_client(profile=None): + """ + Build a boto3 S3 client. Honors AWS_ENDPOINT_URL when set so the same + code path works against any S3-compatible endpoint (AWS S3 by default; + set AWS_ENDPOINT_URL to a Backblaze B2 endpoint to point it at B2). + When the env var is unset, behavior matches the default AWS client. + """ + import boto3 # local import so boto3 is only required when S3 is used + + endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or None + session = boto3.Session(profile_name=profile) if profile else boto3.Session() + return session.client("s3", endpoint_url=endpoint_url) + + def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): """ - Returns a list of full S3 paths to files in a given S3 bucket and directory path. + Returns a list of S3 keys (relative to ``dataset_path``) for objects in a + given S3 bucket and directory path. Uses boto3 directly so it works + against any S3-compatible endpoint when ``AWS_ENDPOINT_URL`` is set. """ # Ensure dataset_path ends with a trailing slash if dataset_path != '' and not dataset_path.endswith('/'): dataset_path += '/' - # Use posixpath to construct the S3 URL path + # Use posixpath to construct the S3 URL path (e.g. "s3://bucket/prefix/") bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) - # Construct the `aws s3 ls` command - cmd = ['aws', 's3', 'ls', bucket_path] - if profile is not None: - cmd.extend(['--profile', profile]) + # Parse "s3://bucket/prefix/..." into bucket + prefix. + if not bucket_path.startswith("s3://"): + raise ValueError( + f"get_s3_contents expected an s3:// URL, got: {bucket_path!r}" + ) + without_scheme = bucket_path[len("s3://"):] + bucket, _, prefix = without_scheme.partition("/") + + s3 = _get_s3_client(profile=profile) + paginator = s3.get_paginator("list_objects_v2") + list_kwargs = {"Bucket": bucket, "Prefix": prefix} + if not recursive: + list_kwargs["Delimiter"] = "/" + + keys = [] + for page in paginator.paginate(**list_kwargs): + for obj in page.get("Contents", []) or []: + key = obj.get("Key", "") + if not key or key.endswith("/"): + continue + keys.append(key) - if recursive: - # Add the --recursive flag if requested - cmd.append('--recursive') - - # Run the `aws s3 ls` command and capture the output - run_ls = subprocess.run(cmd, capture_output=True, check=True) - # Split the output into lines and strip whitespace from each line - contents = run_ls.stdout.decode('utf-8').split('\n') - contents = [x.strip() for x in contents if x] - # Remove the timestamp from lines that begin with a timestamp - contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) - if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] - # Construct a full S3 path for each file in the contents list - contents = [posixpath.join(s3_url_prefix or '', x) - for x in contents if not x.endswith('/')] # Apply the filter, if specified if filter: - contents = [x for x in contents if filter in x] - # Remove redundant directory names in the S3 URL - if recursive: - # Get the main directory name from the S3 URL - main_dir = "/".join(bucket_path.split('/')[3:]) - # Remove the redundant directory names from each file path - contents = [x.replace(f'{main_dir}', '').replace( - '//', '/') for x in contents] - # Print debugging information, if requested + keys = [k for k in keys if filter in k] + + # Match the legacy `aws s3 ls` output shape: paths relative to dataset_path. + # The legacy CLI emitted basenames in non-recursive mode and full keys + # (which it then stripped) in recursive mode; both paths ended up + # relative to dataset_path. boto3 always returns full keys, so strip + # the prefix unconditionally. + if prefix: + keys = [k[len(prefix):] if k.startswith(prefix) else k for k in keys] + keys = [k.lstrip('/') for k in keys] + if debug: - print("contents = \n", contents) - # Return the list of S3 paths to files - return contents + print("contents = \n", keys) + + return keys def get_all_s3_urls( @@ -436,22 +454,32 @@ def get_all_s3_urls( profile = profiles.get(name, None) tar_list = get_s3_contents( subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + # Build a boto3 client once per (name, subset) for presigning. + s3_client = _get_s3_client(profile=profile) for tar in tar_list: - # Escape spaces and parentheses in the tar filename for use in the shell command - tar = tar.replace(" ", "\ ").replace( - "(", "\(").replace(")", "\)") - # Construct the S3 path to the current tar file - s3_path = posixpath.join(name, subset, tar) + " -" - # Construct the AWS CLI command to download the current tar file + # Construct the full s3:// URL for the current tar file. if s3_url_prefix is None: - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" + full_s3_url = posixpath.join(name, subset, tar) else: - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" - if profiles.get(name): - request_str += f" --profile {profiles.get(name)}" + full_s3_url = posixpath.join(s3_url_prefix, name, subset, tar) + + if not full_s3_url.startswith("s3://"): + raise ValueError( + f"get_all_s3_urls expected an s3:// URL, got: {full_s3_url!r}" + ) + without_scheme = full_s3_url[len("s3://"):] + bucket, _, key = without_scheme.partition("/") + + # Short-lived (1h) presigned GET URL works against AWS and any + # S3-compatible endpoint when AWS_ENDPOINT_URL is set. + presigned = s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket, "Key": key}, + ExpiresIn=3600, + ) + request_str = f'pipe:curl -fsSL "{presigned}"' if debug: print("request_str = ", request_str) - # Add the constructed URL to the list of URLs urls.append(request_str) return urls diff --git a/tests/test_dataset_s3.py b/tests/test_dataset_s3.py new file mode 100644 index 00000000..8d7aa468 --- /dev/null +++ b/tests/test_dataset_s3.py @@ -0,0 +1,156 @@ +import os +from unittest import mock + +import pytest + +from stable_audio_tools.data import dataset as ds + + +def _fake_paginator(pages): + "Paginator-like mock; records last paginate(**kwargs) on .last_kwargs." + pag = mock.MagicMock() + pag.last_kwargs = {} + + def paginate(**kwargs): + pag.last_kwargs = kwargs + return iter(pages) + + pag.paginate.side_effect = paginate + return pag + + +def _fake_client(pages=None, presigned_url="https://example.com/signed"): + client = mock.MagicMock() + client.get_paginator.return_value = _fake_paginator(pages or []) + client.generate_presigned_url.return_value = presigned_url + return client + + +def test_get_s3_client_uses_aws_endpoint_url_env(): + fake_boto3 = mock.MagicMock() + fake_session = mock.MagicMock() + fake_boto3.Session.return_value = fake_session + + with mock.patch.dict(os.environ, {"AWS_ENDPOINT_URL": "https://s3.us-west-004.backblazeb2.com"}, clear=False): + with mock.patch.dict("sys.modules", {"boto3": fake_boto3}): + ds._get_s3_client() + + fake_session.client.assert_called_once_with( + "s3", endpoint_url="https://s3.us-west-004.backblazeb2.com" + ) + + +def test_get_s3_client_default_when_env_unset(): + fake_boto3 = mock.MagicMock() + fake_session = mock.MagicMock() + fake_boto3.Session.return_value = fake_session + + env = {k: v for k, v in os.environ.items() if k != "AWS_ENDPOINT_URL"} + with mock.patch.dict(os.environ, env, clear=True): + with mock.patch.dict("sys.modules", {"boto3": fake_boto3}): + ds._get_s3_client() + + # endpoint_url=None preserves boto3's default (AWS) behavior. + fake_session.client.assert_called_once_with("s3", endpoint_url=None) + + +def test_get_s3_client_uses_profile_when_given(): + fake_boto3 = mock.MagicMock() + fake_session = mock.MagicMock() + fake_boto3.Session.return_value = fake_session + + with mock.patch.dict("sys.modules", {"boto3": fake_boto3}): + ds._get_s3_client(profile="myprofile") + + fake_boto3.Session.assert_called_once_with(profile_name="myprofile") + + +def test_get_s3_contents_returns_keys_relative_to_prefix(): + pages = [ + {"Contents": [ + {"Key": "prefix/a.tar"}, + {"Key": "prefix/sub/b.tar"}, + {"Key": "prefix/"}, # directory marker -> skipped + ]}, + ] + client = _fake_client(pages=pages) + + with mock.patch.object(ds, "_get_s3_client", return_value=client): + keys = ds.get_s3_contents("s3://bucket/prefix/", recursive=True) + + client.get_paginator.assert_called_once_with("list_objects_v2") + pag = client.get_paginator.return_value + assert pag.last_kwargs == {"Bucket": "bucket", "Prefix": "prefix/"} + # Recursive mode strips the bucket-level prefix from each key. + assert keys == ["a.tar", "sub/b.tar"] + + +def test_get_s3_contents_non_recursive_adds_delimiter(): + client = _fake_client(pages=[{"Contents": []}]) + + with mock.patch.object(ds, "_get_s3_client", return_value=client): + ds.get_s3_contents("s3://bucket/prefix/", recursive=False) + + pag = client.get_paginator.return_value + assert pag.last_kwargs == { + "Bucket": "bucket", + "Prefix": "prefix/", + "Delimiter": "/", + } + + +def test_get_s3_contents_non_recursive_strips_prefix_from_keys(): + pages = [{"Contents": [ + {"Key": "prefix/a.tar"}, + {"Key": "prefix/b.tar"}, + ]}] + client = _fake_client(pages=pages) + + with mock.patch.object(ds, "_get_s3_client", return_value=client): + keys = ds.get_s3_contents("s3://bucket/prefix/", recursive=False) + + # Keys must be relative to dataset_path in BOTH recursive and non-recursive + # modes (matches the legacy `aws s3 ls` output shape; without this strip, + # `get_all_s3_urls(..., recursive=False)` joins the prefix twice). + assert keys == ["a.tar", "b.tar"] + + +def test_get_s3_contents_applies_filter(): + pages = [{"Contents": [ + {"Key": "prefix/a.tar"}, + {"Key": "prefix/b.txt"}, + {"Key": "prefix/c.tar"}, + ]}] + client = _fake_client(pages=pages) + + with mock.patch.object(ds, "_get_s3_client", return_value=client): + keys = ds.get_s3_contents("s3://bucket/prefix/", filter="tar", recursive=True) + + assert keys == ["a.tar", "c.tar"] + + +def test_get_s3_contents_rejects_non_s3_url(): + with pytest.raises(ValueError): + ds.get_s3_contents("not-an-s3-url/") + + +def test_get_all_s3_urls_emits_pipe_curl_with_presigned_url(): + pages = [{"Contents": [{"Key": "name/train/shard-000.tar"}]}] + fake_url = "https://signed.example.com/shard-000.tar?X-Amz-Signature=abc" + client = _fake_client(pages=pages, presigned_url=fake_url) + + with mock.patch.object(ds, "_get_s3_client", return_value=client): + urls = ds.get_all_s3_urls( + names=["name"], + subsets=["train"], + s3_url_prefix="s3://bucket", + recursive=True, + filter_str="tar", + ) + + assert urls == [f'pipe:curl -fsSL "{fake_url}"'] + client.generate_presigned_url.assert_called_with( + "get_object", + Params={"Bucket": "bucket", "Key": "name/train/shard-000.tar"}, + ExpiresIn=3600, + )