Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions simpletuner/simpletuner_sdk/server/services/cloud/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ def get_default_hardware(self) -> HardwareOption:
REPLICATE_DEFAULT_HARDWARE = {
"gpu-l40s": {
"name": "L40S (48GB)",
"cost_per_second": 0.000975,
"cost_per_second": 0.000972222,
"memory_gb": 48,
},
"gpu-a100-large": {
"name": "A100 (80GB)",
"cost_per_second": 0.001400,
"gpu-h100": {
"name": "H100 (80GB)",
"cost_per_second": 0.001525,
"memory_gb": 80,
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ async def get_enriched_providers() -> List[Dict[str, Any]]:
- Hardware info and costs
"""
from ...routes.cloud._shared import get_job_store
from .replicate_client import DEFAULT_MODEL, get_default_hardware_cost_per_hour, get_hardware_info_async
from .replicate_client import (
DEFAULT_HARDWARE_INFO,
DEFAULT_MODEL,
get_default_hardware_cost_per_hour,
get_hardware_info_async,
)
from .replicate_profiles import (
DEFAULT_REPLICATE_HARDWARE_PROFILE,
get_replicate_hardware_profile,
Expand All @@ -120,6 +125,25 @@ async def get_enriched_providers() -> List[Dict[str, Any]]:
hardware_info = await get_hardware_info_async(store)
l40s_info = hardware_info.get(definition.default_hardware_id, {})
cost_per_hour = await get_default_hardware_cost_per_hour(store)
profile_options = list_replicate_hardware_profiles()
base_costs = {
"h100": (hardware_info.get("gpu-h100") or DEFAULT_HARDWARE_INFO["gpu-h100"]).get("cost_per_second", 0.001525)
* 3600,
"l40s": (hardware_info.get("gpu-l40s") or DEFAULT_HARDWARE_INFO["gpu-l40s"]).get(
"cost_per_second", 0.000972222
)
* 3600,
}
for option in profile_options:
base_profile = "h100" if option["id"].startswith("h100") else "l40s"
multiplier = 1
if "-x" in option["id"]:
try:
multiplier = int(option["id"].rsplit("-x", 1)[1])
except (TypeError, ValueError):
multiplier = 1
option["cost_per_hour"] = round(base_costs[base_profile] * multiplier, 2)
option["cost_per_second"] = round((base_costs[base_profile] / 3600) * multiplier, 6)

provider_data.update(
{
Expand All @@ -129,7 +153,7 @@ async def get_enriched_providers() -> List[Dict[str, Any]]:
"cost_per_hour": round(cost_per_hour, 2),
"configured": bool(get_secrets_manager().get_replicate_token()),
"hardware_profile": default_profile.id,
"hardware_profiles": list_replicate_hardware_profiles(),
"hardware_profiles": profile_options,
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def _get_credential_resolver():
DEFAULT_MODEL = get_replicate_hardware_profile(DEFAULT_REPLICATE_HARDWARE_PROFILE).model

# Default hardware info for cost estimation (fallback values if not configured)
# These are based on Replicate's published pricing as of 2024
# These are based on Replicate's published pricing.
DEFAULT_HARDWARE_INFO: Dict[str, Dict[str, Any]] = {
"gpu-l40s": {"name": "L40S (48GB)", "cost_per_second": 0.000975},
"gpu-a100-large": {"name": "A100 (80GB)", "cost_per_second": 0.001400},
"gpu-l40s": {"name": "L40S (48GB)", "cost_per_second": 0.000972222},
"gpu-h100": {"name": "H100 (80GB)", "cost_per_second": 0.001525},
}
Comment on lines 43 to 48

# Cached hardware info (loaded from config or defaults)
Expand Down Expand Up @@ -112,7 +112,7 @@ async def get_default_hardware_cost_per_hour(store: Optional[Any] = None) -> flo
"""
hardware = await get_hardware_info_async(store)
l40s = hardware.get("gpu-l40s", DEFAULT_HARDWARE_INFO["gpu-l40s"])
return l40s.get("cost_per_second", 0.000975) * 3600
return l40s.get("cost_per_second", 0.000972222) * 3600


def update_hardware_info_cache(hardware_info: Dict[str, Dict[str, Any]]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List
from typing import Any, Dict, List

DEFAULT_REPLICATE_HARDWARE_PROFILE = "h100"

Expand Down Expand Up @@ -70,7 +70,7 @@ class ReplicateHardwareProfile:
}


def list_replicate_hardware_profiles() -> List[Dict[str, str]]:
def list_replicate_hardware_profiles() -> List[Dict[str, Any]]:
"""Return profile metadata suitable for API/UI responses."""
return [
{
Expand Down
39 changes: 39 additions & 0 deletions simpletuner/static/js/modules/cloud/job-submission.js
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,45 @@ window.cloudSubmissionMethods = {
];
},

getReplicateBaseHardwareOptions() {
return this.getReplicateHardwareProfiles()
.filter((profile) => profile.id === 'l40s' || profile.id === 'h100')
.sort((a, b) => (a.id === 'l40s' ? -1 : 1));
},

getSelectedReplicateBaseHardware() {
const selected = this.preSubmitModal?.hardwareProfile || this.getDefaultReplicateHardwareProfile();
return String(selected || '').startsWith('l40s') ? 'l40s' : 'h100';
},

setReplicateBaseHardwareProfile(profileId) {
if (!profileId) {
return;
}
this.preSubmitModal.hardwareProfile = profileId;
this.saveHardwareProfile(profileId);
},
Comment on lines +277 to +290

getReplicateBaseHardwareCostDisplay() {
const selected = this.getSelectedReplicateBaseHardware();
const profile = this.getReplicateBaseHardwareOptions().find((option) => option.id === selected);
if (typeof profile?.cost_per_hour === 'number') {
return '$' + profile.cost_per_hour.toFixed(2) + '/hr';
}
if (selected === 'h100') {
return '$5.49/hr';
}
return '$3.50/hr';
},

getReplicateBaseHardwareCostDetail() {
const selected = this.getSelectedReplicateBaseHardware();
if (selected === 'h100') {
return '$0.001525/sec';
}
return '$0.000972/sec';
},
Comment on lines +305 to +316

getDefaultReplicateHardwareProfile() {
const storedProfile = localStorage.getItem('cloud_replicate_hardware_profile');
if (storedProfile) {
Expand Down
42 changes: 35 additions & 7 deletions simpletuner/templates/cloud_tab.html
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,46 @@ <h6 class="mb-0"><i class="fas fa-download me-2"></i>Local Outputs</h6>
<h6 class="mb-0"><i class="fas fa-server me-2"></i>Hardware</h6>
</div>
<div class="cloud-card-body py-2">
<template x-for="provider in providers.filter(p => p.id === activeProvider)" :key="provider.id">
<template x-if="activeProvider === 'replicate'">
<div class="small">
<div class="d-flex justify-content-between mb-1">
<span class="text-muted">GPU</span>
<span class="text-info" x-text="provider.hardware || 'N/A'"></span>
<div class="mb-2">
<span class="text-muted d-block mb-1">Hardware</span>
<div class="d-flex gap-2">
<template x-for="profile in getReplicateBaseHardwareOptions()" :key="profile.id">
<button type="button"
class="btn btn-sm rounded-pill flex-fill"
data-testid="cloud-settings-hardware-profile"
:class="getSelectedReplicateBaseHardware() === profile.id ? 'btn-info text-dark' : 'btn-outline-secondary'"
@click="setReplicateBaseHardwareProfile(profile.id)">
Comment on lines +170 to +175
Comment on lines +170 to +175
<span x-text="profile.label || profile.id.toUpperCase()"></span>
</button>
</template>
</div>
</div>
<div class="d-flex justify-content-between">
<span class="text-muted">Cost</span>
<span x-text="provider.cost_per_hour ? ('$' + provider.cost_per_hour.toFixed(2) + '/hr') : 'N/A'"></span>
<div class="d-flex justify-content-between align-items-start">
<span class="text-muted">Cost (per hour)</span>
<span class="text-end">
<span x-text="getReplicateBaseHardwareCostDisplay()"></span>
<span class="d-block text-muted" style="font-size: 0.7rem;"
x-text="getReplicateBaseHardwareCostDetail()"></span>
</span>
</div>
</div>
</template>
<div x-show="activeProvider !== 'replicate'">
<template x-for="provider in providers.filter(p => p.id === activeProvider)" :key="provider.id">
<div class="small">
<div class="d-flex justify-content-between mb-1">
<span class="text-muted">GPU</span>
<span class="text-info" x-text="provider.hardware || 'N/A'"></span>
</div>
<div class="d-flex justify-content-between">
<span class="text-muted">Cost</span>
<span x-text="provider.cost_per_hour ? ('$' + provider.cost_per_hour.toFixed(2) + '/hr') : 'N/A'"></span>
</div>
</div>
</template>
</div>
</div>
</div>

Expand Down
2 changes: 1 addition & 1 deletion simpletuner/templates/partials/cloud_hero_cta.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ <h2 class="hero-title">Get Started with Replicate</h2>
<div class="hero-features">
<div class="hero-feature">
<i class="fas fa-bolt text-warning"></i>
<span>L40S & A100 GPUs</span>
<span>L40S & H100 GPUs</span>
</div>
<div class="hero-feature">
<i class="fas fa-dollar-sign text-success"></i>
Expand Down
22 changes: 19 additions & 3 deletions simpletuner/templates/partials/cloud_onboarding_flow.html
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,27 @@ <h3>What This Will Cost</h3>
<div class="cost-breakdown">
<div class="cost-row">
<span class="cost-label">Hardware</span>
<span class="cost-value">L40S (48GB VRAM)</span>
<span class="cost-value">
<span class="d-flex gap-2 justify-content-end">
<template x-for="profile in getReplicateBaseHardwareOptions()" :key="profile.id">
<button type="button"
class="btn btn-sm rounded-pill"
data-testid="cloud-onboarding-hardware-profile"
:class="getSelectedReplicateBaseHardware() === profile.id ? 'btn-info text-dark' : 'btn-outline-secondary'"
@click="setReplicateBaseHardwareProfile(profile.id)">
Comment on lines +230 to +235
<span x-text="profile.label || profile.id.toUpperCase()"></span>
</button>
</template>
</span>
</span>
</div>
<div class="cost-row">
<span class="cost-label">Rate</span>
<span class="cost-value">~$3.50/hour</span>
<span class="cost-label">Cost (per hour)</span>
<span class="cost-value">
<span x-text="getReplicateBaseHardwareCostDisplay()"></span>
<span class="d-block text-muted" style="font-size: 0.7rem;"
x-text="getReplicateBaseHardwareCostDetail()"></span>
</span>
</div>
<div class="cost-row">
<span class="cost-label">Typical LoRA (2000 steps)</span>
Expand Down
43 changes: 43 additions & 0 deletions tests/js/job_submission.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,49 @@ describe('cloudSubmissionMethods', () => {
'l40s-x8',
]);
});

test('base hardware options expose l40s and h100 buttons', () => {
context.providers = [{
id: 'replicate',
hardware_profiles: [
{ id: 'h100', label: 'H100', cost_per_hour: 5.49 },
{ id: 'h100-x2', label: '2x H100', cost_per_hour: 10.98 },
{ id: 'l40s', label: 'L40S', cost_per_hour: 3.50 },
{ id: 'l40s-x2', label: '2x L40S', cost_per_hour: 7.00 },
],
}];

const options = context.getReplicateBaseHardwareOptions();

expect(options.map((option) => option.id)).toEqual(['l40s', 'h100']);
});

test('settings hardware selector persists selected base profile', () => {
context.preSubmitModal.hardwareProfile = 'h100';

context.setReplicateBaseHardwareProfile('l40s');

expect(context.preSubmitModal.hardwareProfile).toBe('l40s');
expect(localStorage.getItem('cloud_replicate_hardware_profile')).toBe('l40s');
});

test('hardware hourly cost display follows selected base hardware', () => {
context.providers = [{
id: 'replicate',
hardware_profiles: [
{ id: 'h100', label: 'H100', cost_per_hour: 5.49 },
{ id: 'l40s', label: 'L40S', cost_per_hour: 3.50 },
],
}];

context.preSubmitModal.hardwareProfile = 'h100';
expect(context.getReplicateBaseHardwareCostDisplay()).toBe('$5.49/hr');
expect(context.getReplicateBaseHardwareCostDetail()).toBe('$0.001525/sec');

context.preSubmitModal.hardwareProfile = 'l40s';
expect(context.getReplicateBaseHardwareCostDisplay()).toBe('$3.50/hr');
expect(context.getReplicateBaseHardwareCostDetail()).toBe('$0.000972/sec');
});
});

describe('prepareJobPayload', () => {
Expand Down
6 changes: 6 additions & 0 deletions tests/test_replicate_hardware_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def test_invalid_profile_raises(self):
with self.assertRaises(ValueError):
normalize_replicate_hardware_profile("a10g")

def test_default_hardware_info_includes_h100_and_l40s_pricing(self):
from simpletuner.simpletuner_sdk.server.services.cloud.replicate_client import DEFAULT_HARDWARE_INFO

self.assertEqual(DEFAULT_HARDWARE_INFO["gpu-l40s"]["cost_per_second"], 0.000972222)
self.assertEqual(DEFAULT_HARDWARE_INFO["gpu-h100"]["cost_per_second"], 0.001525)

async def test_replicate_client_uses_profile_model_for_latest_version(self):
from simpletuner.simpletuner_sdk.server.services.cloud.replicate_client import ReplicateCogClient

Expand Down
Loading