Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion simpletuner/simpletuner_sdk/server/services/tab_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,12 @@ def _git_mirror_enabled(self) -> bool:
def _cloud_tab_enabled(self) -> bool:
try:
defaults = WebUIStateStore().load_defaults()
return bool(getattr(defaults, "cloud_tab_enabled", True))
value = getattr(defaults, "cloud_tab_enabled", True)
if value is None:
return True
if isinstance(value, str):
return value.strip().lower() not in {"0", "false", "no", "off"}
return bool(value)
except Exception as exc:
logger.debug("Failed to evaluate cloud tab enabled flag: %s", exc, exc_info=True)
return True
Expand Down
16 changes: 14 additions & 2 deletions simpletuner/simpletuner_sdk/server/services/webui_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,13 @@ def load_defaults(self) -> WebUIDefaults:
payload = self._read_json("defaults")
if not payload:
return WebUIDefaults()
base_defaults = WebUIDefaults()
data: Dict[str, Any] = {}
for key in WebUIDefaults().__dict__.keys():
for key, default_value in base_defaults.__dict__.items():
if key == "accelerate_overrides":
data[key] = _normalise_accelerate_overrides(payload.get(key))
else:
data[key] = payload.get(key)
data[key] = payload.get(key, default_value)
defaults = WebUIDefaults(**data)

# Normalise theme selection
Expand Down Expand Up @@ -429,6 +430,17 @@ def load_defaults(self) -> WebUIDefaults:
defaults.cloud_dataloader_hint_dismissed = bool(payload.get("cloud_dataloader_hint_dismissed", False))
defaults.cloud_git_hint_dismissed = bool(payload.get("cloud_git_hint_dismissed", False))

# Normalise cloud tab enabled (default True)
cloud_tab_value = payload.get("cloud_tab_enabled")
if cloud_tab_value is None:
defaults.cloud_tab_enabled = True
elif isinstance(cloud_tab_value, bool):
defaults.cloud_tab_enabled = cloud_tab_value
elif isinstance(cloud_tab_value, str):
defaults.cloud_tab_enabled = cloud_tab_value.strip().lower() not in {"0", "false", "no", "off"}
else:
defaults.cloud_tab_enabled = bool(cloud_tab_value)

# Normalise cloud data consent
consent_value = payload.get("cloud_data_consent")
if isinstance(consent_value, str) and consent_value in {"ask", "allow", "deny"}:
Expand Down
6 changes: 4 additions & 2 deletions simpletuner/static/js/modules/cloud/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ if (!window.cloudDashboardComponent) {
this.setupStatus.outputConfigured =
this.publishingStatus.push_to_hub ||
this.publishingStatus.s3_configured ||
(this.webhookUrl && this.webhookUrl.trim().length > 0);
(this.savedWebhookUrl && this.savedWebhookUrl.trim().length > 0) ||
this.publishingStatus.local_upload_available;
},

