Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion simpletuner/simpletuner_sdk/server/services/tab_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,12 @@ def _git_mirror_enabled(self) -> bool:
def _cloud_tab_enabled(self) -> bool:
try:
defaults = WebUIStateStore().load_defaults()
return bool(getattr(defaults, "cloud_tab_enabled", True))
value = getattr(defaults, "cloud_tab_enabled", True)
if value is None:
return True
if isinstance(value, str):
return value.strip().lower() not in {"0", "false", "no", "off"}
return bool(value)
except Exception as exc:
logger.debug("Failed to evaluate cloud tab enabled flag: %s", exc, exc_info=True)
return True
Expand Down
16 changes: 14 additions & 2 deletions simpletuner/simpletuner_sdk/server/services/webui_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,13 @@ def load_defaults(self) -> WebUIDefaults:
payload = self._read_json("defaults")
if not payload:
return WebUIDefaults()
base_defaults = WebUIDefaults()
data: Dict[str, Any] = {}
for key in WebUIDefaults().__dict__.keys():
for key, default_value in base_defaults.__dict__.items():
if key == "accelerate_overrides":
data[key] = _normalise_accelerate_overrides(payload.get(key))
else:
data[key] = payload.get(key)
data[key] = payload.get(key, default_value)
defaults = WebUIDefaults(**data)

# Normalise theme selection
Expand Down Expand Up @@ -429,6 +430,17 @@ def load_defaults(self) -> WebUIDefaults:
defaults.cloud_dataloader_hint_dismissed = bool(payload.get("cloud_dataloader_hint_dismissed", False))
defaults.cloud_git_hint_dismissed = bool(payload.get("cloud_git_hint_dismissed", False))

# Normalise cloud tab enabled (default True)
cloud_tab_value = payload.get("cloud_tab_enabled")
if cloud_tab_value is None:
defaults.cloud_tab_enabled = True
elif isinstance(cloud_tab_value, bool):
defaults.cloud_tab_enabled = cloud_tab_value
elif isinstance(cloud_tab_value, str):
defaults.cloud_tab_enabled = cloud_tab_value.strip().lower() not in {"0", "false", "no", "off"}
else:
defaults.cloud_tab_enabled = bool(cloud_tab_value)

# Normalise cloud data consent
consent_value = payload.get("cloud_data_consent")
if isinstance(consent_value, str) and consent_value in {"ask", "allow", "deny"}:
Expand Down
6 changes: 4 additions & 2 deletions simpletuner/static/js/modules/cloud/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ if (!window.cloudDashboardComponent) {
this.setupStatus.outputConfigured =
this.publishingStatus.push_to_hub ||
this.publishingStatus.s3_configured ||
(this.webhookUrl && this.webhookUrl.trim().length > 0);
(this.savedWebhookUrl && this.savedWebhookUrl.trim().length > 0) ||
this.publishingStatus.local_upload_available;
},

