diff --git a/docs/src/API/codegen.md b/docs/src/API/codegen.md index 4f31405174..0092e1cb8e 100644 --- a/docs/src/API/codegen.md +++ b/docs/src/API/codegen.md @@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b All code generation eventually calls `build_function_wrapper`. ```@docs -build_function_wrapper +ModelingToolkit.build_function_wrapper ``` diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index f823260e2a..200e0d3d56 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -1037,9 +1037,11 @@ function respecialize(sys::AbstractSystem, mapping; all = false) """ if iscall(k) - op = operation(k) + op = operation(k)::BasicSymbolic + @assert !iscall(op) + op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op)) args = arguments(k) - new_p = SymbolicUtils.term(op, args...; type = T) + new_p = op(args...) else new_p = SymbolicUtils.Sym{T}(getname(k)) end diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 637ed674ae..f91348d0c9 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -763,8 +763,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru oldbuf.discrete, newbuf.discrete) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( oldbuf.constant, newbuf.constant) - @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( - oldbuf.nonnumeric, newbuf.nonnumeric) + for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric) + for i in eachindex(oldv) + isassigned(newv, i) && continue + newv[i] = oldv[i] + end + end + @set! newbuf.nonnumeric = Tuple( + typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)) if !ArrayInterface.ismutable(oldbuf) @set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable) @set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2a1208586b..2f5db01479 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -701,9 +701,10 @@ function. Note that the getter ONLY works for problem-like objects, since it generates an observed function. It does NOT work for solutions. """ -Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector) +Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module) @nospecialize - obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false) + obsfn = build_explicit_observed_function( + indp, syms; wrap_delays = false, eval_expression, eval_module) return ObservedWrapper{is_time_dependent(indp)}(obsfn) end @@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns - `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; - initials = false, unwrap_initials = false, p_constructor = identity) + initials = false, unwrap_initials = false, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the @@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac tunable_getter = if isempty(tunable_syms) Returns(SizedVector{0, Float64}()) else - p_constructor ∘ concrete_getu(srcsys, tunable_syms) + p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module) end initials_getter = if initials && !isempty(syms[2]) initsyms = Vector{Any}(syms[2]) @@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end end end - p_constructor ∘ concrete_getu(srcsys, initsyms) + p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module) else Returns(SizedVector{0, Float64}()) end @@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ Base.Fix1(broadcast, p_constructor) ∘ - getu(srcsys, syms[3]) + concrete_getu(srcsys, syms[3]; eval_expression, eval_module) end const_getter = if syms[4] == () Returns(()) @@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end) # nonnumerics retain the assigned buffer type without narrowing Base.Fix1(broadcast, _p_constructor) ∘ - Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ + concrete_getu(srcsys, syms[5]; eval_expression, eval_module) end getters = ( tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) @@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of ` with values from `srcsys`. """ function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity) + srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) @assert is_initializesystem(dstsys) - ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys)) + ugetter = u0_constructor ∘ + concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module) if is_split(dstsys) - pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor) + pgetter = get_mtkparameters_reconstructor( + srcsys, dstsys; p_constructor, eval_expression, eval_module) else syms = parameters(dstsys) - pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor + pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module), + p_constructor = p_constructor + function _getter2(valp, initprob) p_constructor(inner(valp)) end @@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the `initializeprobpmap` function in `OverrideInitData` for the systems. """ function construct_initializeprobpmap( - sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity) + sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module) @assert is_initializesystem(initsys) if is_split(sys) return let getter = get_mtkparameters_reconstructor( - initsys, sys; initials = true, unwrap_initials = true, p_constructor) + initsys, sys; initials = true, unwrap_initials = true, p_constructor, + eval_expression, eval_module) function initprobpmap_split(prob, initsol) getter(initsol, prob) end end else - return let getter = getu(initsys, parameters(sys; initial_parameters = true)), - p_constructor = p_constructor + return let getter = concrete_getu( + initsys, parameters(sys; initial_parameters = true); + eval_expression, eval_module), p_constructor = p_constructor function initprobpmap_nosplit(prob, initsol) return p_constructor(getter(initsol)) @@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU} get_initial_unknowns::GIU end -function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict) +function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict) dvs = unknowns(sys) eqs = equations(sys) guessvars = trues(length(dvs)) for (i, var) in enumerate(dvs) guessvars[i] = !isequal(get(op, var, nothing), Initial(var)) end - get_guessvars = getu(initsys, dvs[guessvars]) + get_guessvars = getu(initprob, dvs[guessvars]) get_initial_unknowns = getu(sys, Initial.(dvs)) return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns) end @@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem( guesses, missing_unknowns; implicit_dae = false, time_dependent_init = is_time_dependent(sys), u0_constructor = identity, p_constructor = identity, floatT = Float64, initialization_eqs = [], - use_scc = true, kwargs...) + use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem( initializeprob = ModelingToolkit.InitializationProblem{iip}( sys, t, op; guesses, time_dependent_init, initialization_eqs, - use_scc, u0_constructor, p_constructor, kwargs...) + use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...) if state_values(initializeprob) !== nothing _u0 = state_values(initializeprob) if ArrayInterface.ismutable(_u0) @@ -1145,7 +1155,7 @@ function maybe_build_initialization_problem( initializeprob = remake(initializeprob; p = initp) get_initial_unknowns = if time_dependent_init - GetUpdatedU0(sys, initializeprob.f.sys, op) + GetUpdatedU0(sys, initializeprob, op) else nothing end @@ -1153,7 +1163,8 @@ function maybe_build_initialization_problem( copy(op), copy(guesses), Vector{Equation}(initialization_eqs), use_scc, time_dependent_init, ReconstructInitializeprob( - sys, initializeprob.f.sys; u0_constructor, p_constructor), + sys, initializeprob.f.sys; u0_constructor, + p_constructor, eval_expression, eval_module), get_initial_unknowns, SetInitialUnknowns(sys)) if time_dependent_init @@ -1172,10 +1183,9 @@ function maybe_build_initialization_problem( initializeprobpmap = nothing else initializeprobpmap = construct_initializeprobpmap( - sys, initializeprob.f.sys; p_constructor) + sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module) end - reqd_syms = parameter_symbols(initializeprob) # we still want the `initialization_data` because it helps with `remake` if initializeprobmap === nothing && initializeprobpmap === nothing update_initializeprob! = nothing @@ -1186,7 +1196,9 @@ function maybe_build_initialization_problem( filter!(punknowns) do p is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing end - pvals = getu(initializeprob, punknowns)(initializeprob) + # See comment below for why `getu` is not used here. + _pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns) + pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob)) for (p, pval) in zip(punknowns, pvals) p = unwrap(p) op[p] = pval @@ -1198,7 +1210,13 @@ function maybe_build_initialization_problem( end if time_dependent_init - uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob) + # We can't use `getu` here because that goes to `SII.observed`, which goes to + # `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If + # `eval_expression == true`, this then runs into world-age issues. Building an + # RGF here is fine since it is always discarded. We can't use `eval_module` for + # the RGF since the user may not have run RGF's init. + _ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns)) + uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob)) for (v, val) in zip(missing_unknowns, uvals) op[v] = val end @@ -1461,7 +1479,7 @@ function process_SciMLProblem( if is_time_dependent(sys) && t0 === nothing t0 = zero(floatT) end - initialization_data = SciMLBase.remake_initialization_data( + initialization_data = @invokelatest SciMLBase.remake_initialization_data( sys, kwargs, u0, t0, p, u0, p) kwargs = merge(kwargs, (; initialization_data)) end @@ -1773,7 +1791,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs """ function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...) # Call `remake` so it runs initialization if it is trivial - remake(T(args...; kwargs...)) + # Use `@invokelatest` to avoid world-age issues with `eval_expression = true` + @invokelatest remake(T(args...; kwargs...)) end """ diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index dc9d71f300..19899aa50f 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -340,11 +340,12 @@ foofn(x) = 4 @testset "`respecialize`" begin @parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r - rp, - rp2 = let - only(@parameters p::Bar), - SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz) - end + rp = only(let p = nothing + @parameters p::Bar + end) + rp2 = only(let p2 = nothing + @parameters p2(t)::Baz + end) @variables x(t) = 1.0 @named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r]) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 28ab3759ef..08beb7ed53 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -357,7 +357,7 @@ ps = MTKParameters( (BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]), BlockedArray(falses(1), [1, 0])), (), (), ()) -@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} +@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector} tsidx1 = 1 tsidx2 = 2 @test length(ps.discrete[1][Block(tsidx1)]) == 3 @@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values( sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) @test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0] @test ps.discrete[2][Block(tsidx1)][] == false + +@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin + @variables x(t) + @parameters p::Any + @named sys = System(D(x) ~ x, t, [x], [p]) + sys = complete(sys) + ps = MTKParameters(sys, [p => 1.0]) + @test ps.nonnumeric isa Tuple{Vector{Any}} + ps2 = remake_buffer(sys, ps, [p], [:a]) + @test ps2.nonnumeric isa Tuple{Vector{Any}} +end diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 38051d9d49..6fdc1b5cbf 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -1,5 +1,6 @@ using Test using ModelingToolkit +using OrdinaryDiffEqDefault using Distributed @@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1) @test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) == ODEPrecompileTest @test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16] + +@test_nowarn solve(ODEPrecompileTest.prob_eval) diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index 2111f7ba64..c77c897655 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__) # Change the module the eval'd function is eval'd into to be the containing module, # which should make it be in the package image const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__) + +function problem(; kwargs...) + # Define some variables + @independent_variables t + @parameters σ ρ β + @variables x(t) y(t) z(t) + D = Differential(t) + + # Define a differential equation + eqs = [D(x) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @named de = System(eqs, t) + de = complete(de) + return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...) +end + +const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__) + end