Skip to content

Commit a59fc5d

Browse files
authored
Merge branch 'main' into pb/ll-inline-roots
2 parents 2af2121 + 79d3219 commit a59fc5d

23 files changed

Lines changed: 1114 additions & 979 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ docs/site/
253253
Manifest.toml
254254
Manifest-v*.toml
255255
.CondaPkg
256+
LocalPreferences.toml
256257

257258
.vscode/*
258259
.vscode/settings.json

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
3-
version = "0.2.253"
3+
version = "0.2.254"
44
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
55

66
[workspace]
@@ -149,7 +149,7 @@ PythonCall = "0.9.25"
149149
Random = "1.10"
150150
Random123 = "1.7"
151151
ReactantCore = "0.1.18"
152-
Reactant_jll = "0.0.371"
152+
Reactant_jll = "0.0.375"
153153
ScopedValues = "1.3.0"
154154
Scratch = "1.3"
155155
Serialization = "1.10"

deps/ReactantExtra/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,24 @@ gentbl_cc_library(
16721672
],
16731673
)
16741674

1675+
gentbl_cc_library(
1676+
name = "ImpulseJLIncGen",
1677+
tbl_outs = [
1678+
(
1679+
[
1680+
"--generator=jl-op-defs",
1681+
"--disable-module-wrap=0",
1682+
],
1683+
"Impulse.jl",
1684+
),
1685+
],
1686+
tblgen = "//:mlir-jl-tblgen",
1687+
td_file = "@enzyme//:Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td",
1688+
deps = [
1689+
"@enzyme//:ImpulseDialectTdFiles",
1690+
],
1691+
)
1692+
16751693
gentbl_cc_library(
16761694
name = "EnzymeXLAJLIncGen",
16771695
tbl_outs = [

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "10d4571a0d5fa08a2fbf285dc585cf16a091b573"
7+
ENZYMEXLA_COMMIT = "99d7b16a3ef9cd5cac1f616c94e5be617a16f0d2"
88

99
ENZYMEXLA_SHA256 = ""
1010

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dialect_files = [
1616
"Func.jl",
1717
"Enzyme.jl",
1818
"EnzymeXLA.jl",
19+
"Impulse.jl",
1920
"StableHLO.jl",
2021
"CHLO.jl",
2122
"VHLO.jl",

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ export default defineConfig({
127127
{ text: "EnzymeXLA", link: "/api/dialects/enzymexla" },
128128
{ text: "Func", link: "/api/dialects/func" },
129129
{ text: "GPU", link: "/api/dialects/gpu" },
130+
{ text: "Impulse", link: "/api/dialects/impulse" },
130131
{ text: "LLVM", link: "/api/dialects/llvm" },
131132
{ text: "MPI", link: "/api/dialects/mpi" },
132133
{ text: "MemRef", link: "/api/dialects/memref" },
@@ -225,6 +226,7 @@ export default defineConfig({
225226
{ text: "EnzymeXLA", link: "/api/dialects/enzymexla" },
226227
{ text: "Func", link: "/api/dialects/func" },
227228
{ text: "GPU", link: "/api/dialects/gpu" },
229+
{ text: "Impulse", link: "/api/dialects/impulse" },
228230
{ text: "LLVM", link: "/api/dialects/llvm" },
229231
{ text: "MPI", link: "/api/dialects/mpi" },
230232
{ text: "MemRef", link: "/api/dialects/memref" },

docs/src/api/dialects/impulse.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
```@meta
2+
CollapsedDocStrings = true
3+
```
4+
5+
# Impulse Dialect
6+
7+
```@autodocs
8+
Modules = [Reactant.MLIR.Dialects.impulse]
9+
```

src/Compiler.jl

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,37 @@ Base.@nospecializeinfer function create_result(
277277
return result_cache[tocopy]
278278
end
279279

280+
Base.@nospecializeinfer function create_result(
281+
@nospecialize(tocopy::Enum),
282+
@nospecialize(path::Tuple),
283+
result_stores,
284+
path_to_shard_info,
285+
to_unreshard_results,
286+
_unresharded_code::Vector{Expr},
287+
_unresharded_arrays_cache,
288+
used_shardinfo,
289+
result_cache,
290+
var_idx,
291+
resultgen_code,
292+
)
293+
if !haskey(result_cache, tocopy)
294+
sym = Symbol("result", var_idx[])
295+
var_idx[] += 1
296+
297+
result = Meta.quot(tocopy)
298+
299+
push!(
300+
resultgen_code,
301+
quote
302+
$sym = $result
303+
end,
304+
)
305+
result_cache[tocopy] = sym
306+
end
307+
308+
return result_cache[tocopy]
309+
end
310+
280311
function create_result(
281312
tocopy::ConcretePJRTNumber{T,D},
282313
@nospecialize(path::Tuple),
@@ -823,14 +854,14 @@ end
823854
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
824855
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}"
825856

826-
function probprog_pass(;
857+
function impulse_pass(;
827858
debug_dump::Bool=DEBUG_PROBPROG_DUMP_VALUE[],
828859
disable_optimizations::Bool=DEBUG_PROBPROG_DISABLE_OPT[],
829860
)
830861
if !disable_optimizations
831-
# TODO(#2063): Add probprog optimization passes
862+
# TODO(#2063): Add impulse optimization passes
832863
end
833-
return "probprog{debug-dump=$debug_dump postpasses=\"arith-raise{stablehlo=true}\"}"
864+
return "expand-impulse{debug-dump=$debug_dump postpasses=\"arith-raise{stablehlo=true}\"}"
834865
end
835866

836867
function run_pass_pipeline!(mod, pass_pipeline, key=""; enable_verifier=true)
@@ -1575,8 +1606,8 @@ function compile_mlir!(
15751606
raise_passes,
15761607
"enzyme-batch",
15771608
opt_passes2,
1578-
probprog_pass(),
1579-
"lower-probprog-to-stablehlo{backend=$backend}",
1609+
impulse_pass(),
1610+
"lower-impulse-to-stablehlo{backend=$backend}",
15801611
"outline-enzyme-regions",
15811612
enzyme_pass,
15821613
opt_passes2,
@@ -1592,7 +1623,7 @@ function compile_mlir!(
15921623
)...,
15931624
opt_passes2,
15941625
lower_enzymexla_passes,
1595-
"lower-probprog-trace-ops{backend=$backend}",
1626+
"lower-impulse-trace-ops{backend=$backend}",
15961627
jit,
15971628
]
15981629
else
@@ -1601,8 +1632,8 @@ function compile_mlir!(
16011632
opt_passes,
16021633
"enzyme-batch",
16031634
opt_passes2,
1604-
probprog_pass(),
1605-
"lower-probprog-to-stablehlo{backend=$backend}",
1635+
impulse_pass(),
1636+
"lower-impulse-to-stablehlo{backend=$backend}",
16061637
"outline-enzyme-regions",
16071638
enzyme_pass,
16081639
opt_passes2,
@@ -1620,13 +1651,13 @@ function compile_mlir!(
16201651
kern,
16211652
raise_passes,
16221653
lower_enzymexla_passes,
1623-
"lower-probprog-trace-ops{backend=$backend}",
1654+
"lower-impulse-trace-ops{backend=$backend}",
16241655
jit,
16251656
]
16261657
end,
16271658
",",
16281659
),
1629-
"probprog",
1660+
"impulse",
16301661
)
16311662
elseif compile_options.optimization_passes === :only_enzyme
16321663
run_pass_pipeline!(

src/Enzyme.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
const enzyme_out = 0
2-
const enzyme_dup = 1
3-
const enzyme_const = 2
4-
const enzyme_dupnoneed = 3
5-
const enzyme_outnoneed = 4
6-
const enzyme_constnoneed = 5
1+
@enumx EnzymeActivity begin
2+
OUT = 0
3+
DUPLICATED = 1
4+
CONST = 2
5+
DUPLICATED_NO_NEED = 3
6+
OUT_NO_NEED = 4
7+
CONST_NO_NEED = 5
8+
end
79

810
struct StackedBatchDuplicated{T,N,M,V<:AbstractArray{T,N},W<:AbstractArray{T,M}} <:
911
Annotation{V}
@@ -201,17 +203,17 @@ end
201203
end
202204

203205
@inline function act_from_type(::Type{<:Active}, reverse, needs_primal)
204-
return needs_primal ? enzyme_out : enzyme_outnoneed
206+
return needs_primal ? EnzymeActivity.OUT : EnzymeActivity.OUT_NO_NEED
205207
end
206208
@inline function act_from_type(::Type{<:Const}, reverse, needs_primal)
207-
return needs_primal ? enzyme_const : enzyme_constnoneed
209+
return needs_primal ? EnzymeActivity.CONST : EnzymeActivity.CONST_NO_NEED
208210
end
209211

210212
@inline function act_from_type(::Type{<:Duplicated}, reverse, needs_primal)
211213
if reverse
212-
return needs_primal ? enzyme_out : enzyme_outnoneed
214+
return needs_primal ? EnzymeActivity.OUT : EnzymeActivity.OUT_NO_NEED
213215
else
214-
return needs_primal ? enzyme_dup : enzyme_dupnoneed
216+
return needs_primal ? EnzymeActivity.DUPLICATED : EnzymeActivity.DUPLICATED_NO_NEED
215217
end
216218
end
217219
@inline function act_from_type(
@@ -221,7 +223,7 @@ end
221223
end
222224

223225
@inline function act_from_type(::Type{<:DuplicatedNoNeed}, reverse, needs_primal)
224-
return reverse ? enzyme_out : enzyme_dupnoneed
226+
return reverse ? EnzymeActivity.OUT : EnzymeActivity.DUPLICATED_NO_NEED
225227
end
226228
@inline function act_from_type(
227229
::Type{<:Union{BatchDuplicatedNoNeed,StackedBatchDuplicatedNoNeed}},
@@ -289,7 +291,9 @@ function set_act!(inp, path, reverse, tostore; emptypath=false, width=1)
289291
end
290292

291293
function act_attr(val)
292-
return MLIR.IR.Attribute(MLIR.API.enzymeActivityAttrGet(MLIR.IR.current_context(), val))
294+
return MLIR.IR.Attribute(
295+
MLIR.API.enzymeActivityAttrGet(MLIR.IR.current_context(), Int32(val))
296+
)
293297
end
294298

295299
function infer_activity(
@@ -342,7 +346,7 @@ function overload_autodiff(
342346
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
343347
fnwrap = mlir_fn_res.fnwrapped
344348

345-
activity = Int32[]
349+
activity = EnzymeActivity.T[]
346350
ad_inputs = MLIR.IR.Value[]
347351

348352
reverse_seeds = Dict{Tuple,MLIR.IR.Value}()
@@ -353,7 +357,7 @@ function overload_autodiff(
353357
push!(activity, act_from_type(arg, reverse))
354358
push_acts!(ad_inputs, arg, path[3:end], reverse)
355359

356-
if CMode <: ReverseMode && act_from_type(arg, false) == enzyme_dup
360+
if CMode <: ReverseMode && act_from_type(arg, false) == EnzymeActivity.DUPLICATED
357361
x = if width == 1
358362
arg.dval
359363
elseif arg.dval isa AbstractArray
@@ -370,7 +374,7 @@ function overload_autodiff(
370374
end
371375

372376
outtys = MLIR.IR.Type[]
373-
ret_activity = Int32[]
377+
ret_activity = EnzymeActivity.T[]
374378

375379
for a in linear_results
376380
if TracedUtils.has_idx(a, resprefix)
@@ -395,7 +399,7 @@ function overload_autodiff(
395399

396400
act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode))
397401
cst = nothing
398-
if act == enzyme_out || act == enzyme_outnoneed
402+
if act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
399403
if width == 1
400404
cst = @opcall fill(one(unwrapped_eltype(a)), size(a))
401405
else
@@ -407,26 +411,28 @@ function overload_autodiff(
407411
if CMode <: ReverseMode && TracedUtils.has_idx(a, argprefix)
408412
idx, path = TracedUtils.get_argidx(a, argprefix)
409413
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
410-
if act_from_type(arg, false) == enzyme_dup
414+
if act_from_type(arg, false) == EnzymeActivity.DUPLICATED
411415
seed = reverse_seeds[path]
412-
if cst == nothing
413-
if act == enzyme_const
414-
act = enzyme_out
415-
elseif act == enzyme_constnoneed
416-
act = enzyme_outnoneed
416+
if cst === nothing
417+
if act == EnzymeActivity.CONST
418+
act = EnzymeActivity.OUT
419+
elseif act == EnzymeActivity.CONST_NO_NEED
420+
act = EnzymeActivity.OUT_NO_NEED
417421
else
418422
@assert false
419423
end
420424
cst = seed
421425
else
422-
@assert act == enzyme_out || act == enzyme_outnoneed
426+
@assert (
427+
act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
428+
)
423429
cst = MLIR.IR.result(MLIR.Dialects.stablehlo.add(cst, seed), 1)
424430
end
425431
end
426432
end
427433

428434
push!(ret_activity, act)
429-
if cst != nothing
435+
if cst !== nothing
430436
push!(ad_inputs, cst)
431437
end
432438
else
@@ -437,7 +443,7 @@ function overload_autodiff(
437443
act = act_from_type(arg, reverse, true)
438444
push!(ret_activity, act)
439445

440-
if act == enzyme_out || act == enzyme_outnoneed
446+
if act == EnzymeActivity.OUT || act == EnzymeActivity.OUT_NO_NEED
441447
seed = reverse_seeds[path]
442448
push!(ad_inputs, seed)
443449
end
@@ -453,7 +459,11 @@ function overload_autodiff(
453459
end
454460

455461
for (i, act) in enumerate(activity)
456-
if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed
462+
if (
463+
act == EnzymeActivity.OUT ||
464+
act == EnzymeActivity.DUPLICATED ||
465+
act == EnzymeActivity.DUPLICATED_NO_NEED
466+
)
457467
push!(outtys, TracedUtils.batch_ty(width, in_tys[i]))
458468
end
459469
end
@@ -528,7 +538,7 @@ function overload_autodiff(
528538
idx, path = TracedUtils.get_argidx(a, argprefix)
529539

530540
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
531-
act_from_type(arg, reverse) != enzyme_out && continue
541+
act_from_type(arg, reverse) != EnzymeActivity.OUT && continue
532542

533543
if idx == 1 && fnwrap && arg isa Active
534544
@assert false

src/Ops.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,8 @@ end
650650
dims::Vector{Int};
651651
location=mlir_stacktrace("reshape", @__FILE__, @__LINE__),
652652
) where {T,N}
653+
@assert length(x) == prod(dims)
654+
653655
# HLO reshape semantics collapse the opposite way
654656
res1 = transpose(x, Int64[N:-1:1...])
655657
restype = mlir_type(TracedRArray{T,length(dims)}, collect(Int64, Base.reverse(dims)))

0 commit comments

Comments
 (0)