Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,18 @@ end
# This will turn local AbstractArrays into DArrays
dbc = bcdistribute(bc)

asyncmap(procs(dest)) do p
remotecall_fetch(p) do
@sync for p in procs(dest)
@async remotecall_wait(p) do
# get the indices for the localpart
lpidx = localpartindex(dest)
@assert lpidx != 0
# create a local version of the broadcast, by constructing views
# Note: creates copies of the argument
lbc = bclocal(dbc, dest.indices[lpidx])
copyto!(localpart(dest), lbc)
return nothing
end
end

return dest
end

Expand Down
2 changes: 1 addition & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function close_by_id(id, pids)
global refs
@sync begin
for p in pids
@async remotecall_fetch(release_localpart, p, id)
@async remotecall_wait(release_localpart, p, id)
end
if !(myid() in pids)
release_localpart(id)
Expand Down
37 changes: 18 additions & 19 deletions src/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function DArray(id, init, dims, pids, idxs, cuts)

if length(unique(localtypes)) != 1
@sync for p in pids
@async remotecall_fetch(release_localpart, p, id)
@async remotecall_wait(release_localpart, p, id)
end
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
end
Expand Down Expand Up @@ -147,8 +147,8 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
end
end

@sync for i = 1:length(pids)
@async remotecall_fetch(construct_localparts, pids[i], init, id, (npids,), pids, idxs, cuts; T=T, A=T)
@sync for p in pids
@async remotecall_wait(construct_localparts, p, init, id, (npids,), pids, idxs, cuts; T=T, A=T)
end

if myid() in pids
Expand All @@ -161,9 +161,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
end

function gather(d::DArray{T,1,T}) where T
a=Array{T}(undef, length(procs(d)))
@sync for (i,p) in enumerate(procs(d))
@async a[i] = remotecall_fetch(localpart, p, d)
pids = procs(d)
a = Vector{T}(undef, length(pids))
asyncmap!(a, pids) do p
remotecall_fetch(localpart, p, d)
end
a
end
Expand Down Expand Up @@ -195,12 +196,9 @@ function DArray(refs)
dimdist = size(refs)
id = next_did()

npids = [r.where for r in refs]
nsizes = Array{Tuple}(undef, dimdist)
@sync for i in 1:length(refs)
let i=i
@async nsizes[i] = remotecall_fetch(sz_localpart_ref, npids[i], refs[i], id)
end
asyncmap!(nsizes, refs) do r
remotecall_fetch(sz_localpart_ref, r.where, r, id)
end

nindices = Array{NTuple{length(dimdist),UnitRange{Int}}}(undef, dimdist...)
Expand All @@ -223,7 +221,7 @@ function DArray(refs)
ncuts = Array{Int,1}[pushfirst!(sort(unique(lastidxs[x,:])), 1) for x in 1:length(dimdist)]
ndims = tuple([sort(unique(lastidxs[x,:]))[end]-1 for x in 1:length(dimdist)]...)

DArray(id, refs, ndims, reshape(npids, dimdist), nindices, ncuts)
DArray(id, refs, ndims, map(r -> r.where, refs), nindices, ncuts)
end

macro DArray(ex0::Expr)
Expand Down Expand Up @@ -673,8 +671,8 @@ Base.copy(d::SubDArray) = copyto!(similar(d), d)
Base.copy(d::SubDArray{<:Any,2}) = copyto!(similar(d), d)

function Base.copyto!(dest::SubOrDArray, src::AbstractArray)
asyncmap(procs(dest)) do p
remotecall_fetch(p) do
@sync for p in procs(dest)
@async remotecall_wait(p) do
ldest = localpart(dest)
copyto!(ldest, view(src, localindices(dest)...))
end
Expand All @@ -684,8 +682,8 @@ end

function Base.deepcopy(src::DArray)
dest = similar(src)
asyncmap(procs(src)) do p
remotecall_fetch(p) do
@sync for p in procs(src)
@async remotecall_wait(p) do
dest[:L] = deepcopy(src[:L])
end
end
Expand Down Expand Up @@ -835,14 +833,15 @@ end

function Base.fill!(A::DArray, x)
@sync for p in procs(A)
@async remotecall_fetch((A,x)->(fill!(localpart(A), x); nothing), p, A, x)
@async remotecall_wait((A,x)->fill!(localpart(A), x), p, A, x)
end
return A
end

