This file provides guidance to AI coding agents when working with code in this repository.
IRON is a close-to-metal Python API for AMD Ryzen™ AI NPUs (XDNA architecture). It provides language bindings around the MLIR-AIE dialect to enable fast and efficient execution on NPU hardware.
Key Technologies:
- MLIR-AIE: Dialect for programming AMD AI Engines (AIE) array architectures
- XRT (Xilinx Runtime): Low-level runtime for interfacing with NPU hardware
- Target Hardware: AMD Ryzen AI NPUs (AIE2/AIE2P architectures - NPU1/NPU2)
- Primary Datatype: bfloat16
# 1. Source XRT (required for all operations)
source /opt/xilinx/xrt/setup.sh
# 2. Create virtual environment (may already be present)
python3 -m venv ironenv
# 3. Activate virtual environment
source ironenv/bin/activate
# 4. Install dependencies
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements.txtNote: XRT must be sourced before running any tests or operators.
Compiled artifacts (.xclbin, .bin, .o files) are stored in build/ directory by default. The build directory can be customized via AIEContext(build_dir="path/to/build").
IRON_EXAMPLE_WEIGHTS_DIR: Path to model weights for applications (default:/srv)
pytest iron/operators/ -m "not extensive" --iterations 1pytest iron/operators/pytest iron/operators/axpy/pytest iron/applications/pytest iron/operators/gemm/test.py::test_gemmpytest iron/operators/ -n auto -m "not extensive"# Check formatting
black --check .
# Auto-format
black .# Check C++ formatting
python scripts/clang-format-wrapper.py --check
# Show differences
python scripts/clang-format-wrapper.py --diff
# Auto-format all
python scripts/clang-format-wrapper.py --fix
# Format specific directory
python scripts/clang-format-wrapper.py --fix --path aie_kernels/# Check all files have proper license headers
reuse lint-
Operators (
iron/operators/)- Each operator directory contains:
op.py: Python interface (inherits fromMLIROperator) - defines operator parameters, compilation artifacts, and runtime argument specsdesign.py: NPU implementation using MLIR-AIE Python API - defines ObjectFIFOs, Workers, and Runtime sequencesreference.py: CPU reference implementation for validationtest.py: End-to-end test (build, run, verify against reference)
- Each operator directory contains:
-
AIE Kernels (
aie_kernels/)- Architecture-specific C++ compute kernels:
generic/: Works on both AIE2 and AIE2Paie2/: AIE2-specific (NPU1)aie2p/: AIE2P-specific (NPU2)
- Use AIE API for vectorization (e.g.,
aie::mmul,aie::add,aie::mul) - Compiled to
.ofiles and linked into operator.xclbin
- Architecture-specific C++ compute kernels:
-
Common Infrastructure (
iron/common/)base.py: Base classes (AIEOperatorBase,MLIROperator,CompositeOperator)compilation/: Compilation artifact system (MLIR → xclbin)fusion.py: Operator fusion framework (FusedMLIROperator)device_manager.py: XRT device initialization and management (singleton pattern)context.py:AIEContextfor operator compilation/executionutils.py: Helper functions (torch_to_numpy,numpy_to_torch)test_utils.py: Test utilities (verify_buffer,nearly_equal)
ObjectFIFO: Data movement primitive in MLIR-AIE
- Connects producers and consumers (shim DMA ↔ compute tiles)
- Uses
acquire()to get buffer access,release()to free it - Pattern: always pair acquire with release in loops
Worker: Compute tile task
- Wraps a Python function that runs on AIE compute core
- Function uses
range_()for loops (not Pythonrange) - Calls compiled C++ kernels via
Kernelobjects
TensorAccessPattern (TAP): Describes how data is sliced and distributed
- Used to parallelize work across multiple columns
- Format:
(tensor_shape, offset, dimensions, strides)
Runtime Sequence: Host-side control flow
rt.fill(): DMA data from host → NPU (shim → L2/L1)rt.drain(): DMA data from NPU → hostrt.start(): Launch workersrt.task_group(): Coordinate parallel DMA operations
Compilation Flow:
design.py (Python MLIR-AIE API)
↓
PythonGeneratedMLIRArtifact
↓
MLIR (.mlir file)
↓ (aie-opt + aie-translate via Peano toolchain)
xclbin (NPU binary) + insts.bin (instruction sequence)
AIEContext: Manages compilation and runtime state
- Default build directory:
build/in current working directory - Compilation rules: Defines pipeline from Python → MLIR → xclbin
- Device manager: Singleton for XRT resource sharing
- Use
AIEContext(build_dir="...", mlir_verbose=True)for custom settings
Device Manager: Singleton that manages XRT resources
- Automatically initializes
pyxrt.device(0) - Caches contexts and kernels per xclbin path
- Shared across all operators to avoid resource conflicts
- NPU1 (AIE2): 4 rows × 4 columns (AMD Ryzen AI Phoenix/Hawk Point)
- It has 5 columns, but only 4 are accessible.
- NPU2 (AIE2P): 4 rows × 8 columns (AMD Ryzen AI 300 Series "Strix Point", Ryzen AI 9 HX 370 "Strix Halo", Krackan)
Common operator parameters and their constraints:
tile_size: Typically 64, 128, 256, or 4096 (depends on operator and data type)num_aie_columns: Must match hardware (1-4 for NPU1, up to 8 for NPU2)num_aie_rows: Always 4 for current NPU architectures
GEMM-specific:
tile_m,tile_k,tile_n: Matrix tile dimensions (typically 64)- Minimum tile sizes depend on
emulate_bf16_mmul_with_bfp16flag:True(default): 8×8×8 minimumFalse: 4×8×8 minimum
- Matrix dimensions must be multiples of
tile × num_rows/columnsM % (tile_m * 4) == 0K % tile_k == 0N % (tile_n * num_aie_columns) == 0
Element-wise ops (add, mul, relu, gelu, etc.):
size % (num_aie_columns * tile_size) == 0size % tile_size == 0
- L3: Host memory (DDR)
- L2: Shared memory tiles (MemTiles in AIE-ML)
- L1: Per-core local memory (limited, ~32-64 KB per tile)
Data movement pattern: L3 → Shim DMA → L2 → L1 (tile local) → Compute
- Create directory in
iron/operators/<operator_name>/ - Implement
op.py:- Subclass
MLIROperator - Implement
get_operator_name(),get_mlir_artifact(),get_kernel_artifacts(),get_arg_spec() - Add validation for dimension constraints (assert statements)
- Define tile sizes and column counts
- Subclass
- Implement
design.py:- Import from
aie.iron(Program, Runtime, Worker, ObjectFifo, Kernel) - Define function that builds MLIR-AIE design
- Use
range_()for loops (not Pythonrange) - Handle device-specific logic (NPU1 vs NPU2) if needed
- Import from
- Implement C++ kernel in
aie_kernels/<arch>/if needed- Choose appropriate directory:
generic/,aie2/, oraie2p/ - Use AIE API for portable vectorization when possible
- Add
event0()andevent1()for performance profiling
- Choose appropriate directory:
- Implement
reference.pywith CPU reference - Implement
test.pywith pytest tests- Use
@pytest.mark.extensivefor slower/larger tests - Use
verify_buffer()fromiron.common.test_utils
- Use
- Register operator in
iron/operators/__init__.py
IRON supports fusing multiple operators into a single ELF file. This improves performance enabling a single runtime dispatch for a chain of operators. This works only with the "full ELF" flow, which uses ELF files at runtime. The ELF files take the place of xclbins:
from iron.common.fusion import FusedMLIROperator
# Define individual operators
gemm1 = AIEGEMM(...)
relu = AIERELU(...)
gemm2 = AIEGEMM(...)
# Create fused operator with runlist
# Intermediate buffers are automatically managed
fused_op = FusedMLIROperator(
name="fused_gemm_relu_gemm",
runlist=[
(gemm1, "in", "temp1"), # (operator, input_buffers, output_buffers)
(relu, "temp1", "temp2"),
(gemm2, "temp2", "out"),
],
input_args={"in": size_in},
output_args={"out": size_out},
context=ctx
)Benefits of fusion:
- Reduces host ↔ NPU data transfers
- Runs a chain of operators using a single host-side dispatch (one CPU/host interrupt after fusion vs. one interrupt per operator without fusion)
Distribute work across NPU columns using TensorAccessPattern:
num_columns = 4
chunk = total_elements // num_columns
taps = [
TensorAccessPattern(
(1, total_elements),
chunk * i, # offset for column i
[1, 1, 1, chunk], # sizes
[0, 0, 0, 1], # strides
)
for i in range(num_columns)
]def core_body(of_in, of_out, kernel_fn):
for _ in range_(num_iterations):
elem_in = of_in.acquire(1)
elem_out = of_out.acquire(1)
kernel_fn(elem_in, elem_out, size)
of_in.release(1)
of_out.release(1)- Always use
range_()in Worker functions (NPU-side code) - Use Python
rangeonly in Runtime sequences (host-side code)
#include <aie_api/aie.hpp>
void my_kernel(bfloat16* in, bfloat16* out, int32_t size) {
event0(); // Start performance counter
aie::vector<bfloat16, 32> vec_in = aie::load_v<32>(in);
// ... vectorized operations ...
aie::store_v(out, vec_out);
event1(); // Stop performance counter
}Note: event0() and event1() are performance profiling markers.
from iron.common.test_utils import verify_buffer
# Compare NPU output against CPU reference
errors = verify_buffer(
output=npu_output,
buf_name="output",
reference=cpu_reference,
rel_tol=0.04, # 4% relative tolerance
abs_tol=1e-6, # Absolute tolerance for small values
max_error_rate=0.0 # 0% of elements can fail (strict)
)
assert len(errors) == 0, f"Found {len(errors)} mismatches"from iron.common.utils import torch_to_numpy, numpy_to_torch
# Convert torch tensor to numpy (preserves bfloat16)
np_array = torch_to_numpy(torch_tensor)
# Convert numpy array to torch (preserves bfloat16)
torch_tensor = numpy_to_torch(np_array)These utilities handle bfloat16 conversion correctly (avoiding float32 intermediate).
Disable XRT runlist for easier debugging (executes kernels individually):
context = AIEContext(use_runlist=False)This sacrifices performance but makes it easier to identify which kernel fails.
Enable verbose MLIR compilation output:
context = AIEContext(mlir_verbose=True)C++ kernels use event0() and event1() markers for performance profiling. These can be analyzed with AIE trace tools to measure cycle counts.
The codebase uses Python's standard logging module. Enable debug logging:
import logging
logging.basicConfig(level=logging.DEBUG)- small.yml: Fast operator tests (non-extensive, runs on every PR)
- extensive.yml: Full test suite (all operators with extensive tests)
- test-examples.yml: Application tests (e.g., Llama inference)
- ci-lint.yml: Linting checks (black, clang-format, reuse)
-
Target Branch: Always submit PRs to
devel -
CI Tests: Run on self-hosted runners with NPU hardware
-
All CI must pass: Including linting and formatting checks
-
Pre-Push Hook (optional but recommended):
cp scripts/hooks/pre-push .git/hooks/pre-push chmod +x .git/hooks/pre-push
-
PR Prefixes: Use "DRAFT:" for work-in-progress, "REFACTOR:" for refactoring
"No XRT device found"
- Ensure
source /opt/xilinx/xrt/setup.shwas run - Check XDNA driver is installed:
lsmod | grep amdxdna
"Kernel not found" or "Symbol not defined"
- Verify kernel
.ccfile is in correctaie_kernels/<arch>/directory - Check
get_kernel_artifacts()inop.pyreferences correct kernel path - Ensure kernel function signature matches
Kernel()declaration indesign.py
Compilation hangs or fails
- Check MLIR-AIE is installed:
python -c "import aie.iron" - Verify
llvm-aieis available:which aie-opt - Look for syntax errors in
design.py(common: usingrangeinstead ofrange_())
Test failures with numerical differences
- Check datatype consistency (bfloat16 has limited precision)
- Verify reference implementation matches NPU kernel exactly
- Look for memory alignment issues in C++ kernel
- Adjust tolerances in
verify_buffer()if needed (rel_tol,abs_tol)
Dimension mismatch errors
- Check operator constraints (e.g.,
M % (tile_m * 4) == 0for GEMM) - Verify
tile_size,num_aie_columns, and total size are compatible - Ensure tensor dimensions are multiples of required alignment
"Invalid configuration: NPU2 has 8 columns"
- NPU1 supports 1-4 columns only
- NPU2 supports up to 8 columns
- Device type is auto-detected via XRT
Kernel compilation failures
- Check kernel is in correct architecture directory (
generic/,aie2/,aie2p/) - Verify
#include <aie_api/aie.hpp>for AIE API kernels - Ensure template parameters match function signature
- Check for syntax errors in vectorization code
Full LLM inference example at iron/applications/llama_3.2_1b/:
- Required files:
model.safetensors,tokenizer.modelfrom Hugging Face - Default location:
/srv/llama3.2-1b/(configurable viaIRON_EXAMPLE_WEIGHTS_DIR) - Additional deps:
pip install -r requirements_examples.txt - Run:
pytest iron/applications/llama_3.2_1b/
See aie_kernels/README.md for catalog of available kernels:
- Element-wise ops (add, mul, scale)
- Matrix operations (mm, mv)
- Reductions (add, max, min)
- ML ops (conv2d, relu, exp)
- Vision ops (rgba2gray, filter2d)
Kernels are organized by coding style:
- AIE API: Portable C++ template library (recommended)
- Intrinsics: Architecture-specific low-level intrinsics (max performance)
- Generic C: Works on any AIE family (basic functionality)