Skip to content

Commit 8b0b466

Browse files
committed
feat: add relation classifier (embedding + logreg) with training pipeline
1 parent 55de64b commit 8b0b466

File tree

9 files changed

+1696
-1
lines changed

9 files changed

+1696
-1
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/ctxgraph-extract/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ orp = "0.9"
2323
composable = "0.9"
2424
ndarray = { workspace = true }
2525
tokenizers = "0.22"
26+
fastembed = { workspace = true }
2627

2728
[dev-dependencies]
2829
tempfile = "3"

crates/ctxgraph-extract/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod ner;
44
pub mod nli;
55
pub mod pipeline;
66
pub mod rel;
7+
pub mod relclf;
78
pub mod relex;
89
pub mod remap;
910
pub mod schema;

crates/ctxgraph-extract/src/model_manager.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,23 @@ impl ModelManager {
308308
}
309309
}
310310

311+
/// Check for a locally available fine-tuned relation classifier model.
312+
///
313+
/// Looks for `relation_classifier/model_int8.onnx` (or `model.onnx`) and
314+
/// `relation_classifier/tokenizer.json` in the cache directory.
315+
///
316+
/// Returns `Some((model_path, tokenizer_path))` if found, `None` otherwise.
317+
pub fn find_relation_classifier(&self) -> Option<std::path::PathBuf> {
318+
let base = self.cache_dir.join("relation_classifier");
319+
320+
[
321+
base.join("model_int8.onnx"),
322+
base.join("model.onnx"),
323+
]
324+
.into_iter()
325+
.find(|p| p.exists())
326+
}
327+
311328
/// Check for locally exported gliner-relex ONNX model.
312329
///
313330
/// The relex model must be exported manually using:

crates/ctxgraph-extract/src/rel.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use orp::params::RuntimeParameters;
1313
use orp::pipeline::Pipeline;
1414

1515
use crate::ner::ExtractedEntity;
16+
use crate::relclf::RelationClassifier;
1617
use crate::relex::RelexEngine;
1718
use crate::schema::ExtractionSchema;
1819

@@ -42,6 +43,9 @@ pub enum RelEngine {
4243
/// Cached relex engine (loaded once per process).
4344
static RELEX_ENGINE: std::sync::OnceLock<Option<RelexEngine>> = std::sync::OnceLock::new();
4445

46+
/// Cached relation classifier (loaded once per process).
47+
static RELCLF_ENGINE: std::sync::OnceLock<Option<RelationClassifier>> = std::sync::OnceLock::new();
48+
4549
/// Model-based relation extraction using gline-rs.
4650
///
4751
/// Requires `gliner-multitask-large-v0.5` ONNX model.
@@ -185,7 +189,7 @@ impl RelEngine {
185189
.collect();
186190

187191
for rel in &result.relations {
188-
if rel.confidence < 0.7 {
192+
if rel.confidence < 0.80 {
189193
continue;
190194
}
191195

@@ -210,6 +214,57 @@ impl RelEngine {
210214
}
211215
}
212216

217+
// Relation classifier (embedding-based) — auto-enabled if model found.
218+
let relclf = RELCLF_ENGINE.get_or_init(|| {
219+
let mgr = crate::model_manager::ModelManager::new().ok()?;
220+
let model_path = mgr.find_relation_classifier()?;
221+
RelationClassifier::new(&model_path).ok()
222+
});
223+
224+
if let Some(classifier) = relclf {
225+
// Use fastembed for generating embeddings (same model as ctxgraph-embed).
226+
static EMBED_ENGINE: std::sync::OnceLock<
227+
Option<fastembed::TextEmbedding>,
228+
> = std::sync::OnceLock::new();
229+
let embed = EMBED_ENGINE.get_or_init(|| {
230+
fastembed::TextEmbedding::try_new(
231+
fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2),
232+
)
233+
.ok()
234+
});
235+
236+
if let Some(embed_model) = embed {
237+
let embed_fn = |text: &str| -> Result<Vec<f32>, RelError> {
238+
let mut vecs = embed_model
239+
.embed(vec![text], None)
240+
.map_err(|e| RelError::Inference(e.to_string()))?;
241+
vecs.pop()
242+
.ok_or_else(|| RelError::Inference("empty embedding".into()))
243+
};
244+
245+
if let Ok(clf_relations) =
246+
classifier.classify_batch(text, entities, &embed_fn)
247+
{
248+
let existing: std::collections::HashSet<(String, String)> = relations
249+
.iter()
250+
.map(|r| (r.head.clone(), r.tail.clone()))
251+
.collect();
252+
253+
for rel in clf_relations {
254+
if rel.confidence < 0.76 {
255+
continue;
256+
}
257+
// Only add relations for pairs not already covered
258+
if !existing.contains(&(rel.head.clone(), rel.tail.clone()))
259+
&& !existing.contains(&(rel.tail.clone(), rel.head.clone()))
260+
{
261+
relations.push(rel);
262+
}
263+
}
264+
}
265+
}
266+
}
267+
213268
Ok(relations)
214269
}
215270
}

0 commit comments

Comments
 (0)