function Random.rand!(A::DArray, ::Type{T}) where T
asyncmap(procs(A)) do p
remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
@sync for p in procs(A)
@async remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
end
return A
end

39 changes: 16 additions & 23 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ function LinearAlgebra.axpy!(α, x::DArray, y::DArray)
if length(x) != length(y)
throw(DimensionMismatch("vectors must have same length"))
end
asyncmap(procs(y)) do p
@async remotecall_fetch(p) do
@sync for p in procs(y)
@async remotecall_wait(p) do
axpy!(α, localpart(x), localpart(y))
return nothing
end
end
return y
Expand All @@ -39,26 +38,22 @@ function LinearAlgebra.dot(x::DVector, y::DVector)
throw(DimensionMismatch(""))
end

results=Any[]
asyncmap(procs(x)) do p
push!(results, remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y))
results = asyncmap(procs(x)) do p
remotecall_fetch((x, y) -> dot(localpart(x), makelocal(y, localindices(x)...)), p, x, y)
end
return reduce(+, results)
end

function LinearAlgebra.norm(x::DArray, p::Real = 2)
results = []
@sync begin
for pp in procs(x)
@async push!(results, remotecall_fetch(() -> norm(localpart(x), p), pp))
end
results = asyncmap(procs(x)) do pp
remotecall_fetch(() -> norm(localpart(x), p), pp)
end
return norm(results, p)
end

function LinearAlgebra.rmul!(A::DArray, x::Number)
@sync for p in procs(A)
@async remotecall_fetch((A,x)->(rmul!(localpart(A), x); nothing), p, A, x)
@async remotecall_wait((A,x)->rmul!(localpart(A), x), p, A, x)
end
return A
end
Expand Down Expand Up @@ -104,13 +99,12 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
# Scale y if necessary
if β != one(β)
asyncmap(procs(y)) do p
remotecall_fetch(p) do
remotecall_wait(p) do
if !iszero(β)
rmul!(localpart(y), β)
else
fill!(localpart(y), 0)
end
return nothing
end
end
end
Expand All @@ -120,7 +114,7 @@ function LinearAlgebra.mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Numbe
p = y.pids[i]
for j = 1:size(R, 2)
rij = R[i,j]
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
end
end

Expand Down Expand Up @@ -150,14 +144,13 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab

# Scale y if necessary
if β != one(β)
asyncmap(procs(y)) do p
remotecall_fetch(p) do
@sync for p in procs(y)
@async remotecall_wait(p) do
if !iszero(β)
rmul!(localpart(y), β)
else
fill!(localpart(y), 0)
end
return nothing
end
end
end
Expand All @@ -167,7 +160,7 @@ function LinearAlgebra.mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::Ab
p = y.pids[i]
for j = 1:size(R, 2)
rij = R[i,j]
@async remotecall_fetch(() -> (add!(localpart(y), fetch(rij), α); nothing), p)
@async remotecall_wait(() -> add!(localpart(y), fetch(rij), α), p)
end
end
return y
Expand Down Expand Up @@ -238,10 +231,10 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
# Scale C if necessary
if β != one(β)
@sync for p in C.pids
if β != zero(β)
@async remotecall_fetch(() -> (rmul!(localpart(C), β); nothing), p)
if iszero(β)
@async remotecall_wait(() -> fill!(localpart(C), 0), p)
else
@async remotecall_fetch(() -> (fill!(localpart(C), 0); nothing), p)
@async remotecall_wait(() -> rmul!(localpart(C), β), p)
end
end
end
Expand All @@ -252,7 +245,7 @@ function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::
p = C.pids[i,k]
for j = 1:size(R, 2)
rijk = R[i,j,k]
@async remotecall_fetch(d -> (add!(localpart(d), fetch(rijk), α); nothing), p, C)
@async remotecall_wait(d -> add!(localpart(d), fetch(rijk), α), p, C)
end
end
end
Expand Down
18 changes: 8 additions & 10 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
Base.map(f, d0::DArray, ds::AbstractArray...) = broadcast(f, d0, ds...)

