Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 82 additions & 32 deletions src/core.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,109 @@
const registry=Dict{Tuple, Any}()
const refs=Set() # Collection of darray identities created on this node
# Thread-safe registry of DArray references
struct DArrayRegistry
data::Dict{Tuple{Int,Int}, Any}
lock::ReentrantLock
DArrayRegistry() = new(Dict{Tuple{Int,Int}, Any}(), ReentrantLock())
end
const REGISTRY = DArrayRegistry()

function Base.get(r::DArrayRegistry, id::Tuple{Int,Int}, default)
@lock r.lock begin
return get(r.data, id, default)
end
end
function Base.getindex(r::DArrayRegistry, id::Tuple{Int,Int})
@lock r.lock begin
return r.data[id]
end
end
function Base.setindex!(r::DArrayRegistry, val, id::Tuple{Int,Int})
@lock r.lock begin
r.data[id] = val
end
return r
end
function Base.delete!(r::DArrayRegistry, id::Tuple{Int,Int})
@lock r.lock delete!(r.data, id)
return r
end

# Thread-safe set of IDs of DArrays created on this node
struct DArrayRefs
data::Set{Tuple{Int,Int}}
lock::ReentrantLock
DArrayRefs() = new(Set{Tuple{Int,Int}}(), ReentrantLock())
end
const REFS = DArrayRefs()

let DID::Int = 1
global next_did
next_did() = (id = DID; DID += 1; (myid(), id))
function Base.push!(r::DArrayRefs, id::Tuple{Int,Int})
# Ensure id refers to a DArray created on this node
if first(id) != myid()
throw(
ArgumentError(
lazy"`DArray` is not created on the current worker: Only `DArray`s created on worker $(myid()) can be stored in this set but the `DArray` was created on worker $(first(id))."))
end
@lock r.lock begin
return push!(r.data, id)
end
end
function Base.delete!(r::DArrayRefs, id::Tuple{Int,Int})
@lock r.lock delete!(r.data, id)
return r
end

# Global counter to generate a unique ID for each DArray
const DID = Threads.Atomic{Int}(1)

"""
next_did()

Produces an incrementing ID that will be used for DArrays.
"""
next_did
Increment a global counter and return a tuple of the current worker ID and the incremented
value of the counter.

release_localpart(id::Tuple) = (delete!(registry, id); nothing)
release_localpart(d) = release_localpart(d.id)
This tuple is used as a unique ID for a new `DArray`.
"""
next_did() = (myid(), Threads.atomic_add!(DID, 1))

function close_by_id(id, pids)
# @async println("Finalizer for : ", id)
global refs
release_localpart(id::Tuple{Int,Int}) = (delete!(REGISTRY, id); nothing)
function release_allparts(id::Tuple{Int,Int}, pids::Array{Int})
@sync begin
released_myid = false
for p in pids
@async remotecall_fetch(release_localpart, p, id)
if p == myid()
@async release_localpart(id)
released_myid = true
else
@async remotecall_fetch(release_localpart, p, id)
end
end
if !(myid() in pids)
release_localpart(id)
if !released_myid
@async release_localpart(id)
end
end
delete!(refs, id)
nothing
return nothing
end

function Base.close(d::DArray)
# @async println("close : ", d.id, ", object_id : ", object_id(d), ", myid : ", myid() )
if (myid() == d.id[1]) && d.release
@async close_by_id(d.id, d.pids)
d.release = false
end
function close_by_id(id::Tuple{Int,Int}, pids::Array{Int})
release_allparts(id, pids)
delete!(REFS, id)
nothing
end

function d_closeall()
crefs = copy(refs)
for id in crefs
if id[1] == myid() # sanity check
if haskey(registry, id)
d = d_from_weakref_or_d(id)
(d === nothing) || close(d)
@lock REFS.lock begin
while !isempty(REFS.data)
id = pop!(REFS.data)
d = d_from_weakref_or_d(id)
if d isa DArray
finalize(d)
end
yield()
end
end
return nothing
end

Base.close(d::DArray) = finalize(d)

"""
procs(d::DArray)

Expand All @@ -67,4 +118,3 @@ Distributed.procs(d::SubDArray) = procs(parent(d))
The identity when input is not distributed
"""
localpart(A) = A

78 changes: 32 additions & 46 deletions src/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,40 @@ dfill(v, args...) = DArray(I->fill(v, map(length,I)), args...)
```
"""
mutable struct DArray{T,N,A} <: AbstractArray{T,N}
id::Tuple
id::Tuple{Int,Int}
dims::NTuple{N,Int}
pids::Array{Int,N} # pids[i]==p ⇒ processor p has piece i
indices::Array{NTuple{N,UnitRange{Int}},N} # indices held by piece i
cuts::Vector{Vector{Int}} # cuts[d][i] = first index of chunk i in dimension d
localpart::Union{A,Nothing}
release::Bool

function DArray{T,N,A}(id, dims, pids, indices, cuts, lp) where {T,N,A}
function DArray{T,N,A}(id::Tuple{Int,Int}, dims::NTuple{N,Int}, pids, indices, cuts, lp) where {T,N,A}
# check invariants
if dims != map(last, last(indices))
throw(ArgumentError("dimension of DArray (dim) and indices do not match"))
end
release = (myid() == id[1])

