-
Notifications
You must be signed in to change notification settings - Fork 36
Remove ThreadSafeVarInfo
#1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove ThreadSafeVarInfo
#1023
Conversation
* Implement InitContext * Fix loading order of modules; move `prefix(::Model)` to model.jl * Add tests for InitContext behaviour * inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). * Document * Add a test to check that `init!!` doesn't change linking * Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. * Add some line breaks Co-authored-by: Markus Hauru <[email protected]> * Add the option of no fallback for ParamsInit * Improve docstrings * typo * `p.default` -> `p.fallback` * Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}` --------- Co-authored-by: Markus Hauru <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1023 +/- ##
==========================================
- Coverage 82.34% 80.91% -1.44%
==========================================
Files 38 39 +1
Lines 3949 3810 -139
==========================================
- Hits 3252 3083 -169
- Misses 697 727 +30 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Benchmark Report for Commit 700ed07Computer InformationBenchmark Results |
Pull Request Test Coverage Report for Build 17024844925Details
💛 - Coveralls |
f797af7 to
79bedaf
Compare
|
DynamicPPL.jl documentation for PR #1023 is available at: |
|
As a bonus, this PR completely fixes all Enzyme issues arising from DPPL 0.37. #947 |
* use `varname_leaves` from AbstractPPL instead * add changelog entry * fix import
…!`, `predict`, `returned`, and `initialize_values` (#984) * Replace `evaluate_and_sample!!` -> `init!!` * Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends * Use `init!!` for initialisation * Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) * Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway * Remove `predict` on vector of VarInfo * Fix some tests * Remove duplicated test * Simplify context testing * Rename FooInit -> InitFromFoo * Fix JETExt * Fix JETExt properly * Fix tests * Improve comments * Remove duplicated tests * Docstring improvements Co-authored-by: Markus Hauru <[email protected]> * Concretise `chain_sample_to_varname_dict` using chain value type * Clarify testset name * Re-add comment that shouldn't have vanished * Fix stale Requires dep * Fix default_varinfo/initialisation for odd models * Add comment to src/sampler.jl Co-authored-by: Markus Hauru <[email protected]> --------- Co-authored-by: Markus Hauru <[email protected]>
79bedaf to
15d662c
Compare
15d662c to
d017e4b
Compare
f4d4fbf to
1f7fef3
Compare
|
Unfortuntately I don't know how to deal with conditioned/fixed variables without a huge amount of faff and macro code duplication 😮💨 |
|
This requires a bit more discussion before we make a commitment -- not entirely sure we should introduce a new macro. |
|
yeah, I remember the discussion we had a few meetings ago |
|
holy merge conflicts |
Summary
This PR removes
ThreadSafeVarInfo.In its place, a
@pobservemacro is added to enable multithreaded tilde-observe statements, according to the plan outlined in #924 (comment). Broadly speaking, the followingis converted into (modulo variable names)
No actual varinfo manipulation happens inside the
Threads.@spawn: instead, the log-likelihood contributions are calculated in each thread, then summed after the individual threads have finished their tasks. Because of this, there is no need to maintain one log-likelihood accumulator per thread, and consequently no need forThreadSafeVarInfo.Closes #429.
Closes #924.
Closes #947.
Why?
Code simplification in DynamicPPL, and reducing the number of
AbstractVarInfosubtypes, is obviously a big argument.But in fact, that's not my main motivation. I'm mostly motivated to do this because TSVI in general is IMO not good code: it works, but in many ways it's a hack.
Threads.@threads for i in x ... end, and then internally we useThreads.threadid()to index into a vector of accumulators. This is now regarded as "incorrect parallel code that contains the possibility of race conditions which can give wrong results". See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/ and https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042.Threads.nthreads() * 2which is a hacky heuristic. The correct solution would beThreads.maxthreadid(), but Mooncake couldn't differentiate through that.threadid,nthreadsand evenmaxthreadid[is] perilous. Any code that relies on a specificthreadidstaying constant, or on a constant number of threads during execution, is bound to be incorrect.".if Threads.nthreads() > 1, which cannot be determined at compile time. This means that:evaluate!!must be together type stable.evaluate!!. That's just silly IMO.cacheForReverseEnzymeAD/Enzyme.jl#2518Does this actually work?
This PR has no tests yet, but I ran this locally and the log-likelihood gets accumulated correctly:
I can also confirm that the parallelisation is correctly occurring with this model:
If you run this with 1 thread it takes 2 seconds, and if you run it with 2 threads it takes 1 second.
It also works correctly with
MCMCThreads()(with some minor adjustments to Turing.jl for compatibility with this branch). NOTE: Sampling with@pobserveis now fully reproducible, whereasThreads.@threadswas not reproducible even when seeded.What now?
There are a handful of limitations to this PR. These are the ones I can think of right now:
It will crash if the VarInfo used for evaluation does not have a likelihood accumulator.DynamicPPL.acclogprior!!()..~(or maybe it does, I haven't tested, but my guess is that it will bug out)xis not a model argument or conditioned upon, this will yield wrong results for the typicalx = Vector{Float64}(undef, 2); @pobserve for i in eachindex(x); x[i] ~ dist; endas it will naively accumulatelogpdf(dist, x[i])even though this should be an assumption rather than observationThere is no way to extract other computations from the threads.Threads.@spawn, so PG will throw an error with@pobserve.@pobserveis a bit too unambitious. If one day we make it work with assume, then it will have to be renamed, i.e. a breaking change.I believe that all of these are either unimportant or can be worked around with some additional macro leg-work:
Not important, nobody is running around evaluating their models with no likelihood accumulator. Not even Turing does this. Also easy enough for us to guard against by wrapping the entire thing in an if/else.acclogprior!!outside the threaded bit.This can be fixed easily by changing the macro to return a tuple of(retval, loglike)rather than justloglike.Threads.@threads.So for now this should mostly be considered a proof of principle rather than a complete PR.
Finally, note that this PR already removes > 550 lines of code but this is not a full picture of the simplification afforded. For example, I did not remove the
split,combine, andconvert_eltypemethods on accumulators, which I believe can either be removed or simplified once TSVI is removed.