Skip to content

Commit 103036b

Browse files
authored
fix: better symmetric materialization (#1909)
1 parent 4687017 commit 103036b

1 file changed

Lines changed: 5 additions & 11 deletions

File tree

src/stdlibs/LinearAlgebra.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,11 @@ function ReactantCore.materialize_traced_array(
124124
m, n = size(x)
125125
row_idxs = @opcall iota(Int, [m, n]; iota_dimension=1)
126126
col_idxs = @opcall iota(Int, [m, n]; iota_dimension=2)
127-
if x.uplo == 'L'
128-
indicator = @opcall compare(row_idxs, col_idxs; comparison_direction="GT")
129-
x_lt = @opcall select(indicator, parent(x), zero(parent(x)))
130-
x_ltd = materialize_traced_array(LowerTriangular(parent(x)))
131-
return @opcall add(x_lt, @opcall(transpose(x_ltd, [2, 1])))
132-
else
133-
indicator = @opcall compare(row_idxs, col_idxs; comparison_direction="LT")
134-
x_ut = @opcall select(indicator, parent(x), zero(parent(x)))
135-
x_utd = materialize_traced_array(UpperTriangular(parent(x)))
136-
return @opcall add(@opcall(transpose(x_utd, [2, 1])), x_ut)
137-
end
127+
indicator = @opcall compare(
128+
row_idxs, col_idxs; comparison_direction=x.uplo == 'L' ? "GT" : "LT"
129+
)
130+
x_transposed = @opcall transpose(parent(x), [2, 1])
131+
return @opcall select(indicator, parent(x), x_transposed)
138132
end
139133

140134
function TracedUtils.set_mlir_data!(

0 commit comments

Comments
 (0)