Skip to content

Commit a3478e6

Browse files
authored
Expose common factorisations (#83)
* qr and lq both work. lu doesnt * need to fix QRColPiv - will come back to it * lu fixed, cholesky started * non distributed cholesky, lu, qr and lq all pass * test distributed too * move helper functions into core * add convenience setindex! with warning if used with scalars * whitespace * add more convenience functions * qol functions with darray * not going to do pivoting this time * delete trailing whitespace * test/factor.jl wasnt designed to be run in loop * delete commented code from test/factor * delete calls to elemental library that were never needed * change indentation to 4 spaces as appears to be used already * AbstractArray -> Array * typeof -> isa, but potentially delete this one * error when setindex! with scalars * added type signatures to factorisation struct outer constructors * remove outdated comment * do not test functions with preprended _ * clearer types to throw error on scalar setindex
1 parent 9807b28 commit a3478e6

File tree

9 files changed

+505
-4
lines changed

9 files changed

+505
-4
lines changed

src/core/distmatrix.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ for (elty, ext) in ((:ElInt, :i),
148148
A.obj, i, j))
149149
return A
150150
end
151+
152+
function isLocalRow(A::DistMatrix{$elty}, i::Integer)
153+
rv = Ref{ElInt}(0)
154+
ElError(ccall(($(string("ElDistMatrixIsLocalRow_", ext)), libEl), Cuint,
155+
(Ptr{Cvoid}, ElInt, Ref{ElInt}),
156+
A.obj, i - 1, rv))
157+
return Bool(rv[])
158+
end
159+
160+
function isLocalCol(A::DistMatrix{$elty}, i::Integer)
161+
rv = Ref{ElInt}(0)
162+
ElError(ccall(($(string("ElDistMatrixIsLocalCol_", ext)), libEl), Cuint,
163+
(Ptr{Cvoid}, ElInt, Ref{ElInt}),
164+
A.obj, i - 1, rv))
165+
return Bool(rv[])
166+
end
167+
151168
end
152169
end
153170

@@ -205,3 +222,28 @@ function hcat(x::Vector{DistMatrix{T}}) where {T}
205222
return A
206223
end
207224
end
225+
226+
import DistributedArrays.localpart
227+
# used in testing
228+
function localpart(A::Elemental.DistMatrix{T}) where T
229+
buffer = Base.zeros(T, localHeight(A), localWidth(A))
230+
return localpart!(buffer, A)
231+
end
232+
233+
function localpart!(buffer, A::Elemental.DistMatrix)
234+
@assert size(buffer) == (localHeight(A), localWidth(A))
235+
for j in 1:localWidth(A), i in 1:localHeight(A)
236+
buffer[i, j] = getLocal(A, i, j)
237+
end
238+
return buffer
239+
end
240+
241+
import DistributedArrays.localindices
242+
# used in testing
243+
function localindices(A::Elemental.DistMatrix{T}) where T
244+
# sometimes they aren't contigous so cant do start:start+length
245+
rows = findall(isLocalRow(A, i) for i in 1:height(A))
246+
cols = findall(isLocalCol(A, i) for i in 1:width(A))
247+
return (rows, cols)
248+
end
249+