d = d_from_weakref_or_d(id)
if d === nothing
d = new(id, dims, pids, indices, cuts, lp, release)
d = new(id, dims, pids, indices, cuts, lp)
end

if release
push!(refs, id)
registry[id] = WeakRef(d)

# println("Installing finalizer for : ", d.id, ", : ", object_id(d), ", isbits: ", isbits(d))
finalizer(close, d)
if first(id) == myid()
push!(REFS, id)
REGISTRY[id] = WeakRef(d)
finalizer(d) do d
@async close_by_id(d.id, d.pids)
end
end
d
end

DArray{T,N,A}() where {T,N,A} = new()
end

function d_from_weakref_or_d(id)
d = get(registry, id, nothing)
isa(d, WeakRef) && return d.value
return d
end
unpack_weakref(x) = x
unpack_weakref(x::WeakRef) = x.value
d_from_weakref_or_d(id::Tuple{Int,Int}) = unpack_weakref(get(REGISTRY, id, nothing))

Base.eltype(::Type{DArray{T}}) where {T} = T
empty_localpart(T,N,A) = A(Array{T}(undef, ntuple(zero, N)))
Expand All @@ -77,41 +73,34 @@ Base.hash(d::DArray, h::UInt) = Base.hash(d.id, h)

## core constructors ##

function DArray(id, init, dims, pids, idxs, cuts)
function DArray(id::Tuple{Int,Int}, init::I, dims, pids, idxs, cuts) where {I}
localtypes = Vector{DataType}(undef,length(pids))

@sync begin
for i = 1:length(pids)
@async begin
local typA
if isa(init, Function)
typA = remotecall_fetch(construct_localparts, pids[i], init, id, dims, pids, idxs, cuts)
else
# constructing from an array of remote refs.
typA = remotecall_fetch(construct_localparts, pids[i], init[i], id, dims, pids, idxs, cuts)
end
localtypes[i] = typA
end
if init isa Function
asyncmap!(localtypes, pids) do pid
return remotecall_fetch(construct_localparts, pid, init, id, dims, pids, idxs, cuts)
end
else
asyncmap!(localtypes, pids, init) do pid, pid_init
# constructing from an array of remote refs.
return remotecall_fetch(construct_localparts, pid, pid_init, id, dims, pids, idxs, cuts)
end
end

if length(unique(localtypes)) != 1
if !allequal(localtypes)
@sync for p in pids
@async remotecall_fetch(release_localpart, p, id)
end
throw(ErrorException("Constructed localparts have different `eltype`: $(localtypes)"))
throw(ErrorException(lazy"Constructed localparts have different `eltype`: $(localtypes)"))
end
A = first(localtypes)

if myid() in pids
d = registry[id]
d = isa(d, WeakRef) ? d.value : d
return unpack_weakref(REGISTRY[id])
else
T = eltype(A)
N = length(dims)
d = DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A))
return DArray{T,N,A}(id, dims, pids, idxs, cuts, empty_localpart(T,N,A))
end
d
end

function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=nothing)
Expand All @@ -124,7 +113,7 @@ function construct_localparts(init, id, dims, pids, idxs, cuts; T=nothing, A=not
end
N = length(dims)
d = DArray{T,N,A}(id, dims, pids, idxs, cuts, localpart)
registry[id] = d
REGISTRY[id] = d
A
end

Expand Down Expand Up @@ -152,12 +141,10 @@ function ddata(;T::Type=Any, init::Function=I->nothing, pids=workers(), data::Ve
end

if myid() in pids
d = registry[id]
d = isa(d, WeakRef) ? d.value : d
return unpack_weakref(REGISTRY[id])
else
d = DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing)
return DArray{T,1,T}(id, (npids,), pids, idxs, cuts, nothing)
end
d
end

function gather(d::DArray{T,1,T}) where T
Expand Down Expand Up @@ -428,7 +415,7 @@ end
function Base.:(==)(d::SubDArray, a::AbstractArray)
cd = copy(d)
t = cd == a
close(cd)
finalize(cd)
return t
end
Base.:(==)(a::AbstractArray, d::DArray) = d == a
Expand All @@ -437,19 +424,19 @@ Base.:(==)(d1::DArray, d2::DArray) = invoke(==, Tuple{DArray, AbstractArray}, d1
function Base.:(==)(d1::SubDArray, d2::DArray)
cd1 = copy(d1)
t = cd1 == d2
close(cd1)
finalize(cd1)
return t
end
function Base.:(==)(d1::DArray, d2::SubDArray)
cd2 = copy(d2)
t = d1 == cd2
close(cd2)
finalize(cd2)
return t
end
function Base.:(==)(d1::SubDArray, d2::SubDArray)
cd1 = copy(d1)
t = cd1 == d2
close(cd1)
finalize(cd1)
return t
end

Expand Down Expand Up @@ -845,4 +832,3 @@ function Random.rand!(A::DArray, ::Type{T}) where T
remotecall_wait((A, T)->rand!(localpart(A), T), p, A, T)
end
end

Loading
Loading