// Note: hasDatasets getter moved to final return object
Expand Down Expand Up @@ -677,7 +678,8 @@ if (!window.cloudDashboardComponent) {
if (!this.publishingStatus) return false;
return this.publishingStatus.push_to_hub ||
this.publishingStatus.s3_configured ||
(this.webhookUrl && this.webhookUrl.trim().length > 0);
(this.savedWebhookUrl && this.savedWebhookUrl.trim().length > 0) ||
this.publishingStatus.local_upload_available;
},
get allSetupComplete() {
return this.hasDatasets && this.hasActiveConfig && this.hasOutputDestination;
Expand Down
29 changes: 26 additions & 3 deletions simpletuner/static/js/modules/cloud/metrics.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,42 @@ window.cloudMetricsMethods = {

async saveWebhookConfig() {
this.configSaving = true;
const webhookUrl = typeof this.webhookUrl === 'string' ? this.webhookUrl.trim() : '';
try {
const response = await fetch('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: this.webhookUrl || null }),
body: JSON.stringify({ webhook_url: webhookUrl }),
});
if (response.ok && window.showToast) {

let data = {};
try {
data = await response.json();
} catch (_) {}

if (!response.ok) {
throw new Error(data.detail || 'Failed to save webhook config');
}

const savedUrl = (data.config && data.config.webhook_url) || data.webhook_url || webhookUrl || '';
this.savedWebhookUrl = savedUrl;
this.webhookUrl = savedUrl;
if (this.publishingStatus) {
this.publishingStatus.local_upload_available = savedUrl.length > 0;
if (!savedUrl) {
this.publishingStatus.local_upload_dir = null;
}
}

if (window.showToast) {
window.showToast('Webhook configuration saved', 'success');
}
return true;
} catch (error) {
if (window.showToast) {
window.showToast('Failed to save webhook config', 'error');
window.showToast(error.message || 'Failed to save webhook config', 'error');
}
return false;
} finally {
this.configSaving = false;
}
Expand Down
5 changes: 3 additions & 2 deletions simpletuner/static/js/modules/cloud/providers.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ window.cloudProviderMethods = {
const response = await fetch(`/api/cloud/providers/${this.activeProvider}/config`);
if (response.ok) {
const data = await response.json();
this.providerConfig = data || {};
this.webhookUrl = data.webhook_url || '';
this.providerConfig = data.config || data || {};
this.savedWebhookUrl = this.providerConfig.webhook_url || '';
this.webhookUrl = this.savedWebhookUrl;
this.loadCostLimitStatus();
}
} catch (error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ window.cloudPublishingStateFactory = function(initial) {
availableConfigs: [],
selectedConfigName: null,
webhookUrl: initialData.webhook_url || '',
savedWebhookUrl: initialData.webhook_url || '',
webhookTesting: false,
webhookTestMode: null,
webhookTestResult: null,
Expand Down
2 changes: 1 addition & 1 deletion simpletuner/templates/cloud_tab.html
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ <h6 class="mb-0"><i class="fas fa-download me-2"></i>Local Outputs</h6>
<div class="mb-2">
<label class="small text-muted mb-1 d-block" style="font-size: 0.7rem;">Endpoint URL</label>
<code class="d-block small px-2 py-1 rounded" style="background: rgba(0,0,0,0.2); word-break: break-all; font-size: 0.7rem;"
x-text="webhookUrl.replace(/\/$/, '') + '/api/cloud/storage'"></code>
x-text="savedWebhookUrl.replace(/\/$/, '') + '/api/cloud/storage'"></code>
</div>
<div>
<label class="small text-muted mb-1 d-block" style="font-size: 0.7rem;">Output Directory</label>
Expand Down
2 changes: 1 addition & 1 deletion simpletuner/templates/partials/cloud_dataloader_hint.html
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ <h6 class="mb-0">
<template x-if="publishingStatus.s3_configured && !publishingStatus.push_to_hub">
<span class="text-warning"><i class="fab fa-aws me-1"></i>S3/Cloud Storage</span>
</template>
<template x-if="webhookUrl && webhookUrl.trim().length > 0 && !publishingStatus.push_to_hub && !publishingStatus.s3_configured">
<template x-if="savedWebhookUrl && savedWebhookUrl.trim().length > 0 && !publishingStatus.push_to_hub && !publishingStatus.s3_configured">
<span class="text-info"><i class="fas fa-link me-1"></i>Webhook to local machine</span>
</template>
</div>
Expand Down
4 changes: 2 additions & 2 deletions simpletuner/templates/partials/cloud_onboarding_flow.html
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ <h3>How You Get Your Model Back</h3>
Replicate uploads the model back here. Requires exposing your server via ngrok or cloudflare tunnel.
</p>
<div class="option-status">
<template x-if="webhookUrl && webhookUrl.trim().length > 0">
<template x-if="savedWebhookUrl && savedWebhookUrl.trim().length > 0">
<span class="text-success small"><i class="fas fa-check-circle me-1"></i>Webhook configured</span>
</template>
<template x-if="!webhookUrl || webhookUrl.trim().length === 0">
<template x-if="!savedWebhookUrl || savedWebhookUrl.trim().length === 0">
<button type="button"
class="btn btn-sm btn-outline-secondary"
@click="showSettingsPanel = true">
Expand Down
101 changes: 101 additions & 0 deletions tests/js/cloud_metrics.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* Tests for cloud metrics and webhook configuration methods.
*/

global.fetch = jest.fn();

global.console = {
...console,
error: jest.fn(),
};

global.showToast = jest.fn();
window.showToast = global.showToast;

require('../../simpletuner/static/js/modules/cloud/metrics.js');

describe('cloudMetricsMethods webhook configuration', () => {
let context;

beforeEach(() => {
jest.resetAllMocks();
fetch.mockReset();
context = {
webhookUrl: '',
savedWebhookUrl: '',
configSaving: false,
publishingStatus: {
local_upload_available: false,
local_upload_dir: null,
},
};
});

test('saves valid webhook and marks local upload configured after success', async () => {
context.webhookUrl = ' https://webhook.example.com ';
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: { webhook_url: 'https://webhook.example.com' },
}),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(true);
expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: 'https://webhook.example.com' }),
});
expect(context.savedWebhookUrl).toBe('https://webhook.example.com');
expect(context.webhookUrl).toBe('https://webhook.example.com');
expect(context.publishingStatus.local_upload_available).toBe(true);
expect(window.showToast).toHaveBeenCalledWith('Webhook configuration saved', 'success');
});

test('keeps draft webhook visible when save fails validation', async () => {
context.webhookUrl = 'h';
context.savedWebhookUrl = '';
fetch.mockResolvedValueOnce({
ok: false,
json: () => Promise.resolve({ detail: 'Invalid webhook URL: Invalid URL format' }),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(false);
expect(context.webhookUrl).toBe('h');
expect(context.savedWebhookUrl).toBe('');
expect(context.publishingStatus.local_upload_available).toBe(false);
expect(window.showToast).toHaveBeenCalledWith('Invalid webhook URL: Invalid URL format', 'error');
});

test('sends empty string when clearing webhook configuration', async () => {
context.webhookUrl = '';
context.savedWebhookUrl = 'https://old-webhook.example.com';
context.publishingStatus.local_upload_available = true;
context.publishingStatus.local_upload_dir = '/tmp/outputs';
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: { webhook_url: null },
}),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(true);
expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: '' }),
});
expect(context.savedWebhookUrl).toBe('');
expect(context.webhookUrl).toBe('');
expect(context.publishingStatus.local_upload_available).toBe(false);
expect(context.publishingStatus.local_upload_dir).toBeNull();
});
});
23 changes: 23 additions & 0 deletions tests/js/cloud_providers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ describe('cloudProviderMethods', () => {
activeProvider: 'replicate',
providerConfig: {},
webhookUrl: '',
savedWebhookUrl: '',
costLimit: {
loading: false,
saving: false,
Expand Down Expand Up @@ -155,6 +156,28 @@ describe('cloudProviderMethods', () => {

expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config');
expect(context.webhookUrl).toBe('https://webhook.example.com');
expect(context.savedWebhookUrl).toBe('https://webhook.example.com');
});

test('loads webhook URL from wrapped provider config response', async () => {
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: {
webhook_url: 'https://wrapped-webhook.example.com',
version_override: 'v2.0.0',
},
}),
});

context.loadCostLimitStatus = jest.fn();

await window.cloudProviderMethods.loadProviderConfig.call(context);

expect(context.providerConfig.webhook_url).toBe('https://wrapped-webhook.example.com');
expect(context.webhookUrl).toBe('https://wrapped-webhook.example.com');
expect(context.savedWebhookUrl).toBe('https://wrapped-webhook.example.com');
});

