@@ -817,6 +817,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) {
817817 InstanceMethod (" _branchGetPosition" , &SessionContext::_branchGetPosition),
818818 InstanceMethod (" _branchGetPerplexity" , &SessionContext::_branchGetPerplexity),
819819 InstanceMethod (" _branchGetLogits" , &SessionContext::_branchGetLogits),
820+ InstanceMethod (" _branchSetLogits" , &SessionContext::_branchSetLogits),
820821 InstanceMethod (" _branchPrune" , &SessionContext::_branchPrune),
821822 InstanceMethod (" _branchPruneSubtree" , &SessionContext::_branchPruneSubtree),
822823 InstanceMethod (" _branchParent" , &SessionContext::_branchParent),
@@ -839,6 +840,7 @@ Napi::Object SessionContext::Init(Napi::Env env, Napi::Object exports) {
839840 // ===== STORE API (internal, wrapped by lib/BranchStore.js) =====
840841 InstanceMethod (" _storeCommit" , &SessionContext::_storeCommit),
841842 InstanceMethod (" _storePrefill" , &SessionContext::_storePrefill),
843+ InstanceMethod (" _storeMergeLogits" , &SessionContext::_storeMergeLogits),
842844 InstanceMethod (" _storeRetainOnly" , &SessionContext::_storeRetainOnly),
843845 InstanceMethod (" _storeAvailable" , &SessionContext::_storeAvailable),
844846 InstanceMethod (" _storeKvPressure" , &SessionContext::_storeKvPressure),
@@ -1983,6 +1985,33 @@ Napi::Value SessionContext::_branchGetLogits(const Napi::CallbackInfo& info) {
19831985 return result;
19841986}
19851987
1988+ Napi::Value SessionContext::_branchSetLogits (const Napi::CallbackInfo& info) {
1989+ Napi::Env env = info.Env ();
1990+ ensureNotDisposed ();
1991+
1992+ if (info.Length () < 2 || !info[0 ].IsNumber () || !info[1 ].IsTypedArray ()) {
1993+ throw Napi::Error::New (env, " _branchSetLogits requires (handle, Float32Array)" );
1994+ }
1995+
1996+ auto handle = static_cast <lloyal::branch::BranchHandle>(
1997+ info[0 ].As <Napi::Number>().Uint32Value ());
1998+
1999+ Napi::TypedArray ta = info[1 ].As <Napi::TypedArray>();
2000+ if (ta.TypedArrayType () != napi_float32_array) {
2001+ throw Napi::Error::New (env, " _branchSetLogits: expected Float32Array" );
2002+ }
2003+ Napi::Float32Array arr = ta.As <Napi::Float32Array>();
2004+ std::span<const float > data (arr.Data (), arr.ElementLength ());
2005+
2006+ try {
2007+ lloyal::branch::set_logits (handle, data, _branchStore);
2008+ } catch (const std::exception& e) {
2009+ throw Napi::Error::New (env, std::string (" _branchSetLogits: " ) + e.what ());
2010+ }
2011+
2012+ return env.Undefined ();
2013+ }
2014+
19862015Napi::Value SessionContext::_branchPrune (const Napi::CallbackInfo& info) {
19872016 Napi::Env env = info.Env ();
19882017 ensureNotDisposed ();
@@ -2381,6 +2410,40 @@ Napi::Value SessionContext::_storePrefill(const Napi::CallbackInfo& info) {
23812410 return worker->GetPromise ();
23822411}
23832412
2413+ Napi::Value SessionContext::_storeMergeLogits (const Napi::CallbackInfo& info) {
2414+ Napi::Env env = info.Env ();
2415+ ensureNotDisposed ();
2416+
2417+ if (info.Length () < 3 || !info[0 ].IsNumber () || !info[1 ].IsArray () || !info[2 ].IsNumber ()) {
2418+ throw Napi::Error::New (env, " _storeMergeLogits requires (dstHandle, srcHandles[], alpha)" );
2419+ }
2420+
2421+ auto dstHandle = static_cast <lloyal::branch::BranchHandle>(
2422+ info[0 ].As <Napi::Number>().Uint32Value ());
2423+
2424+ Napi::Array jsSrcs = info[1 ].As <Napi::Array>();
2425+ uint32_t n = jsSrcs.Length ();
2426+
2427+ std::vector<lloyal::branch::BranchHandle> srcHandles (n);
2428+ for (uint32_t i = 0 ; i < n; i++) {
2429+ srcHandles[i] = static_cast <lloyal::branch::BranchHandle>(
2430+ jsSrcs.Get (i).As <Napi::Number>().Uint32Value ());
2431+ }
2432+
2433+ float alpha = static_cast <float >(info[2 ].As <Napi::Number>().DoubleValue ());
2434+
2435+ try {
2436+ _branchStore.merge_logits (
2437+ dstHandle,
2438+ std::span<const lloyal::branch::BranchHandle>(srcHandles.data (), srcHandles.size ()),
2439+ alpha);
2440+ } catch (const std::exception& e) {
2441+ throw Napi::Error::New (env, std::string (" _storeMergeLogits: " ) + e.what ());
2442+ }
2443+
2444+ return env.Undefined ();
2445+ }
2446+
23842447Napi::Value SessionContext::_branchPruneSubtree (const Napi::CallbackInfo& info) {
23852448 Napi::Env env = info.Env ();
23862449 ensureNotDisposed ();
0 commit comments