Skip to content

Commit

Permalink
Format and add ldiv support for NotIPIV
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Aug 3, 2023
1 parent afec32d commit 45d4af5
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 88 deletions.
18 changes: 9 additions & 9 deletions perf/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@ else
BLAS.vendor() === :mkl ? :MKL : :OpenBLAS
end
df = DataFrame(Size = ns,
Reference = ref_mflops)
Reference = ref_mflops)
setproperty!(df, blaslib, bas_mflops)
setproperty!(df, Symbol("RF with default threshold"), rec_mflops)
setproperty!(df, Symbol("RF fully recursive"), rec4_mflops)
setproperty!(df, Symbol("RF fully iterative"), rec800_mflops)
df = stack(df,
[Symbol("RF with default threshold"),
Symbol("RF fully recursive"),
Symbol("RF fully iterative"),
blaslib,
:Reference], variable_name = :Library, value_name = :GFLOPS)
[Symbol("RF with default threshold"),
Symbol("RF fully recursive"),
Symbol("RF fully iterative"),
blaslib,
:Reference], variable_name = :Library, value_name = :GFLOPS)
plt = df |> @vlplot(:line, color={:Library, scale = {scheme = "category10"}},
x={:Size}, y={:GFLOPS},
width=1000, height=600)
x={:Size}, y={:GFLOPS},
width=1000, height=600)
save(joinpath(homedir(), "Pictures",
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)
"lu_float64_$(VERSION)_$(Sys.CPU_NAME)_$(nc)cores_$blaslib.png"), plt)

#=
using Plot
Expand Down
4 changes: 3 additions & 1 deletion src/RecursiveFactorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ include("./lu.jl")

import PrecompileTools

PrecompileTools.@compile_workload begin lu!(rand(2, 2)) end
PrecompileTools.@compile_workload begin
lu!(rand(2, 2))
end

end # module
108 changes: 59 additions & 49 deletions src/lu.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using LoopVectorization
using TriangularSolve: ldiv!
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
LinearAlgebra, Adjoint, Transpose
LinearAlgebra, Adjoint, Transpose, UpperTriangular
using StrideArraysCore
using Polyester: @batch

Expand Down Expand Up @@ -41,16 +41,22 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)

if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_cols!)
function LinearAlgebra._ipiv_cols!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
B::StridedVecOrMat)
B::StridedVecOrMat)
return B
end
end
if CUSTOMIZABLE_PIVOT && isdefined(LinearAlgebra, :_ipiv_rows!)
function LinearAlgebra._ipiv_rows!(::LU{<:Any, <:Any, NotIPIV}, ::OrdinalRange,
B::StridedVecOrMat)
B::StridedVecOrMat)
return B
end
end
if CUSTOMIZABLE_PIVOT
function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV},
B::StridedVecOrMat{T}) where {T <: BlasFloat}
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
end
end

