Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions exo/inference/rkllm/rkllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ class RKLLMInferenceEngine(InferenceEngine):
- Loads complete .rkllm models (no partial layer loading)
- Uses HTTP client to connect to rkllama server (recommended)
- Thread-safe with dedicated executor for blocking operations
- Supports per-core NPU pinning via RKNN_CORE_MASK env var

Core mask values for RK3588 (3 NPU cores, 6 TOPS total):
0x1 = core 0 only (~2 TOPS)
0x2 = core 1 only (~2 TOPS)
0x4 = core 2 only (~2 TOPS)
0x3 = cores 0+1 (~4 TOPS)
0x7 = all cores (~6 TOPS, default)

Set RKNN_CORE_MASK to pin this engine instance to specific NPU cores.
Useful for multi-tenant workloads where multiple models share the
same board.

Note: RKLLM loads complete models (the .rkllm format is compiled as
a monolith). Pipeline-parallel layer sharding across multiple NPU
nodes is not possible until Rockchip adds partial layer loading to
the RKLLM SDK. Tracked at airockchip/rknn-llm#489. For distributed
inference of large models, use GPU/Apple Silicon nodes via exo's
MLX or tinygrad backends.
"""

def __init__(
Expand All @@ -82,6 +101,13 @@ def __init__(
self._shard_lock = asyncio.Lock()
self.session = {}

# NPU core pinning: RKNN_CORE_MASK selects which NPU cores this
# engine instance is allowed to use. When running multiple exo
# nodes on the same RK3588 (e.g. in separate LXC containers),
# each node should get its own core to avoid contention.
self.core_mask = int(os.environ.get("RKNN_CORE_MASK", "0"), 0)
self.npu_cores = self._count_cores(self.core_mask) if self.core_mask else 3

# HTTP client configuration
self._server_config = RKLLMServerConfig(
host=os.environ.get("RKLLM_SERVER_HOST", server_host),
Expand All @@ -93,15 +119,41 @@ def __init__(
self._stream_tasks = {} # Active streaming tasks per request

if DEBUG >= 1:
print(f"RKLLM engine initialized (HTTP mode: {self._server_config.base_url})")
core_desc = f"cores=all" if not self.core_mask else f"core_mask=0x{self.core_mask:x} ({self.npu_cores} core(s))"
print(f"RKLLM engine initialized (HTTP mode: {self._server_config.base_url}, {core_desc})")

# Set inference engine info metric
INFERENCE_ENGINE_INFO.info({
'engine': 'RKLLMInferenceEngine',
'mode': 'http',
'server': self._server_config.base_url
'server': self._server_config.base_url,
'core_mask': hex(self.core_mask) if self.core_mask else 'all',
'npu_cores': str(self.npu_cores),
})

@staticmethod
def _count_cores(mask: int) -> int:
"""Count set bits in the core mask."""
count = 0
while mask:
count += mask & 1
mask >>= 1
return count

def get_capability_descriptor(self) -> dict:
"""Return a descriptor of this engine's NPU capabilities.

