@@ -256,6 +256,113 @@ impl Graph {
256256 . link_episode_entity ( episode_id, entity_id, span_start, span_end)
257257 }
258258
259+ // ── Embeddings ──
260+
261+ /// Store an embedding for an episode. The embedding is serialized as
262+ /// little-endian f32 bytes.
263+ pub fn store_embedding ( & self , episode_id : & str , embedding : & [ f32 ] ) -> Result < ( ) > {
264+ let bytes: Vec < u8 > = embedding
265+ . iter ( )
266+ . flat_map ( |f| f. to_le_bytes ( ) )
267+ . collect ( ) ;
268+ self . storage . store_episode_embedding ( episode_id, & bytes)
269+ }
270+
271+ /// Store an embedding for an entity.
272+ pub fn store_entity_embedding ( & self , entity_id : & str , embedding : & [ f32 ] ) -> Result < ( ) > {
273+ let bytes: Vec < u8 > = embedding
274+ . iter ( )
275+ . flat_map ( |f| f. to_le_bytes ( ) )
276+ . collect ( ) ;
277+ self . storage . store_entity_embedding ( entity_id, & bytes)
278+ }
279+
280+ /// Load all episode embeddings as (episode_id, Vec<f32>) pairs.
281+ pub fn get_embeddings ( & self ) -> Result < Vec < ( String , Vec < f32 > ) > > {
282+ let raw = self . storage . get_all_episode_embeddings ( ) ?;
283+ let result = raw
284+ . into_iter ( )
285+ . map ( |( id, bytes) | {
286+ let floats: Vec < f32 > = bytes
287+ . chunks_exact ( 4 )
288+ . map ( |c| f32:: from_le_bytes ( c. try_into ( ) . unwrap ( ) ) )
289+ . collect ( ) ;
290+ ( id, floats)
291+ } )
292+ . collect ( ) ;
293+ Ok ( result)
294+ }
295+
296+ /// Fused search using Reciprocal Rank Fusion (RRF) over FTS5 + semantic results.
297+ ///
298+ /// `query_embedding` should be the pre-computed embedding for `query`.
299+ /// Returns episodes ranked by combined RRF score.
300+ pub fn search_fused (
301+ & self ,
302+ query : & str ,
303+ query_embedding : & [ f32 ] ,
304+ limit : usize ,
305+ ) -> Result < Vec < FusedEpisodeResult > > {
306+ const K : f64 = 60.0 ;
307+
308+ // Accumulate RRF scores per episode id
309+ let mut scores: std:: collections:: HashMap < String , f64 > = std:: collections:: HashMap :: new ( ) ;
310+ let mut episodes_map: std:: collections:: HashMap < String , Episode > =
311+ std:: collections:: HashMap :: new ( ) ;
312+
313+ // --- FTS5 ranked list ---
314+ // Fetch a generous pool for RRF (up to 10x limit or 200)
315+ let fts_pool = ( limit * 10 ) . max ( 200 ) ;
316+ let fts_results = self . storage . search_episodes ( query, fts_pool) ;
317+ if let Ok ( fts) = fts_results {
318+ for ( rank, ( episode, _) ) in fts. into_iter ( ) . enumerate ( ) {
319+ let rrf = 1.0 / ( K + rank as f64 + 1.0 ) ;
320+ * scores. entry ( episode. id . clone ( ) ) . or_insert ( 0.0 ) += rrf;
321+ episodes_map. insert ( episode. id . clone ( ) , episode) ;
322+ }
323+ }
324+
325+ // --- Semantic (cosine similarity) ranked list ---
326+ let all_embeddings = self . get_embeddings ( ) ?;
327+ if !all_embeddings. is_empty ( ) && !query_embedding. is_empty ( ) {
328+ // Compute cosine similarities
329+ let mut semantic: Vec < ( String , f32 ) > = all_embeddings
330+ . into_iter ( )
331+ . map ( |( id, vec) | {
332+ let sim = cosine_similarity ( query_embedding, & vec) ;
333+ ( id, sim)
334+ } )
335+ . collect ( ) ;
336+ // Sort descending by similarity
337+ semantic. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
338+
339+ for ( rank, ( ep_id, _sim) ) in semantic. into_iter ( ) . enumerate ( ) {
340+ let rrf = 1.0 / ( K + rank as f64 + 1.0 ) ;
341+ * scores. entry ( ep_id. clone ( ) ) . or_insert ( 0.0 ) += rrf;
342+ // Fetch episode if not already cached
343+ if !episodes_map. contains_key ( & ep_id) {
344+ if let Ok ( Some ( ep) ) = self . storage . get_episode ( & ep_id) {
345+ episodes_map. insert ( ep_id, ep) ;
346+ }
347+ }
348+ }
349+ }
350+
351+ // Sort by total RRF score descending
352+ let mut fused: Vec < ( String , f64 ) > = scores. into_iter ( ) . collect ( ) ;
353+ fused. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
354+
355+ let results = fused
356+ . into_iter ( )
357+ . take ( limit)
358+ . filter_map ( |( id, score) | {
359+ episodes_map. remove ( & id) . map ( |episode| FusedEpisodeResult { episode, score } )
360+ } )
361+ . collect ( ) ;
362+
363+ Ok ( results)
364+ }
365+
259366 // ── Search ──
260367
261368 /// Search episodes via FTS5 full-text search.
@@ -319,3 +426,19 @@ impl Graph {
319426 self . storage . stats ( )
320427 }
321428}
429+
430+ /// Compute cosine similarity between two f32 vectors.
431+ /// Returns 0.0 if either vector has zero magnitude.
432+ fn cosine_similarity ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
433+ if a. len ( ) != b. len ( ) || a. is_empty ( ) {
434+ return 0.0 ;
435+ }
436+ let dot: f32 = a. iter ( ) . zip ( b. iter ( ) ) . map ( |( x, y) | x * y) . sum ( ) ;
437+ let mag_a: f32 = a. iter ( ) . map ( |x| x * x) . sum :: < f32 > ( ) . sqrt ( ) ;
438+ let mag_b: f32 = b. iter ( ) . map ( |x| x * x) . sum :: < f32 > ( ) . sqrt ( ) ;
439+ if mag_a == 0.0 || mag_b == 0.0 {
440+ 0.0
441+ } else {
442+ dot / ( mag_a * mag_b)
443+ }
444+ }
0 commit comments