// Note: hasDatasets getter moved to final return object
Expand Down Expand Up @@ -677,7 +678,8 @@ if (!window.cloudDashboardComponent) {
if (!this.publishingStatus) return false;
return this.publishingStatus.push_to_hub ||
this.publishingStatus.s3_configured ||
(this.webhookUrl && this.webhookUrl.trim().length > 0);
(this.savedWebhookUrl && this.savedWebhookUrl.trim().length > 0) ||
this.publishingStatus.local_upload_available;
},
get allSetupComplete() {
return this.hasDatasets && this.hasActiveConfig && this.hasOutputDestination;
Expand Down
29 changes: 26 additions & 3 deletions simpletuner/static/js/modules/cloud/metrics.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,42 @@ window.cloudMetricsMethods = {

async saveWebhookConfig() {
this.configSaving = true;
const webhookUrl = typeof this.webhookUrl === 'string' ? this.webhookUrl.trim() : '';
try {
const response = await fetch('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: this.webhookUrl || null }),
body: JSON.stringify({ webhook_url: webhookUrl }),
});
if (response.ok && window.showToast) {

let data = {};
try {
data = await response.json();
} catch (_) {}

if (!response.ok) {
throw new Error(data.detail || 'Failed to save webhook config');
}

const savedUrl = (data.config && data.config.webhook_url) || data.webhook_url || webhookUrl || '';
this.savedWebhookUrl = savedUrl;
this.webhookUrl = savedUrl;
if (this.publishingStatus) {
this.publishingStatus.local_upload_available = savedUrl.length > 0;
if (!savedUrl) {
this.publishingStatus.local_upload_dir = null;
}
}

if (window.showToast) {
window.showToast('Webhook configuration saved', 'success');
}
return true;
} catch (error) {
if (window.showToast) {
window.showToast('Failed to save webhook config', 'error');
window.showToast(error.message || 'Failed to save webhook config', 'error');
}
return false;
} finally {
this.configSaving = false;
}
Expand Down
5 changes: 3 additions & 2 deletions simpletuner/static/js/modules/cloud/providers.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ window.cloudProviderMethods = {
const response = await fetch(`/api/cloud/providers/${this.activeProvider}/config`);
if (response.ok) {
const data = await response.json();
this.providerConfig = data || {};
this.webhookUrl = data.webhook_url || '';
this.providerConfig = data.config || data || {};
this.savedWebhookUrl = this.providerConfig.webhook_url || '';
this.webhookUrl = this.savedWebhookUrl;
this.loadCostLimitStatus();
}
} catch (error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ window.cloudPublishingStateFactory = function(initial) {
availableConfigs: [],
selectedConfigName: null,
webhookUrl: initialData.webhook_url || '',
savedWebhookUrl: initialData.webhook_url || '',
webhookTesting: false,
webhookTestMode: null,
webhookTestResult: null,
Expand Down
2 changes: 1 addition & 1 deletion simpletuner/templates/cloud_tab.html
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ <h6 class="mb-0"><i class="fas fa-download me-2"></i>Local Outputs</h6>
<div class="mb-2">
<label class="small text-muted mb-1 d-block" style="font-size: 0.7rem;">Endpoint URL</label>
<code class="d-block small px-2 py-1 rounded" style="background: rgba(0,0,0,0.2); word-break: break-all; font-size: 0.7rem;"
x-text="webhookUrl.replace(/\/$/, '') + '/api/cloud/storage'"></code>
x-text="savedWebhookUrl.replace(/\/$/, '') + '/api/cloud/storage'"></code>
</div>
<div>
<label class="small text-muted mb-1 d-block" style="font-size: 0.7rem;">Output Directory</label>
Expand Down
2 changes: 1 addition & 1 deletion simpletuner/templates/partials/cloud_dataloader_hint.html
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ <h6 class="mb-0">
<template x-if="publishingStatus.s3_configured && !publishingStatus.push_to_hub">
<span class="text-warning"><i class="fab fa-aws me-1"></i>S3/Cloud Storage</span>
</template>
<template x-if="webhookUrl && webhookUrl.trim().length > 0 && !publishingStatus.push_to_hub && !publishingStatus.s3_configured">
<template x-if="savedWebhookUrl && savedWebhookUrl.trim().length > 0 && !publishingStatus.push_to_hub && !publishingStatus.s3_configured">
<span class="text-info"><i class="fas fa-link me-1"></i>Webhook to local machine</span>
</template>
</div>
Expand Down
4 changes: 2 additions & 2 deletions simpletuner/templates/partials/cloud_onboarding_flow.html
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ <h3>How You Get Your Model Back</h3>
Replicate uploads the model back here. Requires exposing your server via ngrok or cloudflare tunnel.
</p>
<div class="option-status">
<template x-if="webhookUrl && webhookUrl.trim().length > 0">
<template x-if="savedWebhookUrl && savedWebhookUrl.trim().length > 0">
<span class="text-success small"><i class="fas fa-check-circle me-1"></i>Webhook configured</span>
</template>
<template x-if="!webhookUrl || webhookUrl.trim().length === 0">
<template x-if="!savedWebhookUrl || savedWebhookUrl.trim().length === 0">
<button type="button"
class="btn btn-sm btn-outline-secondary"
@click="showSettingsPanel = true">
Expand Down
101 changes: 101 additions & 0 deletions tests/js/cloud_metrics.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/**
* Tests for cloud metrics and webhook configuration methods.
*/

global.fetch = jest.fn();

global.console = {
...console,
error: jest.fn(),
};

global.showToast = jest.fn();
window.showToast = global.showToast;

require('../../simpletuner/static/js/modules/cloud/metrics.js');

describe('cloudMetricsMethods webhook configuration', () => {
let context;

beforeEach(() => {
jest.resetAllMocks();
fetch.mockReset();
context = {
webhookUrl: '',
savedWebhookUrl: '',
configSaving: false,
publishingStatus: {
local_upload_available: false,
local_upload_dir: null,
},
};
});

test('saves valid webhook and marks local upload configured after success', async () => {
context.webhookUrl = ' https://webhook.example.com ';
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: { webhook_url: 'https://webhook.example.com' },
}),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(true);
expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: 'https://webhook.example.com' }),
});
expect(context.savedWebhookUrl).toBe('https://webhook.example.com');
expect(context.webhookUrl).toBe('https://webhook.example.com');
expect(context.publishingStatus.local_upload_available).toBe(true);
expect(window.showToast).toHaveBeenCalledWith('Webhook configuration saved', 'success');
});

test('keeps draft webhook visible when save fails validation', async () => {
context.webhookUrl = 'h';
context.savedWebhookUrl = '';
fetch.mockResolvedValueOnce({
ok: false,
json: () => Promise.resolve({ detail: 'Invalid webhook URL: Invalid URL format' }),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(false);
expect(context.webhookUrl).toBe('h');
expect(context.savedWebhookUrl).toBe('');
expect(context.publishingStatus.local_upload_available).toBe(false);
expect(window.showToast).toHaveBeenCalledWith('Invalid webhook URL: Invalid URL format', 'error');
});

test('sends empty string when clearing webhook configuration', async () => {
context.webhookUrl = '';
context.savedWebhookUrl = 'https://old-webhook.example.com';
context.publishingStatus.local_upload_available = true;
context.publishingStatus.local_upload_dir = '/tmp/outputs';
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: { webhook_url: null },
}),
});

const result = await window.cloudMetricsMethods.saveWebhookConfig.call(context);

expect(result).toBe(true);
expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config', {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ webhook_url: '' }),
});
expect(context.savedWebhookUrl).toBe('');
expect(context.webhookUrl).toBe('');
expect(context.publishingStatus.local_upload_available).toBe(false);
expect(context.publishingStatus.local_upload_dir).toBeNull();
});
});
23 changes: 23 additions & 0 deletions tests/js/cloud_providers.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ describe('cloudProviderMethods', () => {
activeProvider: 'replicate',
providerConfig: {},
webhookUrl: '',
savedWebhookUrl: '',
costLimit: {
loading: false,
saving: false,
Expand Down Expand Up @@ -155,6 +156,28 @@ describe('cloudProviderMethods', () => {

expect(fetch).toHaveBeenCalledWith('/api/cloud/providers/replicate/config');
expect(context.webhookUrl).toBe('https://webhook.example.com');
expect(context.savedWebhookUrl).toBe('https://webhook.example.com');
});

test('loads webhook URL from wrapped provider config response', async () => {
fetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
provider: 'replicate',
config: {
webhook_url: 'https://wrapped-webhook.example.com',
version_override: 'v2.0.0',
},
}),
});

context.loadCostLimitStatus = jest.fn();

await window.cloudProviderMethods.loadProviderConfig.call(context);

expect(context.providerConfig.webhook_url).toBe('https://wrapped-webhook.example.com');
expect(context.webhookUrl).toBe('https://wrapped-webhook.example.com');
expect(context.savedWebhookUrl).toBe('https://wrapped-webhook.example.com');
});