Used by the topology manager to weigh this node's capacity
relative to other nodes in the cluster. A single-core node
gets roughly 1/3 the weight of a full 3-core node.
"""
return {
"accelerator": "rk3588-npu",
"core_mask": self.core_mask,
"npu_cores": self.npu_cores,
"tops_estimate": self.npu_cores * 2, # ~2 TOPS per core
}

@property
def tokenizer(self):
"""Return the tokenizer for compatibility with exo framework."""
Expand Down Expand Up @@ -526,9 +578,9 @@ async def ensure_shard(self, shard: Shard):
if DEBUG >= 1:
print(f"Loading RKLLM model: {model_name}")

# Load model via HTTP API with timing
# Load model via HTTP API with timing, passing core_mask if set
load_start = time.time()
success = await self._http_client.load_model(model_name)
success = await self._http_client.load_model(model_name, core_mask=self.core_mask)
load_duration = time.time() - load_start
RKLLM_MODEL_LOAD_SECONDS.observe(load_duration)
RKLLM_HTTP_REQUESTS.labels(endpoint='/load_model', status='success' if success else 'error').inc()
Expand Down
60 changes: 59 additions & 1 deletion exo/inference/rkllm/rkllm_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RKLLMServerConfig:
host: str = "localhost"
port: int = 8080
timeout: float = 300.0 # 5 minute timeout for generation
ollama_compat: bool = True # Use Ollama-compatible API (/api/generate)

@property
def base_url(self) -> str:
Expand Down Expand Up @@ -71,6 +72,12 @@ async def list_models(self) -> List[str]:
"""Get list of available models on the server."""
try:
session = await self._get_session()
if self.config.ollama_compat:
async with session.get(f"{self.config.base_url}/api/tags") as resp:
if resp.status == 200:
data = await resp.json()
return [m.get("name", m.get("model", "")) for m in data.get("models", [])]
return []
async with session.get(f"{self.config.base_url}/models") as resp:
if resp.status == 200:
data = await resp.json()
Expand All @@ -83,6 +90,8 @@ async def list_models(self) -> List[str]:

async def get_current_model(self) -> Optional[str]:
"""Get the currently loaded model name."""
if self.config.ollama_compat:
return self._current_model
try:
session = await self._get_session()
async with session.get(f"{self.config.base_url}/current_model") as resp:
Expand All @@ -99,14 +108,16 @@ async def load_model(
self,
model_name: str,
huggingface_path: Optional[str] = None,
from_file: Optional[str] = None
from_file: Optional[str] = None,
core_mask: int = 0
) -> bool:
"""
Load a model on the rkllama server.

Args:
model_name: Name of the model directory in ~/RKLLAMA/models/
huggingface_path: Optional HuggingFace repo for tokenizer
core_mask: NPU core mask (0 = all cores, 0x1/0x2/0x4 = single core)
from_file: Optional .rkllm filename

