@@ -326,43 +326,41 @@ else
326326 _tail (nt:: NamedTuple ) = Base. tail (nt)
327327end
328328
329- function subset (varinfo:: UntypedVarInfo , vns:: AbstractVector{<:VarName} )
329+ function subset (varinfo:: VarInfo , vns:: AbstractVector{<:VarName} )
330330 metadata = subset (varinfo. metadata, vns)
331331 return VarInfo (metadata, deepcopy (varinfo. logp), deepcopy (varinfo. num_produce))
332332end
333333
334- function subset (varinfo:: VectorVarInfo , vns:: AbstractVector{<:VarName} )
335- metadata = subset (varinfo. metadata, vns)
336- return VarInfo (metadata, deepcopy (varinfo. logp), deepcopy (varinfo. num_produce))
334+ function subset (metadata:: NamedTuple , vns:: AbstractVector{<:VarName} )
335+ vns_syms = Set (unique (map (getsym, vns)))
336+ syms = filter (Base. Fix2 (in, vns_syms), keys (metadata))
337+ metadatas = map (syms) do sym
338+ subset (getfield (metadata, sym), filter (== (sym) ∘ getsym, vns))
339+ end
340+ return NamedTuple {syms} (metadatas)
337341end
338342
339- function subset (varinfo:: TypedVarInfo , vns:: AbstractVector{<:VarName{sym}} ) where {sym}
340- # If all the variables are using the same symbol, then we can just extract that field from the metadata.
341- metadata = subset (getfield (varinfo. metadata, sym), vns)
342- return VarInfo (
343- NamedTuple {(sym,)} (tuple (metadata)),
344- deepcopy (varinfo. logp),
345- deepcopy (varinfo. num_produce),
346- )
347- end
343+ # The above method is type unstable since we don't know which symbols are in `vns`.
344+ # In the below special case, when all `vns` have the same symbol, we can write a type stable
345+ # version.
348346
349- function subset (varinfo:: TypedVarInfo , vns:: AbstractVector{<:VarName} )
350- syms = Tuple (unique (map (getsym, vns)))
351- metadatas = map (syms) do sym
352- subset (getfield (varinfo. metadata, sym), filter (== (sym) ∘ getsym, vns))
347+ @generated function subset (
348+ metadata:: NamedTuple{names} , vns:: AbstractVector{<:VarName{sym}}
349+ ) where {names,sym}
350+ return if (sym in names)
351+ # TODO (mhauru) Note that this could still generate an empty metadata object if none
352+ # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
353+ # emptiness would make this type unstable again.
354+ :((; $ sym= subset (metadata.$ sym, vns)))
355+ else
356+ :(NamedTuple {} ())
353357 end
354-
355- return VarInfo (
356- NamedTuple {syms} (metadatas), deepcopy (varinfo. logp), deepcopy (varinfo. num_produce)
357- )
358358end
359359
360360function subset (metadata:: Metadata , vns_given:: AbstractVector{VN} ) where {VN<: VarName }
361361 # TODO : Should we error if `vns` contains a variable that is not in `metadata`?
362- # For each `vn` in `vns`, get the variables subsumed by `vn`.
363- vns = mapreduce (vcat, vns_given; init= VN[]) do vn
364- filter (Base. Fix1 (subsumes, vn), metadata. vns)
365- end
362+ # Find all the vns in metadata that are subsumed by one of the given vns.
363+ vns = filter (vn -> any (subsumes (vn_given, vn) for vn_given in vns_given), metadata. vns)
366364 indices_for_vns = map (Base. Fix1 (getindex, metadata. idcs), vns)
367365 indices = if isempty (vns)
368366 Dict {VarName,Int} ()
0 commit comments