Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions bin/configure-trusted-publisher
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""
Configure a PyPI trusted publisher via the API.

Reads credentials from ~/.pypirc (like twine), auto-detects the provider and
repo from the current git checkout, and lets you pick a workflow file.

Usage:
configure-trusted-publisher <project> [options]

Examples:
configure-trusted-publisher mypackage
configure-trusted-publisher mypackage --dry-run
configure-trusted-publisher mypackage --environment release
configure-trusted-publisher mypackage --api-url https://test.pypi.org
"""

from __future__ import annotations

import argparse
import configparser
import glob
import json
import os
import re
import subprocess
import sys
import urllib.error
import urllib.request
from pathlib import Path


CONTENT_TYPE = "application/vnd.pypi.api-v0-danger+json"
DEFAULT_PYPI_URL = "https://pypi.org"
DEFAULT_REPOSITORY = "pypi"


def _read_pypirc(config_file: str | None, repository: str) -> dict:
"""Read credentials from ~/.pypirc, mirroring twine's logic."""
path = Path(config_file) if config_file else Path.home() / ".pypirc"

parser = configparser.RawConfigParser()
if path.exists():
try:
parser.read(str(path), encoding="utf-8")
except UnicodeDecodeError:
parser.read(str(path))

# Collect server-login defaults (deprecated but supported)
defaults: dict[str, str | None] = {
"username": None,
"password": None,
}
if parser.has_section("server-login"):
defaults["username"] = parser.get("server-login", "username", fallback=None)
defaults["password"] = parser.get("server-login", "password", fallback=None)

if parser.has_section(repository):
return {
"repository": parser.get(repository, "repository", fallback=None),
"username": parser.get(repository, "username", fallback=defaults["username"]),
"password": parser.get(repository, "password", fallback=defaults["password"]),
}

return defaults


def _resolve_token(args: argparse.Namespace) -> str:
"""Resolve API token from CLI arg, env var, or .pypirc."""
if args.token:
return args.token

token = os.environ.get("PYPI_TOKEN")
if token:
return token

config = _read_pypirc(getattr(args, "config_file", None), args.repository)
token = config.get("password")
if token:
return token

print(
"Error: No API token found.\n"
"Provide via --token, PYPI_TOKEN env var, or ~/.pypirc [pypi] password field.",
file=sys.stderr,
)
sys.exit(1)
Comment on lines +69 to +88
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep my token safely stored encrypted using keyring (here's how twine does it). Ideally, this script would provide that as an option (or maybe I just pass in --token $(keyring get ...)).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hope is that this script either becomes part of twine, or its own project, and we can sort that out there. It's a demo here because without an api client it's kind of hard to show that it works :)



def _git_remote_url() -> str | None:
"""Return the upstream or origin remote URL of the current git checkout."""
for remote in ("upstream", "origin"):
try:
result = subprocess.run(
["git", "remote", "get-url", remote],
capture_output=True,
text=True,
check=True,
)
return result.stdout.strip()
except subprocess.CalledProcessError:
continue
return None


def _parse_github_remote(url: str) -> tuple[str, str] | None:
"""Parse owner and repo from a GitHub remote URL."""
patterns = [
r"github\.com[:/]([^/]+)/([^/.]+?)(?:\.git)?$",
]
for pattern in patterns:
m = re.search(pattern, url)
if m:
return m.group(1), m.group(2)
return None


def _parse_gitlab_remote(url: str) -> tuple[str, str, str] | None:
"""Parse namespace, project, and host from a GitLab remote URL."""
# Matches both gitlab.com and self-hosted instances
patterns = [
r"(gitlab\.[^:/]+)[:/](.+)/([^/.]+?)(?:\.git)?$",
]
for pattern in patterns:
m = re.search(pattern, url)
if m:
host = f"https://{m.group(1)}"
namespace = m.group(2)
project = m.group(3)
return namespace, project, host
return None


def _list_github_workflows() -> list[str]:
"""List workflow files in .github/workflows/."""
workflows = sorted(glob.glob(".github/workflows/*.yml") + glob.glob(".github/workflows/*.yaml"))
return [Path(w).name for w in workflows]


def _list_gitlab_pipelines() -> list[str]:
"""List GitLab CI pipeline files."""
candidates = [".gitlab-ci.yml", ".gitlab-ci.yaml"]
candidates += sorted(glob.glob("ci/**/*.yml", recursive=True))
candidates += sorted(glob.glob("ci/**/*.yaml", recursive=True))
return [p for p in candidates if Path(p).exists()]


def _prompt_choice(prompt: str, choices: list[str]) -> str:
"""Print numbered choices and prompt user to pick one."""
for i, choice in enumerate(choices, 1):
print(f" {i}. {choice}")
while True:
raw = input(f"{prompt} [1-{len(choices)}]: ").strip()
try:
idx = int(raw) - 1
if 0 <= idx < len(choices):
return choices[idx]
except ValueError:
pass
print(f"Please enter a number between 1 and {len(choices)}.")


def _api_url(base_url: str, project: str, publisher_id: str | None = None) -> str:
base = base_url.rstrip("/")
path = f"/danger-api/projects/{project}/trusted-publishers"
if publisher_id:
path += f"/{publisher_id}"
return base + path


def _call_api(url: str, token: str, payload: dict) -> dict:
body = json.dumps(payload).encode()
req = urllib.request.Request(
url,
data=body,
method="POST",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": CONTENT_TYPE,
"Accept": CONTENT_TYPE,
},
)
try:
with urllib.request.urlopen(req) as resp:
return json.loads(resp.read())
except urllib.error.HTTPError as e:
body = e.read().decode(errors="replace")
try:
detail = json.loads(body)
except Exception:
detail = {"raw": body}
print(f"Error {e.code}: {detail}", file=sys.stderr)
sys.exit(1)