Returns:
Expand All @@ -119,6 +130,19 @@ async def load_model(
print(f"Model {model_name} already loaded")
return True

if self.config.ollama_compat:
# Ollama-compat rkllama loads models on first use via /api/generate.
# Just verify the model exists in the list, then mark it as current.
available = await self.list_models()
if model_name in available:
self._current_model = model_name
if DEBUG >= 1:
print(f"RKLLM model {model_name} available (Ollama compat, lazy load)")
return True
if DEBUG >= 1:
print(f"RKLLM model {model_name} not in available list: {available}")
return False

# Unload current model if one is loaded
if current:
await self.unload_model()
Expand All @@ -131,6 +155,8 @@ async def load_model(
payload["huggingface_path"] = huggingface_path
if from_file:
payload["from"] = from_file
if core_mask:
payload["core_mask"] = core_mask

async with session.post(
f"{self.config.base_url}/load_model",
Expand Down Expand Up @@ -266,8 +292,40 @@ async def generate_from_prompt(self, prompt: str) -> str:
# Use prompt as-is
messages = [{"role": "user", "content": prompt}]

if self.config.ollama_compat:
return await self._generate_ollama(prompt if not extracted_content else extracted_content)
return await self.generate(messages, stream=False)

async def _generate_ollama(self, prompt: str) -> str:
"""Generate via the Ollama-compatible /api/generate endpoint."""
try:
session = await self._get_session()
payload = {
"model": self._current_model or "",
"prompt": prompt,
"stream": False,
}
async with session.post(
f"{self.config.base_url}/api/generate",
json=payload
) as resp:
if resp.status == 200:
data = await resp.json()
return data.get("response", "")
else:
error = await resp.text()
if DEBUG >= 1:
print(f"Ollama generate failed ({resp.status}): {error[:200]}")
return ""
except asyncio.TimeoutError:
if DEBUG >= 1:
print("Ollama generate timed out")
return ""
except Exception as e:
if DEBUG >= 1:
print(f"Ollama generate failed: {e}")
return ""

async def generate_stream(
self,
messages: List[Dict[str, str]],
Expand Down
148 changes: 148 additions & 0 deletions scripts/test-3node-lxc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env bash
# Test distributed RKLLM inference across 3 LXC containers on one RK3588.
#
# Creates 3 Incus containers, each pinned to a single NPU core via
# RKNN_CORE_MASK. All three run exo and discover each other via mDNS
# on the bridge network. A small model is split across the 3 "nodes"
# to test the full pipeline parallelism path.
#
# Prerequisites:
# - Incus installed on the host
# - /dev/rknpu accessible (RK3588 board)
# - exo repo cloned at ~/exo (this repo)
#
# Usage:
# scripts/test-3node-lxc.sh [setup|run|teardown|all]
#
# "all" runs setup, waits for the cluster to form, sends a test
# prompt, and tears down.
set -euo pipefail

BASE_IMAGE="${EXO_TEST_IMAGE:-images:debian/12}"
CONTAINER_PREFIX="exo-npu"
EXO_REPO="${EXO_REPO:-$(cd "$(dirname "$0")/.." && pwd)}"

log() { printf '\033[1;34m[3node-test]\033[0m %s\n' "$*"; }
die() { printf '\033[1;31m[3node-test]\033[0m %s\n' "$*" >&2; exit 1; }

setup_containers() {
log "creating 3 LXC containers for distributed NPU test"

for i in 0 1 2; do
local name="${CONTAINER_PREFIX}-${i}"
local mask=$((1 << i)) # 0x1, 0x2, 0x4
local port=$((52415 + i))

if incus info "$name" >/dev/null 2>&1; then
log "$name already exists, skipping create"
else
log "creating $name (core_mask=0x$(printf '%x' $mask), api_port=$port)"
incus launch "$BASE_IMAGE" "$name"
sleep 3
fi

# Pass the NPU device into the container
incus config device add "$name" npu unix-char path=/dev/rknpu 2>/dev/null || true

# Mount the exo repo read-only so we don't need to clone inside
incus config device add "$name" exo-src disk \
source="$EXO_REPO" path=/opt/exo 2>/dev/null || true

# Install Python and deps inside the container
incus exec "$name" -- bash -c '
apt-get update -qq
apt-get install -y -qq python3 python3-venv python3-pip git curl
if ! command -v uv >/dev/null 2>&1; then
curl -LsSf https://astral.sh/uv/install.sh | sh
fi
' 2>/dev/null

# Write the startup env vars
incus exec "$name" -- bash -c "
cat > /opt/exo-env.sh <<'ENVEOF'
export RKNN_CORE_MASK=$mask
export RKLLM_SERVER_PORT=8080
export EXO_PORT=$port
export PATH=\$HOME/.local/bin:\$PATH
ENVEOF
"

log "$name ready (core $i, mask=0x$(printf '%x' $mask))"
done

log "all 3 containers created"
}

run_cluster() {
log "starting exo on all 3 containers"

for i in 0 1 2; do
local name="${CONTAINER_PREFIX}-${i}"
incus exec "$name" -- bash -c '
source /opt/exo-env.sh
cd /opt/exo
nohup $HOME/.local/bin/uv run exo --port $EXO_PORT > /tmp/exo.log 2>&1 &
echo $! > /tmp/exo.pid
'
log "started exo on $name (pid written to /tmp/exo.pid)"
done

log "waiting 15s for peer discovery"
sleep 15

# Check cluster health via the first node's API
local api_ip
api_ip=$(incus exec "${CONTAINER_PREFIX}-0" -- hostname -I | awk '{print $1}')
log "checking cluster at http://$api_ip:52415"

local cluster_info
cluster_info=$(curl -sS "http://$api_ip:52415/v1/models" 2>/dev/null || echo "FAILED")
echo "$cluster_info"

if echo "$cluster_info" | grep -q "FAILED"; then
die "cluster API not responding"
fi

log "cluster is up, sending test prompt"

local response
response=$(curl -sS "http://$api_ip:52415/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "qwen2.5-1.5b-rkllm",
"messages": [{"role": "user", "content": "What is 2+2?"}],
"max_tokens": 32
}' 2>/dev/null || echo "FAILED")

echo "$response" | head -20
if echo "$response" | grep -q "choices"; then
log "distributed inference test PASSED"
else
log "distributed inference test FAILED (no choices in response)"
fi
}

teardown_containers() {
log "tearing down test containers"
for i in 0 1 2; do
local name="${CONTAINER_PREFIX}-${i}"
incus delete --force "$name" 2>/dev/null || true
log "deleted $name"
done
log "teardown complete"
}

case "${1:-all}" in
setup) setup_containers ;;
run) run_cluster ;;
teardown) teardown_containers ;;
all)
setup_containers
run_cluster
teardown_containers
;;
*)
echo "usage: $0 [setup|run|teardown|all]"
exit 1
;;
esac