From db710d5d7f5912dff5246763d20febe5196764ee Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Mon, 15 Apr 2024 12:32:25 -0400 Subject: [PATCH] Use `square_view` to avoid square check, instead of `new` hack --- Project.toml | 4 ++-- src/lu.jl | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index e8dfb17..c1ff3df 100644 --- a/Project.toml +++ b/Project.toml @@ -16,8 +16,8 @@ LinearAlgebra = "1.5" LoopVectorization = "0.10,0.11, 0.12" Polyester = "0.3.2,0.4.1, 0.5, 0.6, 0.7" PrecompileTools = "1" -StrideArraysCore = "0.1.13, 0.2.1, 0.3, 0.4.1, 0.5" -TriangularSolve = "0.1.1" +StrideArraysCore = "0.5.5" +TriangularSolve = "0.2" julia = "1.5" [extras] diff --git a/src/lu.jl b/src/lu.jl index 458ab4d..667e146 100644 --- a/src/lu.jl +++ b/src/lu.jl @@ -3,11 +3,9 @@ using TriangularSolve: ldiv! using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat using StrideArraysCore +using StrideArraysCore: square_view using Polyester: @batch -@generated function _unit_lower_triangular(B::A) where {T, A <: AbstractMatrix{T}} - Expr(:new, UnitLowerTriangular{T, A}, :B) -end # 1.7 compat normalize_pivot(t::Val{T}) where {T} = t to_stdlib_pivot(t) = t @@ -55,6 +53,7 @@ end if CUSTOMIZABLE_PIVOT function LinearAlgebra.ldiv!(A::LU{T, <:StridedMatrix, <:NotIPIV}, B::StridedVecOrMat{T}) where {T <: BlasFloat} + tri = @inbounds square_view(A.factors, size(A.factors, 1)) ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B)) end end @@ -138,10 +137,10 @@ end info = reckernel!(A, Val(Pivot), m, mnmin, ipiv, info, blocksize, Val(Thread))::Int @inbounds if m < n # fat matrix # [AL AR] - AL = @view A[:, 1:m] + AL = square_view(A, m) AR = @view A[:, (m + 1):n] Pivot && apply_permutation!(ipiv, AR, Val{Thread}()) - ldiv!(_unit_lower_triangular(AL), AR, Val{Thread}()) + ldiv!(UnitLowerTriangular(AL), AR, Val{Thread}()) end info end @@ -190,7 +189,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b # ======================================== # # Now, our LU process looks like this - # [ P1 ] [ A11 A21 ] [ L11 0 ] [ U11 U12 ] + # [ P1 ] [ A11 A12 ] [ L11 0 ] [ U11 U12 ] # [ ] [ ] = [ ] [ ] # [ P2 ] [ A21 A22 ] [ L21 I ] [ 0 A′22 ] # ======================================== # @@ -203,7 +202,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b # AL AR # [A11 A12] # [A21 A22] - A11 = @view A[1:n1, 1:n1] + A11 = square_view(A, n1) A12 = @view A[1:n1, (n1 + 1):n] A21 = @view A[(n1 + 1):m, 1:n1] A22 = @view A[(n1 + 1):m, (n1 + 1):n] @@ -223,7 +222,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b # [ A22 ] [ 0 ] [ A22 ] Pivot && apply_permutation!(P1, AR, thread) # A12 = L11 U12 => U12 = L11 \ A12 - ldiv!(_unit_lower_triangular(A11), A12, thread) + ldiv!(UnitLowerTriangular(A11), A12, thread) # Schur complement: # We have A22 = L21 U12 + A′22, hence # A′22 = A22 - L21 U12