Skip to content

Commit a812698

Browse files
feat(contrastive): add _branchSetLogits/_storeMergeLogits NAPI bindings
Add native contrastive-decode primitives wiring: - _branchSetLogits(handle, Float32Array): overwrite a branch's cached logits_snapshot with caller-provided values. - _storeMergeLogits(dstHandle, srcHandles[], alpha): additively merge experts' logit snapshots into dst's, in place. Pure CPU op. Both bindings delegate to liblloyal v1.5.4 primitives: - lloyal::branch::set_logits() - BranchStore::merge_logits() Deps: - liblloyal: → v1.5.4 - llama.cpp: b8608 → b8795
1 parent 33dd843 commit a812698

5 files changed

Lines changed: 81 additions & 16 deletions

File tree

llama.cpp

package.json

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@lloyal-labs/lloyal.node",
3-
"version": "2.0.5",
3+
"version": "2.1.0",
44
"description": "Node.js client for liblloyal+llama.cpp",
55
"main": "dist/index.js",
66
"types": "dist/index.d.ts",
@@ -65,19 +65,19 @@
6565
"typescript": "^5.9.3"
6666
},
6767
"optionalDependencies": {
68-
"@lloyal-labs/lloyal.node-darwin-arm64": "2.0.5",
69-
"@lloyal-labs/lloyal.node-darwin-x64": "2.0.5",
70-
"@lloyal-labs/lloyal.node-linux-arm64": "2.0.5",
71-
"@lloyal-labs/lloyal.node-linux-arm64-cuda": "2.0.5",
72-
"@lloyal-labs/lloyal.node-linux-arm64-vulkan": "2.0.5",
73-
"@lloyal-labs/lloyal.node-linux-x64": "2.0.5",
74-
"@lloyal-labs/lloyal.node-linux-x64-cuda": "2.0.5",
75-
"@lloyal-labs/lloyal.node-linux-x64-vulkan": "2.0.5",
76-
"@lloyal-labs/lloyal.node-win32-arm64": "2.0.5",
77-
"@lloyal-labs/lloyal.node-win32-arm64-vulkan": "2.0.5",
78-
"@lloyal-labs/lloyal.node-win32-x64": "2.0.5",
79-
"@lloyal-labs/lloyal.node-win32-x64-cuda": "2.0.5",
80-
"@lloyal-labs/lloyal.node-win32-x64-vulkan": "2.0.5"
68+
"@lloyal-labs/lloyal.node-darwin-arm64": "2.1.0",
69+
"@lloyal-labs/lloyal.node-darwin-x64": "2.1.0",
70+
"@lloyal-labs/lloyal.node-linux-arm64": "2.1.0",
71+
"@lloyal-labs/lloyal.node-linux-arm64-cuda": "2.1.0",
72+
"@lloyal-labs/lloyal.node-linux-arm64-vulkan": "2.1.0",
73+
"@lloyal-labs/lloyal.node-linux-x64": "2.1.0",
74+
"@lloyal-labs/lloyal.node-linux-x64-cuda": "2.1.0",
75+
"@lloyal-labs/lloyal.node-linux-x64-vulkan": "2.1.0",
76+
"@lloyal-labs/lloyal.node-win32-arm64": "2.1.0",
77+
"@lloyal-labs/lloyal.node-win32-arm64-vulkan": "2.1.0",
78+
"@lloyal-labs/lloyal.node-win32-x64": "2.1.0",
79+
"@lloyal-labs/lloyal.node-win32-x64-cuda": "2.1.0",
80+
"@lloyal-labs/lloyal.node-win32-x64-vulkan": "2.1.0"
8181
},
8282
"engines": {
8383
"node": ">=22.0.0"

src/SessionContext.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) {
817817
InstanceMethod("_branchGetPosition", &SessionContext::_branchGetPosition),
818818
InstanceMethod("_branchGetPerplexity", &SessionContext::_branchGetPerplexity),
819819
InstanceMethod("_branchGetLogits", &SessionContext::_branchGetLogits),
820+
InstanceMethod("_branchSetLogits", &SessionContext::_branchSetLogits),
820821
InstanceMethod("_branchPrune", &SessionContext::_branchPrune),
821822
InstanceMethod("_branchPruneSubtree", &SessionContext::_branchPruneSubtree),
822823
InstanceMethod("_branchParent", &SessionContext::_branchParent),
@@ -839,6 +840,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) {
839840
// ===== STORE API (internal, wrapped by lib/BranchStore.js) =====
840841
InstanceMethod("_storeCommit", &SessionContext::_storeCommit),
841842
InstanceMethod("_storePrefill", &SessionContext::_storePrefill),
843+
InstanceMethod("_storeMergeLogits", &SessionContext::_storeMergeLogits),
842844
InstanceMethod("_storeRetainOnly", &SessionContext::_storeRetainOnly),
843845
InstanceMethod("_storeAvailable", &SessionContext::_storeAvailable),
844846
InstanceMethod("_storeKvPressure", &SessionContext::_storeKvPressure),
@@ -1983,6 +1985,33 @@ Napi::Value SessionContext::_branchGetLogits(const Napi::CallbackInfo& info) {
19831985
return result;
19841986
}
19851987

1988+
Napi::Value SessionContext::_branchSetLogits(const Napi::CallbackInfo& info) {
1989+
Napi::Env env = info.Env();
1990+
ensureNotDisposed();
1991+
1992+
if (info.Length() < 2 || !info[0].IsNumber() || !info[1].IsTypedArray()) {
1993+
throw Napi::Error::New(env, "_branchSetLogits requires (handle, Float32Array)");
1994+
}
1995+
1996+
auto handle = static_cast<lloyal::branch::BranchHandle>(
1997+
info[0].As<Napi::Number>().Uint32Value());
1998+
1999+
Napi::TypedArray ta = info[1].As<Napi::TypedArray>();
2000+
if (ta.TypedArrayType() != napi_float32_array) {
2001+
throw Napi::Error::New(env, "_branchSetLogits: expected Float32Array");
2002+
}
2003+
Napi::Float32Array arr = ta.As<Napi::Float32Array>();
2004+
std::span<const float> data(arr.Data(), arr.ElementLength());
2005+
2006+
try {
2007+
lloyal::branch::set_logits(handle, data, _branchStore);
2008+
} catch (const std::exception& e) {
2009+
throw Napi::Error::New(env, std::string("_branchSetLogits: ") + e.what());
2010+
}
2011+
2012+
return env.Undefined();
2013+
}
2014+
19862015
Napi::Value SessionContext::_branchPrune(const Napi::CallbackInfo& info) {
19872016
Napi::Env env = info.Env();
19882017
ensureNotDisposed();
@@ -2381,6 +2410,40 @@ Napi::Value SessionContext::_storePrefill(const Napi::CallbackInfo& info) {
23812410
return worker->GetPromise();
23822411
}
23832412

2413+
Napi::Value SessionContext::_storeMergeLogits(const Napi::CallbackInfo& info) {
2414+
Napi::Env env = info.Env();
2415+
ensureNotDisposed();
2416+
2417+
if (info.Length() < 3 || !info[0].IsNumber() || !info[1].IsArray() || !info[2].IsNumber()) {
2418+
throw Napi::Error::New(env, "_storeMergeLogits requires (dstHandle, srcHandles[], alpha)");
2419+
}
2420+
2421+
auto dstHandle = static_cast<lloyal::branch::BranchHandle>(
2422+
info[0].As<Napi::Number>().Uint32Value());
2423+
2424+
Napi::Array jsSrcs = info[1].As<Napi::Array>();
2425+
uint32_t n = jsSrcs.Length();
2426+
2427+
std::vector<lloyal::branch::BranchHandle> srcHandles(n);
2428+
for (uint32_t i = 0; i < n; i++) {
2429+
srcHandles[i] = static_cast<lloyal::branch::BranchHandle>(
2430+
jsSrcs.Get(i).As<Napi::Number>().Uint32Value());
2431+
}
2432+
2433+
float alpha = static_cast<float>(info[2].As<Napi::Number>().DoubleValue());
2434+
2435+
try {
2436+
_branchStore.merge_logits(
2437+
dstHandle,
2438+
std::span<const lloyal::branch::BranchHandle>(srcHandles.data(), srcHandles.size()),
2439+
alpha);
2440+
} catch (const std::exception& e) {
2441+
throw Napi::Error::New(env, std::string("_storeMergeLogits: ") + e.what());
2442+
}
2443+
2444+
return env.Undefined();
2445+
}
2446+
23842447
Napi::Value SessionContext::_branchPruneSubtree(const Napi::CallbackInfo& info) {
23852448
Napi::Env env = info.Env();
23862449
ensureNotDisposed();

src/SessionContext.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
253253
Napi::Value _branchGetPosition(const Napi::CallbackInfo& info);
254254
Napi::Value _branchGetPerplexity(const Napi::CallbackInfo& info);
255255
Napi::Value _branchGetLogits(const Napi::CallbackInfo& info);
256+
Napi::Value _branchSetLogits(const Napi::CallbackInfo& info);
256257
Napi::Value _branchPrune(const Napi::CallbackInfo& info);
257258
Napi::Value _branchPruneSubtree(const Napi::CallbackInfo& info);
258259
Napi::Value _branchParent(const Napi::CallbackInfo& info);
@@ -276,6 +277,7 @@ class SessionContext : public Napi::ObjectWrap<SessionContext> {
276277

277278
Napi::Value _storeCommit(const Napi::CallbackInfo& info);
278279
Napi::Value _storePrefill(const Napi::CallbackInfo& info);
280+
Napi::Value _storeMergeLogits(const Napi::CallbackInfo& info);
279281
Napi::Value _storeRetainOnly(const Napi::CallbackInfo& info);
280282
Napi::Value _storeAvailable(const Napi::CallbackInfo& info);
281283
Napi::Value _storeKvPressure(const Napi::CallbackInfo& info);

0 commit comments

Comments
 (0)