Don't overlay mul! for sparse arrays#1739
Don't overlay mul! for sparse arrays#1739albertomercurio wants to merge 3 commits intoEnzymeAD:mainfrom
Conversation
| (:AbstractVector, :AbstractMatrix, :AbstractVector), | ||
| (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), |
There was a problem hiding this comment.
Should we make cT and bT Dense as well?
There was a problem hiding this comment.
I was just limiting the matrix to be dense to make it more general. But I can add the constraint also to the other two arguments.
There was a problem hiding this comment.
Should I make them DenseArrays as well?
There was a problem hiding this comment.
Seems like this prevents dispatches for Triangular matrices and such from going down this route.
|
Most importantly, this PR should wait #1696 to be merged first. |
|
I have relaxed again the method back to When using a sparse matrix, it correctly calls ERROR: LoadError: "Cannot trace existing trace type"
Stacktrace:
[1] #make_tracer#142
@ ~/.julia/dev/Reactant/src/Tracing.jl:1292
[2] prepare_mlir_fn_args(args::Tuple{Int64, Nothing, KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)}, Reactant.TracedRArray{Float64, 1}, GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, Reactant.TracedRArray{Float64, 1}}, name::String, concretein::Bool, toscalar::Bool, argprefix::Symbol, runtime::Val{:PJRT}, optimize_then_pad::Bool, do_transpose::Bool, input_shardings::Nothing, verify_arg_names::Nothing)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:453
[3] make_mlir_fn(f::typeof(ReactantKernelAbstractionsExt.tokw), args::Tuple{Int64, Nothing, KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)}, Reactant.TracedRArray{Float64, 1}, GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, Reactant.TracedRArray{Float64, 1}}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:324
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{Int64, Nothing, KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)}, Reactant.TracedRArray{Float64, 1}, GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, Reactant.TracedRArray{Float64, 1}}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1742", Symbol} where var"#s1742", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1603
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/dev/Reactant/src/Compiler.jl:1570 [inlined]
[6] compile_xla(f::Function, args::Tuple{Int64, Nothing, KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)}, Reactant.TracedRArray{Float64, 1}, GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, Reactant.TracedRArray{Float64, 1}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3492
[7] compile_xla
@ ~/.julia/dev/Reactant/src/Compiler.jl:3465 [inlined]
[8] compile(f::Function, args::Tuple{Int64, Nothing, KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)}, Reactant.TracedRArray{Float64, 1}, GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, Reactant.TracedRArray{Float64, 1}}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3567
[9] macro expansion
@ ~/.julia/dev/Reactant/src/Compiler.jl:2642 [inlined]
[10] (::KernelAbstractions.Kernel{ReactantKernelAbstractionsExt.ReactantBackend, KernelAbstractions.NDIteration.DynamicSize, KernelAbstractions.NDIteration.DynamicSize, typeof(gpu_spmv_kernel!)})(::Reactant.TracedRArray{Float64, 1}, ::Vararg{Any}; ndrange::Int64, workgroupsize::Nothing)
@ ReactantKernelAbstractionsExt ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:107
[11] Kernel
@ ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:103 [inlined]
[12] spmv!
@ ~/.julia/dev/Reactant/test_sparse_debug.jl:46 [inlined]
[13] mul!(y::Reactant.TracedRArray{Float64, 1}, A::GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, x::Reactant.TracedRArray{Float64, 1}, α::Bool, β::Bool)
@ Main ~/.julia/dev/Reactant/test_sparse_debug.jl:64
[14] #mul!
@ ~/.julia/dev/Reactant/src/Overlay.jl:136 [inlined]
[15] (::Nothing)(none::typeof(mul!), none::Reactant.TracedRArray{Float64, 1}, none::GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, none::Reactant.TracedRArray{Float64, 1}, none::Bool, none::Bool)
@ Reactant ./<missing>:0
[16] call_with_reactant(::typeof(mul!), ::Reactant.TracedRArray{Float64, 1}, ::GenericSparseMatrixCSR{Float64, Int64, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Int64, 1}, Reactant.TracedRArray{Float64, 1}}, ::Reactant.TracedRArray{Float64, 1}, ::Bool, ::Bool)
@ Reactant ~/.julia/dev/Reactant/src/utils.jl:519
[17] make_mlir_fn(f::typeof(mul!), args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, GenericSparseMatrixCSR{Float64, Int64, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Bool, Bool}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{:PJRT}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:348
[18] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, GenericSparseMatrixCSR{Float64, Int64, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Bool, Bool}, compile_options::CompileOptions, callcache::Dict{Vector, @NamedTuple{f_name::String, mlir_result_types::Vector{Reactant.MLIR.IR.Type}, traced_result, mutated_args::Vector{Int64}, linear_results::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, fnwrapped::Bool, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol}}, sdycache::Dict{Tuple{AbstractVector{Int64}, NTuple{var"#s1742", Symbol} where var"#s1742", NTuple{N, Int64} where N}, @NamedTuple{sym_name::Reactant.MLIR.IR.Attribute, mesh_attr::Reactant.MLIR.IR.Attribute, mesh_op::Reactant.MLIR.IR.Operation, mesh::Reactant.Sharding.Mesh}}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{:PJRT}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1603
[19] compile_mlir! (repeats 2 times)
@ ~/.julia/dev/Reactant/src/Compiler.jl:1570 [inlined]
[20] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, GenericSparseMatrixCSR{Float64, Int64, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Bool, Bool}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3492
[21] compile_xla
@ ~/.julia/dev/Reactant/src/Compiler.jl:3465 [inlined]
[22] compile(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, GenericSparseMatrixCSR{Float64, Int64, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Int64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Bool, Bool}; kwargs::@Kwargs{fn_kwargs::@NamedTuple{}, client::Nothing, reshape_propagate::Symbol, raise_first::Bool, assert_nonallocating::Bool, serializable::Bool, legalize_chlo_to_stablehlo::Bool, transpose_propagate::Symbol, donated_args::Symbol, optimize_then_pad::Bool, cudnn_hlo_optimize::Bool, compile_options::Missing, sync::Bool, no_nan::Bool, raise::Bool, shardy_passes::Symbol, optimize::Bool, optimize_communications::Bool})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3567
[23] top-level scope
@ ~/.julia/dev/Reactant/src/Compiler.jl:2642 |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1739 +/- ##
==========================================
- Coverage 68.16% 64.69% -3.48%
==========================================
Files 109 113 +4
Lines 11779 12552 +773
==========================================
+ Hits 8029 8120 +91
- Misses 3750 4432 +682 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Another approach would be to define the overlayed function for all the possible matrices except sparse. |
e8915ce to
563c4b8
Compare
|
Ok I should have managed to make it work. I have added a check in the KernelAbstractions.jl method, I don't know if this makes sense or not. |
| function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing) | ||
| # If we're already inside a compilation/tracing context, or if any arguments are traced, | ||
| # we should trace through this kernel call instead of trying to compile it again. | ||
| if Reactant.within_compile() || any(ReactantCore.is_traced, args) |
There was a problem hiding this comment.
this seems extraneous can this be done without?
There was a problem hiding this comment.
If I don't put it I get the error
ERROR: "Cannot trace existing trace type"
Stacktrace:
[1] make_tracer(seen::Reactant.OrderedIdDict{…}, prev::Reactant.TracedRArray{…}, path::Any, mode::Reactant.TraceMode; toscalar::Bool, tobatch::Nothing, sharding::Any, runtime::Any, kwargs::@Kwargs{})
@ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:1298
[2] prepare_mlir_fn_args(args::Tuple{…}, name::String, concretein::Bool, toscalar::Bool, argprefix::Symbol, runtime::Val{…}, optimize_then_pad::Bool, do_transpose::Bool, input_shardings::Nothing, verify_arg_names::Nothing)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:450
[3] make_mlir_fn(f::typeof(ReactantKernelAbstractionsExt.tokw), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:321
[4] make_mlir_fn
@ ~/.julia/dev/Reactant/src/TracedUtils.jl:275 [inlined]
[5] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(ReactantKernelAbstractionsExt.tokw), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1608
[6] compile_mlir!
@ ~/.julia/dev/Reactant/src/Compiler.jl:1572 [inlined]
[7] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3500
[8] compile_xla
@ ~/.julia/dev/Reactant/src/Compiler.jl:3472 [inlined]
[9] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3576
[10] macro expansion
@ ~/.julia/dev/Reactant/src/Compiler.jl:2649 [inlined]
[11] (::KernelAbstractions.Kernel{…})(::Reactant.TracedRArray{…}, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
@ ReactantKernelAbstractionsExt ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:116
[12] Kernel
@ ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:104 [inlined]
[13] spmv!
@ ~/.julia/dev/Reactant/testsparse/script.jl:48 [inlined]
[14] mul!(y::Reactant.TracedRArray{…}, A::GenericSparseMatrixCSR{…}, x::Reactant.TracedRArray{…}, α::Bool, β::Bool)
@ Main ~/.julia/dev/Reactant/testsparse/script.jl:66
[15] #mul!
@ ~/.julia/dev/Reactant/src/Overlay.jl:136 [inlined]
[16] (::Nothing)(none::typeof(mul!), none::Reactant.TracedRArray{…}, none::GenericSparseMatrixCSR{…}, none::Reactant.TracedRArray{…}, none::Bool, none::Bool)
@ Reactant ./<missing>:0
[17] call_with_reactant(::typeof(mul!), ::Reactant.TracedRArray{…}, ::GenericSparseMatrixCSR{…}, ::Reactant.TracedRArray{…}, ::Bool, ::Bool)
@ Reactant ~/.julia/dev/Reactant/src/utils.jl:519
[18] #mul!
@ ~/.julia/dev/Reactant/src/Overlay.jl:143 [inlined]
[19] (::Nothing)(none::typeof(mul!), none::Reactant.TracedRArray{…}, none::GenericSparseMatrixCSR{…}, none::Reactant.TracedRArray{…})
@ Reactant ./<missing>:0
[20] #mul!
@ ~/.julia/dev/Reactant/src/Overlay.jl:143 [inlined]
[21] call_with_reactant(::typeof(mul!), ::Reactant.TracedRArray{…}, ::GenericSparseMatrixCSR{…}, ::Reactant.TracedRArray{…})
@ Reactant ~/.julia/dev/Reactant/src/utils.jl:0
[22] make_mlir_fn(f::typeof(mul!), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
@ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:345
[23] make_mlir_fn
@ ~/.julia/dev/Reactant/src/TracedUtils.jl:275 [inlined]
[24] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(mul!), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1608
[25] compile_mlir!
@ ~/.julia/dev/Reactant/src/Compiler.jl:1572 [inlined]
[26] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3500
[27] compile_xla
@ ~/.julia/dev/Reactant/src/Compiler.jl:3472 [inlined]
[28] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3576
[29] top-level scope
@ ~/.julia/dev/Reactant/src/Compiler.jl:2649
Some type information was truncated. Use `show(err)` to see complete types.
Another alternative it to remove this check and specify the method directly as
for (cT, aT, bT) in (
(:AbstractVector, :AnyDenseMatrix, :AbstractVector),
(:AbstractMatrix, :AnyDenseMatrix, :AbstractVecOrMat),
)
@eval begin
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)where AnyDenseMatrix is something like
const AnyDenseMatrix = Union{DenseMatrix, Transpose{Any, <:DenseMatrix}, Symmetric{Any, <:DenseMatrix}, UpperTriangular{Any, <:DenseMatrix}} # And all the other possible wrappersThis basically keeps the orgiginal code unchanged.
I don't know which case you do prefer.
There was a problem hiding this comment.
It does sound a bit dangerous to assume that we know in advance every possible wrapper of a dense matrix? Not all of them are in Base or LinearAlgebra
|
Hello, I wrote PR #2767 as an alternative to this PR. I think that one is slightly better, but let me know. |
The
mul!should not be overlayed for sparse arrays, as they require custom methods.This PR fixes #1296