Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib
XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB)
EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)

.DEFAULT_GOAL := $(EXLA_SO)

# Build flags
#
# Note that XLA requires c++17, Fine as well
Expand Down Expand Up @@ -86,7 +88,21 @@ else
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib'
endif

$(EXLA_SO): $(EXLA_CACHE_SO)
# Optional test dylib: registers qr_cpu_custom_call_f32_exla_alias -> same
# handler as qr_cpu_custom_call_f32. Built only when MIX_ENV=test.
TEST_PLUGIN_CC = c_src/exla_test/custom_calls.cc
TEST_PLUGIN_SO = $(PRIV_DIR)/test/exla_qr_alias.so

$(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR)
@ mkdir -p $(dir $@)
$(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS)

EXLA_SO_DEPS = $(EXLA_CACHE_SO)
ifeq ($(MIX_ENV),test)
EXLA_SO_DEPS += $(TEST_PLUGIN_SO)
endif

$(EXLA_SO): $(EXLA_SO_DEPS)
@ mkdir -p $(PRIV_DIR)
@ mkdir -p $(PRIV_DIR)/xla_extension
@ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \
Expand Down
17 changes: 17 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <dlfcn.h>

#include <cstring>
#include <fine.hpp>
#include <stdexcept>
Expand Down Expand Up @@ -29,6 +31,8 @@
#include "xla/tsl/platform/statusor.h"
#include "llvm/Support/ThreadPool.h"

#include <vector>

namespace exla {

using callback_bridge::Pending;
Expand Down Expand Up @@ -535,6 +539,19 @@ fine::Ok<> load_pjrt_plugin(ErlNifEnv *env, std::string device_type,

FINE_NIF(load_pjrt_plugin, 0);

// Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run.
fine::Ok<> load_dylib(ErlNifEnv *env, std::string path) {
void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL);
if (handle == nullptr) {
const char *err = dlerror();
throw std::invalid_argument(err ? err : "dlopen failed");
}
(void)handle;
return fine::Ok();
}

FINE_NIF(load_dylib, 0);

int64_t get_device_count(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client) {
return client->client()->device_count();
}
Expand Down
15 changes: 15 additions & 0 deletions exla/c_src/exla_test/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Test-only shared library: registers an alias FFI name that reuses the
// existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so.
#ifndef EXLA_PROD

#include "xla/ffi/api/api.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame);

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias",
"Host", qr_cpu_custom_call_f32);

#endif
7 changes: 7 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ defmodule EXLA do
* `:highest` - Slowest but most accurate. Performs computations in float32
or float64 as applicable

## Native custom calls (`EXLA.CustomCall`)

Some `Nx.block/4` tags can be lowered to XLA **custom calls** (StableHLO plus
a registered native handler). Implement the `EXLA.CustomCall` protocol for
your block tag struct; see `EXLA.CustomCall` for the `call/4` contract,
including returning `:skip` to fall back to the block's default Elixir callback.

## Clients

The `EXLA` library uses a client for compiling and executing code.
Expand Down
156 changes: 156 additions & 0 deletions exla/lib/exla/custom_call.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
defprotocol EXLA.CustomCall do
@moduledoc """
Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls**
(`stablehlo.custom_call` in MLIR), the same style as helpers on
`EXLA.MLIR.Value` such as `qr/3` and `eigh/3`.

Other blocks (for example gather-based `take` or FFT) are lowered inline in
`EXLA.Defn` and do not use this protocol.

## When `EXLA.Defn` calls it

During compilation with `compiler: EXLA`, when the builder is an MLIR
`EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is
passed here. `EXLA.Defn` invokes `call/4` once per block.

If `call/4` returns `:skip`, EXLA compiles the block's **default callback**
(the anonymous function body) instead of emitting a custom call.

## `call/4` arguments

Callback arity is `call(struct, args, out, client)`, matching
`Nx.block(tag, inputs, outputs, fn ... end)` (tag, inputs, outputs, then client).

* `struct` — the **tag** passed as the first argument to `Nx.block/4`
(your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`).

* `args` — list of **input templates**, in the same order as `inputs` in
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: maybe args before out in the argument list?

`Nx.block/4`.

* `out` — the **output template** tuple passed to `Nx.block/4` (expression
metadata for shapes and types, not runtime tensors).

* `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate
host-only lowerings).

## `call/4` return value

* **`:skip`** — this implementation does not apply (unsupported type,
non-host platform, wrong arity, etc.). The default block implementation
is used instead.

* **`{:ok, %EXLA.CustomCall.Spec{}}`** — emit a StableHLO custom call; see
`EXLA.CustomCall.Spec` for `call_target_name`, optional `attributes`
(`[{name, attr}]` string pairs for the `stablehlo.custom_call` `backend_config` dictionary), and optional
`operand_element_types` (operand converts when they differ
from the lowered inputs).

## Dispatch

The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags
live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can
add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen
whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback.

## Native handlers

Emitting a custom call in MLIR is only half of the story: the **target name**
must be registered with XLA on the relevant platform (typically via a native
library loaded into the process). That registration is **not** configured
through `config :exla, ...`; you load or link the native code by the same
means you would for any other NIF-backed extension.

## Example

defmodule MyApp.CustomQrTag do
defstruct []
end

defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do
def call(_tag, [_input], {%{type: {kind, size}}, _r_expr}, %{platform: :host})
when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do
{:ok, %EXLA.CustomCall.Spec{call_target_name: "my_custom_qr_target"}}
end

def call(_, _, _, _), do: :skip
end

Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with
`compiler: EXLA`.
"""

