Skip to content

Commit 9f2afce

Browse files
Copilotavik-pal
andauthored
Fix permutedims! to support AnyTracedRArray destinations (including ReshapedArray)
Agent-Logs-Url: https://github.com/EnzymeAD/Reactant.jl/sessions/36ff73e5-274e-4c26-bbf9-acf3130b3cb0 Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com>
1 parent 47d6412 commit 9f2afce

2 files changed

Lines changed: 17 additions & 1 deletion

File tree

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
13381338
return @opcall transpose(materialize_traced_array(A), Int64[perm...])
13391339
end
13401340

1341-
function Base.permutedims!(dest::TracedRArray, src::AnyTracedRArray, perm)
1341+
function Base.permutedims!(dest::AnyTracedRArray, src::AnyTracedRArray, perm)
13421342
result = @opcall transpose(materialize_traced_array(src), Int64[perm...])
13431343
TracedUtils.set_mlir_data!(dest, result.mlir_data)
13441344
return dest

test/core/wrapped_arrays.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,19 @@ end
294294
@test @jit(view_transpose(x_ra)) view_transpose(x)
295295
@test @jit(view_diagonal(x_ra)) view_diagonal(x)
296296
end
297+
298+
function permutedims!_reshaped(A, B)
299+
A_reshaped = reshape(A, 2, 2, 2)
300+
B_reshaped = reshape(B, 2, 2, 2)
301+
permutedims!(A_reshaped, B_reshaped, (2, 3, 1))
302+
return A
303+
end
304+
305+
@testset "permutedims! on reshaped arrays" begin
306+
A = randn(Float32, 8)
307+
B = randn(Float32, 8)
308+
A_ra = Reactant.to_rarray(A)
309+
B_ra = Reactant.to_rarray(B)
310+
311+
@test Array(@jit(permutedims!_reshaped(A_ra, B_ra))) permutedims!_reshaped(A, B)
312+
end

0 commit comments

Comments
 (0)