def _print_curl(url: str, token: str, payload: dict, *, expose_token: bool = False) -> None:
body = json.dumps(payload, indent=2)
token_value = token if expose_token else "<YOUR-API-TOKEN>"
print("\nEquivalent curl command:\n")
print(
f"curl -X POST '{url}' \\\n"
f" -H 'Authorization: Bearer {token_value}' \\\n"
f" -H 'Content-Type: {CONTENT_TYPE}' \\\n"
f" -H 'Accept: {CONTENT_TYPE}' \\\n"
f" -d '{body}'"
)


def main() -> None:
parser = argparse.ArgumentParser(
description="Configure a PyPI trusted publisher via the API.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("project", help="PyPI project name")
parser.add_argument(
"--api-url",
default=DEFAULT_PYPI_URL,
help=f"PyPI API base URL (default: {DEFAULT_PYPI_URL})",
)
parser.add_argument(
"--repository",
default=DEFAULT_REPOSITORY,
help=f"~/.pypirc repository section (default: {DEFAULT_REPOSITORY})",
)
parser.add_argument("--token", help="API token (overrides .pypirc and PYPI_TOKEN)")
parser.add_argument(
"--environment",
default="",
help="CI environment name (optional, e.g. 'release')",
)
parser.add_argument(
"-n", "--dry-run",
action="store_true",
help="Print the curl command instead of calling the API",
)
parser.add_argument(
"--show-token",
action="store_true",
help="Include the real API token in the --dry-run curl output (CAUTION: token will be visible in your shell history and terminal)",
)
parser.add_argument(
"--workflow",
help="Workflow filename (skips interactive selection)",
)
parser.add_argument(
"--config-file",
help="Path to .pypirc (default: ~/.pypirc)",
)
args = parser.parse_args()

token = _resolve_token(args)

remote_url = _git_remote_url()
if not remote_url:
print(
"Error: Not in a git repository or no 'upstream'/'origin' remote found.",
file=sys.stderr,
)
sys.exit(1)

github_info = _parse_github_remote(remote_url)
gitlab_info = _parse_gitlab_remote(remote_url)

if github_info:
owner, repo = github_info
print(f"Detected GitHub repository: {owner}/{repo}")

workflows = _list_github_workflows()
if args.workflow:
workflow = args.workflow
elif not workflows:
print(
"No workflow files found in .github/workflows/.\n"
"Provide --workflow <filename> to specify one manually.",
file=sys.stderr,
)
sys.exit(1)
elif len(workflows) == 1:
workflow = workflows[0]
print(f"Using workflow: {workflow}")
else:
print("\nAvailable workflows:")
workflow = _prompt_choice("Select workflow", workflows)

payload: dict = {
"publisher": "github",
"owner": owner,
"repository": repo,
"workflow_filename": workflow,
}
if args.environment:
payload["environment"] = args.environment

elif gitlab_info:
namespace, project, host = gitlab_info
print(f"Detected GitLab repository: {namespace}/{project} at {host}")

pipelines = _list_gitlab_pipelines()
if args.workflow:
workflow_filepath = args.workflow
elif not pipelines:
print(
"No pipeline files found. "
"Provide --workflow <filepath> to specify one manually.",
file=sys.stderr,
)
sys.exit(1)
elif len(pipelines) == 1:
workflow_filepath = pipelines[0]
print(f"Using pipeline: {workflow_filepath}")
else:
print("\nAvailable pipeline files:")
workflow_filepath = _prompt_choice("Select pipeline file", pipelines)

payload = {
"publisher": "gitlab",
"namespace": namespace,
"project": project,
"workflow_filepath": workflow_filepath,
"issuer_url": host,
}
if args.environment:
payload["environment"] = args.environment

else:
print(
f"Error: Could not detect GitHub or GitLab from remote URL: {remote_url}\n"
"Only GitHub and GitLab are supported by this auto-detection script.\n"
"For Google or ActiveState publishers, call the API directly.",
file=sys.stderr,
)
sys.exit(1)

url = _api_url(args.api_url, args.project)

if args.dry_run:
_print_curl(url, token, payload, expose_token=args.show_token)
return

result = _call_api(url, token, payload)
publisher = result.get("trusted_publisher", {})
print(
f"\nSuccess! Trusted publisher added to {args.project}:\n"
f" Provider: {publisher.get('publisher_name')}\n"
f" Specifier: {publisher.get('specifier')}\n"
f" URL: {publisher.get('publisher_url')}\n"
f" ID: {publisher.get('id')}"
)


if __name__ == "__main__":
main()
Loading