src/julia/darray.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,5 +250,46 @@ for (elty, ext) in ((:ElInt, :i),
250250
processQueues(A)
251251
return A
252252
end
253+
254+
function convert(::Type{DistMatrix{$elty}}, DA::DistributedArrays.DArray)
255+
npr, npc = size(procs(DA))
256+
if npr*npc != MPI.Comm_size(MPI.COMM_WORLD)
257+
error("Used non MPI.COMM_WORLD DArray for DistMatrix, ",
258+
"as procs(DA)=($npr,$npc) is incompatible with ",
259+
"MPI.Comm_size(MPI.COMM_WORLD)=$(MPI.Comm_size(MPI.COMM_WORLD))")
260+
end
261+
262+
m, n = size(DA)
263+
A = DistMatrix($elty, m, n)
264+
@sync begin
265+
for id in workers()
266+
let A = A, DA = DA
267+
@async remotecall_fetch(id) do
268+
rows, cols = DistributedArrays.localindices(DA)
269+
reserve(A,length(rows) * length(cols))
270+
for j in cols, i in rows
271+
queueUpdate(A, i - 1, j - 1, DA[i, j])
272+
end
273+
end
274+
end
275+
end
276+
end
277+
processQueues(A)
278+
return A
279+
end
280+
281+
function copyto!(DA::DistributedArrays.DArray{$elty}, A::DistMatrix{$elty} )
282+
@sync begin
283+
ijs = localindices(DA)
284+
for j in ijs[2], i in ijs[1]
285+
queuePull(A, i, j)
286+
end
287+
DAlocal = DA[:L]
288+
289+
DAlocal_mat = ndims(DAlocal) == 1 ? reshape(DAlocal, :, 1) : DAlocal
290+
processPullQueue(A, DAlocal_mat)
291+
end
292+
return DA
293+
end
253294
end
254295
end

src/julia/generic.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@ Base.convert(::Type{Array}, xd::DistMatrix{T}) where {T} =
181181

182182
Base.Array(xd::DistMatrix) = convert(Array, xd)
183183

184+
function Base.setindex!(A::DistMatrix, values::Number, i::Integer, j::Integer)
185+
throw(ArgumentError("setindex! with scalars is disallowed.
186+
Use a large collection to setindex! in bulk."))
187+
end
188+
function Base.setindex!(A::DistMatrix,
189+
values,
190+
globalis,
191+
globaljs)
192+
for (cj, globalj) in enumerate(globaljs), (ci, globali) in enumerate(globalis)
193+
queueUpdate(A, globali, globalj, values[ci, cj])
194+
end
195+
processQueues(A)
196+
end
197+
198+
184199
LinearAlgebra.norm(x::ElementalMatrix) = nrm2(x)
185200
# function LinearAlgebra.norm(x::ElementalMatrix)
186201
# if size(x, 2) == 1
@@ -194,6 +209,12 @@ LinearAlgebra.cholesky!(A::Hermitian{<:Union{Real,Complex},<:ElementalMatrix}) =
194209
LinearAlgebra.cholesky(A::Hermitian{<:Union{Real,Complex},<:ElementalMatrix}) = cholesky!(copy(A))
195210

196211
LinearAlgebra.lu(A::ElementalMatrix) = _lu!(copy(A))
212+
LinearAlgebra.lu!(A::ElementalMatrix) = _lu!(A)
213+
LinearAlgebra.qr(A::ElementalMatrix) = _qr!(copy(A))
214+
LinearAlgebra.qr!(A::ElementalMatrix) = _qr!(A)
215+
LinearAlgebra.lq(A::ElementalMatrix) = _lq!(copy(A))
216+
LinearAlgebra.lq!(A::ElementalMatrix) = _lq!(A)
217+
LinearAlgebra.cholesky!(A::ElementalMatrix) = _cholesky!(A)
197218

198219
# Mixed multiplication with Julia Arrays
199220
(*)(A::DistMatrix{T}, B::StridedVecOrMat{T}) where {T} = A*convert(DistMatrix{T}, B)

src/lapack_like/factor.jl

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,142 @@ for mattype in ("", "Dist")
9292
uplo, A.obj))
9393
return A
9494
end
95+
end
96+
end
97+
end
98+
99+
# These are the number types that Elemental supports
100+
101+
for mattype in ("", "Dist")
102+
mat = Symbol(mattype, "Matrix")
103+
_p = Symbol(mattype, "Permutation")
104+
105+
# TODO - fix QRColPiv
106+
#QRColPivStructName = Symbol("QRColPiv$(string(mattype))")
107+
QRStructName = Symbol("QR$(string(mattype))")
108+
LQStructName = Symbol("LQ$(string(mattype))")
109+
LUStructName = Symbol("LU$(string(mattype))")
110+
CHStructName = Symbol("Cholesky$(string(mattype))")
111+
112+
@eval begin
113+
114+
struct $QRStructName{T,U<:Real}
115+
A::$mat{T}
116+
t::$mat{T}
117+
d::$mat{U}
118+
orientation::Ref{Orientation}
119+
end
120+
function $QRStructName(A::$mat{T}, t::$mat{T}, d::$mat{U}
121+
) where {U<:Union{Float32, Float64}, T<:Union{Complex{U}, U}}
122+
return $QRStructName(A, t, d, Ref(NORMAL::Orientation))
123+
end
124+
125+
struct $LQStructName{T, U<:Real}
126+
A::$mat{T}
127+
householderscalars::$mat{T}
128+
signature::$mat{U}
129+
orientation::Ref{Orientation}
130+
end
131+
function $LQStructName(A::$mat{T}, householderscalars::$mat{T}, signature::$mat{U}
132+
) where {U<:Union{Float32, Float64}, T<:Union{Complex{U}, U}}
133+
return $LQStructName(A, householderscalars, signature, Ref(NORMAL::Orientation))
134+
end
135+
136+
struct $LUStructName{T}
137+
A::$mat{T}
138+
p::$_p
139+
orientation::Ref{Orientation}
140+
end
141+
function $LUStructName(A::$mat{T}, p::$_p
142+
) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}}
143+
return $LUStructName(A, p, Ref(NORMAL::Orientation))
144+
end
145+
146+
struct $CHStructName{T}
147+
uplo::UpperOrLower
148+
A::$mat{T}
149+
orientation::Ref{Orientation}
150+
end
151+
function $CHStructName(uplo::UpperOrLower, A::$mat{T}
152+
) where {T<:Union{Float32, Float64, ComplexF32, ComplexF64}}
153+
return $CHStructName(uplo, A, Ref(NORMAL::Orientation))
154+
end
155+
156+
end
157+
158+
for (elty, ext) in ((:Float32, :s),
159+
(:Float64, :d),
160+
(:ComplexF32, :c),
161+
(:ComplexF64, :z))
162+
163+
@eval begin
95164