function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
asyncmap(procs(dest)) do p
remotecall_fetch(p) do
@sync for p in procs(dest)
@async remotecall_wait(p) do
map!(f, localpart(dest), makelocal(src, localindices(dest)...))
return nothing
end
end
return dest
Expand Down Expand Up @@ -38,8 +37,8 @@ function Base.reducedim_initarray(A::DArray, region, v0, ::Type{R}) where {R}
# Store reduction on lowest pids
pids = A.pids[ntuple(i -> i in region ? (1:1) : (:), ndims(A))...]
chunks = similar(pids, Future)
@sync for i in eachindex(pids)
@async chunks[i...] = remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), pids[i...])
asyncmap!(chunks, pids) do p
remotecall_wait(() -> Base.reducedim_initarray(localpart(A), region, v0, R), p)
end
return DArray(chunks)
end
Expand All @@ -64,13 +63,12 @@ end
# has been run on each localpart with mapreducedim_within. Eventually, we might
# want to write mapreducedim_between! as a binary reduction.
function mapreducedim_between!(f, op, R::DArray, A::DArray, region)
asyncmap(procs(R)) do p
remotecall_fetch(p, f, op, R, A, region) do f, op, R, A, region
@sync for p in procs(R)
@async remotecall_wait(p, f, op, R, A, region) do f, op, R, A, region
localind = [r for r = localindices(A)]
localind[[region...]] = [1:n for n = size(A)[[region...]]]
B = convert(Array, A[localind...])
Base.mapreducedim!(f, op, localpart(R), B)
nothing
end
end
return R
Expand Down Expand Up @@ -163,8 +161,8 @@ function map_localparts(f::Callable, A::Array, DA::DArray)
end

function map_localparts!(f::Callable, d::DArray)
asyncmap(procs(d)) do p
remotecall_fetch((f,d)->(f(localpart(d)); nothing), p, f, d)
@sync for p in procs(d)
@async remotecall_wait((f,d)->f(localpart(d)), p, f, d)
end
return d
end
Expand Down
4 changes: 2 additions & 2 deletions src/spmd.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module SPMD

using Distributed: RemoteChannel, myid, procs, remote_do, remotecall_fetch
using Distributed: RemoteChannel, myid, procs, remote_do, remotecall_fetch, remotecall_wait
using ..DistributedArrays: DistributedArrays, gather, next_did

export sendto, recvfrom, recvfrom_any, barrier, bcast, scatter, gather
Expand Down Expand Up @@ -243,7 +243,7 @@ function spmd(f, args...; pids=procs(), context=nothing)
ctxt_id = context.id
end
@sync for p in pids
@async remotecall_fetch(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
@async remotecall_wait(spmd_local, p, f_noarg, ctxt_id, clear_ctxt)
end
nothing
end
Expand Down
21 changes: 20 additions & 1 deletion test/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ unpack(ex) = ex
A = randn(100,100)
DA = distribute(A)

# sum either throws an ArgumentError or a CompositeException of ArgumentErrors
# sum either throws an ArgumentError, a CompositeException of ArgumentErrors,
# or a RemoteException wrapping an ArgumentError
try
sum(DA, dims=-1)
catch err
Expand All @@ -369,6 +370,9 @@ unpack(ex) = ex
orig_err = unpack(excep)
@test isa(orig_err, ArgumentError)
end
elseif isa(err, RemoteException)
@test err.captured isa CapturedException
@test err.captured.ex isa ArgumentError
else
@test isa(err, ArgumentError)
end
Expand All @@ -383,6 +387,9 @@ unpack(ex) = ex
orig_err = unpack(excep)
@test isa(orig_err, ArgumentError)
end
elseif isa(err, RemoteException)
@test err.captured isa CapturedException
@test err.captured.ex isa ArgumentError
else
@test isa(err, ArgumentError)
end
Expand Down Expand Up @@ -1039,6 +1046,8 @@ check_leaks()
close(d)
end

check_leaks()

@testset "rand!" begin
d = dzeros(30, 30)
rand!(d)
Expand All @@ -1048,6 +1057,16 @@ end

check_leaks()

@testset "fill!" begin
d = dzeros(30, 30)
fill!(d, 3.14)
@test all(x-> x == 3.14, d)

close(d)
end

check_leaks()

d_closeall()

@testset "test for any leaks" begin
Expand Down
Loading