Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 107 additions & 84 deletions plugins/stashAI/stashai.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
(function () {
"use strict";

let STASHMARKER_API_URL = "https://cc1234-stashtag.hf.space/api/predict";
let STASHMARKER_API_URL = "https://cc1234-stashtag-onnx.hf.space/gradio_api/call/predict_tags";
let STASHMARKER_API_MARKER = "https://cc1234-stashtag-onnx.hf.space/gradio_api/call/predict_markers";

var OPTIONS = [
"Anal",
Expand Down Expand Up @@ -2543,6 +2544,80 @@
});
}

async function gradioCall(url, image, vtt, threshold, retries = 3) {
for (let attempt = 0; attempt < retries; attempt++) {
try {
return await _gradioCall(url, image, vtt, threshold);
} catch (err) {
if (attempt === retries - 1) throw err;
await new Promise((r) => setTimeout(r, 3000));
}
}
}

async function _gradioCall(url, image, vtt, threshold) {
const body = {
data: [
{ url: image, meta: { _type: "gradio.FileData" } },
vtt,
threshold,
],
};

const response = await fetch(url, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(body),
});

if (!response.ok) {
throw new Error("HTTP " + response.status);
}

const { event_id } = await response.json();
const sseUrl = url + "/" + event_id;

const controller = new AbortController();
const timeout = setTimeout(() => controller.abort(), 120000);

let text;
try {
const resp = await fetch(sseUrl, { signal: controller.signal });
if (!resp.ok) {
throw new Error("HTTP " + resp.status);
}
text = await resp.text();
} finally {
clearTimeout(timeout);
}

let currentEvent = "";
let currentData = "";

for (const line of text.split("\n")) {
if (line.startsWith("event: ")) {
currentEvent = line.slice(7).trim();
} else if (line.startsWith("data: ")) {
currentData = line.slice(6);
} else if (line === "") {
if (currentEvent === "complete") {
try {
return JSON.parse(currentData);
} catch (e) {
throw new Error("Failed to parse result");
}
}
if (currentEvent === "error") {
throw new Error(currentData || "API error");
}
currentEvent = "";
currentData = "";
}
}

throw new Error("No result received");
}

function instance$3($$self, $$props, $$invalidate) {
let { $$slots: slots = {}, $$scope } = $$props;
validate_slots("MarkerButton", slots, []);
Expand All @@ -2569,51 +2644,23 @@

let vtt = await download(vtt_url);

// query the api with a threshold of 0.4 as we want to do the filtering ourselves
var data = { data: [image, vtt, 0.4] };

fetch(STASHMARKER_API_URL + "_1", {
method: "POST",
headers: {
"Content-Type": "application/json; charset=utf-8",
},
body: JSON.stringify(data),
})
.then((response) => {
if (response.status !== 200) {
$$invalidate(0, (scanner = false));
alert(
"Something went wrong. It's likely a server issue, Please try again later."
);
return;
}

return response.json();
})
.then((data) => {
$$invalidate(0, (scanner = false));
let frames = data.data[0];
$$invalidate(0, (scanner = false));

if (frames.length === 0) {
alert("No tags found");
return;
}
try {
let result = await gradioCall(STASHMARKER_API_MARKER, image, vtt, 0.4);
let frames = result[0];

// find a div with class row
let row = document.querySelector(".row");
$$invalidate(0, (scanner = false));

new MarkerMatches({ target: row, props: { frames, url } });
})
.catch((error) => {
$$invalidate(0, (scanner = false));
if (!frames || frames.length === 0) {
alert("No tags found");
return;
}

if (error.message === "") {
alert("Error: Service may be down. please try again later.");
} else {
alert("Error: " + error.message);
}
});
let row = document.querySelector(".row");
new MarkerMatches({ target: row, props: { frames, url } });
} catch (error) {
$$invalidate(0, (scanner = false));
alert("Error: " + (error.message || "Service may be down. Please try again later."));
}
}

const writable_props = [];
Expand Down Expand Up @@ -4027,52 +4074,28 @@
reader.readAsDataURL(vblob);
});

// query the api with a threshold of 0.2 as we want to do the filtering ourselves
var data = { data: [image, vtt, 0.2] };

fetch(STASHMARKER_API_URL, {
method: "POST",
headers: {
"Content-Type": "application/json; charset=utf-8",
},
body: JSON.stringify(data),
})
.then((response) => {
if (response.status !== 200) {
$$invalidate(0, (scanner = false));
alert(
"Something went wrong. It's likely a server issue, Please try again later."
);
return;
}
try {
let result = await gradioCall(STASHMARKER_API_URL, image, vtt, 0.2);
let tags = {};
result.forEach((item) => Object.assign(tags, item));

return response.json();
})
.then((data) => {
$$invalidate(0, (scanner = false));
$$invalidate(0, (scanner = false));

if (data.data[0].length === 0) {
alert("No tags found");
return;
}
if (Object.keys(tags).length === 0) {
alert("No tags found");
return;
}

// grab stash-tag-threshold from local storage or set to default
let threshold = localStorage.getItem("stash-tag-threshold") || 0.4;
let threshold = localStorage.getItem("stash-tag-threshold") || 0.4;

new TagMatches({
target: document.body,
props: { matches: data.data[0], url, threshold },
});
})
.catch((error) => {
$$invalidate(0, (scanner = false));

if (error.message === "") {
alert("Error: Service may be down. please try again later.");
} else {
alert("Error: " + error.message);
}
new TagMatches({
target: document.body,
props: { matches: tags, url, threshold },
});
} catch (error) {
$$invalidate(0, (scanner = false));
alert("Error: " + (error.message || "Service may be down. Please try again later."));
}
}

const writable_props = [];
Expand Down
2 changes: 1 addition & 1 deletion plugins/stashAI/stashai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ ui:
- stashai.css
csp:
connect-src:
- "https://cc1234-stashtag.hf.space"
- "https://cc1234-stashtag-onnx.hf.space"
Loading