-
Notifications
You must be signed in to change notification settings - Fork 6
Convergence, bootstrap, p-values #316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
e80389c
add converged method
Maximilian-Stefan-Ernst 1623119
refactor bootstrap and add bootstrap for any statistic
Maximilian-Stefan-Ernst 8f2eaf1
add CI and p values
Maximilian-Stefan-Ernst 6f1246c
add exports
Maximilian-Stefan-Ernst 20c8b63
add boostrap, CI and p-values tests
Maximilian-Stefan-Ernst 0eb04ef
Apply suggestions from formatter
Maximilian-Stefan-Ernst 5e5575b
fix bootstrap tests
Maximilian-Stefan-Ernst c27706d
fix bootstrap tests
Maximilian-Stefan-Ernst 24a7947
fix nobs var check in update_observed, keep previous args in update_o…
Maximilian-Stefan-Ernst aa62557
increase tolerance for bootstrap test
Maximilian-Stefan-Ernst fdca879
increase tolerance for bootstrap test
Maximilian-Stefan-Ernst ccf8c0d
remove bootstrap try-catch and update tests
Maximilian-Stefan-Ernst 1adb1fa
fix bootstrap
Maximilian-Stefan-Ernst 9478d55
fix bootstrap
Maximilian-Stefan-Ernst 2e37710
fix bootstrap
Maximilian-Stefan-Ernst 58cdb0f
fix mg weights
Maximilian-Stefan-Ernst 7f072f4
fix bootstrap tests
Maximilian-Stefan-Ernst cacb73c
fix bootstrap tests
Maximilian-Stefan-Ernst a1c76af
add WLS option to fix weight matrix for updaeting observed data
Maximilian-Stefan-Ernst 537cf16
narrower output types and remove locks for bootstrap + code formatting
Maximilian-Stefan-Ernst d00a65e
Apply suggestions from code review
Maximilian-Stefan-Ernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,39 +1,164 @@ | ||
| """ | ||
| bootstrap( | ||
| fitted::SemFit, | ||
| specification::SemSpecification; | ||
| statistic = solution, | ||
| n_boot = 3000, | ||
| data = nothing, | ||
| engine = :Optim, | ||
| parallel = false, | ||
| fit_kwargs = Dict(), | ||
| replace_kwargs = Dict()) | ||
|
|
||
| Return bootstrap samples for `statistic`. | ||
|
|
||
| # Arguments | ||
| - `fitted`: a fitted SEM. | ||
| - `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`. | ||
| - `statistic`: any function that can be called on a `SemFit` object. | ||
| The output will be returned as the bootstrap sample. | ||
| - `n_boot`: number of boostrap samples | ||
| - `data`: data to sample from. Only needed if different than the data from `sem_fit` | ||
| - `engine`: optimizer engine, passed to `fit`. | ||
| - `parallel`: if `true`, run bootstrap samples in parallel on all available threads. | ||
| The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or | ||
| the `--threads` flag when starting Julia. | ||
| - `fit_kwargs` : a `Dict` controlling model fitting for each bootstrap sample, | ||
| passed to `fit` | ||
| - `replace_kwargs`: a `Dict` passed to `replace_observed` | ||
|
|
||
| # Example | ||
| ```julia | ||
| # 1000 boostrap samples of the minimum, fitted with :Optim | ||
| bootstrap( | ||
| fitted; | ||
| statistic = StructuralEquationModels.minimum, | ||
| n_boot = 1000, | ||
| engine = :Optim, | ||
| ) | ||
| ``` | ||
| """ | ||
| function bootstrap( | ||
| fitted::SemFit, | ||
| specification::SemSpecification; | ||
| statistic = solution, | ||
| n_boot = 3000, | ||
| data = nothing, | ||
| engine = :Optim, | ||
| parallel = false, | ||
| fit_kwargs = Dict(), | ||
| replace_kwargs = Dict(), | ||
| ) | ||
| # access data and convert to matrix | ||
| data = prepare_data_bootstrap(data, fitted.model) | ||
| start = solution(fitted) | ||
| # pre-allocations | ||
| out = Vector{Any}(nothing, n_boot) | ||
| conv = fill(false, n_boot) | ||
| # fit to bootstrap samples | ||
| if !parallel | ||
| for i in 1:n_boot | ||
| sample_data = bootstrap_sample(data) | ||
| new_model = replace_observed( | ||
| fitted.model; | ||
| data = sample_data, | ||
| specification = specification, | ||
| replace_kwargs..., | ||
| ) | ||
| new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) | ||
| sample = statistic(new_fit) | ||
| c = converged(new_fit) | ||
| out[i] = sample | ||
| conv[i] = c | ||
| end | ||
| else | ||
| n_threads = Threads.nthreads() | ||
| # Pre-create one independent model copy per thread via deepcopy. | ||
| model_pool = Channel(n_threads) | ||
| for _ in 1:n_threads | ||
| put!(model_pool, deepcopy(fitted.model)) | ||
| end | ||
| # fit models in parallel | ||
| lk = ReentrantLock() | ||
| Threads.@threads for i in 1:n_boot | ||
| thread_model = take!(model_pool) | ||
| sample_data = bootstrap_sample(data) | ||
| new_model = replace_observed( | ||
| thread_model; | ||
| data = sample_data, | ||
| specification = specification, | ||
| replace_kwargs..., | ||
| ) | ||
| new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) | ||
| sample = statistic(new_fit) | ||
| c = converged(new_fit) | ||
| out[i] = sample | ||
| conv[i] = c | ||
| put!(model_pool, thread_model) | ||
| end | ||
| end | ||
| return Dict( | ||
| :samples => collect(a for a in out), | ||
| :n_boot => n_boot, | ||
| :n_converged => sum(conv), | ||
| :converged => conv, | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
| se_bootstrap( | ||
| sem_fit::SemFit; | ||
| fitted::SemFit, | ||
| specification::SemSpecification; | ||
| n_boot = 3000, | ||
| data = nothing, | ||
| specification = nothing, | ||
| parallel = false, | ||
| kwargs...) | ||
| fit_kwargs = Dict(), | ||
| replace_kwargs = Dict()) | ||
|
|
||
| Return bootstrap standard errors. | ||
|
|
||
| # Arguments | ||
| - `fitted`: a fitted SEM. | ||
| - `specification`: a `ParameterTable` or `RAMMatrices` object passed to `replace_observed`. | ||
| - `n_boot`: number of boostrap samples | ||
| - `data`: data to sample from. Only needed if different than the data from `sem_fit` | ||
| - `specification`: a `ParameterTable` or `RAMMatrices` object passed down to `replace_observed`. | ||
| Necessary for FIML models. | ||
| - `engine`: optimizer engine, passed to `fit`. | ||
| - `parallel`: if `true`, run bootstrap samples in parallel on all available threads. | ||
| The number of threads is controlled by the `JULIA_NUM_THREADS` environment variable or | ||
| the `--threads` flag when starting Julia. | ||
| - `kwargs...`: passed down to `replace_observed` | ||
| - `fit_kwargs` : a `Dict` controlling model fitting for each bootstrap sample, | ||
| passed to `sem_fit` | ||
| - `replace_kwargs`: a `Dict` passed to `replace_observed` | ||
|
|
||
| # Example | ||
| ```julia | ||
| # 1000 boostrap samples, fitted with :NLopt | ||
| using NLopt | ||
|
|
||
| se_bootstrap( | ||
| fitted; | ||
| n_boot = 1000, | ||
| engine = :NLopt, | ||
| ) | ||
| ``` | ||
| """ | ||
| function se_bootstrap( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be possible to make |
||
| fitted::SemFit{Mi, So, St, Mo, O}; | ||
| fitted::SemFit, | ||
| specification::SemSpecification; | ||
| n_boot = 3000, | ||
| data = nothing, | ||
| specification = nothing, | ||
| engine = :Optim, | ||
| parallel = false, | ||
| kwargs..., | ||
| ) where {Mi, So, St, Mo, O} | ||
| fit_kwargs = Dict(), | ||
| replace_kwargs = Dict(), | ||
| ) | ||
| # access data and convert to matrix | ||
| data = prepare_data_bootstrap(data, fitted.model) | ||
| start = solution(fitted) | ||
| # pre-allocations | ||
| total_sum = zero(start) | ||
| total_squared_sum = zero(start) | ||
| n_failed = Ref(0) | ||
| n_conv = Ref(0) | ||
| # fit to bootstrap samples | ||
| if !parallel | ||
| for _ in 1:n_boot | ||
|
|
@@ -42,14 +167,15 @@ function se_bootstrap( | |
| fitted.model; | ||
| data = sample_data, | ||
| specification = specification, | ||
| kwargs..., | ||
| replace_kwargs..., | ||
| ) | ||
| try | ||
| sol = solution(fit(new_model; start_val = start)) | ||
| new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) | ||
| sol = solution(new_fit) | ||
| conv = converged(new_fit) | ||
| if conv | ||
| n_conv[] += 1 | ||
| @. total_sum += sol | ||
| @. total_squared_sum += sol^2 | ||
| catch | ||
| n_failed[] += 1 | ||
| end | ||
| end | ||
| else | ||
|
|
@@ -63,37 +189,37 @@ function se_bootstrap( | |
| lk = ReentrantLock() | ||
| Threads.@threads for _ in 1:n_boot | ||
| thread_model = take!(model_pool) | ||
| try | ||
| sample_data = bootstrap_sample(data) | ||
| new_model = replace_observed( | ||
| thread_model; | ||
| data = sample_data, | ||
| specification = specification, | ||
| kwargs..., | ||
| ) | ||
| sol = solution(fit(new_model; start_val = start)) | ||
| sample_data = bootstrap_sample(data) | ||
| new_model = replace_observed( | ||
| thread_model; | ||
| data = sample_data, | ||
| specification = specification, | ||
| replace_kwargs..., | ||
| ) | ||
| new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...) | ||
| sol = solution(new_fit) | ||
| conv = converged(new_fit) | ||
| if conv | ||
| lock(lk) do | ||
| n_conv[] += 1 | ||
| @. total_sum += sol | ||
| @. total_squared_sum += sol^2 | ||
| end | ||
| catch | ||
| lock(lk) do | ||
| n_failed[] += 1 | ||
| end | ||
Maximilian-Stefan-Ernst marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| finally | ||
| put!(model_pool, thread_model) | ||
| end | ||
| put!(model_pool, thread_model) | ||
| end | ||
| end | ||
| # compute parameters | ||
| n_conv = n_boot - n_failed[] | ||
| n_conv = n_conv[] | ||
| sd = sqrt.(total_squared_sum / n_conv - (total_sum / n_conv) .^ 2) | ||
| if !iszero(n_failed[]) | ||
| @warn "During bootstrap sampling, "*string(n_failed[])*" models did not converge" | ||
| end | ||
| @info string(n_conv)*" models converged" | ||
| return sd | ||
| end | ||
|
|
||
| ############################################################################################ | ||
| ### Helper Functions | ||
| ############################################################################################ | ||
|
|
||
| function bootstrap_sample(data::Matrix) | ||
| nobs = size(data, 1) | ||
| index_new = rand(1:nobs, nobs) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| _doc_normal_CI = """ | ||
| (1) normal_CI(fitted, se; α = 0.05, name_lower = :ci_lower, name_upper = :ci_upper) | ||
|
|
||
| (2) normal_CI!(partable, fitted, se; α = 0.05, name_lower = :ci_lower, name_upper = :ci_upper) | ||
|
|
||
| Return normal-theory confidence intervals for all model parameters. | ||
| `normal_CI!` additionally writes the result into `partable`. | ||
|
|
||
| # Arguments | ||
| - `fitted`: a fitted SEM. | ||
| - `se`: standard errors for each parameter, e.g. from [`se_hessian`](@ref) or | ||
| [`se_bootstrap`](@ref). | ||
| - `partable`: a [`ParameterTable`](@ref) to write confidence intervals to. | ||
| - `α`: significance level. Defaults to `0.05` (95% intervals). | ||
| - `name_lower`: column name for the lower bound in `partable`. Defaults to `:ci_lower`. | ||
| - `name_upper`: column name for the upper bound in `partable`. Defaults to `:ci_upper`. | ||
|
|
||
| # Returns | ||
| - a `Dict` with keys `name_lower` and `name_upper`, each mapping to a vector of bounds | ||
| over all parameters. | ||
| """ | ||
|
|
||
| @doc "$(_doc_normal_CI)" | ||
| function normal_CI(fitted, se; α = 0.05, name_lower = :ci_lower, name_upper = :ci_upper) | ||
| qnt = quantile(Normal(0, 1), 1-α/2); | ||
| sol = solution(fitted) | ||
| return Dict(name_lower => sol - qnt*se, name_upper => sol + qnt*se) | ||
| end | ||
|
|
||
| @doc "$(_doc_normal_CI)" | ||
| function normal_CI!( | ||
| partable, | ||
| fitted, | ||
| se; | ||
| α = 0.05, | ||
| name_lower = :ci_lower, | ||
| name_upper = :ci_upper, | ||
| ) | ||
| cis = normal_CI(fitted, se; α, name_lower, name_upper) | ||
| update_partable!(partable, name_lower, fitted, cis[name_lower]) | ||
| update_partable!(partable, name_upper, fitted, cis[name_upper]) | ||
| return cis | ||
| end |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to use
NamedTuplehere.Also, while it might be fragile to use
Base.return_types()to infer the result ofstatistic()before it is called, here it should be in principle possible to useT = mapreduce(typeof, promote_type, out)andVector{T}(out)to help with the type inference and faster code generation downstream.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree in principle that the NamedTuple is better, but I will keep the Dict for now because it has a much nicer console output that is actually humanly readable. Nice idea to fix the type, curiously
collect(a for a in out)seems to be faster by quite a margin.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
bootstrap()will be the generic method to calculate different statistics, it would be nice to make its result play nicely with the type inference in its callers.NamedTupleis just an ad hoc way to achieve it. If pretty printing is a concern, it could be its own typeSemBootstrapResult{T}, and SEM.jl may rely on e.g. PrettyPrinting.jl to quickly implementBase.show(io, bootstrap::SemBootstrapResult) = pprint(bootstrap)