-
Notifications
You must be signed in to change notification settings - Fork 218
Refactor EXLA block lowering through EXLA.CustomCall protocol #1739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Chapaman
wants to merge
14
commits into
elixir-nx:main
Choose a base branch
from
Chapaman:exla_block_implementation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
014eb64
Refactor EXLA block lowering through EXLA.CustomCall protocol
Chapaman 5db679f
update EXLA.CustomCall to handle C-backed Nx.block tags (QR, Eigh)
Chapaman 4011521
test(exla): add QR FFI alias plugin + dlopen NIF and MLIR/JIT coverag…
Chapaman 3152003
fix formatting
Chapaman f2aa558
CustomCall now has only one callback + add documentation
Chapaman 2d8e9c2
update based on polvalente comments
Chapaman adf2369
upcast integers to float in C
Chapaman c058794
remove name from qr and eigh _cpu_target in custom_call.ex
Chapaman e424803
remove Value.qr/3 and Value.eigh/3 as they are not being used
Chapaman 238743d
remove CustomCall.Builtins, add integer test on custom_call_alias_test
Chapaman 3f2cb9e
.
Chapaman 1f812dd
refactor: defp Defn op helpers; CustomCall.call/4 + Spec (backend_con…
Chapaman 1e1fb6b
now integer QR and Eigh custom calls lower to f32 in elixir
Chapaman 3edbdf6
update based on feedback
Chapaman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| // Test-only shared library: registers an alias FFI name that reuses the | ||
| // existing qr_cpu_custom_call_f32 handler symbol from libxla_extension.so. | ||
| #ifndef EXLA_PROD | ||
|
|
||
| #include "xla/ffi/api/api.h" | ||
| #include "xla/ffi/ffi_api.h" | ||
|
|
||
| namespace ffi = xla::ffi; | ||
|
|
||
| extern "C" XLA_FFI_Error *qr_cpu_custom_call_f32(XLA_FFI_CallFrame *call_frame); | ||
|
|
||
| XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "qr_cpu_custom_call_f32_exla_alias", | ||
| "Host", qr_cpu_custom_call_f32); | ||
|
|
||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| defprotocol EXLA.CustomCall do | ||
| @moduledoc """ | ||
| Extension point for lowering selected `Nx.block/4` tags to **XLA custom calls** | ||
| (`stablehlo.custom_call` in MLIR), the same style as helpers on | ||
| `EXLA.MLIR.Value` such as `qr/3` and `eigh/3`. | ||
|
|
||
| Other blocks (for example gather-based `take` or FFT) are lowered inline in | ||
| `EXLA.Defn` and do not use this protocol. | ||
|
|
||
| ## When `EXLA.Defn` calls it | ||
|
|
||
| During compilation with `compiler: EXLA`, when the builder is an MLIR | ||
| `EXLA.MLIR.Function`, each `Nx.block(tag, inputs, outputs, fn ... end)` is | ||
| passed here. `EXLA.Defn` invokes `call/4` once per block. | ||
|
|
||
| If `call/4` returns `:skip`, EXLA compiles the block's **default callback** | ||
| (the anonymous function body) instead of emitting a custom call. | ||
|
|
||
| ## `call/4` arguments | ||
|
|
||
| Callback arity is `call(struct, args, out, client)`, matching | ||
| `Nx.block(tag, inputs, outputs, fn ... end)` (tag, inputs, outputs, then client). | ||
|
|
||
| * `struct` — the **tag** passed as the first argument to `Nx.block/4` | ||
| (your own `defstruct` or an existing tag such as `%Nx.Block.LinAlg.QR{}`). | ||
|
|
||
| * `args` — list of **input templates**, in the same order as `inputs` in | ||
| `Nx.block/4`. | ||
|
|
||
| * `out` — the **output template** tuple passed to `Nx.block/4` (expression | ||
| metadata for shapes and types, not runtime tensors). | ||
|
|
||
| * `client` — the active `EXLA.Client` (use e.g. `client.platform` to gate | ||
| host-only lowerings). | ||
|
|
||
| ## `call/4` return value | ||
|
|
||
| * **`:skip`** — this implementation does not apply (unsupported type, | ||
| non-host platform, wrong arity, etc.). The default block implementation | ||
| is used instead. | ||
|
|
||
| * **`{:ok, %EXLA.CustomCall.Spec{}}`** — emit a StableHLO custom call; see | ||
| `EXLA.CustomCall.Spec` for `call_target_name`, optional `attributes` | ||
| (`[{name, attr}]` string pairs for the `stablehlo.custom_call` `backend_config` dictionary), and optional | ||
| `operand_element_types` (operand converts when they differ | ||
| from the lowered inputs). | ||
|
|
||
| ## Dispatch | ||
|
|
||
| The protocol uses `@fallback_to_any true`. Built-in lowerings for known tags | ||
| live in `defimpl EXLA.CustomCall, for: Any`. Your application or dependency can | ||
| add `defimpl EXLA.CustomCall, for: YourStruct`; that implementation is chosen | ||
| whenever the block tag is `%YourStruct{}`, instead of the `Any` fallback. | ||
|
|
||
| ## Native handlers | ||
|
|
||
| Emitting a custom call in MLIR is only half of the story: the **target name** | ||
| must be registered with XLA on the relevant platform (typically via a native | ||
| library loaded into the process). That registration is **not** configured | ||
| through `config :exla, ...`; you load or link the native code by the same | ||
| means you would for any other NIF-backed extension. | ||
|
|
||
| ## Example | ||
|
|
||
| defmodule MyApp.CustomQrTag do | ||
| defstruct [] | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: MyApp.CustomQrTag do | ||
| def call(_tag, [_input], {%{type: {kind, size}}, _r_expr}, %{platform: :host}) | ||
| when kind != :c and kind in [:f, :bf] and size in [16, 32, 64] do | ||
| {:ok, %EXLA.CustomCall.Spec{call_target_name: "my_custom_qr_target"}} | ||
| end | ||
|
|
||
| def call(_, _, _, _), do: :skip | ||
| end | ||
|
|
||
| Then use `Nx.block(%MyApp.CustomQrTag{}, ...)` inside a `defn` compiled with | ||
| `compiler: EXLA`. | ||
| """ | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Returns `:skip` or `{:ok, %EXLA.CustomCall.Spec{}}`. | ||
|
|
||
| Invoked as `call(struct, args, out, client)`. | ||
| """ | ||
| def call(struct, args, out, client) | ||
| end | ||
|
|
||
| # Default EXLA lowerings for **C-backed custom_call** `Nx.block/4` tags live | ||
| # in this `defimpl ..., for: Any` module. With `@fallback_to_any true` on the | ||
| # protocol, applications and libraries can define their own | ||
| # `defimpl EXLA.CustomCall, for: SomeStruct` — protocol dispatch uses that | ||
| # implementation instead of this fallback when the block tag matches (you can | ||
| # also target a built-in struct such as `Nx.Block...` from your app if needed). | ||
| # | ||
| defimpl EXLA.CustomCall, for: Any do | ||
| @moduledoc false | ||
|
|
||
| alias EXLA.CustomCall.Spec | ||
|
|
||
| def call( | ||
| %Nx.Block.LinAlg.QR{}, | ||
| [%{type: in_type} | _], | ||
| {%{type: q_type}, _r_expr}, | ||
| %{platform: :host} | ||
| ) | ||
| when elem(q_type, 0) != :c and elem(in_type, 0) != :c do | ||
| qr_cpu_custom_call(in_type) | ||
| end | ||
|
|
||
| # Native target names depend only on the input dtype; output templates may use | ||
| # different element types (e.g. promotion) and must not change the call target. | ||
| def call(%Nx.Block.LinAlg.Eigh{}, [%{type: in_type} | _], _out, %{platform: :host}) | ||
| when elem(in_type, 0) != :c do | ||
| eigh_cpu_custom_call(in_type) | ||
| end | ||
|
|
||
| def call(_, _, _, _), do: :skip | ||
|
|
||
| defp qr_cpu_custom_call({kind, _bits}) when kind in [:s, :u] do | ||
| {:ok, | ||
| %Spec{ | ||
| call_target_name: "qr_cpu_custom_call_f32", | ||
| operand_element_types: [{:f, 32}] | ||
| }} | ||
| end | ||
|
|
||
| defp qr_cpu_custom_call(in_type) do | ||
| case in_type do | ||
| {:f, 32} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f32"}} | ||
| {:f, 64} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f64"}} | ||
| {:f, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_f16"}} | ||
| {:bf, 16} -> {:ok, %Spec{call_target_name: "qr_cpu_custom_call_bf16"}} | ||
| _ -> :skip | ||
| end | ||
| end | ||
|
|
||
| defp eigh_cpu_custom_call({kind, _bits}) when kind in [:s, :u] do | ||
| {:ok, | ||
| %Spec{ | ||
| call_target_name: "eigh_cpu_custom_call_f32", | ||
| operand_element_types: [{:f, 32}] | ||
| }} | ||
| end | ||
|
|
||
| defp eigh_cpu_custom_call(in_type) do | ||
| case in_type do | ||
| {:f, 32} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f32"}} | ||
| {:f, 64} -> {:ok, %Spec{call_target_name: "eigh_cpu_custom_call_f64"}} | ||
| _ -> :skip | ||
| end | ||
| end | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| defmodule EXLA.CustomCall.Spec do | ||
| @moduledoc """ | ||
| Result of `EXLA.CustomCall.call/4` when lowering a tagged `Nx.block/4` to | ||
| `stablehlo.custom_call`. | ||
|
|
||
| * **`call_target_name`** — XLA FFI handler name (`call_target_name` on the op). | ||
|
|
||
| * **`attributes`** — Optional `{name, attr}` pairs, default `[]`, merged into | ||
| the `backend_config` dictionary on `stablehlo.custom_call` (StableHLO’s name | ||
| for that attribute). Each `name` must be a **binary** MLIR identifier; each | ||
| `attr` must be a **binary** with valid MLIR attribute syntax for the RHS after | ||
| `name = ` (for example `{"k", "42 : i64"}`). An empty list omits the dictionary | ||
| from the op. | ||
|
|
||
| * **`operand_element_types`** — How operand SSA values are presented to the handler: | ||
|
|
||
| * **`:default`** — use each lowered operand’s element type as produced from the | ||
| block inputs. No extra converts. | ||
|
|
||
| * **`[Nx.Type.t(), ...]`** — one type per block input, same order and length as | ||
| `Nx.block/4` inputs. Before building the custom call, each operand is | ||
| converted (StableHLO `convert`) when its element type differs from the | ||
| requested type; shapes are unchanged. Use this when the native kernel’s | ||
| FFI signature expects dtypes that may differ from the traced expression | ||
| types (for example after promotion rules). | ||
| """ | ||
|
|
||
| @enforce_keys [:call_target_name] | ||
|
|
||
| defstruct [:call_target_name, attributes: [], operand_element_types: :default] | ||
|
|
||
| @type t :: %__MODULE__{ | ||
| call_target_name: String.t(), | ||
| attributes: [{String.t(), String.t()}], | ||
| operand_element_types: :default | [Nx.Type.t()] | ||
| } | ||
| end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: maybe
argsbeforeoutin the argument list?