Skip to content

Commit 537cf16

Browse files
narrower output types and remove locks for bootstrap + code formatting
1 parent a1c76af commit 537cf16

7 files changed

Lines changed: 38 additions & 28 deletions

File tree

src/frontend/fit/standard_errors/bootstrap.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ function bootstrap(
5353
data = prepare_data_bootstrap(data, fitted.model)
5454
start = solution(fitted)
5555
# pre-allocations
56-
out = []
57-
conv = []
56+
out = Vector{Any}(nothing, n_boot)
57+
conv = fill(false, n_boot)
5858
# fit to bootstrap samples
5959
if !parallel
60-
for _ in 1:n_boot
60+
for i in 1:n_boot
6161
sample_data = bootstrap_sample(data)
6262
new_model = replace_observed(
6363
fitted.model;
@@ -68,8 +68,8 @@ function bootstrap(
6868
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
6969
sample = statistic(new_fit)
7070
c = converged(new_fit)
71-
push!(out, sample)
72-
push!(conv, c)
71+
out[i] = sample
72+
conv[i] = c
7373
end
7474
else
7575
n_threads = Threads.nthreads()
@@ -80,7 +80,7 @@ function bootstrap(
8080
end
8181
# fit models in parallel
8282
lk = ReentrantLock()
83-
Threads.@threads for _ in 1:n_boot
83+
Threads.@threads for i in 1:n_boot
8484
thread_model = take!(model_pool)
8585
sample_data = bootstrap_sample(data)
8686
new_model = replace_observed(
@@ -92,17 +92,15 @@ function bootstrap(
9292
new_fit = fit(new_model; start_val = start, engine = engine, fit_kwargs...)
9393
sample = statistic(new_fit)
9494
c = converged(new_fit)
95-
lock(lk) do
96-
push!(out, sample)
97-
push!(conv, c)
98-
end
95+
out[i] = sample
96+
conv[i] = c
9997
put!(model_pool, thread_model)
10098
end
10199
end
102100
return Dict(
103-
:samples => out,
101+
:samples => collect(a for a in out),
104102
:n_boot => n_boot,
105-
:n_converged => isempty(conv) ? 0 : sum(conv),
103+
:n_converged => sum(conv),
106104
:converged => conv,
107105
)
108106
end

src/implied/RAM/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ function update_observed(implied::RAM, observed::SemObserved; kwargs...)
203203
observed = observed,
204204
gradient_required = !isnothing(implied.∇A),
205205
meanstructure = MeanStruct(implied) == HasMeanStruct,
206-
kwargs...)
206+
kwargs...,
207+
)
207208
end
208209
end

src/implied/RAM/symbolic.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,14 @@ function update_observed(implied::RAMSymbolic, observed::SemObserved; kwargs...)
214214
return implied
215215
else
216216
return RAMSymbolic(;
217-
observed = observed,
218-
vech = implied.Σ isa Vector,
219-
gradient = !isnothing(implied.∇Σ),
220-
hessian = !isnothing(implied.∇²Σ),
221-
meanstructure = MeanStruct(implied) == HasMeanStruct,
222-
approximate_hessian = isnothing(implied.∇²Σ),
223-
kwargs...)
217+
observed = observed,
218+
vech = implied.Σ isa Vector,
219+
gradient = !isnothing(implied.∇Σ),
220+
hessian = !isnothing(implied.∇²Σ),
221+
meanstructure = MeanStruct(implied) == HasMeanStruct,
222+
approximate_hessian = isnothing(implied.∇²Σ),
223+
kwargs...,
224+
)
224225
end
225226
end
226227

src/loss/ML/ML.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,9 @@ function update_observed(lossfun::SemML, observed::SemObserved; kwargs...)
238238
return lossfun
239239
else
240240
return SemML(;
241-
observed = observed,
242-
approximate_hessian = HessianEval(lossfun) == ApproxHessian,
243-
kwargs...)
241+
observed = observed,
242+
approximate_hessian = HessianEval(lossfun) == ApproxHessian,
243+
kwargs...,
244+
)
244245
end
245246
end

src/loss/WLS/WLS.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,16 @@ function update_observed(lossfun::SemWLS, observed::SemObserved; recompute_V = t
178178
return SemWLS(;
179179
observed = observed,
180180
meanstructure = MeanStruct(kwargs[:implied]) == HasMeanStruct,
181-
kwargs...)
181+
kwargs...,
182+
)
182183
else
183184
return SemWLS(;
184185
observed = observed,
185186
wls_weight_matrix = lossfun.V,
186187
wls_weight_matrix_mean = lossfun.V_μ,
187188
meanstructure = MeanStruct(kwargs[:implied]) == HasMeanStruct,
188-
kwargs...)
189+
kwargs...,
190+
)
189191

190192
end
191193
end

test/examples/helper.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ function test_bootstrap(
144144
rtol_hessian = 0.2,
145145
compare_bs = true,
146146
rtol_bs = 0.1,
147-
n_boot = 500)
147+
n_boot = 500,
148+
)
148149
@testset rng = Random.seed!(32432) "bootstrap" begin
149150
se_bs = @suppress se_bootstrap(model_fit, spec; n_boot = n_boot)
150151
# hessian and bootstrap se are close
@@ -157,7 +158,8 @@ function test_bootstrap(
157158
if compare_bs
158159
bs_samples = bootstrap(model_fit, spec; n_boot = n_boot)
159160
@test bs_samples[:n_converged] >= 0.95*n_boot
160-
bs_samples = cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2)
161+
bs_samples =
162+
cat(bs_samples[:samples][BitVector(bs_samples[:converged])]..., dims = 2)
161163
se_bs_2 = sqrt.(var(bs_samples, corrected = false, dims = 2))
162164
#println(maximum(abs.(se_bs_2 - se_bs)))
163165
@test isapprox(se_bs_2, se_bs, rtol = rtol_bs)

test/examples/multigroup/build_models.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,12 @@ model_ls_g2 = Sem(
253253
loss = SemWLS,
254254
)
255255

256-
model_ls_multigroup = SemEnsemble(model_ls_g1, model_ls_g2; groups = [:Pasteur, :Grant_White], optimizer = semoptimizer)
256+
model_ls_multigroup = SemEnsemble(
257+
model_ls_g1,
258+
model_ls_g2;
259+
groups = [:Pasteur, :Grant_White],
260+
optimizer = semoptimizer,
261+
)
257262

258263
@testset "ls_gradients_multigroup" begin
259264
test_gradient(model_ls_multigroup, start_test; atol = 1e-9)

0 commit comments

Comments
 (0)