96165
function _lu!(A::$mat{$elty})
97166
p = $_p()
98-
ElError(ccall(($(string("ElLUPartialPiv", mattype, "_", ext)), libEl), Cuint,
99-
(Ptr{Cvoid}, Ptr{Cvoid}),
100-
A.obj, p.obj))
101-
return A, p
167+
ElError(ccall(($(string("ElLU", mattype, "_", ext)), libEl), Cuint,
168+
(Ptr{Cvoid}, Ptr{Cvoid}),
169+
A.obj, p.obj))
170+
return $LUStructName(A, p)
171+
end
172+
173+
function LinearAlgebra.:\(lu::$LUStructName{$elty}, b::$mat{$elty})
174+
x = deepcopy(b)#$mat($elty)
175+
ElError(ccall(($(string("ElSolveAfterLU", mattype, "_", ext)), libEl), Cuint,
176+
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}),
177+
lu.orientation[], lu.A.obj, x.obj))
178+
return x
102179
end
180+
181+
function _qr!(A::$mat{$elty})
182+
t = $mat($elty)
183+
d = $mat(real($elty))
184+
ElError(ccall(($(string("ElQR", mattype, "_", ext)), libEl), Cuint,
185+
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
186+
A.obj, t.obj, d.obj))
187+
return $QRStructName(A, t, d)
188+
end
189+
190+
function LinearAlgebra.:\(qr::$QRStructName{$elty}, b::$mat{$elty})
191+
x = $mat($elty)
192+
ElError(ccall(($(string("ElSolveAfterQR", mattype, "_", ext)), libEl), Cuint,
193+
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
194+
qr.orientation[], qr.A.obj, qr.t.obj, qr.d.obj, b.obj, x.obj))
195+
return x
196+
end
197+
198+
function _lq!(A::$mat{$elty})
199+
householderscalars = $mat($elty)
200+
signature = $mat(real($elty))
201+
ElError(ccall(($(string("ElLQ", mattype, "_", ext)), libEl), Cuint,
202+
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
203+
A.obj, householderscalars.obj, signature.obj))
204+
return $LQStructName(A, householderscalars, signature)
205+
end
206+
207+
function LinearAlgebra.:\(lq::$LQStructName{$elty}, b::$mat{$elty})
208+
x = $mat($elty)
209+
ElError(ccall(($(string("ElSolveAfterLQ", mattype, "_", ext)), libEl), Cuint,
210+
(Orientation, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}),
211+
lq.orientation[], lq.A.obj, lq.householderscalars.obj,
212+
lq.signature.obj, b.obj, x.obj))
213+
return x
214+
end
215+
216+
function _cholesky!(A::$mat{$elty}, uplo::UpperOrLower=UPPER::UpperOrLower)
217+
ElError(ccall(($(string("ElCholesky", mattype, "_", ext)), libEl), Cuint,
218+
(UpperOrLower, Ptr{Cvoid}),
219+
uplo, A.obj))
220+
return $CHStructName(uplo, A)
221+
end
222+
223+
function LinearAlgebra.:\(ch::$CHStructName{$elty}, b::$mat{$elty})
224+
x = deepcopy(b)#$mat($elty)
225+
ElError(ccall(($(string("ElSolveAfterCholesky", mattype, "_", ext)), libEl), Cuint,
226+
(UpperOrLower, Orientation, Ptr{Cvoid}, Ptr{Cvoid}),
227+
ch.uplo, ch.orientation[], ch.A.obj, x.obj))
228+
return x
229+
end
230+
103231
end
104232
end
105233
end

