Skip to content

Commit 5f1b3d2

Browse files
authored
Merge pull request #43 from JuliaParallel/anj/chol
Add cholesky wrapper
2 parents e9f3338 + 28abba2 commit 5f1b3d2

File tree

6 files changed

+116
-38
lines changed

6 files changed

+116
-38
lines changed

src/blas_like/level1.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ for (elty, relty, ext) in ((:Float32, :Float32, :s),
1313
return y
1414
end
1515

16-
function copy!(src::$mat{$elty}, dest::$mat{$elty})
16+
# Which is opposite Julia's copy! so we call it _copy! to avoid confusion
17+
function _copy!(src::$mat{$elty}, dest::$mat{$elty})
1718
ElError(ccall(($(string("ElCopy", sym, ext)), libEl), Cuint,
1819
(Ptr{Void}, Ptr{Void}),
1920
src.obj, dest.obj))
@@ -72,6 +73,3 @@ for (elty, relty, ext) in ((:Float32, :Float32, :s),
7273
end
7374
end
7475
end
75-
76-
copy(A::ElementalMatrix) = copy!(A, similar(A))
77-
length(A::ElementalMatrix) = prod(size(A))

src/blas_like/level3.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@ for (elty, relty, ext) in ((:Float32, :Float32, :s),
66
for (mat, sym) in ((:Matrix, "_"),
77
(:DistMatrix, "Dist_"))
88

9-
for (transA, elenumA) in (("", :NORMAL), ("t", :TRANSPOSE), ("c", :ADJOINT))
10-
for (transB, elenumB) in (("", :NORMAL), ("t", :TRANSPOSE), ("c", :ADJOINT))
11-
f = Symbol("A", transA, "_mul_B", transB, "!")
9+
@eval begin
1210

13-
@eval begin
14-
function ($f)(α::$elty, A::$mat{$elty}, B::$mat{$elty}, β::$elty, C::$mat{$elty})
15-
ElError(ccall(($(string("ElGemm", sym, ext)), libEl), Cuint,
16-
(Cint, Cint, $elty, Ptr{Void}, Ptr{Void}, $elty, Ptr{Void}),
17-
$elenumA, $elenumB, α, A.obj, B.obj, β, C.obj))
18-
return C
19-
end
20-
end
11+
function gemm(orientationOfA::Orientation, orientationOfB::Orientation, α::$elty, A::$mat{$elty}, B::$mat{$elty}, β::$elty, C::$mat{$elty})
12+
ElError(ccall(($(string("ElGemm", sym, ext)), libEl), Cuint,
13+
(Orientation, Orientation, $elty, Ptr{Void}, Ptr{Void}, $elty, Ptr{Void}),
14+
orientationOfA, orientationOfB, α, A.obj, B.obj, β, C.obj))
15+
return C
16+
end
17+
18+
function trsm(side::LeftOrRight, uplo::UpperOrLower, orientation::Orientation, diag::UnitOrNonUnit, α::$elty, A::$mat{$elty}, B::$mat{$elty})
19+
ElError(ccall(($(string("ElTrsm", sym, ext)), libEl), Cuint,
20+
(LeftOrRight, UpperOrLower, Orientation, UnitOrNonUnit,
21+
$elty, Ptr{Void}, Ptr{Void}),
22+
side, uplo, orientation, diag,
23+
α, A.obj, B.obj))
24+
return B
2125
end
2226
end
2327
end

src/core/types.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ eltype{T}(A::ElementalMatrix{T}) = T
5858
@enum Dist MC MD MR VC VR STAR CIRC
5959
@enum Orientation NORMAL TRANSPOSE ADJOINT
6060
@enum UpperOrLower LOWER UPPER
61+
@enum LeftOrRight LEFT RIGHT
62+
@enum UnitOrNonUnit NON_UNIT UNIT
6163

6264
# Get MPIWorldComm
6365
function CommWorldValue()

src/julia/generic.jl

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Julia interface when not defined in source files
22

3-
eltype{T}(x::DistMultiVec{T}) = T
3+
Base.eltype{T}(x::DistMultiVec{T}) = T
44

5-
function size(A::ElementalMatrix, i::Integer)
5+
function Base.size(A::ElementalMatrix, i::Integer)
66
if i < 1
77
error("dimension out of range")
88
elseif i == 1
@@ -14,17 +14,62 @@ function size(A::ElementalMatrix, i::Integer)
1414
end
1515
end
1616

17-
size(A::ElementalMatrix) = (size(A, 1), size(A, 2))
17+
Base.size(A::ElementalMatrix) = (size(A, 1), size(A, 2))
1818

19-
(*){T<:ElementalMatrix}(A::T, B::T) = A_mul_B!(one(eltype(A)), A, B, zero(eltype(A)), similar(A, (size(A, 1), size(B, 2))))
20-
(*){T}(A::DistSparseMatrix{T}, B::DistMultiVec{T}) = A_mul_B!(one(T), A, B, zero(T), similar(B, (size(A, 1), size(B, 2))))
21-
Ac_mul_B{T<:ElementalMatrix}(A::T, B::T) = Ac_mul_B!(one(eltype(A)), A, B, zero(eltype(A)), similar(A, (size(A, 2), size(B, 2))))
22-
Ac_mul_B{T}(A::DistSparseMatrix{T}, B::DistMultiVec{T}) = Ac_mul_B!(one(T), A, B, zero(T), similar(B, (size(A, 2), size(B, 2))))
19+
Base.copy!(A::T, B::T) where {T<:ElementalMatrix} = _copy!(B, A)
20+
# copy(A::ElementalMatrix) = copy!(similar(A), A)
21+
Base.length(A::ElementalMatrix) = prod(size(A))
22+
23+
## Current mutating Julia multiplication API
24+
Base.A_mul_B!(C::T, A::T, B::T) where {T<:ElementalMatrix} = gemm(NORMAL, NORMAL, one(eltype(T)), A, B, zero(eltype(T)), C)
25+
Base.Ac_mul_B!(C::T, A::T, B::T) where {T<:ElementalMatrix} = gemm(ADJOINT, NORMAL, one(eltype(T)), A, B, zero(eltype(T)), C)
26+
Base.At_mul_B!(C::T, A::T, B::T) where {T<:ElementalMatrix} = gemm(TRANSPOSE, NORMAL, one(eltype(T)), A, B, zero(eltype(T)), C)
27+
Base.A_mul_Bc!(C::T, A::T, B::T) where {T<:ElementalMatrix} = gemm(NORMAL, ADJOINT, one(eltype(T)), A, B, zero(eltype(T)), C)
28+
Base.A_mul_Bt!(C::T, A::T, B::T) where {T<:ElementalMatrix} = gemm(NORMAL, TRANSPOSE, one(eltype(T)), A, B, zero(eltype(T)), C)
29+
30+
## BLAS like multiplication API (i.e. with α and β)
31+
Base.A_mul_B!::Number, A::S, B::S, β::Number, C::S) where {S<:ElementalMatrix{T}} where {T} =
32+
gemm(NORMAL, NORMAL, T(α), A, B, T(β), C)
33+
Base.Ac_mul_B!::Number, A::S, B::S, β::Number, C::S) where {S<:ElementalMatrix{T}} where {T} =
34+
gemm(ADJOINT, NORMAL, T(α), A, B, T(β), C)
35+
Base.At_mul_B!::Number, A::S, B::S, β::Number, C::S) where {S<:ElementalMatrix{T}} where {T} =
36+
gemm(TRANSPOSE, NORMAL, T(α), A, B, T(β), C)
37+
Base.A_mul_Bc!::Number, A::S, B::S, β::Number, C::S) where {S<:ElementalMatrix{T}} where {T} =
38+
gemm(NORMAL, ADJOINT, T(α), A, B, T(β), C)
39+
Base.A_mul_Bt!::Number, A::S, B::S, β::Number, C::S) where {S<:ElementalMatrix{T}} where {T} =
40+
gemm(NORMAL, TRANSPOSE, T(α), A, B, T(β), C)
41+
42+
## Linear solve API
43+
Base.LinAlg.A_ldiv_B!(A::LowerTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
44+
trsm(LEFT, LOWER, NORMAL , NON_UNIT, one(T), A.data, B)
45+
Base.LinAlg.Ac_ldiv_B!(A::LowerTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
46+
trsm(LEFT, LOWER, ADJOINT , NON_UNIT, one(T), A.data, B)
47+
Base.LinAlg.At_ldiv_B!(A::LowerTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
48+
trsm(LEFT, LOWER, TRANSPOSE, NON_UNIT, one(T), A.data, B)
49+
Base.LinAlg.A_ldiv_B!(A::UpperTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
50+
trsm(LEFT, UPPER, NORMAL , NON_UNIT, one(T), A.data, B)
51+
Base.LinAlg.Ac_ldiv_B!(A::UpperTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
52+
trsm(LEFT, UPPER, ADJOINT , NON_UNIT, one(T), A.data, B)
53+
Base.LinAlg.At_ldiv_B!(A::UpperTriangular{T,S}, B::S) where {T,S<:ElementalMatrix} =
54+
trsm(LEFT, UPPER, TRANSPOSE, NON_UNIT, one(T), A.data, B)
55+
56+
Base.LinAlg.A_rdiv_B!(A::S, B::LowerTriangular{T,S}) where {T,S<:ElementalMatrix} =
57+
trsm(RIGHT, LOWER, NORMAL , NON_UNIT, one(T), B.data, A)
58+
Base.LinAlg.A_rdiv_Bc!(A::S, B::LowerTriangular{T,S}) where {T,S<:ElementalMatrix} =
59+
trsm(RIGHT, LOWER, ADJOINT , NON_UNIT, one(T), B.data, A)
60+
Base.LinAlg.A_rdiv_Bt!(A::S, B::LowerTriangular{T,S}) where {T,S<:ElementalMatrix} =
61+
trsm(RIGHT, LOWER, TRANSPOSE, NON_UNIT, one(T), B.data, A)
62+
Base.LinAlg.A_rdiv_B!(A::S, B::UpperTriangular{T,S}) where {T,S<:ElementalMatrix} =
63+
trsm(RIGHT, UPPER, NORMAL , NON_UNIT, one(T), B.data, A)
64+
Base.LinAlg.A_rdiv_Bc!(A::S, B::UpperTriangular{T,S}) where {T,S<:ElementalMatrix} =
65+
trsm(RIGHT, UPPER, ADJOINT , NON_UNIT, one(T), B.data, A)
66+
Base.LinAlg.A_rdiv_Bt!(A::S, B::UpperTriangular{T,S}) where {T,S<:ElementalMatrix} =
67+
trsm(RIGHT, UPPER, TRANSPOSE, NON_UNIT, one(T), B.data, A)
2368

2469
# Spectral
25-
svd(A::ElementalMatrix) = svd!(copy(A))
26-
svd(A::ElementalMatrix, ctrl::SVDCtrl) = svd!(copy(A), ctrl)
27-
svdvals(A::ElementalMatrix, ctrl::SVDCtrl) = svdvals!(copy(A), ctrl)
70+
Base.LinAlg.svd(A::ElementalMatrix) = svd!(copy(A))
71+
Base.LinAlg.svd(A::ElementalMatrix, ctrl::SVDCtrl) = svd!(copy(A), ctrl)
72+
Base.LinAlg.svdvals(A::ElementalMatrix, ctrl::SVDCtrl) = svdvals!(copy(A), ctrl)
2873

2974
# conversions to and from julia arrays
3075

@@ -43,7 +88,7 @@ svdvals(A::ElementalMatrix, ctrl::SVDCtrl) = svdvals!(copy(A), ctrl)
4388
# return dest
4489
# end
4590

46-
function copy!{T}(dest::DistMatrix{T}, src::Base.VecOrMat)
91+
function Base.copy!{T}(dest::DistMatrix{T}, src::Base.VecOrMat)
4792
m, n = size(src, 1), size(src, 2)
4893
zeros!(dest, m, n)
4994
if MPI.commRank(comm(dest)) == 0
@@ -57,21 +102,21 @@ function copy!{T}(dest::DistMatrix{T}, src::Base.VecOrMat)
57102
return dest
58103
end
59104

60-
function convert{T}(::Type{Matrix{T}}, A::Base.VecOrMat{T})
105+
function Base.convert{T}(::Type{Matrix{T}}, A::Base.VecOrMat{T})
61106
m, n = size(A, 1), size(A, 2)
62107
B = Matrix(T)
63108
resize!(B, m, n)
64109
Base.unsafe_copy!(pointer(B), pointer(A), m*n)
65110
return B
66111
end
67-
function convert{T}(::Type{Base.Matrix{T}}, A::Matrix{T})
112+
function Base.convert{T}(::Type{Base.Matrix{T}}, A::Matrix{T})
68113
m, n = size(A)
69114
B = Base.Matrix{T}(m, n)
70115
Base.unsafe_copy!(pointer(B), pointer(A), m*n)
71116
return B
72117
end
73118

74-
function convert{T}(::Type{DistMatrix{T}}, A::Base.VecOrMat{T})
119+
function Base.convert{T}(::Type{DistMatrix{T}}, A::Base.VecOrMat{T})
75120
m, n = size(A, 1), size(A, 2)
76121
B = DistMatrix(T)
77122
zeros!(B, m, n)
@@ -86,7 +131,7 @@ function convert{T}(::Type{DistMatrix{T}}, A::Base.VecOrMat{T})
86131
return B
87132
end
88133

89-
function convert{T}(::Type{DistMultiVec{T}}, A::Base.VecOrMat{T})
134+
function Base.convert{T}(::Type{DistMultiVec{T}}, A::Base.VecOrMat{T})
90135
m, n = size(A, 1), size(A, 2)
91136
B = DistMultiVec(T)
92137
zeros!(B, m, n)
@@ -101,7 +146,7 @@ function convert{T}(::Type{DistMultiVec{T}}, A::Base.VecOrMat{T})
101146
return B
102147
end
103148

104-
function convert{T}(::Type{DistMatrix{T}}, A::DistMultiVec{T})
149+
function Base.convert{T}(::Type{DistMatrix{T}}, A::DistMultiVec{T})
105150
m, n = size(A)
106151
B = DistMatrix(T)
107152
zeros!(B, m, n)
@@ -115,7 +160,7 @@ function convert{T}(::Type{DistMatrix{T}}, A::DistMultiVec{T})
115160
return B
116161
end
117162

118-
function norm(x::ElementalMatrix)
163+
function Base.LinAlg.norm(x::ElementalMatrix)
119164
if size(x, 2) == 1
120165
return nrm2(x)
121166
else
@@ -124,10 +169,11 @@ function norm(x::ElementalMatrix)
124169
end
125170

126171
# Multiplication
127-
(*){T}(A::DistMatrix{T}, B::Base.VecOrMat{T}) = A*convert(DistMatrix{T}, B)
128-
(*){T}(A::DistMultiVec{T}, B::Base.VecOrMat{T}) = convert(DistMatrix{T}, A)*convert(DistMatrix{T}, B)
129-
(*){T}(A::DistSparseMatrix{T}, B::Base.VecOrMat{T}) = A*convert(DistMultiVec{T}, B)
130-
Ac_mul_B{T}(A::DistMatrix{T}, B::Base.VecOrMat{T}) = Ac_mul_B(A, convert(DistMatrix{T}, B))
131-
Ac_mul_B{T}(A::DistMultiVec{T}, B::Base.VecOrMat{T}) = Ac_mul_B(convert(DistMatrix{T}, A), convert(DistMatrix{T}, B))
132-
Ac_mul_B{T}(A::DistSparseMatrix{T}, B::Base.VecOrMat{T}) = Ac_mul_B(A, convert(DistMultiVec{T}, B))
172+
# (*){T}(A::DistMatrix{T}, B::Base.VecOrMat{T}) = A*convert(DistMatrix{T}, B)
173+
# (*){T}(A::DistMultiVec{T}, B::Base.VecOrMat{T}) = convert(DistMatrix{T}, A)*convert(DistMatrix{T}, B)
174+
# (*){T}(A::DistSparseMatrix{T}, B::Base.VecOrMat{T}) = A*convert(DistMultiVec{T}, B)
175+
# Ac_mul_B{T}(A::DistMatrix{T}, B::Base.VecOrMat{T}) = Ac_mul_B(A, convert(DistMatrix{T}, B))
176+
# Ac_mul_B{T}(A::DistMultiVec{T}, B::Base.VecOrMat{T}) = Ac_mul_B(convert(DistMatrix{T}, A), convert(DistMatrix{T}, B))
177+
# Ac_mul_B{T}(A::DistSparseMatrix{T}, B::Base.VecOrMat{T}) = Ac_mul_B(A, convert(DistMultiVec{T}, B))
133178

179+
Base.cholfact!(A::Hermitian{<:Any,<:ElementalMatrix}, ::Type{Val{false}}) = Base.LinAlg.Cholesky(cholesky(A.uplo == 'U' ? UPPER : LOWER, A.data), A.uplo)

src/lapack_like/factor.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,21 @@ function RegSolveCtrl{T<:ElFloatType}(::Type{T};
3333
ElBool(progress),
3434
ElBool(time))
3535
end
36+
37+
for (elty, ext) in ((:Float32, :s),
38+
(:Float64, :d),
39+
(:Complex64, :c),
40+
(:Complex128, :z))
41+
for mattype in ("", "Dist")
42+
mat = Symbol(mattype, "Matrix")
43+
@eval begin
44+
45+
function cholesky(uplo::UpperOrLower, A::$mat{$elty})
46+
ElError(ccall(($(string("ElCholesky", mattype, "_", ext)), libEl), Cuint,
47+
(UpperOrLower, Ptr{Void}),
48+
uplo, A.obj))
49+
return A
50+
end
51+
end
52+
end
53+
end

test/spectral.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Elemental, Base.Test
2+
3+
@testset "generel eigenvalues (Schur) with eltype: $elty" for elty in (Float32, Float64, Complex{Float32}, Complex{Float64})
4+
n = 10
5+
A = Elemental.DistMatrix(elty)
6+
Elemental.gaussian!(A, n, n)
7+
elvals = Elemental.eigvalsGeneral(A)
8+
lavals = eigvals(Array(A))
9+
@test sort(abs.(vec(Array(elvals)))) sort(abs.(lavals))
10+
end

0 commit comments

Comments
 (0)