test('calls loadCostLimitStatus after loading config', async () => {
Expand Down
28 changes: 28 additions & 0 deletions tests/test_git_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,34 @@ def test_git_tab_gating(self) -> None:
self.assertIsNotNone(config)


class TabServiceCloudGatingTests(unittest.TestCase):
def test_cloud_tab_enabled_when_setting_is_missing_or_null(self) -> None:
dummy_templates = SimpleNamespace(TemplateResponse=lambda request, name, context: {"name": name, "context": context})
with patch("simpletuner.simpletuner_sdk.server.services.tab_service.WebUIStateStore") as mock_store_cls:
defaults = SimpleNamespace(cloud_tab_enabled=None)
mock_store = mock_store_cls.return_value
mock_store.load_defaults.return_value = defaults
service = TabService(dummy_templates) # type: ignore[arg-type]

tabs = service.get_all_tabs()
self.assertTrue(any(tab["name"] == "cloud" for tab in tabs))
config = service.get_tab_config("cloud")
self.assertIsNotNone(config)

def test_cloud_tab_disabled_when_setting_is_false(self) -> None:
dummy_templates = SimpleNamespace(TemplateResponse=lambda request, name, context: {"name": name, "context": context})
with patch("simpletuner.simpletuner_sdk.server.services.tab_service.WebUIStateStore") as mock_store_cls:
defaults = SimpleNamespace(cloud_tab_enabled=False)
mock_store = mock_store_cls.return_value
mock_store.load_defaults.return_value = defaults
service = TabService(dummy_templates) # type: ignore[arg-type]

tabs = service.get_all_tabs()
self.assertFalse(any(tab["name"] == "cloud" for tab in tabs))
with self.assertRaises(HTTPException):
service.get_tab_config("cloud")


class GitRepoServiceBranchTests(unittest.TestCase):
def setUp(self) -> None:
self._config_store_instances = ConfigStore._instances.copy()
Expand Down
28 changes: 28 additions & 0 deletions tests/test_webui_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""unittest-based coverage for WebUI backend state management."""

import json
import os
import tempfile
import unittest
Expand Down Expand Up @@ -106,6 +107,33 @@ def test_load_defaults_returns_empty_when_missing(self) -> None:
self.assertIsNone(loaded.configs_dir)
self.assertIsNone(loaded.output_dir)
self.assertIsNone(loaded.active_config)
self.assertTrue(loaded.cloud_tab_enabled)

def test_legacy_defaults_without_cloud_tab_enabled_keeps_cloud_enabled(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"theme": "dark"}))

loaded = self.store.load_defaults()
bundle = self.store.resolve_defaults(loaded)

self.assertTrue(loaded.cloud_tab_enabled)
self.assertTrue(bundle["resolved"]["cloud_tab_enabled"])

def test_null_cloud_tab_enabled_keeps_default_enabled(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"cloud_tab_enabled": None}))

loaded = self.store.load_defaults()

self.assertTrue(loaded.cloud_tab_enabled)

def test_false_cloud_tab_enabled_is_preserved(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"cloud_tab_enabled": False}))

loaded = self.store.load_defaults()

self.assertFalse(loaded.cloud_tab_enabled)

def test_load_onboarding_returns_empty_when_missing(self) -> None:
loaded = self.store.load_onboarding()
Expand Down
Loading
Loading