diff --git a/simpletuner/simpletuner_sdk/server/services/cloud/pricing.py b/simpletuner/simpletuner_sdk/server/services/cloud/pricing.py index d1e58a30f..865ba02c1 100644 --- a/simpletuner/simpletuner_sdk/server/services/cloud/pricing.py +++ b/simpletuner/simpletuner_sdk/server/services/cloud/pricing.py @@ -245,12 +245,12 @@ def get_default_hardware(self) -> HardwareOption: REPLICATE_DEFAULT_HARDWARE = { "gpu-l40s": { "name": "L40S (48GB)", - "cost_per_second": 0.000975, + "cost_per_second": 0.000972222, "memory_gb": 48, }, - "gpu-a100-large": { - "name": "A100 (80GB)", - "cost_per_second": 0.001400, + "gpu-h100": { + "name": "H100 (80GB)", + "cost_per_second": 0.001525, "memory_gb": 80, }, } diff --git a/simpletuner/simpletuner_sdk/server/services/cloud/provider_registry.py b/simpletuner/simpletuner_sdk/server/services/cloud/provider_registry.py index 96a73cbfa..1fbfb6dce 100644 --- a/simpletuner/simpletuner_sdk/server/services/cloud/provider_registry.py +++ b/simpletuner/simpletuner_sdk/server/services/cloud/provider_registry.py @@ -93,7 +93,12 @@ async def get_enriched_providers() -> List[Dict[str, Any]]: - Hardware info and costs """ from ...routes.cloud._shared import get_job_store - from .replicate_client import DEFAULT_MODEL, get_default_hardware_cost_per_hour, get_hardware_info_async + from .replicate_client import ( + DEFAULT_HARDWARE_INFO, + DEFAULT_MODEL, + get_default_hardware_cost_per_hour, + get_hardware_info_async, + ) from .replicate_profiles import ( DEFAULT_REPLICATE_HARDWARE_PROFILE, get_replicate_hardware_profile, @@ -120,6 +125,25 @@ async def get_enriched_providers() -> List[Dict[str, Any]]: hardware_info = await get_hardware_info_async(store) l40s_info = hardware_info.get(definition.default_hardware_id, {}) cost_per_hour = await get_default_hardware_cost_per_hour(store) + profile_options = list_replicate_hardware_profiles() + base_costs = { + "h100": (hardware_info.get("gpu-h100") or DEFAULT_HARDWARE_INFO["gpu-h100"]).get("cost_per_second", 0.001525) + * 3600, + "l40s": (hardware_info.get("gpu-l40s") or DEFAULT_HARDWARE_INFO["gpu-l40s"]).get( + "cost_per_second", 0.000972222 + ) + * 3600, + } + for option in profile_options: + base_profile = "h100" if option["id"].startswith("h100") else "l40s" + multiplier = 1 + if "-x" in option["id"]: + try: + multiplier = int(option["id"].rsplit("-x", 1)[1]) + except (TypeError, ValueError): + multiplier = 1 + option["cost_per_hour"] = round(base_costs[base_profile] * multiplier, 2) + option["cost_per_second"] = (base_costs[base_profile] / 3600) * multiplier provider_data.update( { @@ -129,7 +153,7 @@ async def get_enriched_providers() -> List[Dict[str, Any]]: "cost_per_hour": round(cost_per_hour, 2), "configured": bool(get_secrets_manager().get_replicate_token()), "hardware_profile": default_profile.id, - "hardware_profiles": list_replicate_hardware_profiles(), + "hardware_profiles": profile_options, } ) diff --git a/simpletuner/simpletuner_sdk/server/services/cloud/replicate_client.py b/simpletuner/simpletuner_sdk/server/services/cloud/replicate_client.py index bf8fe1508..4680f9684 100644 --- a/simpletuner/simpletuner_sdk/server/services/cloud/replicate_client.py +++ b/simpletuner/simpletuner_sdk/server/services/cloud/replicate_client.py @@ -41,10 +41,10 @@ def _get_credential_resolver(): DEFAULT_MODEL = get_replicate_hardware_profile(DEFAULT_REPLICATE_HARDWARE_PROFILE).model # Default hardware info for cost estimation (fallback values if not configured) -# These are based on Replicate's published pricing as of 2024 +# These are based on Replicate's published pricing. DEFAULT_HARDWARE_INFO: Dict[str, Dict[str, Any]] = { - "gpu-l40s": {"name": "L40S (48GB)", "cost_per_second": 0.000975}, - "gpu-a100-large": {"name": "A100 (80GB)", "cost_per_second": 0.001400}, + "gpu-l40s": {"name": "L40S (48GB)", "cost_per_second": 0.000972222}, + "gpu-h100": {"name": "H100 (80GB)", "cost_per_second": 0.001525}, } # Cached hardware info (loaded from config or defaults) @@ -112,7 +112,7 @@ async def get_default_hardware_cost_per_hour(store: Optional[Any] = None) -> flo """ hardware = await get_hardware_info_async(store) l40s = hardware.get("gpu-l40s", DEFAULT_HARDWARE_INFO["gpu-l40s"]) - return l40s.get("cost_per_second", 0.000975) * 3600 + return l40s.get("cost_per_second", 0.000972222) * 3600 def update_hardware_info_cache(hardware_info: Dict[str, Dict[str, Any]]) -> None: diff --git a/simpletuner/simpletuner_sdk/server/services/cloud/replicate_profiles.py b/simpletuner/simpletuner_sdk/server/services/cloud/replicate_profiles.py index 85bdcba24..591e3922d 100644 --- a/simpletuner/simpletuner_sdk/server/services/cloud/replicate_profiles.py +++ b/simpletuner/simpletuner_sdk/server/services/cloud/replicate_profiles.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict, List +from typing import Any, Dict, List DEFAULT_REPLICATE_HARDWARE_PROFILE = "h100" @@ -70,7 +70,7 @@ class ReplicateHardwareProfile: } -def list_replicate_hardware_profiles() -> List[Dict[str, str]]: +def list_replicate_hardware_profiles() -> List[Dict[str, Any]]: """Return profile metadata suitable for API/UI responses.""" return [ { diff --git a/simpletuner/static/js/modules/cloud/job-submission.js b/simpletuner/static/js/modules/cloud/job-submission.js index d48bd8c25..f1c8683f4 100644 --- a/simpletuner/static/js/modules/cloud/job-submission.js +++ b/simpletuner/static/js/modules/cloud/job-submission.js @@ -255,6 +255,66 @@ window.cloudSubmissionMethods = { ]; }, + getReplicateBaseHardwareOptions() { + return this.getReplicateHardwareProfiles() + .filter((profile) => profile.id === 'l40s' || profile.id === 'h100') + .sort((a, b) => (a.id === 'l40s' ? -1 : 1)); + }, + + getSelectedReplicateBaseHardware() { + const selected = this.preSubmitModal?.hardwareProfile || this.getDefaultReplicateHardwareProfile(); + return String(selected || '').startsWith('l40s') ? 'l40s' : 'h100'; + }, + + getSelectedReplicateHardwareProfile() { + const selected = this.preSubmitModal?.hardwareProfile || this.getDefaultReplicateHardwareProfile(); + const profiles = this.getReplicateHardwareProfiles(); + return profiles.find((profile) => profile.id === selected) || + profiles.find((profile) => profile.id === this.getSelectedReplicateBaseHardware()) || + null; + }, + + setReplicateBaseHardwareProfile(profileId) { + if (!profileId) { + return; + } + const currentProfile = this.preSubmitModal?.hardwareProfile || this.getDefaultReplicateHardwareProfile(); + const multiplier = String(currentProfile || '').match(/-x\d+$/)?.[0] || ''; + const candidateProfile = `${profileId}${multiplier}`; + const profiles = this.getReplicateHardwareProfiles(); + const nextProfile = profiles.some((profile) => profile.id === candidateProfile) + ? candidateProfile + : profileId; + this.preSubmitModal.hardwareProfile = nextProfile; + this.saveHardwareProfile(nextProfile); + }, + + getReplicateBaseHardwareCostDisplay() { + const selected = this.getSelectedReplicateBaseHardware(); + const profile = this.getSelectedReplicateHardwareProfile() || + this.getReplicateBaseHardwareOptions().find((option) => option.id === selected); + if (typeof profile?.cost_per_hour === 'number') { + return '$' + profile.cost_per_hour.toFixed(2) + '/hr'; + } + if (selected === 'h100') { + return '$5.49/hr'; + } + return '$3.50/hr'; + }, + + getReplicateBaseHardwareCostDetail() { + const selected = this.getSelectedReplicateBaseHardware(); + const profile = this.getSelectedReplicateHardwareProfile() || + this.getReplicateBaseHardwareOptions().find((option) => option.id === selected); + if (typeof profile?.cost_per_second === 'number') { + return '$' + profile.cost_per_second.toFixed(6) + '/sec'; + } + if (selected === 'h100') { + return '$0.001525/sec'; + } + return '$0.000972/sec'; + }, + getDefaultReplicateHardwareProfile() { const storedProfile = localStorage.getItem('cloud_replicate_hardware_profile'); if (storedProfile) { diff --git a/simpletuner/templates/cloud_tab.html b/simpletuner/templates/cloud_tab.html index e262b3936..1a52e32f1 100644 --- a/simpletuner/templates/cloud_tab.html +++ b/simpletuner/templates/cloud_tab.html @@ -161,18 +161,47 @@