@fallback_to_any true

@doc """
Returns `:skip` or `{:ok, %EXLA.CustomCall.Spec{}}`.

Invoked as `call(struct, args, out, client)`.
"""
def call(struct, args, out, client)
end

# Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live
# in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the
# protocol, applications and libraries can define their own
# `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that
# implementation instead of this fallback when the block tag matches (you can
# also target a built-in struct such as `Nx.Block...` from your app if needed).
#
defimpl EXLA.CustomCall, for: Any do
@moduledoc false

alias EXLA.CustomCall.Spec

def call(
%Nx.Block.LinAlg.QR{},
[%{type: in_type} | _],
{%{type: q_type}, _r_expr},
%{platform: :host}
)
when elem(q_type, 0) != :c and elem(in_type, 0) != :c do
qr_cpu_custom_call(in_type)
end

# Native target names depend only on the input dtype; output templates may use
# different element types (e.g. promotion) and must not change the call target.
def call(%Nx.Block.LinAlg.Eigh{}, [%{type: in_type} | _], _out, %{platform: :host})
when elem(in_type, 0) != :c do
eigh_cpu_custom_call(in_type)
end

def call(_, _, _, _), do: :skip

defp qr_cpu_custom_call({kind, _bits}) when kind in [:s, :u] do
{:ok,
%Spec{
call_target_name: "qr_cpu_custom_call_f32",
operand_element_types: [{:f, 32}]
}}
end

defp qr_cpu_custom_call(in_type) do
case in_type do
{:f, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f32"}}
{:f, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f64"}}
{:f, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f16"}}
{:bf, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_bf16"}}
_ -> :skip
end
end

defp eigh_cpu_custom_call({kind, _bits}) when kind in [:s, :u] do
{:ok,
%Spec{
call_target_name: "eigh_cpu_custom_call_f32",
operand_element_types: [{:f, 32}]
}}
end

defp eigh_cpu_custom_call(in_type) do
case in_type do
{:f, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f32"}}
{:f, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f64"}}
_ -> :skip
end
end
end
37 changes: 37 additions & 0 deletions exla/lib/exla/custom_call/spec.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defmodule EXLA.CustomCall.Spec do
@moduledoc """
Result of `EXLA.CustomCall.call/4` when lowering a tagged `Nx.block/4` to
`stablehlo.custom_call`.

* **`call_target_name`** — XLA FFI handler name (`call_target_name` on the op).

* **`attributes`** — Optional `{name, attr}` pairs, default `[]`, merged into
the `backend_config` dictionary on `stablehlo.custom_call` (StableHLO’s name
for that attribute). Each `name` must be a **binary** MLIR identifier; each
`attr` must be a **binary** with valid MLIR attribute syntax for the RHS after
`name = ` (for example `{"k", "42 : i64"}`). An empty list omits the dictionary
from the op.

* **`operand_element_types`** — How operand SSA values are presented to the handler:

* **`:default`** — use each lowered operand’s element type as produced from the
block inputs. No extra converts.

* **`[Nx.Type.t(), ...]`** — one type per block input, same order and length as
`Nx.block/4` inputs. Before building the custom call, each operand is
converted (StableHLO `convert`) when its element type differs from the
requested type; shapes are unchanged. Use this when the native kernel’s
FFI signature expects dtypes that may differ from the traced expression
types (for example after promotion rules).
"""

@enforce_keys [:call_target_name]

defstruct [:call_target_name, attributes: [], operand_element_types: :default]

@type t :: %__MODULE__{
call_target_name: String.t(),
attributes: [{String.t(), String.t()}],
operand_element_types: :default | [Nx.Type.t()]
}
end
Loading
Loading