Skip to content

Commit

Permalink
Use square_view to avoid square check, instead of new hack
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Apr 15, 2024
1 parent 58fdb2f commit db710d5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ LinearAlgebra = "1.5"
LoopVectorization = "0.10,0.11, 0.12"
Polyester = "0.3.2,0.4.1, 0.5, 0.6, 0.7"
PrecompileTools = "1"
StrideArraysCore = "0.1.13, 0.2.1, 0.3, 0.4.1, 0.5"
TriangularSolve = "0.1.1"
StrideArraysCore = "0.5.5"
TriangularSolve = "0.2"
julia = "1.5"

[extras]
Expand Down
15 changes: 7 additions & 8 deletions src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ using TriangularSolve: ldiv!
using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
using StrideArraysCore
using StrideArraysCore: square_view
using Polyester: @batch

@generated function _unit_lower_triangular(B::A) where {T, A <: AbstractMatrix{T}}
Expr(:new, UnitLowerTriangular{T, A}, :B)
end
# 1.7 compat
normalize_pivot(t::Val{T}) where {T} = t
to_stdlib_pivot(t) = t
Expand Down Expand Up @@ -55,6 +53,7 @@ end
if CUSTOMIZABLE_PIVOT
function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV},
B::StridedVecOrMat{T}) where {T <: BlasFloat}
tri = @inbounds square_view(A.factors, size(A.factors, 1))
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
end
end
Expand Down Expand Up @@ -138,10 +137,10 @@ end
info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int
@inbounds if m < n # fat matrix
# [AL AR]
AL = @view A[:, 1:m]
AL = square_view(A, m)
AR = @view A[:, (m + 1):n]
Pivot && apply_permutation!(ipiv, AR, Val{Thread}())
ldiv!(_unit_lower_triangular(AL), AR, Val{Thread}())
ldiv!(UnitLowerTriangular(AL), AR, Val{Thread}())
end
info
end
Expand Down Expand Up @@ -190,7 +189,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b

# ======================================== #
# Now, our LU process looks like this
# [ P1 ] [ A11 A21 ] [ L11 0 ] [ U11 U12 ]
# [ P1 ] [ A11 A12 ] [ L11 0 ] [ U11 U12 ]
# [ ] [ ] = [ ] [ ]
# [ P2 ] [ A21 A22 ] [ L21 I ] [ 0 A′22 ]
# ======================================== #
Expand All @@ -203,7 +202,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
# AL AR
# [A11 A12]
# [A21 A22]
A11 = @view A[1:n1, 1:n1]
A11 = square_view(A, n1)
A12 = @view A[1:n1, (n1 + 1):n]
A21 = @view A[(n1 + 1):m, 1:n1]
A22 = @view A[(n1 + 1):m, (n1 + 1):n]
Expand All @@ -223,7 +222,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
# [ A22 ] [ 0 ] [ A22 ]
Pivot && apply_permutation!(P1, AR, thread)
# A12 = L11 U12 => U12 = L11 \ A12
ldiv!(_unit_lower_triangular(A11), A12, thread)
ldiv!(UnitLowerTriangular(A11), A12, thread)
# Schur complement:
# We have A22 = L21 U12 + A′22, hence
# A′22 = A22 - L21 U12
Expand Down

0 comments on commit db710d5

Please sign in to comment.