function lu!(A, pivot = Val(true), thread = Val(true); check = true, kwargs...)
m, n = size(A)
Expand Down Expand Up @@ -80,11 +86,11 @@ recurse(_) = false
_ptrarray(ipiv) = PtrArray(ipiv)
_ptrarray(ipiv::NotIPIV) = ipiv
function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
pivot = Val(true), thread = Val(true);
check::Bool = true,
# the performance is not sensitive wrt blocksize, and 8 is a good default
blocksize::Integer = length(A) 40_000 ? 8 : 16,
threshold::Integer = pick_threshold()) where {T}
pivot = Val(true), thread = Val(true);
check::Bool = true,
# the performance is not sensitive wrt blocksize, and 8 is a good default
blocksize::Integer = length(A) 40_000 ? 8 : 16,
threshold::Integer = pick_threshold()) where {T}
pivot = normalize_pivot(pivot)
info = zero(BlasInt)
m, n = size(A)
Expand All @@ -94,10 +100,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
end
if recurse(A) && mnmin > threshold
if T <: Union{Float32, Float64}
GC.@preserve ipiv A begin info = recurse!(view(PtrArray(A), axes(A)...), pivot,
m, n, mnmin,
_ptrarray(ipiv), info, blocksize,
thread) end
GC.@preserve ipiv A begin
info = recurse!(view(PtrArray(A), axes(A)...), pivot,
m, n, mnmin,
_ptrarray(ipiv), info, blocksize,
thread)
end
else
info = recurse!(A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
end
Expand All @@ -109,7 +117,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
end

@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
::Val{true}) where {Pivot}
::Val{true}) where {Pivot}
if length(A) * _sizeof(eltype(A)) >
0.92 * LoopVectorization.VectorizationBase.cache_size(Val(2))
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(true))
Expand All @@ -118,11 +126,11 @@ end
end
end
@inline function recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
::Val{false}) where {Pivot}
::Val{false}) where {Pivot}
_recurse!(A, Val{Pivot}(), m, n, mnmin, ipiv, info, blocksize, Val(false))
end
@inline function _recurse!(A, ::Val{Pivot}, m, n, mnmin, ipiv, info, blocksize,
::Val{Thread}) where {Pivot, Thread}
::Val{Thread}) where {Pivot, Thread}
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
@inbounds if m < n # fat matrix
# [AL AR]
Expand Down Expand Up @@ -166,7 +174,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
nothing
end
function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, blocksize,
thread)::BlasInt where {T, Pivot}
thread)::BlasInt where {T, Pivot}
@inbounds begin
if n <= max(blocksize, 1)
info = _generic_lufact!(A, Val(Pivot), ipiv, info)
Expand Down Expand Up @@ -262,44 +270,46 @@ end
function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where {Pivot}
m, n = size(A)
minmn = length(ipiv)
@inbounds begin for k in 1:minmn
# find index max
kp = k
if Pivot
amax = abs(zero(eltype(A)))
for i in k:m
absi = abs(A[i, k])
if absi > amax
kp = i
amax = absi
@inbounds begin
for k in 1:minmn
# find index max
kp = k
if Pivot
amax = abs(zero(eltype(A)))
for i in k:m
absi = abs(A[i, k])
if absi > amax
kp = i
amax = absi
end
end
ipiv[k] = kp
end
ipiv[k] = kp
end
if !iszero(A[kp, k])
if k != kp
# Interchange
@simd for i in 1:n
tmp = A[k, i]
A[k, i] = A[kp, i]
A[kp, i] = tmp
if !iszero(A[kp, k])
if k != kp
# Interchange
@simd for i in 1:n
tmp = A[k, i]
A[k, i] = A[kp, i]
A[kp, i] = tmp
end
end
# Scale first column
Akkinv = inv(A[k, k])
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
A[i, k] *= Akkinv
end
elseif info == 0
info = k
end
# Scale first column
Akkinv = inv(A[k, k])
@turbo check_empty=true warn_check_args=false for i in (k + 1):m
A[i, k] *= Akkinv
end
elseif info == 0
info = k
end
k == minmn && break
# Update the rest
@turbo warn_check_args=false for j in (k + 1):n
for i in (k + 1):m
A[i, j] -= A[i, k] * A[k, j]
k == minmn && break
# Update the rest
@turbo warn_check_args=false for j in (k + 1):n
for i in (k + 1):m
A[i, j] -= A[i, k] * A[k, j]
end
end
end
end end
end
return info
end
60 changes: 31 additions & 29 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,37 @@ function testlu(A, MF, BF)
end
testlu(A::Union{Transpose, Adjoint}, MF, BF) = testlu(parent(A), parent(MF), BF)

@testset "Test LU factorization" begin for _p in (true, false),
T in (Float64, Float32, ComplexF64, ComplexF32,
Real)
@testset "Test LU factorization" begin
for _p in (true, false),
T in (Float64, Float32, ComplexF64, ComplexF32,
Real)

p = Val(_p)
for (i, s) in enumerate([1:10; 50:80:200; 300])
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
siz = (s, s + 2)
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
if isconcretetype(T)
A = rand(T, siz...)
else
_A = rand(siz...)
A = Matrix{T}(undef, siz...)
copyto!(A, _A)
p = Val(_p)
for (i, s) in enumerate([1:10; 50:80:200; 300])
iseven(i) && (p = RecursiveFactorization.to_stdlib_pivot(p))
siz = (s, s + 2)
@info("size: $(siz[1]) × $(siz[2]), T = $T, p = $_p")
if isconcretetype(T)
A = rand(T, siz...)
else
_A = rand(siz...)
A = Matrix{T}(undef, siz...)
copyto!(A, _A)
end
MF = mylu(A, p)
BF = baselu(A, p)
testlu(A, MF, BF)
testlu(A, mylu(A, p, Val(false)), BF)
A′ = permutedims(A)
MF′ = mylu(A′', p)
testlu(A′', MF′, BF)
testlu(A′', mylu(A′', p, Val(false)), BF)
i = rand(1:s) # test `MF.info`
A[:, i] .= 0
MF = mylu(A, p, check = false)
BF = baselu(A, p, check = false)
testlu(A, MF, BF)
testlu(A, mylu(A, p, Val(false), check = false), BF)
end
MF = mylu(A, p)
BF = baselu(A, p)
testlu(A, MF, BF)
testlu(A, mylu(A, p, Val(false)), BF)
A′ = permutedims(A)
MF′ = mylu(A′', p)
testlu(A′', MF′, BF)
testlu(A′', mylu(A′', p, Val(false)), BF)
i = rand(1:s) # test `MF.info`
A[:, i] .= 0
MF = mylu(A, p, check = false)
BF = baselu(A, p, check = false)
testlu(A, MF, BF)
testlu(A, mylu(A, p, Val(false), check = false), BF)
end
end end
end

0 comments on commit 45d4af5

Please sign in to comment.