Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ disabled_tools:
# MCP search server (FAISS + Qwen3-Embedding-8B)
# Build index first: python src/utils/browsecomp-plus-tools/setup_data.py
# model_name: local path to Qwen3-Embedding-8B, or HuggingFace model ID
# embedding.api_base: optional OpenAI-compatible embeddings API (vLLM/sglang/ModelScope/etc.)
# When set, query vectors are fetched via API instead of loading the model locally.
# The API model must match the index (e.g. Qwen/Qwen3-Embedding-8B).
mcp_server:
auto_start: true
startup_timeout: 300
Expand All @@ -29,6 +32,11 @@ mcp_server:
model_name: Qwen/Qwen3-Embedding-8B
k: 5
snippet_max_tokens: 512
# embedding:
# api_base: ${EMBEDDING_API_BASE}
# api_key: ${EMBEDDING_API_KEY}
# model: Qwen/Qwen3-Embedding-8B
# normalize: false

# LLM Judge for answer verification
# Set model + api_base to your LLM endpoint; disable_thinking recommended for vLLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@
logger = logging.getLogger(__name__)


def _resolve_env(value: str | None) -> str:
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
return os.environ.get(value[2:-1], "")
return value or ""


def _maybe_l2_normalize(vec: np.ndarray, enabled: bool) -> np.ndarray:
"""L2-normalize vec when enabled, skipping if already unit length."""
if not enabled:
return vec

norm = float(np.linalg.norm(vec))
if norm <= 0 or np.isclose(norm, 1.0, atol=1e-3, rtol=1e-3):
return vec

return vec / norm


class FaissSearcher(BaseSearcher):
@classmethod
def parse_args(cls, parser):
Expand Down Expand Up @@ -73,6 +91,22 @@ def parse_args(cls, parser):
default=8192,
help="Maximum sequence length for FAISS search (default: 8192)",
)
parser.add_argument(
"--embedding-api-base",
default=None,
help="OpenAI-compatible embeddings API base URL. If set, query "
"embeddings are fetched via API instead of loading a local model.",
)
parser.add_argument(
"--embedding-api-key",
default="EMPTY",
help="API key for the embeddings endpoint (default: EMPTY).",
)
parser.add_argument(
"--embedding-api-model",
default=None,
help="Model name for the embeddings API. Defaults to --model-name.",
)

def __init__(self, args):
if args.model_name == "bm25":
Expand All @@ -84,6 +118,8 @@ def __init__(self, args):
self.retriever = None
self.model = None
self.tokenizer = None
self._embedding_client = None
self._embedding_model = None
self.lookup = None
self.docid_to_text = None

Expand Down Expand Up @@ -133,7 +169,15 @@ def _setup_gpu(self) -> None:
# Keep FAISS index on CPU to avoid GPU memory contention with LLM services.
logger.info("FAISS index on CPU.")

@property
def _use_embedding_api(self) -> bool:
return bool(_resolve_env(getattr(self.args, "embedding_api_base", None)))

def _load_model(self) -> None:
if self._use_embedding_api:
self._setup_embedding_api()
return

logger.info(f"Loading model: {self.args.model_name}")

hf_home = os.getenv("HF_HOME")
Expand Down Expand Up @@ -180,6 +224,59 @@ def _load_model(self) -> None:

logger.info("Model loaded successfully")

def _setup_embedding_api(self) -> None:
import openai

api_base = _resolve_env(self.args.embedding_api_base)
if not api_base:
raise ValueError("embedding_api_base is set but empty after env resolution")

api_key = _resolve_env(self.args.embedding_api_key) or "EMPTY"
self._embedding_model = (
self.args.embedding_api_model or self.args.model_name
)
self._embedding_client = openai.OpenAI(api_key=api_key, base_url=api_base)
self.model = None
self.tokenizer = None
logger.info(
"Using embedding API at %s (model=%s)",
api_base,
self._embedding_model,
)

def _encode_query(self, query: str) -> np.ndarray:
text = self.args.task_prefix + query

if self._embedding_client is not None:
try:
resp = self._embedding_client.embeddings.create(
model=self._embedding_model,
input=text,
encoding_format="float",
)
except Exception as e:
raise RuntimeError(f"Embedding API request failed: {e}") from e

vec = np.array(resp.data[0].embedding, dtype=np.float32)
vec = _maybe_l2_normalize(vec, self.args.normalize)
return vec.reshape(1, -1)

batch_dict = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=self.args.max_length,
return_tensors="pt",
)

device = next(self.model.parameters()).device
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

with torch.amp.autocast(device.type):
with torch.no_grad():
q_reps = self.model.encode_query(batch_dict)
return q_reps.cpu().detach().numpy()

def _load_dataset(self) -> None:
logger.info(f"Loading dataset: {self.args.dataset_name}")

Expand Down Expand Up @@ -234,24 +331,13 @@ def _load_dataset(self) -> None:
)

def search(self, query: str, k: int = 10) -> List[Dict[str, Any]]:
if not all([self.retriever, self.model, self.tokenizer, self.lookup]):
raise RuntimeError("Searcher not properly initialized")

batch_dict = self.tokenizer(
self.args.task_prefix + query,
padding=True,
truncation=True,
max_length=self.args.max_length,
return_tensors="pt",
encoder_ready = self._embedding_client is not None or (
self.model is not None and self.tokenizer is not None
)
if not all([self.retriever, encoder_ready, self.lookup]):
raise RuntimeError("Searcher not properly initialized")

device = next(self.model.parameters()).device
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

with torch.amp.autocast(device.type):
with torch.no_grad():
q_reps = self.model.encode_query(batch_dict)
q_reps = q_reps.cpu().detach().numpy()
q_reps = self._encode_query(query)

all_scores, psg_indices = self.retriever.search(q_reps, k)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def main():
]
if cfg.get("model_name"):
argv += ["--model-name", cfg["model_name"]]

emb_cfg = cfg.get("embedding") or {}
if emb_cfg.get("api_base"):
argv += ["--embedding-api-base", emb_cfg["api_base"]]
if emb_cfg.get("api_key"):
argv += ["--embedding-api-key", emb_cfg["api_key"]]
if emb_cfg.get("model"):
argv += ["--embedding-api-model", emb_cfg["model"]]
if emb_cfg.get("normalize"):
argv += ["--normalize"]

sys.argv = argv

from mcp_server import main as mcp_main
Expand Down