test('calls loadCostLimitStatus after loading config', async () => {
Expand Down
28 changes: 28 additions & 0 deletions tests/test_git_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,34 @@ def test_git_tab_gating(self) -> None:
self.assertIsNotNone(config)


class TabServiceCloudGatingTests(unittest.TestCase):
def test_cloud_tab_enabled_when_setting_is_missing_or_null(self) -> None:
dummy_templates = SimpleNamespace(TemplateResponse=lambda request, name, context: {"name": name, "context": context})
with patch("simpletuner.simpletuner_sdk.server.services.tab_service.WebUIStateStore") as mock_store_cls:
defaults = SimpleNamespace(cloud_tab_enabled=None)
mock_store = mock_store_cls.return_value
mock_store.load_defaults.return_value = defaults
service = TabService(dummy_templates) # type: ignore[arg-type]

tabs = service.get_all_tabs()
self.assertTrue(any(tab["name"] == "cloud" for tab in tabs))
config = service.get_tab_config("cloud")
self.assertIsNotNone(config)

def test_cloud_tab_disabled_when_setting_is_false(self) -> None:
dummy_templates = SimpleNamespace(TemplateResponse=lambda request, name, context: {"name": name, "context": context})
with patch("simpletuner.simpletuner_sdk.server.services.tab_service.WebUIStateStore") as mock_store_cls:
defaults = SimpleNamespace(cloud_tab_enabled=False)
mock_store = mock_store_cls.return_value
mock_store.load_defaults.return_value = defaults
service = TabService(dummy_templates) # type: ignore[arg-type]

tabs = service.get_all_tabs()
self.assertFalse(any(tab["name"] == "cloud" for tab in tabs))
with self.assertRaises(HTTPException):
service.get_tab_config("cloud")


class GitRepoServiceBranchTests(unittest.TestCase):
def setUp(self) -> None:
self._config_store_instances = ConfigStore._instances.copy()
Expand Down
28 changes: 28 additions & 0 deletions tests/test_webui_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""unittest-based coverage for WebUI backend state management."""

import json
import os
import tempfile
import unittest
Expand Down Expand Up @@ -106,6 +107,33 @@ def test_load_defaults_returns_empty_when_missing(self) -> None:
self.assertIsNone(loaded.configs_dir)
self.assertIsNone(loaded.output_dir)
self.assertIsNone(loaded.active_config)
self.assertTrue(loaded.cloud_tab_enabled)

def test_legacy_defaults_without_cloud_tab_enabled_keeps_cloud_enabled(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"theme": "dark"}))

loaded = self.store.load_defaults()
bundle = self.store.resolve_defaults(loaded)

self.assertTrue(loaded.cloud_tab_enabled)
self.assertTrue(bundle["resolved"]["cloud_tab_enabled"])

def test_null_cloud_tab_enabled_keeps_default_enabled(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"cloud_tab_enabled": None}))

loaded = self.store.load_defaults()

self.assertTrue(loaded.cloud_tab_enabled)

def test_false_cloud_tab_enabled_is_preserved(self) -> None:
defaults_file = self.store.base_dir / "defaults.json"
defaults_file.write_text(json.dumps({"cloud_tab_enabled": False}))

loaded = self.store.load_defaults()

self.assertFalse(loaded.cloud_tab_enabled)

def test_load_onboarding_returns_empty_when_missing(self) -> None:
loaded = self.store.load_onboarding()
Expand Down
Loading
Loading