test/distcholesky.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using MPI, MPIClusterManagers, Distributed
2+
3+
man = MPIManager(np = 2);
4+
5+
addprocs(man);
6+
7+
@everywhere using LinearAlgebra, Elemental
8+
9+
const M = 400
10+
const N = M
11+
12+
@mpi_do man M = @fetchfrom 1 M
13+
@mpi_do man N = @fetchfrom 1 N
14+
15+
const Ahost = rand(Float64, M, N)
16+
Ahost .+= Ahost'
17+
Ahost .+= M * I(M)
18+
const bhost = rand(Float64, M)
19+
20+
@mpi_do man Aall = @fetchfrom 1 Ahost
21+
@mpi_do man ball = @fetchfrom 1 bhost
22+
23+
@mpi_do man A = Elemental.DistMatrix(Float64);
24+
@mpi_do man b = Elemental.DistMatrix(Float64);
25+
26+
@mpi_do man A = Elemental.resize!(A, M, N);
27+
@mpi_do man b = Elemental.resize!(b, M);
28+
29+
@mpi_do man copyto!(A, Aall)
30+
@mpi_do man copyto!(b, ball)
31+
32+
@mpi_do man chA = Elemental.cholesky!(A);
33+
34+
@mpi_do man x = chA \ b;
35+
36+
@mpi_do man localx = zeros(Float64, Elemental.localHeight(x), Elemental.localWidth(x))
37+
@mpi_do man copyto!(localx, Elemental.localpart(x))
38+
39+
using Test
40+
x = vcat((fetch(@spawnat p localx)[:] for p in workers())...)
41+
@testset "Cholesky" begin
42+
@test x Ahost \ bhost
43+
end

test/distlq.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using MPI, MPIClusterManagers, Distributed
2+
3+
man = MPIManager(np = 2);
4+
5+
addprocs(man);
6+
7+
@everywhere using LinearAlgebra, Elemental
8+
9+
const M = 300
10+
const N = 400
11+
12+
@mpi_do man M = @fetchfrom 1 M
13+
@mpi_do man N = @fetchfrom 1 N
14+
15+
const Ahost = rand(Float64, M, N)
16+
const bhost = rand(Float64, M)
17+
18+
@mpi_do man Aall = @fetchfrom 1 Ahost
19+
@mpi_do man ball = @fetchfrom 1 bhost
20+
21+
@mpi_do man A = Elemental.DistMatrix(Float64);
22+
@mpi_do man b = Elemental.DistMatrix(Float64);
23+
24+
@mpi_do man A = Elemental.resize!(A, M, N);
25+
@mpi_do man b = Elemental.resize!(b, M);
26+
27+
@mpi_do man copyto!(A, Aall)
28+
@mpi_do man copyto!(b, ball)
29+
30+
@mpi_do man lqA = Elemental.lq!(A);
31+
32+
@mpi_do man x = lqA \ b;
33+
34+
@mpi_do man localx = zeros(Float64, Elemental.localHeight(x), Elemental.localWidth(x))
35+
@mpi_do man copyto!(localx, Elemental.localpart(x))
36+
37+
using Test
38+
x = vcat((fetch(@spawnat p localx)[:] for p in workers())...)
39+
@testset "lq" begin
40+
@test x Ahost \ bhost
41+
end

0 commit comments

Comments
 (0)