diff --git a/src/core.jl b/src/core.jl index 5192f22..b79aae7 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,58 +1,109 @@ -const registry=Dict{Tuple, Any}() -const refs=Set() # Collection of darray identities created on this node +# Thread-safe registry of DArray references +struct DArrayRegistry + data::Dict{Tuple{Int,Int}, Any} + lock::ReentrantLock + DArrayRegistry() = new(Dict{Tuple{Int,Int}, Any}(), ReentrantLock()) +end +const REGISTRY = DArrayRegistry() + +function Base.get(r::DArrayRegistry, id::Tuple{Int,Int}, default) + @lock r.lock begin + return get(r.data, id, default) + end +end +function Base.getindex(r::DArrayRegistry, id::Tuple{Int,Int}) + @lock r.lock begin + return r.data[id] + end +end +function Base.setindex!(r::DArrayRegistry, val, id::Tuple{Int,Int}) + @lock r.lock begin + r.data[id] = val + end + return r +end +function Base.delete!(r::DArrayRegistry, id::Tuple{Int,Int}) + @lock r.lock delete!(r.data, id) + return r +end + +# Thread-safe set of IDs of DArrays created on this node +struct DArrayRefs + data::Set{Tuple{Int,Int}} + lock::ReentrantLock + DArrayRefs() = new(Set{Tuple{Int,Int}}(), ReentrantLock()) +end +const REFS = DArrayRefs() -let DID::Int = 1 - global next_did - next_did() = (id = DID; DID += 1; (myid(), id)) +function Base.push!(r::DArrayRefs, id::Tuple{Int,Int}) + # Ensure id refers to a DArray created on this node + if first(id) != myid() + throw( + ArgumentError( + lazy"`DArray` is not created on the current worker: Only `DArray`s created on worker $(myid()) can be stored in this set but the `DArray` was created on worker $(first(id)).")) + end + @lock r.lock begin + return push!(r.data, id) + end +end +function Base.delete!(r::DArrayRefs, id::Tuple{Int,Int}) + @lock r.lock delete!(r.data, id) + return r end +# Global counter to generate a unique ID for each DArray +const DID = Threads.Atomic{Int}(1) + """ next_did() -Produces an incrementing ID that will be used for DArrays. -""" -next_did +Increment a global counter and return a tuple of the current worker ID and the incremented +value of the counter. -release_localpart(id::Tuple) = (delete!(registry, id); nothing) -release_localpart(d) = release_localpart(d.id) +This tuple is used as a unique ID for a new `DArray`. +""" +next_did() = (myid(), Threads.atomic_add!(DID, 1)) -function close_by_id(id, pids) -# @async println("Finalizer for : ", id) - global refs +release_localpart(id::Tuple{Int,Int}) = (delete!(REGISTRY, id); nothing) +function release_allparts(id::Tuple{Int,Int}, pids::Array{Int}) @sync begin + released_myid = false for p in pids - @async remotecall_fetch(release_localpart, p, id) + if p == myid() + @async release_localpart(id) + released_myid = true + else + @async remotecall_fetch(release_localpart, p, id) + end end - if !(myid() in pids) - release_localpart(id) + if !released_myid + @async release_localpart(id) end end - delete!(refs, id) - nothing + return nothing end -function Base.close(d::DArray) -# @async println("close : ", d.id, ", object_id : ", object_id(d), ", myid : ", myid() ) - if (myid() == d.id[1]) && d.release - @async close_by_id(d.id, d.pids) - d.release = false - end +function close_by_id(id::Tuple{Int,Int}, pids::Array{Int}) + release_allparts(id, pids) + delete!(REFS, id) nothing end function d_closeall() - crefs = copy(refs) - for id in crefs - if id[1] == myid() # sanity check - if haskey(registry, id) - d = d_from_weakref_or_d(id) - (d === nothing) || close(d) + @lock REFS.lock begin + while !isempty(REFS.data) + id = pop!(REFS.data) + d = d_from_weakref_or_d(id) + if d isa DArray + finalize(d) end - yield() end end + return nothing end +Base.close(d::DArray) = finalize(d) + """ procs(d::DArray) @@ -67,4 +118,3 @@ Distributed.procs(d::SubDArray) = procs(parent(d)) The identity when input is not distributed """ localpart(A) = A - diff --git a/src/darray.jl b/src/darray.jl index d3868ad..2d95155 100644 --- a/src/darray.jl +++ b/src/darray.jl @@ -23,32 +23,30 @@ dfill(v, args...) = DArray(I->fill(v, map(length,I)), args...) ``` """ mutable struct DArray{T,N,A} <: AbstractArray{T,N} - id::Tuple + id::Tuple{Int,Int} dims::NTuple{N,Int} pids::Array{Int,N} # pids[i]==p ⇒ processor p has piece i indices::Array{NTuple{N,UnitRange{Int}},N} # indices held by piece i cuts::Vector{Vector{Int}} # cuts[d][i] = first index of chunk i in dimension d localpart::Union{A,Nothing} - release::Bool - function DArray{T,N,A}(id, dims, pids, indices, cuts, lp) where {T,N,A} + function DArray{T,N,A}(id::Tuple{Int,Int}, dims::NTuple{N,Int}, pids, indices, cuts, lp) where {T,N,A} # check invariants if dims != map(last, last(indices)) throw(ArgumentError("dimension of DArray (dim) and indices do not match")) end - release = (myid() == id[1]) d = d_from_weakref_or_d(id) if d === nothing - d = new(id, dims, pids, indices, cuts, lp, release) + d = new(id, dims, pids, indices, cuts, lp) end - if release - push!(refs, id) - registry[id] = WeakRef(d) - -# println("Installing finalizer for : ", d.id, ", : ", object_id(d), ", isbits: ", isbits(d)) - finalizer(close, d) + if first(id) == myid() + push!(REFS, id) + REGISTRY[id] = WeakRef(d) + finalizer(d) do d + @async close_by_id(d.id, d.pids) + end end d end @@ -56,11 +54,9 @@ mutable struct DArray{T,N,A} <: AbstractArray{T,N} DArray{T,N,A}() where {T,N,A} = new() end -function d_from_weakref_or_d(id) - d = get(registry, id, nothing) - isa(d, WeakRef) && return d.value - return d -end +unpack_weakref(x) = x +unpack_weakref(x::WeakRef) = x.value +d_from_weakref_or_d(id::Tuple{Int,Int}) = unpack_weakref(get(REGISTRY, id, nothing)) Base.eltype(::Type{DArray{T}}) where {T} = T empty_localpart(T,N,A) = A(Array{T}(undef, ntuple(zero, N))) @@ -77,41 +73,34 @@ Base.hash(d::DArray, h::UInt) = Base.hash(d.id, h) ## core constructors ## -function DArray(id, init, dims, pids, idxs, cuts) +function DArray(id::Tuple{Int,Int}, init::I, dims, pids, idxs, cuts) where {I} localtypes = Vector{DataType}(undef,length(pids)) - - @sync begin - for i = 1:length(pids) - @async begin - local typA - if isa(init, Function) - typA = remotecall_fetch(construct_localparts, pids[i], init, id, dims, pids, idxs, cuts) - else - # constructing from an array of remote refs. - typA = remotecall_fetch(construct_localparts, pids[i], init[i], id, dims, pids, idxs, cuts) - end - localtypes[i] = typA - end + if init isa Function + asyncmap!(localtypes, pids) do pid + return remotecall_fetch(construct_localparts, pid, init, id, dims, pids, idxs, cuts) + end + else + asyncmap!(localtypes, pids, init) do pid, pid_init + # constructing from an array of remote refs. + return remotecall_fetch(construct_localparts, pid, pid_init, id, dims, pids, idxs, cuts) end end - if length(unique(localtypes)) != 1 + if !allequal(localtypes) @sync for p in pids @async remotecall_fetch(release_localpart, p, id) end - throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)")) + throw(ErrorException(lazy"Constructed localparts have different `eltype`: $(localtypes)")) end A = first(localtypes) if myid() in pids - d = registry[id] - d = isa(d, WeakRef) ? d.value : d + return unpack_weakref(REGISTRY[id]) else T = eltype(A) N = length(dims) - d = DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A)) + return DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A)) end - d end function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=nothing) @@ -124,7 +113,7 @@ function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=not end N = length(dims) d = DArray{T,N,A}(id, dims, pids, idxs, cuts, localpart) - registry[id] = d + REGISTRY[id] = d A end @@ -152,12 +141,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve end if myid() in pids - d = registry[id] - d = isa(d, WeakRef) ? d.value : d + return unpack_weakref(REGISTRY[id]) else - d = DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing) + return DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing) end - d end function gather(d::DArray{T,1,T}) where T @@ -428,7 +415,7 @@ end function Base.:(==)(d::SubDArray, a::AbstractArray) cd = copy(d) t = cd == a - close(cd) + finalize(cd) return t end Base.:(==)(a::AbstractArray, d::DArray) = d == a @@ -437,19 +424,19 @@ Base.:(==)(d1::DArray, d2::DArray) = invoke(==, Tuple{DArray, AbstractArray}, d1 function Base.:(==)(d1::SubDArray, d2::DArray) cd1 = copy(d1) t = cd1 == d2 - close(cd1) + finalize(cd1) return t end function Base.:(==)(d1::DArray, d2::SubDArray) cd2 = copy(d2) t = d1 == cd2 - close(cd2) + finalize(cd2) return t end function Base.:(==)(d1::SubDArray, d2::SubDArray) cd1 = copy(d1) t = cd1 == d2 - close(cd1) + finalize(cd1) return t end @@ -845,4 +832,3 @@ function Random.rand!(A::DArray, ::Type{T}) where T remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T) end end - diff --git a/src/spmd.jl b/src/spmd.jl index 014f2a0..0539671 100644 --- a/src/spmd.jl +++ b/src/spmd.jl @@ -16,41 +16,52 @@ mutable struct WorkerDataChannel end mutable struct SPMDContext - id::Tuple + id::Tuple{Int,Int} chnl::Channel store::Dict{Any,Any} - pids::Array - release::Bool - - function SPMDContext(id) - ctxt = new(id, Channel(typemax(Int)), Dict{Any,Any}(), [], false) - finalizer(finalize_ctxt, ctxt) - ctxt + pids::Array{Int} + + function SPMDContext(id::Tuple{Int,Int}, pids::Vector{Int}) + ctxt = new(id, Channel(typemax(Int)), Dict{Any,Any}(), pids) + if first(id) == myid() + finalizer(ctxt) do ctxt + for p in ctxt.pids + @async remote_do(delete_ctxt_id, p, ctxt.id) + end + end + end + return ctxt end end -function finalize_ctxt(ctxt::SPMDContext) - ctxt.release && close(ctxt) + +# Every worker is associated with its own RemoteChannel +struct WorkerChannelDict + data::Dict{Int, WorkerDataChannel} + lock::ReentrantLock + WorkerChannelDict() = new(Dict{Int, WorkerDataChannel}(), ReentrantLock()) +end +const WORKERCHANNELS = WorkerChannelDict() + +Base.get!(f, x::WorkerChannelDict, id::Int) = @lock x.lock get!(f, x.data, id) + +# mapping between a context id and context object +struct SPMDContextDict + data::Dict{Tuple{Int,Int}, SPMDContext} + lock::ReentrantLock + SPMDContextDict() = new(Dict{Tuple{Int,Int}, SPMDContext}(), ReentrantLock()) end +const CONTEXTS = SPMDContextDict() + +Base.delete!(x::SPMDContextDict, id::Tuple{Int,Int}) = @lock x.lock delete!(x.data, id) +Base.get!(f, x::SPMDContextDict, id::Tuple{Int,Int}) = @lock x.lock get!(f, x.data, id) function context_local_storage() ctxt = get_ctxt_from_id(task_local_storage(:SPMD_CTXT)) ctxt.store end -function context(pids=procs()) - global map_ctxts - ctxt = SPMDContext(next_did()) - ctxt.pids = pids - ctxt.release = true - ctxt -end - -# Every worker is associated with its own RemoteChannel -const map_worker_channels = Dict{Int, WorkerDataChannel}() - -# mapping between a context id and context object -const map_ctxts = Dict{Tuple, SPMDContext}() +context(pids::Vector{Int}=procs()) = SPMDContext(next_did(), pids) # Multiple SPMD blocks can be executed concurrently, # each in its own context. Messages are still sent as part of the @@ -86,27 +97,21 @@ function get_dc(wc::WorkerDataChannel) return wc.rc end -function get_ctxt_from_id(ctxt_id) - global map_ctxts - ctxt = get(map_ctxts, ctxt_id, nothing) - if ctxt == nothing - ctxt = SPMDContext(ctxt_id) - map_ctxts[ctxt_id] = ctxt +function get_ctxt_from_id(ctxt_id::Tuple{Int,Int}) + ctxt = get!(CONTEXTS, ctxt_id) do + return SPMDContext(ctxt_id, Int[]) end return ctxt end - # Since modules may be loaded in any order on the workers, # and workers may be dynamically added, pull in the remote channel # handles when accessed for the first time. -function get_remote_dc(pid) - global map_worker_channels - if !haskey(map_worker_channels, pid) - map_worker_channels[pid] = WorkerDataChannel(pid) +function get_remote_dc(pid::Int) + wc = get!(WORKERCHANNELS, pid) do + return WorkerDataChannel(pid) end - - return get_dc(map_worker_channels[pid]) + return get_dc(wc) end function send_msg(to, typ, data, tag) @@ -248,17 +253,8 @@ function spmd(f, args...; pids=procs(), context=nothing) nothing end -function delete_ctxt_id(ctxt_id) - global map_ctxts - haskey(map_ctxts, ctxt_id) && delete!(map_ctxts, ctxt_id) - nothing -end +delete_ctxt_id(ctxt_id::Tuple{Int,Int}) = delete!(CONTEXTS, ctxt_id) -function Base.close(ctxt::SPMDContext) - for p in ctxt.pids - remote_do(delete_ctxt_id, p, ctxt.id) - end - ctxt.release = false -end +Base.close(ctxt::SPMDContext) = finalize(ctxt) end diff --git a/test/darray.jl b/test/darray.jl index 507bfda..4abd399 100644 --- a/test/darray.jl +++ b/test/darray.jl @@ -1052,10 +1052,10 @@ d_closeall() @testset "test for any leaks" begin sleep(1.0) # allow time for any cleanup to complete - allrefszero = Bool[remotecall_fetch(()->length(DistributedArrays.refs) == 0, p) for p in procs()] + allrefszero = Bool[remotecall_fetch(()-> @lock(DistributedArrays.REFS.lock, isempty(DistributedArrays.REFS.data)), p) for p in procs()] @test all(allrefszero) - allregistrieszero = Bool[remotecall_fetch(()->length(DistributedArrays.registry) == 0, p) for p in procs()] + allregistrieszero = Bool[remotecall_fetch(()-> @lock(DistributedArrays.REGISTRY.lock, isempty(DistributedArrays.REGISTRY.data)), p) for p in procs()] @test all(allregistrieszero) end diff --git a/test/runtests.jl b/test/runtests.jl index e85b812..dcdf380 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,9 +26,13 @@ const MYID = myid() const OTHERIDS = filter(id-> id != MYID, procs())[rand(1:(nprocs()-1))] function check_leaks() - if length(DistributedArrays.refs) > 0 + nrefs = @lock DistributedArrays.REFS.lock length(DistributedArrays.REFS.data) + if !iszero(nrefs) sleep(0.1) # allow time for any cleanup to complete and test again - length(DistributedArrays.refs) > 0 && @warn("Probable leak of ", length(DistributedArrays.refs), " darrays") + nrefs = @lock DistributedArrays.REFS.lock length(DistributedArrays.REFS.data) + if !iszero(nrefs) + @warn("Probable leak of ", nrefs, " darrays") + end end end diff --git a/test/spmd.jl b/test/spmd.jl index b14918f..148b987 100644 --- a/test/spmd.jl +++ b/test/spmd.jl @@ -171,25 +171,27 @@ end @everywhere begin if myid() != 1 local n = 0 - for (k,v) in DistributedArrays.SPMD.map_ctxts - store = v.store - localsum = store[:LOCALSUM] - if localsum != 2*sum(workers())*2 - println("localsum ", localsum, " != $(2*sum(workers())*2)") - error("localsum mismatch") + @lock DistributedArrays.SPMD.CONTEXTS.lock begin + for (k,v) in DistributedArrays.SPMD.CONTEXTS.data + store = v.store + localsum = store[:LOCALSUM] + if localsum != 2*sum(workers())*2 + println("localsum ", localsum, " != $(2*sum(workers())*2)") + error("localsum mismatch") + end + n += 1 end - n += 1 end @assert n == 8 end end # close the contexts -foreach(x->close(x), contexts) +foreach(close, contexts) # verify that the localstores have been deleted. @everywhere begin - @assert isempty(DistributedArrays.SPMD.map_ctxts) + @assert @lock DistributedArrays.SPMD.CONTEXTS.lock isempty(DistributedArrays.SPMD.CONTEXTS.data) end println("SPMD: Passed spmd function with explicit context run concurrently")