From 8eb2d68ccd213ee599656186a6b0ae689f864a02 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 23:52:58 +0530 Subject: [PATCH 001/157] fix: fix `ODEProblem` construction during precompilation with `eval_expression = true` --- src/systems/problem_utils.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2a1208586b..445e771380 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1175,7 +1175,6 @@ function maybe_build_initialization_problem( sys, initializeprob.f.sys; p_constructor) end - reqd_syms = parameter_symbols(initializeprob) # we still want the `initialization_data` because it helps with `remake` if initializeprobmap === nothing && initializeprobpmap === nothing update_initializeprob! = nothing @@ -1186,7 +1185,9 @@ function maybe_build_initialization_problem( filter!(punknowns) do p is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing end - pvals = getu(initializeprob, punknowns)(initializeprob) + # See comment below for why `getu` is not used here. + _pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns) + pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob)) for (p, pval) in zip(punknowns, pvals) p = unwrap(p) op[p] = pval @@ -1198,7 +1199,13 @@ function maybe_build_initialization_problem( end if time_dependent_init - uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob) + # We can't use `getu` here because that goes to `SII.observed`, which goes to + # `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If + # `eval_expression == true`, this then runs into world-age issues. Building an + # RGF here is fine since it is always discarded. We can't use `eval_module` for + # the RGF since the user may not have run RGF's init. + _ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns)) + uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob)) for (v, val) in zip(missing_unknowns, uvals) op[v] = val end From fba3cccb5ae51cdc8119edcd4ee0ae556ecd5a4d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 8 Oct 2025 14:11:51 +0530 Subject: [PATCH 002/157] fix: fix overspecialization of nonnumeric buffers in `remake_buffer` --- src/systems/parameter_buffer.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 637ed674ae..f91348d0c9 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -763,8 +763,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru oldbuf.discrete, newbuf.discrete) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( oldbuf.constant, newbuf.constant) - @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( - oldbuf.nonnumeric, newbuf.nonnumeric) + for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric) + for i in eachindex(oldv) + isassigned(newv, i) && continue + newv[i] = oldv[i] + end + end + @set! newbuf.nonnumeric = Tuple( + typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)) if !ArrayInterface.ismutable(oldbuf) @set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable) @set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials) From 638a7a91a5c8f52fec4205356245307ca7f1d3d3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 8 Oct 2025 14:12:08 +0530 Subject: [PATCH 003/157] fix: fix `eval_expression = true` construction of problems --- src/systems/problem_utils.jl | 57 +++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 445e771380..d1c85e437f 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -701,9 +701,10 @@ function. Note that the getter ONLY works for problem-like objects, since it generates an observed function. It does NOT work for solutions. """ -Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector) +Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module) @nospecialize - obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false) + obsfn = build_explicit_observed_function( + indp, syms; wrap_delays = false, eval_expression, eval_module) return ObservedWrapper{is_time_dependent(indp)}(obsfn) end @@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns - `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; - initials = false, unwrap_initials = false, p_constructor = identity) + initials = false, unwrap_initials = false, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the @@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac tunable_getter = if isempty(tunable_syms) Returns(SizedVector{0, Float64}()) else - p_constructor ∘ concrete_getu(srcsys, tunable_syms) + p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module) end initials_getter = if initials && !isempty(syms[2]) initsyms = Vector{Any}(syms[2]) @@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end end end - p_constructor ∘ concrete_getu(srcsys, initsyms) + p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module) else Returns(SizedVector{0, Float64}()) end @@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ Base.Fix1(broadcast, p_constructor) ∘ - getu(srcsys, syms[3]) + concrete_getu(srcsys, syms[3]; eval_expression, eval_module) end const_getter = if syms[4] == () Returns(()) @@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end) # nonnumerics retain the assigned buffer type without narrowing Base.Fix1(broadcast, _p_constructor) ∘ - Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ + concrete_getu(srcsys, syms[5]; eval_expression, eval_module) end getters = ( tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) @@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of ` with values from `srcsys`. """ function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity) + srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) @assert is_initializesystem(dstsys) - ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys)) + ugetter = u0_constructor ∘ + concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module) if is_split(dstsys) - pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor) + pgetter = get_mtkparameters_reconstructor( + srcsys, dstsys; p_constructor, eval_expression, eval_module) else syms = parameters(dstsys) - pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor + pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module), + p_constructor = p_constructor + function _getter2(valp, initprob) p_constructor(inner(valp)) end @@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the `initializeprobpmap` function in `OverrideInitData` for the systems. """ function construct_initializeprobpmap( - sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity) + sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module) @assert is_initializesystem(initsys) if is_split(sys) return let getter = get_mtkparameters_reconstructor( - initsys, sys; initials = true, unwrap_initials = true, p_constructor) + initsys, sys; initials = true, unwrap_initials = true, p_constructor, + eval_expression, eval_module) function initprobpmap_split(prob, initsol) getter(initsol, prob) end end else - return let getter = getu(initsys, parameters(sys; initial_parameters = true)), - p_constructor = p_constructor + return let getter = concrete_getu( + initsys, parameters(sys; initial_parameters = true); + eval_expression, eval_module), p_constructor = p_constructor function initprobpmap_nosplit(prob, initsol) return p_constructor(getter(initsol)) @@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU} get_initial_unknowns::GIU end -function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict) +function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict) dvs = unknowns(sys) eqs = equations(sys) guessvars = trues(length(dvs)) for (i, var) in enumerate(dvs) guessvars[i] = !isequal(get(op, var, nothing), Initial(var)) end - get_guessvars = getu(initsys, dvs[guessvars]) + get_guessvars = getu(initprob, dvs[guessvars]) get_initial_unknowns = getu(sys, Initial.(dvs)) return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns) end @@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem( guesses, missing_unknowns; implicit_dae = false, time_dependent_init = is_time_dependent(sys), u0_constructor = identity, p_constructor = identity, floatT = Float64, initialization_eqs = [], - use_scc = true, kwargs...) + use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem( initializeprob = ModelingToolkit.InitializationProblem{iip}( sys, t, op; guesses, time_dependent_init, initialization_eqs, - use_scc, u0_constructor, p_constructor, kwargs...) + use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...) if state_values(initializeprob) !== nothing _u0 = state_values(initializeprob) if ArrayInterface.ismutable(_u0) @@ -1145,7 +1155,7 @@ function maybe_build_initialization_problem( initializeprob = remake(initializeprob; p = initp) get_initial_unknowns = if time_dependent_init - GetUpdatedU0(sys, initializeprob.f.sys, op) + GetUpdatedU0(sys, initializeprob, op) else nothing end @@ -1153,7 +1163,8 @@ function maybe_build_initialization_problem( copy(op), copy(guesses), Vector{Equation}(initialization_eqs), use_scc, time_dependent_init, ReconstructInitializeprob( - sys, initializeprob.f.sys; u0_constructor, p_constructor), + sys, initializeprob.f.sys; u0_constructor, + p_constructor, eval_expression, eval_module), get_initial_unknowns, SetInitialUnknowns(sys)) if time_dependent_init @@ -1172,7 +1183,7 @@ function maybe_build_initialization_problem( initializeprobpmap = nothing else initializeprobpmap = construct_initializeprobpmap( - sys, initializeprob.f.sys; p_constructor) + sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module) end # we still want the `initialization_data` because it helps with `remake` @@ -1468,7 +1479,7 @@ function process_SciMLProblem( if is_time_dependent(sys) && t0 === nothing t0 = zero(floatT) end - initialization_data = SciMLBase.remake_initialization_data( + initialization_data = @invokelatest SciMLBase.remake_initialization_data( sys, kwargs, u0, t0, p, u0, p) kwargs = merge(kwargs, (; initialization_data)) end From 83db1e5af773407399354a68159758156e3515a7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Oct 2025 13:20:35 +0530 Subject: [PATCH 004/157] test: test no specialization of nonnumeric buffers --- test/mtkparameters.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 28ab3759ef..08beb7ed53 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -357,7 +357,7 @@ ps = MTKParameters( (BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]), BlockedArray(falses(1), [1, 0])), (), (), ()) -@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} +@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector} tsidx1 = 1 tsidx2 = 2 @test length(ps.discrete[1][Block(tsidx1)]) == 3 @@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values( sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) @test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0] @test ps.discrete[2][Block(tsidx1)][] == false + +@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin + @variables x(t) + @parameters p::Any + @named sys = System(D(x) ~ x, t, [x], [p]) + sys = complete(sys) + ps = MTKParameters(sys, [p => 1.0]) + @test ps.nonnumeric isa Tuple{Vector{Any}} + ps2 = remake_buffer(sys, ps, [p], [:a]) + @test ps2.nonnumeric isa Tuple{Vector{Any}} +end From 135383ba2f3f24540f625c8fa48bbeb517e20cc9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Oct 2025 13:20:47 +0530 Subject: [PATCH 005/157] test: test solving `ODEProblem` with `eval_expression = true` --- test/precompile_test.jl | 3 +++ test/precompile_test/ODEPrecompileTest.jl | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 38051d9d49..6fdc1b5cbf 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -1,5 +1,6 @@ using Test using ModelingToolkit +using OrdinaryDiffEqDefault using Distributed @@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1) @test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) == ODEPrecompileTest @test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16] + +@test_nowarn solve(ODEPrecompileTest.prob_eval) diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index 2111f7ba64..c77c897655 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__) # Change the module the eval'd function is eval'd into to be the containing module, # which should make it be in the package image const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__) + +function problem(; kwargs...) + # Define some variables + @independent_variables t + @parameters σ ρ β + @variables x(t) y(t) z(t) + D = Differential(t) + + # Define a differential equation + eqs = [D(x) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @named de = System(eqs, t) + de = complete(de) + return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...) +end + +const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__) + end From af71ccaaaccc61b09d8c72f2dca46828d9bc3e88 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Oct 2025 17:56:03 +0530 Subject: [PATCH 006/157] docs: fix bad docstring --- docs/src/API/codegen.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/API/codegen.md b/docs/src/API/codegen.md index 4f31405174..0092e1cb8e 100644 --- a/docs/src/API/codegen.md +++ b/docs/src/API/codegen.md @@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b All code generation eventually calls `build_function_wrapper`. ```@docs -build_function_wrapper +ModelingToolkit.build_function_wrapper ``` From aaaac6e7ccbd6505bfaceb17ef81b38d8d715062 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 18:17:38 +0530 Subject: [PATCH 007/157] fix: fix type-change of time-dependent parameters with `respecialize` --- src/systems/diffeqs/basic_transformations.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index f823260e2a..200e0d3d56 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -1037,9 +1037,11 @@ function respecialize(sys::AbstractSystem, mapping; all = false) """ if iscall(k) - op = operation(k) + op = operation(k)::BasicSymbolic + @assert !iscall(op) + op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op)) args = arguments(k) - new_p = SymbolicUtils.term(op, args...; type = T) + new_p = op(args...) else new_p = SymbolicUtils.Sym{T}(getname(k)) end From 9ede8ab5cc5b14f45dee2e52245340479a7454d2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 18:18:10 +0530 Subject: [PATCH 008/157] fix: add `@invokelatest` to allow trivial initialization with `eval_expression = true` --- src/systems/problem_utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index d1c85e437f..2f5db01479 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1791,7 +1791,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs """ function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...) # Call `remake` so it runs initialization if it is trivial - remake(T(args...; kwargs...)) + # Use `@invokelatest` to avoid world-age issues with `eval_expression = true` + @invokelatest remake(T(args...; kwargs...)) end """ From eb1c901aa004d1adf04dddd415f633b6883dd462 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 18:18:19 +0530 Subject: [PATCH 009/157] test: fix `respecialize` test --- test/basic_transformations.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index dc9d71f300..19899aa50f 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -340,11 +340,12 @@ foofn(x) = 4 @testset "`respecialize`" begin @parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r - rp, - rp2 = let - only(@parameters p::Bar), - SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz) - end + rp = only(let p = nothing + @parameters p::Bar + end) + rp2 = only(let p2 = nothing + @parameters p2(t)::Baz + end) @variables x(t) = 1.0 @named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r]) From 9afe05e20b4ebc19ba42fcdb6a72c5217914cea2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 18:28:21 +0530 Subject: [PATCH 010/157] refactor: remove usages of `Symbolic` --- ext/MTKCasADiDynamicOptExt.jl | 2 +- ext/MTKInfiniteOptExt.jl | 6 +++--- ext/MTKPyomoDynamicOptExt.jl | 4 ++-- src/ModelingToolkit.jl | 4 ++-- src/clock.jl | 2 +- src/constants.jl | 2 +- src/discretedomain.jl | 2 +- src/independent_variables.jl | 2 +- src/parameters.jl | 10 +++++----- src/problems/jumpproblem.jl | 4 ++-- .../StructuralTransformations.jl | 2 +- src/structural_transformation/pantelides.jl | 4 ++-- .../symbolics_tearing.jl | 4 ++-- src/structural_transformation/utils.jl | 2 +- src/systems/abstractsystem.jl | 16 +++++++-------- src/systems/callbacks.jl | 9 +++++---- src/systems/connectors.jl | 2 +- src/systems/system.jl | 2 +- src/systems/systemstructure.jl | 6 +++--- src/systems/unit_check.jl | 16 +++++++-------- src/systems/validation.jl | 20 +++++++++---------- src/utils.jl | 10 +++++----- src/variables.jl | 20 +++++++++---------- 23 files changed, 75 insertions(+), 76 deletions(-) diff --git a/ext/MTKCasADiDynamicOptExt.jl b/ext/MTKCasADiDynamicOptExt.jl index addc478d98..b1762e7f7f 100644 --- a/ext/MTKCasADiDynamicOptExt.jl +++ b/ext/MTKCasADiDynamicOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::CasADiModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X.u[i, :] : X(t)[i] + t isa Union{Num, SymbolicT} ? X.u[i, :] : X(t)[i] end function MTK.lowered_integral(model::CasADiModel, expr, lo, hi) diff --git a/ext/MTKInfiniteOptExt.jl b/ext/MTKInfiniteOptExt.jl index e0f02c0436..acb80c1041 100644 --- a/ext/MTKInfiniteOptExt.jl +++ b/ext/MTKInfiniteOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::InfiniteOptModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X[i] : X[i](t) + t isa Union{Num, SymbolicT} ? X[i] : X[i](t) end function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau) @@ -256,13 +256,13 @@ for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqr end # JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both. -function Base.isequal(::SymbolicUtils.Symbolic, +function Base.isequal(::SymbolicT, ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}) false end function Base.isequal( ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}, - ::SymbolicUtils.Symbolic) + ::SymbolicT) false end end diff --git a/ext/MTKPyomoDynamicOptExt.jl b/ext/MTKPyomoDynamicOptExt.jl index 5b4e9e7a1c..fe18b2678d 100644 --- a/ext/MTKPyomoDynamicOptExt.jl +++ b/ext/MTKPyomoDynamicOptExt.jl @@ -53,7 +53,7 @@ struct PyomoDynamicOptProblem{uType, tType, isinplace, P, F, K} <: end end -function pysym_getproperty(s::Union{Num, Symbolics.Symbolic}, name::Symbol) +function pysym_getproperty(s::Union{Num, SymbolicT}, name::Symbol) Symbolics.wrap(SymbolicUtils.term( _getproperty, Symbolics.unwrap(s), Val{name}(), type = Symbolics.Struct{PyomoVar})) end @@ -165,7 +165,7 @@ end function MTK.lowered_var(m::PyomoDynamicOptModel, uv, i, t) X = Symbolics.value(pysym_getproperty(m.model_sym, uv)) - var = t isa Union{Num, Symbolics.Symbolic} ? X[i, m.t_sym] : X[i, t] + var = t isa Union{Num, SymbolicT} ? X[i, m.t_sym] : X[i, t] Symbolics.unwrap(var) end diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6c79eb4fb1..428fe10e7a 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -10,7 +10,7 @@ end import SymbolicUtils import SymbolicUtils: iscall, arguments, operation, maketerm, promote_symtype, - Symbolic, isadd, ismul, ispow, issym, FnType, + isadd, ismul, ispow, issym, FnType, @rule, Rewriters, substitute, metadata, BasicSymbolic, Sym, Term using SymbolicUtils.Code @@ -74,7 +74,7 @@ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr -using Symbolics: degree +using Symbolics: degree, VartypeT, SymbolicT using Symbolics: _parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, build_expr, unwrap, wrap, VariableSource, getname, variable, diff --git a/src/clock.jl b/src/clock.jl index df3b6f4b47..03ae6efdbf 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -45,7 +45,7 @@ has_time_domain(_, x) = has_time_domain(x) Determine if variable `x` has a time-domain attributed to it. """ -function has_time_domain(x::Symbolic) +function has_time_domain(x::SymbolicT) # getmetadata(x, ContinuousClock, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing getmetadata(x, VariableTimeDomain, nothing) !== nothing diff --git a/src/constants.jl b/src/constants.jl index 4113287ad4..9685e89b07 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -3,7 +3,7 @@ Test whether `x` is a constant-type Sym. """ function isconstant(x) x = unwrap(x) - x isa Symbolic && !getmetadata(x, VariableTunable, true) + x isa SymbolicT && !getmetadata(x, VariableTunable, true) end """ diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 9e57296d9f..ffa36c9dbd 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -60,7 +60,7 @@ julia> Δ = Shift(t) """ struct Shift <: Operator """Fixed Shift""" - t::Union{Nothing, Symbolic} + t::Union{Nothing, SymbolicT} steps::Int Shift(t, steps = 1) = new(value(t), steps) end diff --git a/src/independent_variables.jl b/src/independent_variables.jl index d1f2ab4210..e1fd9be262 100644 --- a/src/independent_variables.jl +++ b/src/independent_variables.jl @@ -13,6 +13,6 @@ macro independent_variables(ts...) toiv) |> esc end -toiv(s::Symbolic) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) +toiv(s::SymbolicT) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) toiv(s::Symbolics.Arr) = wrap(toiv(value(s))) toiv(s::Num) = Num(toiv(value(s))) diff --git a/src/parameters.jl b/src/parameters.jl index d8ff1bf1be..413caf7985 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -25,19 +25,19 @@ Check if the variable contains the metadata identifying it as a parameter. function isparameter(x) x = unwrap(x) - if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing + if x isa SymbolicT && (varT = getvariabletype(x, nothing)) !== nothing return varT === PARAMETER #TODO: Delete this branch - elseif x isa Symbolic && Symbolics.getparent(x, false) !== false + elseif x isa SymbolicT && Symbolics.getparent(x, false) !== false p = Symbolics.getparent(x) isparameter(p) || (hasmetadata(p, Symbolics.VariableSource) && getmetadata(p, Symbolics.VariableSource)[1] == :parameters) - elseif iscall(x) && operation(x) isa Symbolic + elseif iscall(x) && operation(x) isa SymbolicT varT === PARAMETER || isparameter(operation(x)) elseif iscall(x) && operation(x) == (getindex) isparameter(arguments(x)[1]) - elseif x isa Symbolic + elseif x isa SymbolicT varT === PARAMETER else false @@ -80,7 +80,7 @@ toparam(s::Num) = wrap(toparam(value(s))) Maps the variable to an unknown. """ -tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) +tovar(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) tovar(s::Union{Num, Symbolics.Arr}) = wrap(tovar(unwrap(s))) """ diff --git a/src/problems/jumpproblem.jl b/src/problems/jumpproblem.jl index 32aa25182f..113f5fc2f2 100644 --- a/src/problems/jumpproblem.jl +++ b/src/problems/jumpproblem.jl @@ -196,13 +196,13 @@ end ### Functions to determine which unknowns a jump depends on function get_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}, variables) jr = value(jump.rate) - (jr isa Symbolic) && get_variables!(dep, jr, variables) + (jr isa SymbolicT) && get_variables!(dep, jr, variables) dep end function get_variables!(dep, jump::MassActionJump, variables) sr = value(jump.scaled_rates) - (sr isa Symbolic) && get_variables!(dep, sr, variables) + (sr isa SymbolicT) && get_variables!(dep, sr, variables) for varasop in jump.reactant_stoch any(isequal(varasop[1]), variables) && push!(dep, varasop[1]) end diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 681025cb81..7c874dd7c4 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -12,7 +12,7 @@ using SymbolicUtils: maketerm, iscall using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, - unknowns, equations, vars, Symbolic, diff2term_with_unit, + unknowns, equations, vars, SymbolicT, diff2term_with_unit, shift2term_with_unit, value, operation, arguments, Sym, Term, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 871bd99ef4..0a215c0714 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -37,7 +37,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) # LHS variable is looked up from var_to_diff # the var_to_diff[i]-th variable is the differentiated version of var at i eq = out_eqs[eqidx] - lhs = if !(eq.lhs isa Symbolic) + lhs = if !(eq.lhs isa SymbolicT) 0 elseif isdiffeq(eq) # look up the variable that represents D(lhs) @@ -56,7 +56,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) rhs = ModelingToolkit.expand_derivatives(D(eq.rhs)) rhs = fast_substitute(rhs, state.param_derivative_map) substitution_dict = Dict(x.lhs => x.rhs - for x in out_eqs if x !== nothing && x.lhs isa Symbolic) + for x in out_eqs if x !== nothing && x.lhs isa SymbolicT) sub_rhs = substitute(rhs, substitution_dict) out_eqs[diff] = lhs ~ sub_rhs end diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 39b959c5a6..1d4539968d 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -144,7 +144,7 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F, eq = 0 ~ eq.rhs - eq.lhs end rhs = eq.rhs - if rhs isa Symbolic + if rhs isa SymbolicT # Check if the RHS is solvable in all unknown variable derivatives and if those # the linear terms for them are all zero. If so, move them to the # LHS. @@ -1189,7 +1189,7 @@ end Backshift the given expression `ex`. """ function backshift_expr(ex, iv) - ex isa Symbolic || return ex + ex isa SymbolicT || return ex return descend_lower_shift_varname_with_unit( simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv) end diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 3fa4f28aa9..84d6c1b01c 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -249,7 +249,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no a, b, islinear = linear_expansion(term, var) a, b = unwrap(a), unwrap(b) islinear || (all_int_vars = false; continue) - if a isa Symbolic + if a isa SymbolicT all_int_vars = false if !allow_symbolic if allow_parameter diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0bd05bb4b9..c18011f284 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1019,7 +1019,7 @@ struct LocalScope <: SymScope end Apply `LocalScope` to `sym`. """ -function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function LocalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1051,7 +1051,7 @@ end Apply `ParentScope` to `sym`, with `parent` being `LocalScope`. """ -function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function ParentScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1081,7 +1081,7 @@ struct GlobalScope <: SymScope end Apply `GlobalScope` to `sym`. """ -function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function GlobalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) == getindex args = arguments(sym) @@ -1106,7 +1106,7 @@ Namespace `x` with the name of `sys`. function renamespace(sys, x) sys === nothing && return x x = unwrap(x) - if x isa Symbolic + if x isa SymbolicT T = typeof(x) if iscall(x) && operation(x) isa Operator return maketerm(typeof(x), operation(x), @@ -1500,10 +1500,8 @@ function defaults_and_guesses(sys::AbstractSystem) end unknowns(sys::Union{AbstractSystem, Nothing}, v) = namespace_expr(v, sys) -for vType in [Symbolics.Arr, Symbolics.Symbolic{<:AbstractArray}] - @eval unknowns(sys::AbstractSystem, v::$vType) = namespace_expr(v, sys) - @eval parameters(sys::AbstractSystem, v::$vType) = toparam(unknowns(sys, v)) -end +unknowns(sys::AbstractSystem, v::Symbolics.Arr) = namespace_expr(v, sys) +parameters(sys::AbstractSystem, v::Symbolics.Arr) = toparam(unknowns(sys, v)) parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(unknowns(sys, v)) for f in [:unknowns, :parameters] @eval function $f(sys::AbstractSystem, vs::AbstractArray) @@ -2257,7 +2255,7 @@ end function default_to_parentscope(v) uv = unwrap(v) - uv isa Symbolic || return v + uv isa SymbolicT || return v apply_to_variables(v) do sym ParentScope(sym) end diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c94166103b..69fb702072 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -169,7 +169,7 @@ Base.nameof(::Pre) = :Pre Base.show(io::IO, x::Pre) = print(io, "Pre") unPre(x::Num) = unPre(unwrap(x)) unPre(x::Symbolics.Arr) = unPre(unwrap(x)) -unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x +unPre(x::SymbolicT) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x function (p::Pre)(x) iw = Symbolics.iswrapped(x) @@ -420,14 +420,14 @@ A callback that triggers at the first timestep that the conditions are satisfied The condition can be one of: - Δt::Real - periodic events with period Δt - ts::Vector{Real} - events trigger at these preset times given by `ts` -- eqs::Vector{Symbolic} - events trigger when the condition evaluates to true +- eqs::Vector{SymbolicT} - events trigger when the condition evaluates to true Arguments: - iv: The independent variable of the system. This must be specified if the independent variable appears in one of the equations explicitly, as in x ~ t + 1. - alg_eqs: Algebraic equations of the system that must be satisfied after the callback occurs. """ struct SymbolicDiscreteCallback <: AbstractCallback - conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}} + conditions::Union{Number, Vector{<:Number}, SymbolicT} affect::Union{Affect, SymbolicAffect, Nothing} initialize::Union{Affect, SymbolicAffect, Nothing} finalize::Union{Affect, SymbolicAffect, Nothing} @@ -435,9 +435,10 @@ struct SymbolicDiscreteCallback <: AbstractCallback end function SymbolicDiscreteCallback( - condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing; + condition::Union{SymbolicT, Number, Vector{<:Number}}, affect = nothing; initialize = nothing, finalize = nothing, reinitializealg = nothing, kwargs...) + @assert !(condition isa SymbolicT && symtype(condition) != Bool) c = is_timed_condition(condition) ? condition : value(scalarize(condition)) if isnothing(reinitializealg) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index c0ddf5baee..89be35b716 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -72,7 +72,7 @@ end Get the connection type of symbolic variable `s` from the `VariableConnectType` metadata. Defaults to `Equality` if not present. """ -function get_connection_type(s::Symbolic) +function get_connection_type(s::SymbolicT) s = unwrap(s) if iscall(s) && operation(s) === getindex s = arguments(s)[1] diff --git a/src/systems/system.jl b/src/systems/system.jl index dcb3ed6f9b..74e233791a 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -484,7 +484,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...) diffeqs = Equation[] othereqs = Equation[] for eq in eqs - if !(eq.lhs isa Union{Symbolic, Number, AbstractArray}) + if !(eq.lhs isa Union{SymbolicT, Number, AbstractArray}) push!(othereqs, eq) continue end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a1460731cb..e9d68e162f 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,11 +1,11 @@ using DataStructures using Symbolics: linear_expansion, unwrap -using SymbolicUtils: iscall, operation, arguments, Symbolic +using SymbolicUtils: iscall, operation, arguments using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value, InvalidSystemException, isdifferential, _iszero, - isparameter, Connection, + isparameter, Connection, SymbolicT independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, InferredTimeDomain, @@ -372,7 +372,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) union!(dvs, xx) end end - ps = Set{Symbolic}() + ps = Set{SymbolicT}() for x in full_parameters(sys) push!(ps, x) if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index acf7451065..eac27b58be 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -12,7 +12,7 @@ function __get_literal_unit(x) if x isa Pair x = x[1] end - if !(x isa Union{Num, Symbolic}) + if !(x isa Union{Num, SymbolicT}) return nothing end v = value(x) @@ -129,7 +129,7 @@ function get_unit(op::Comparison, args) return unitless end -function get_unit(x::Symbolic) +function get_unit(x::SymbolicT) if (u = __get_literal_unit(x)) !== nothing screen_unit(u) elseif issym(x) @@ -249,14 +249,14 @@ function _validate(conn::Connection; info::String = "") end function validate(jump::Union{VariableRateJump, - ConstantRateJump}, t::Symbolic; + ConstantRateJump}, t::SymbolicT; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MassActionJump, t::SymbolicT; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -267,7 +267,7 @@ function validate(jump::MassActionJump, t::Symbolic; info::String = "") ["scaled_rates", "1/(t*reactants^$n))"]; info) end -function validate(jumps::Vector{JumpType}, t::Symbolic) +function validate(jumps::Vector{JumpType}, t::SymbolicT) labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] majs = filter(x -> x isa MassActionJump, jumps) crjs = filter(x -> x isa ConstantRateJump, jumps) @@ -284,7 +284,7 @@ function validate(eq::Union{Inequality, Equation}; info::String = "") end end function validate(eq::Equation, - term::Union{Symbolic, DQ.AbstractQuantity, Num}; info::String = "") + term::Union{SymbolicT, DQ.AbstractQuantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end function validate(eq::Equation, terms::Vector; info::String = "") @@ -306,10 +306,10 @@ function validate(eqs::Vector, noise::Matrix; info::String = "") all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -function validate(eqs::Vector, term::Symbolic; info::String = "") +function validate(eqs::Vector, term::SymbolicT; info::String = "") all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing +validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing """ Throws error if units of equations are invalid. diff --git a/src/systems/validation.jl b/src/systems/validation.jl index d416a02ea2..cdf75a1631 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -6,11 +6,11 @@ using ..ModelingToolkit: ValidationError, get_systems, Conditional, Comparison using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: Symbolic, value, issym, isadd, ismul, ispow +using Symbolics: SymbolicT, value, issym, isadd, ismul, ispow const MT = ModelingToolkit -Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y -Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y +Base.:*(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x * y +Base.:/(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x / y """ Throw exception on invalid unit types, otherwise return argument. @@ -104,7 +104,7 @@ function get_unit(op::Comparison, args) return unitless end -function get_unit(x::Symbolic) +function get_unit(x::SymbolicT) if issym(x) get_literal_unit(x) elseif isadd(x) @@ -214,14 +214,14 @@ function _validate(conn::Connection; info::String = "") end function validate(jump::Union{MT.VariableRateJump, - MT.ConstantRateJump}, t::Symbolic; + MT.ConstantRateJump}, t::SymbolicT; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MT.MassActionJump, t::SymbolicT; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -232,7 +232,7 @@ function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") ["scaled_rates", "1/(t*reactants^$n))"]; info) end -function validate(jumps::Vector{JumpType}, t::Symbolic) +function validate(jumps::Vector{JumpType}, t::SymbolicT) labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] majs = filter(x -> x isa MassActionJump, jumps) crjs = filter(x -> x isa ConstantRateJump, jumps) @@ -249,7 +249,7 @@ function validate(eq::MT.Equation; info::String = "") end end function validate(eq::MT.Equation, - term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "") + term::Union{SymbolicT, Unitful.Quantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end function validate(eq::MT.Equation, terms::Vector; info::String = "") @@ -271,10 +271,10 @@ function validate(eqs::Vector, noise::Matrix; info::String = "") all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -function validate(eqs::Vector, term::Symbolic; info::String = "") +function validate(eqs::Vector, term::SymbolicT; info::String = "") all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing +validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing """ Throws error if units of equations are invalid. diff --git a/src/utils.jl b/src/utils.jl index 0da7e4860b..e89d7f1d3a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -389,7 +389,7 @@ isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) isvariable(x::Num)::Bool = isvariable(value(x)) function isvariable(x)::Bool - x isa Symbolic || return false + x isa SymbolicT || return false p = getparent(x, nothing) p === nothing || (x = p) hasmetadata(x, VariableSource) @@ -412,7 +412,7 @@ v = ModelingToolkit.vars(D(y) ~ u) v == Set([D(y), u]) ``` """ -function vars(exprs::Symbolic; op = Differential) +function vars(exprs::SymbolicT; op = Differential) iscall(exprs) ? vars([exprs]; op = op) : Set([exprs]) end vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) @@ -544,7 +544,7 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif for eq in equations(sys) eqtype_supports_collect_vars(eq) || continue if eq isa Equation - eq.lhs isa Union{Symbolic, Number} || continue + eq.lhs isa Union{SymbolicT, Number} || continue end collect_vars!(unknowns, parameters, eq, iv; depth, op) end @@ -803,7 +803,7 @@ end function _with_unit(f, x, t, args...) x = f(x, args...) - if hasmetadata(x, VariableUnit) && (t isa Symbolic && hasmetadata(t, VariableUnit)) + if hasmetadata(x, VariableUnit) && (t isa SymbolicT && hasmetadata(t, VariableUnit)) xu = getmetadata(x, VariableUnit) tu = getmetadata(t, VariableUnit) x = setmetadata(x, VariableUnit, xu / tu) @@ -1184,4 +1184,4 @@ function wrap_with_D(n, D, repeats) else wrap_with_D(D(n), D, repeats - 1) end -end \ No newline at end of file +end diff --git a/src/variables.jl b/src/variables.jl index 46c9c95bc6..e5a167d656 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -179,7 +179,7 @@ struct Stream <: AbstractConnectType end # special stream connector Get the connect type of x. See also [`hasconnect`](@ref). """ getconnect(x::Num) = getconnect(unwrap(x)) -getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing) +getconnect(x::SymbolicT) = Symbolics.getmetadata(x, VariableConnectType, nothing) """ hasconnect(x) @@ -280,7 +280,7 @@ Create parameters with bounds like this @parameters p [bounds=(-1, 1)] ``` """ -function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic}) +function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) x = unwrap(x) p = Symbolics.getparent(x, nothing) if p === nothing @@ -512,7 +512,7 @@ end Maps the brownianiable to an unknown. """ -tobrownian(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) +tobrownian(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) tobrownian(s::Num) = Num(tobrownian(value(s))) isbrownian(s) = getvariabletype(s) === BROWNIAN @@ -587,7 +587,7 @@ Fetch any miscellaneous data associated with symbolic variable `x`. See also [`hasmisc(x)`](@ref). """ getmisc(x::Num) = getmisc(unwrap(x)) -getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing) +getmisc(x::SymbolicT) = Symbolics.getmetadata(x, VariableMisc, nothing) """ hasmisc(x) @@ -606,7 +606,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata) Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions. """ getunit(x::Num) = getunit(unwrap(x)) -getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing) +getunit(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnit, nothing) """ hasunit(x) @@ -615,10 +615,10 @@ Check if the variable `x` has a unit. hasunit(x) = getunit(x) !== nothing getunshifted(x::Num) = getunshifted(unwrap(x)) -getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) +getunshifted(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnshifted, nothing) getshift(x::Num) = getshift(unwrap(x)) -getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) +getshift(x::SymbolicT) = Symbolics.getmetadata(x, VariableShift, 0) ################### ### Evaluate at ### @@ -629,7 +629,7 @@ getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) An operator that evaluates time-dependent variables at a specific absolute time point `t`. # Fields -- `t::Union{Symbolic, Number}`: The absolute time at which to evaluate the variable. +- `t::Union{SymbolicT, Number}`: The absolute time at which to evaluate the variable. # Description `EvalAt` is used to evaluate time-dependent variables at a specific time point. This is particularly @@ -677,10 +677,10 @@ end See also: [`Differential`](@ref) """ struct EvalAt <: Symbolics.Operator - t::Union{Symbolic, Number} + t::Union{SymbolicT, Number} end -function (A::EvalAt)(x::Symbolic) +function (A::EvalAt)(x::SymbolicT) if symbolic_type(x) == NotSymbolic() || !iscall(x) if x isa Symbolics.CallWithMetadata return x(A.t) From f04e9e7eaec4e4fda88a5c577dbaa0fc51c2b625 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 18:29:02 +0530 Subject: [PATCH 011/157] refactor: remove usages of `_parse_vars` --- src/ModelingToolkit.jl | 2 +- src/constants.jl | 4 ++-- src/independent_variables.jl | 4 ++-- src/parameters.jl | 4 ++-- src/variables.jl | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 428fe10e7a..6f63dd73e7 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -75,7 +75,7 @@ using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr using Symbolics: degree, VartypeT, SymbolicT -using Symbolics: _parse_vars, value, @derivatives, get_variables, +using Symbolics: parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, build_expr, unwrap, wrap, VariableSource, getname, variable, NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, diff --git a/src/constants.jl b/src/constants.jl index 9685e89b07..fed010a2ee 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -26,8 +26,8 @@ Define one or more constants. See also [`@independent_variables`](@ref), [`@parameters`](@ref) and [`@variables`](@ref). """ macro constants(xs...) - Symbolics._parse_vars(:constants, + Symbolics.parse_vars(:constants, Real, xs, - toconstant) |> esc + toconstant) end diff --git a/src/independent_variables.jl b/src/independent_variables.jl index e1fd9be262..fce2d93873 100644 --- a/src/independent_variables.jl +++ b/src/independent_variables.jl @@ -7,10 +7,10 @@ Define one or more independent variables. For example: @variables x(t) """ macro independent_variables(ts...) - Symbolics._parse_vars(:independent_variables, + Symbolics.parse_vars(:independent_variables, Real, ts, - toiv) |> esc + toiv) end toiv(s::SymbolicT) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) diff --git a/src/parameters.jl b/src/parameters.jl index 413caf7985..d5d96120b8 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -91,10 +91,10 @@ Define one or more known parameters. See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref). """ macro parameters(xs...) - Symbolics._parse_vars(:parameters, + Symbolics.parse_vars(:parameters, Real, xs, - toparam) |> esc + toparam) end function find_types(array) diff --git a/src/variables.jl b/src/variables.jl index e5a167d656..8089b9bffd 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -526,10 +526,10 @@ macro brownians(xs...) x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$), xs) || error("@brownians only takes scalar expressions!") - Symbolics._parse_vars(:brownian, + Symbolics.parse_vars(:brownian, Real, xs, - tobrownian) |> esc + tobrownian) end ## Guess ====================================================================== From 4d42365bebdd8bce7747bebd6294b7696482607d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 18:30:24 +0530 Subject: [PATCH 012/157] refactor: remove outdated imports --- src/ModelingToolkit.jl | 6 +++--- src/clock.jl | 2 +- src/systems/index_cache.jl | 5 ----- src/systems/model_parsing.jl | 2 +- src/systems/parameter_buffer.jl | 1 - 5 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6f63dd73e7..0c823c1b4e 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -76,9 +76,9 @@ using RuntimeGeneratedFunctions: drop_expr using Symbolics: degree, VartypeT, SymbolicT using Symbolics: parse_vars, value, @derivatives, get_variables, - exprs_occur_in, symbolic_linear_solve, build_expr, unwrap, wrap, + exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, - NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, + NAMESPACE_SEPARATOR, setdefaultval, hasnode, fixpoint_sub, fast_substitute, CallWithMetadata, CallWithParent const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) @@ -89,7 +89,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity, ParallelForm, SerialForm, MultithreadedForm, build_function, rhss, lhss, prettify_expr, gradient, jacobian, hessian, derivative, sparsejacobian, sparsehessian, - substituter, scalarize, getparent, hasderiv, hasdiff + scalarize, getparent, hasderiv import DiffEqBase: @add_kwonly export independent_variables, unknowns, observables, parameters, full_parameters, diff --git a/src/clock.jl b/src/clock.jl index 03ae6efdbf..6537334645 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -77,7 +77,7 @@ See also [`is_continuous_domain`](@ref) """ function has_continuous_domain(x) issym(x) && return is_continuous_domain(x) - hasderiv(x) || hasdiff(x) || hassample(x) || hashold(x) + hasderiv(x) || hassample(x) || hashold(x) end """ diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 19c78413cf..b5575ea628 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -3,11 +3,6 @@ struct BufferTemplate length::Int end -function BufferTemplate(s::Type{<:Symbolics.Struct}, length::Int) - T = Symbolics.juliatype(s) - BufferTemplate(T, length) -end - struct Nonnumeric <: SciMLStructures.AbstractPortion end const NONNUMERIC_PORTION = Nonnumeric() diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 699cfee8fd..0369b12212 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -629,7 +629,7 @@ function _set_var_metadata!(metadata_with_exprs, a, m, v::Expr) a end function _set_var_metadata!(metadata_with_exprs, a, m, v) - wrap(set_scalar_metadata(unwrap(a), m, v)) + wrap(setmetadata(unwrap(a), m, v)) end function set_var_metadata(a, ms) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index f91348d0c9..b42ef1e3ef 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,4 +1,3 @@ -symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x) symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) From 194831cdbc437eadc18215339f7f65af2184d823 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 18:35:48 +0530 Subject: [PATCH 013/157] refactor: remove usages of `fast_substitute` --- src/ModelingToolkit.jl | 2 +- src/inputoutput.jl | 2 +- .../StructuralTransformations.jl | 2 +- src/structural_transformation/pantelides.jl | 2 +- .../symbolics_tearing.jl | 10 +++--- src/systems/abstractsystem.jl | 20 ++++++------ src/systems/alias_elimination.jl | 2 +- src/systems/callbacks.jl | 31 +++++++++---------- src/systems/codegen.jl | 4 +-- src/systems/diffeqs/basic_transformations.jl | 2 +- src/systems/imperative_affect.jl | 8 ++--- src/systems/optimal_control_interface.jl | 10 +++--- src/systems/solver_nlprob.jl | 2 +- src/systems/system.jl | 2 +- src/systems/systems.jl | 2 +- src/systems/systemstructure.jl | 6 ++-- src/utils.jl | 2 +- test/optimizationsystem.jl | 2 +- 18 files changed, 54 insertions(+), 57 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0c823c1b4e..69c6486eff 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -79,7 +79,7 @@ using Symbolics: parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, NAMESPACE_SEPARATOR, setdefaultval, - hasnode, fixpoint_sub, fast_substitute, + hasnode, fixpoint_sub, CallWithMetadata, CallWithParent const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, diff --git a/src/inputoutput.jl b/src/inputoutput.jl index c113c4e753..8d4772ef64 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -312,7 +312,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) @set! structure.graph = complete(new_graph) @set! sys.eqs = isempty(input_to_parameters) ? equations(sys) : - fast_substitute(equations(sys), input_to_parameters) + substitute(equations(sys), input_to_parameters) @set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters)) ps = parameters(sys) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 7c874dd7c4..e407678c18 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -3,7 +3,7 @@ module StructuralTransformations using Setfield: @set!, @set using UnPack: @unpack -using Symbolics: unwrap, linear_expansion, fast_substitute +using Symbolics: unwrap, linear_expansion import Symbolics using SymbolicUtils using SymbolicUtils.Code diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 0a215c0714..47fa5aa762 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -54,7 +54,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) D(eq.lhs) end rhs = ModelingToolkit.expand_derivatives(D(eq.rhs)) - rhs = fast_substitute(rhs, state.param_derivative_map) + rhs = substitute(rhs, state.param_derivative_map) substitution_dict = Dict(x.lhs => x.rhs for x in out_eqs if x !== nothing && x.lhs isa SymbolicT) sub_rhs = substitute(rhs, substitution_dict) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 1d4539968d..02b3b7ada1 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -65,7 +65,7 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...) sys = ts.sys eq = equations(ts)[ieq] - eq = 0 ~ fast_substitute( + eq = 0 ~ substitute( ModelingToolkit.derivative( eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map) @@ -217,7 +217,7 @@ function substitute_derivatives_algevars!( v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false) for eq in 𝑑neighbors(graph, dv) dummy_sub[dd] = v_t - neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t) + neweqs[eq] = substitute(neweqs[eq], dd => v_t) end fullvars[dv] = v_t # If we have: @@ -230,7 +230,7 @@ function substitute_derivatives_algevars!( while (ddx = var_to_diff[dx]) !== nothing dx_t = D(x_t) for eq in 𝑑neighbors(graph, ddx) - neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t) + neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t) end fullvars[ddx] = dx_t dx = ddx @@ -961,8 +961,8 @@ function update_simplified_system!( obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; - fast_substitute(state.additional_observed, obs_sub)] + obs = [substitute(observed(sys), obs_sub); solved_eqs; + substitute(state.additional_observed, obs_sub)] unknown_idxs = filter( i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars)) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index c18011f284..d057e7cfe5 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -526,7 +526,7 @@ function (f::Initial)(x) return result end -# This is required so `fast_substitute` works +# This is required so `substitute` works function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta) val = Initial()(args...) if symbolic_type(val) == NotSymbolic() @@ -2733,26 +2733,26 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, elseif sys isa System rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), collect(rules))) - newsys = @set sys.eqs = fast_substitute(get_eqs(sys), rules) + newsys = @set sys.eqs = substitute(get_eqs(sys), rules) @set! newsys.unknowns = map(get_unknowns(sys)) do var get(rules, var, var) end @set! newsys.ps = map(get_ps(sys)) do var get(rules, var, var) end - @set! newsys.parameter_dependencies = fast_substitute( + @set! newsys.parameter_dependencies = substitute( get_parameter_dependencies(sys), rules) - @set! newsys.defaults = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.defaults = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_defaults(sys)) - @set! newsys.guesses = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.guesses = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_guesses(sys)) - @set! newsys.noise_eqs = fast_substitute(get_noise_eqs(sys), rules) - @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(fast_substitute( + @set! newsys.noise_eqs = substitute(get_noise_eqs(sys), rules) + @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(substitute( get_costs(sys), rules)) - @set! newsys.observed = fast_substitute(get_observed(sys), rules) - @set! newsys.initialization_eqs = fast_substitute( + @set! newsys.observed = substitute(get_observed(sys), rules) + @set! newsys.initialization_eqs = substitute( get_initialization_eqs(sys), rules) - @set! newsys.constraints = fast_substitute(get_constraints(sys), rules) + @set! newsys.constraints = substitute(get_constraints(sys), rules) @set! newsys.systems = map(s -> substitute(s, rules), get_systems(sys)) else error("substituting symbols is not supported for $(typeof(sys))") diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index dc25378b4a..40114c50a4 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -97,7 +97,7 @@ function alias_elimination!(state::TearingState; kwargs...) nvs_orig = ndsts(graph_orig) for ieq in eqs_to_update eq = eqs[ieq] - eqs[ieq] = fast_substitute(eq, subs) + eqs[ieq] = substitute(eq, subs) end @set! mm.nparentrows = nsrcs(graph) @set! mm.row_cols = eltype(mm.row_cols)[mm.row_cols[i] diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 69fb702072..354a661bc6 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -25,10 +25,8 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...) end SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...) -function Symbolics.fast_substitute(aff::SymbolicAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs), - map(substituter, aff.discrete_parameters)) +function (s::SymbolicUtils.Substituter)(aff::SymbolicAffect) + SymbolicAffect(s(aff.affect), s(aff.alg_eqs), s(aff.discrete_parameters)) end struct AffectSystem @@ -42,17 +40,16 @@ struct AffectSystem discretes::Vector end -function Symbolics.fast_substitute(aff::AffectSystem, rules) - substituter = Base.Fix2(fast_substitute, rules) +function (s::SymbolicUtils.Substituter)(aff::AffectSystem) sys = aff.system - @set! sys.eqs = map(substituter, get_eqs(sys)) - @set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys)) - @set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)]) - @set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)]) - @set! sys.unknowns = map(substituter, get_unknowns(sys)) - @set! sys.ps = map(substituter, get_ps(sys)) - AffectSystem(sys, map(substituter, aff.unknowns), - map(substituter, aff.parameters), map(substituter, aff.discretes)) + @set! sys.eqs = s(get_eqs(sys)) + @set! sys.parameter_dependencies = (get_parameter_dependencies(sys)) + @set! sys.defaults = Dict([k => s(v) for (k, v) in defaults(sys)]) + @set! sys.guesses = Dict([k => s(v) for (k, v) in guesses(sys)]) + @set! sys.unknowns = s(get_unknowns(sys)) + @set! sys.ps = s(get_ps(sys)) + AffectSystem(sys, s(aff.unknowns), s(aff.parameters), s(aff.discretes)) + end function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...) @@ -103,8 +100,8 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], rev_map = Dict(zip(discrete_parameters, discretes)) subs = merge(rev_map, Dict(zip(dvs, _dvs))) - affect = Symbolics.fast_substitute(affect, subs) - alg_eqs = Symbolics.fast_substitute(alg_eqs, subs) + affect = substitute(affect, subs) + alg_eqs = substitute(alg_eqs, subs) @named affectsys = System( vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)), @@ -898,7 +895,7 @@ function compile_equational_affect( obseqs, eqs = unhack_observed(observed(affsys), equations(affsys)) if isempty(equations(affsys)) - update_eqs = Symbolics.fast_substitute( + update_eqs = substitute( obseqs, Dict([p => unPre(p) for p in parameters(affsys)])) rhss = map(x -> x.rhs, update_eqs) lhss = map(x -> x.lhs, update_eqs) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 2687fedb80..b5230e7454 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -562,7 +562,7 @@ function generate_boundary_conditions(sys::System, u0, u0_idxs, t0; expression = cons = [con.lhs - con.rhs for con in constraints(sys)] # conssubs = Dict() # get_constraint_unknown_subs!(conssubs, cons, stidxmap, iv, sol) - # cons = map(x -> fast_substitute(x, conssubs), cons) + # cons = map(x -> substitute(x, conssubs), cons) init_conds = Any[] for i in u0_idxs @@ -1066,7 +1066,7 @@ function build_explicit_observed_function(sys, ts; Base.throw(ArgumentError("Symbol $var is not present in the system.")) end end - ts = fast_substitute(ts, namespace_subs) + ts = substitute(ts, namespace_subs) obsfilter = if param_only if is_split(sys) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index 200e0d3d56..e2eee33bae 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -1051,7 +1051,7 @@ function respecialize(sys::AbstractSystem, mapping; all = false) subrules[unwrap(k)] = unwrap(new_p) end - substituter = Base.Fix2(fast_substitute, subrules) + substituter = Base.Fix2(substitute, subrules) @set! sys.eqs = map(substituter, get_eqs(sys)) @set! sys.observed = map(substituter, get_observed(sys)) @set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys)) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 1c43022f4b..6579ad63cc 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -67,10 +67,10 @@ function ImperativeAffect(; f, kwargs...) ImperativeAffect(f; kwargs...) end -function Symbolics.fast_substitute(aff::ImperativeAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms, - map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) +function (s::SymbolicUtils.Substituter)(aff::ImperativeAffect) + ImperativeAffect(aff.f, s(aff.obs), aff.obs_syms, + s(aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) + end function Base.show(io::IO, mfa::ImperativeAffect) diff --git a/src/systems/optimal_control_interface.jl b/src/systems/optimal_control_interface.jl index 5a0ddbf8d5..108eb05893 100644 --- a/src/systems/optimal_control_interface.jl +++ b/src/systems/optimal_control_interface.jl @@ -357,16 +357,16 @@ function substitute_model_vars(model, sys, exprs, tspan) t = get_iv(sys) exprs = map( - c -> Symbolics.fast_substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) + c -> substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) (ti, tf) = tspan if symbolic_type(tf) === ScalarSymbolic() _tf = model.tₛ + ti exprs = map( - c -> Symbolics.fast_substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) - exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs) + c -> substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, Dict(tf => _tf)), exprs) end - exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) exprs end @@ -440,7 +440,7 @@ end function substitute_toterm(vars, exprs) toterm_map = Dict([u => default_toterm(value(u)) for u in vars]) - exprs = map(c -> Symbolics.fast_substitute(c, toterm_map), exprs) + exprs = map(c -> substitute(c, toterm_map), exprs) end function substitute_params(pmap, exprs) diff --git a/src/systems/solver_nlprob.jl b/src/systems/solver_nlprob.jl index badfe21efb..d4772018c1 100644 --- a/src/systems/solver_nlprob.jl +++ b/src/systems/solver_nlprob.jl @@ -55,7 +55,7 @@ function inner_nlsystem(sys::System, mm, nlstep_compile::Bool) subrules = Dict([v => unwrap(gamma2*v + inner_tmp[i]) for (i, v) in enumerate(dvs)]) subrules[t] = unwrap(c) - new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) + new_rhss = map(Base.Fix2(substitute, subrules), rhss) new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs new_eqs = [0 ~ rhs for rhs in new_rhss] diff --git a/src/systems/system.jl b/src/systems/system.jl index 74e233791a..5f8b33216f 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -899,7 +899,7 @@ function NonlinearSystem(sys::System) subrules[var] = 0.0 end eqs = map(eqs) do eq - fast_substitute(eq, subrules) + substitute(eq, subrules) end nsys = System(eqs, unknowns(sys), [parameters(sys); get_iv(sys)]; defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys), diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 4c52300239..6aa4166d03 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -195,7 +195,7 @@ function simplify_optimization_system(sys::System; split = true, kwargs...) dvs[i] = irrvar end end - econs = fast_substitute.(econs, (irreducible_subs,)) + econs = substitute.(econs, (irreducible_subs,)) nlsys = System(econs, dvs, parameters(sys); name = :___tmp_nlsystem) snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false) obs = observed(snlsys) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index e9d68e162f..37706c616f 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -408,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential && isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv) # parameter derivatives are opted out by specifying `D(p) ~ missing`, but - # we want to store `nothing` in the map because that means `fast_substitute` + # we want to store `nothing` in the map because that means `substitute` # will ignore the rule. We will this identify the presence of `eq′.lhs` in # the differentiated expression and error. param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing) @@ -751,11 +751,11 @@ function shift_discrete_system(ts::TearingState) if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold, Pre})) for i in eachindex(fullvars) - fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( + fullvars[i] = StructuralTransformations.simplify_shifts(substitute( fullvars[i], discmap; operator = Union{Sample, Hold, Pre})) end for i in eachindex(eqs) - eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute( + eqs[i] = StructuralTransformations.simplify_shifts(substitute( eqs[i], discmap; operator = Union{Sample, Hold, Pre})) end @set! ts.sys.eqs = eqs diff --git a/src/utils.jl b/src/utils.jl index e89d7f1d3a..b55a8ea0ae 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -687,7 +687,7 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0) Encountered a wrapped value in `collect_var!`. This function should only ever \ receive unwrapped symbolic variables. This is likely a bug in the code generating \ an expression passed to `collect_vars!` or `collect_scoped_vars!`. A common cause \ - is using `substitute` or `fast_substitute` with rules where the values are \ + is using `substitute` with rules where the values are \ wrapped symbolic variables. """) end diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index 30e80e1cec..c29b68b9be 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -399,7 +399,7 @@ end prob = OptimizationProblem(sys, [x => [42.0, 12.37]]; hess = true, sparse = true) symbolic_hess = Symbolics.hessian(cost(sys), x) - symbolic_hess_value = Symbolics.fast_substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) + symbolic_hess_value = substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) oop_hess = prob.f.hess(prob.u0, prob.p) @test oop_hess ≈ symbolic_hess_value From 2cb147bb3d6101afec9c26f9926a087d348f2628 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 18:43:12 +0530 Subject: [PATCH 014/157] refactor: remove usages of `CallWithMetadata` --- src/ModelingToolkit.jl | 3 +-- src/parameters.jl | 10 +++------- src/systems/index_cache.jl | 16 ++++++---------- src/systems/validation.jl | 4 ++-- src/variables.jl | 2 +- test/model_parsing.jl | 2 +- 6 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 69c6486eff..144770269d 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -79,8 +79,7 @@ using Symbolics: parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, NAMESPACE_SEPARATOR, setdefaultval, - hasnode, fixpoint_sub, - CallWithMetadata, CallWithParent + hasnode, fixpoint_sub, CallAndWrap const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, diff --git a/src/parameters.jl b/src/parameters.jl index d5d96120b8..59dede20e6 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -46,17 +46,13 @@ end function iscalledparameter(x) x = unwrap(x) - return isparameter(getmetadata(x, CallWithParent, nothing)) + return SymbolicUtils.is_called_function_symbolic(x) && isparameter(operation(x)) end function getcalledparameter(x) x = unwrap(x) - # `parent` is a `CallWithMetadata` with the correct metadata, - # but no namespacing. `operation(x)` has the correct namespacing, - # but is not a `CallWithMetadata` and doesn't have any metadata. - # This approach combines both. - parent = getmetadata(x, CallWithParent) - return CallWithMetadata(operation(x), metadata(parent)) + @assert iscalledparameter(x) + return operation(x) end """ diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index b5575ea628..c861a2e1c3 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -27,15 +27,14 @@ struct DiscreteIndex end const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} -const NonnumericMap = Dict{ - Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}} +const NonnumericMap = Dict{SymbolicT, Tuple{Int, Int}} const UnknownIndexMap = Dict{ BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} const TunableIndexMap = Dict{BasicSymbolic, Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} -const SymbolicParam = Union{BasicSymbolic, CallWithMetadata} +const SymbolicParam = SymbolicT struct IndexCache unknown_idx::UnknownIndexMap @@ -48,8 +47,7 @@ struct IndexCache constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} - dependent_pars_to_timeseries::Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType} + dependent_pars_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate initials_buffer_size::BufferTemplate @@ -307,8 +305,7 @@ function IndexCache(sys::AbstractSystem) end end - dependent_pars_to_timeseries = Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}() + dependent_pars_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() for eq in get_parameter_dependencies(sys) sym = eq.lhs @@ -515,10 +512,9 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(sum(x -> x.length, temp))] for temp in ic.discrete_buffer_sizes) - const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] + const_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.constant_buffer_sizes) - nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF)) - for _ in 1:(temp.length)] + nonnumeric_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)] for temp in ic.nonnumeric_buffer_sizes) for p in ps p = unwrap(p) diff --git a/src/systems/validation.jl b/src/systems/validation.jl index cdf75a1631..7d8f39e0ec 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -6,7 +6,7 @@ using ..ModelingToolkit: ValidationError, get_systems, Conditional, Comparison using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: SymbolicT, value, issym, isadd, ismul, ispow +using Symbolics: SymbolicT, value, issym, isadd, ismul, ispow, CallAndWrap const MT = ModelingToolkit Base.:*(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x * y @@ -49,7 +49,7 @@ get_unit(x::Real) = unitless get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(value(x)) -function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}) +function get_unit(x::Union{Symbolics.Arr, CallAndWrap}) get_literal_unit(x) end get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) diff --git a/src/variables.jl b/src/variables.jl index 8089b9bffd..1d5475c12c 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -682,7 +682,7 @@ end function (A::EvalAt)(x::SymbolicT) if symbolic_type(x) == NotSymbolic() || !iscall(x) - if x isa Symbolics.CallWithMetadata + if x isa CallAndWrap return x(A.t) else return x diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 2c713d4149..73a54b7c51 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -1007,7 +1007,7 @@ end vars = Symbolics.get_variables(only(equations(ex))) @test length(vars) == 2 for u in Symbolics.unwrap.(unknowns(ex)) - @test !Symbolics.hasmetadata(u, Symbolics.CallWithParent) + @test !SymbolicUtils.is_function_symbolic(u) @test any(isequal(u), vars) end end From 36ea1d8cb688710e6c3bae056835c0f0d8a7e635 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:09:04 +0530 Subject: [PATCH 015/157] refactor: remove usages of old `symtype` syntax --- docs/src/basics/FAQ.md | 4 +--- src/ModelingToolkit.jl | 6 +++--- src/discretedomain.jl | 4 ++-- .../StructuralTransformations.jl | 2 +- src/systems/abstractsystem.jl | 2 +- src/systems/codegen_utils.jl | 4 ++-- src/systems/diffeqs/basic_transformations.jl | 8 ++++---- src/systems/system.jl | 3 ++- src/systems/systemstructure.jl | 2 +- src/utils.jl | 4 ++-- test/model_parsing.jl | 3 +-- test/odesystem.jl | 2 +- test/sciml_problem_inputs.jl | 4 ++-- test/simplify.jl | 4 +++- test/variable_parsing.jl | 15 ++++++++------- 15 files changed, 34 insertions(+), 33 deletions(-) diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 2511c00675..7b712395b3 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -192,9 +192,7 @@ p, replace, alias = SciMLStructures.canonicalize(Tunable(), prob.p) # changes to the array will be reflected in parameter values ``` -See the [basic example on optimizing](https://docs.sciml.ai/ModelingToolkit/dev/examples/remake/#Optimizing-through-an-ODE-solve-and-re-creating-MTK-Problems) for combining these steps to optimizing parameters and use ForwardDiff.jl as the backend for Automatic Differentiation. - -# ERROR: ArgumentError: SymbolicUtils.BasicSymbolic{Real}[xˍt(t)] are missing from the variable map. +# ERROR: ArgumentError: `[xˍt(t)]` are missing from the variable map. This error can come up after running `mtkcompile` on a system that generates dummy derivatives (i.e. variables with `ˍt`). For example, here even though all the variables are defined with initial values, the `ODEProblem` generation will throw an error that defaults are missing from the variable map. diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 144770269d..6828228d16 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -9,10 +9,10 @@ using PrecompileTools, Reexport end import SymbolicUtils +import SymbolicUtils as SU import SymbolicUtils: iscall, arguments, operation, maketerm, promote_symtype, isadd, ismul, ispow, issym, FnType, - @rule, Rewriters, substitute, metadata, BasicSymbolic, - Sym, Term + @rule, Rewriters, substitute, metadata, BasicSymbolic using SymbolicUtils.Code import SymbolicUtils.Code: toexpr import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint @@ -79,7 +79,7 @@ using Symbolics: parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, NAMESPACE_SEPARATOR, setdefaultval, - hasnode, fixpoint_sub, CallAndWrap + hasnode, fixpoint_sub, CallAndWrap, SArgsT, SSym, STerm const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, diff --git a/src/discretedomain.jl b/src/discretedomain.jl index ffa36c9dbd..eb47eecec9 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -162,7 +162,7 @@ function Sample(arg::Real) Sample()(arg) end end -(D::Sample)(x) = Term{symtype(x)}(D, Any[x]) +(D::Sample)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Sample)(x::Num) = Num(D(value(x))) SymbolicUtils.promote_symtype(::Sample, x) = x Base.nameof(::Sample) = :Sample @@ -208,7 +208,7 @@ end is_transparent_operator(::Type{Hold}) = true -(D::Hold)(x) = Term{symtype(x)}(D, Any[x]) +(D::Hold)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Hold)(x::Num) = Num(D(value(x))) SymbolicUtils.promote_symtype(::Hold, x) = x Base.nameof(::Hold) = :Hold diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index e407678c18..90883ce7e6 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -14,7 +14,7 @@ using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, unknowns, equations, vars, SymbolicT, diff2term_with_unit, shift2term_with_unit, value, - operation, arguments, Sym, Term, simplify, symbolic_linear_solve, + operation, arguments, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, empty_substitutions, get_substitutions, get_tearing_state, get_iv, independent_variables, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d057e7cfe5..b07f23e8f2 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -69,7 +69,7 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end -const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___) +const MTKPARAMETERS_ARG = SSym(:___mtkparameters___; type = Vector{Vector{Any}}, shape = SymbolicUtils.Unknown(1)) """ $(TYPEDSIGNATURES) diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index dbbd7f85a8..b77bba98b4 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -135,8 +135,8 @@ end """ The argument of generated functions corresponding to the history function. """ -const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___) -const BVP_SOLUTION = Sym{Symbolics.FnType{Tuple{<:Real}, Vector{Real}}}(:__sol__) +const DDE_HISTORY_FUN = SSym(:___history___; type = SU.FnType{Tuple{Any, <:Real}, Vector{Real}}, shape = SU.Unknown(1)) +const BVP_SOLUTION = SSym(:__sol__; type = Symbolics.FnType{Tuple{<:Real}, Vector{Real}}, shape = SU.Unknown(1)) """ $(TYPEDSIGNATURES) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index e2eee33bae..4cc71db108 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -470,7 +470,7 @@ julia> M = change_independent_variable(M, x); julia> M = mtkcompile(M; allow_symbolic = true); julia> unknowns(M) -3-element Vector{SymbolicUtils.BasicSymbolic{Real}}: +3-element Vector{Symbolics.SymbolicsT}: xˍt(x) y(x) yˍx(x) @@ -1037,13 +1037,13 @@ function respecialize(sys::AbstractSystem, mapping; all = false) """ if iscall(k) - op = operation(k)::BasicSymbolic + op = operation(k)::SymbolicT @assert !iscall(op) - op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op)) + op = SU.Sym{VartypeT}(nameof(op); type = SU.FnType{Tuple, T, Nothing}, shape = SU.shape(k)) args = arguments(k) new_p = op(args...) else - new_p = SymbolicUtils.Sym{T}(getname(k)) + new_p = SSym(getname(k); type = T, shape = SU.shape(v)) end get_ps(sys)[idx] = new_p diff --git a/src/systems/system.jl b/src/systems/system.jl index 5f8b33216f..83f70ee0b6 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -94,7 +94,7 @@ struct System <: IntermediateDeprecationSystem The independent variable for a time-dependent system, or `nothing` for a time-independent system. """ - iv::Union{Nothing, BasicSymbolic{Real}} + iv::Union{Nothing, SymbolicT} """ Equations that compute variables of a system that have been eliminated from the set of unknowns by `mtkcompile`. More generally, this contains all variables that can be @@ -278,6 +278,7 @@ struct System <: IntermediateDeprecationSystem variable $iv. """)) end + @assert iv === nothing || symtype(iv) === Real jumps = Vector{JumpType}(jumps) if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing check_independent_variables([iv]) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 37706c616f..fc7b6ec5ab 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -846,7 +846,7 @@ function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure) " variables\n") Base.print_matrix(io, SystemStructurePrintMatrix(s)) else - S = incidence_matrix(s.graph, Num(Sym{Real}(:×))) + S = incidence_matrix(s.graph, Num(SSym(:×; type = Real, shape = SU.ShapeVecT()))) print(io, "Incidence matrix:") show(io, mime, S) end diff --git a/src/utils.jl b/src/utils.jl index b55a8ea0ae..0a7c9e0bb4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,7 +20,7 @@ function detime_dvs(op) if !iscall(op) op elseif issym(operation(op)) - Sym{Real}(nameof(operation(op))) + SSym(nameof(operation(op)); type = Real, shape = SU.ShapeVecT()) else maketerm(typeof(op), operation(op), detime_dvs.(arguments(op)), metadata(op)) @@ -33,7 +33,7 @@ end Reverse `detime_dvs` for the given `dvs` using independent variable `iv`. """ function retime_dvs(op, dvs, iv) - issym(op) && return Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(op))(iv) + issym(op) && return SSym(nameof(op); type = FnType{Tuple{symtype(iv)}, Real}, shape = SU.ShapeVecT())(iv) iscall(op) ? maketerm(typeof(op), operation(op), retime_dvs.(arguments(op), (dvs,), (iv,)), metadata(op)) : diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 73a54b7c51..eefd155db6 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -201,8 +201,7 @@ resistor = getproperty(rc, :resistor; namespace = false) @named pi_model = PiModel() - @test typeof(ModelingToolkit.getdefault(pi_model.p)) <: - SymbolicUtils.BasicSymbolic{Irrational} + @test symtype(ModelingToolkit.getdefault(pi_model.p)) <: Irrational @test getdefault(getdefault(pi_model.p)) == π end diff --git a/test/odesystem.jl b/test/odesystem.jl index cf16b31a42..417eaa70a7 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -535,7 +535,7 @@ sys = complete(sys) us = map(s -> (@variables $s(t))[1], syms) ps = map(s -> (@variables $s(t))[1], syms_p) buffer, = @variables $buffername[1:length(u0)] - dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia + dummy_var = Symbolics.SSym(:_; type = Any) # this is safe because _ cannot be a rvalue in Julia ss = Iterators.flatten((us, ps)) vv = Iterators.flatten((u0, p0)) diff --git a/test/sciml_problem_inputs.jl b/test/sciml_problem_inputs.jl index a91a8d8c7c..cdb83ad0e2 100644 --- a/test/sciml_problem_inputs.jl +++ b/test/sciml_problem_inputs.jl @@ -2,7 +2,7 @@ # Fetch packages using ModelingToolkit, JumpProcesses, NonlinearSolve, OrdinaryDiffEq, StaticArrays, - SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test + SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test, SymbolicUtils using ModelingToolkit: t_nounits as t, D_nounits as D # Sets rnd number. @@ -29,7 +29,7 @@ begin ] noise_eqs = fill(0.01, 3, 6) jumps = [ - MassActionJump(kp, Pair{Symbolics.BasicSymbolic{Real}, Int64}[], [X => 1]), + MassActionJump(kp, Pair{Symbolics.SymbolicT, Int64}[], [X => 1]), MassActionJump(kd, [X => 1], [X => -1]), MassActionJump(k1, [X => 1], [X => -1, Y => 1]), MassActionJump(k2, [Y => 1], [X => 1, Y => -1]), diff --git a/test/simplify.jl b/test/simplify.jl index 4252e3262e..6968883132 100644 --- a/test/simplify.jl +++ b/test/simplify.jl @@ -1,5 +1,7 @@ using ModelingToolkit using ModelingToolkit: value +using Symbolics: STerm +import SymbolicUtils using Test @independent_variables t @@ -11,7 +13,7 @@ null_op = 0 * t one_op = 1 * t @test isequal(simplify(one_op), t) -identity_op = Num(Term(identity, [value(x)])) +identity_op = Num(STerm(identity, [value(x)]; type = Real, shape = SymbolicUtils.ShapeVecT())) @test isequal(simplify(identity_op), x) minus_op = -x diff --git a/test/variable_parsing.jl b/test/variable_parsing.jl index 60b4e24d64..a00aa2f729 100644 --- a/test/variable_parsing.jl +++ b/test/variable_parsing.jl @@ -2,15 +2,16 @@ using ModelingToolkit using Test using ModelingToolkit: value, Flow -using SymbolicUtils: FnType +using Symbolics: SSym +using SymbolicUtils: FnType, ShapeVecT @independent_variables t @variables x(t) y(t) # test multi-arg @variables z(t) # test single-arg -x1 = Num(Sym{FnType{Tuple{Any}, Real}}(:x)(value(t))) -y1 = Num(Sym{FnType{Tuple{Any}, Real}}(:y)(value(t))) -z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) +x1 = Num(SSym(:x; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +y1 = Num(SSym(:y; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +z1 = Num(SSym(:z; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) @test isequal(x1, x) @test isequal(y1, y) @@ -22,9 +23,9 @@ z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) end @parameters σ(..) -t1 = Num(Sym{Real}(:t)) -s1 = Num(Sym{Real}(:s)) -σ1 = Num(Sym{FnType{Tuple, Real}}(:σ)) +t1 = Num(SSym(:t; type = Real, shape = ShapeVecT())) +s1 = Num(SSym(:s; type = Real, shape = ShapeVecT())) +σ1 = Num(SSym(:σ; type = FnType{Tuple, Real}, shape = ShapeVecT())) @test isequal(t1, t) @test isequal(s1, s) @test isequal(σ1(t), σ(t)) From 51440fd46acabfd35bdf9976496013998a946521 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:14:28 +0530 Subject: [PATCH 016/157] refactor: remove usages of `Difference` --- src/systems/unit_check.jl | 1 - src/utils.jl | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index eac27b58be..46d765e604 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -71,7 +71,6 @@ get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(unwrap(x)) get_unit(x::Symbolics.Arr) = get_unit(unwrap(x)) get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) get_unit(op::typeof(getindex), args) = get_unit(args[1]) get_unit(x::SciMLBase.NullParameters) = unitless get_unit(op::typeof(instream), args) = get_unit(args[1]) diff --git a/src/utils.jl b/src/utils.jl index 0a7c9e0bb4..667e79bfa3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -329,9 +329,7 @@ end Throw error when difference/derivative operation occurs in the R.H.S. """ @noinline function throw_invalid_operator(opvar, eq, op::Type) - if op === Difference - error("The Difference operator is deprecated, use ShiftIndex instead") - elseif op === Differential + if op === Differential optext = "derivative" end msg = "The $optext variable must be isolated to the left-hand " * From e0b6e72d75ec8cb222fd09463c14b9a95b9a6750 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:23:23 +0530 Subject: [PATCH 017/157] refactor: remove usages of `Symbolics._mapreduce` --- src/systems/unit_check.jl | 2 +- src/systems/validation.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 46d765e604..e3a113e691 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -113,7 +113,7 @@ function get_unit(op::Conditional, args) return terms[2] end -function get_unit(op::typeof(Symbolics._mapreduce), args) +function get_unit(op::typeof(mapreduce), args) if args[2] == + get_unit(args[3]) else diff --git a/src/systems/validation.jl b/src/systems/validation.jl index 7d8f39e0ec..bb2c622d92 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -89,7 +89,7 @@ function get_unit(op::Conditional, args) return terms[2] end -function get_unit(op::typeof(Symbolics._mapreduce), args) +function get_unit(op::typeof(mapreduce), args) if args[2] == + get_unit(args[3]) else From eafd2daba7c46754b48c53fe386210533d9089e5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:33:25 +0530 Subject: [PATCH 018/157] refactor: remove usages of `Symbolics.getparent` --- src/ModelingToolkit.jl | 2 +- src/parameters.jl | 4 ++-- src/systems/unit_check.jl | 4 ++-- src/systems/validation.jl | 4 ++-- src/utils.jl | 20 ++++++-------------- src/variables.jl | 34 +++++++++++++--------------------- 6 files changed, 26 insertions(+), 42 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6828228d16..223c6096bf 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -88,7 +88,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity, ParallelForm, SerialForm, MultithreadedForm, build_function, rhss, lhss, prettify_expr, gradient, jacobian, hessian, derivative, sparsejacobian, sparsehessian, - scalarize, getparent, hasderiv + scalarize, hasderiv import DiffEqBase: @add_kwonly export independent_variables, unknowns, observables, parameters, full_parameters, diff --git a/src/parameters.jl b/src/parameters.jl index 59dede20e6..7bb76d7bf0 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -28,8 +28,8 @@ function isparameter(x) if x isa SymbolicT && (varT = getvariabletype(x, nothing)) !== nothing return varT === PARAMETER #TODO: Delete this branch - elseif x isa SymbolicT && Symbolics.getparent(x, false) !== false - p = Symbolics.getparent(x) + elseif x isa SymbolicT && iscall(x) && operation(x) === getindex + p = arguments(x)[1] isparameter(p) || (hasmetadata(p, Symbolics.VariableSource) && getmetadata(p, Symbolics.VariableSource)[1] == :parameters) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index e3a113e691..839b77094f 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -155,8 +155,8 @@ function get_unit(x::SymbolicT) op = operation(x) if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + elseif iscall(op) && operation(op) === getindex + gp = arguments(op)[1] return screen_unit(getmetadata(gp, VariableUnit, unitless)) end # Actual function calls: args = arguments(x) diff --git a/src/systems/validation.jl b/src/systems/validation.jl index bb2c622d92..ecd98b1d43 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -129,8 +129,8 @@ function get_unit(x::SymbolicT) op = operation(x) if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + elseif iscall(op) && operation(op) === getindex + gp = arguments(op)[1] return screen_unit(getmetadata(gp, VariableUnit, unitless)) end # Actual function calls: args = arguments(x) diff --git a/src/utils.jl b/src/utils.jl index 667e79bfa3..a311dd9d05 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -310,18 +310,12 @@ function collect_var_to_name!(vars, xs) for x in xs symbolic_type(x) == NotSymbolic() && continue x = unwrap(x) - if hasmetadata(x, Symbolics.GetindexParent) - xarr = getmetadata(x, Symbolics.GetindexParent) - hasname(xarr) || continue - vars[Symbolics.getname(xarr)] = xarr - else - if iscall(x) && operation(x) === getindex - x = arguments(x)[1] - end - x = unwrap(x) - hasname(x) || continue - vars[Symbolics.getname(unwrap(x))] = x + if iscall(x) && operation(x) === getindex + x = arguments(x)[1] end + x = unwrap(x) + hasname(x) || continue + vars[Symbolics.getname(unwrap(x))] = x end end @@ -388,9 +382,7 @@ isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) isvariable(x::Num)::Bool = isvariable(value(x)) function isvariable(x)::Bool x isa SymbolicT || return false - p = getparent(x, nothing) - p === nothing || (x = p) - hasmetadata(x, VariableSource) + hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1]) end """ diff --git a/src/variables.jl b/src/variables.jl index 1d5475c12c..d1bd03f97b 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -190,13 +190,13 @@ function setconnect(x, t::Type{T}) where {T <: AbstractConnectType} setmetadata(x, VariableConnectType, t) end -### Input, Output, Irreducible -isvarkind(m, x::Union{Num, Symbolics.Arr}) = isvarkind(m, value(x)) -function isvarkind(m, x) - iskind = getmetadata(x, m, nothing) - iskind !== nothing && return iskind - x = getparent(x, x) - getmetadata(x, m, false) +### Input, Output, Irreducible +isvarkind(m, x, def = false) = safe_getmetadata(m, x, def) +safe_getmetadata(m, x::Union{Num, Symbolics.Arr}, def) = safe_getmetadata(m, value(x), def) +function safe_getmetadata(m, x, default) + hasmetadata(x, m) && return getmetadata(x, m) + iscall(x) && operation(x) === getindex && return safe_getmetadata(m, arguments(x)[1], default) + return default end """ @@ -282,8 +282,8 @@ Create parameters with bounds like this """ function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) x = unwrap(x) - p = Symbolics.getparent(x, nothing) - if p === nothing + if operation(p) === getindex + p = arguments(p)[1] bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf)) if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() bounds = map(bounds) do b @@ -339,9 +339,7 @@ isdisturbance(x::Num) = isdisturbance(Symbolics.unwrap(x)) Determine whether symbolic variable `x` is marked as a disturbance input. """ function isdisturbance(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDisturbance, false) + isvarkind(VariableDisturbance, x) end setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v) @@ -372,9 +370,7 @@ Create a tunable parameter by See also [`tunable_parameters`](@ref), [`getbounds`](@ref) """ function istunable(x, default = true) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableTunable, default) + isvarkind(VariableTunable, x, default) end ## Dist ======================================================================== @@ -398,9 +394,7 @@ getdist(u) # retrieve distribution ``` """ function getdist(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDistribution, nothing) + safe_getmetadata(VariableDistribution, x, nothing) end """ @@ -492,9 +486,7 @@ getdescription(x::Symbolics.Arr) = getdescription(Symbolics.unwrap(x)) Return any description attached to variables `x`. If no description is attached, an empty string is returned. """ function getdescription(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDescription, "") + safe_getmetadata(VariableDescription, x, "") end """ From 251ee2b24eedf6ae24e2bfbcbe73e59170061f86 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:39:42 +0530 Subject: [PATCH 019/157] fix: implement `SII.getname` for `System` --- src/systems/system.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/systems/system.jl b/src/systems/system.jl index 83f70ee0b6..6accc2b633 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -459,6 +459,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; initializesystem, is_initializesystem, is_discrete; checks) end +SymbolicIndexingInterface.getname(x::System) = nameof(x) + """ $(TYPEDSIGNATURES) From 334dcfdb36d2fb5316199f01dd867c70b0cda468 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 19:53:52 +0530 Subject: [PATCH 020/157] fix: remove usages of `Symbolics.Unknown`, `Symbolics.shape` --- src/problems/initializationproblem.jl | 2 +- .../StructuralTransformations.jl | 1 + src/structural_transformation/symbolics_tearing.jl | 4 ++-- src/structural_transformation/utils.jl | 2 +- src/systems/abstractsystem.jl | 2 +- src/systems/callbacks.jl | 2 +- src/systems/index_cache.jl | 4 ++-- src/systems/nonlinear/initializesystem.jl | 5 ++--- src/systems/parameter_buffer.jl | 12 ++++++------ src/systems/problem_utils.jl | 12 ++++++------ src/systems/system.jl | 6 ++---- src/systems/systemstructure.jl | 4 ++-- src/utils.jl | 4 ++-- src/variables.jl | 4 ++-- 14 files changed, 31 insertions(+), 33 deletions(-) diff --git a/src/problems/initializationproblem.jl b/src/problems/initializationproblem.jl index 6960811bbd..5267ad4abc 100644 --- a/src/problems/initializationproblem.jl +++ b/src/problems/initializationproblem.jl @@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const for k in keys(op) has_u0_ics |= is_variable(sys, k) || isdifferential(k) || symbolic_type(k) == ArraySymbolic() && - is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k)))) + symbolic_has_known_size(k) && is_variable(sys, unwrap(first(wrap(k)))) end if !has_u0_ics && get_initializesystem(sys) !== nothing isys = get_initializesystem(sys; initialization_eqs, check_units) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 90883ce7e6..cbc01de2b5 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -9,6 +9,7 @@ using SymbolicUtils using SymbolicUtils.Code using SymbolicUtils.Rewriters using SymbolicUtils: maketerm, iscall +import SymbolicUtils as SU using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 02b3b7ada1..051387b939 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -1251,7 +1251,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) array || continue iscall(lhs) || continue operation(lhs) === getindex || continue - Symbolics.shape(lhs) != Symbolics.Unknown() || continue + SU.shape(lhs) isa SU.Unknown && continue arg1 = arguments(lhs)[1] cnt = get(arr_obs_occurrences, arg1, 0) arr_obs_occurrences[arg1] = cnt + 1 @@ -1264,7 +1264,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) for sym in unknowns iscall(sym) || continue operation(sym) === getindex || continue - Symbolics.shape(sym) != Symbolics.Unknown() || continue + SU.shape(sym) isa SU.Unknown && continue arg1 = arguments(sym)[1] cnt = get(arr_obs_occurrences, arg1, 0) cnt == 0 && continue diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 84d6c1b01c..d4211b16e1 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -257,7 +257,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no if any( v -> any(isequal(v), fullvars) || symbolic_type(v) == ArraySymbolic() && - Symbolics.shape(v) != Symbolics.Unknown() && + SU.shape(v) isa SU.Unknown || any(x -> any(isequal(x), fullvars), collect(v)), vars( a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial})) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b07f23e8f2..a1f76145cf 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2783,7 +2783,7 @@ function process_parameter_equations(sys::AbstractSystem) if all(varsbuf) do sym is_parameter(sys, sym) || symbolic_type(sym) == ArraySymbolic() && - is_sized_array_symbolic(sym) && + symbolic_has_known_size(sym) && all(Base.Fix1(is_parameter, sys), collect(sym)) || iscall(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[1]) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 354a661bc6..19725cbfc7 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -869,7 +869,7 @@ function default_operating_point(affsys::AffectSystem) T = symtype(p) if T <: Number op[p] = false - elseif T <: Array{<:Real} && is_sized_array_symbolic(p) + elseif T <: Array{<:Real} && symbolic_has_known_size(p) op[p] = zeros(size(p)) end end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index c861a2e1c3..ea651c236f 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -202,7 +202,7 @@ function IndexCache(sys::AbstractSystem) haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() && + if istunable(p, true) && symbolic_has_known_size(p) && (ctype == Real || ctype <: AbstractFloat || ctype <: AbstractArray{Real} || ctype <: AbstractArray{<:AbstractFloat}) @@ -417,7 +417,7 @@ function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) end sym = unwrap(sym) validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray && - Symbolics.shape(sym) !== Symbolics.Unknown() + symbolic_has_known_size(sym) return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing ParameterIndex(SciMLStructures.Tunable(), idx, validate_size) elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index f377f0202f..928a2f3118 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -271,8 +271,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem; vars!(vs, eq; op = Initial) allpars = full_parameters(sys) for p in allpars - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(p) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && SU.shape(p) isa SU.Unknown append!(allpars, Symbolics.scalarize(p)) end end @@ -502,7 +501,7 @@ function get_possibly_array_fallback_singletons(varmap, p) return varmap[p] end if symbolic_type(p) == ArraySymbolic() - is_sized_array_symbolic(p) || return nothing + symbolic_has_known_size(p) || return nothing scal = collect(p) if all(x -> haskey(varmap, x), scal) res = [varmap[x] for x in scal] diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index b42ef1e3ef..322889016e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -123,7 +123,7 @@ function MTKParameters( val = symconvert(ctype, val) done = set_value(sym, val) if !done && Symbolics.isarraysymbolic(sym) - if Symbolics.shape(sym) === Symbolics.Unknown() + if !symbolic_has_known_size(sym) for i in eachindex(val) set_value(sym[i], val[i]) end @@ -463,11 +463,11 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val) end stype = symtype(p) sz = if stype <: AbstractArray - Symbolics.shape(p) == Symbolics.Unknown() ? Symbolics.Unknown() : size(p) + size(p) elseif stype <: Number size(p) else - Symbolics.Unknown() + SU.Unknown(-1) end validate_parameter_type(ic, stype, sz, p, idx, val) end @@ -479,7 +479,7 @@ function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val) stype = AbstractArray{<:stype} end validate_parameter_type( - ic, stype, Symbolics.Unknown(), nothing, idx, val) + ic, stype, SU.Unknown(-1), nothing, idx, val) end function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) @@ -499,7 +499,7 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end # ... and must match sizes - if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz + if stype <: AbstractArray && !(sz isa SU.Unknown) && size(val) != sz throw(InvalidParameterSizeException(sym, val)) end # Early exit @@ -718,7 +718,7 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru sym = idx idx = parameter_index(ic, sym) if idx === nothing - Symbolics.shape(sym) == Symbolics.Unknown() && + symbolic_has_known_size(sym) || throw(ParameterNotInSystem(sym)) size(sym) == size(val) || throw(InvalidParameterSizeException(sym, val)) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2f5db01479..7a301ea34b 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -17,10 +17,10 @@ anydict(x) = AnyDict(x) """ $(TYPEDSIGNATURES) -Check if `x` is a symbolic with known size. Assumes `Symbolics.shape(unwrap(x))` +Check if `x` is a symbolic with known size. Assumes `SymbolicUtils.shape(unwrap(x))` is a valid operation. """ -is_sized_array_symbolic(x) = Symbolics.shape(unwrap(x)) != Symbolics.Unknown() +symbolic_has_known_size(x) = !(SU.shape(unwrap(x)) isa SU.Unknown) """ $(TYPEDSIGNATURES) @@ -128,7 +128,7 @@ function add_fallbacks!( haskey(varmap, ttvar) && continue # array symbolics with a defined size may be present in the scalarized form - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) val = map(eachindex(var)) do idx # @something is lazy and saves from writing a massive if-elseif-else @something(get(varmap, var[idx], nothing), @@ -162,7 +162,7 @@ function add_fallbacks!( fallbacks, arrvar, nothing) get(fallbacks, ttarrvar, nothing) Some(nothing) if val !== nothing val = val[idxs...] - is_sized_array_symbolic(arrvar) && push!(arrvars, arrvar) + symbolic_has_known_size(arrvar) && push!(arrvars, arrvar) end else val = nothing @@ -197,7 +197,7 @@ function missingvars( ttsym = toterm(var) haskey(varmap, ttsym) && continue - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) mask = map(eachindex(var)) do idx !haskey(varmap, var[idx]) && !haskey(varmap, ttsym[idx]) end @@ -535,7 +535,7 @@ If a scalarized entry already exists, it is not overridden. function scalarize_vars_in_varmap!(varmap::AbstractDict, vars) for var in vars symbolic_type(var) == ArraySymbolic() || continue - is_sized_array_symbolic(var) || continue + symbolic_has_known_size(var) || continue haskey(varmap, var) || continue for i in eachindex(var) haskey(varmap, var[i]) && continue diff --git a/src/systems/system.jl b/src/systems/system.jl index 6accc2b633..c0f71cc160 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -614,15 +614,13 @@ function gather_array_params(ps) for p in ps if iscall(p) && operation(p) === getindex par = arguments(p)[begin] - if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && - all(par[i] in ps for i in eachindex(par)) + if symbolic_has_known_size(p) && all(par[i] in ps for i in eachindex(par)) push!(new_ps, par) else push!(new_ps, p) end else - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(unwrap(p)) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && symbolic_has_known_size(p) for i in eachindex(p) delete!(new_ps, p[i]) end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index fc7b6ec5ab..3c9c429468 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -304,7 +304,7 @@ end function symbolic_contains(var, set) var in set || symbolic_type(var) == ArraySymbolic() && - Symbolics.shape(var) != Symbolics.Unknown() && + symbolic_has_known_size(var) && all(x -> x in set, Symbolics.scalarize(var)) end @@ -375,7 +375,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) ps = Set{SymbolicT}() for x in full_parameters(sys) push!(ps, x) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() + if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) xx = Symbolics.scalarize(x) union!(ps, xx) end diff --git a/src/utils.jl b/src/utils.jl index a311dd9d05..1c1727781c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -968,7 +968,7 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) end any(isequal(expr), vars) && return expr iscall(expr) || return expr - Symbolics.shape(expr) == Symbolics.Unknown() && return expr + symbolic_has_known_size(expr) || return expr haskey(state, expr) && return state[expr] op = operation(expr) args = arguments(expr) @@ -1059,7 +1059,7 @@ function var_in_varlist(var, varlist::AbstractSet, iv) # indexed array symbolic, unscalarized array present (iscall(var) && operation(var) === getindex && arguments(var)[1] in varlist) || # unscalarized sized array symbolic, all scalarized elements present - (symbolic_type(var) == ArraySymbolic() && is_sized_array_symbolic(var) && + (symbolic_type(var) == ArraySymbolic() && symbolic_has_known_size(var) && all(x -> x in varlist, collect(var))) || # delayed variables (isdelay(var, iv) && var_in_varlist(operation(var)(iv), varlist, iv)) diff --git a/src/variables.jl b/src/variables.jl index d1bd03f97b..9ae5c48803 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -285,7 +285,7 @@ function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) if operation(p) === getindex p = arguments(p)[1] bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf)) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() + if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) bounds = map(bounds) do b b isa AbstractArray && return b return fill(b, size(x)) @@ -297,7 +297,7 @@ function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) idxs = arguments(x)[2:end] bounds = map(bounds) do b if b isa AbstractArray - if Symbolics.shape(p) != Symbolics.Unknown() && size(p) != size(b) + if symbolic_has_known_size(p) && size(p) != size(b) throw(DimensionMismatch("Expected array variable $p with shape $(size(p)) to have bounds of identical size. Found $bounds of size $(size(bounds)).")) end return b[idxs...] From d82061753ebd480adef63ac52549b43b3028e564 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 30 Sep 2025 23:19:00 +0530 Subject: [PATCH 021/157] refactor: remove usages of `array_term` --- src/discretedomain.jl | 18 +++++++++--------- src/systems/abstractsystem.jl | 19 +++++-------------- src/systems/callbacks.jl | 5 +++-- src/systems/connectors.jl | 3 +-- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index eb47eecec9..54f451af78 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -33,7 +33,8 @@ at the inferred clock for that equation. struct SampleTime <: Operator SampleTime() = SymbolicUtils.term(SampleTime, type = Real) end -SymbolicUtils.promote_symtype(::Type{<:SampleTime}, t...) = Real +SymbolicUtils.promote_symtype(::Type{SampleTime}, ::Type{T}) where {T} = Real +SymbolicUtils.promote_shape(::Type{SampleTime}, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::SampleTime) = :SampleTime SymbolicUtils.isbinop(::SampleTime) = false @@ -71,11 +72,7 @@ SymbolicUtils.isbinop(::Shift) = false function (D::Shift)(x, allow_zero = false) !allow_zero && D.steps == 0 && return x - if Symbolics.isarraysymbolic(x) - Symbolics.array_term(D, x) - else - term(D, x) - end + term(D, x; type = symtype(x), shape = SU.shape(x)) end function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) !allow_zero && D.steps == 0 && return x @@ -94,7 +91,8 @@ function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) end wrap(D(vt, allow_zero)) end -SymbolicUtils.promote_symtype(::Shift, t) = t +SymbolicUtils.promote_symtype(::Shift, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Shift, @nospecialize(x::SU.ShapeT)) = x Base.show(io::IO, D::Shift) = print(io, "Shift(", D.t, ", ", D.steps, ")") @@ -164,7 +162,8 @@ function Sample(arg::Real) end (D::Sample)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Sample)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Sample, x) = x +SymbolicUtils.promote_symtype(::Sample, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Sample, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Sample) = :Sample SymbolicUtils.isbinop(::Sample) = false @@ -210,7 +209,8 @@ is_transparent_operator(::Type{Hold}) = true (D::Hold)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Hold)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Hold, x) = x +SymbolicUtils.promote_symtype(::Hold, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Hold, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Hold) = :Hold SymbolicUtils.isbinop(::Hold) = false diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index a1f76145cf..f9a1f09f4b 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -488,7 +488,8 @@ of a system. See the documentation section on initialization for more informatio struct Initial <: Symbolics.Operator end is_timevarying_operator(::Type{Initial}) = false Initial(x) = Initial()(x) -SymbolicUtils.promote_symtype(::Type{Initial}, T) = T +SymbolicUtils.promote_symtype(::Initial, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Initial, @nospecialize(x::SU.ShapeT)) = x SymbolicUtils.isbinop(::Initial) = false Base.nameof(::Initial) = :Initial Base.show(io::IO, x::Initial) = print(io, "Initial") @@ -508,15 +509,14 @@ function (f::Initial)(x) # don't double wrap iscall(x) && operation(x) isa Initial && return x result = if symbolic_type(x) == ArraySymbolic() - # create an array for `Initial(array)` - Symbolics.array_term(f, x) + term(f, x; type = symtype(x), shape = SU.shape(x)) elseif iscall(x) && operation(x) == getindex # instead of `Initial(x[1])` create `Initial(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, f(arr), arguments(x)[2:end]...) + f(arr)[arguments(x)[2:end]...] else - term(f, x) + term(f, x; type = symtype(x), shape = SU.shape(x)) end # the result should be a parameter result = toparam(result) @@ -526,15 +526,6 @@ function (f::Initial)(x) return result end -# This is required so `substitute` works -function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta) - val = Initial()(args...) - if symbolic_type(val) == NotSymbolic() - return val - end - return metadata(val, meta) -end - supports_initialization(sys::AbstractSystem) = true function add_initialization_parameters(sys::AbstractSystem; split = true) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 19725cbfc7..bc603db3ff 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -183,14 +183,15 @@ function (p::Pre)(x) iscall(x) && operation(x) isa Pre && return x result = if symbolic_type(x) == ArraySymbolic() # create an array for `Pre(array)` + term(p, x; type = symtype(x), shape = SU.shape(x)) Symbolics.array_term(p, x) elseif iscall(x) && operation(x) == getindex # instead of `Pre(x[1])` create `Pre(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, p(arr), arguments(x)[2:end]...) + p(arr)[arguments(x)[2:end]...] else - term(p, x) + term(p, x; type = symtype(x), shape = SU.shape(x)) end # the result should be a parameter result = toparam(result) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 89be35b716..c987407538 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -948,8 +948,7 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS stream_var = only(arguments(expr)) iscall(stream_var) && operation(stream_var) === getindex || continue args = arguments(stream_var) - new_expr = Symbolics.array_term( - instream, args[1]; size = size(args[1]), ndims = ndims(args[1]))[args[2:end]...] + new_expr = term(instream, args[1]; type = symtype(args[1]), shape = SU.shape(args[1]))[args[2:end]...] instream_subs[expr] = new_expr end From c0e5e021438086e6ca0a1bc328cf4a0c1ddb1b81 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:47:30 +0530 Subject: [PATCH 022/157] fix: move ChainRulesCore to an extension --- Project.toml | 3 ++- src/adjoints.jl => ext/MTKChainRulesCoreExt.jl | 18 ++++++++++++++++++ src/ModelingToolkit.jl | 3 --- src/systems/problem_utils.jl | 4 +++- 4 files changed, 23 insertions(+), 5 deletions(-) rename src/adjoints.jl => ext/MTKChainRulesCoreExt.jl (85%) diff --git a/Project.toml b/Project.toml index 661debc6ce..8fbc943102 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -68,6 +67,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" @@ -77,6 +77,7 @@ Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKCasADiDynamicOptExt = "CasADi" +MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" diff --git a/src/adjoints.jl b/ext/MTKChainRulesCoreExt.jl similarity index 85% rename from src/adjoints.jl rename to ext/MTKChainRulesCoreExt.jl index 98266de938..c213a164a3 100644 --- a/src/adjoints.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -1,3 +1,13 @@ +module MTKChainRulesCoreExt + +import ChainRulesCore +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk +using ModelingToolkit: MTKParameters, NONNUMERIC_PORTION, AbstractSystem +import ModelingToolkit as MTK +import SciMLStructures +import SymbolicIndexingInterface: remake_buffer +import SciMLBase: AbstractNonlinearProblem, remake + function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...) function mtp_pullback(dt) dt = unthunk(dt) @@ -104,3 +114,11 @@ function ChainRulesCore.rrule( end ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol) + +function ModelingToolkit.update_initializeprob!(initprob::AbstractNonlinearProblem, prob) + pgetter = ChainRulesCore.@ignore_derivatives MTK.get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + p = pgetter(prob, initprob) + return remake(initprob; p) +end + +end diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 223c6096bf..aaaa618e3e 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -68,8 +68,6 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc using OffsetArrays: Origin import CommonSolve import EnumX -import ChainRulesCore -import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr @@ -230,7 +228,6 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") -include("adjoints.jl") include("deprecations.jl") const t_nounits = let diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 7a301ea34b..3c64ed05fc 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -968,9 +968,11 @@ end A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires `is_update_oop = Val(true)` to be passed to `update_initializeprob!`. + +Any changes to this method should also be made to the one in ChainRulesCoreExt. """ function update_initializeprob!(initprob, prob) - pgetter = ChainRulesCore.@ignore_derivatives get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + pgetter = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter p = pgetter(prob, initprob) return remake(initprob; p) end From 99ebed168aade368954cf60b26f8bce90f63111b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:48:20 +0530 Subject: [PATCH 023/157] fix: move JuliaFormatter to an extension --- Project.toml | 3 ++- ext/MTKJuliaFormatterExt.jl | 12 ++++++++++++ src/ModelingToolkit.jl | 1 - src/utils.jl | 2 +- 4 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 ext/MTKJuliaFormatterExt.jl diff --git a/Project.toml b/Project.toml index 8fbc943102..c4303b57bd 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,6 @@ FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ImplicitDiscreteSolve = "3263718b-31ed-49cf-8a0f-35a466e8af96" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -71,6 +70,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" @@ -81,6 +81,7 @@ MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" +MTKJuliaFormatterExt = "JuliaFormatter" MTKLabelledArraysExt = "LabelledArrays" MTKPyomoDynamicOptExt = "Pyomo" diff --git a/ext/MTKJuliaFormatterExt.jl b/ext/MTKJuliaFormatterExt.jl new file mode 100644 index 0000000000..a4efeec931 --- /dev/null +++ b/ext/MTKJuliaFormatterExt.jl @@ -0,0 +1,12 @@ +module MTKJuliaFormatterExt + +import ModelingToolkit: readable_code, _readable_code, rec_remove_macro_linenums! +import JuliaFormatter + +function readable_code(expr::Expr) + expr = Base.remove_linenums!(_readable_code(expr)) + rec_remove_macro_linenums!(expr) + JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) +end + +end diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index aaaa618e3e..b76b02abc2 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -54,7 +54,6 @@ using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, Ti PeriodicClock, Clock, SolverStepClock, ContinuousClock, OverrideInit, NoInit using Distributed -import JuliaFormatter using MLStyle import Moshi using Moshi.Data: @data diff --git a/src/utils.jl b/src/utils.jl index 1c1727781c..4ddc13b22a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -84,7 +84,7 @@ end function readable_code(expr) expr = Base.remove_linenums!(_readable_code(expr)) rec_remove_macro_linenums!(expr) - JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) + return string(expr) end # System validation enums From 60e5f5f9b0985538d0f5e7409c9de0d2de230303 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:51:15 +0530 Subject: [PATCH 024/157] fix: fix invalidations from `promote_symtype` method --- src/systems/connectors.jl | 2 +- src/systems/state_machines.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index c987407538..50d79cadec 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -1115,4 +1115,4 @@ function instream_rt(ins::Val{inner_n}, outs::Val{outer_n}, for k in 1:M and ck.m_flow.max > 0 =# end -SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Vararg) = Real +SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Type{T}, ::Type{S}, ::Type{R}) where {T, S, R} = Real diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index ea65981804..48d9d2f4f6 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -75,7 +75,7 @@ for (s, T) in [(:timeInState, :Real), seed = hash(s) @eval begin $s(x) = wrap(term($s, x)) - SymbolicUtils.promote_symtype(::typeof($s), _...) = $T + SymbolicUtils.promote_symtype(::typeof($s), ::Type{S}) where {S} = $T function SymbolicUtils.show_call(io, ::typeof($s), args) if isempty(args) print(io, $s, "()") From f7533e2c76de67a4a96133959ad3790ac1455f78 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:52:47 +0530 Subject: [PATCH 025/157] fix: remove usages of `occursin` for searching expressions --- ext/MTKPyomoDynamicOptExt.jl | 5 +++-- src/structural_transformation/symbolics_tearing.jl | 2 +- src/systems/callbacks.jl | 2 +- src/systems/diffeqs/basic_transformations.jl | 2 +- src/systems/nonlinear/homotopy_continuation.jl | 6 +++--- src/systems/system.jl | 2 +- src/systems/systemstructure.jl | 2 +- src/utils.jl | 2 +- 8 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ext/MTKPyomoDynamicOptExt.jl b/ext/MTKPyomoDynamicOptExt.jl index fe18b2678d..58888bd974 100644 --- a/ext/MTKPyomoDynamicOptExt.jl +++ b/ext/MTKPyomoDynamicOptExt.jl @@ -5,6 +5,7 @@ using DiffEqBase using UnPack using NaNMath using Setfield +import SymbolicUtils as SU const MTK = ModelingToolkit const SPECIAL_FUNCTIONS_DICT = Dict([acos => Pyomo.py_acos, @@ -112,7 +113,7 @@ function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons; n_idxs = 1) Symbolics.unwrap(expr), SPECIAL_FUNCTIONS_DICT, fold = false) cons_sym = Symbol("cons", hash(cons)) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) setproperty!(model, cons_sym, pyomo.Constraint(model.t, rule = Pyomo.pyfunc(f))) else @@ -124,7 +125,7 @@ end function MTK.set_objective!(pmodel::PyomoDynamicOptModel, expr) @unpack model, model_sym, t_sym, dummy_sym = pmodel expr = Symbolics.substitute(expr, SPECIAL_FUNCTIONS_DICT, fold = false) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) model.obj = pyomo.Objective(model.t, rule = Pyomo.pyfunc(f)) else diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 051387b939..ad3e889dff 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -108,7 +108,7 @@ end function solve_equation(eq, var, simplify) rhs = value(symbolic_linear_solve(eq, var; simplify = simplify, check = false)) - occursin(var, rhs) && throw(EquationSolveErrors(eq, var, rhs)) + SU.query(in(var), rhs) && throw(EquationSolveErrors(eq, var, rhs)) var ~ rhs end diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index bc603db3ff..a560e11208 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -69,7 +69,7 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], discrete_parameters = unwrap.(discrete_parameters) for p in discrete_parameters - occursin(unwrap(iv), unwrap(p)) || + SU.query(isequal(unwrap(iv)), unwrap(p)) || error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).") end diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index 4cc71db108..cb173a15fc 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -142,7 +142,7 @@ function change_of_variables( for (new_var, ex, first, second) in zip(new_vars, dfdt, ∂f∂x, ∂2f∂x2) for (eqs, neq) in zip(old_eqs, neqs) - if occursin(value(eqs.lhs), value(ex)) + if SU.query(isequal(value(eqs.lhs)), value(ex)) ex = substitute(ex, eqs.lhs => eqs.rhs) if isSDE for (noise, B) in zip(neq, brownvars) diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 96c00411ad..cce180511d 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -1,5 +1,5 @@ function contains_variable(x, wrt) - any(y -> occursin(y, x), wrt) + any(y -> SU.query(isequal(y), x), wrt) end """ @@ -270,7 +270,7 @@ function PolynomialTransformation(sys::System) transformation_err = nothing for t in all_non_poly_terms # if the term involves multiple unknowns, we can't invert it - dvs_in_term = map(x -> occursin(x, t), dvs) + dvs_in_term = map(x -> SU.query(isequal(x), t), dvs) if count(dvs_in_term) > 1 transformation_err = MultivarTerm(t, dvs[dvs_in_term]) is_poly = false @@ -369,7 +369,7 @@ function transform_system(sys::System, transformation::PolynomialTransformation; t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs)) # the substituted variable occurs outside the substituted term poly_and_nonpoly = map(dvs) do x - all(!isequal(x), new_dvs) && occursin(x, t) + all(!isequal(x), new_dvs) && SU.query(isequal(x), t) end if any(poly_and_nonpoly) return NotPolynomialError( diff --git a/src/systems/system.jl b/src/systems/system.jl index c0f71cc160..77ce78fb5d 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -681,7 +681,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) for var in auxvars if !iscall(var) - occursin(iv, var) && (var ∈ sts || + SU.query(isequal(iv), var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))) elseif length(arguments(var)) > 1 throw(ArgumentError("Too many arguments for variable $var.")) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 3c9c429468..2ffda19830 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -680,7 +680,7 @@ function trivial_tearing!(ts::TearingState) end isvalid || continue # skip if the LHS is present in the RHS, since then this isn't explicit - if occursin(eq.lhs, eq.rhs) + if SU.query(isequal(eq.lhs), eq.rhs) push!(blacklist, i) continue end diff --git a/src/utils.jl b/src/utils.jl index 4ddc13b22a..42859dc37b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -151,7 +151,7 @@ function check_variables(dvs, iv) for dv in dvs isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables.")) - (is_delay_var(iv, dv) || occursin(iv, dv)) || + (is_delay_var(iv, dv) || SU.query!(isequal(iv), dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv.")) end end From 04c7e95277dc3ca4dc043dbceee282d98e299ba2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:53:03 +0530 Subject: [PATCH 026/157] fix: remove usages of deprecated `children` --- src/structural_transformation/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index d4211b16e1..6f54f33029 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -535,7 +535,7 @@ function shift2term(var) Symbol(string(nameof(oldop))) newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), - Symbolics.children(O), Symbolics.metadata(O)) + arguments(O), SU.metadata(O)) newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname)) newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) From 072384057896a38b2a7c4175eb2e060503f37487 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:56:01 +0530 Subject: [PATCH 027/157] refactor: concretely type some utility functions --- src/utils.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 42859dc37b..34c28bcf3d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -271,51 +271,51 @@ function setdefault(v, val) val === nothing ? v : wrap(setdefaultval(unwrap(v), value(val))) end -function process_variables!(var_to_name, defs, guesses, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, guesses::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_guesses!(guesses, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function process_variables!(var_to_name, defs, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function collect_defaults!(defs, vars) +function collect_defaults!(defs::SymmapT, vars::Vector{SymbolicT}) for v in vars - symbolic_type(v) == NotSymbolic() && continue - if haskey(defs, v) || !hasdefault(unwrap(v)) || (def = getdefault(v)) === nothing + isconst(v) && continue + if haskey(defs, v) || (def = Symbolics.getdefaultval(v, nothing)) === nothing continue end - defs[v] = getdefault(v) + defs[v] = SU.Const{VartypeT}(def) end return defs end -function collect_guesses!(guesses, vars) +function collect_guesses!(guesses::SymmapT, vars::Vector{SymbolicT}) for v in vars + isconst(v) && continue symbolic_type(v) == NotSymbolic() && continue - if haskey(guesses, v) || !hasguess(unwrap(v)) || (def = getguess(v)) === nothing + if haskey(guesses, v) || (def = getguess(v)) === nothing continue end - guesses[v] = getguess(v) + guesses[v] = SU.Const{VartypeT}(def) end return guesses end -function collect_var_to_name!(vars, xs) +function collect_var_to_name!(vars::Dict{Symbol, SymbolicT}, xs::Vector{SymbolicT}) for x in xs - symbolic_type(x) == NotSymbolic() && continue - x = unwrap(x) - if iscall(x) && operation(x) === getindex - x = arguments(x)[1] + x = Moshi.Match.@match x begin + BSImpl.Const(;) => continue + BSImpl.Term(; f, args) && if f === getindex end => args[1] + _ => x end - x = unwrap(x) hasname(x) || continue - vars[Symbolics.getname(unwrap(x))] = x + vars[getname(x)] = x end end From b5b373b5d76db581cdf3de28dff6c67eab5d884b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:56:22 +0530 Subject: [PATCH 028/157] fix: handle `Const` in connection expansion --- src/systems/analysis_points.jl | 2 +- src/systems/connectors.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index a5a612b9ca..42bf94eb02 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -250,7 +250,7 @@ Remove all `AnalysisPoint`s in `sys` and any of its subsystems, replacing them b """ function remove_analysis_points(sys::AbstractSystem) eqs = map(get_eqs(sys)) do eq - eq.lhs isa AnalysisPoint ? to_connection(eq.rhs) : eq + value(eq.lhs) isa AnalysisPoint ? to_connection(value(eq.rhs)) : eq end @set! sys.eqs = eqs @set! sys.systems = map(remove_analysis_points, get_systems(sys)) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 50d79cadec..95ca70e6be 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -632,8 +632,8 @@ function returned from `generate_isouter`. """ function handle_maybe_connect_equation!(eqs, state::AbstractConnectionState, eq::Equation, namespace::Vector{Symbol}, isouter) - lhs = eq.lhs - rhs = eq.rhs + lhs = value(eq.lhs) + rhs = value(eq.rhs) if !(lhs isa Connection) # split connections and equations From b3f1d323f3194fd4d5a7cd4809e09c5d6449db54 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:56:47 +0530 Subject: [PATCH 029/157] refactor: improve type-stability of `renamespace` --- src/systems/abstractsystem.jl | 61 +++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f9a1f09f4b..9ae0b0cc82 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1089,40 +1089,51 @@ renamespace(sys, eq::Equation) = namespace_equation(eq, sys) renamespace(names::AbstractVector, x) = foldr(renamespace, names, init = x) +renamespace(sys, tgt::AbstractSystem) = rename(tgt, renamespace(sys, nameof(tgt))) +renamespace(sys, tgt::Symbol) = Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, tgt) + """ $(TYPEDSIGNATURES) Namespace `x` with the name of `sys`. """ -function renamespace(sys, x) - sys === nothing && return x - x = unwrap(x) - if x isa SymbolicT - T = typeof(x) - if iscall(x) && operation(x) isa Operator - return maketerm(typeof(x), operation(x), - Any[renamespace(sys, only(arguments(x)))], - metadata(x))::T - end - if iscall(x) && operation(x) === getindex - args = arguments(x) - return maketerm( - typeof(x), operation(x), vcat(renamespace(sys, args[1]), args[2:end]), - metadata(x))::T - end - let scope = getmetadata(x, SymScope, LocalScope()) +function renamespace(sys, x::SymbolicT) + Moshi.Match.@match x begin + BSImpl.Sym(; name) => let scope = getmetadata(x, SymScope, LocalScope())::Union{LocalScope, ParentScope, GlobalScope} if scope isa LocalScope - rename(x, renamespace(getname(sys), getname(x)))::T + return rename(x, renamespace(getname(sys), name))::SymbolicT elseif scope isa ParentScope - setmetadata(x, SymScope, scope.parent)::T - else # GlobalScope - x::T + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x end + error() + end + BSImpl.Term(; f, args, shape, type, metadata) => begin + if f === getindex + newargs = copy(parent(args)) + newargs[1] = renamespace(sys, args[1]) + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + elseif f isa SymbolicT + let scope = getmetadata(x, SymScope, LocalScope())::Union{LocalScope, ParentScope, GlobalScope} + if scope isa LocalScope + return rename(x, renamespace(getname(sys), getname(x)))::SymbolicT + elseif scope isa ParentScope + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x + end + error() + end + elseif f isa Operator + newargs = copy(parent(args)) + for (i, arg) in enumerate(args) + newargs[i] = renamespace(sys, arg) + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata) + end + error() end - elseif x isa AbstractSystem - rename(x, renamespace(sys, nameof(x))) - else - Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, x) end end From 89cf3a7855114f5448e129d857403dd6c22c0455 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:57:59 +0530 Subject: [PATCH 030/157] refactor: make `System` more concretely typed --- src/ModelingToolkit.jl | 3 + src/systems/system.jl | 135 ++++++++++++++++++++++++----------------- 2 files changed, 83 insertions(+), 55 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index b76b02abc2..d23311b95c 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -159,6 +159,9 @@ include("parameters.jl") include("independent_variables.jl") include("constants.jl") +const SymmapT = Dict{SymbolicT, SymbolicT} +const COMMON_NOTHING = SU.Const{VartypeT}(nothing) + include("utils.jl") include("systems/index_cache.jl") diff --git a/src/systems/system.jl b/src/systems/system.jl index 77ce78fb5d..62644b6643 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -63,7 +63,7 @@ struct System <: IntermediateDeprecationSystem loss of an optimization problem. Scalar loss values must also be provided as a single- element vector. """ - costs::Vector{<:Union{BasicSymbolic, Real}} + costs::Vector{SymbolicT} """ A function which combines costs into a scalar value. This should take two arguments, the `costs` of this system and the consolidated costs of all subsystems in the order @@ -76,20 +76,20 @@ struct System <: IntermediateDeprecationSystem The variables being solved for by this system. For example, in a differential equation system, this contains the dependent variables. """ - unknowns::Vector + unknowns::Vector{SymbolicT} """ The parameters of the system. Parameters can either be variables that parameterize the problem being solved for (e.g. the spring constant of a mass-spring system) or additional unknowns not part of the main dynamics of the system (e.g. discrete/clocked variables in a hybrid ODE). """ - ps::Vector + ps::Vector{SymbolicT} """ The brownian variables of the system, created via `@brownians`. Each brownian variable represents an independent noise. A system with brownians cannot be simulated directly. It needs to be compiled using `mtkcompile` into `noise_eqs`. """ - brownians::Vector + brownians::Vector{SymbolicT} """ The independent variable for a time-dependent system, or `nothing` for a time-independent system. @@ -117,7 +117,7 @@ struct System <: IntermediateDeprecationSystem A mapping from the name of a variable to the actual symbolic variable in the system. This is used to enable `getproperty` syntax to access variables of a system. """ - var_to_name::Dict{Symbol, Any} + var_to_name::Dict{Symbol, SymbolicT} """ The name of the system. """ @@ -132,11 +132,11 @@ struct System <: IntermediateDeprecationSystem by initial values provided to the problem constructor. Defaults of parent systems take priority over those in child systems. """ - defaults::Dict + defaults::SymmapT """ Guess values for variables of a system that are solved for during initialization. """ - guesses::Dict + guesses::SymmapT """ A list of subsystems of this system. Used for hierarchically building models. """ @@ -167,7 +167,7 @@ struct System <: IntermediateDeprecationSystem associated error message. By default these assertions cause the generated code to output `NaN`s if violated, but can be made to error using `debug_system`. """ - assertions::Dict{BasicSymbolic, String} + assertions::Dict{SymbolicT, String} """ The metadata associated with this system, as a `Base.ImmutableDict`. This follows the same interface as SymbolicUtils.jl. Metadata can be queried and updated using @@ -193,12 +193,12 @@ struct System <: IntermediateDeprecationSystem $INTERNAL_FIELD_WARNING The list of input variables of the system. """ - inputs::OrderedSet{BasicSymbolic} + inputs::OrderedSet{SymbolicT} """ $INTERNAL_FIELD_WARNING The list of output variables of the system. """ - outputs::OrderedSet{BasicSymbolic} + outputs::OrderedSet{SymbolicT} """ The `TearingState` of the system post-simplification with `mtkcompile`. """ @@ -264,9 +264,9 @@ struct System <: IntermediateDeprecationSystem tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, - connector_type, assertions = Dict{BasicSymbolic, String}(), + connector_type, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [], - inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(), + inputs = Set{SymbolicT}(), outputs = Set{SymbolicT}(), tearing_state = nothing, namespacing = true, complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, @@ -321,6 +321,26 @@ function default_consolidate(costs, subcosts) return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0) end +function unwrap_vars(vars::AbstractArray{SymbolicT}) + vec(vars) +end +function unwrap_vars(vars) + result = SymbolicT[] + for var in vars + push!(result, unwrap(var)) + end + return result +end + +defsdict(x::SymmapT) = x +function defsdict(x::Union{AbstractDict, AbstractArray{<:Pair}}) + result = SymmapT() + for (k, v) in x + result[unwrap(k)] = SU.Const{VartypeT}(v) + end + return result +end + """ $(TYPEDSIGNATURES) @@ -337,74 +357,74 @@ for time-independent systems, unknowns `dvs`, parameters `ps` and brownian varia All other keyword arguments are named identically to the corresponding fields in [`System`](@ref). """ -function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; - constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = [], - costs = BasicSymbolic[], consolidate = default_consolidate, - observed = Equation[], parameter_dependencies = Equation[], defaults = Dict(), - guesses = Dict(), systems = System[], initialization_eqs = Equation[], +function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; + constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = JumpType[], + costs = SymbolicT[], consolidate = default_consolidate, + observed = Equation[], parameter_dependencies = Equation[], defaults = SymmapT(), + guesses = SymmapT(), systems = System[], initialization_eqs = Equation[], continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[], - connector_type = nothing, assertions = Dict{BasicSymbolic, String}(), + connector_type = nothing, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, - is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(), - outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing, + is_dde = nothing, tstops = [], inputs = OrderedSet{SymbolicT}(), + outputs = OrderedSet{SymbolicT}(), tearing_state = nothing, ignored_connections = nothing, parent = nothing, description = "", name = nothing, discover_from_metadata = true, initializesystem = nothing, is_initializesystem = false, is_discrete = false, preface = [], checks = true) name === nothing && throw(NoNameError()) if !isempty(parameter_dependencies) - @warn """ - The `parameter_dependencies` keyword argument is deprecated. Please provide all - such equations as part of the normal equations of the system. - """ + @invokelatest warn_pdeps() eqs = Equation[eqs; parameter_dependencies] end iv = unwrap(iv) - ps = unwrap.(ps) - dvs = unwrap.(dvs) - filter!(!Base.Fix2(isdelay, iv), dvs) - brownians = unwrap.(brownians) + ps = unwrap_vars(ps) + dvs = unwrap_vars(dvs) + if iv !== nothing + filter!(!Base.Fix2(isdelay, iv), dvs) + end + brownians = unwrap_vars(brownians) - if !(eqs isa AbstractArray) - eqs = [eqs] + if !(eqs isa Vector{Equation}) + eqs = Equation[eqs] end + eqs = eqs::Vector{Equation} if noise_eqs !== nothing noise_eqs = unwrap.(noise_eqs) end - costs = unwrap.(costs) - if isempty(costs) - costs = Union{BasicSymbolic, Real}[] - end - - defaults = anydict(defaults) - guesses = anydict(guesses) + costs = unwrap_vars(costs) - inputs = unwrap.(inputs) - outputs = unwrap.(outputs) - inputs = OrderedSet{BasicSymbolic}(inputs) - outputs = OrderedSet{BasicSymbolic}(outputs) + defaults = defsdict(defaults) + guesses = defsdict(guesses) + inputs = unwrap_vars(inputs) + outputs = unwrap_vars(outputs) + if !(inputs isa OrderedSet{SymbolicT}) + inputs = OrderedSet{SymbolicT}(inputs) + end + if !(outputs isa OrderedSet{SymbolicT}) + outputs = OrderedSet{SymbolicT}(outputs) + end for subsys in systems - for var in ModelingToolkit.inputs(subsys) + for var in get_inputs(subsys) push!(inputs, renamespace(subsys, var)) end - for var in ModelingToolkit.outputs(subsys) + for var in get_outputs(subsys) push!(outputs, renamespace(subsys, var)) end end - var_to_name = anydict() + var_to_name = Dict{Symbol, SymbolicT}() - let defaults = discover_from_metadata ? defaults : Dict(), - guesses = discover_from_metadata ? guesses : Dict(), - inputs = discover_from_metadata ? inputs : Set(), - outputs = discover_from_metadata ? outputs : Set() + let defaults = discover_from_metadata ? defaults : SymmapT(), + guesses = discover_from_metadata ? guesses : SymmapT(), + inputs = discover_from_metadata ? inputs : OrderedSet{SymbolicT}(), + outputs = discover_from_metadata ? outputs : OrderedSet{SymbolicT}() process_variables!(var_to_name, defaults, guesses, dvs) process_variables!(var_to_name, defaults, guesses, ps) - process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed]) - process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed]) + process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.lhs for eq in observed]) + process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.rhs for eq in observed]) for var in dvs if isinput(var) @@ -414,10 +434,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; end end end - filter!(!(isnothing ∘ last), defaults) - filter!(!(isnothing ∘ last), guesses) - defaults = anydict([unwrap(k) => unwrap(v) for (k, v) in defaults]) - guesses = anydict([unwrap(k) => unwrap(v) for (k, v) in guesses]) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), defaults) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), guesses) sysnames = nameof.(systems) unique_sysnames = Set(sysnames) @@ -436,7 +454,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; is_dde = _check_if_dde(eqs, iv, systems) end - assertions = Dict{BasicSymbolic, String}(unwrap(k) => v for (k, v) in assertions) + assertions = Dict{SymbolicT, String}(unwrap(k) => v for (k, v) in assertions) if isempty(metadata) metadata = MetadataT() @@ -459,6 +477,13 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; initializesystem, is_initializesystem, is_discrete; checks) end +@noinline function warn_pdeps() + @warn """ + The `parameter_dependencies` keyword argument is deprecated. Please provide all + such equations as part of the normal equations of the system. + """ +end + SymbolicIndexingInterface.getname(x::System) = nameof(x) """ From 53207025c24237635c4458354db09f417541ffc9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:58:38 +0530 Subject: [PATCH 031/157] fix: fix usage of new `get_variables` in `is_bound` --- src/inputoutput.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 8d4772ef64..5afc6b2000 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -49,6 +49,13 @@ See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@re """ unbound_outputs(sys) = filter(x -> !is_bound(sys, x), outputs(sys)) +function _is_atomic_inside_operator(ex::SymbolicT) + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => false + _ => true + end +end + """ is_bound(sys, u) @@ -75,8 +82,11 @@ function is_bound(sys, u, stack = []) eqs = equations(sys) eqs = filter(eq -> has_var(eq, u), eqs) # Only look at equations that contain u # isout = isoutput(u) + vars = Set{SymbolicT}() for eq in eqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + empty!(vars) + get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) for var in vars var === u && continue if !same_or_inner_namespace(u, var) @@ -88,7 +98,9 @@ function is_bound(sys, u, stack = []) oeqs = observed(sys) oeqs = filter(eq -> has_var(eq, u), oeqs) # Only look at equations that contain u for eq in oeqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + empty!(vars) + get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) for var in vars var === u && continue if !same_or_inner_namespace(u, var) From 0a8a58b6a291825e61ff2dd2d9f5133d139a14a2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 16:59:08 +0530 Subject: [PATCH 032/157] fix: handle removal of `operator` kwarg of `substitute` --- src/systems/systemstructure.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 2ffda19830..a30e399047 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -752,11 +752,11 @@ function shift_discrete_system(ts::TearingState) for i in eachindex(fullvars) fullvars[i] = StructuralTransformations.simplify_shifts(substitute( - fullvars[i], discmap; operator = Union{Sample, Hold, Pre})) + fullvars[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end for i in eachindex(eqs) eqs[i] = StructuralTransformations.simplify_shifts(substitute( - eqs[i], discmap; operator = Union{Sample, Hold, Pre})) + eqs[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end @set! ts.sys.eqs = eqs @set! ts.fullvars = fullvars From bb5138b9b47e76217b1810b9644b07db15d1fdfe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 1 Oct 2025 17:57:14 +0530 Subject: [PATCH 033/157] refactor: get `System` to precompile in a trivial case --- src/ModelingToolkit.jl | 101 +++++++++++++++++++++++++++++++- src/systems/callbacks.jl | 3 + src/systems/system.jl | 123 ++++++++++++++++++++++----------------- src/utils.jl | 75 ++++++++++++++---------- 4 files changed, 215 insertions(+), 87 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d23311b95c..36816a1adc 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -360,10 +360,93 @@ for prop in [SYS_PROPS; [:continuous_events, :discrete_events]] end PrecompileTools.@compile_workload begin - using ModelingToolkit + fold1 = Val{false}() + using SymbolicUtils + using SymbolicUtils: shape + using Symbolics + @syms x y f(t) q[1:5] + SymbolicUtils.Sym{SymReal}(:a; type = Real, shape = SymbolicUtils.ShapeVecT()) + x + y + x * y + x / y + x ^ y + x ^ 5 + 6 ^ x + x - y + -y + 2y + z = 2 + dict = SymbolicUtils.ACDict{VartypeT}() + dict[x] = 1 + dict[y] = 1 + type::typeof(DataType) = rand() < 0.5 ? Real : Float64 + nt = (; type, shape, unsafe = true) + Base.pairs(nt) + BSImpl.AddMul{VartypeT}(1, dict, SymbolicUtils.AddMulVariant.MUL; type, shape = SymbolicUtils.ShapeVecT(), unsafe = true) + *(y, z) + *(z, y) + SymbolicUtils.symtype(y) + f(x) + (5x / 5) + expand((x + y) ^ 2) + simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + rules2 = Dict(x => 1) + Dx = Differential(x) + Differential(y)(ex) + uex = unwrap(ex) + Symbolics.executediff(Dx, uex) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + substitute(ex, rules2; fold = fold1) + @variables foo + f(foo) + @variables x y f(::Real) q[1:5] + x + y + x * y + x / y + x ^ y + x ^ 5 + # 6 ^ x + x - y + -y + 2y + symtype(y) + z = 2 + *(y, z) + *(z, y) + f(x) + (5x / 5) + [x, y] + [x, f, f] + promote_type(Int, Num) + promote_type(Real, Num) + promote_type(Float64, Num) + # expand((x + y) ^ 2) + # simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + # rules2 = Dict(x => 1) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + Symbolics.linear_expansion(ex, y) + # substitute(ex, rules2; fold = fold1) + # substitute(ex, rules2) + # substitute(ex, rules1; fold = fold2) + # substitute(ex, rules2; fold = fold2) + q[1] + q'q + using ModelingToolkit @variables x(ModelingToolkit.t_nounits) - @named sys = System([ModelingToolkit.D_nounits(x) ~ -x], ModelingToolkit.t_nounits) - prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), jac = true) + isequal(ModelingToolkit.D_nounits.x, ModelingToolkit.t_nounits) + sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + # mtkcompile(sys) @mtkmodel __testmod__ begin @constants begin c = 1.0 @@ -394,4 +477,16 @@ PrecompileTools.@compile_workload begin end end +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Nothing, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:id, :hash, :hash2), Tuple{Nothing, Int64, Int64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, Int64, Int64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, Nothing}}, Type{SymbolicUtils.BasicSymbolicImpl.Term{SymbolicUtils.SymReal}}}) +precompile(Tuple{typeof(Symbolics.parse_vars), Symbol, Type, Tuple{Symbol, Symbol}, Function}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:name, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{Symbol, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.vect), Symbolics.Equation, Vararg{Symbolics.Equation}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(SymbolicUtils.isequal_somescalar), Float64, Float64}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults, :guesses), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults, :guesses), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) + end # module diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a560e11208..ae63176a92 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -568,6 +568,9 @@ conditions(cb::AbstractCallback) = cb.conditions function conditions(cbs::Vector{<:AbstractCallback}) reduce(vcat, conditions(cb) for cb in cbs; init = []) end +function conditions(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(conditions, vcat, cbs; init = Equation[]) +end equations(cb::AbstractCallback) = conditions(cb) equations(cb::Vector{<:AbstractCallback}) = conditions(cb) diff --git a/src/systems/system.jl b/src/systems/system.jl index 62644b6643..2b2fd9aa7e 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -46,7 +46,7 @@ struct System <: IntermediateDeprecationSystem this noise matrix is diagonal. Diagonal noise can be specified by providing an `N` length vector. If this field is `nothing`, the system does not have noise. """ - noise_eqs::Union{Nothing, AbstractVector, AbstractMatrix} + noise_eqs::Union{Nothing, Vector{SymbolicT}, Matrix{SymbolicT}} """ Jumps associated with the system. Each jump can be a `VariableRateJump`, `ConstantRateJump` or `MassActionJump`. See `JumpProcesses.jl` for more information. @@ -279,30 +279,37 @@ struct System <: IntermediateDeprecationSystem """)) end @assert iv === nothing || symtype(iv) === Real - jumps = Vector{JumpType}(jumps) - if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing - check_independent_variables([iv]) + if (checks isa Bool && checks === true || checks isa Int && (checks & CheckComponents) > 0) && iv !== nothing + check_independent_variables((iv,)) check_variables(unknowns, iv) check_parameters(ps, iv) check_equations(eqs, iv) - if noise_eqs !== nothing && size(noise_eqs, 1) != length(eqs) - throw(IllFormedNoiseEquationsError(size(noise_eqs, 1), length(eqs))) + Neq = length(eqs) + if noise_eqs isa Matrix{SymbolicT} + N1 = size(noise_eqs, 1) + elseif noise_eqs isa Vector{SymbolicT} + N1 = length(noise_eqs) + elseif noise_eqs === nothing + N1 = Neq + else + error() end + N1 == Neq || throw(IllFormedNoiseEquationsError(N1, Neq)) check_equations(equations(continuous_events), iv) check_subsystems(systems) end - if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(unknowns, ps, iv) - if noise_eqs === nothing - check_units(u, eqs) - else - check_units(u, eqs, noise_eqs) - end - if iv !== nothing - check_units(u, jumps, iv) - end - isempty(constraints) || check_units(u, constraints) - end + # if checks == true || (checks & CheckUnits) > 0 + # u = __get_unit_type(unknowns, ps, iv) + # if noise_eqs === nothing + # check_units(u, eqs) + # else + # check_units(u, eqs, noise_eqs) + # end + # if iv !== nothing + # check_units(u, jumps, iv) + # end + # isempty(constraints) || check_units(u, constraints) + # end new(tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, @@ -321,13 +328,11 @@ function default_consolidate(costs, subcosts) return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0) end -function unwrap_vars(vars::AbstractArray{SymbolicT}) - vec(vars) -end -function unwrap_vars(vars) - result = SymbolicT[] - for var in vars - push!(result, unwrap(var)) +unwrap_vars(vars::AbstractArray{SymbolicT}) = vars +function unwrap_vars(vars::AbstractArray) + result = similar(vars, SymbolicT) + for i in eachindex(vars) + result[i] = SU.Const{VartypeT}(vars[i]) end return result end @@ -372,29 +377,30 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; initializesystem = nothing, is_initializesystem = false, is_discrete = false, preface = [], checks = true) name === nothing && throw(NoNameError()) + + if !(eqs isa Vector{Equation}) + eqs = Equation[eqs] + end + eqs = eqs::Vector{Equation} + if !isempty(parameter_dependencies) @invokelatest warn_pdeps() - eqs = Equation[eqs; parameter_dependencies] + append!(eqs, parameter_dependencies) end iv = unwrap(iv) - ps = unwrap_vars(ps) - dvs = unwrap_vars(dvs) + ps = vec(unwrap_vars(ps)) + dvs = vec(unwrap_vars(dvs)) if iv !== nothing filter!(!Base.Fix2(isdelay, iv), dvs) end brownians = unwrap_vars(brownians) - if !(eqs isa Vector{Equation}) - eqs = Equation[eqs] - end - eqs = eqs::Vector{Equation} - if noise_eqs !== nothing - noise_eqs = unwrap.(noise_eqs) + noise_eqs = unwrap_vars(noise_eqs) end - costs = unwrap_vars(costs) + costs = vec(unwrap_vars(costs)) defaults = defsdict(defaults) guesses = defsdict(guesses) @@ -423,8 +429,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; process_variables!(var_to_name, defaults, guesses, dvs) process_variables!(var_to_name, defaults, guesses, ps) - process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.lhs for eq in observed]) - process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.rhs for eq in observed]) + buffer = SymbolicT[] + for eq in observed + push!(buffer, eq.lhs) + push!(buffer, eq.rhs) + end + process_variables!(var_to_name, defaults, guesses, buffer) for var in dvs if isinput(var) @@ -437,10 +447,9 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), defaults) filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), guesses) - sysnames = nameof.(systems) - unique_sysnames = Set(sysnames) - if length(unique_sysnames) != length(sysnames) - throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) + + if !allunique(map(nameof, systems)) + nonunique_subsystems(systems) end continuous_events, discrete_events = create_symbolic_events( @@ -454,7 +463,10 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; is_dde = _check_if_dde(eqs, iv, systems) end - assertions = Dict{SymbolicT, String}(unwrap(k) => v for (k, v) in assertions) + _assertions = Dict{SymbolicT, String} + for (k, v) in assertions + _assertions[unwrap(k)::SymbolicT] = v + end if isempty(metadata) metadata = MetadataT() @@ -468,6 +480,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; metadata = meta end metadata = refreshed_metadata(metadata) + jumps = Vector{JumpType}(jumps) System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints, costs, consolidate, dvs, ps, brownians, iv, observed, Equation[], var_to_name, name, description, defaults, guesses, systems, initialization_eqs, @@ -477,6 +490,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; initializesystem, is_initializesystem, is_discrete; checks) end +@noinline function nonunique_subsystems(systems) + sysnames = nameof.(systems) + unique_sysnames = Set(sysnames) + throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) +end + @noinline function warn_pdeps() @warn """ The `parameter_dependencies` keyword argument is deprecated. Please provide all @@ -751,19 +770,15 @@ differential equations. """ is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys) -function _check_if_dde(eqs, iv, subsystems) - is_dde = any(ModelingToolkit.is_dde, subsystems) - if !is_dde - vs = Set() - for eq in eqs - vars!(vs, eq) - is_dde = any(vs) do sym - isdelay(unwrap(sym), iv) - end - is_dde && break - end +_check_if_dde(eqs::Vector{Equation}, iv::Nothing, subsystems::Vector{System}) = false +function _check_if_dde(eqs::Vector{Equation}, iv::SymbolicT, subsystems::Vector{System}) + any(ModelingToolkit.is_dde, subsystems) && return true + pred = Base.Fix2(isdelay, iv) + for eq in eqs + SU.query(pred, eq.lhs) && return true + SU.query(pred, eq.rhs) && return true end - return is_dde + return false end """ diff --git a/src/utils.jl b/src/utils.jl index 34c28bcf3d..f75dd6d375 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -117,11 +117,14 @@ const CheckUnits = 1 << 2 function check_independent_variables(ivs) for iv in ivs - isparameter(iv) || - @warn "Independent variable $iv should be defined with @independent_variables $iv." + isparameter(iv) || @invokelatest warn_indepvar(iv) end end +@noinline function warn_indepvar(iv::SymbolicT) + @warn "Independent variable $iv should be defined with @independent_variables $iv." +end + function check_parameters(ps, iv) for p in ps isequal(iv, p) && @@ -129,22 +132,16 @@ function check_parameters(ps, iv) end end -function is_delay_var(iv, var) - if Symbolics.isarraysymbolic(var) - return is_delay_var(iv, first(collect(var))) - end - args = nothing - try - args = arguments(var) - catch - return false +function is_delay_var(iv::SymbolicT, var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(; f, args) => begin + length(args) > 1 && return false + arg = args[1] + isequal(arg, iv) && return false + return symtype(arg) <: Real + end + _ => false end - length(args) > 1 && return false - isequal(first(args), iv) && return false - delay = iv - first(args) - delay isa Integer || - delay isa AbstractFloat || - (delay isa Num && isreal(value(delay))) end function check_variables(dvs, iv) @@ -187,20 +184,35 @@ function collect_ivs(eqs, op = Differential) return ivs end +struct IndepvarCheckPredicate + iv::SymbolicT +end + +function (icp::IndepvarCheckPredicate)(ex::SymbolicT) + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Differential end => begin + f = f::Differential + isequal(f.x, icp.iv) || throw_multiple_iv(icp.iv, f.x) + return false + end + _ => false + end +end + +@noinline function throw_multiple_iv(iv, newiv) + throw(ArgumentError("Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed.")) +end + """ check_equations(eqs, iv) Assert that equations are well-formed when building ODE, i.e., only containing a single independent variable. """ -function check_equations(eqs, iv) - ivs = collect_ivs(eqs) - display = collect(ivs) - length(ivs) <= 1 || - throw(ArgumentError("Differential w.r.t. multiple variables $display are not allowed.")) - if length(ivs) == 1 - single_iv = pop!(ivs) - isequal(single_iv, iv) || - throw(ArgumentError("Differential w.r.t. variable ($single_iv) other than the independent variable ($iv) are not allowed.")) +function check_equations(eqs::Vector{Equation}, iv::SymbolicT) + icp = IndepvarCheckPredicate(iv) + for eq in eqs + SU.query!(icp, eq.lhs) + SU.query!(icp, eq.rhs) end end @@ -211,10 +223,12 @@ Assert that the subsystems have the appropriate namespacing behavior. """ function check_subsystems(systems) idxs = findall(!does_namespacing, systems) - if !isempty(idxs) - names = join(" " .* string.(nameof.(systems[idxs])), "\n") - throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) - end + isempty(idxs) || throw_bad_namespacing(systems, idxs) +end + +@noinline function throw_bad_namespacing(systems, idxs) + names = join(" " .* string.(nameof.(systems[idxs])), "\n") + throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) end """ @@ -626,6 +640,7 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics if issym(expr) return collect_var!(unknowns, parameters, expr, iv; depth) end + SymbolicUtils.isconst(expr) && return for var in vars(expr; op) while iscall(var) && operation(var) isa op validate_operator(operation(var), arguments(var), iv; context = expr) From 866b2a3a3209dc5fb0604dd78e8b8da442e87bfa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:24:40 +0530 Subject: [PATCH 034/157] fix: improve inference of variable metadata accessors --- src/utils.jl | 4 ++-- src/variables.jl | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index f75dd6d375..f88da709d1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -394,9 +394,9 @@ isdifferential(expr) = isoperator(expr, Differential) isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) isvariable(x::Num)::Bool = isvariable(value(x)) -function isvariable(x)::Bool +function isvariable(x) x isa SymbolicT || return false - hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1]) + hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1])::Bool end """ diff --git a/src/variables.jl b/src/variables.jl index 9ae5c48803..b3966f5ea6 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -218,13 +218,13 @@ setio(x, i::Bool, o::Bool) = setoutput(setinput(x, i), o) Check if variable `x` is marked as an input. """ -isinput(x) = isvarkind(VariableInput, x) +isinput(x) = isvarkind(VariableInput, x)::Bool """ $(TYPEDSIGNATURES) Check if variable `x` is marked as an output. """ -isoutput(x) = isvarkind(VariableOutput, x) +isoutput(x) = isvarkind(VariableOutput, x)::Bool # Before the solvability check, we already have handled IO variables, so # irreducibility is independent from IO. @@ -234,7 +234,7 @@ isoutput(x) = isvarkind(VariableOutput, x) Check if `x` is marked as irreducible. This prevents it from being eliminated as an observed variable in `mtkcompile`. """ -isirreducible(x) = isvarkind(VariableIrreducible, x) +isirreducible(x) = isvarkind(VariableIrreducible, x)::Bool setirreducible(x, v::Bool) = setmetadata(x, VariableIrreducible, v) state_priority(x::Union{Num, Symbolics.Arr}) = state_priority(unwrap(x)) """ @@ -339,7 +339,7 @@ isdisturbance(x::Num) = isdisturbance(Symbolics.unwrap(x)) Determine whether symbolic variable `x` is marked as a disturbance input. """ function isdisturbance(x) - isvarkind(VariableDisturbance, x) + isvarkind(VariableDisturbance, x)::Bool end setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v) @@ -370,7 +370,7 @@ Create a tunable parameter by See also [`tunable_parameters`](@ref), [`getbounds`](@ref) """ function istunable(x, default = true) - isvarkind(VariableTunable, x, default) + isvarkind(VariableTunable, x, default)::Bool end ## Dist ======================================================================== @@ -607,10 +607,10 @@ Check if the variable `x` has a unit. hasunit(x) = getunit(x) !== nothing getunshifted(x::Num) = getunshifted(unwrap(x)) -getunshifted(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnshifted, nothing) +getunshifted(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnshifted, nothing)::Union{SymbolicT, Nothing} getshift(x::Num) = getshift(unwrap(x)) -getshift(x::SymbolicT) = Symbolics.getmetadata(x, VariableShift, 0) +getshift(x::SymbolicT) = Symbolics.getmetadata(x, VariableShift, 0)::Int ################### ### Evaluate at ### From 65dc9aeddb50f5cc6714c988ce7e66ab29bcd918 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:24:53 +0530 Subject: [PATCH 035/157] fix: make `default_toterm` type stable --- src/variables.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/variables.jl b/src/variables.jl index b3966f5ea6..543617e7de 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -245,19 +245,30 @@ chosen as a state in `mtkcompile`. """ state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64 -normalize_to_differential(x) = x +function normalize_to_differential(@nospecialize(op)) + if op isa Shift && op.t isa SymbolicT + return Differential(op.t) ^ op.steps + else + return op + end +end -function default_toterm(x) - if iscall(x) && (op = operation(x)) isa Operator - if !(op isa Differential) - if op isa Shift && op.steps < 0 +default_toterm(x) = x +function default_toterm(x::SymbolicT) + Moshi.Match.@match x begin + BSImpl.Term(; f, args, shape, type, metadata) && if f isa Operator end => begin + if f isa Shift && f.steps < 0 return shift2term(x) + elseif f isa Differential + return Symbolics.diff2term(x) + else + newf = normalize_to_differential(f) + f === newf && return x + x = BSImpl.Term{VartypeT}(newf, args; type, shape, metadata) + return Symbolics.diff2term(x) end - x = normalize_to_differential(op)(arguments(x)...) end - Symbolics.diff2term(x) - else - x + _ => return x end end From 9a620276938f327c3272840bd211810d637d1b67 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:25:46 +0530 Subject: [PATCH 036/157] fix: improve inference of several utility functions --- src/utils.jl | 61 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index f88da709d1..fc5fbda207 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -526,13 +526,10 @@ ModelingToolkit.collect_applied_operators(eq, Differential) == Set([D(y)]) The difference compared to `collect_operator_variables` is that `collect_operator_variables` returns the variable without the operator applied. """ -function collect_applied_operators(x, op) - v = vars(x, op = op) - filter(v) do x - issym(x) && return false - iscall(x) && return operation(x) isa op - false - end +function collect_applied_operators(x::SymbolicT, ::Type{op}) where {op} + v = Set{SymbolicT}() + SU.search_variables!(v, x; is_atomic = OnlyOperatorIsAtomic{op}()) + return v end """ @@ -543,12 +540,12 @@ Search through equations and parameter dependencies of `sys`, where sys is at a recursively searches through all subsystems of `sys`, increasing the depth if it is not `-1`. A depth of `-1` indicates searching for variables with `GlobalScope`. """ -function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential) +function collect_scoped_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, sys::AbstractSystem, iv::Union{SymbolicT, Nothing}; depth = 1, op = Differential) if has_eqs(sys) for eq in equations(sys) eqtype_supports_collect_vars(eq) || continue if eq isa Equation - eq.lhs isa Union{SymbolicT, Number} || continue + symtype(eq.lhs) <: Number || continue end collect_vars!(unknowns, parameters, eq, iv; depth, op) end @@ -622,6 +619,24 @@ function Base.showerror(io::IO, err::OperatorIndepvarMismatchError) end end +struct OnlyOperatorIsAtomic{O} end + +function (::OnlyOperatorIsAtomic{O})(ex::SymbolicT) where {O} + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa O end => true + _ => false + end +end + +struct OperatorIsAtomic{O} end + +function (::OperatorIsAtomic{O})(ex::SymbolicT) where {O} + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => f isa O + _ => true + end +end + """ $(TYPEDSIGNATURES) @@ -636,12 +651,15 @@ can be checked using `check_scope_depth`. This function should return `nothing`. """ -function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics.Operator) - if issym(expr) - return collect_var!(unknowns, parameters, expr, iv; depth) +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + Moshi.Match.@match expr begin + BSImpl.Const(;) => return + BSImpl.Sym(;) => return collect_var!(unknowns, parameters, expr, iv; depth) + _ => nothing end - SymbolicUtils.isconst(expr) && return - for var in vars(expr; op) + vars = Set{SymbolicT}() + SU.search_variables!(vars, expr; is_atomic = OperatorIsAtomic{op}()) + for var in vars while iscall(var) && operation(var) isa op validate_operator(operation(var), arguments(var), iv; context = expr) var = arguments(var)[1] @@ -651,6 +669,13 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics return nothing end +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray{SymbolicT}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + for var in expr + collect_vars!(unknowns, parameters, var, iv; depth, op) + end + return nothing +end + """ $(TYPEDSIGNATURES) @@ -696,7 +721,7 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0) wrapped symbolic variables. """) end - check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing + check_scope_depth(getmetadata(var, SymScope, LocalScope())::AllScopes, depth) || return nothing var = setmetadata(var, SymScope, LocalScope()) if iscalledparameter(var) callable = getcalledparameter(var) @@ -724,7 +749,7 @@ function check_scope_depth(scope, depth) if scope isa LocalScope return depth == 0 elseif scope isa ParentScope - return depth > 0 && check_scope_depth(scope.parent, depth - 1) + return depth > 0 && check_scope_depth(scope.parent, depth - 1)::Bool elseif scope isa GlobalScope return depth == -1 end @@ -838,8 +863,8 @@ end Check if `T` is an appropriate symtype for a symbolic variable representing a floating point number or array of such numbers. """ -function is_floatingpoint_symtype(T::Type) - return T == Real || T == Number || T == Complex || T <: AbstractFloat || +function is_floatingpoint_symtype(T) + return T === Real || T === Number || T === Complex || T <: AbstractFloat || T <: AbstractArray && is_floatingpoint_symtype(eltype(T)) end From 85a19165db7f9670c826cc60b344f040b9d588d4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:26:45 +0530 Subject: [PATCH 037/157] fix: make `shift2term` type-stable --- .../StructuralTransformations.jl | 13 +-- src/structural_transformation/utils.jl | 84 +++++++++++-------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index cbc01de2b5..62ae97186f 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -3,17 +3,19 @@ module StructuralTransformations using Setfield: @set!, @set using UnPack: @unpack -using Symbolics: unwrap, linear_expansion +using Symbolics: unwrap, linear_expansion, VartypeT, SymbolicT import Symbolics using SymbolicUtils +using SymbolicUtils: BSImpl using SymbolicUtils.Code using SymbolicUtils.Rewriters -using SymbolicUtils: maketerm, iscall +using SymbolicUtils: maketerm, iscall, symtype import SymbolicUtils as SU +import Moshi using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, - unknowns, equations, vars, SymbolicT, diff2term_with_unit, + unknowns, equations, vars, diff2term_with_unit, shift2term_with_unit, value, operation, arguments, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, @@ -28,7 +30,8 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, - Schedule, schedule, iscomplete, get_schedule + Schedule, schedule, iscomplete, get_schedule, VariableUnshifted, + VariableShift using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete @@ -41,7 +44,7 @@ using ModelingToolkit: algeqs, EquationsView, dervars_range, diffvars_range, algvars_range, DiffGraph, complete!, get_fullvars, system_subset -using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic +using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic, getname using ModelingToolkit.DiffEqBase using ModelingToolkit.StaticArrays diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 6f54f33029..f6ff669c0e 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -503,43 +503,55 @@ end """ Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t). """ -function shift2term(var) - iscall(var) || return var - op = operation(var) - op isa Shift || return var - iv = op.t - arg = only(arguments(var)) - if operation(arg) === getindex - idxs = arguments(arg)[2:end] - newvar = shift2term(op(first(arguments(arg))))[idxs...] - unshifted = ModelingToolkit.getunshifted(newvar)[idxs...] - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted) - return newvar +function shift2term(var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(f, args) && if f isa Shift end => begin + op = f + arg = args[1] + Moshi.Match.@match arg begin + BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = shift2term(op(newargs[1])) + unshifted_args = copy(newargs) + unshifted_args[1] = ModelingToolkit.getunshifted(newargs[1]) + unshifted = BSImpl.Term{VartypeT}(getindex, unshifted_args; type, shape, metadata) + if metadata === nothing + metadata = Base.ImmutableDict{DataType, Any}(VariableUnshifted, unshifted) + elseif metadata isa Base.ImmutableDict{DataType, Any} + metadata = Base.ImmutableDict(metadata, VariableUnshifted, unshifted) + end + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + end + _ => nothing + end + unshifted = ModelingToolkit.getunshifted(arg) + is_lowered = unshifted !== nothing + backshift = op.steps + ModelingToolkit.getshift(arg) + io = IOBuffer() + O = (is_lowered ? unshifted : arg)::SymbolicT + write(io, getname(O)) + # Char(0x209c) = ₜ + write(io, Char(0x209c)) + # Char(0x208b) = ₋ (subscripted minus) + # Char(0x208a) = ₊ (subscripted plus) + pm = backshift > 0 ? Char(0x208a) : Char(0x208b) + write(io, pm) + backshift = abs(backshift) + N = ndigits(backshift) + den = 10 ^ (N - 1) + for _ in 1:N + # subscripted number, e.g. ₁ + write(io, Char(0x2080 + div(backshift, den) % 10)) + den = div(den, 10) + end + newname = Symbol(take!(io)) + newvar = Symbolics.rename(var, newname) + newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) + newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) + return newvar + end + _ => return var end - is_lowered = !isnothing(ModelingToolkit.getunshifted(arg)) - - backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps - - # Char(0x208b) = ₋ (subscripted minus) - # Char(0x208a) = ₊ (subscripted plus) - pm = backshift > 0 ? Char(0x208a) : Char(0x208b) - # subscripted number, e.g. ₁ - num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift)))) - # Char(0x209c) = ₜ - # ds = ₜ₋₁ - ds = join([Char(0x209c), pm, num]) - - O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg - oldop = operation(O) - newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : - Symbol(string(nameof(oldop))) - - newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), - arguments(O), SU.metadata(O)) - newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname)) - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) - newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) - return newvar end function isdoubleshift(var) From 99d1c00e350179405931dda2153c6a5670af7da3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:26:57 +0530 Subject: [PATCH 038/157] fix: make `unhack_observed` type-stable --- src/systems/nonlinear/initializesystem.jl | 29 +++++++---------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 928a2f3118..b318b7392a 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -823,30 +823,19 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w initialization. """ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) - subs = Dict() - tempvars = Set() - rm_idxs = Int[] + subs = Dict{SymbolicT, SymbolicT}() + mask = trues(length(obseqs)) for (i, eq) in enumerate(obseqs) - iscall(eq.rhs) || continue - if operation(eq.rhs) == StructuralTransformations.change_origin - push!(rm_idxs, i) - continue - end - end - - for (i, eq) in enumerate(obseqs) - if eq.lhs in tempvars - subs[eq.lhs] = eq.rhs - push!(rm_idxs, i) - end + mask[i] = !iscall(eq.rhs) || operation(eq.rhs) !== StructuralTransformations.change_origin end - obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] - obseqs = map(obseqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + obseqs = obseqs[mask] + for i in eachindex(obseqs) + obseqs[i] = fixpoint_sub(obseqs[i].lhs, subs) ~ fixpoint_sub(obseqs[i], subs) end - eqs = map(eqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + eqs = copy(eqs) + for i in eachindex(eqs) + eqs[i] = fixpoint_sub(eqs[i].lhs, subs) ~ fixpoint_sub(eqs[i], subs) end return obseqs, eqs end From 0d11876e46194df90018e693d746eeee6c49b407 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:27:15 +0530 Subject: [PATCH 039/157] fix: make `independent_variables` type-stable --- src/systems/abstractsystem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 9ae0b0cc82..2d40a32277 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -94,11 +94,11 @@ See also [`@independent_variables`](@ref) and [`ModelingToolkit.get_iv`](@ref). """ function independent_variables(sys::AbstractSystem) if isdefined(sys, :iv) && getfield(sys, :iv) !== nothing - return [getfield(sys, :iv)] + return SymbolicT[getfield(sys, :iv)] elseif isdefined(sys, :ivs) - return getfield(sys, :ivs) + return unwrap.(getfield(sys, :ivs))::Vector{SymbolicT} else - return [] + return SymbolicT[] end end From c72bc8af08533abbda1abe10799cfb8c5d269c27 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:27:45 +0530 Subject: [PATCH 040/157] fix: improve type-stability of some SII functions --- src/systems/abstractsystem.jl | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 2d40a32277..946b711405 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -170,17 +170,20 @@ function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem) return unknowns(sys) end -function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) - sym = unwrap(sym) +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + is_parameter(sys, unwrap(sym)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Int) + sym in 1:length(parameter_symbols(sys)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::SymbolicT) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - iscall(sym) && - operation(sym) === getindex && + iscall(sym) && operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end - if unwrap(sym) isa Int - return unwrap(sym) in 1:length(parameter_symbols(sys)) - end return any(isequal(sym), parameter_symbols(sys)) || hasname(sym) && !(iscall(sym) && operation(sym) == getindex) && is_parameter(sys, getname(sym)) @@ -191,7 +194,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol return is_parameter(ic, sym) end - named_parameters = [getname(x) + named_parameters = Symbol[getname(x) for x in parameter_symbols(sys) if hasname(x) && !(iscall(x) && operation(x) == getindex)] return any(isequal(sym), named_parameters) || @@ -200,6 +203,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, named_parameters)) == 1 end +SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) = false + function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing From 30c6a0b92c10b8534234a5e275361ceef6390315 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:28:07 +0530 Subject: [PATCH 041/157] fix: improve type-stability of the `Initial` operator --- src/systems/abstractsystem.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 946b711405..7625c6fbee 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -513,15 +513,16 @@ function (f::Initial)(x) end # don't double wrap iscall(x) && operation(x) isa Initial && return x - result = if symbolic_type(x) == ArraySymbolic() - term(f, x; type = symtype(x), shape = SU.shape(x)) - elseif iscall(x) && operation(x) == getindex + sh = SU.shape(x) + result = if SU.is_array_shape(sh) + term(f, x; type = symtype(x), shape = sh) + elseif iscall(x) && operation(x) === getindex # instead of `Initial(x[1])` create `Initial(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] f(arr)[arguments(x)[2:end]...] else - term(f, x; type = symtype(x), shape = SU.shape(x)) + term(f, x; type = symtype(x), shape = sh) end # the result should be a parameter result = toparam(result) From eecb9c5ee71d382c6f08cfa9a773887a14588f9e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:28:27 +0530 Subject: [PATCH 042/157] fix: make `add_initialization_parameters` type-stable --- src/systems/abstractsystem.jl | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 7625c6fbee..5bc7107433 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -539,16 +539,20 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) supports_initialization(sys) || return sys is_initializesystem(sys) && return sys - all_initialvars = Set{BasicSymbolic}() + all_initialvars = Set{SymbolicT}() # time-independent systems don't initialize unknowns # but may initialize parameters using guesses for unknowns eqs = equations(sys) - if !(eqs isa Vector{Equation}) - eqs = Equation[x for x in eqs if x isa Equation] - end obs, eqs = unhack_observed(observed(sys), eqs) - for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs))) - x = unwrap(x) + for x in unknowns(sys) + if iscall(x) && operation(x) == getindex && split + push!(all_initialvars, arguments(x)[1]) + else + push!(all_initialvars, x) + end + end + for eq in obs + x = eq.lhs if iscall(x) && operation(x) == getindex && split push!(all_initialvars, arguments(x)[1]) else @@ -558,15 +562,19 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) # add derivatives of all variables for steady-state initial conditions if is_time_dependent(sys) && !is_discrete_system(sys) - D = Differential(get_iv(sys)) - union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)]) + D = Differential(get_iv(sys)::SymbolicT) + for v in all_initialvars + iscall(v) && push!(all_initialvars, D(v)) + end end for eq in get_parameter_dependencies(sys) is_variable_floatingpoint(eq.lhs) || continue push!(all_initialvars, eq.lhs) end - all_initialvars = collect(all_initialvars) - initials = map(Initial(), all_initialvars) + initials = collect(all_initialvars) + for (i, v) in enumerate(initials) + initials[i] = Initial()(v) + end @set! sys.ps = unique!([get_ps(sys); initials]) defs = copy(get_defaults(sys)) for ivar in initials From cb16cef4f95e8060f0b4abec5a563628e77901eb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:28:41 +0530 Subject: [PATCH 043/157] fix: improve type-stability of `discover_globalscoped` --- src/systems/abstractsystem.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 5bc7107433..d0baf6cc8a 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -606,9 +606,9 @@ end Find [`GlobalScope`](@ref)d variables in `sys` and add them to the unknowns/parameters. """ function discover_globalscoped(sys::AbstractSystem) - newunknowns = OrderedSet() - newparams = OrderedSet() - iv = has_iv(sys) ? get_iv(sys) : nothing + newunknowns = OrderedSet{SymbolicT}() + newparams = OrderedSet{SymbolicT}() + iv::Union{SymbolicT, Nothing} = has_iv(sys) ? get_iv(sys) : nothing collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1) setdiff!(newunknowns, observables(sys)) @set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams))) From a8ba1cbd92e67cc281a342960486da4cc6524625 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:29:19 +0530 Subject: [PATCH 044/157] fix: make topsorting equations type-stable --- src/systems/alias_elimination.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 40114c50a4..03878e6847 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -411,7 +411,7 @@ julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k]) Equation(x(t), y(t) + z(t)) ``` """ -function topsort_equations(eqs, unknowns; check = true) +function topsort_equations(eqs::Vector{Equation}, unknowns::Vector{SymbolicT}; check = true) graph, assigns = observed2graph(eqs, unknowns) neqs = length(eqs) degrees = zeros(Int, neqs) @@ -460,22 +460,25 @@ function topsort_equations(eqs, unknowns; check = true) return ordered_eqs end -function observed2graph(eqs, unknowns) +function observed2graph(eqs::Vector{Equation}, unknowns::Vector{SymbolicT})::Tuple{BipartiteGraph{Int, Nothing}, Vector{Int}} graph = BipartiteGraph(length(eqs), length(unknowns)) - v2j = Dict(unknowns .=> 1:length(unknowns)) + v2j = Dict{SymbolicT, Int}(unknowns .=> 1:length(unknowns)) # `assigns: eq -> var`, `eq` defines `var` assigns = similar(eqs, Int) - + vars = Set{SymbolicT}() for (i, eq) in enumerate(eqs) lhs_j = get(v2j, eq.lhs, nothing) lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns.")) assigns[i] = lhs_j - vs = vars(eq.rhs; op = Symbolics.Operator) - for v in vs + empty!(vars) + SU.search_variables!(vars, eq.rhs; is_atomic = OperatorIsAtomic{SU.Operator}()) + for v in vars j = get(v2j, v, nothing) - j !== nothing && add_edge!(graph, i, j) + if j isa Int + add_edge!(graph, i, j) + end end end From d0aac5530730d28799b629240ff10696d9111568 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:29:50 +0530 Subject: [PATCH 045/157] fix: make `SymbolicAffect` and `AffectSystem` type-stable --- src/systems/callbacks.jl | 67 +++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index ae63176a92..8bf49f508d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -7,15 +7,19 @@ end struct SymbolicAffect affect::Vector{Equation} alg_eqs::Vector{Equation} - discrete_parameters::Vector{Any} + discrete_parameters::Vector{SymbolicT} end function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[], - discrete_parameters = Any[], kwargs...) - if !(discrete_parameters isa AbstractVector) - discrete_parameters = Any[discrete_parameters] - elseif !(discrete_parameters isa Vector{Any}) - discrete_parameters = Vector{Any}(discrete_parameters) + discrete_parameters = SymbolicT[], kwargs...) + if symbolic_type(discrete_parameters) !== NotSymbolic() + discrete_parameters = SymbolicT[unwrap(discrete_parameters)] + elseif !(discrete_parameters isa Vector{SymbolicT}) + _discs = SymbolicT[] + for p in discrete_parameters + push!(_discs, unwrap(p)) + end + discrete_parameters = _discs end SymbolicAffect(affect, alg_eqs, discrete_parameters) end @@ -33,11 +37,11 @@ struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem """Unknowns of the parent ODESystem whose values are modified or accessed by the affect.""" - unknowns::Vector + unknowns::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are accessed by the affect.""" - parameters::Vector + parameters::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are modified by the affect.""" - discretes::Vector + discretes::Vector{SymbolicT} end function (s::SymbolicUtils.Substituter)(aff::AffectSystem) @@ -57,7 +61,11 @@ function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], discrete_parameters = spec.discrete_parameters, kwargs...) end -function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], +@noinline function warn_algebraic_equation(eq::Equation) + @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." +end + +function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[], iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...) isempty(affect) && return nothing if isnothing(iv) @@ -65,26 +73,24 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], @warn "No independent variable specified. Defaulting to t_nounits." end - discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters]) - discrete_parameters = unwrap.(discrete_parameters) + discrete_parameters = SymbolicAffect(affect; alg_eqs, discrete_parameters).discrete_parameters for p in discrete_parameters SU.query(isequal(unwrap(iv)), unwrap(p)) || error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).") end - dvs = OrderedSet() - params = OrderedSet() - _varsbuf = Set() + dvs = OrderedSet{SymbolicT}() + params = OrderedSet{SymbolicT}() + _varsbuf = Set{SymbolicT}() for eq in affect - if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() || - symbolic_type(eq.lhs) === NotSymbolic()) - @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." + if !haspre(eq) && !(isconst(eq.lhs) && isconst(eq.rhs)) + @invokelatest warn_algebraic_equation(eq) end collect_vars!(dvs, params, eq, iv; op = Pre) empty!(_varsbuf) - vars!(_varsbuf, eq; op = Pre) - filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf) + SU.search_variables!(_varsbuf, eq; is_atomic = OperatorIsAtomic{Pre}()) + filter!(x -> iscall(x) && operation(x) === Pre(), _varsbuf) union!(params, _varsbuf) diffvs = collect_applied_operators(eq, Differential) union!(dvs, diffvs) @@ -92,14 +98,20 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], for eq in alg_eqs collect_vars!(dvs, params, eq, iv) end - pre_params = filter(haspre ∘ value, params) - sys_params = collect(setdiff(params, union(discrete_parameters, pre_params))) + pre_params = filter(haspre, params) + sys_params = SymbolicT[] + disc_ps_set = Set{SymbolicT}(discrete_parameters) + for p in params + p in disc_ps_set && continue + p in pre_params && continue + push!(sys_params, p) + end discretes = map(tovar, discrete_parameters) dvs = collect(dvs) _dvs = map(default_toterm, dvs) - rev_map = Dict(zip(discrete_parameters, discretes)) - subs = merge(rev_map, Dict(zip(dvs, _dvs))) + rev_map = Dict{SymbolicT, SymbolicT}(zip(discrete_parameters, discretes)) + subs = merge(rev_map, Dict{SymbolicT, SymbolicT}(zip(dvs, _dvs))) affect = substitute(affect, subs) alg_eqs = substitute(alg_eqs, subs) @@ -108,14 +120,13 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], collect(union(pre_params, sys_params)); is_discrete = true) affectsys = mtkcompile(affectsys; fully_determined = nothing) # get accessed parameters p from Pre(p) in the callback parameters - accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params)))) + accessed_params = Vector{SymbolicT}(filter(isparameter, map(unPre, collect(pre_params)))) union!(accessed_params, sys_params) # add scalarized unknowns to the map. - _dvs = reduce(vcat, map(scalarize, _dvs), init = Any[]) + _dvs = reduce(vcat, map(scalarize, _dvs), init = SymbolicT[]) - AffectSystem(affectsys, collect(_dvs), collect(accessed_params), - collect(discrete_parameters)) + AffectSystem(affectsys, _dvs, accessed_params, discrete_parameters) end system(a::AffectSystem) = a.system From 38d28e581efce8e49cb7df8e0bde0f68afc3ca81 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:30:08 +0530 Subject: [PATCH 046/157] fix: improve type-stability of `ImperativeAffect` --- src/systems/imperative_affect.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 6579ad63cc..7fc0c6abe1 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -85,10 +85,16 @@ context(a::ImperativeAffect) = a.ctx observed(a::ImperativeAffect) = a.obs observed_syms(a::ImperativeAffect) = a.obs_syms function discretes(a::ImperativeAffect) - Iterators.filter(ModelingToolkit.isparameter, - Iterators.flatten(Iterators.map( - x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x], - a.modified))) + discs = SymbolicT[] + for val in a.modified + val = unwrap(val) + if val isa SymbolicT + isparameter(a) && push!(discs, val) + elseif val isa AbstractArray + append!(discs, filter(isparameter, map(unwrap, val))) + end + end + return discs end modified(a::ImperativeAffect) = a.modified modified_syms(a::ImperativeAffect) = a.mod_syms From ef6f531fade1786173be664ac1501d0286bdf838 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:31:13 +0530 Subject: [PATCH 047/157] fix: make `reorder_parameters` more type-stable --- src/systems/index_cache.jl | 95 +++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 52 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index ea651c236f..d1fe95f997 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -486,7 +486,7 @@ end function reorder_parameters( sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing - reorder_parameters(get_index_cache(sys), ps; kwargs...) + reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) elseif ps isa Tuple ps else @@ -494,46 +494,54 @@ function reorder_parameters( end end -function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true) +const COMMON_DEFAULT_VAR = unwrap(only(@variables __DEF__)) + +function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing = false, flatten = true) isempty(ps) && return () - param_buf = if ic.tunable_buffer_size.length == 0 - () - else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.tunable_buffer_size.length)],) + result = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}}() + param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length) + push!(result, param_buf) + initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length) + push!(result, initials_buf) + + disc_buf = Vector{SymbolicT}[] + for bufszs in ic.discrete_buffer_sizes + push!(disc_buf, fill(COMMON_DEFAULT_VAR, sum(x -> x.length, bufszs))) + end + const_buf = Vector{SymbolicT}[] + for bufsz in ic.constant_buffer_sizes + push!(const_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) + end + nonnumeric_buf = Vector{SymbolicT}[] + for bufsz in ic.nonnumeric_buffer_sizes + push!(nonnumeric_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) end - initials_buf = if ic.initials_buffer_size.length == 0 - () + if flatten + append!(result, disc_buf) + append!(result, const_buf) + append!(result, nonnumeric_buf) else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.initials_buffer_size.length)],) + push!(result, disc_buf) + push!(result, const_buf) + push!(result, nonnumeric_buf) end - - disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(sum(x -> x.length, temp))] - for temp in ic.discrete_buffer_sizes) - const_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.constant_buffer_sizes) - nonnumeric_buf = Tuple(SymbolicT[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.nonnumeric_buffer_sizes) for p in ps - p = unwrap(p) if haskey(ic.discrete_idx, p) idx = ic.discrete_idx[p] disc_buf[idx.buffer_idx][idx.idx_in_buffer] = p elseif haskey(ic.tunable_idx, p) i = ic.tunable_idx[p] if i isa Int - param_buf[1][i] = unwrap(p) + param_buf[i] = p else - param_buf[1][i] = unwrap.(collect(p)) + param_buf[i] = collect(p) end elseif haskey(ic.initials_idx, p) i = ic.initials_idx[p] if i isa Int - initials_buf[1][i] = unwrap(p) + initials_buf[i] = p else - initials_buf[1][i] = unwrap.(collect(p)) + initials_buf[i] = collect(p) end elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] @@ -546,37 +554,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = end end - param_buf = broadcast.(unwrap, param_buf) - initials_buf = broadcast.(unwrap, initials_buf) - disc_buf = broadcast.(unwrap, disc_buf) - const_buf = broadcast.(unwrap, const_buf) - nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf) - if drop_missing - filterer = !isequal(unwrap(variable(:DEF))) - param_buf = filter.(filterer, param_buf) - initials_buf = filter.(filterer, initials_buf) - disc_buf = filter.(filterer, disc_buf) - const_buf = filter.(filterer, const_buf) - nonnumeric_buf = filter.(filterer, nonnumeric_buf) - end - - if flatten - result = ( - param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...) - if all(isempty, result) - return () - end - return result - else - if isempty(param_buf) - param_buf = ((),) - end - if isempty(initials_buf) - initials_buf = ((),) + filterer = !isequal(COMMON_DEFAULT_VAR) + for inner in result + if inner isa Vector{SymbolicT} + filter!(filterer, inner) + elseif inner isa Vector{Vector{SymbolicT}} + for buf in inner + filter!(filterer, buf) + end + end end - return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf) end + + return result end # Given a parameter index, find the index of the buffer it is in when From b4d0b35564896713cd02bb5f40b21f2aaa6a5213 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:31:48 +0530 Subject: [PATCH 048/157] fix: make some SII impls of `IndexCache` more type-stable --- src/systems/index_cache.jl | 39 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index d1fe95f997..1606878fcd 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -410,14 +410,16 @@ function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) parameter_index(ic, sym) !== nothing end -function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return nothing - end - sym = unwrap(sym) - validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray && - symbolic_has_known_size(sym) +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + parameter_index(ic, unwrap(sym)) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Symbol) + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + parameter_index(ic, sym) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::SymbolicT) + validate_size = Symbolics.isarraysymbolic(sym) && symbolic_has_known_size(sym) return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing ParameterIndex(SciMLStructures.Tunable(), idx, validate_size) elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing @@ -464,23 +466,14 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...)) end -function check_index_map(idxmap, sym) - if (idx = get(idxmap, sym, nothing)) !== nothing - return idx - elseif !isa(sym, Symbol) && (!iscall(sym) || operation(sym) !== getindex) && - hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing - return idx - end +function check_index_map(idxmap::Dict{SymbolicT, V}, sym::SymbolicT)::Union{V, Nothing} where {V} + idx = get(idxmap, sym, nothing) + idx === nothing || return idx dsym = default_toterm(sym) isequal(sym, dsym) && return nothing - if (idx = get(idxmap, dsym, nothing)) !== nothing - idx - elseif !isa(dsym, Symbol) && (!iscall(dsym) || operation(dsym) !== getindex) && - hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing - idx - else - nothing - end + idx = get(idxmap, dsym, nothing) + idx === nothing || return idx + return nothing end function reorder_parameters( From afd19da471503e4f8c659260541549424af64001 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:32:17 +0530 Subject: [PATCH 049/157] fix: make `IndexCache` constructor more type-stable --- src/problems/sccnonlinearproblem.jl | 2 - src/systems/index_cache.jl | 280 ++++++++++++++++------------ 2 files changed, 156 insertions(+), 126 deletions(-) diff --git a/src/problems/sccnonlinearproblem.jl b/src/problems/sccnonlinearproblem.jl index d71124adde..8d131a1f31 100644 --- a/src/problems/sccnonlinearproblem.jl +++ b/src/problems/sccnonlinearproblem.jl @@ -1,5 +1,3 @@ -const TypeT = Union{DataType, UnionAll} - struct CacheWriter{F} fn::F end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 1606878fcd..09f97079a2 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,5 +1,7 @@ +const TypeT = Union{DataType, UnionAll, Union} + struct BufferTemplate - type::Union{DataType, UnionAll, Union} + type::TypeT length::Int end @@ -26,12 +28,12 @@ struct DiscreteIndex idx_in_clock::Int end -const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} +const MaybeUnknownArrayIndexT = Union{Int, UnitRange{Int}, AbstractArray{Int}} +const MaybeArrayIndexT = Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} +const ParamIndexMap = Dict{SymbolicT, Tuple{Int, Int}} const NonnumericMap = Dict{SymbolicT, Tuple{Int, Int}} -const UnknownIndexMap = Dict{ - BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} -const TunableIndexMap = Dict{BasicSymbolic, - Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} +const UnknownIndexMap = Dict{SymbolicT, MaybeUnknownArrayIndexT} +const TunableIndexMap = Dict{SymbolicT, MaybeArrayIndexT} const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} const SymbolicParam = SymbolicT @@ -46,8 +48,8 @@ struct IndexCache initials_idx::TunableIndexMap constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap - observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} - dependent_pars_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} + observed_syms_to_timeseries::Dict{SymbolicT, TimeseriesSetType} + dependent_pars_to_timeseries::Dict{SymbolicT, TimeseriesSetType} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate initials_buffer_size::BufferTemplate @@ -63,26 +65,36 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks - usym = unwrap(sym) - rsym = renamespace(sys, usym) - sym_idx = if Symbolics.isarraysymbolic(sym) + rsym = renamespace(sys, sym) + sym_idx::MaybeUnknownArrayIndexT = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end - unk_idxs[usym] = sym_idx + unk_idxs[sym] = sym_idx unk_idxs[rsym] = sym_idx idx += length(sym) end + found_array_syms = Set{SymbolicT}() for sym in unks - usym = unwrap(sym) iscall(sym) && operation(sym) === getindex || continue arrsym = arguments(sym)[1] - all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue - - idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)] + arrsym in found_array_syms && continue + idxs = Int[] + valid_arrsym = true + for i in eachindex(arrsym) + idxsym = arrsym[i] + idx = get(unk_idxs, idxsym, nothing)::Union{Int, Nothing} + valid_arrsym = idx !== nothing + valid_arrsym || break + push!(idxs, idx::Int) + end + push!(found_array_syms, arrsym) + valid_arrsym || break if idxs == idxs[begin]:idxs[end] - idxs = reshape(idxs[begin]:idxs[end], size(idxs)) + idxs = reshape(idxs[begin]:idxs[end], size(idxs))::AbstractArray{Int} + else + idxs = reshape(idxs, size(arrsym))::AbstractArray{Int} end rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs @@ -90,62 +102,24 @@ function IndexCache(sys::AbstractSystem) end end - tunable_pars = BasicSymbolic[] - initial_pars = BasicSymbolic[] - constant_buffers = Dict{Any, Set{BasicSymbolic}}() - nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}() - - function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S} - sym = unwrap(sym) - buf = get!(buffers, ctype, S()) - push!(buf, sym) - end - function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype) - sym = unwrap(sym) - push!(buffers, sym) - end - - disc_param_callbacks = Dict{SymbolicParam, Set{Int}}() - events = vcat(continuous_events(sys), discrete_events(sys)) - for (i, event) in enumerate(events) - discs = Set{SymbolicParam}() - affs = affects(event) - if !(affs isa AbstractArray) - affs = [affs] - end - for affect in affs - if affect isa AffectSystem || affect isa ImperativeAffect - union!(discs, unwrap.(discretes(affect))) - elseif isnothing(affect) - continue - else - error("Unhandled affect type $(typeof(affect))") - end - end - - for sym in discs - is_parameter(sys, sym) || - error("Expected discrete variable $sym in callback to be a parameter") - - # Only `foo(t)`-esque parameters can be saved - if iscall(sym) && length(arguments(sym)) == 1 && - isequal(only(arguments(sym)), get_iv(sys)) - clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym) - push!(clocks, i) - elseif is_variable_floatingpoint(sym) - insert_by_type!(constant_buffers, sym, symtype(sym)) - else - stype = symtype(sym) - if stype <: FnType - stype = fntype_to_function_type(stype) - end - insert_by_type!(nonnumeric_buffers, sym, stype) - end - end - end - clock_partitions = unique(collect(values(disc_param_callbacks))) - disc_symtypes = unique(symtype.(keys(disc_param_callbacks))) - disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes)) + tunable_pars = SymbolicT[] + initial_pars = SymbolicT[] + constant_buffers = Dict{TypeT, Set{SymbolicT}}() + nonnumeric_buffers = Dict{TypeT, Set{SymbolicT}}() + + disc_param_callbacks = Dict{SymbolicParam, BitSet}() + cevs = continuous_events(sys) + devs = discrete_events(sys) + events = Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}[cevs; devs] + parse_callbacks_for_discretes!(cevs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, 0) + parse_callbacks_for_discretes!(devs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, length(cevs)) + clock_partitions = unique(collect(values(disc_param_callbacks)))::Vector{BitSet} + disc_symtypes = Set{TypeT}() + for x in keys(disc_param_callbacks) + push!(disc_symtypes, symtype(x)) + end + disc_symtypes = collect(disc_symtypes)::Vector{TypeT} + disc_symtype_idx = Dict{TypeT, Int}(zip(disc_symtypes, eachindex(disc_symtypes))) disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes] for sym in keys(disc_param_callbacks) push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym) @@ -153,13 +127,12 @@ function IndexCache(sys::AbstractSystem) disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes] for (i, buffer) in enumerate(disc_syms_by_symtype) for partition in clock_partitions - push!(disc_syms_by_symtype_by_partition[i], - [sym for sym in buffer if disc_param_callbacks[sym] == partition]) + push!(disc_syms_by_symtype_by_partition[i], filter(==(partition) ∘ Base.Fix1(getindex, disc_param_callbacks), buffer)) end end disc_idxs = Dict{SymbolicParam, DiscreteIndex}() callback_to_clocks = Dict{ - Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}() + Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, BitSet}() for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition) symi = 0 for (parti, disc_syms) in enumerate(disc_syms_by_partition) @@ -187,15 +160,13 @@ function IndexCache(sys::AbstractSystem) disc_buffer_templates = Vector{BufferTemplate}[] for (symtype, disc_syms_by_partition) in zip( disc_symtypes, disc_syms_by_symtype_by_partition) - push!(disc_buffer_templates, - [BufferTemplate(symtype, length(buf)) for buf in disc_syms_by_partition]) + push!(disc_buffer_templates, map(Base.Fix1(BufferTemplate, symtype) ∘ length, disc_syms_by_partition)) end for p in parameters(sys; initial_parameters = true) - p = unwrap(p) ctype = symtype(p) if ctype <: FnType - ctype = fntype_to_function_type(ctype) + ctype = fntype_to_function_type(ctype)::TypeT end haskey(disc_idxs, p) && continue haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue @@ -206,7 +177,7 @@ function IndexCache(sys::AbstractSystem) (ctype == Real || ctype <: AbstractFloat || ctype <: AbstractArray{Real} || ctype <: AbstractArray{<:AbstractFloat}) - if iscall(p) && operation(p) isa Initial + if iscall(p) && operation(p) === Initial() initial_pars else tunable_pars @@ -222,33 +193,10 @@ function IndexCache(sys::AbstractSystem) ) end - function get_buffer_sizes_and_idxs(T, buffers::Dict) - idxs = T() - buffer_sizes = BufferTemplate[] - for (i, (T, buf)) in enumerate(buffers) - for (j, p) in enumerate(buf) - ttp = default_toterm(p) - rp = renamespace(sys, p) - rttp = renamespace(sys, ttp) - idxs[p] = (i, j) - idxs[ttp] = (i, j) - idxs[rp] = (i, j) - idxs[rttp] = (i, j) - end - if T <: Symbolics.FnType - T = Any - end - push!(buffer_sizes, BufferTemplate(T, length(buf))) - end - return idxs, buffer_sizes - end - const_idxs, - const_buffer_sizes = get_buffer_sizes_and_idxs( - ParamIndexMap, constant_buffers) + const_buffer_sizes = get_buffer_sizes_and_idxs(ParamIndexMap, sys, constant_buffers) nonnumeric_idxs, - nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs( - NonnumericMap, nonnumeric_buffers) + nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(NonnumericMap, sys, nonnumeric_buffers) tunable_idxs = TunableIndexMap() tunable_buffer_size = 0 @@ -257,7 +205,8 @@ function IndexCache(sys::AbstractSystem) empty!(initial_pars) end for p in tunable_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) tunable_buffer_size + 1 else reshape( @@ -275,7 +224,8 @@ function IndexCache(sys::AbstractSystem) initials_idxs = TunableIndexMap() initials_buffer_size = 0 for p in initial_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) initials_buffer_size + 1 else reshape( @@ -293,23 +243,27 @@ function IndexCache(sys::AbstractSystem) for k in collect(keys(tunable_idxs)) v = tunable_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk::SymbolicT, vv) in zip(iter, v) tunable_idxs[kk] = vv end end for k in collect(keys(initials_idxs)) v = initials_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk, vv) in zip(iter, v) initials_idxs[kk] = vv end end dependent_pars_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() - + vs = Set{SymbolicT}() for eq in get_parameter_dependencies(sys) sym = eq.lhs - vs = vars(eq.rhs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs @@ -323,24 +277,29 @@ function IndexCache(sys::AbstractSystem) rttsym = renamespace(sys, ttsym) for s in (sym, ttsym, rsym, rttsym) dependent_pars_to_timeseries[s] = timeseries - if hasname(s) && (!iscall(s) || operation(s) != getindex) + if hasname(s) && (!iscall(s) || operation(s) !== getindex) symbol_to_variable[getname(s)] = sym end end end - observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}() + observed_syms_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() for eq in observed(sys) if symbolic_type(eq.lhs) != NotSymbolic() sym = eq.lhs - vs = vars(eq.rhs; op = Nothing) + empty!(vs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs if (idx = get(disc_idxs, v, nothing)) !== nothing push!(timeseries, idx.clock_idx) - elseif iscall(v) && operation(v) === getindex && - (idx = get(disc_idxs, arguments(v)[1], nothing)) !== nothing + elseif Moshi.Match.@match v begin + BSImpl.Term(; f, args) => begin + f === getindex && (idx = get(disc_idxs, args[1], nothing)) !== nothing + end + _ => false + end push!(timeseries, idx.clock_idx) elseif haskey(observed_syms_to_timeseries, v) union!(timeseries, observed_syms_to_timeseries[v]) @@ -361,13 +320,12 @@ function IndexCache(sys::AbstractSystem) end end - for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), - keys(const_idxs), keys(nonnumeric_idxs), - keys(observed_syms_to_timeseries), independent_variable_symbols(sys))) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = sym - end - end + populate_symbol_to_var!(symbol_to_variable, keys(unk_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(disc_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(tunable_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(const_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(nonnumeric_idxs)) + populate_symbol_to_var!(symbol_to_variable, independent_variable_symbols(sys)) return IndexCache( unk_idxs, @@ -388,6 +346,80 @@ function IndexCache(sys::AbstractSystem) ) end +function populate_symbol_to_var!(symbol_to_variable::Dict{Symbol, SymbolicT}, vars) + for sym::SymbolicT in vars + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end +end + +""" + $TYPEDSIGNATURES + +Utility function for the `IndexCache` constructor. +""" +function insert_by_type!(buffers::Dict{TypeT, Set{SymbolicT}}, sym::SymbolicT, ctype::TypeT) + buf = get!(Set{SymbolicT}, buffers, ctype) + push!(buf, sym) +end +function insert_by_type!(buffers::Vector{SymbolicT}, sym::SymbolicT, ::TypeT) + push!(buffers, sym) +end + +function parse_callbacks_for_discretes!(events::Vector, disc_param_callbacks::Dict{SymbolicT, BitSet}, constant_buffers::Dict{TypeT, Set{SymbolicT}}, nonnumeric_buffers::Dict{TypeT, Set{SymbolicT}}, offset::Int) + for (i, event) in enumerate(events) + discs = Set{SymbolicParam}() + affect = event.affect::Union{AffectSystem, ImperativeAffect, Nothing} + if affect isa AffectSystem || affect isa ImperativeAffect + union!(discs, discretes(affect)) + elseif affect === nothing + continue + end + + for sym in discs + is_parameter(sys, sym) || + error("Expected discrete variable $sym in callback to be a parameter") + + # Only `foo(t)`-esque parameters can be saved + if iscall(sym) && length(arguments(sym)) == 1 && + isequal(only(arguments(sym)), get_iv(sys)) + clocks = get!(BitSet, disc_param_callbacks, sym) + push!(clocks, i + offset) + elseif is_variable_floatingpoint(sym) + insert_by_type!(constant_buffers, sym, symtype(sym)) + else + stype = symtype(sym) + if stype <: FnType + stype = fntype_to_function_type(stype)::TypeT + end + insert_by_type!(nonnumeric_buffers, sym, stype) + end + end + end +end + +function get_buffer_sizes_and_idxs(::Type{BufT}, sys::AbstractSystem, buffers::Dict) where {BufT} + idxs = BufT() + buffer_sizes = BufferTemplate[] + for (i, (T, buf)) in enumerate(buffers) + for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) + idxs[p] = (i, j) + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) + end + if T <: Symbolics.FnType + T = Any + end + push!(buffer_sizes, BufferTemplate(T, length(buf))) + end + return idxs, buffer_sizes +end + function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) variable_index(ic, sym) !== nothing end From d279aaa5da0354182f79873463e06429bc9eb064 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:32:31 +0530 Subject: [PATCH 050/157] fix: improve precompile-friendliness of `complete` --- src/ModelingToolkit.jl | 10 +- src/systems/abstractsystem.jl | 280 ++++++++++++++++++++-------------- src/systems/index_cache.jl | 10 +- 3 files changed, 183 insertions(+), 117 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 36816a1adc..6cf27ea2c5 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -445,7 +445,14 @@ PrecompileTools.@compile_workload begin @variables x(ModelingToolkit.t_nounits) isequal(ModelingToolkit.D_nounits.x, ModelingToolkit.t_nounits) sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) - sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + complete(sys) + @syms p[1:2] + ndims(p) + size(p) + axes(p) + length(p) + v = [p] + isempty(v) # mtkcompile(sys) @mtkmodel __testmod__ begin @constants begin @@ -488,5 +495,6 @@ precompile(Tuple{Type{NamedTuple{(:name, :defaults), T} where T<:Tuple}, Tuple{S precompile(Tuple{typeof(SymbolicUtils.isequal_somescalar), Float64, Float64}) precompile(Tuple{Type{NamedTuple{(:name, :defaults, :guesses), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}) precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults, :guesses), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:type, :shape), Tuple{DataType, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}}}, typeof(SymbolicUtils.term), Any, SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}}) end # module diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d0baf6cc8a..8c6cc93c53 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -631,34 +631,36 @@ This namespacing functionality can also be toggled independently of `complete` using [`toggle_namespacing`](@ref). """ function complete( - sys::AbstractSystem; split = true, flatten = true, add_initial_parameters = true) + sys::T; split = true, flatten = true, add_initial_parameters = true) where {T <: AbstractSystem} sys = discover_globalscoped(sys) if flatten - eqs = equations(sys) - if eqs isa AbstractArray && eltype(eqs) <: Equation - newsys = expand_connections(sys) - else - newsys = sys - end + newsys = expand_connections(sys) newsys = ModelingToolkit.flatten(newsys) if has_parent(newsys) && get_parent(sys) === nothing - @set! newsys.parent = complete(sys; split = false, flatten = false) + @set! newsys.parent = complete(sys; split = false, flatten = false)::T end sys = newsys - sys = process_parameter_equations(sys) + sys = process_parameter_equations(sys)::T if add_initial_parameters - sys = add_initialization_parameters(sys; split) + sys = add_initialization_parameters(sys; split)::T end + cb_alg_eqs = Equation[alg_equations(sys); observed(sys)] if has_continuous_events(sys) && is_time_dependent(sys) - @set! sys.continuous_events = complete.( - get_continuous_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + cevts = SymbolicContinuousCallback[] + for ev in get_continuous_events(sys) + ev = complete(ev; iv = get_iv(sys)::SymbolicT, alg_eqs = cb_alg_eqs) + push!(cevts, ev) + end + @set! sys.continuous_events = cevts end if has_discrete_events(sys) && is_time_dependent(sys) - @set! sys.discrete_events = complete.( - get_discrete_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + devts = SymbolicDiscreteCallback[] + for ev in get_discrete_events(sys) + ev = complete(ev; iv = get_iv(sys)::SymbolicT, alg_eqs = cb_alg_eqs) + push!(devts, ev) + end + @set! sys.discrete_events = devts end end if split && has_index_cache(sys) @@ -666,39 +668,48 @@ function complete( # Ideally we'd do `get_ps` but if `flatten = false` # we don't get all of them. So we call `parameters`. all_ps = parameters(sys; initial_parameters = true) + all_ps_set = Set{SymbolicT}(all_ps) # inputs have to be maintained in a specific order input_vars = inputs(sys) if !isempty(all_ps) # reorder parameters by portions - ps_split = reorder_parameters(sys, all_ps) + ps_split = Vector{Vector{SymbolicT}}(reorder_parameters(sys, all_ps)) # if there are tunables, they will all be in `ps_split[1]` # and the arrays will have been scalarized - ordered_ps = eltype(all_ps)[] + ordered_ps = SymbolicT[] + offset = 0 # if there are no tunables, vcat them if !isempty(get_index_cache(sys).tunable_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[1], all_ps_set) + offset += 1 end # unflatten initial parameters if !isempty(get_index_cache(sys).initials_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[2], all_ps_set) + offset += 1 + end + for i in (offset+1):length(ps_split) + append!(ordered_ps, ps_split[i]::Vector{SymbolicT}) end - ordered_ps = vcat( - ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[])) if isscheduled(sys) # ensure inputs are sorted - input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,)) - @assert all(!isnothing, input_idxs) - @assert issorted(input_idxs) + last_idx = 0 + for p in input_vars + idx = findfirst(isequal(p), ordered_ps)::Int + @assert last_idx < idx + last_idx = idx + end end @set! sys.ps = ordered_ps end elseif has_index_cache(sys) @set! sys.index_cache = nothing end - if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing - @set! sys.initializesystem = complete(get_initializesystem(sys); split) + if has_initializesystem(sys) + isys = get_initializesystem(sys) + if isys isa T + @set! sys.initializesystem = complete(isys::T; split) + end end sys = toggle_namespacing(sys, false; safe = true) isdefined(sys, :complete) ? (@set! sys.complete = true) : sys @@ -727,26 +738,28 @@ parameters in the system `all_ps`, unscalarize the elements in `params` and appe to `buffer` in the same order as they are present in `params`. Effectively, if `params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`. """ -function unflatten_parameters!(buffer, params, all_ps) +function unflatten_parameters!(buffer::Vector{SymbolicT}, params::Vector{SymbolicT}, all_ps::Set{SymbolicT}) i = 1 # go through all the tunables while i <= length(params) sym = params[i] # if the sym is not a scalarized array symbolic OR it was already scalarized, # just push it as-is - if !iscall(sym) || operation(sym) != getindex || - any(isequal(sym), all_ps) + if !iscall(sym) || operation(sym) !== getindex || sym in all_ps push!(buffer, sym) i += 1 continue end + + arrsym = first(arguments(sym)) # the next `length(sym)` symbols should be scalarized versions of the same # array symbolic - if !allequal(first(arguments(x)) - for x in view(params, i:(i + length(sym) - 1))) - error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + for j in (i+1):(i+length(sym)-1) + p = params[j] + if !(iscall(p) && operation(p) === getindex && isequal(arguments(p)[1], arrsym)) + error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + end end - arrsym = first(arguments(sym)) push!(buffer, arrsym) i += length(arrsym) end @@ -1099,6 +1112,8 @@ function GlobalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) end end +const AllScopes = Union{LocalScope, ParentScope, GlobalScope} + renamespace(sys, eq::Equation) = namespace_equation(eq, sys) renamespace(names::AbstractVector, x) = foldr(renamespace, names, init = x) @@ -1112,8 +1127,9 @@ renamespace(sys, tgt::Symbol) = Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, Namespace `x` with the name of `sys`. """ function renamespace(sys, x::SymbolicT) + isequal(x, SU.idxs_for_arrayop(VartypeT)) && return x Moshi.Match.@match x begin - BSImpl.Sym(; name) => let scope = getmetadata(x, SymScope, LocalScope())::Union{LocalScope, ParentScope, GlobalScope} + BSImpl.Sym(; name) => let scope = getmetadata(x, SymScope, LocalScope())::AllScopes if scope isa LocalScope return rename(x, renamespace(getname(sys), name))::SymbolicT elseif scope isa ParentScope @@ -1172,8 +1188,14 @@ Return `equations(sys)`, namespaced by the name of `sys`. """ function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys)) eqs = equations(sys) - isempty(eqs) && return Equation[] - map(eq -> namespace_equation(eq, sys; ivs), eqs) + isempty(eqs) && return eqs + if eqs === get_eqs(sys) + eqs = copy(eqs) + end + for i in eachindex(eqs) + eqs[i] = namespace_equation(eqs[i], sys; ivs) + end + return eqs end function namespace_initialization_equations( @@ -1220,7 +1242,15 @@ function namespace_jump(j::MassActionJump, sys) end function namespace_jumps(sys::AbstractSystem) - return [namespace_jump(j, sys) for j in get_jumps(sys)] + js = jumps(sys) + isempty(js) && return js + if js === get_jumps(sys) + js = copy(js) + end + for i in eachindex(js) + js[i] = namespace_jump(js[i], sys) + end + return js end function namespace_brownians(sys::AbstractSystem) @@ -1240,48 +1270,63 @@ function is_array_of_symbolics(x) any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) end -function namespace_expr( - O, sys, n = (sys === nothing ? nothing : nameof(sys)); - ivs = sys === nothing ? nothing : independent_variables(sys)) - sys === nothing && return O - O = unwrap(O) - # Exceptions for arrays of symbolic and Ref of a symbolic, the latter - # of which shows up in broadcasts - if symbolic_type(O) == NotSymbolic() && !(O isa AbstractArray) && !(O isa Ref) - return O - end - if any(isequal(O), ivs) - return O - elseif iscall(O) - T = typeof(O) - renamed = let sys = sys, n = n, T = T - map(a -> namespace_expr(a, sys, n; ivs)::Any, arguments(O)) - end - if isvariable(O) - # Use renamespace so the scope is correct, and make sure to use the - # metadata from the rescoped variable - rescoped = renamespace(n, O) - maketerm(typeof(rescoped), operation(rescoped), renamed, - metadata(rescoped)) - elseif Symbolics.isarraysymbolic(O) - # promote_symtype doesn't work for array symbolics - maketerm(typeof(O), operation(O), renamed, metadata(O)) - else - maketerm(typeof(O), operation(O), renamed, metadata(O)) +function namespace_expr(O, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + return O +end +function namespace_expr(O::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + namespace_expr(O, args...; kw...) +end +function namespace_expr(O::AbstractArray, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + is_array_of_symbolics(O) || return O + O = copy(O) + for i in eachindex(O) + O[i] = namespace_expr(O[i], sys, n; ivs) + end + return O +end +function namespace_expr(O::SymbolicT, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + any(isequal(O), ivs) && return O + isvar = isvariable(O) + Moshi.Match.@match O begin + BSImpl.Const(;) => return O + BSImpl.Sym(;) => return isvar ? renamespace(n, O) : O + BSImpl.Term(; f, args, metadata, type, shape) => begin + newargs = copy(parent(args)) + for i in eachindex(args) + newargs[i] = namespace_expr(newargs[i], sys, n; ivs) + end + if isvar + rescoped = renamespace(n, O) + f = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :f) + meta = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :metadata) + elseif f isa SymbolicT + f = renamespace(n, f) + meta = metadata + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata = meta) end - elseif isvariable(O) - renamespace(n, O) - elseif O isa AbstractArray && is_array_of_symbolics(O) - let sys = sys, n = n - map(o -> namespace_expr(o, sys, n; ivs), O) + BSImpl.AddMul(; coeff, dict, variant, type, shape, metadata) => begin + newdict = copy(dict) + for (k, v) in newdict + newdict[namespace_expr(k, sys, n; ivs)] = v + end + return BSImpl.AddMul{VartypeT}(coeff, newdict, variant; type, shape, metadata) + end + BSImpl.Div(; num, den, type, shape, metadata) => begin + num = namespace_expr(num, sys, n; ivs) + den = namespace_expr(den, sys, n; ivs) + return BSImpl.Div{VartypeT}(num, den, false; type, shape, metadata) + end + BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce, type, shape, metadata) => begin + if term isa SymbolicT + term = namespace_expr(term, sys, n; ivs) + end + expr = namespace_expr(expr, sys, n; ivs) + return BSImpl.ArrayOp{VartypeT}(output_idx, expr, reduce, term, ranges; type, shape, metadata) end - else - O end end -_nonum(@nospecialize x) = x isa Num ? x.val : x - """ $(TYPEDSIGNATURES) @@ -1293,21 +1338,14 @@ See also [`ModelingToolkit.get_unknowns`](@ref). function unknowns(sys::AbstractSystem) sts = get_unknowns(sys) systems = get_systems(sys) - nonunique_unknowns = if isempty(systems) - sts - else - system_unknowns = reduce(vcat, namespace_variables.(systems)) - isempty(sts) ? system_unknowns : [sts; system_unknowns] + if isempty(systems) + return sts end - isempty(nonunique_unknowns) && return nonunique_unknowns - # `Vector{Any}` is incompatible with the `SymbolicIndexingInterface`, which uses - # `elsymtype = symbolic_type(eltype(_arg))` - # which inappropriately returns `NotSymbolic()` - if nonunique_unknowns isa Vector{Any} - nonunique_unknowns = _nonum.(nonunique_unknowns) + result = copy(sts) + for subsys in systems + append!(result, namespace_variables(subsys)) end - @assert typeof(nonunique_unknowns) !== Vector{Any} - unique(nonunique_unknowns) + return result end """ @@ -1331,19 +1369,24 @@ See also [`@parameters`](@ref) and [`ModelingToolkit.get_ps`](@ref). """ function parameters(sys::AbstractSystem; initial_parameters = false) ps = get_ps(sys) - if ps == SciMLBase.NullParameters() + if ps === SciMLBase.NullParameters() return [] end if eltype(ps) <: Pair ps = first.(ps) end systems = get_systems(sys) - result = unique(isempty(systems) ? ps : - [ps; reduce(vcat, namespace_parameters.(systems))]) + if isempty(systems) + return ps + end + result = copy(ps) + for subsys in systems + append!(result, namespace_parameters(subsys)) + end if !initial_parameters && !is_initializesystem(sys) filter!(result) do sym return !(isoperator(sym, Initial) || - iscall(sym) && operation(sym) == getindex && + iscall(sym) && operation(sym) === getindex && isoperator(arguments(sym)[1], Initial)) end end @@ -1539,15 +1582,12 @@ See also [`full_equations`](@ref) and [`ModelingToolkit.get_eqs`](@ref). function equations(sys::AbstractSystem) eqs = get_eqs(sys) systems = get_systems(sys) - if isempty(systems) - return eqs - else - eqs = Equation[eqs; - reduce(vcat, - namespace_equations.(get_systems(sys)); - init = Equation[])] - return eqs + isempty(systems) && return eqs + eqs = copy(eqs) + for subsys in systems + append!(eqs, namespace_equations(subsys)) end + return eqs end """ @@ -1627,10 +1667,12 @@ all the subsystems of `sys` (appropriately namespaced). function jumps(sys::AbstractSystem) js = get_jumps(sys) systems = get_systems(sys) - if isempty(systems) - return js + isempty(systems) && return js + js = copy(js) + for subsys in systems + append!(js, namespace_jumps(subsys)) end - return [js; reduce(vcat, namespace_jumps.(systems); init = [])] + return js end """ @@ -1679,8 +1721,14 @@ end function namespace_constraints(sys) cstrs = constraints(sys) - isempty(cstrs) && return Vector{Union{Equation, Inequality}}(undef, 0) - map(cstr -> namespace_constraint(cstr, sys), cstrs) + isempty(cstrs) && return cstrs + if cstrs === get_constraints(sys) + cstrs = copy(cstrs) + end + for i in eachindex(cstrs) + cstrs[i] = namespace_constraint(cstrs[i], sys) + end + return cstrs end """ @@ -1691,7 +1739,12 @@ Get all constraints in the system `sys` and all of its subsystems, appropriately function constraints(sys::AbstractSystem) cs = get_constraints(sys) systems = get_systems(sys) - isempty(systems) ? cs : [cs; reduce(vcat, namespace_constraints.(systems))] + isempty(systems) && return cs + cs = copy(sys) + for subsys in systems + append!(cs, namespace_constraints(subsys)) + end + return cs end """ @@ -2788,12 +2841,12 @@ function process_parameter_equations(sys::AbstractSystem) if !isempty(get_systems(sys)) throw(ArgumentError("Expected flattened system")) end - varsbuf = Set() + varsbuf = Set{SymbolicT}() pareq_idxs = Int[] eqs = equations(sys) for (i, eq) in enumerate(eqs) empty!(varsbuf) - vars!(varsbuf, eq; op = Union{Differential, Initial, Pre}) + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Union{Differential, Initial, Pre}}()) # singular equations isempty(varsbuf) && continue if all(varsbuf) do sym @@ -2816,8 +2869,11 @@ function process_parameter_equations(sys::AbstractSystem) end end - pareqs = [get_parameter_dependencies(sys); eqs[pareq_idxs]] - explicitpars = [eq.lhs for eq in pareqs] + pareqs = Equation[get_parameter_dependencies(sys); eqs[pareq_idxs]] + explicitpars = SymbolicT[] + for eq in pareqs + push!(explicitpars, eq.lhs) + end pareqs = topsort_equations(pareqs, explicitpars) eqs = eqs[setdiff(eachindex(eqs), pareq_idxs)] diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 09f97079a2..e0dc2ef923 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -508,22 +508,24 @@ function check_index_map(idxmap::Dict{SymbolicT, V}, sym::SymbolicT)::Union{V, N return nothing end +const ReorderedParametersT = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}} + function reorder_parameters( sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) elseif ps isa Tuple - ps + return ReorderedParametersT(collect(ps)) else - (ps,) + eltype(ReorderedParametersT)[ps] end end const COMMON_DEFAULT_VAR = unwrap(only(@variables __DEF__)) function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing = false, flatten = true) - isempty(ps) && return () - result = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}}() + result = ReorderedParametersT() + isempty(ps) && return result param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length) push!(result, param_buf) initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length) From ad058d34b4905911ab5a423b6cc23be2a243f0f7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:33:34 +0530 Subject: [PATCH 051/157] refactor: miscellaneous improvements --- src/ModelingToolkit.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6cf27ea2c5..384ed9de20 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -11,14 +11,16 @@ end import SymbolicUtils import SymbolicUtils as SU import SymbolicUtils: iscall, arguments, operation, maketerm, promote_symtype, - isadd, ismul, ispow, issym, FnType, + isadd, ismul, ispow, issym, FnType, isconst, BSImpl, @rule, Rewriters, substitute, metadata, BasicSymbolic using SymbolicUtils.Code import SymbolicUtils.Code: toexpr import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint using DocStringExtensions using SpecialFunctions, NaNMath -using DiffEqCallbacks +@recompile_invalidations begin + using DiffEqCallbacks +end using Graphs import ExprTools: splitdef, combinedef import OrderedCollections @@ -338,6 +340,8 @@ export AbstractCollocation, JuMPCollocation, InfiniteOptCollocation, CasADiCollocation, PyomoCollocation export DynamicOptSolution +const set_scalar_metadata = setmetadata + @public apply_to_variables, equations_toplevel, unknowns_toplevel, parameters_toplevel @public continuous_events_toplevel, discrete_events_toplevel, assertions, is_alg_equation @public is_diff_equation, Equality, linearize_symbolic, reorder_unknowns From 4af9d7c981933587cc14b94e292403b9a11e0e15 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Oct 2025 18:34:22 +0530 Subject: [PATCH 052/157] fix: minor fix for `evaluate_varmap!` --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 3c64ed05fc..ed42c96d89 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -486,7 +486,7 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) v === nothing && continue symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = fixpoint_sub(v, varmap; maxiters = limit) + varmap[k] = value(fixpoint_sub(v, varmap; maxiters = limit)) end end From ee520b372f79062848d12279dc3026d6390a20e6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:05:58 +0530 Subject: [PATCH 053/157] fix: minor type-stability improvements to `hasbounds` --- src/variables.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/variables.jl b/src/variables.jl index 543617e7de..1c218e0646 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -291,8 +291,8 @@ Create parameters with bounds like this @parameters p [bounds=(-1, 1)] ``` """ -function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) - x = unwrap(x) +getbounds(x::Union{Num, Symbolics.Arr}) = getbounds(unwrap(x)) +function getbounds(x::SymbolicT) if operation(p) === getindex p = arguments(p)[1] bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf)) @@ -329,8 +329,8 @@ Determine whether symbolic variable `x` has bounds associated with it. See also [`getbounds`](@ref). """ function hasbounds(x) - b = getbounds(x) - any(isfinite.(b[1]) .|| isfinite.(b[2])) + b = getbounds(x)::NTuple{2} + any(isfinite.(b[1]) .|| isfinite.(b[2]))::Bool end function setbounds(x::Num, bounds) From b6058b232bbdd7a76f6739ac03d88a013c7e664c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:06:12 +0530 Subject: [PATCH 054/157] fix: minor type-stability improvement to `isoperator` --- src/utils.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index fc5fbda207..b3faa709b8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -387,8 +387,13 @@ function check_operator_variables(eqs, op::T) where {T} end end -isoperator(expr, op) = iscall(expr) && operation(expr) isa op -isoperator(op) = expr -> isoperator(expr, op) +function isoperator(expr::SymbolicT, ::Type{op}) where {op <: SU.Operator} + Moshi.Match.@match expr begin + BSImpl.Term(; f) => f isa op + _ => false + end +end +isoperator(::Type{op}) where {op <: SU.Operator} = Base.Fix2(isoperator, op) isdifferential(expr) = isoperator(expr, Differential) isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) From 97d6f89c2273dd873905f1f0f594e4879460b35c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:06:27 +0530 Subject: [PATCH 055/157] fix: make `flatten_equations` type-stable --- src/utils.jl | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b3faa709b8..967c35225c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1148,25 +1148,34 @@ Given a list of equations where some may be array equations, flatten the array e without scalarizing occurrences of array variables and return the new list of equations. """ function flatten_equations(eqs::Vector{Equation}) - mapreduce(vcat, eqs; init = Equation[]) do eq - islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) - isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) - if islhsarr || isrhsarr - islhsarr && isrhsarr || - error(""" - LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions \ - or both scalar - """) - size(eq.lhs) == size(eq.rhs) || - error(""" - Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got \ - $(size(eq.lhs)) and $(size(eq.rhs)) - """) - return vec(collect(eq.lhs) .~ collect(eq.rhs)) + _eqs = Equation[] + for eq in eqs + shlhs = SU.shape(eq.lhs) + if isempty(shlhs) + push!(_eqs, eq) + continue + end + if length(shlhs) == 1 + lhs = collect(eq.lhs)::Vector{SymbolicT} + rhs = collect(eq.rhs)::Vector{SymbolicT} + for (l, r) in zip(lhs, rhs) + push!(_eqs, l ~ r) + end + elseif length(shlhs) == 2 + lhs = collect(eq.lhs)::Matrix{SymbolicT} + rhs = collect(eq.rhs)::Matrix{SymbolicT} + for (l, r) in zip(lhs, rhs) + push!(_eqs, l ~ r) + end else - eq + lhs = collect(eq.lhs)::Matrix{SymbolicT} + rhs = collect(eq.rhs)::Matrix{SymbolicT} + for (l, r) in zip(lhs, rhs) + push!(_eqs, l ~ r) + end end end + return _eqs end const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump} From 550c10667f2ca053b2e52d3bcef0a652fa00d4e5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:06:57 +0530 Subject: [PATCH 056/157] fix: make `isparameter` type stable --- src/parameters.jl | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/src/parameters.jl b/src/parameters.jl index 7bb76d7bf0..f2eeb3d8e0 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -15,33 +15,17 @@ The symbolic metadata key for storing the `VariableType`. """ struct MTKVariableTypeCtx end -getvariabletype(x, def = VARIABLE) = getmetadata(unwrap(x), MTKVariableTypeCtx, def) +getvariabletype(x, def = VARIABLE) = safe_getmetadata(MTKVariableTypeCtx, unwrap(x), def)::Union{typeof(def), VariableType} """ $TYPEDEF Check if the variable contains the metadata identifying it as a parameter. """ -function isparameter(x) - x = unwrap(x) - - if x isa SymbolicT && (varT = getvariabletype(x, nothing)) !== nothing - return varT === PARAMETER - #TODO: Delete this branch - elseif x isa SymbolicT && iscall(x) && operation(x) === getindex - p = arguments(x)[1] - isparameter(p) || - (hasmetadata(p, Symbolics.VariableSource) && - getmetadata(p, Symbolics.VariableSource)[1] == :parameters) - elseif iscall(x) && operation(x) isa SymbolicT - varT === PARAMETER || isparameter(operation(x)) - elseif iscall(x) && operation(x) == (getindex) - isparameter(arguments(x)[1]) - elseif x isa SymbolicT - varT === PARAMETER - else - false - end +isparameter(x::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isparameter(unwrap(x)) +function isparameter(x::SymbolicT) + varT = getvariabletype(x, nothing) + return varT === PARAMETER end function iscalledparameter(x) From 01f298e8c6d37c8afb61ce531867b9fa0c2ab588 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:07:13 +0530 Subject: [PATCH 057/157] fix: make `input_timedomain` type-stable --- src/discretedomain.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 54f451af78..9226a49be0 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -324,6 +324,8 @@ end Base.:+(k::ShiftIndex, i::Int) = ShiftIndex(k.clock, k.steps + i) Base.:-(k::ShiftIndex, i::Int) = k + (-i) +const InputTimeDomainElT = Union{TimeDomain, InferredTimeDomain} + """ input_timedomain(op::Operator) @@ -334,7 +336,7 @@ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - (InferredDiscrete(),) + InputTimeDomainElT[InferredDiscrete()] end """ @@ -349,14 +351,14 @@ function output_timedomain(s::Shift, arg = nothing) InferredDiscrete() end -input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),) +input_timedomain(::Sample, _ = nothing) = InputTimeDomainElT[ContinuousClock()] output_timedomain(s::Sample, _ = nothing) = s.clock -function input_timedomain(h::Hold, arg = nothing) +function input_timedomain(::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - (InferredDiscrete(),) # the Hold accepts any discrete + InputTimeDomainElT[InferredDiscrete()] # the Hold accepts any discrete end output_timedomain(::Hold, _ = nothing) = ContinuousClock() From c2f3b5fb76c338005d124240f28106d6fa6f6e69 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:08:29 +0530 Subject: [PATCH 058/157] fix: improve type-stability of `default_consolidate`, `flatten` --- src/systems/system.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index 2b2fd9aa7e..e364847009 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -322,10 +322,14 @@ struct System <: IntermediateDeprecationSystem end end +_sum_costs(costs::Vector{SymbolicT}) = SU.add_worker(VartypeT, costs) +_sum_costs(costs::Vector{Num}) = SU.add_worker(VartypeT, costs) +# `reduce` instead of `sum` because the rrule for `sum` doesn't +# handle the `init` kwarg. +_sum_costs(costs::Vector) = reduce(+, costs; init = 0.0) + function default_consolidate(costs, subcosts) - # `reduce` instead of `sum` because the rrule for `sum` doesn't - # handle the `init` kwarg. - return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0) + return _sum_costs(costs) + _sum_costs(subcosts) end unwrap_vars(vars::AbstractArray{SymbolicT}) = vars @@ -792,9 +796,9 @@ function flatten(sys::System, noeqs = false) isempty(systems) && return sys costs = cost(sys) if _iszero(costs) - costs = Union{Real, BasicSymbolic}[] + costs = SymbolicT[] else - costs = [costs] + costs = SymbolicT[costs] end # We don't include `ignored_connections` in the flattened system, because # connection expansion inherently requires the hierarchy structure. If the system From d74eeefe29e69212b65fd8f20d40fa632fdc7cdc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:09:28 +0530 Subject: [PATCH 059/157] fix: improve type-stability of connection infrastructure --- src/systems/analysis_points.jl | 23 +- src/systems/connectiongraph.jl | 4 +- src/systems/connectors.jl | 535 +++++++++++++++++++++------------ 3 files changed, 364 insertions(+), 198 deletions(-) diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index 42bf94eb02..8a1e8f400d 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -161,7 +161,17 @@ end Convert an `AnalysisPoint` to a standard connection. """ function to_connection(ap::AnalysisPoint) - return connect(ap.input, ap.outputs...) + if ap.input isa System + vs = System[ap.input] + append!(vs, ap.outputs::Vector{System}) + return Connection() ~ Connection(vs) + elseif ap.input isa SymbolicT + vs = SymbolicT[ap.input] + append!(vs, ap.outputs::Vector{SymbolicT}) + return Connection() ~ Connection(vs) + else + error("Unreachable!") + end end """ @@ -179,7 +189,7 @@ end # create analysis points via `connect` function connect(in, ap::AnalysisPoint, outs...; verbose = true) - return AnalysisPoint() ~ AnalysisPoint(in, ap.name, collect(outs); verbose) + return AnalysisPoint() ~ AnalysisPoint(unwrap(in), ap.name, collect(unwrap.(outs)); verbose) end """ @@ -249,8 +259,13 @@ end Remove all `AnalysisPoint`s in `sys` and any of its subsystems, replacing them by equivalent connections. """ function remove_analysis_points(sys::AbstractSystem) - eqs = map(get_eqs(sys)) do eq - value(eq.lhs) isa AnalysisPoint ? to_connection(value(eq.rhs)) : eq + eqs = Equation[] + for eq in get_eqs(sys) + if unwrap_const(eq.lhs) isa AnalysisPoint + push!(eqs, to_connection(unwrap_const(eq.rhs)::AnalysisPoint)) + else + push!(eqs, eq) + end end @set! sys.eqs = eqs @set! sys.systems = map(remove_analysis_points, get_systems(sys)) diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 27e542c2ab..00424789e5 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -45,8 +45,8 @@ Create a `ConnectionVertex` given use for this connection. """ function ConnectionVertex( - namespace::Vector{Symbol}, var::Union{BasicSymbolic, AbstractSystem}, isouter::Bool) - if var isa BasicSymbolic + namespace::Vector{Symbol}, var::Union{SymbolicT, AbstractSystem}, isouter::Bool) + if var isa SymbolicT name = getname(var) else name = nameof(var) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 95ca70e6be..af23951cfa 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -23,8 +23,18 @@ Connect multiple connectors created via `@connector`. All connected connectors must be unique. """ function connect(sys1::AbstractSystem, sys2::AbstractSystem, syss::AbstractSystem...) - syss = (sys1, sys2, syss...) - length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") + _syss = System[] + push!(_syss, sys1) + push!(_syss, sys2) + for sys in syss + push!(_syss, sys) + end + syss = _syss + sysnames = Symbol[] + for sys in syss + push!(sysnames, nameof(sys)) + end + allunique(sysnames) || error("connect takes distinct systems!") Equation(Connection(), Connection(syss)) # the RHS are connected systems end @@ -73,11 +83,7 @@ Get the connection type of symbolic variable `s` from the `VariableConnectType` Defaults to `Equality` if not present. """ function get_connection_type(s::SymbolicT) - s = unwrap(s) - if iscall(s) && operation(s) === getindex - s = arguments(s)[1] - end - getmetadata(s, VariableConnectType, Equality) + safe_getmetadata(VariableConnectType, s, Equality)::DataType end """ @@ -171,8 +177,12 @@ get_systems(c::Connection) = c.systems Refer to the [Connection semantics](@ref connect_semantics) section of the docs for more information. """ -instream(a) = term(instream, unwrap(a), type = symtype(a)) -SymbolicUtils.promote_symtype(::typeof(instream), _) = Real +function instream(a::SymbolicT) + BSImpl.Term{VartypeT}(instream, SArgsT((unwrap(a),)); type = symtype(a), shape = SU.shape(a)) +end +instream(a::Num) = Num(instream(unwrap(a))) +instream(a::Symbolics.Arr{T, N}) where {T, N} = Symbolics.Arr{T, N}(instream(unwrap(a))) +SymbolicUtils.promote_symtype(::typeof(instream), ::Type{T}) where {T} = T isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing @@ -183,7 +193,7 @@ Utility struct which wraps a symbolic variable used in a `Connection` to enable to work. """ struct SymbolicWithNameof - var::Any + var::SymbolicT end function Base.nameof(x::SymbolicWithNameof) @@ -192,7 +202,7 @@ end is_causal_variable_connection(c) = false function is_causal_variable_connection(c::Connection) - all(x -> x isa SymbolicWithNameof, get_systems(c)) + all(Base.Fix2(isa, SymbolicWithNameof), get_systems(c)) end const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr} @@ -214,24 +224,36 @@ end Perform validation for a connect statement involving causal variables. """ -function validate_causal_variables_connection(allvars) - var1 = allvars[1] - var2 = allvars[2] - vars = Base.tail(Base.tail(allvars)) +function validate_causal_variables_connection(allvars::Vector{SymbolicT}) for var in allvars vtype = getvariabletype(var) vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) end - if length(unique(allvars)) !== length(allvars) + if !allunique(allvars) throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) end - allsizes = map(size, allvars) - if !allequal(allsizes) - throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively.")) + sh1 = SU.shape(allvars[1])::SU.ShapeVecT + sz1 = SU.SmallV{Int}() + for x in sh1 + push!(sz1, length(x)) end - non_causal_variables = filter(allvars) do var - !isinput(var) && !isoutput(var) + sz2 = SU.SmallV{Int}() + for v in allvars + sh = SU.shape(v)::SU.ShapeVecT + empty!(sz2) + for x in sh + push!(sz2, length(x)) + end + if !isequal(sz1, sz2) + throw(ArgumentError("Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively.")) + + end + end + non_causal_variables = SymbolicT[] + for x in allvars + isinput(x) || isoutput(x) || continue + push!(non_causal_variables, x) end isempty(non_causal_variables) || throw(NonCausalVariableError(non_causal_variables)) end @@ -250,9 +272,14 @@ var1 ~ var3 """ function connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...) - allvars = (var1, var2, vars...) + allvars = SymbolicT[] + push!(allvars, unwrap(var1)) + push!(allvars, unwrap(var2)) + for var in vars + push!(allvars, unwrap(var)) + end validate_causal_variables_connection(allvars) - return Equation(Connection(), Connection(map(SymbolicWithNameof, unwrap.(allvars)))) + return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars))) end """ @@ -302,6 +329,27 @@ mydiv(num, den) = end @register_symbolic mydiv(n, d) +struct IsOuter + outer_connectors::Set{Symbol} +end + +function (io::IsOuter)(name::Symbol) + name in io.outer_connectors +end + +function (io::IsOuter)(sys) + nm = nameof(sys) + isconnector(sys) || error("$nm is not a connector!") + s = string(nm) + idx = findfirst(NAMESPACE_SEPARATOR, s) + parent_name = if idx === nothing + nm + else + Symbol(@view(s[1:prevind(s, idx)])) + end + return io(parent_name) +end + """ $(TYPEDSIGNATURES) @@ -309,23 +357,12 @@ Return a function which checks whether the connector (system) passed to it is an connector of `sys`. The function can also be given the name of a system as a `Symbol`. """ function generate_isouter(sys::AbstractSystem) - outer_connectors = Symbol[] + outer_connectors = Set{Symbol}() for s in get_systems(sys) n = nameof(s) isconnector(s) && push!(outer_connectors, n) end - let outer_connectors = outer_connectors - function isouter(sys)::Bool - s = string(nameof(sys)) - isconnector(sys) || error("$s is not a connector!") - idx = findfirst(isequal(NAMESPACE_SEPARATOR), s) - parent_name = Symbol(idx === nothing ? s : s[1:prevind(s, idx)]) - isouter(parent_name) - end - function isouter(name::Symbol)::Bool - return name in outer_connectors - end - end + return IsOuter(outer_connectors) end @noinline function connection_error(ss) @@ -336,14 +373,24 @@ abstract type IsFrame end "Return true if the system is a 3D multibody frame, otherwise return false." function isframe(sys) - getmetadata(sys, IsFrame, false) + getmetadata(sys, IsFrame, false)::Bool end abstract type FrameOrientation end +struct RotationMatrix + R::Matrix{SymbolicT} + w::Vector{SymbolicT} + + function RotationMatrix(R::AbstractMatrix, w::AbstractVector) + new(unwrap_vars(R), unwrap_vars(w)) + end + +end + "Return orientation object of a multibody frame." function ori(sys) - getmetadata(sys, FrameOrientation, nothing) + getmetadata(sys, FrameOrientation, nothing)::Union{RotationMatrix, Nothing} end """ @@ -373,13 +420,13 @@ index_from_type(::Type{OutputVar{I}}) where {I} = I Chain `getproperty` calls on sys in the order given by `names` and return the unwrapped result. """ -function iterative_getproperty(sys::AbstractSystem, names::AbstractVector{Symbol}) +function iterative_getproperty(sys::AbstractSystem, names::Vector{Symbol}) # we don't want to namespace the first time - result = toggle_namespacing(sys, false) + result::Union{SymbolicT, System} = toggle_namespacing(sys, false) for name in names - result = getproperty(result, name) + result = getvar(result, name)::Union{SymbolicT, System} end - return unwrap(result) + return result end """ @@ -389,10 +436,13 @@ Return the variable/subsystem of `sys` referred to by vertex `vert`. """ function variable_from_vertex(sys::AbstractSystem, vert::ConnectionVertex) value = iterative_getproperty(sys, vert.name) - value isa AbstractSystem && return value + value isa System && return value + value = value::SymbolicT vert.type <: Union{InputVar, OutputVar} || return value + vert.type === InputVar{CartesianIndex()} && return value + vert.type === OutputVar{CartesianIndex()} && return value # index possibly array causal variable - unwrap(wrap(value)[index_from_type(vert.type)]) + value[index_from_type(vert.type)]::SymbolicT end """ @@ -408,7 +458,7 @@ function returned from [`generate_isouter`](@ref) for the system referred to by `namespace` must not contain the name of the root system. """ function generate_connectionsets!(connection_state::AbstractConnectionState, - namespace::Vector{Symbol}, connected, isouter) + namespace::Vector{Symbol}, connected, isouter::IsOuter) initial_len = length(namespace) _generate_connectionsets!(connection_state, namespace, connected, isouter) # Enforce postcondition as a sanity check that the namespacing is implemented correctly @@ -418,51 +468,57 @@ end function _generate_connectionsets!(connection_state::AbstractConnectionState, namespace::Vector{Symbol}, - connected_vars::Union{ - AbstractVector{SymbolicWithNameof}, Tuple{Vararg{SymbolicWithNameof}}}, - isouter) + connected_vars::Vector{SymbolicWithNameof}, + isouter::IsOuter) # unwrap the `SymbolicWithNameof` into the contained symbolic variables. - connected_vars = map(x -> x.var, connected_vars) - _generate_connectionsets!(connection_state, namespace, connected_vars, isouter) + _connected_vars = SymbolicT[] + for x in connected_vars + push!(_connected_vars, x.var) + end + _generate_connectionsets!(connection_state, namespace, _connected_vars, isouter) end -function _generate_connectionsets!(connection_state::AbstractConnectionState, - namespace::Vector{Symbol}, - connected_vars::Union{ - AbstractVector{<:BasicSymbolic}, Tuple{Vararg{BasicSymbolic}}}, - isouter) - # NOTE: variable connections don't populate the domain network +@noinline function throw_both_input_output(var::SymbolicT, connected_vars::Vector{SymbolicT}) + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is both input and output. + """)) +end +@noinline function throw_not_input_output(var::SymbolicT, connected_vars::Vector{SymbolicT}) + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is neither input nor output. + """)) +end - # wrap to be able to call `eachindex` on a non-array variable - representative = wrap(first(connected_vars)) +function _generate_connectionsets_with_idxs!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, connected_vars::Vector{SymbolicT}, isouter::IsOuter, + idxs::CartesianIndices{N, NTuple{N, UnitRange{Int}}}) where {N} # all of them have the same size, but may have different axes/shape # so we iterate over `eachindex(eachindex(..))` since that is identical for all - for sz_i in eachindex(eachindex(representative)) - hyperedge = map(connected_vars) do var - var = unwrap(var) + for sz_i in eachindex(idxs) + hyperedge = ConnectionVertex[] + for var in connected_vars var_ns = namespace_hierarchy(getname(var)) - i = eachindex(wrap(var))[sz_i] - + if N === 0 + i = sz_i + else + i = (eachindex(var)::CartesianIndices{N, NTuple{N, UnitRange{Int}}})[sz_i]::CartesianIndex{N} + end is_input = isinput(var) is_output = isoutput(var) if is_input && is_output - names = join(string.(connected_vars), ", ") - throw(ArgumentError(""" - Variable $var in connection `connect($names)` is both input and output. - """)) + throw_both_input_output(var, connected_vars) elseif is_input type = InputVar{i} elseif is_output type = OutputVar{i} else - names = join(string.(connected_vars), ", ") - throw(ArgumentError(""" - Variable $var in connection `connect($names)` is neither input nor output. - """)) + throw_not_input_output(var, connected_vars) end - - return ConnectionVertex( + vert = ConnectionVertex( [namespace; var_ns], length(var_ns) == 1 || isouter(var_ns[1]), type) + push!(hyperedge, vert) end add_connection_edge!(connection_state, hyperedge) @@ -480,10 +536,36 @@ end function _generate_connectionsets!(connection_state::AbstractConnectionState, namespace::Vector{Symbol}, - systems::Union{AbstractVector{<:AbstractSystem}, Tuple{Vararg{AbstractSystem}}}, - isouter) + connected_vars::Vector{SymbolicT}, + isouter::IsOuter) + # NOTE: variable connections don't populate the domain network + + representative = first(connected_vars) + idxs = eachindex(representative) + # Manual dispatch for common cases + if idxs isa CartesianIndices{0, Tuple{}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + elseif idxs isa CartesianIndices{1, Tuple{UnitRange{Int}}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + elseif idxs isa CartesianIndices{2, NTuple{2, UnitRange{Int}}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + else + # Dynamic dispatch + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + end +end + +function _generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, + systems::Vector{T}, + isouter::IsOuter) where {T <: AbstractSystem} + systems = systems::Vector{System} regular_systems = System[] - domain_system = nothing + domain_system::Union{Nothing, System} = nothing for s in systems if is_domain_connector(s) if domain_system === nothing @@ -518,7 +600,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, push!(domain_hyperedge, domain_vertex) push!(hyperedge, dv_vertex) - for (i, sys) in enumerate(systems) + for sys in systems sts = unknowns(sys) sys_is_outer = isouter(sys) @@ -526,6 +608,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # are properly namespaced sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) for v in sts vtype = get_connection_type(v) @@ -539,7 +622,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, push!(domain_hyperedge, sys_vertex) end # remember to remove the added namespace! - foreach(_ -> pop!(namespace), sys_ns) + resize!(namespace, N) end @assert length(hyperedge) > 1 @assert length(domain_hyperedge) == length(hyperedge) @@ -553,10 +636,10 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # Add 9 orientation variables if connection is between multibody frames if isframe(sys1) # Multibody O = ori(sys1) - orientation_vars = Symbolics.unwrap.(collect(vec(O.R))) - sys1_dvs = [sys1_dvs; orientation_vars] + orientation_vars = vec(O.R) + sys1_dvs = SymbolicT[sys1_dvs; orientation_vars] end - sys1_dvs_set = Set(sys1_dvs) + sys1_dvs_set = Set{SymbolicT}(sys1_dvs) num_unknowns = length(sys1_dvs) # We first build sets of all vertices that are connected together @@ -567,8 +650,8 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # Add 9 orientation variables if connection is between multibody frames if isframe(sys) # Multibody O = ori(sys) - orientation_vars = Symbolics.unwrap.(vec(O.R)) - unknown_vars = [unknown_vars; orientation_vars] + orientation_vars = vec(O.R) + unknown_vars = SymbolicT[unknown_vars; orientation_vars] end # Error if any subsequent systems do not have the same number of unknowns # or have unknowns not in the others. @@ -580,6 +663,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # are properly namespaced sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) sys_is_outer = isouter(sys) for (j, v) in enumerate(unknown_vars) @@ -588,7 +672,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, domain_vertex = ConnectionVertex(namespace) push!(domain_hyperedge, domain_vertex) # remember to remove the added namespace! - foreach(_ -> pop!(namespace), sys_ns) + resize!(namespace, N) end for var_set in var_sets # all connected variables should have the same type @@ -630,35 +714,39 @@ can be pushed, unmodified. Connection equations update the given `state`. The eq present at the path in the hierarchical system given by `namespace`. `isouter` is the function returned from `generate_isouter`. """ -function handle_maybe_connect_equation!(eqs, state::AbstractConnectionState, - eq::Equation, namespace::Vector{Symbol}, isouter) +function handle_maybe_connect_equation!(eqs::Vector{Equation}, state::AbstractConnectionState, + eq::Equation, namespace::Vector{Symbol}, isouter::IsOuter) lhs = value(eq.lhs) rhs = value(eq.rhs) if !(lhs isa Connection) # split connections and equations - if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray - append!(eqs, Symbolics.scalarize(eq)) - else - push!(eqs, eq) - end + push!(eqs, eq) return end + lhs = lhs::Connection + rhs = rhs::Connection + handle_maybe_connect_equation!(state, lhs, rhs, namespace, isouter) +end +function handle_maybe_connect_equation!(state::AbstractConnectionState, + lhs::Connection, rhs::Connection, namespace::Vector{Symbol}, isouter::IsOuter) if get_systems(lhs) === :domain # This is a domain connection, so we only update the domain connection graph - hyperedge = map(get_systems(rhs)) do sys - sys isa AbstractSystem || error("Domain connections can only connect systems!") + syss = get_systems(rhs)::Vector{System} + hyperedge = ConnectionVertex[] + for sys in syss sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) vertex = ConnectionVertex(namespace) - foreach(_ -> pop!(namespace), sys_ns) - return vertex + resize!(namespace, N) + push!(hyperedge, vertex) end add_domain_connection_edge!(state, hyperedge) else - connected_systems = get_systems(rhs) + connected_systems = get_systems(rhs)::Union{Vector{System}, Vector{SymbolicWithNameof}} generate_connectionsets!(state, namespace, connected_systems, isouter) end return nothing @@ -712,12 +800,13 @@ function _generate_connection_set!(connection_state::ConnectionState, end # go through the removed connections and update the negative graph - for conn in something(get_ignored_connections(sys), ()) - eq = Equation(Connection(), conn) - # there won't be any standard equations, so we can pass `nothing` instead of - # `eqs`. - handle_maybe_connect_equation!( - nothing, negative_connection_state, eq, namespace, isouter) + ignored = get_ignored_connections(sys) + if ignored isa Vector{Connection} + for conn in ignored + # there won't be any standard equations, so we can pass `nothing` instead of + # `eqs`. + handle_maybe_connect_equation!(negative_connection_state, Connection(), conn, namespace, isouter) + end end # all connectors are eventually inside connectors, and all flow variables @@ -732,17 +821,41 @@ function _generate_connection_set!(connection_state::ConnectionState, end pop!(namespace) end - - # recurse down the hierarchy - @set! sys.systems = map(subsys) do s - generate_connection_set!(connection_state, negative_connection_state, s, namespace) + new_systems = System[] + for s in subsys + news = generate_connection_set!(connection_state, negative_connection_state, s, namespace) + push!(new_systems, news) end + # recurse down the hierarchy + @set! sys.systems = new_systems @set! sys.eqs = eqs # Remember to pop the name at the end! does_namespacing(sys) && pop!(namespace) return sys end +function _flow_equations_from_idxs!(eqs::Vector{Equation}, cset::Vector{ConnectionVertex}, idxs::CartesianIndices{N, NTuple{N, UnitRange{Int}}}) where {N} + add_buffer = SymbolicT[] + # each variable can have different axes, but they all have the same size + for sz_i in eachindex(idxs) + empty!(add_buffer) + for cvert in cset + # all of this wrapping/unwrapping is necessary because the relevant + # methods are defined on `Arr/Num` and not `BasicSymbolic`. + v = variable_from_vertex(sys, cvert)::SymbolicT + if N === 0 + v = v + else + vidxs = eachindex(v)::CartesianIndices{N, NTuple{N, UnitRange{Int}}} + v = v[vidxs[sz_i]] + end + push!(add_buffer, cvert.isouter ? -v : v) + end + rhs = SU.add_worker(VartypeT, add_buffer) + push!(eqs, Symbolics.COMMON_ZERO ~ rhs) + end +end + """ $(TYPEDSIGNATURES) @@ -756,7 +869,7 @@ function generate_connection_equations_and_stream_connections( for cset in csets cvert = cset[1] - var = variable_from_vertex(sys, cvert)::BasicSymbolic + var = variable_from_vertex(sys, cvert)::SymbolicT vtype = cvert.type if vtype <: Union{InputVar, OutputVar} length(cset) > 1 || continue @@ -782,10 +895,10 @@ function generate_connection_equations_and_stream_connections( end end root_vert = something(inner_output, outer_input) - root_var = variable_from_vertex(sys, root_vert) + root_var = variable_from_vertex(sys, root_vert)::SymbolicT for cvert in cset isequal(cvert, root_vert) && continue - push!(eqs, variable_from_vertex(sys, cvert) ~ root_var) + push!(eqs, variable_from_vertex(sys, cvert)::SymbolicT ~ root_var) end elseif vtype === Stream push!(stream_connections, cset) @@ -793,48 +906,48 @@ function generate_connection_equations_and_stream_connections( # arrays have to be broadcasted to be added/subtracted/negated which leads # to bad-looking equations. Just generate scalar equations instead since # mtkcompile will scalarize anyway. - representative = variable_from_vertex(sys, cset[1]) - # each variable can have different axes, but they all have the same size - for sz_i in eachindex(eachindex(wrap(representative))) - rhs = 0 - for cvert in cset - # all of this wrapping/unwrapping is necessary because the relevant - # methods are defined on `Arr/Num` and not `BasicSymbolic`. - v = variable_from_vertex(sys, cvert)::BasicSymbolic - idxs = eachindex(wrap(v)) - v = unwrap(wrap(v)[idxs[sz_i]]) - rhs += cvert.isouter ? unwrap(-wrap(v)) : v - end - push!(eqs, 0 ~ rhs) + representative = variable_from_vertex(sys, cset[1])::SymbolicT + idxs = eachindex(representative) + if idxs isa CartesianIndices{0, Tuple{}} + _flow_equations_from_idxs!(eqs, cset, idxs) + elseif idxs isa CartesianIndices{1, Tuple{UnitRange{Int}}} + _flow_equations_from_idxs!(eqs, cset, idxs) + elseif idxs isa CartesianIndices{2, NTuple{2, UnitRange{Int}}} + _flow_equations_from_idxs!(eqs, cset, idxs) + else + _flow_equations_from_idxs!(eqs, cset, idxs) end else # Equality - vars = map(Base.Fix1(variable_from_vertex, sys), cset) - outer_input = inner_output = nothing + vars = SymbolicT[] + for cvar in cset + push!(vars, variable_from_vertex(sys, cvar)::SymbolicT) + end + outer_input = inner_output = 0 all_io = true # attempt to interpret the equality as a causal connectionset if # possible - for (cvert, vert) in zip(cset, vars) + for (i, vert) in enumerate(vars) is_i = isinput(vert) is_o = isoutput(vert) all_io &= is_i || is_o all_io || break if cvert.isouter && is_i && outer_input === nothing - outer_input = cvert + outer_input = i elseif !cvert.isouter && is_o && inner_output === nothing - inner_output = cvert + inner_output = i end end # this doesn't necessarily mean this is a well-structured causal connection, # but it is sufficient and we're generating equalities anyway. - if all_io && xor(outer_input !== nothing, inner_output !== nothing) - root_vert = something(inner_output, outer_input) - root_var = variable_from_vertex(sys, root_vert) - for (cvert, var) in zip(cset, vars) - isequal(cvert, root_vert) && continue + if all_io && xor(!iszero(outer_input), !iszero(inner_output)) + root_vert_i = iszero(outer_input) ? inner_output : outer_input + root_var = vars[root_vert_i] + for (i, var) in enumerate(vars) + i == root_vert_i && continue push!(eqs, var ~ root_var) end else - base = variable_from_vertex(sys, cset[1]) + base = vars[1] for i in 2:length(cset) v = vars[i] push!(eqs, base ~ v) @@ -852,12 +965,15 @@ Generate the defaults for parameters in the domain sets given by `domain_csets`. """ function domain_defaults( sys::AbstractSystem, domain_csets::Vector{Vector{ConnectionVertex}}) - defs = Dict() + defs = Dict{SymbolicT, SymbolicT}() for cset in domain_csets - systems = map(Base.Fix1(variable_from_vertex, sys), cset) - @assert all(x -> x isa AbstractSystem, systems) + systems = System[] + for cvar in cset + push!(systems, variable_from_vertex(sys, cvar)::System) + end idx = findfirst(is_domain_connector, systems) idx === nothing && continue + idx = idx::Int domain_sys = systems[idx] # note that these will not be namespaced with `domain_sys`. domain_defs = defaults(domain_sys) @@ -871,7 +987,7 @@ function domain_defaults( for par in parameters(csys) defval = get(domain_defs, par, nothing) defval === nothing && continue - defs[parameters(csys, par)] = parameters(domain_sys, par) + defs[renamespace(csys, par)] = renamespace(domain_sys, par) end end end @@ -915,16 +1031,24 @@ Given a connection vertex `cvert` referring to a variable in a connector in `sys the flow variable in that connector. """ function get_flowvar(sys::AbstractSystem, cvert::ConnectionVertex) - parent_names = @view cvert.name[1:(end - 1)] - parent_sys = iterative_getproperty(sys, parent_names) + tmp = pop!(cvert.name) + parent_sys = iterative_getproperty(sys, cvert.name)::System + push!(cvert.name, tmp) for var in unknowns(parent_sys) type = get_connection_type(var) - type == Flow || continue - return unwrap(unknowns(parent_sys, var)) + type === Flow || continue + return renamespace(parent_sys, var) end throw(ArgumentError("There is no flow variable in system `$(nameof(parent_sys))`")) end +function instream_is_atomic(ex::SymbolicT) + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f === instream end => true + _ => false + end +end + """ $(TYPEDSIGNATURES) @@ -937,42 +1061,54 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS tol = 1e-8) eqs = equations(sys) # collect all `instream` terms in the equations - instream_exprs = Set{BasicSymbolic}() + instream_exprs = Set{SymbolicT}() for eq in eqs - collect_instream!(instream_exprs, eq) + SU.search_variables!(instream_exprs, eq; is_atomic = instream_is_atomic) end # specifically substitute `instream(x[i]) => instream(x)[i]` - instream_subs = Dict{BasicSymbolic, BasicSymbolic}() + instream_subs = Dict{SymbolicT, SymbolicT}() for expr in instream_exprs - stream_var = only(arguments(expr)) - iscall(stream_var) && operation(stream_var) === getindex || continue - args = arguments(stream_var) - new_expr = term(instream, args[1]; type = symtype(args[1]), shape = SU.shape(args[1]))[args[2:end]...] - instream_subs[expr] = new_expr + exargs = Moshi.Data.variant_getfield(expr, BSImpl.Term{VartypeT}, :args) + stream_var = only(exargs) + Moshi.Match.@match stream_var begin + BSImpl.Term(; f, args, type, shape) && if f === getindex end => begin + newargs = copy(parent(args)) + arg = newargs[1] + sharg = SU.shape(arg) + starg = SU.symtype(arg) + newargs[1] = BSImpl.Term{VartypeT}(instream, SArgsT((arg,)); type = starg, shape = sharg) + new_expr = BSImpl.Term{VartypeT}(getindex, newargs; type, shape) + instream_subs[expr] = new_expr + end + _ => nothing + end end # for all the newly added `instream(x)[i]`, add `instream(x)` to `instream_exprs` # also remove all `instream(x[i])` for (k, v) in instream_subs - push!(instream_exprs, arguments(v)[1]) + push!(instream_exprs, Moshi.Match.@match v begin + BSImpl.Term(; args) => args[1] + end) delete!(instream_exprs, k) end # This is an implementation of the modelica spec # https://specification.modelica.org/maint/3.6/stream-connectors.html additional_eqs = Equation[] + add_buffer = SymbolicT[] for cset in csets n_outer = count(cvert -> cvert.isouter, cset) n_inner = length(cset) - n_outer if n_inner == 1 && n_outer == 0 cvert = only(cset) - stream_var = variable_from_vertex(sys, cvert)::BasicSymbolic + stream_var = variable_from_vertex(sys, cvert)::SymbolicT instream_subs[instream(stream_var)] = stream_var elseif n_inner == 2 && n_outer == 0 cvert1, cvert2 = cset - stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic - stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + stream_var1 = variable_from_vertex(sys, cvert1)::SymbolicT + stream_var2 = variable_from_vertex(sys, cvert2)::SymbolicT instream_subs[instream(stream_var1)] = stream_var2 instream_subs[instream(stream_var2)] = stream_var1 elseif n_inner == 1 && n_outer == 1 @@ -980,14 +1116,14 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS if cvert_inner.isouter cvert_inner, cvert_outer = cvert_outer, cvert_inner end - streamvar_inner = variable_from_vertex(sys, cvert_inner)::BasicSymbolic - streamvar_outer = variable_from_vertex(sys, cvert_outer)::BasicSymbolic + streamvar_inner = variable_from_vertex(sys, cvert_inner)::SymbolicT + streamvar_outer = variable_from_vertex(sys, cvert_outer)::SymbolicT instream_subs[instream(streamvar_inner)] = instream(streamvar_outer) push!(additional_eqs, (streamvar_outer ~ streamvar_inner)) elseif n_inner == 0 && n_outer == 2 cvert1, cvert2 = cset - stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic - stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + stream_var1 = variable_from_vertex(sys, cvert1)::SymbolicT + stream_var2 = variable_from_vertex(sys, cvert2)::SymbolicT push!(additional_eqs, (stream_var1 ~ instream(stream_var2)), (stream_var2 ~ instream(stream_var1))) else @@ -996,55 +1132,70 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS # https://specification.modelica.org/maint/3.6/stream-connectors.html#instream-and-connection-equations # We could implement the "if" case using variable bounds? It would be nice to # move that metadata to the system (storing it similar to `defaults`). - outer_cverts = filter(cvert -> cvert.isouter, cset) - inner_cverts = filter(cvert -> !cvert.isouter, cset) - - outer_streamvars = map(Base.Fix1(variable_from_vertex, sys), outer_cverts) - inner_streamvars = map(Base.Fix1(variable_from_vertex, sys), inner_cverts) - - outer_flowvars = map(Base.Fix1(get_flowvar, sys), outer_cverts) - inner_flowvars = map(Base.Fix1(get_flowvar, sys), inner_cverts) + outer_cverts = ConnectionVertex[] + inner_cverts = ConnectionVertex[] + outer_streamvars = SymbolicT[] + inner_streamvars = SymbolicT[] + outer_flowvars = SymbolicT[] + inner_flowvars = SymbolicT[] + for cvert in cset + svar = variable_from_vertex(sys, cvert)::SymbolicT + fvar = get_flowvar(sys, cvert)::SymbolicT + push!(cvert.isouter ? outer_cverts : inner_cverts, cvert) + push!(cvert.isouter ? outer_streamvars : inner_streamvars, svar) + push!(cvert.isouter ? outer_flowvars : inner_flowvars, fvar) + end - mask = trues(length(inner_cverts)) for inner_i in eachindex(inner_cverts) - # mask out the current variable - mask[inner_i] = false svar = inner_streamvars[inner_i] - instream_subs[instream(svar)] = term( - instream_rt, Val(n_inner - 1), Val(n_outer), inner_flowvars[mask]..., - inner_streamvars[mask]..., outer_flowvars..., outer_streamvars...) - # make sure to reset the mask - mask[inner_i] = true + args = SArgsT() + push!(args, SU.Const{VartypeT}(Val(n_inner - 1))) + push!(args, SU.Const{VartypeT}(Val(n_outer - 1))) + for i in eachindex(inner_cverts) + i == inner_i && continue + push!(args, inner_flowvars[i]) + end + for i in eachindex(inner_cverts) + i == inner_i && continue + push!(args, inner_streamvars[i]) + end + append!(args, outer_flowvars) + append!(args, outer_streamvars) + expr = BSImpl.Term{VartypeT}(instream_rt, args; + type = Real, shape = SU.ShapeVecT()) + instream_subs[instream(svar)] = expr end for q in 1:n_outer - sq = mapreduce(+, inner_flowvars) do fvar - max(-fvar, 0) + empty!(add_buffer) + for fvar in inner_flowvars + push!(add_buffer, max(-fvar, 0)) end - sq += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) - outer_i == q && return 0 - max(fvar, 0) + for (i, fvar) in outer_flowvars + i == q && continue + push!(add_buffer, max(fvar, 0)) end - # sanity check to make sure it isn't going to codegen a `mapreduce` - @assert operation(sq) == (+) + sq = SU.add_worker(VartypeT, add_buffer) - num = mapreduce(+, inner_flowvars, inner_streamvars) do fvar, svar - positivemax(-fvar, sq; tol) * svar + empty!(add_buffer) + for (fvar, svar) in zip(inner_flowvars, inner_streamvars) + push!(add_buffer, positivemax(-fvar, sq; tol) * svar) end - num += mapreduce( - +, enumerate(outer_flowvars), outer_streamvars) do (outer_i, fvar), svar - outer_i == q && return 0 - positivemax(fvar, sq; tol) * instream(svar) + for (i, (fvar, svar)) in enumerate(zip(outer_flowvars, outer_streamvars)) + i == q && continue + push!(add_buffer, positivemax(fvar, sq; tol) * instream(svar)) end - @assert operation(num) == (+) + num = SU.add_worker(VartypeT, add_buffer) - den = mapreduce(+, inner_flowvars) do fvar - positivemax(-fvar, sq; tol) + empty!(add_buffer) + for fvar in inner_flowvars + push!(add_buffer, positivemax(-fvar, sq; tol)) end - den += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) - outer_i == q && return 0 - positivemax(fvar, sq; tol) + for (i, fvar) in enumerate(outer_flowvars) + i == q && continue + push!(add_buffer, positivemax(fvar, sq; tol)) end + den = SU.add_worker(VartypeT, add_buffer) push!(additional_eqs, (outer_streamvars[q] ~ num / den)) end From 5fba4988a9e6e8ac26a21f21aa057b6f548dba1a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:10:06 +0530 Subject: [PATCH 060/157] fix: improve type-stability of `getvar` --- src/systems/abstractsystem.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 8c6cc93c53..5c754dfde1 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -976,8 +976,12 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace = does_namespacing( if has_eqs(sys) for eq in get_eqs(sys) eq isa Equation || continue - if eq.lhs isa AnalysisPoint && nameof(eq.rhs) == name - return namespace ? renamespace(sys, eq.rhs) : eq.rhs + lhs = value(eq.lhs) + rhs = value(eq.rhs) + if lhs isa AnalysisPoint + rhs = rhs::AnalysisPoint + nameof(rhs) == name || continue + return namespace ? renamespace(sys, rhs) : rhs end end end From f702020957aee483122a10a1455201b634ba3d16 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:10:25 +0530 Subject: [PATCH 061/157] fix: improve type-stability of aggregator functions --- src/systems/abstractsystem.jl | 41 ++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 5c754dfde1..b4093842c8 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1258,7 +1258,14 @@ function namespace_jumps(sys::AbstractSystem) end function namespace_brownians(sys::AbstractSystem) - return [renamespace(sys, b) for b in brownians(sys)] + bs = brownians(sys) + if bs === get_brownians(sys) + bs = copy(bs) + end + for i in eachindex(bs) + bs[i] = renamespace(sys, bs[i]) + end + return bs end function namespace_assignment(eq::Assignment, sys) @@ -1519,10 +1526,15 @@ See also [`observables`](@ref) and [`ModelingToolkit.get_observed()`](@ref). function observed(sys::AbstractSystem) obs = get_observed(sys) systems = get_systems(sys) - [obs; - reduce(vcat, - (map(o -> namespace_equation(o, s), observed(s)) for s in systems), - init = Equation[])] + isempty(systems) && return obs + obs = copy(obs) + for subsys in systems + _obs = observed(subsys) + for eq in _obs + push!(obs, namespace_equation(eq, subsys)) + end + end + return obs end """ @@ -1691,7 +1703,11 @@ function brownians(sys::AbstractSystem) if isempty(systems) return bs end - return [bs; reduce(vcat, namespace_brownians.(systems); init = [])] + bs = copy(bs) + for subsys in systems + append!(bs, namespace_brownians(subsys)) + end + return bs end """ @@ -1705,10 +1721,13 @@ function cost(sys::AbstractSystem) consolidate = get_consolidate(sys) systems = get_systems(sys) if isempty(systems) - return consolidate(cs, Float64[]) + return consolidate(cs, Float64[])::SymbolicT end - subcosts = [namespace_expr(cost(subsys), subsys) for subsys in systems] - return consolidate(cs, subcosts) + subcosts = SymbolicT[] + for subsys in systems + push!(subcosts, namespace_expr(cost(subsys), subsys)) + end + return consolidate(cs, subcosts)::SymbolicT end namespace_constraint(eq::Equation, sys) = namespace_equation(eq, sys) @@ -2853,13 +2872,15 @@ function process_parameter_equations(sys::AbstractSystem) SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Union{Differential, Initial, Pre}}()) # singular equations isempty(varsbuf) && continue - if all(varsbuf) do sym + if let sys = sys + all(varsbuf) do sym is_parameter(sys, sym) || symbolic_type(sym) == ArraySymbolic() && symbolic_has_known_size(sym) && all(Base.Fix1(is_parameter, sys), collect(sym)) || iscall(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[1]) + end end # Everything in `varsbuf` is a parameter, so this is a cheap `is_parameter` # check. From 5aeee69f7c433b6ea782703328461a925ee0d13b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:10:41 +0530 Subject: [PATCH 062/157] fix: improve type-stability of `noise_to_brownians` --- src/systems/diffeqs/basic_transformations.jl | 32 +++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index cb173a15fc..b7baad8861 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -836,33 +836,43 @@ function noise_to_brownians(sys::System; names::Union{Symbol, Vector{Symbol}} = if neqs === nothing throw(ArgumentError("Expected a system with `noise_eqs`.")) end + neqs = neqs::Union{Vector{SymbolicT}, Matrix{SymbolicT}} if !isempty(get_systems(sys)) throw(ArgumentError("The system must be flattened.")) end # vector means diagonal noise - nbrownians = ndims(neqs) == 1 ? length(neqs) : size(neqs, 2) - if names isa Symbol - names = [Symbol(names, :_, i) for i in 1:nbrownians] + nbrownians = if neqs isa Vector{SymbolicT} + length(neqs) + elseif neqs isa Matrix{SymbolicT} + size(neqs, 2) end - if length(names) != nbrownians + if names isa Symbol + _names = Symbol[] + for i in 1:nbrownians + push!(_names, Symbol(names, :_, i)) + end + names = _names + elseif names isa Vector{Symbol} && length(names) != nbrownians throw(ArgumentError(""" The system has $nbrownians brownian variables. Received $(length(names)) names \ for the brownian variables. Provide $nbrownians names or a single `Symbol` to use \ an array variable of the appropriately length. """)) end - brownvars = map(names) do name - unwrap(only(@brownians $name)) + names = names::Vector{Symbol} + brownvars = SymbolicT[] + for name in names + push!(brownvars, unwrap(only(@brownians $name))) end - - terms = if ndims(neqs) == 1 + terms = if neqs isa Vector{SymbolicT} neqs .* brownvars - else + elseif neqs isa Matrix{SymbolicT} neqs * brownvars end - eqs = map(get_eqs(sys), terms) do eq, term - eq.lhs ~ eq.rhs + term + eqs = Equation[] + for (eq, term) in zip(get_eqs(sys), terms) + push!(eqs, eq.lhs ~ eq.rhs + term) end @set! sys.eqs = eqs From 7f8f9dcc1268c0b4e49866f8f26bc67f2ea807f1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:31:54 +0530 Subject: [PATCH 063/157] fix: improve type-stability of `simplify_optimization_system` --- src/systems/systems.jl | 54 +++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 6aa4166d03..f49c901867 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -175,42 +175,62 @@ function simplify_optimization_system(sys::System; split = true, kwargs...) sys = flatten(sys) cons = constraints(sys) econs = Equation[] - icons = similar(cons, 0) + icons = Inequality[] for e in cons if e isa Equation push!(econs, e) - else + elseif e isa Inequality push!(icons, e) end end - irreducible_subs = Dict() - dvs = mapreduce(Symbolics.scalarize, vcat, unknowns(sys)) - if !(dvs isa Array) - dvs = [dvs] + irreducible_subs = Dict{SymbolicT, SymbolicT}() + dvs = SymbolicT[] + for var in unknowns(sys) + sh = SU.shape(var)::SU.ShapeVecT + if isempty(sh) + push!(dvs, var) + else + append!(dvs, vec(collect(var)::Array{SymbolicT})::Vector{SymbolicT}) + end end for i in eachindex(dvs) var = dvs[i] if hasbounds(var) - irreducible_subs[var] = irrvar = setirreducible(var, true) + irreducible_subs[var] = irrvar = setirreducible(var, true)::SymbolicT dvs[i] = irrvar end end - econs = substitute.(econs, (irreducible_subs,)) + subst = SU.Substituter{false}(irreducible_subs, SU.default_substitute_filter) + for i in eachindex(econs) + econs[i] = subst(econs[i]) + end nlsys = System(econs, dvs, parameters(sys); name = :___tmp_nlsystem) - snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false) + snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false)::System obs = observed(snlsys) seqs = equations(snlsys) trueobs, _ = unhack_observed(obs, seqs) - subs = Dict(eq.lhs => eq.rhs for eq in trueobs) - cons_simplified = similar(cons, length(icons) + length(seqs)) - for (i, eq) in enumerate(Iterators.flatten((seqs, icons))) - cons_simplified[i] = fixpoint_sub(eq, subs) + subs = Dict{SymbolicT, SymbolicT}() + for eq in trueobs + subs[eq.lhs] = eq.rhs end - newsts = setdiff(dvs, keys(subs)) + cons_simplified = Union{Equation, Inequality}[] + for eq in seqs + push!(cons_simplified, fixpoint_sub(eq, subs)) + end + for eq in icons + push!(cons_simplified, fixpoint_sub(eq, subs)) + end + setdiff!(dvs, keys(subs)) + newsts = dvs @set! sys.constraints = cons_simplified - @set! sys.observed = [observed(sys); obs] - newcost = fixpoint_sub.(get_costs(sys), (subs,)) - @set! sys.costs = newcost + newobs = copy(observed(sys)) + append!(newobs, obs) + @set! sys.observed = newobs + newcosts = copy(get_costs(sys)) + for i in eachindex(newcosts) + newcosts[i] = fixpoint_sub(newcosts[i], subs) + end + @set! sys.costs = newcosts @set! sys.unknowns = newsts return sys end From a49be16b0dad813b3ffdc1931c7d90f57eac7e6a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 6 Oct 2025 12:32:29 +0530 Subject: [PATCH 064/157] fix: improve type-stability of `TearingState` --- src/ModelingToolkit.jl | 1 + src/systems/systemstructure.jl | 458 ++++++++++++++++++++------------- 2 files changed, 276 insertions(+), 183 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 384ed9de20..1017923ee6 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -163,6 +163,7 @@ include("constants.jl") const SymmapT = Dict{SymbolicT, SymbolicT} const COMMON_NOTHING = SU.Const{VartypeT}(nothing) +const COMMON_MISSING = SU.Const{VartypeT}(missing) include("utils.jl") diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a30e399047..d18a7984e6 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -153,7 +153,7 @@ Base.@kwdef mutable struct SystemStructure """Graph that connects equations to the variable they will be solved for during simplification.""" solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing} """Variable types (brownian, variable, parameter) in the system.""" - var_types::Union{Vector{VariableType}, Nothing} + var_types::Vector{VariableType} """Whether the system is discrete.""" only_discrete::Bool end @@ -205,10 +205,10 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T """The set of variables of the system.""" - fullvars::Vector{BasicSymbolic} + fullvars::Vector{SymbolicT} structure::SystemStructure - extra_eqs::Vector - param_derivative_map::Dict{BasicSymbolic, Any} + extra_eqs::Vector{Equation} + param_derivative_map::Dict{SymbolicT, SymbolicT} original_eqs::Vector{Equation} """ Additional user-provided observed equations. The variables calculated here @@ -278,7 +278,7 @@ function Base.show(io::IO, state::TearingState) print(io, "TearingState of ", typeof(state.sys)) end -struct EquationsView{T} <: AbstractVector{Any} +struct EquationsView{T} <: AbstractVector{Equation} ts::TearingState{T} end equations(ts::TearingState) = EquationsView(ts) @@ -301,11 +301,11 @@ function is_time_dependent_parameter(p, allps, iv) (args = arguments(p); length(args)) == 1 && isequal(only(args), iv)) end -function symbolic_contains(var, set) - var in set || - symbolic_type(var) == ArraySymbolic() && - symbolic_has_known_size(var) && - all(x -> x in set, Symbolics.scalarize(var)) +function symbolic_contains(var::SymbolicT, set::Set{SymbolicT}) + var in set # || + # symbolic_type(var) == ArraySymbolic() && + # symbolic_has_known_size(var) && + # all(x -> x in set, Symbolics.scalarize(var)) end """ @@ -315,21 +315,21 @@ Descend through the system hierarchy and look for statemachines. Remove equation the inner statemachine systems. Return the new `sys` and an array of top-level statemachines. """ -function extract_top_level_statemachines(sys::AbstractSystem) +function extract_top_level_statemachines(sys::System) eqs = get_eqs(sys) - - if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + predicate = Base.Fix2(isa, StateMachineOperator) ∘ SU.unwrap_const + if !isempty(eqs) && all(predicate, eqs) # top-level statemachine with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) return with_removed, [sys] - elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + elseif !isempty(eqs) && any(predicate, eqs) # error: can't mix error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") else # descend subsystems = get_systems(sys) - newsubsystems = eltype(subsystems)[] - statemachines = eltype(subsystems)[] + newsubsystems = System[] + statemachines = System[] for subsys in subsystems newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) push!(newsubsystems, newsubsys) @@ -346,12 +346,12 @@ end Return `sys` with all equations (including those in subsystems) removed. """ function remove_child_equations(sys::AbstractSystem) - @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.eqs = Equation[] @set! sys.systems = map(remove_child_equations, get_systems(sys)) return sys end -function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) +function TearingState(sys; check = true, sort_eqs = true) # flatten system sys = flatten(sys) sys = process_parameter_equations(sys) @@ -361,76 +361,29 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eqs = flatten_equations(equations(sys)) original_eqs = copy(eqs) neqs = length(eqs) - param_derivative_map = Dict{BasicSymbolic, Any}() + param_derivative_map = Dict{SymbolicT, SymbolicT}() + fullvars = SymbolicT[] # * Scalarize unknowns - dvs = Set{BasicSymbolic}() - fullvars = BasicSymbolic[] - for x in unknowns(sys) - push!(dvs, x) - xx = Symbolics.scalarize(x) - if xx isa AbstractArray - union!(dvs, xx) - end - end + dvs = Set{SymbolicT}() + collect_vars_to_set!(dvs, unknowns(sys)) ps = Set{SymbolicT}() - for x in full_parameters(sys) - push!(ps, x) - if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) - xx = Symbolics.scalarize(x) - union!(ps, xx) - end - end - browns = Set{BasicSymbolic}() - for x in brownians(sys) - push!(browns, x) - xx = Symbolics.scalarize(x) - if xx isa AbstractArray - union!(browns, xx) - end - end - var2idx = Dict{BasicSymbolic, Int}() + collect_vars_to_set!(ps, full_parameters(sys)) + browns = Set{SymbolicT}() + collect_vars_to_set!(browns, brownians(sys)) + var2idx = Dict{SymbolicT, Int}() var_types = VariableType[] - addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx, var_types = var_types - (var, vtype) -> get!(var2idx, var) do - push!(dvs, var) - push!(fullvars, var) - push!(var_types, vtype) - return length(fullvars) - end - end + + addvar! = AddVar!(var2idx, dvs, fullvars, var_types) # build symbolic incidence - symbolic_incidence = Vector{BasicSymbolic}[] - varsbuf = Set() + symbolic_incidence = Vector{SymbolicT}[] + varsbuf = Set{SymbolicT}() eqs_to_retain = trues(length(eqs)) for (i, eq) in enumerate(eqs) - _eq = eq - if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential && - isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv) - # parameter derivatives are opted out by specifying `D(p) ~ missing`, but - # we want to store `nothing` in the map because that means `substitute` - # will ignore the rule. We will this identify the presence of `eq′.lhs` in - # the differentiated expression and error. - param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing) - eqs_to_retain[i] = false - # change the equation if the RHS is `missing` so the rest of this loop works - eq = 0.0 ~ coalesce(eq.rhs, 0.0) - end - is_statemachine_equation = false - if eq.lhs isa StateMachineOperator - is_statemachine_equation = true - eq = eq - rhs = eq.rhs - elseif _iszero(eq.lhs) - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - else - lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - eq = 0 ~ rhs - lhs - end + eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, eqs_to_retain, ps, iv, i, eq) empty!(varsbuf) - vars!(varsbuf, eq; op = Symbolics.Operator) - incidence = Set{BasicSymbolic}() + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{SU.Operator}()) + incidence = Set{SymbolicT}() isalgeq = true for v in varsbuf # additionally track brownians in fullvars @@ -446,7 +399,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) !haskey(param_derivative_map, Differential(iv)(v)) # Parameter derivatives default to zero - they stay constant # between callbacks - param_derivative_map[Differential(iv)(v)] = 0.0 + param_derivative_map[Differential(iv)(v)] = Symbolics.COMMON_ZERO end continue end @@ -455,14 +408,27 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) isdelay(v, iv) && continue if !symbolic_contains(v, dvs) - isvalid = iscall(v) && - (operation(v) isa Shift || is_transparent_operator(operation(v))) + isvalid = Moshi.Match.@match v begin + BSImpl.Term(; f) => f isa Shift || f isa Operator && is_transparent_operator(f)::Bool + _ => false + end v′ = v - while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift} - v′ = arguments(v′)[1] - if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope - isvalid = true - break + while !isvalid + Moshi.Match.@match v′ begin + BSImpl.Term(; f, args) => begin + if f isa Differential + v′ = args[1] + elseif f isa Shift + v′ = args[1] + else + break + end + if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope + isvalid = true + break + end + end + _ => break end end if !isvalid @@ -470,43 +436,50 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) end addvar!(v, VARIABLE) - if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && - (it = input_timedomain(v)) !== nothing - for v′ in arguments(v) - addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f isa SU.Operator && + !(f isa Differential) && (it = input_timedomain(v)::Vector{InputTimeDomainElT}) !== nothing + end => begin + for (v′, td) in zip(args, it) + addvar!(setmetadata(v′, VariableTimeDomain, td), VARIABLE) + end end + _ => nothing end end isalgeq &= !isdifferential(v) - if symbolic_type(v) == ArraySymbolic() - vv = collect(v) + sh = SU.shape(v)::SU.ShapeVecT + if isempty(sh) + push!(incidence, v) + addvar!(v, VARIABLE) + elseif length(sh) == 1 + vv = collect(v)::Vector{SymbolicT} union!(incidence, vv) - map(vv) do vi + for vi in vv + addvar!(vi, VARIABLE) + end + elseif length(sh) == 2 + vv = collect(v)::Matrix{SymbolicT} + union!(incidence, vv) + for vi in vv addvar!(vi, VARIABLE) end else - push!(incidence, v) - addvar!(v, VARIABLE) + vv = collect(v) + union!(incidence, vv)::Array{SymbolicT} + for vi in vv + addvar!(vi, VARIABLE) + end end end if isalgeq || is_statemachine_equation eqs[i] = eq - else - eqs[i] = eqs[i].lhs ~ rhs end push!(symbolic_incidence, collect(incidence)) end - dervaridxs = OrderedSet{Int}() - for (i, v) in enumerate(fullvars) - while isdifferential(v) - push!(dervaridxs, i) - v = arguments(v)[1] - i = addvar!(v, VARIABLE) - end - end eqs = eqs[eqs_to_retain] original_eqs = original_eqs[eqs_to_retain] neqs = length(eqs) @@ -521,51 +494,42 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) symbolic_incidence = symbolic_incidence[sortidxs] end + dervaridxs = OrderedSet{Int}() + add_intermediate_derivatives!(fullvars, dervaridxs, addvar!) # Handle shifts - find lowest shift and add intermediates with derivative edges ### Handle discrete variables - lowest_shift = Dict() - for var in fullvars - if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) - steps = operation(var).steps - if steps > 0 - error("Only non-positive shifts allowed. Found $var with a shift of $steps") - end - v = arguments(var)[1] - lowest_shift[v] = min(get(lowest_shift, v, 0), steps) - end - end - for var in fullvars - if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) - op = operation(var) - steps = op.steps - v = arguments(var)[1] - lshift = lowest_shift[v] - tt = op.t - elseif haskey(lowest_shift, var) - lshift = lowest_shift[var] - steps = 0 - tt = iv - v = var - else - continue - end - if lshift < steps - push!(dervaridxs, var2idx[var]) - end - for s in (steps - 1):-1:(lshift + 1) - sf = Shift(tt, s) - dvar = sf(v) - idx = addvar!(dvar, VARIABLE) - if !(idx in dervaridxs) - push!(dervaridxs, idx) - end - end - end - + add_intermediate_shifts!(fullvars, dervaridxs, var2idx, addvar!, iv) # sort `fullvars` such that the mass matrix is as diagonal as possible. dervaridxs = collect(dervaridxs) - sorted_fullvars = OrderedSet(fullvars[dervaridxs]) - var_to_old_var = Dict(zip(fullvars, fullvars)) + fullvars, var_types = sort_fullvars(fullvars, dervaridxs, var_types, iv) + var2idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars)) + ndervars = length(dervaridxs) + # invalidate `dervaridxs`, it is just `1:ndervars` + dervaridxs = nothing + + # build `var_to_diff` + var_to_diff = build_var_to_diff(fullvars, ndervars, var2idx, iv) + + # build incidence graph + graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx) + + @set! sys.eqs = eqs + + eq_to_diff = DiffGraph(nsrcs(graph)) + + return TearingState{typeof(sys)}(sys, fullvars, + SystemStructure(complete(var_to_diff), complete(eq_to_diff), + complete(graph), nothing, var_types, false), + Equation[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) +end + +function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var_types::Vector{VariableType}, @nospecialize(iv::Union{SymbolicT, Nothing})) + if iv === nothing + return fullvars, var_types + end + iv = iv::SymbolicT + sorted_fullvars = OrderedSet{SymbolicT}(fullvars[dervaridxs]) + var_to_old_var = Dict{SymbolicT, SymbolicT}(zip(fullvars, fullvars)) for dervaridx in dervaridxs dervar = fullvars[dervaridx] diffvar = var_to_old_var[lower_order_var(dervar, iv)] @@ -582,38 +546,164 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) sortperm = indexin(new_fullvars, fullvars) fullvars = new_fullvars var_types = var_types[sortperm] - var2idx = Dict(fullvars .=> eachindex(fullvars)) - dervaridxs = 1:length(dervaridxs) + return fullvars, var_types +end - # build `var_to_diff` +function build_var_to_diff(fullvars::Vector{SymbolicT}, ndervars::Int, var2idx::Dict{SymbolicT, Int}, @nospecialize(iv::Union{SymbolicT, Nothing})) nvars = length(fullvars) - diffvars = [] var_to_diff = DiffGraph(nvars, true) - for dervaridx in dervaridxs + if iv === nothing + return var_to_diff + end + iv = iv::SymbolicT + for dervaridx in 1:ndervars dervar = fullvars[dervaridx] diffvar = lower_order_var(dervar, iv) diffvaridx = var2idx[diffvar] - push!(diffvars, diffvar) var_to_diff[diffvaridx] = dervaridx end + return var_to_diff +end - # build incidence graph +function build_incidence_graph(nvars::Int, symbolic_incidence::Vector{Vector{SymbolicT}}, var2idx::Dict{SymbolicT, Int}) + neqs = length(symbolic_incidence) graph = BipartiteGraph(neqs, nvars, Val(false)) for (ie, vars) in enumerate(symbolic_incidence), v in vars - jv = var2idx[v] add_edge!(graph, ie, jv) end - @set! sys.eqs = eqs + return graph +end - eq_to_diff = DiffGraph(nsrcs(graph)) +function collect_vars_to_set!(buffer::Set{SymbolicT}, vars::Vector{SymbolicT}) + for x in vars + push!(buffer, x) + Moshi.Match.@match x begin + BSImpl.Term(; f, args) && if f === getindex end => push!(buffer, args[1]) + _ => nothing + end + sh = SU.shape(x) + sh isa SU.Unknown && continue + sh = sh::SU.ShapeVecT + isempty(sh) && continue + idxs = SU.stable_eachindex(x) + for i in idxs + push!(buffer, x[i]) + end + end +end - ts = TearingState(sys, fullvars, - SystemStructure(complete(var_to_diff), complete(eq_to_diff), - complete(graph), nothing, var_types, false), - Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) - return ts +function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation) + is_statemachine_equation = false + lhs = eq.lhs + rhs = eq.rhs + Moshi.Match.@match lhs begin + BSImpl.Term(; f, args) && if f isa Differential && iv isa SymbolicT && isequal(f.x, iv) && + is_time_dependent_parameter(args[1], ps, iv) + end => begin + # parameter derivatives are opted out by specifying `D(p) ~ missing`, but + # we want to store `nothing` in the map because that means `substitute` + # will ignore the rule. We will this identify the presence of `eq′.lhs` in + # the differentiated expression and error. + if eq.rhs !== COMMON_MISSING + param_derivative_map[lhs] = rhs + eq = Symbolics.COMMON_ZERO ~ rhs + else + # change the equation if the RHS is `missing` so the rest of this loop works + eq = Symbolics.COMMON_ZERO ~ Symbolics.COMMON_ZERO + end + eqs_to_retain[i] = false + end + BSImpl.Const(; val) && if val isa StateMachineOperator end => begin + is_statemachine_equation = true + end + BSImpl.Const(;) && if _iszero(lhs) end => nothing + _ => begin + eq = Symbolics.COMMON_ZERO ~ (rhs - lhs) + end + end + return eq, is_statemachine_equation +end + +struct AddVar! + var2idx::Dict{SymbolicT, Int} + dvs::Set{SymbolicT} + fullvars::Vector{SymbolicT} + var_types::Vector{VariableType} +end + +function (avc::AddVar!)(var::SymbolicT, vtype::VariableType) + idx = get(avc.var2idx, var, nothing) + idx === nothing || return idx::Int + push!(avc.dvs, var) + push!(avc.fullvars, var) + push!(avc.var_types, vtype) + return avc.var2idx[var] = length(avc.fullvars) +end + +function add_intermediate_derivatives!(fullvars::Vector{SymbolicT}, dervaridxs::OrderedSet{Int}, addvar!::AddVar!) + for (i, v) in enumerate(fullvars) + while true + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f isa Differential end => begin + push!(dervaridxs, i) + v = args[1] + addvar!(v, VARIABLE) + end + _ => break + end + end + end +end + +function add_intermediate_shifts!(fullvars::Vector{SymbolicT}, dervaridxs::OrderedSet{Int}, var2idx::Dict{SymbolicT, Int}, addvar!::AddVar!, iv::Union{SymbolicT, Nothing}) + lowest_shift = Dict{SymbolicT, Int}() + for var in fullvars + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift end => begin + steps = f.steps + if steps > 0 + error("Only non-positive shifts allowed. Found $var with a shift of $steps") + end + v = args[1] + lowest_shift[v] = min(get(lowest_shift, v, 0), steps) + end + _ => nothing + end + end + for var in fullvars + lshift = typemax(Int) + steps = typemax(Int) + tt = iv + v = var + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift end => begin + steps = f.steps + v = args[1] + lshift = lowest_shift[v] + tt = f.t + end + if haskey(lowest_shift, var) end => begin + lshift = lowest_shift[var] + steps = 0 + tt = iv + v = var + end + _ => continue + end + if lshift < steps + push!(dervaridxs, var2idx[var]) + end + for s in (steps - 1):-1:(lshift + 1) + sf = Shift(tt, s) + dvar = sf(v) + idx = addvar!(dvar, VARIABLE) + if !(idx in dervaridxs) + push!(dervaridxs, idx) + end + end + end end """ @@ -719,36 +809,38 @@ function trivial_tearing!(ts::TearingState) return ts end -function lower_order_var(dervar, t) - if isdifferential(dervar) - diffvar = arguments(dervar)[1] - elseif ModelingToolkit.isoperator(dervar, ModelingToolkit.Shift) - s = operation(dervar) - step = s.steps - 1 - vv = arguments(dervar)[1] - if step != 0 - diffvar = Shift(s.t, step)(vv) - else - diffvar = vv +function lower_order_var(dervar::SymbolicT, t::SymbolicT) + Moshi.Match.@match dervar begin + BSImpl.Term(; f, args) && if f isa Differential end => return args[1] + BSImpl.Term(; f, args) && if f isa Shift end => begin + step = f.steps - 1 + vv = args[1] + if step != 0 + diffvar = Shift(f.t, step)(vv) + else + diffvar = vv + end + return diffvar end - else - return Shift(t, -1)(dervar) + _ => return Shift(t, -1)(dervar) end - diffvar end function shift_discrete_system(ts::TearingState) @unpack fullvars, sys = ts - discvars = OrderedSet() + fullvars_set = Set{SymbolicT}(fullvars) + discvars = OrderedSet{SymbolicT}() eqs = equations(sys) for eq in eqs - vars!(discvars, eq; op = Union{Sample, Hold, Pre}) + SU.search_variables!(discvars, eq; is_atomic = OperatorIsAtomic{Union{Sample, Hold, Pre}}()) end - iv = get_iv(sys) - - discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) + iv = get_iv(sys)::SymbolicT + discmap = Dict{SymbolicT, SymbolicT}() for k in discvars - if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold, Pre})) + k in fullvars_set || continue + isoperator(k, Union{Sample, Hold, Pre}) && continue + discmap[k] = StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) + end for i in eachindex(fullvars) fullvars[i] = StructuralTransformations.simplify_shifts(substitute( From 407329001825bc58cfecbe4a6243bad5f6da5dd9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:47:44 +0530 Subject: [PATCH 065/157] fix: improve type-stability of `pantelides_reassemble` --- src/structural_transformation/pantelides.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 47fa5aa762..581c1f1198 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -8,11 +8,11 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) sys = state.sys # Step 1: write derivative equations in_eqs = equations(sys) - out_eqs = Vector{Any}(undef, nv(eq_to_diff)) + out_eqs = Vector{Equation}(undef, nv(eq_to_diff)) fill!(out_eqs, nothing) out_eqs[1:length(in_eqs)] .= in_eqs - out_vars = Vector{Any}(undef, nv(var_to_diff)) + out_vars = Vector{SymbolicT}(undef, nv(var_to_diff)) fill!(out_vars, nothing) out_vars[1:length(fullvars)] .= fullvars @@ -31,8 +31,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) out_vars[diff] = D(vi) end - d_dict = Dict(zip(fullvars, 1:length(fullvars))) - lhss = Set{Any}([x.lhs for x in in_eqs if isdiffeq(x)]) + d_dict = Dict{SymbolicT, Int}(zip(fullvars, 1:length(fullvars))) for (eqidx, diff) in edges(eq_to_diff) # LHS variable is looked up from var_to_diff # the var_to_diff[i]-th variable is the differentiated version of var at i From 055eedfbadaf3660bc8217fbe41755e69283a9ff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:48:33 +0530 Subject: [PATCH 066/157] fix: improve type-stability of `tearing_reassemble` --- .../symbolics_tearing.jl | 160 ++++++++++-------- 1 file changed, 90 insertions(+), 70 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index ad3e889dff..1a5ddd55f8 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -204,7 +204,7 @@ State selection is done. All non-differentiated variables are algebraic variables, and all variables that appear differentiated are differential variables. """ function substitute_derivatives_algevars!( - ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing) + ts::TearingState, neweqs::Vector{Equation}, var_eq_matching::Matching, dummy_sub::Dict{SymbolicT, SymbolicT}, iv::Union{Nothing, SymbolicT}, D::Union{Nothing, Differential, Shift}) @unpack fullvars, sys, structure = ts @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) @@ -214,7 +214,7 @@ function substitute_derivatives_algevars!( dv === nothing && continue if var_eq_matching[var] !== SelectedState() dd = fullvars[dv] - v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false) + v_t = setio(diff2term_with_unit(dd, iv), false, false) for eq in 𝑑neighbors(graph, dv) dummy_sub[dd] = v_t neweqs[eq] = substitute(neweqs[eq], dd => v_t) @@ -326,17 +326,20 @@ Effects on the system structure: """ function generate_derivative_variables!( ts::TearingState, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs; mm, iv = nothing, D = nothing) + var_sccs, mm::SparseMatrixCLIL{T, Int}, iv::Union{SymbolicT, Nothing}) where {T} @unpack fullvars, sys, structure = ts @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) diff_to_var = invview(var_to_diff) is_discrete = is_only_discrete(structure) - linear_eqs = mm === nothing ? Dict{Int, Int}() : - Dict(reverse(en) for en in enumerate(mm.nzrows)) + linear_eqs = Dict{Int, Int}() + for (i, e) in enumerate(mm.nzrows) + linear_eqs[e] = i + end # We need the inverse mapping of `var_sccs` to update it efficiently later. - v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph)) + v_to_scc = NTuple{2, Int}[] + resize!(v_to_scc, ndsts(graph)) for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc) v_to_scc[v] = (i, j) @@ -424,9 +427,7 @@ function generate_derivative_variables!( end new_sccs = insert_sccs(var_sccs, sccs_to_insert) - if mm !== nothing - @set! mm.ncols = ndsts(graph) - end + @set! mm.ncols = ndsts(graph) return new_sccs end @@ -539,26 +540,28 @@ Reorder the equations and unknowns to be in the BLT sorted form. Return the new equations, the solved equations, the new orderings, and the number of solved variables and equations. """ -function generate_system_equations!(state::TearingState, neweqs, var_eq_matching, - full_var_eq_matching, var_sccs, extra_eqs_vars; - simplify = false, iv = nothing, D = nothing) +function generate_system_equations!(state::TearingState, neweqs::Vector{Equation}, + var_eq_matching::Matching, full_var_eq_matching::Matching, + var_sccs::Vector{Vector{Int}}, extra_eqs_vars::NTuple{2, Vector{Int}}, + iv::Union{SymbolicT, Nothing}, D::Union{Differential, Shift, Nothing}; + simplify::Bool = false) @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) - full_eq_var_matching = invview(full_var_eq_matching) diff_to_var = invview(var_to_diff) extra_eqs, extra_vars = extra_eqs_vars - total_sub = Dict() + total_sub = Dict{SymbolicT, SymbolicT}() if is_only_discrete(structure) for (i, v) in enumerate(fullvars) - op = operation(v) - op isa Shift && (op.steps < 0) && - begin + Moshi.Match.@match v begin + BSImpl.Term(; f) && if f isa Shift && f.steps < 0 end => begin lowered = lower_shift_varname_with_unit(v, iv) total_sub[v] = lowered fullvars[i] = lowered end + _ => nothing + end end end @@ -579,13 +582,11 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching end digraph = DiCMOBiGraph{false}(graph, var_eq_matching) - idep = iv for (i, scc) in enumerate(var_sccs) # note that the `vscc <-> escc` relation is a set-to-set mapping, and not # point-to-point. vscc, escc = get_sorted_scc(digraph, full_var_eq_matching, var_eq_matching, scc) var_sccs[i] = vscc - if length(escc) != length(vscc) isempty(escc) && continue escc = setdiff(escc, extra_eqs) @@ -594,11 +595,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching isempty(vscc) && continue end - offset = 1 for ieq in escc iv = eq_var_matching[ieq] - eq = neweqs[ieq] - codegen_equation!(eq_generator, neweqs[ieq], ieq, iv; simplify) + neq = neweqs[ieq] + codegen_equation!(eq_generator, neq, ieq, iv; simplify) end end @@ -620,14 +620,16 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching solved_vars_set = BitSet(solved_vars) # We filled zeros for algebraic variables, so fill them properly here offset = 1 + findnextfn = let diff_vars_set = diff_vars_set, solved_vars_set = solved_vars_set, + diff_to_var = diff_to_var, ispresent = ispresent + j -> !(j in diff_vars_set || j in solved_vars_set) && diff_to_var[j] === nothing && + ispresent(j) + end for (i, v) in enumerate(var_ordering) v == 0 || continue # find the next variable which is not differential or solved, is not the # derivative of another variable and is present in the equations - index = findnext(1:ndsts(graph), offset) do j - !(j in diff_vars_set || j in solved_vars_set) && diff_to_var[j] === nothing && - ispresent(j) - end + index = findnext(findnextfn, 1:ndsts(graph), offset) # in case of overdetermined systems, this may not be present index === nothing && break var_ordering[i] = index @@ -649,11 +651,20 @@ Sort the provided SCC `scc`, given the `digraph` of the system constructed using function get_sorted_scc( digraph::DiCMOBiGraph, full_var_eq_matching::Matching, var_eq_matching::Matching, scc::Vector{Int}) eq_var_matching = invview(var_eq_matching) - full_eq_var_matching = invview(full_var_eq_matching) # obtain the matched equations in the SCC - scc_eqs = Int[full_var_eq_matching[v] for v in scc if full_var_eq_matching[v] isa Int] + scc_eqs = Int[] # obtain the equations in the SCC that are linearly solvable - scc_solved_eqs = Int[var_eq_matching[v] for v in scc if var_eq_matching[v] isa Int] + scc_solved_eqs = Int[] + for v in scc + e = full_var_eq_matching[v] + if e isa Int + push!(scc_eqs, e) + end + e = var_eq_matching[v] + if e isa Int + push!(scc_solved_eqs, e) + end + end # obtain the subgraph of the contracted graph involving the solved equations subgraph, varmap = Graphs.induced_subgraph(digraph, scc_solved_eqs) # topologically sort the solved equations and append the remainder @@ -661,7 +672,13 @@ function get_sorted_scc( setdiff(scc_eqs, scc_solved_eqs)] # the variables of the SCC are obtained by inverse mapping the sorted equations # and appending the rest - scc_vars = [eq_var_matching[e] for e in scc_eqs if eq_var_matching[e] isa Int] + scc_vars = Int[] + for e in scc_eqs + v = eq_var_matching[e] + if v isa Int + push!(scc_vars, v) + end + end append!(scc_vars, setdiff(scc, scc_vars)) return scc_vars, scc_eqs end @@ -672,7 +689,7 @@ end Struct containing the information required to generate equations of a system, as well as the generated equations and associated metadata. """ -struct EquationGenerator{S, D, I} +struct EquationGenerator{S} """ `TearingState` of the system. """ @@ -681,15 +698,15 @@ struct EquationGenerator{S, D, I} Substitutions to perform in all subsequent equations. For each differential equation `D(x) ~ f(..)`, the substitution `D(x) => f(..)` is added to the rules. """ - total_sub::Dict{Any, Any} + total_sub::Dict{SymbolicT, SymbolicT} """ The differential operator, or `nothing` if not applicable. """ - D::D + D::Union{Differential, Shift, Nothing} """ The independent variable, or `nothing` if not applicable. """ - idep::I + idep::Union{SymbolicT, Nothing} """ The new generated equations of the system. """ @@ -829,19 +846,17 @@ Generate a first-order differential equation whose LHS is `dx`. `var` and `dx` represent the same variable, but `var` may be a higher-order differential and `dx` is always first-order. For example, if `var` is D(D(x)), then `dx` would be `D(x_t)`. Solve `eq` for `var`, substitute previously solved variables, and return the differential equation. """ function make_differential_equation(var, dx, eq, total_sub) - dx ~ simplify_shifts(Symbolics.fixpoint_sub( - Symbolics.symbolic_linear_solve(eq, var), - total_sub; operator = ModelingToolkit.Shift)) + v1 = Symbolics.symbolic_linear_solve(eq, var)::SymbolicT + v2 = Symbolics.fixpoint_sub(v1, total_sub; operator = ModelingToolkit.Shift) + v3 = simplify_shifts(v2) + dx ~ v3 end """ Generate an algebraic equation. Substitute solved variables into `eq` and return the equation. """ function make_algebraic_equation(eq, total_sub) - rhs = eq.rhs - if !(eq.lhs isa Number && eq.lhs == 0) - rhs = eq.rhs - eq.lhs - end + rhs = eq.rhs - eq.lhs 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)) end @@ -910,7 +925,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, var_sccs, eq_orderi var_ordering_set = BitSet(var_ordering) for scc in var_sccs # Map variables to their new indices - map!(v -> varsperm[v], scc, scc) + map!(Base.Fix1(getindex, varsperm), scc, scc) # Remove variables not in the reduced set filter!(!iszero, scc) end @@ -929,16 +944,20 @@ end Update the system equations, unknowns, and observables after simplification. """ function update_simplified_system!( - state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns; - array_hack = true, D = nothing, iv = nothing) + state::TearingState, neweqs::Vector{Equation}, solved_eqs::Vector{Equation}, + dummy_sub::Dict{SymbolicT, SymbolicT}, var_sccs::Vector{Vector{Int}}, + extra_unknowns::Vector{SymbolicT}, iv::Union{SymbolicT, Nothing}, + D::Union{Differential, Shift, Nothing}; array_hack = true) @unpack fullvars, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) # Since we solved the highest order derivative variable in discrete systems, # we make a list of the solved variables and avoid including them in the # unknowns. - solved_vars = Set() + solved_vars = Set{SymbolicT}() if is_only_discrete(structure) + iv = iv::SymbolicT + D = D::Shift for eq in solved_eqs var = eq.lhs if isequal(eq.lhs, eq.rhs) @@ -964,22 +983,25 @@ function update_simplified_system!( obs = [substitute(observed(sys), obs_sub); solved_eqs; substitute(state.additional_observed, obs_sub)] - unknown_idxs = filter( - i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars)) + filterer = let diff_to_var = diff_to_var, ispresent = ispresent, fullvars = fullvars, + solved_vars = solved_vars + i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars) + end + unknown_idxs = filter(filterer, eachindex(state.fullvars)) unknowns = state.fullvars[unknown_idxs] unknowns = [unknowns; extra_unknowns] if is_only_discrete(structure) # Algebraic variables are shifted forward by one, so we backshift them. - unknowns = map(enumerate(unknowns)) do (i, var) - if iscall(var) && operation(var) isa Shift && operation(var).steps == 1 - # We might have shifted a variable with io metadata. That is irrelevant now - # because we handled io variables earlier in `_mtkcompile!` so just ignore - # it here. - setio(backshift_expr(var, iv), false, false) - else - var + _unknowns = SymbolicT[] + for var in unknowns + Moshi.Match.@match var begin + BSImpl.Term(; f, args, type, shape, metadata) && if f isa Shift && f.steps == 1 end => begin + push!(_unknowns, setio(args[1], false, false)) + end + _ => push!(_unknowns, var) end end + unknowns = _unknowns end @set! sys.unknowns = unknowns @@ -1044,7 +1066,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, extra_eqs_vars = get_extra_eqs_vars( state, var_eq_matching, full_var_eq_matching, fully_determined) neweqs = collect(equations(state)) - dummy_sub = Dict() + dummy_sub = Dict{SymbolicT, SymbolicT}() if ModelingToolkit.has_iv(state.sys) iv = get_iv(state.sys) @@ -1060,29 +1082,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, extra_unknowns = state.fullvars[extra_eqs_vars[2]] if is_only_discrete(state.structure) var_sccs = add_additional_history!( - state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; iv, D) + state, var_eq_matching, full_var_eq_matching, var_sccs, iv) end # Structural simplification - substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D) + substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D) var_sccs = generate_derivative_variables!( - state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; mm, iv, D) - + state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv) neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = generate_system_equations!( state, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs, extra_eqs_vars; simplify, iv, D) + var_sccs, extra_eqs_vars, iv, D; simplify) state = reorder_vars!( state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var) # var_eq_matching and full_var_eq_matching are now invalidated sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, - extra_unknowns; array_hack, iv, D) + extra_unknowns, iv, D; array_hack) @set! state.sys = sys @set! sys.tearing_state = state @@ -1120,24 +1141,23 @@ x(k) ~ x(k-1) + x(k-2) Where the last equation is the observed equation. """ function add_additional_history!( - state::TearingState, neweqs::Vector, var_eq_matching::Matching, - full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; iv, D) + state::TearingState, var_eq_matching::Matching, + full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}, iv::Union{SymbolicT, Nothing}) + iv === nothing && return var_sccs + iv = iv::SymbolicT @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure - eq_var_matching = invview(var_eq_matching) diff_to_var = invview(var_to_diff) - is_discrete = is_only_discrete(structure) - digraph = DiCMOBiGraph{false}(graph, var_eq_matching) # We need the inverse mapping of `var_sccs` to update it efficiently later. - v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph)) + v_to_scc = NTuple{2, Int}[] + resize!(v_to_scc, ndsts(graph)) for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc) v_to_scc[v] = (i, j) end vars_to_backshift = BitSet() - eqs_to_backshift = BitSet() # add history for differential variables for ivar in 1:length(fullvars) ieq = var_eq_matching[ivar] @@ -1243,7 +1263,7 @@ occurs in observed equations (and unknowns if it's split). function tearing_hacks(sys, obs, unknowns, neweqs; array = true) # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations - arr_obs_occurrences = Dict() + arr_obs_occurrences = Dict{SymbolicT, Int}() for (i, eq) in enumerate(obs) lhs = eq.lhs rhs = eq.rhs From 25a411637f66dddfe3a0676d02795716f050b99b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:48:57 +0530 Subject: [PATCH 067/157] fix: improve type-stability of structural transformation utils --- src/structural_transformation/utils.jl | 70 +++++++++++++++++--------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index f6ff669c0e..b2ba2059c3 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -243,24 +243,36 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no all_int_vars = true coeffs === nothing || empty!(coeffs) empty!(to_rm) + __indexed_fullvar_is_var = let fullvars = fullvars + function indexed_fullvar_is_var(x::SymbolicT) + for v in fullvars + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex && isequal(args[1], x) end => return true + _ => nothing + end + end + return false + end + end + __allow_sym_par_cond = let fullvars = fullvars, is_atomic = ModelingToolkit.OperatorIsAtomic{Union{Differential, Shift, Pre, Sample, Hold, Initial}}(), __indexed_fullvar_is_var = __indexed_fullvar_is_var + function allow_sym_par_cond(v) + is_atomic(v) && any(isequal(v), fullvars) || + symbolic_type(v) == ArraySymbolic() && (SU.shape(v) isa SU.Unknown || + __indexed_fullvar_is_var(v)) + end + end for j in 𝑠neighbors(graph, ieq) var = fullvars[j] isirreducible(var) && (all_int_vars = false; continue) a, b, islinear = linear_expansion(term, var) - a, b = unwrap(a), unwrap(b) + islinear || (all_int_vars = false; continue) if a isa SymbolicT all_int_vars = false if !allow_symbolic if allow_parameter # if any of the variables in `a` are present in fullvars (taking into account arrays) - if any( - v -> any(isequal(v), fullvars) || - symbolic_type(v) == ArraySymbolic() && - SU.shape(v) isa SU.Unknown || - any(x -> any(isequal(x), fullvars), collect(v)), - vars( - a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial})) + if SU.query(__allow_sym_par_cond, a) continue end else @@ -559,27 +571,35 @@ function isdoubleshift(var) ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift) end +simplify_shifts(eq::Equation) = simplify_shifts(eq.lhs) ~ simplify_shifts(eq.rhs) + +function _simplify_shifts(var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift && f.steps == 0 end => return args[1] + BSImpl.Term(; f = op1, args) && if op1 isa Shift end => begin + vv1 = args[1] + Moshi.Match.@match vv1 begin + BSImpl.Term(; f = op2, args = a2) && if op2 isa Shift end => begin + vv2 = a2[1] + s1 = op1.steps + s2 = op2.steps + t1 = op1.t + t2 = op2.t + return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2)) + end + _ => return var + end + end + _ => var + end +end + """ Simplify multiple shifts: Shift(t, k1)(Shift(t, k2)(x)) becomes Shift(t, k1+k2)(x). """ -function simplify_shifts(var) +function simplify_shifts(var::SymbolicT) ModelingToolkit.hasshift(var) || return var - var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs) - (op = operation(var)) isa Shift && op.steps == 0 && return first(arguments(var)) - if isdoubleshift(var) - op1 = operation(var) - vv1 = arguments(var)[1] - op2 = operation(vv1) - vv2 = arguments(vv1)[1] - s1 = op1.steps - s2 = op2.steps - t1 = op1.t - t2 = op2.t - return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2)) - else - return maketerm(typeof(var), operation(var), simplify_shifts.(arguments(var)), - unwrap(var).metadata) - end + return SU.Rewriters.Postwalk(_simplify_shifts)(var) end """ From e5132c8aa1d63fa9d8d8b47f39a031c7aa23e3dc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:49:06 +0530 Subject: [PATCH 068/157] fix: improve type-stability of alias elimination --- .../symbolics_tearing.jl | 3 +- src/systems/alias_elimination.jl | 63 ++++++++++++------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 1a5ddd55f8..d3f9a4c5da 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -69,7 +69,8 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...) ModelingToolkit.derivative( eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map) - vs = ModelingToolkit.vars(eq.rhs) + vs = Set{SymbolicT}() + SU.search_variables!(vs, eq.rhs) for v in vs # parameters with unknown derivatives have a value of `nothing` in the map, # so use `missing` as the default. diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 03878e6847..2a35f59458 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -10,10 +10,12 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...) @unpack graph, var_to_diff, solvable_graph = state.structure mm = alias_eliminate_graph!(state, mm; kwargs...) s = state.structure - for g in (s.graph, s.solvable_graph) - g === nothing && continue + for (ei, e) in enumerate(mm.nzrows) + set_neighbors!(s.graph, e, mm.row_cols[ei]) + end + if s.solvable_graph isa BipartiteGraph{Int, Nothing} for (ei, e) in enumerate(mm.nzrows) - set_neighbors!(g, e, mm.row_cols[ei]) + set_neighbors!(s.solvable_graph, e, mm.row_cols[ei]) end end @@ -46,14 +48,12 @@ alias_elimination(sys) = alias_elimination!(TearingState(sys))[1] function alias_elimination!(state::TearingState; kwargs...) sys = state.sys complete!(state.structure) - graph_orig = copy(state.structure.graph) mm = alias_eliminate_graph!(state; kwargs...) fullvars = state.fullvars @unpack var_to_diff, graph, solvable_graph = state.structure - subs = Dict() - obs = Equation[] + subs = Dict{SymbolicT, SymbolicT}() # If we encounter y = -D(x), then we need to expand the derivative when # D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)). to_expand = Int[] @@ -62,17 +62,21 @@ function alias_elimination!(state::TearingState; kwargs...) dels = Int[] eqs = collect(equations(state)) resize!(eqs, nsrcs(graph)) + + __trivial_eq_rhs = let fullvars = fullvars + function trivial_eq_rhs(var, coeff) + iszero(coeff) && return Symbolics.COMMON_ZERO + return coeff * fullvars[var] + end + end for (ei, e) in enumerate(mm.nzrows) vs = 𝑠neighbors(graph, e) if isempty(vs) # remove empty equations push!(dels, e) else - rhs = mapfoldl(+, pairs(nonzerosmap(@view mm[ei, :]))) do (var, coeff) - iszero(coeff) && return 0 - return coeff * fullvars[var] - end - eqs[e] = 0 ~ rhs + rhs = mapfoldl(__trivial_eq_rhs, +, pairs(nonzerosmap(@view mm[ei, :]))) + eqs[e] = Symbolics.COMMON_ZERO ~ rhs end end deleteat!(eqs, sort!(dels)) @@ -92,21 +96,22 @@ function alias_elimination!(state::TearingState; kwargs...) n_new_eqs = idx - lineqs = BitSet(mm.nzrows) eqs_to_update = BitSet() - nvs_orig = ndsts(graph_orig) for ieq in eqs_to_update eq = eqs[ieq] eqs[ieq] = substitute(eq, subs) end - @set! mm.nparentrows = nsrcs(graph) - @set! mm.row_cols = eltype(mm.row_cols)[mm.row_cols[i] - for (i, eq) in enumerate(mm.nzrows) - if old_to_new_eq[eq] > 0] - @set! mm.row_vals = eltype(mm.row_vals)[mm.row_vals[i] - for (i, eq) in enumerate(mm.nzrows) - if old_to_new_eq[eq] > 0] - @set! mm.nzrows = Int[old_to_new_eq[eq] for eq in mm.nzrows if old_to_new_eq[eq] > 0] + new_nparentrows = nsrcs(graph) + new_row_cols = eltype(mm.row_cols)[] + new_row_vals = eltype(mm.row_vals)[] + new_nzrows = Int[] + for (i, eq) in enumerate(mm.nzrows) + old_to_new_eq[eq] > 0 || continue + push!(new_row_cols, mm.row_cols[i]) + push!(new_row_vals, mm.row_vals[i]) + push!(new_nzrows, old_to_new_eq[eq]) + end + mm = typeof(mm)(new_nparentrows, mm.ncols, new_nzrows, new_row_cols, new_row_vals) for old_ieq in to_expand ieq = old_to_new_eq[old_ieq] @@ -138,7 +143,13 @@ function alias_elimination!(state::TearingState; kwargs...) sys = state.sys @set! sys.eqs = eqs state.sys = sys - return invalidate_cache!(sys), mm + # This phrasing infers the return type as `Union{Tuple{...}}` instead of + # `Tuple{Union{...}, ...}` + if mm isa SparseMatrixCLIL{BigInt, Int} + return invalidate_cache!(sys), mm + else + return invalidate_cache!(sys), mm + end end """ @@ -301,7 +312,13 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) end - return mm, solvable_variables, bar + # This phrasing infers the return type as `Union{Tuple{...}}` instead of + # `Tuple{Union{...}, ...}` + if mm isa SparseMatrixCLIL{BigInt, Ti} + return mm, solvable_variables, bar + else + return mm, solvable_variables, bar + end end function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff) From 27c893ad8856b4b0e0b994aa20ccde5827b1db53 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:49:13 +0530 Subject: [PATCH 069/157] fix: improve type-stability of clock inference --- src/systems/clock_inference.jl | 64 +++++++++++++--------------------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index a88e8c42fe..36650723b2 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -58,17 +58,24 @@ struct NotInferredTimeDomain end function error_sample_time(eq) error("$eq\ncontains `SampleTime` but it is not an Inferred discrete equation.") end -function substitute_sample_time(ci::ClockInference, ts::TearingState) +function substitute_sample_time(ci::ClockInference{T}, ts::T) where {T <: TearingState} @unpack eq_domain = ci eqs = copy(equations(ts)) @assert length(eqs) == length(eq_domain) + subrules = Dict{SymbolicT, SymbolicT}() + st = SampleTime() for i in eachindex(eqs) eq = eqs[i] domain = eq_domain[i] - dt = sampletime(domain) - neweq = substitute_sample_time(eq, dt) - if neweq isa NotInferredTimeDomain - error_sample_time(eq) + dt = SU.Const{VartypeT}(sampletime(domain)) + if dt === COMMON_NOTHING + if SU.query(isequal(st), eq.lhs) || SU.query(isequal(st), eq.rhs) + error_sample_time(eq) + end + neweq = eq + else + subrules[st] = dt + neweq = substitute(eq, subrules) end eqs[i] = neweq end @@ -76,30 +83,6 @@ function substitute_sample_time(ci::ClockInference, ts::TearingState) @set! ci.ts = ts end -function substitute_sample_time(eq::Equation, dt) - substitute_sample_time(eq.lhs, dt) ~ substitute_sample_time(eq.rhs, dt) -end - -function substitute_sample_time(ex, dt) - iscall(ex) || return ex - op = operation(ex) - args = arguments(ex) - if op == SampleTime - dt === nothing && return NotInferredTimeDomain() - return dt - else - new_args = similar(args) - for (i, arg) in enumerate(args) - ex_arg = substitute_sample_time(arg, dt) - if ex_arg isa NotInferredTimeDomain - return ex_arg - end - new_args[i] = ex_arg - end - maketerm(typeof(ex), op, new_args, metadata(ex)) - end -end - """ Update the equation-to-time domain mapping by inferring the time domain from the variables. """ @@ -109,7 +92,7 @@ function infer_clocks!(ci::ClockInference) fullvars = get_fullvars(ts) isempty(inferred) && return ci - var_to_idx = Dict(fullvars .=> eachindex(fullvars)) + var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars)) # all shifted variables have the same clock as the unshifted variant for (i, v) in enumerate(fullvars) @@ -122,9 +105,9 @@ function infer_clocks!(ci::ClockInference) # preallocated buffers: # variables in each equation - varsbuf = Set() + varsbuf = Set{SymbolicT}() # variables in each argument to an operator - arg_varsbuf = Set() + arg_varsbuf = Set{SymbolicT}() # hyperedge for each equation hyperedge = Set{ClockVertex.Type}() # hyperedge for each argument to an operator @@ -136,7 +119,7 @@ function infer_clocks!(ci::ClockInference) empty!(varsbuf) empty!(hyperedge) # get variables in equation - vars!(varsbuf, eq; op = Symbolics.Operator) + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Symbolics.Operator}()) # add the equation to the hyperedge eq_node = if is_initialization_equation ClockVertex.InitEquation(ieq) @@ -155,14 +138,17 @@ function infer_clocks!(ci::ClockInference) # now we only care about synchronous operators iscall(var) || continue op = operation(var) - is_timevarying_operator(op) || continue + is_timevarying_operator(op)::Bool || continue # arguments and corresponding time domains args = arguments(var) tdomains = input_timedomain(op) - if !(tdomains isa AbstractArray || tdomains isa Tuple) - tdomains = [tdomains] + if tdomains isa Tuple + tdomains = Vector{InputTimeDomainElT}(collect(tdomains)) + elseif !(tdomains isa Vector{InputTimeDomainElT}) + tdomains = InputTimeDomainElT[tdomains] end + tdomains = tdomains::Vector{InputTimeDomainElT} nargs = length(args) ndoms = length(tdomains) if nargs != ndoms @@ -178,7 +164,7 @@ function infer_clocks!(ci::ClockInference) empty!(arg_varsbuf) empty!(arg_hyperedge) # get variables in argument - vars!(arg_varsbuf, arg; op = Union{Differential, Shift}) + SU.search_variables!(arg_varsbuf, arg; is_atomic = OperatorIsAtomic{Union{Differential, Shift}}()) # get hyperedge for involved variables for v in arg_varsbuf vidx = get(var_to_idx, v, nothing) @@ -200,7 +186,7 @@ function infer_clocks!(ci::ClockInference) # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't # add the edge, and instead add this to the `relative_hyperedges` mapping. InferredClock.InferredDiscrete(i) => begin - relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i) + relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i) union!(relative_edge, arg_hyperedge) end end @@ -237,7 +223,7 @@ function infer_clocks!(ci::ClockInference) clock_partitions = connectionsets(inference_graph) for partition in clock_partitions - clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition) + clockidxs = findall(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Clock), partition) if isempty(clockidxs) push!(partition, ClockVertex.Clock(ContinuousClock())) push!(clockidxs, length(partition)) From e8d3e875e61e59df6ee1e2865be52ddef463644a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:49:25 +0530 Subject: [PATCH 070/157] fix: improve type-stability of `Schedule` struct --- src/systems/system.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index e364847009..d3f5ade17b 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -3,7 +3,7 @@ struct Schedule """ Mapping of `Differential`s of variables to corresponding derivative expressions. """ - dummy_sub::Dict{Any, Any} + dummy_sub::Dict{SymbolicT, SymbolicT} end const MetadataT = Base.ImmutableDict{DataType, Any} From ed4a4316fbbeb18ca23871281f30d67b7797554c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:50:20 +0530 Subject: [PATCH 071/157] fix: improve type-stability of `\itdneighbors` --- src/bipartite_graph.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index 6e4f359617..0fab095928 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -375,7 +375,8 @@ end function 𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M} = Val(false)) where {M} require_complete(g) - M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j] + backj = g.badjlist[j]::Vector{Int} + M ? zip(backj, (g.metadata[i][j] for i in backj)) : backj end Graphs.ne(g::BipartiteGraph) = g.ne Graphs.nv(g::BipartiteGraph) = sum(length, vertices(g)) From 9c7522bd47ee9eb47647a9fe7a8d140fc0928156 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:50:49 +0530 Subject: [PATCH 072/157] fix: improve type-stability of IO utils --- src/inputoutput.jl | 18 +++++++++--------- src/linearization.jl | 25 ++++++++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 5afc6b2000..adbc6de6d4 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -264,7 +264,7 @@ end """ Turn input variables into parameters of the system. """ -function inputs_to_parameters!(state::TransformationState, inputsyms) +function inputs_to_parameters!(state::TransformationState, inputsyms::Vector{SymbolicT}) check_bound = inputsyms === nothing @unpack structure, fullvars, sys = state @unpack var_to_diff, graph, solvable_graph = structure @@ -274,9 +274,9 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) var_reidx = zeros(Int, length(fullvars)) ninputs = 0 nvar = 0 - new_parameters = [] - input_to_parameters = Dict() - new_fullvars = [] + new_parameters = SymbolicT[] + input_to_parameters = Dict{SymbolicT, SymbolicT}() + new_fullvars = SymbolicT[] for (i, v) in enumerate(fullvars) if isinput(v) && !(check_bound && is_bound(sys, v)) if var_to_diff[i] !== nothing @@ -295,8 +295,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) end end if ninputs == 0 - @set! sys.inputs = OrderedSet{BasicSymbolic}() - @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) + @set! sys.inputs = OrderedSet{SymbolicT}() + @set! sys.outputs = OrderedSet{SymbolicT}(filter(isoutput, fullvars)) state.sys = sys return state end @@ -329,10 +329,10 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) ps = parameters(sys) @set! sys.ps = [ps; new_parameters] - @set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters) - @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) + @set! sys.inputs = OrderedSet{SymbolicT}(new_parameters) + @set! sys.outputs = OrderedSet{SymbolicT}(filter(isoutput, fullvars)) @set! state.sys = sys - @set! state.fullvars = Vector{BasicSymbolic}(new_fullvars) + @set! state.fullvars = Vector{SymbolicT}(new_fullvars) @set! state.structure = structure return state end diff --git a/src/linearization.jl b/src/linearization.jl index 3c11484d61..aa8c93df94 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -611,13 +611,24 @@ end """ Modify the variable metadata of system variables to indicate which ones are inputs, outputs, and disturbances. Needed for `inputs`, `outputs`, `disturbances`, `unbound_inputs`, `unbound_outputs` to return the proper subsets. """ -function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true) +function markio!(state, orig_inputs::Set{SymbolicT}, + inputs::Vector{SymbolicT}, outputs::Vector{SymbolicT}, + disturbances::Vector{SymbolicT}; check = true) fullvars = get_fullvars(state) - inputset = Dict{Any, Bool}(i => false for i in inputs) - outputset = Dict{Any, Bool}(o => false for o in outputs) - disturbanceset = Dict{Any, Bool}(d => false for d in disturbances) + inputset = Dict{SymbolicT, Bool}() + for i in inputs + inputset[i] = false + end + outputset = Dict{SymbolicT, Bool}() + for o in outputs + outputset[o] = false + end + disturbanceset = Dict{SymbolicT, Bool}() + for d in disturbances + disturbanceset[d] = false + end for (i, v) in enumerate(fullvars) - if v in keys(inputset) + if haskey(inputset, v) if v in keys(outputset) v = setio(v, true, true) outputset[v] = true @@ -626,7 +637,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true end inputset[v] = true fullvars[i] = v - elseif v in keys(outputset) + elseif haskey(outputset, v) v = setio(v, false, true) outputset[v] = true fullvars[i] = v @@ -638,7 +649,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true fullvars[i] = v end - if v in keys(disturbanceset) + if haskey(disturbanceset, v) v = setio(v, true, false) v = setdisturbance(v, true) disturbanceset[v] = true From 119565b525db7edf224fe24a159e96b4dc428228 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 7 Oct 2025 14:51:27 +0530 Subject: [PATCH 073/157] fix: improve type-stability of `mtkcompile` --- src/systems/systems.jl | 35 ++++++++++++------------------ src/systems/systemstructure.jl | 39 ++++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index f49c901867..9b0d469534 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -28,40 +28,31 @@ present in the equations of the system will be removed in this process. + `sort_eqs=true` controls whether equations are sorted lexicographically before simplification or not. """ function mtkcompile( - sys::AbstractSystem; additional_passes = [], simplify = false, split = true, + sys::System; additional_passes = (), simplify = false, split = true, allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs = SymbolicT[], outputs = SymbolicT[], + disturbance_inputs = SymbolicT[], kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) - newsys′ = __mtkcompile(sys; simplify, + # Canonicalize types of arguments to prevent repeated compilation of inner methods + inputs = unwrap_vars(inputs) + outputs = unwrap_vars(outputs) + disturbance_inputs = unwrap_vars(disturbance_inputs) + newsys = __mtkcompile(sys; simplify, allow_symbolic, allow_parameter, conservative, fully_determined, inputs, outputs, disturbance_inputs, additional_passes, kwargs...) - if newsys′ isa Tuple - @assert length(newsys′) == 2 - newsys = newsys′[1] - else - newsys = newsys′ - end for pass in additional_passes newsys = pass(newsys) end - if has_parent(newsys) - @set! newsys.parent = complete(sys; split = false, flatten = false) - end + @set! newsys.parent = complete(sys; split = false, flatten = false) newsys = complete(newsys; split) - if newsys′ isa Tuple - idxs = [parameter_index(newsys, i) for i in io[1]] - return newsys, idxs - else - return newsys - end + return newsys end function __mtkcompile(sys::AbstractSystem; simplify = false, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs::Vector{SymbolicT} = SymbolicT[], outputs::Vector{SymbolicT} = SymbolicT[], + disturbance_inputs::Vector{SymbolicT} = SymbolicT[], sort_eqs = true, kwargs...) # TODO: convert noise_eqs to brownians for simplification @@ -72,7 +63,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return sys end if isempty(equations(sys)) && !is_time_dependent(sys) && !_iszero(cost(sys)) - return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) + return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)::System end sys, statemachines = extract_top_level_statemachines(sys) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index d18a7984e6..45b483db63 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1003,9 +1003,9 @@ function make_eqs_zero_equals!(ts::TearingState) end function mtkcompile!(state::TearingState; simplify = false, - check_consistency = true, fully_determined = true, warn_initialize_determined = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + check_consistency = true, fully_determined = true, + inputs = SymbolicT[], outputs = SymbolicT[], + disturbance_inputs = SymbolicT[], kwargs...) if !is_time_dependent(state.sys) return _mtkcompile!(state; simplify, check_consistency, @@ -1017,8 +1017,6 @@ function mtkcompile!(state::TearingState; simplify = false, # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) - time_domains = merge(Dict(state.fullvars .=> ci.var_domain), - Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) if !isempty(tss) && continuous_id == 0 # do a trait check here - handle fully discrete system @@ -1075,17 +1073,17 @@ function mtkcompile!(state::TearingState; simplify = false, end function _mtkcompile!(state::TearingState; simplify = false, - check_consistency = true, fully_determined = true, warn_initialize_determined = false, + check_consistency = true, fully_determined = true, dummy_derivative = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs::Vector{SymbolicT} = SymbolicT[], outputs::Vector{SymbolicT} = SymbolicT[], + disturbance_inputs::Vector{SymbolicT} = SymbolicT[], kwargs...) if fully_determined isa Bool check_consistency &= fully_determined else check_consistency = true end - orig_inputs = Set() + orig_inputs = Set{SymbolicT}() ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs) state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs]) trivial_tearing!(state) @@ -1094,6 +1092,22 @@ function _mtkcompile!(state::TearingState; simplify = false, fully_determined = ModelingToolkit.check_consistency( state, orig_inputs; nothrow = fully_determined === nothing) end + # This phrasing avoids making the `kwcall` dynamic dispatch due to the type of a + # keyword (`mm`) being non-concrete + if mm isa SparseMatrixCLIL{BigInt, Int} + sys = _mtkcompile_worker!(state, sys, mm; fully_determined, dummy_derivative, simplify, kwargs...) + else + sys =_mtkcompile_worker!(state, sys, mm; fully_determined, dummy_derivative, simplify, kwargs...) + end + fullunknowns = [observables(sys); unknowns(sys)] + @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns) + + ModelingToolkit.invalidate_cache!(sys) +end + +function _mtkcompile_worker!(state::TearingState{S}, sys::S, mm::SparseMatrixCLIL{T, Int}; + fully_determined::Bool, dummy_derivative::Bool, simplify::Bool, + kwargs...) where {S, T} if fully_determined && dummy_derivative sys = ModelingToolkit.dummy_derivative( sys, state; simplify, mm, check_consistency, kwargs...) @@ -1101,17 +1115,14 @@ function _mtkcompile!(state::TearingState; simplify = false, var_eq_matching = pantelides!(state; finalize = false, kwargs...) sys = pantelides_reassemble(state, var_eq_matching) state = TearingState(sys) - sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) + sys, mm::SparseMatrixCLIL{T, Int} = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) sys = ModelingToolkit.dummy_derivative( sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) else sys = ModelingToolkit.tearing( sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) end - fullunknowns = [observables(sys); unknowns(sys)] - @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns) - - ModelingToolkit.invalidate_cache!(sys) + return sys end struct DifferentiatedVariableNotUnknownError <: Exception From 891fde3f8e36f82278d54a5ebf5935ca4487103e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 16:59:08 +0530 Subject: [PATCH 074/157] refactor: update old `get_variables!` methods to `search_variables!` --- src/problems/jumpproblem.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/problems/jumpproblem.jl b/src/problems/jumpproblem.jl index 113f5fc2f2..6d7eb01f24 100644 --- a/src/problems/jumpproblem.jl +++ b/src/problems/jumpproblem.jl @@ -194,17 +194,17 @@ function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump, Variable end ### Functions to determine which unknowns a jump depends on -function get_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}, variables) - jr = value(jump.rate) - (jr isa SymbolicT) && get_variables!(dep, jr, variables) +function SU.search_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}; kw...) + jr = unwrap(jump.rate) + (jr isa SymbolicT) && SU.search_variables!(dep, jr; kw...) dep end -function get_variables!(dep, jump::MassActionJump, variables) - sr = value(jump.scaled_rates) - (sr isa SymbolicT) && get_variables!(dep, sr, variables) +function SU.search_variables!(dep, jump::MassActionJump; is_atomic = SU.default_is_atomic, kw...) + sr = unwrap(jump.scaled_rates) + (sr isa SymbolicT) && SU.search_variables!(dep, sr; kw...) for varasop in jump.reactant_stoch - any(isequal(varasop[1]), variables) && push!(dep, varasop[1]) + is_atomic(varasop[1]) && push!(dep, varasop[1]) end dep end From 357fbf84b14ede538305976c68845ae2134268a2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 16:59:56 +0530 Subject: [PATCH 075/157] refactor: remove dead code in `unhack_observed` --- src/systems/nonlinear/initializesystem.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index b318b7392a..68d5b8f26e 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -823,20 +823,12 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w initialization. """ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) - subs = Dict{SymbolicT, SymbolicT}() mask = trues(length(obseqs)) for (i, eq) in enumerate(obseqs) mask[i] = !iscall(eq.rhs) || operation(eq.rhs) !== StructuralTransformations.change_origin end obseqs = obseqs[mask] - for i in eachindex(obseqs) - obseqs[i] = fixpoint_sub(obseqs[i].lhs, subs) ~ fixpoint_sub(obseqs[i], subs) - end - eqs = copy(eqs) - for i in eachindex(eqs) - eqs[i] = fixpoint_sub(eqs[i].lhs, subs) ~ fixpoint_sub(eqs[i], subs) - end return obseqs, eqs end From beac316db5550c83c757baeb310bcb17160c3838 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:04:03 +0530 Subject: [PATCH 076/157] fix: improve type-stability of 2-arg `System` constructor --- src/systems/system.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index d3f5ade17b..f8789d8938 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -529,11 +529,12 @@ other symbolic expressions passed to the system. function System(eqs::Vector{Equation}, iv; kwargs...) iv === nothing && return System(eqs; kwargs...) - diffvars = OrderedSet() - othervars = OrderedSet() - ps = Set() + diffvars = OrderedSet{SymbolicT}() + othervars = OrderedSet{SymbolicT}() + ps = OrderedSet{SymbolicT}() diffeqs = Equation[] othereqs = Equation[] + iv = unwrap(iv) for eq in eqs if !(eq.lhs isa Union{SymbolicT, Number, AbstractArray}) push!(othereqs, eq) @@ -562,7 +563,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...) allunknowns = union(diffvars, othervars) eqs = [diffeqs; othereqs] - brownians = Set() + brownians = Set{SymbolicT}() for x in allunknowns x = unwrap(x) if getvariabletype(x) == BROWNIAN @@ -601,8 +602,8 @@ function System(eqs::Vector{Equation}, iv; kwargs...) noiseeqs = get(kwargs, :noise_eqs, nothing) if noiseeqs !== nothing # validate noise equations - noisedvs = OrderedSet() - noiseps = OrderedSet() + noisedvs = OrderedSet{SymbolicT}() + noiseps = OrderedSet{SymbolicT}() collect_vars!(noisedvs, noiseps, noiseeqs, iv) for dv in noisedvs dv ∈ allunknowns || From d07b0930a75e26054f9fddfed21d387f312ec98f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:08:28 +0530 Subject: [PATCH 077/157] fixup! fix: improve type-stability of some SII functions --- src/systems/abstractsystem.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b4093842c8..52c443e35a 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -317,7 +317,8 @@ for traitT in [ ArraySymbolic ] @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) - allsyms = vars(sym; op = Symbolics.Operator) + allsyms = Set{SymbolicT}() + SU.search_variables!(allsyms, sym; is_atomic = OperatorIsAtomic{Symbolics.Operator}()) for s in allsyms s = unwrap(s) if is_variable(sys, s) || is_independent_variable(sys, s) From fb899b25f52d83defb977ca970ce2028f977f71e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:09:47 +0530 Subject: [PATCH 078/157] fixup! fix: make `add_initialization_parameters` type-stable --- src/systems/abstractsystem.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 52c443e35a..49aa2635d9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -583,7 +583,8 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) defs[ivar] = false else defs[ivar] = collect(ivar) - for scal_ivar in defs[ivar] + for idx in SU.stable_eachindex(ivar) + scal_ivar = ivar[idx] defs[scal_ivar] = false end end From 46ea268bbbcc7fe8f6c5c061e95a92aa4f286909 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:09:53 +0530 Subject: [PATCH 079/157] fixup! fix: improve precompile-friendliness of `complete` --- src/systems/abstractsystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 49aa2635d9..1f2e512b4e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -682,12 +682,12 @@ function complete( offset = 0 # if there are no tunables, vcat them if !isempty(get_index_cache(sys).tunable_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps_set) + unflatten_parameters!(ordered_ps, ps_split[offset + 1], all_ps_set) offset += 1 end # unflatten initial parameters if !isempty(get_index_cache(sys).initials_idx) - unflatten_parameters!(ordered_ps, ps_split[2], all_ps_set) + unflatten_parameters!(ordered_ps, ps_split[offset + 1], all_ps_set) offset += 1 end for i in (offset+1):length(ps_split) From d05341a75800c61f2741bada818a740b5da69b69 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:10:03 +0530 Subject: [PATCH 080/157] fixup! refactor: improve type-stability of `renamespace` --- src/systems/abstractsystem.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 1f2e512b4e..61eeda204e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1126,6 +1126,9 @@ renamespace(names::AbstractVector, x) = foldr(renamespace, names, init = x) renamespace(sys, tgt::AbstractSystem) = rename(tgt, renamespace(sys, nameof(tgt))) renamespace(sys, tgt::Symbol) = Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, tgt) +renamespace(sys, x::Num) = Num(renamespace(sys, unwrap(x))) +renamespace(sys, x::Arr{T, N}) where {T, N} = Arr{T, N}(renamespace(sys, unwrap(x))) +renamespace(sys, x::CallAndWrap{T}) where {T} = CallAndWrap{T}(renamespace(sys, unwrap(x))) """ $(TYPEDSIGNATURES) From 739e1b56ff68dc882d1167fa0edcfdf7c64a9717 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:10:14 +0530 Subject: [PATCH 081/157] fixup! fix: improve precompile-friendliness of `complete` --- src/systems/abstractsystem.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 61eeda204e..7154efb5f6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1290,7 +1290,7 @@ function namespace_expr(O, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) return O end function namespace_expr(O::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) - namespace_expr(O, args...; kw...) + typeof(O)(namespace_expr(unwrap(O), sys, n; kw...)) end function namespace_expr(O::AbstractArray, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) is_array_of_symbolics(O) || return O @@ -1318,12 +1318,15 @@ function namespace_expr(O::SymbolicT, sys::AbstractSystem, n::Symbol = nameof(sy elseif f isa SymbolicT f = renamespace(n, f) meta = metadata + else + meta = metadata end return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata = meta) end BSImpl.AddMul(; coeff, dict, variant, type, shape, metadata) => begin newdict = copy(dict) - for (k, v) in newdict + empty!(newdict) + for (k, v) in dict newdict[namespace_expr(k, sys, n; ivs)] = v end return BSImpl.AddMul{VartypeT}(coeff, newdict, variant; type, shape, metadata) From 8889a67c18e683e85c6a549f398341dfae84c2cc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:10:47 +0530 Subject: [PATCH 082/157] fixup! fix: improve precompile-friendliness of `complete` --- src/systems/abstractsystem.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 7154efb5f6..fc171965e0 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1389,15 +1389,12 @@ See also [`@parameters`](@ref) and [`ModelingToolkit.get_ps`](@ref). function parameters(sys::AbstractSystem; initial_parameters = false) ps = get_ps(sys) if ps === SciMLBase.NullParameters() - return [] + return SymbolicT[] end if eltype(ps) <: Pair - ps = first.(ps) + ps = Vector{SymbolicT}(unwrap.(first.(ps))) end systems = get_systems(sys) - if isempty(systems) - return ps - end result = copy(ps) for subsys in systems append!(result, namespace_parameters(subsys)) From d95e4458df6b106d6958e5eb17c49e72c6594518 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:10:51 +0530 Subject: [PATCH 083/157] fixup! fix: improve precompile-friendliness of `complete` --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index fc171965e0..cb0c2d2fc9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1768,7 +1768,7 @@ function constraints(sys::AbstractSystem) cs = get_constraints(sys) systems = get_systems(sys) isempty(systems) && return cs - cs = copy(sys) + cs = copy(cs) for subsys in systems append!(cs, namespace_constraints(subsys)) end From 2a0b1b1a830fc6486869d1ff795959e33aff3e54 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:11:06 +0530 Subject: [PATCH 084/157] fix: improve type-stability of `compose` --- src/systems/abstractsystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index cb0c2d2fc9..b8a73f05d6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2752,8 +2752,8 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys) if has_is_dde(sys) @set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys)) end - newunknowns = OrderedSet() - newparams = OrderedSet() + newunknowns = OrderedSet{SymbolicT}() + newparams = OrderedSet{SymbolicT}() iv = has_iv(sys) ? get_iv(sys) : nothing for ssys in systems collect_scoped_vars!(newunknowns, newparams, ssys, iv) From f48b38306faee8db1db70cc8f2d88f57b92264ba Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:11:20 +0530 Subject: [PATCH 085/157] fix: improve type-stability of `split_system` --- src/systems/clock_inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 36650723b2..96529e0423 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -327,7 +327,7 @@ function split_system(ci::ClockInference{S}) where {S} continuous_id = continuous_id[] # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) - inputs = map(_ -> Any[], 1:cid_counter[]) + inputs = map(_ -> SymbolicT[], 1:cid_counter[]) # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) # put variables into the right clock partition From 0ced1957570d2f9b03cccb874cde59ff130fdfbc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:11:50 +0530 Subject: [PATCH 086/157] refactor: handle new `reorder_parameters` in `build_explicit_observed_function` --- src/systems/codegen.jl | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index b5230e7454..486d2f6180 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -1024,6 +1024,17 @@ function build_explicit_observed_function(sys, ts; cse = true, mkarray = nothing, wrap_delays = is_dde(sys)) + if inputs === nothing + inputs = () + else + inputs = vec(unwrap_vars(inputs)) + end + if disturbance_inputs === nothing + disturbance_inputs = () + else + disturbance_inputs = vec(unwrap_vars(disturbance_inputs)) + end + ps::Vector{SymbolicT} = vec(unwrap_vars(ps)) # TODO: cleanup is_tuple = ts isa Tuple if is_tuple @@ -1038,7 +1049,8 @@ function build_explicit_observed_function(sys, ts; ts = symbol_to_symbolic(sys, ts; allsyms) end - vs = ModelingToolkit.vars(ts; op) + vs = Set{SymbolicT}() + SU.search_variables!(vs, ts; is_atomic = OperatorIsAtomic{op}()) namespace_subs = Dict() ns_map = Dict{Any, Any}(renamespace(sys, eq.lhs) => eq.lhs for eq in observed(sys)) for sym in unknowns(sys) @@ -1084,13 +1096,11 @@ function build_explicit_observed_function(sys, ts; else (unknowns(sys),) end - if inputs === nothing - inputs = () - else + if inputs isa Vector{SymbolicT} ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list inputs = (inputs,) end - if disturbance_inputs !== nothing + if disturbance_inputs isa Vector{SymbolicT} # Disturbance inputs may or may not be included as inputs, depending on disturbance_argument ps = setdiff(ps, disturbance_inputs) end @@ -1099,15 +1109,15 @@ function build_explicit_observed_function(sys, ts; else disturbance_inputs = () end - ps = reorder_parameters(sys, ps) + rps::ReorderedParametersT = reorder_parameters(sys, ps) iv = if is_time_dependent(sys) (get_iv(sys),) else () end - args = (dvs..., inputs..., ps..., iv..., disturbance_inputs...) + args = (dvs..., inputs..., rps..., iv..., disturbance_inputs...) p_start = length(dvs) + length(inputs) + 1 - p_end = length(dvs) + length(inputs) + length(ps) + p_end = length(dvs) + length(inputs) + length(rps) fns = build_function_wrapper( sys, ts, args...; p_start, p_end, filter_observed = obsfilter, output_type, mkarray, try_namespaced = true, expression = Val{true}, cse, @@ -1118,7 +1128,7 @@ function build_explicit_observed_function(sys, ts; end oop, iip = eval_or_rgf.(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( + p_start + wrap_delays, length(args) - length(rps) + 1 + wrap_delays, is_split(sys))}( oop, iip) return return_inplace ? (f, f) : f else @@ -1127,7 +1137,7 @@ function build_explicit_observed_function(sys, ts; end f = eval_or_rgf(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( + p_start + wrap_delays, length(args) - length(rps) + 1 + wrap_delays, is_split(sys))}( f, nothing) return f end From f5d0bbed181f4c51d9c6abfdc87160f216bc7361 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:12:12 +0530 Subject: [PATCH 087/157] fixup! fix: improve type-stability of connection infrastructure --- src/systems/connectors.jl | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index af23951cfa..49f1354a26 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -834,21 +834,15 @@ function _generate_connection_set!(connection_state::ConnectionState, return sys end -function _flow_equations_from_idxs!(eqs::Vector{Equation}, cset::Vector{ConnectionVertex}, idxs::CartesianIndices{N, NTuple{N, UnitRange{Int}}}) where {N} +function _flow_equations_from_idxs!(sys::AbstractSystem, eqs::Vector{Equation}, cset::Vector{ConnectionVertex}, len::Int) add_buffer = SymbolicT[] # each variable can have different axes, but they all have the same size - for sz_i in eachindex(idxs) + for sz_i in 1:len empty!(add_buffer) for cvert in cset - # all of this wrapping/unwrapping is necessary because the relevant - # methods are defined on `Arr/Num` and not `BasicSymbolic`. v = variable_from_vertex(sys, cvert)::SymbolicT - if N === 0 - v = v - else - vidxs = eachindex(v)::CartesianIndices{N, NTuple{N, UnitRange{Int}}} - v = v[vidxs[sz_i]] - end + vidxs = SU.stable_eachindex(v) + v = v[vidxs[sz_i]] push!(add_buffer, cvert.isouter ? -v : v) end rhs = SU.add_worker(VartypeT, add_buffer) @@ -907,16 +901,7 @@ function generate_connection_equations_and_stream_connections( # to bad-looking equations. Just generate scalar equations instead since # mtkcompile will scalarize anyway. representative = variable_from_vertex(sys, cset[1])::SymbolicT - idxs = eachindex(representative) - if idxs isa CartesianIndices{0, Tuple{}} - _flow_equations_from_idxs!(eqs, cset, idxs) - elseif idxs isa CartesianIndices{1, Tuple{UnitRange{Int}}} - _flow_equations_from_idxs!(eqs, cset, idxs) - elseif idxs isa CartesianIndices{2, NTuple{2, UnitRange{Int}}} - _flow_equations_from_idxs!(eqs, cset, idxs) - else - _flow_equations_from_idxs!(eqs, cset, idxs) - end + _flow_equations_from_idxs!(sys, eqs, cset, length(representative)::Int) else # Equality vars = SymbolicT[] for cvar in cset From a795669f5ab938d2f0f51f0cd14cb1413bb4a306 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:12:27 +0530 Subject: [PATCH 088/157] fixup! fix: make some SII impls of `IndexCache` more type-stable --- src/systems/index_cache.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index e0dc2ef923..8caef22efb 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -424,11 +424,15 @@ function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) variable_index(ic, sym) !== nothing end -function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return nothing - end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + variable_index(ic, unwrap(sym)) +end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::Symbol) + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + variable_index(ic, sym) +end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::SymbolicT) idx = check_index_map(ic.unknown_idx, sym) idx === nothing || return idx iscall(sym) && operation(sym) == getindex || return nothing @@ -437,6 +441,7 @@ function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) idx === nothing && return nothing return idx[args[2:end]...] end +SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) = false function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) parameter_index(ic, sym) !== nothing From cf8af8b5234e65756539357b9fa36f0cbcdb76cb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:13:25 +0530 Subject: [PATCH 089/157] fixup! fix: make `reorder_parameters` more type-stable --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 8caef22efb..13b610bade 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -518,7 +518,7 @@ const ReorderedParametersT = Vector{Union{Vector{SymbolicT}, Vector{Vector{Symbo function reorder_parameters( sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing - reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) + return reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) elseif ps isa Tuple return ReorderedParametersT(collect(ps)) else From 89468bf496e6e872db949cd21efc0189d0078214 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:13:28 +0530 Subject: [PATCH 090/157] fixup! fix: improve precompile-friendliness of `complete` --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 13b610bade..c340960c16 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -522,7 +522,7 @@ function reorder_parameters( elseif ps isa Tuple return ReorderedParametersT(collect(ps)) else - eltype(ReorderedParametersT)[ps] + return eltype(ReorderedParametersT)[ps] end end From c0f7b91c7d39ec902bdb4c0cdcb598be147f047c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:13:33 +0530 Subject: [PATCH 091/157] fixup! fix: make `reorder_parameters` more type-stable --- src/systems/index_cache.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index c340960c16..80f52029d7 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -532,9 +532,13 @@ function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing result = ReorderedParametersT() isempty(ps) && return result param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length) - push!(result, param_buf) + if !isempty(param_buf) || !flatten + push!(result, param_buf) + end initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length) - push!(result, initials_buf) + if !isempty(initials_buf) || !flatten + push!(result, initials_buf) + end disc_buf = Vector{SymbolicT}[] for bufszs in ic.discrete_buffer_sizes From 9443322a517874dd11ca1ef49c220131bd76b133 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:14:22 +0530 Subject: [PATCH 092/157] fix: handle new `reorder_parameters` in `get_mtkparameters_reconstructor` --- src/systems/problem_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ed42c96d89..0f444bf882 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -812,12 +812,12 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ Base.Fix1(broadcast, p_constructor) ∘ - concrete_getu(srcsys, syms[3]; eval_expression, eval_module) + concrete_getu(srcsys, Tuple(syms[3]); eval_expression, eval_module) end const_getter = if syms[4] == () Returns(()) else - Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, Tuple(syms[4])) end nonnumeric_getter = if syms[5] == () Returns(()) @@ -829,7 +829,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # nonnumerics retain the assigned buffer type without narrowing Base.Fix1(broadcast, _p_constructor) ∘ Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ - concrete_getu(srcsys, syms[5]; eval_expression, eval_module) + concrete_getu(srcsys, Tuple(syms[5]); eval_expression, eval_module) end getters = ( tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) From 1545218c07fbaa7de6f6ce8c20034ee7221ad4e8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:14:43 +0530 Subject: [PATCH 093/157] refactor: avoid using `vars!` --- src/systems/problem_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 0f444bf882..a42099ffd1 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -446,14 +446,14 @@ function check_substitution_cycles( for (k, v) in varmap kidx = var_to_idx[k] if symbolic_type(v) != NotSymbolic() - vars!(buffer, v) + SU.search_variables!(buffer, v) for var in buffer haskey(var_to_idx, var) || continue add_edge!(graph, kidx, var_to_idx[var]) end elseif v isa AbstractArray for val in v - vars!(buffer, val) + SU.search_variables!(buffer, val) end for var in buffer haskey(var_to_idx, var) || continue From a67adf213c787547c16924afae04bfdf963a1324 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:14:57 +0530 Subject: [PATCH 094/157] feat: handle `Const` variants in problem building --- src/systems/problem_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index a42099ffd1..169b2c541a 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -242,10 +242,10 @@ symbolic values, all of which need to be unwrapped. Specializes when `x isa Abst to unwrap keys and values, returning an `AnyDict`. """ function recursive_unwrap(x::AbstractArray) - symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x) + symbolic_type(x) == ArraySymbolic() ? value(x) : recursive_unwrap.(x) end -recursive_unwrap(x) = unwrap(x) +recursive_unwrap(x) = value(x) function recursive_unwrap(x::AbstractDict) return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x) @@ -261,7 +261,7 @@ entry for `eq.lhs`, insert the reverse mapping if `eq.rhs` is not a number. function add_observed_equations!(varmap::AbstractDict, eqs) for eq in eqs if var_in_varlist(eq.lhs, keys(varmap), nothing) - eq.rhs isa Number && continue + SU.isconst(eq.rhs) && continue var_in_varlist(eq.rhs, keys(varmap), nothing) && continue !iscall(eq.rhs) || issym(operation(eq.rhs)) || continue varmap[eq.rhs] = eq.lhs From 8e2b96487b527369b15148ea23fa92cf870a3de4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:15:08 +0530 Subject: [PATCH 095/157] refactor: simplify method definitions --- src/clock.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/clock.jl b/src/clock.jl index 6537334645..08c3d84dd0 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -53,10 +53,8 @@ end has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false -for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),) - @eval output_timedomain(::$op, arg = nothing) = ContinuousClock() -end +input_timedomain(::Differential, arg = nothing) = InputTimeDomainElT[ContinuousClock()] +output_timedomain(::Differential, arg = nothing) = ContinuousClock() """ has_discrete_domain(x) From 3955abb1fc42740e0f506bbb7fed0d9cc8f14ca9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:15:23 +0530 Subject: [PATCH 096/157] refactor: improve type-stability of `ShiftIndex` --- src/discretedomain.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 9226a49be0..8a551f1670 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -264,9 +264,10 @@ end function (xn::Num)(k::ShiftIndex) @unpack clock, steps = k - x = value(xn) + x = unwrap(xn) # Verify that the independent variables of k and x match and that the expression doesn't have multiple variables - vars = ModelingToolkit.vars(x) + vars = Set{SymbolicT}() + SU.search_variables!(vars, x) if length(vars) != 1 error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") end From 4ed662b92cd3eb804e0fe4dd006ce7f15724ecc4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:16:14 +0530 Subject: [PATCH 097/157] fixup! fix: make `input_timedomain` type-stable --- src/discretedomain.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 8a551f1670..bde06ee64d 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -335,7 +335,7 @@ Should return a tuple containing the time domain type for each argument to the o """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) - return get_time_domain(arg) + return InputTimeDomainElT[get_time_domain(arg)] end InputTimeDomainElT[InferredDiscrete()] end @@ -357,7 +357,7 @@ output_timedomain(s::Sample, _ = nothing) = s.clock function input_timedomain(::Hold, arg = nothing) if has_time_domain(arg) - return get_time_domain(arg) + return InputTimeDomainElT[get_time_domain(arg)] end InputTimeDomainElT[InferredDiscrete()] # the Hold accepts any discrete end From 54bedf56b88d7640c2be4cf9721f6b6ddc0a5a6c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:17:02 +0530 Subject: [PATCH 098/157] refactor: improve performance, type stability of `is_bound` --- src/inputoutput.jl | 86 +++++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index adbc6de6d4..a52e91941c 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -56,17 +56,31 @@ function _is_atomic_inside_operator(ex::SymbolicT) end end -""" - is_bound(sys, u) +struct IsBoundValidator + eqs_vars::Vector{Set{SymbolicT}} + obs_vars::Vector{Set{SymbolicT}} + stack::OrderedSet{SymbolicT} +end -Determine whether input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`. -A variable/signal is considered bound if it appears in an equation together with variables from other subsystems. -The typical usecase for this function is to determine whether the input to an IO component is connected to another component, -or if it remains an external input that the user has to supply before simulating the system. +function IsBoundValidator(sys::System) + eqs_vars = Set{SymbolicT}[] + for eq in equations(sys) + vars = Set{SymbolicT}() + SU.search_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + SU.search_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + push!(eqs_vars, vars) + end + obs_vars = Set{SymbolicT}[] + for eq in observed(sys) + vars = Set{SymbolicT}() + SU.search_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + SU.search_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + push!(obs_vars, vars) + end + return IsBoundValidator(eqs_vars, obs_vars, OrderedSet{SymbolicT}()) +end -See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref) -""" -function is_bound(sys, u, stack = []) +function (ibv::IsBoundValidator)(u::SymbolicT) #= For observed quantities, we check if a variable is connected to something that is bound to something further out. In the following scenario @@ -78,40 +92,42 @@ function is_bound(sys, u, stack = []) When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask if var is bound, if it is, then sys₊y(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles =# - u ∈ Set(stack) && return false # Cycle detected - eqs = equations(sys) - eqs = filter(eq -> has_var(eq, u), eqs) # Only look at equations that contain u - # isout = isoutput(u) - vars = Set{SymbolicT}() - for eq in eqs - empty!(vars) - get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) - get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + u in ibv.stack && return false # Cycle detected + for vars in ibv.eqs_vars + u in vars || continue for var in vars var === u && continue - if !same_or_inner_namespace(u, var) - return true - end + same_or_inner_namespace(u, var) || return true end end - # Look through observed equations as well - oeqs = observed(sys) - oeqs = filter(eq -> has_var(eq, u), oeqs) # Only look at equations that contain u - for eq in oeqs - empty!(vars) - get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) - get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + for vars in ibv.obs_vars + u in vars || continue for var in vars var === u && continue - if !same_or_inner_namespace(u, var) - return true - end - if is_bound(sys, var, [stack; u]) && !inner_namespace(u, var) # The variable we are comparing to can not come from an inner namespace, binding only counts outwards - return true - end + same_or_inner_namespace(u, var) || return true + push!(ibv.stack, u) + isbound = ibv(var) + pop!(ibv.stack) + # The variable we are comparing to can not come from an inner namespace, + # binding only counts outwards + isbound && !inner_namespace(u, var) && return true end end - false + return false +end + +""" + is_bound(sys, u) + +Determine whether input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`. +A variable/signal is considered bound if it appears in an equation together with variables from other subsystems. +The typical usecase for this function is to determine whether the input to an IO component is connected to another component, +or if it remains an external input that the user has to supply before simulating the system. + +See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref) +""" +function is_bound(sys, u) + return IsBoundValidator(sys)(unwrap(u)) end """ From b7f1011ea0ea5eb3852ed04fd785f864e0b430c7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:17:15 +0530 Subject: [PATCH 099/157] refactor: improve type-stability of `generate_control_function` --- src/inputoutput.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index a52e91941c..8f0b063fb5 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -231,9 +231,9 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs( # add to inputs for the purposes of io processing inputs = [inputs; disturbance_inputs] end - + inputs = vec(unwrap_vars(inputs)) dvs = unknowns(sys) - ps = parameters(sys; initial_parameters = true) + ps::Vector{SymbolicT} = parameters(sys; initial_parameters = true) ps = setdiff(ps, inputs) if disturbance_inputs !== nothing # remove from inputs since we do not want them as actual inputs to the dynamics From 7e2693534a1a9fdf8d93e9913e2efe868243e6bf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:17:36 +0530 Subject: [PATCH 100/157] fixup! fix: make `isparameter` type stable --- src/parameters.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parameters.jl b/src/parameters.jl index f2eeb3d8e0..d3bc796d2f 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -27,6 +27,7 @@ function isparameter(x::SymbolicT) varT = getvariabletype(x, nothing) return varT === PARAMETER end +isparameter(x) = false function iscalledparameter(x) x = unwrap(x) From 954af3e94f93ba049d657490ccdf77ec2180af1c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:17:46 +0530 Subject: [PATCH 101/157] fixup! fix: remove usages of `occursin` for searching expressions --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 967c35225c..93d53c375f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -148,7 +148,7 @@ function check_variables(dvs, iv) for dv in dvs isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables.")) - (is_delay_var(iv, dv) || SU.query!(isequal(iv), dv)) || + (is_delay_var(iv, dv) || SU.query(isequal(iv), dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv.")) end end From ad3a834db6a3f29490ffcca19538b8bb11b1f89c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:17:49 +0530 Subject: [PATCH 102/157] fixup! refactor: get `System` to precompile in a trivial case --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 93d53c375f..cba76469d1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -211,8 +211,8 @@ Assert that equations are well-formed when building ODE, i.e., only containing a function check_equations(eqs::Vector{Equation}, iv::SymbolicT) icp = IndepvarCheckPredicate(iv) for eq in eqs - SU.query!(icp, eq.lhs) - SU.query!(icp, eq.rhs) + SU.query(icp, eq.lhs) + SU.query(icp, eq.rhs) end end From f372c4d8f43b6b18210f263033254495e0999e71 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:18:04 +0530 Subject: [PATCH 103/157] refactor: improve type-stability of `check_operator_variables` --- src/utils.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index cba76469d1..d323b4ca64 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -360,11 +360,11 @@ end Check if all the LHS are unique """ function check_operator_variables(eqs, op::T) where {T} - ops = Set() - tmp = Set() + ops = Set{SymbolicT}() + tmp = Set{SymbolicT}() for eq in eqs _check_operator_variables(eq, op) - vars!(tmp, eq.lhs) + SU.search_variables!(tmp, eq.lhs; is_atomic = OperatorIsAtomic{Differential}()) if length(tmp) == 1 x = only(tmp) if op === Differential @@ -496,16 +496,16 @@ function collect_operator_variables(eq::Equation, args...) end """ - collect_operator_variables(eqs::AbstractVector{Equation}, op) + collect_operator_variables(eqs::Vector{Equation}, ::Type{op}) where {op} Return a `Set` containing all variables that have Operator `op` applied to them. See also [`collect_differential_variables`](@ref). """ -function collect_operator_variables(eqs::AbstractVector{Equation}, op) - vars = Set() - diffvars = Set() +function collect_operator_variables(eqs::Vector{Equation}, ::Type{op}) where {op} + vars = Set{SymbolicT}() + diffvars = Set{SymbolicT}() for eq in eqs - vars!(vars, eq; op = op) + SU.search_variables!(vars, eq; is_atomic = OperatorIsAtomic{op}()) for v in vars isoperator(v, op) || continue push!(diffvars, arguments(v)[1]) From 288cb103b12c69f9fc143a69d6606e2a58c00139 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:22:28 +0530 Subject: [PATCH 104/157] fixup! fix: improve inference of several utility functions --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index d323b4ca64..1af449ff97 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -674,7 +674,7 @@ function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{S return nothing end -function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray{SymbolicT}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) for var in expr collect_vars!(unknowns, parameters, var, iv; depth, op) end From 2c83233166ca3d2ca809bd877b1790443500e274 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:22:53 +0530 Subject: [PATCH 105/157] fixup! fix: improve inference of several utility functions --- src/utils.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 1af449ff97..1ca6f7d8c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -692,7 +692,7 @@ eqtype_supports_collect_vars(eq::Equation) = true eqtype_supports_collect_vars(eq::Inequality) = true eqtype_supports_collect_vars(eq::Pair) = true -function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv; +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, eq::Union{Equation, Inequality}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) @@ -700,12 +700,17 @@ function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv end function collect_vars!( - unknowns, parameters, p::Pair, iv; depth = 0, op = Symbolics.Operator) + unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, p::Pair, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) collect_vars!(unknowns, parameters, p[1], iv; depth, op) collect_vars!(unknowns, parameters, p[2], iv; depth, op) return nothing end +function collect_vars!( + unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + return nothing +end + """ $(TYPEDSIGNATURES) @@ -713,7 +718,7 @@ Identify whether `var` belongs to the current system using `depth` and scoping i Add `var` to `unknowns` or `parameters` appropriately, and search through any expressions in known metadata of `var` using `collect_vars!`. """ -function collect_var!(unknowns, parameters, var, iv; depth = 0) +function collect_var!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, var::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0) isequal(var, iv) && return nothing if Symbolics.iswrapped(var) error(""" From 90f13205eea1697411942985f8c8712ff1454704 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:23:57 +0530 Subject: [PATCH 106/157] fix: improve type-stability of `observed_equations_used_by` --- src/utils.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 1ca6f7d8c8..d0d82b0e96 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -942,7 +942,16 @@ Keyword arguments: `available_vars` will not be searched for in the observed equations. """ function observed_equations_used_by(sys::AbstractSystem, exprs; - involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = []) + involved_vars = nothing, obs = observed(sys), available_vars = Set{SymbolicT}()) + if involved_vars === nothing + involved_vars = Set{SymbolicT}() + SU.search_variables!(involved_vars, exprs; is_atomic = OperatorIsAtomic{Union{Shift, Differential, Initial}}()) + elseif !(involved_vars isa Set{SymbolicT}) + involved_vars = Set{SymbolicT}(involved_vars) + end + if !(available_vars isa Set) + available_vars = Set(available_vars) + end if iscomplete(sys) && obs == observed(sys) cache = getmetadata(sys, MutableCacheKey, nothing) obs_graph_cache = get!(cache, ObservedGraphCacheKey) do @@ -956,10 +965,6 @@ function observed_equations_used_by(sys::AbstractSystem, exprs; graph = observed_dependency_graph(obs) end - if !(available_vars isa Set) - available_vars = Set(available_vars) - end - obsidxs = BitSet() for sym in involved_vars sym in available_vars && continue From 352e5fc02beb1f8da6303f80e1dd80d2f735233b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:24:18 +0530 Subject: [PATCH 107/157] fixup! fix: make `flatten_equations` type-stable --- src/utils.jl | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d0d82b0e96..c9efdc6b2e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1160,29 +1160,8 @@ without scalarizing occurrences of array variables and return the new list of eq function flatten_equations(eqs::Vector{Equation}) _eqs = Equation[] for eq in eqs - shlhs = SU.shape(eq.lhs) - if isempty(shlhs) - push!(_eqs, eq) - continue - end - if length(shlhs) == 1 - lhs = collect(eq.lhs)::Vector{SymbolicT} - rhs = collect(eq.rhs)::Vector{SymbolicT} - for (l, r) in zip(lhs, rhs) - push!(_eqs, l ~ r) - end - elseif length(shlhs) == 2 - lhs = collect(eq.lhs)::Matrix{SymbolicT} - rhs = collect(eq.rhs)::Matrix{SymbolicT} - for (l, r) in zip(lhs, rhs) - push!(_eqs, l ~ r) - end - else - lhs = collect(eq.lhs)::Matrix{SymbolicT} - rhs = collect(eq.rhs)::Matrix{SymbolicT} - for (l, r) in zip(lhs, rhs) - push!(_eqs, l ~ r) - end + for (i1, i2) in zip(SU.stable_eachindex(eq.lhs), SU.stable_eachindex(eq.rhs)) + push!(_eqs, eq.lhs[i1] ~ eq.rhs[i2]) end end return _eqs From 3d05173dcd3c61e3f4ca54ad7f2d833b518ab1b7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 17:24:36 +0530 Subject: [PATCH 108/157] fixup! refactor: remove usages of `Symbolics.getparent` --- src/variables.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variables.jl b/src/variables.jl index 1c218e0646..7075aa3a0c 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -193,7 +193,7 @@ end ### Input, Output, Irreducible isvarkind(m, x, def = false) = safe_getmetadata(m, x, def) safe_getmetadata(m, x::Union{Num, Symbolics.Arr}, def) = safe_getmetadata(m, value(x), def) -function safe_getmetadata(m, x, default) +function safe_getmetadata(m::DataType, x::SymbolicT, default) hasmetadata(x, m) && return getmetadata(x, m) iscall(x) && operation(x) === getindex && return safe_getmetadata(m, arguments(x)[1], default) return default From d8c1e584d0e2ccd739ccb1bd46ab99e388074efb Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 10 Oct 2025 18:39:56 +0530 Subject: [PATCH 109/157] fixup! fix: improve inference of several utility functions --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index c9efdc6b2e..866437f5b0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -531,7 +531,7 @@ ModelingToolkit.collect_applied_operators(eq, Differential) == Set([D(y)]) The difference compared to `collect_operator_variables` is that `collect_operator_variables` returns the variable without the operator applied. """ -function collect_applied_operators(x::SymbolicT, ::Type{op}) where {op} +function collect_applied_operators(x, ::Type{op}) where {op} v = Set{SymbolicT}() SU.search_variables!(v, x; is_atomic = OnlyOperatorIsAtomic{op}()) return v From 21fcf5e7f2963470e6c191f23132418be03748f0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:05 +0530 Subject: [PATCH 110/157] fixup! fix: move ChainRulesCore to an extension --- ext/MTKChainRulesCoreExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/MTKChainRulesCoreExt.jl b/ext/MTKChainRulesCoreExt.jl index c213a164a3..409cad8704 100644 --- a/ext/MTKChainRulesCoreExt.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -3,6 +3,7 @@ module MTKChainRulesCoreExt import ChainRulesCore import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using ModelingToolkit: MTKParameters, NONNUMERIC_PORTION, AbstractSystem +import ModelingToolkit import ModelingToolkit as MTK import SciMLStructures import SymbolicIndexingInterface: remake_buffer From 91b7c29733e57b64a22a11b69c5145746af2c337 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:10 +0530 Subject: [PATCH 111/157] fixup! fix: make `add_initialization_parameters` type-stable --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b8a73f05d6..6810649ab3 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -564,7 +564,7 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) # add derivatives of all variables for steady-state initial conditions if is_time_dependent(sys) && !is_discrete_system(sys) D = Differential(get_iv(sys)::SymbolicT) - for v in all_initialvars + for v in collect(all_initialvars) iscall(v) && push!(all_initialvars, D(v)) end end From 33306999dcc45b963c96de6447fc1c53109de9d5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:20 +0530 Subject: [PATCH 112/157] fixup! fix: make `IndexCache` constructor more type-stable --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 80f52029d7..cabf056c43 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -92,7 +92,7 @@ function IndexCache(sys::AbstractSystem) push!(found_array_syms, arrsym) valid_arrsym || break if idxs == idxs[begin]:idxs[end] - idxs = reshape(idxs[begin]:idxs[end], size(idxs))::AbstractArray{Int} + idxs = reshape(idxs[begin]:idxs[end], size(arrsym))::AbstractArray{Int} else idxs = reshape(idxs, size(arrsym))::AbstractArray{Int} end From 38b72ad18c287deb44b68884db79e3941894e76f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:33 +0530 Subject: [PATCH 113/157] fixup! fix: minor fix for `evaluate_varmap!` --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 169b2c541a..fb937c3aa2 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -486,7 +486,7 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) v === nothing && continue symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = value(fixpoint_sub(v, varmap; maxiters = limit)) + varmap[k] = value(fixpoint_sub(v, varmap; maxiters = limit, fold = Val(true))) end end From 3c2af16c4a5ca7113da8a4a66bdfae5facced12b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:43 +0530 Subject: [PATCH 114/157] fixup! refactor: get `System` to precompile in a trivial case --- src/systems/system.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/systems/system.jl b/src/systems/system.jl index f8789d8938..a8b8507767 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -332,6 +332,7 @@ function default_consolidate(costs, subcosts) return _sum_costs(costs) + _sum_costs(subcosts) end +unwrap_vars(x) = unwrap_vars(collect(x)) unwrap_vars(vars::AbstractArray{SymbolicT}) = vars function unwrap_vars(vars::AbstractArray) result = similar(vars, SymbolicT) From d9a4539d7eb0b12310a53d04efb7f2633dd6cfa7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:53:52 +0530 Subject: [PATCH 115/157] fixup! refactor: make `System` more concretely typed --- src/systems/system.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index a8b8507767..5f63749426 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -409,12 +409,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; defaults = defsdict(defaults) guesses = defsdict(guesses) - inputs = unwrap_vars(inputs) - outputs = unwrap_vars(outputs) if !(inputs isa OrderedSet{SymbolicT}) + inputs = unwrap.(inputs) inputs = OrderedSet{SymbolicT}(inputs) end if !(outputs isa OrderedSet{SymbolicT}) + outputs = unwrap.(outputs) outputs = OrderedSet{SymbolicT}(outputs) end for subsys in systems From 4db99bb3e32563d2a7c6b0e09497c18e3086326b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:54:17 +0530 Subject: [PATCH 116/157] refactor: improve type-stability of constraint validation --- src/systems/system.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index 5f63749426..8b2070075c 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -686,10 +686,10 @@ Process variables in constraints of the (ODE) System. """ function process_constraint_system( constraints::Vector{Union{Equation, Inequality}}, sts, ps, iv; validate = true) - isempty(constraints) && return Set(), Set() + isempty(constraints) && return OrderedSet{SymbolicT}(), OrderedSet{SymbolicT}() - constraintsts = OrderedSet() - constraintps = OrderedSet() + constraintsts = OrderedSet{SymbolicT}() + constraintps = OrderedSet{SymbolicT}() for cons in constraints collect_vars!(constraintsts, constraintps, cons, iv) union!(constraintsts, collect_applied_operators(cons, Differential)) @@ -707,8 +707,8 @@ end Process the costs for the constraint system. """ function process_costs(costs::Vector, sts, ps, iv) - coststs = OrderedSet() - costps = OrderedSet() + coststs = OrderedSet{SymbolicT}() + costps = OrderedSet{SymbolicT}() for cost in costs collect_vars!(coststs, costps, cost, iv) end @@ -743,8 +743,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the System. Called variables must be variables of the System.")) - isequal(arg, iv) || isparameter(arg) || arg isa Integer || - arg isa AbstractFloat || + isequal(arg, iv) || isparameter(arg) || isconst(arg) && symtype(arg) <: Real || throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) isparameter(arg) && !isequal(arg, iv) && push!(auxps, arg) From db3a50c76d8a9085d13f172d918f04c37af524ae Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:54:28 +0530 Subject: [PATCH 117/157] fixup! refactor: remove usages of old `symtype` syntax --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 866437f5b0..27fe8e5667 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,7 +20,7 @@ function detime_dvs(op) if !iscall(op) op elseif issym(operation(op)) - SSym(nameof(operation(op)); type = Real, shape = SU.ShapeVecT()) + SSym(nameof(operation(op)); type = Real, shape = SU.shape(op)) else maketerm(typeof(op), operation(op), detime_dvs.(arguments(op)), metadata(op)) From 3e79363f55af3b91bb8e6ff5aedeb3725e65d401 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:54:40 +0530 Subject: [PATCH 118/157] fixup! refactor: concretely type some utility functions --- src/utils.jl | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 27fe8e5667..53d3fd0245 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -301,10 +301,21 @@ end function collect_defaults!(defs::SymmapT, vars::Vector{SymbolicT}) for v in vars isconst(v) && continue - if haskey(defs, v) || (def = Symbolics.getdefaultval(v, nothing)) === nothing + haskey(defs, v) && continue + def = Symbolics.getdefaultval(v, nothing) + if def !== nothing + defs[v] = SU.Const{VartypeT}(def) continue end - defs[v] = SU.Const{VartypeT}(def) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex end => begin + haskey(defs, args[1]) && continue + def = Symbolics.getdefaultval(args[1], nothing) + def === nothing && continue + defs[args[1]] = def + end + _ => nothing + end end return defs end @@ -313,10 +324,21 @@ function collect_guesses!(guesses::SymmapT, vars::Vector{SymbolicT}) for v in vars isconst(v) && continue symbolic_type(v) == NotSymbolic() && continue - if haskey(guesses, v) || (def = getguess(v)) === nothing + haskey(guesses, v) && continue + def = getguess(v) + if def !== nothing + guesses[v] = SU.Const{VartypeT}(def) continue end - guesses[v] = SU.Const{VartypeT}(def) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex end => begin + haskey(guesses, args[1]) && continue + def = Symbolics.getdefaultval(args[1], nothing) + def === nothing && continue + guesses[args[1]] = def + end + _ => nothing + end end return guesses end From 18caee9a87cce03285e3918d5ef02940ac3f0d0d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 16:55:19 +0530 Subject: [PATCH 119/157] fixup! fix: make some SII impls of `IndexCache` more type-stable --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index cabf056c43..e00a1a250f 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -439,7 +439,7 @@ function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::SymbolicT args = arguments(sym) idx = variable_index(ic, args[1]) idx === nothing && return nothing - return idx[args[2:end]...] + return idx[unwrap_const.(args[2:end])...] end SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) = false From 058024c35215790231b1e2af1e30f348b8b5208e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 13 Oct 2025 17:06:21 +0530 Subject: [PATCH 120/157] refactor: remove Unitful --- docs/Project.toml | 2 - docs/src/API/model_building.md | 2 - docs/src/basics/Validation.md | 4 +- src/ModelingToolkit.jl | 11 +- src/systems/model_parsing.jl | 17 +- src/systems/unit_check.jl | 2 - src/systems/validation.jl | 287 --------------------------------- 7 files changed, 7 insertions(+), 318 deletions(-) delete mode 100644 src/systems/validation.jl diff --git a/docs/Project.toml b/docs/Project.toml index 499a76a921..db2e22fca3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -32,7 +32,6 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [compat] Attractors = "1.24" @@ -65,4 +64,3 @@ StochasticDiffEq = "6" SymbolicIndexingInterface = "0.3.1" SymbolicUtils = "3, 4" Symbolics = "6" -Unitful = "1.12" diff --git a/docs/src/API/model_building.md b/docs/src/API/model_building.md index 7a8389fd2f..6ad624298e 100644 --- a/docs/src/API/model_building.md +++ b/docs/src/API/model_building.md @@ -13,8 +13,6 @@ ModelingToolkit.t_nounits ModelingToolkit.D_nounits ModelingToolkit.t ModelingToolkit.D -ModelingToolkit.t_unitful -ModelingToolkit.D_unitful ``` Users are recommended to use the appropriate common definition in their models. The required diff --git a/docs/src/basics/Validation.md b/docs/src/basics/Validation.md index 3f36a06e5e..651cf3f477 100644 --- a/docs/src/basics/Validation.md +++ b/docs/src/basics/Validation.md @@ -155,7 +155,7 @@ future when `ModelingToolkit` is extended to support eliminating `DynamicQuantit ## Other Restrictions -`Unitful` provides non-scalar units such as `dBm`, `°C`, etc. At this time, `ModelingToolkit` only supports scalar quantities. Additionally, angular degrees (`°`) are not supported because trigonometric functions will treat plain numerical values as radians, which would lead systems validated using degrees to behave erroneously when being solved. +`DynamicQuantities` provides non-scalar units such as `°C`, etc. At this time, `ModelingToolkit` only supports scalar quantities. Additionally, angular degrees (`°`) are not supported because trigonometric functions will treat plain numerical values as radians, which would lead systems validated using degrees to behave erroneously when being solved. ## Troubleshooting & Gotchas @@ -169,7 +169,7 @@ Parameter and initial condition values are supplied to problem constructors as p ```julia function remove_units(p::Dict) - Dict(k => Unitful.ustrip(ModelingToolkit.get_unit(k), v) for (k, v) in p) + Dict(k => DynamicQuantities.ustrip(ModelingToolkit.get_unit(k), v) for (k, v) in p) end add_units(p::Dict) = Dict(k => v * ModelingToolkit.get_unit(k) for (k, v) in p) ``` diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 1017923ee6..c78989ce05 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -38,7 +38,7 @@ else const IntDisjointSet = IntDisjointSets end using Base.Threads -using Latexify, Unitful, ArrayInterface +using Latexify, ArrayInterface using Setfield, ConstructionBase import Libdl using DocStringExtensions @@ -96,7 +96,7 @@ export independent_variables, unknowns, observables, parameters, full_parameters @reexport using UnPack RuntimeGeneratedFunctions.init(@__MODULE__) -import DynamicQuantities, Unitful +import DynamicQuantities const DQ = DynamicQuantities import DifferentiationInterface as DI @@ -217,7 +217,6 @@ include("systems/pde/pdesystem.jl") include("systems/sparsematrixclil.jl") include("systems/unit_check.jl") -include("systems/validation.jl") include("systems/dependency_graphs.jl") include("clock.jl") include("discretedomain.jl") @@ -238,15 +237,11 @@ include("deprecations.jl") const t_nounits = let only(@independent_variables t) end -const t_unitful = let - only(@independent_variables t [unit = Unitful.u"s"]) -end const t = let only(@independent_variables t [unit = DQ.u"s"]) end const D_nounits = Differential(t_nounits) -const D_unitful = Differential(t_unitful) const D = Differential(t) export ODEFunction, convert_system_indepvar, @@ -349,7 +344,7 @@ const set_scalar_metadata = setmetadata @public similarity_transform, inputs, outputs, bound_inputs, unbound_inputs, bound_outputs @public unbound_outputs, is_bound @public AbstractSystem, CheckAll, CheckNone, CheckComponents, CheckUnits -@public t, D, t_nounits, D_nounits, t_unitful, D_unitful +@public t, D, t_nounits, D_nounits @public SymbolicContinuousCallback, SymbolicDiscreteCallback @public VariableType, MTKVariableTypeCtx, VariableBounds, VariableConnectType @public VariableDescription, VariableInput, VariableIrreducible, VariableMisc diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 0369b12212..8c6ad113f2 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -573,7 +573,7 @@ function get_t(mod, t) get_var(mod, t) catch e if e isa UndefVarError - @warn("Could not find a predefined `t` in `$mod`; generating a new one within this model.\nConsider defining it or importing `t` (or `t_nounits`, `t_unitful` as `t`) from ModelingToolkit.") + @warn("Could not find a predefined `t` in `$mod`; generating a new one within this model.\nConsider defining it or importing `t` (or `t_nounits as t`) from ModelingToolkit.") variable(:t) else throw(e) @@ -901,18 +901,6 @@ function convert_units( DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) end -function convert_units(varunits::Unitful.FreeUnits, value) - Unitful.ustrip(varunits, value) -end - -convert_units(::Unitful.FreeUnits, value::NoValue) = NO_VALUE - -function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T} - Unitful.ustrip.(varunits, value) -end - -convert_units(::Unitful.FreeUnits, value::Num) = value - convert_units(::DynamicQuantities.Quantity, value::Num) = value function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) @@ -930,8 +918,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) try $setdefault($vv, $convert_units($unit, $name)) catch e - if isa(e, $(DynamicQuantities.DimensionError)) || - isa(e, $(Unitful.DimensionError)) + if isa(e, $(DynamicQuantities.DimensionError)) error("Unable to convert units for \'" * string(:($$vv)) * "\'") elseif isa(e, MethodError) error("No or invalid units provided for \'" * string(:($$vv)) * diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 839b77094f..64dd1e0290 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -23,8 +23,6 @@ function __get_scalar_unit_type(v) u = __get_literal_unit(v) if u isa DQ.AbstractQuantity return Val(:DynamicQuantities) - elseif u isa Unitful.Unitlike - return Val(:Unitful) end return nothing end diff --git a/src/systems/validation.jl b/src/systems/validation.jl deleted file mode 100644 index ecd98b1d43..0000000000 --- a/src/systems/validation.jl +++ /dev/null @@ -1,287 +0,0 @@ -module UnitfulUnitCheck - -using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, RecursiveArrayTools -using ..ModelingToolkit: ValidationError, - ModelingToolkit, Connection, instream, JumpType, VariableUnit, - get_systems, - Conditional, Comparison -using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: SymbolicT, value, issym, isadd, ismul, ispow, CallAndWrap -const MT = ModelingToolkit - -Base.:*(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x * y -Base.:/(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x / y - -""" -Throw exception on invalid unit types, otherwise return argument. -""" -function screen_unit(result) - result isa Unitful.Unitlike || - throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result)).")) - result isa Unitful.ScalarUnits || - throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) - result == u"°" && - throw(ValidationError("Degrees are not supported. Use radians instead.")) - result -end - -""" -Test unit equivalence. - -Example of implemented behavior: - -```julia -using ModelingToolkit, Unitful -MT = ModelingToolkit -@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"] -@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes -@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude -@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents -``` -""" -equivalent(x, y) = isequal(1 * x, 1 * y) -const unitless = Unitful.unit(1) - -""" -Find the unit of a symbolic item. -""" -get_unit(x::Real) = unitless -get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) -get_unit(x::AbstractArray) = map(get_unit, x) -get_unit(x::Num) = get_unit(value(x)) -function get_unit(x::Union{Symbolics.Arr, CallAndWrap}) - get_literal_unit(x) -end -get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::typeof(getindex), args) = get_unit(args[1]) -get_unit(x::SciMLBase.NullParameters) = unitless -get_unit(op::typeof(instream), args) = get_unit(args[1]) - -get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless)) - -function get_unit(op, args) # Fallback - result = op(1 .* get_unit.(args)...) - try - unit(result) - catch - throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) - end -end - -function get_unit(op::Integral, args) - unit = 1 - if op.domain.variables isa Vector - for u in op.domain.variables - unit *= get_unit(u) - end - else - unit *= get_unit(op.domain.variables) - end - return get_unit(args[1]) * unit -end - -function get_unit(op::Conditional, args) - terms = get_unit.(args) - terms[1] == unitless || - throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless.")) - equivalent(terms[2], terms[3]) || - throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match.")) - return terms[2] -end - -function get_unit(op::typeof(mapreduce), args) - if args[2] == + - get_unit(args[3]) - else - throw(ValidationError("Unsupported array operation $op")) - end -end - -function get_unit(op::Comparison, args) - terms = get_unit.(args) - equivalent(terms[1], terms[2]) || - throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match.")) - return unitless -end - -function get_unit(x::SymbolicT) - if issym(x) - get_literal_unit(x) - elseif isadd(x) - terms = get_unit.(arguments(x)) - firstunit = terms[1] - for other in terms[2:end] - termlist = join(map(repr, terms), ", ") - equivalent(other, firstunit) || - throw(ValidationError(", in sum $x, units [$termlist] do not match.")) - end - return firstunit - elseif ispow(x) - pargs = arguments(x) - base, expon = get_unit.(pargs) - @assert expon isa Unitful.DimensionlessUnits - if base == unitless - unitless - else - pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] - end - elseif iscall(x) - op = operation(x) - if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls - return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && operation(op) === getindex - gp = arguments(op)[1] - return screen_unit(getmetadata(gp, VariableUnit, unitless)) - end # Actual function calls: - args = arguments(x) - return get_unit(op, args) - else # This function should only be reached by Terms, for which `iscall` is true - throw(ArgumentError("Unsupported value $x.")) - end -end - -""" -Get unit of term, returning nothing & showing warning instead of throwing errors. -""" -function safe_get_unit(term, info) - side = nothing - try - side = get_unit(term) - catch err - if err isa Unitful.DimensionError - @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") - elseif err isa ValidationError - @warn(info*err.message) - elseif err isa MethodError - @warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).") - else - rethrow() - end - end - side -end - -function _validate(terms::Vector, labels::Vector{String}; info::String = "") - valid = true - first_unit = nothing - first_label = nothing - for (term, label) in zip(terms, labels) - equnit = safe_get_unit(term, info * label) - if equnit === nothing - valid = false - elseif !isequal(term, 0) - if first_unit === nothing - first_unit = equnit - first_label = label - elseif !equivalent(first_unit, equnit) - valid = false - @warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.") - end - end - end - valid -end - -function _validate(conn::Connection; info::String = "") - valid = true - syss = get_systems(conn) - sys = first(syss) - unks = unknowns(sys) - for i in 2:length(syss) - s = syss[i] - _unks = unknowns(s) - if length(unks) != length(_unks) - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.") - continue - end - for (i, x) in enumerate(unks) - j = findfirst(isequal(x), _unks) - if j == nothing - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same unknowns.") - else - aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i") - bunit = safe_get_unit(_unks[j], info * string(nameof(s)) * "#$j") - if !equivalent(aunit, bunit) - valid = false - @warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.") - end - end - end - end - valid -end - -function validate(jump::Union{MT.VariableRateJump, - MT.ConstantRateJump}, t::SymbolicT; - info::String = "") - newinfo = replace(info, "eq." => "jump") - _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units - validate(jump.affect!, info = newinfo) -end - -function validate(jump::MT.MassActionJump, t::SymbolicT; info::String = "") - left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols - net_symbols = [x[1] for x in jump.net_stoch] - all_symbols = vcat(left_symbols, net_symbols) - allgood = _validate(all_symbols, string.(all_symbols); info) - n = sum(x -> x[2], jump.reactant_stoch, init = 0) - base_unitful = all_symbols[1] #all same, get first - allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)], - ["scaled_rates", "1/(t*reactants^$n))"]; info) -end - -function validate(jumps::Vector{JumpType}, t::SymbolicT) - labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] - majs = filter(x -> x isa MassActionJump, jumps) - crjs = filter(x -> x isa ConstantRateJump, jumps) - vrjs = filter(x -> x isa VariableRateJump, jumps) - splitjumps = [majs, crjs, vrjs] - all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)]) -end - -function validate(eq::MT.Equation; info::String = "") - if typeof(eq.lhs) == Connection - _validate(eq.rhs; info) - else - _validate([eq.lhs, eq.rhs], ["left", "right"]; info) - end -end -function validate(eq::MT.Equation, - term::Union{SymbolicT, Unitful.Quantity, Num}; info::String = "") - _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) -end -function validate(eq::MT.Equation, terms::Vector; info::String = "") - _validate(vcat([eq.lhs, eq.rhs], terms), - vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info) -end - -""" -Returns true iff units of equations are valid. -""" -function validate(eqs::Vector; info::String = "") - all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Vector; info::String = "") - all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Matrix; info::String = "") - all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, term::SymbolicT; info::String = "") - all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing - -""" -Throws error if units of equations are invalid. -""" -function MT.check_units(::Val{:Unitful}, eqs...) - validate(eqs...) || - throw(ValidationError("Some equations had invalid units. See warnings for details.")) -end - -end # module From 240119dc37ec5e61792f4d5676b729464cf387d8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:53:59 +0530 Subject: [PATCH 121/157] fix: handle consts in `find_eq_solvables!` --- src/structural_transformation/utils.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index b2ba2059c3..8f2a6cd6ae 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -267,7 +267,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no a, b, islinear = linear_expansion(term, var) islinear || (all_int_vars = false; continue) - if a isa SymbolicT + if !SU.isconst(a) all_int_vars = false if !allow_symbolic if allow_parameter @@ -282,20 +282,20 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no add_edge!(solvable_graph, ieq, j) continue end - if !(a isa Number) + if !(symtype(a) <: Number) all_int_vars = false continue end # When the expression is linear with numeric `a`, then we can safely # only consider `b` for the following iterations. term = b - if isone(abs(a)) - coeffs === nothing || push!(coeffs, convert(Int, a)) + if SU._isone(abs(a)) + coeffs === nothing || push!(coeffs, convert(Int, unwrap_const(a))) else all_int_vars = false conservative && continue end - if a != 0 + if !SU._iszero(a) add_edge!(solvable_graph, ieq, j) else if may_be_zero From bc517bff43ae82b6fc736aba11fc96a2203dad8c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:54:51 +0530 Subject: [PATCH 122/157] refactor: unwrap `AnalysisPoint` inside consts where necessary --- src/systems/abstractsystem.jl | 2 +- src/systems/analysis_points.jl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 6810649ab3..9886f8d4c9 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -980,7 +980,7 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace = does_namespacing( eq isa Equation || continue lhs = value(eq.lhs) rhs = value(eq.rhs) - if lhs isa AnalysisPoint + if value(lhs) isa AnalysisPoint rhs = rhs::AnalysisPoint nameof(rhs) == name || continue return namespace ? renamespace(sys, rhs) : rhs diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index 8a1e8f400d..97241cbe62 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -428,7 +428,7 @@ Search for the analysis point with the given `name` in `get_eqs(sys)`. function analysis_point_index(sys::AbstractSystem, name::Symbol) name = namespace_hierarchy(name)[end] findfirst(get_eqs(sys)) do eq - eq.lhs isa AnalysisPoint && nameof(eq.rhs) == name + value(eq.lhs) isa AnalysisPoint && nameof(value(eq.rhs)::AnalysisPoint) == name end end @@ -540,7 +540,7 @@ function apply_transformation(tf::Break, sys::AbstractSystem) breaksys_eqs = copy(get_eqs(breaksys)) @set! breaksys.eqs = breaksys_eqs - ap = breaksys_eqs[ap_idx].rhs + ap = value(breaksys_eqs[ap_idx].rhs)::AnalysisPoint deleteat!(breaksys_eqs, ap_idx) breaksys = with_analysis_point_ignored(breaksys, ap) @@ -598,7 +598,7 @@ function apply_transformation(tf::GetInput, sys::AbstractSystem) error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).") # get the analysis point ap_sys_eqs = get_eqs(ap_sys) - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # input variable ap_ivar = ap_var(ap.input) @@ -653,7 +653,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem) # modified equations ap_sys_eqs = copy(get_eqs(ap_sys)) @set! ap_sys.eqs = ap_sys_eqs - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # remove analysis point deleteat!(ap_sys_eqs, ap_idx) ap_sys = with_analysis_point_ignored(ap_sys, ap) @@ -723,7 +723,7 @@ function apply_transformation(tf::AddVariable, sys::AbstractSystem) ap_idx === nothing && error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).") ap_sys_eqs = get_eqs(ap_sys) - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # add equations involving new variable ap_ivar = ap_var(ap.input) From e408e062d953fef6dfa19a6dc7b1af3dff2d9a50 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:55:20 +0530 Subject: [PATCH 123/157] refactor: make `AnalysisPoint` more type-stable --- src/ModelingToolkit.jl | 2 +- src/systems/analysis_points.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index c78989ce05..6ae5debdf0 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -174,10 +174,10 @@ include("systems/model_parsing.jl") include("systems/connectiongraph.jl") include("systems/connectors.jl") include("systems/state_machines.jl") -include("systems/analysis_points.jl") include("systems/imperative_affect.jl") include("systems/callbacks.jl") include("systems/system.jl") +include("systems/analysis_points.jl") include("systems/codegen_utils.jl") include("problems/docs.jl") include("systems/codegen.jl") diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index 97241cbe62..d25f717d72 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -71,7 +71,7 @@ struct AnalysisPoint The outputs of the connection. In the context of ModelingToolkitStandardLibrary.jl, these are all `RealInput` connectors. """ - outputs::Union{Nothing, Vector{Any}} + outputs::Union{Nothing, Vector{System}, Vector{SymbolicT}} function AnalysisPoint(input, name::Symbol, outputs; verbose = true) # input to analysis point should be an output variable @@ -230,7 +230,7 @@ typically is not (unless the model is an inverse model). warning if you are analyzing an inverse model. """ function connect(in::AbstractSystem, name::Symbol, out, outs...; verbose = true) - return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) + return AnalysisPoint() ~ AnalysisPoint(in, name, System[out; collect(outs)]; verbose) end function connect( From 13bdba90b517916c621fc03e58b828160d0c2666 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:55:26 +0530 Subject: [PATCH 124/157] fixup! fix: improve type-stability of connection infrastructure --- src/systems/connectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 49f1354a26..438b7231ed 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -252,7 +252,7 @@ function validate_causal_variables_connection(allvars::Vector{SymbolicT}) end non_causal_variables = SymbolicT[] for x in allvars - isinput(x) || isoutput(x) || continue + (isinput(x) || isoutput(x)) && continue push!(non_causal_variables, x) end isempty(non_causal_variables) || throw(NonCausalVariableError(non_causal_variables)) From bdf423e65b7d9accf3d510c71cae9574353837fa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:55:44 +0530 Subject: [PATCH 125/157] fix: fix parameter default parsing in `@mtkmodel` --- src/systems/model_parsing.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 8c6ad113f2..b2f9ad4736 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -259,6 +259,9 @@ function unit_handled_variable_value(meta, varname) return varval end +no_value_default_to_nothing(::NoValue) = nothing +no_value_default_to_nothing(x) = x + # This function parses various variable/parameter definitions. # # The comments indicate the syntax matched by a block; either when parsed directly @@ -336,17 +339,17 @@ Base.@nospecializeinfer function parse_variable_def!( unit_handled_variable_value(meta, varname) if varclass == :parameters Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@parameters ($a[$(indices...)]::$type = $varval), + var = :($varname = $first(@parameters ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@constants ($a[$(indices...)]::$type = $varval), + var = :($varname = $first(@constants ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@variables ($a[$(indices)]::$type = $varval), + var = :($varname = $first(@variables ($a[$(indices)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) end update_array_kwargs_and_metadata!( From 81681e14cfced526384bc638a31f1c822e919d81 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:55:52 +0530 Subject: [PATCH 126/157] fixup! refactor: get `System` to precompile in a trivial case --- src/systems/system.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index 8b2070075c..baaf8dc085 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -298,18 +298,18 @@ struct System <: IntermediateDeprecationSystem check_equations(equations(continuous_events), iv) check_subsystems(systems) end - # if checks == true || (checks & CheckUnits) > 0 - # u = __get_unit_type(unknowns, ps, iv) - # if noise_eqs === nothing - # check_units(u, eqs) - # else - # check_units(u, eqs, noise_eqs) - # end - # if iv !== nothing - # check_units(u, jumps, iv) - # end - # isempty(constraints) || check_units(u, constraints) - # end + if checks == true || (checks & CheckUnits) > 0 + u = __get_unit_type(unknowns, ps, iv) + if noise_eqs === nothing + check_units(u, eqs) + else + check_units(u, eqs, noise_eqs) + end + if iv !== nothing + check_units(u, jumps, iv) + end + isempty(constraints) || check_units(u, constraints) + end new(tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, From 0a9ad5723bc00e71077ca344d4f1da1a33be0ede Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:56:05 +0530 Subject: [PATCH 127/157] fix: handle symbolic consts in unit checking --- src/systems/unit_check.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index 64dd1e0290..050da35edb 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -127,7 +127,9 @@ function get_unit(op::Comparison, args) end function get_unit(x::SymbolicT) - if (u = __get_literal_unit(x)) !== nothing + if isconst(x) + return get_unit(value(x)) + elseif (u = __get_literal_unit(x)) !== nothing screen_unit(u) elseif issym(x) get_literal_unit(x) @@ -147,7 +149,7 @@ function get_unit(x::SymbolicT) if base == unitless unitless else - pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] + isconst(pargs[2]) ? base^unwrap_const(pargs[2]) : (1 * base)^pargs[2] end elseif iscall(x) op = operation(x) @@ -193,7 +195,7 @@ function _validate(terms::Vector, labels::Vector{String}; info::String = "") equnit = safe_get_unit(term, info * label) if equnit === nothing valid = false - elseif !isequal(term, 0) + elseif !SU._iszero(term) if first_unit === nothing first_unit = equnit first_label = label @@ -274,8 +276,8 @@ function validate(jumps::Vector{JumpType}, t::SymbolicT) end function validate(eq::Union{Inequality, Equation}; info::String = "") - if typeof(eq.lhs) == Connection - _validate(eq.rhs; info) + if isconst(eq.lhs) && value(eq.lhs) isa Connection + _validate(value(eq.rhs)::Connection; info) else _validate([eq.lhs, eq.rhs], ["left", "right"]; info) end From 9a3a0d6a2f530a6cbb9e3e757a9341065b5c6467 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 11:56:18 +0530 Subject: [PATCH 128/157] fixup! fixup! fix: make `flatten_equations` type-stable --- src/utils.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 53d3fd0245..6eb6e427a1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1182,8 +1182,14 @@ without scalarizing occurrences of array variables and return the new list of eq function flatten_equations(eqs::Vector{Equation}) _eqs = Equation[] for eq in eqs - for (i1, i2) in zip(SU.stable_eachindex(eq.lhs), SU.stable_eachindex(eq.rhs)) - push!(_eqs, eq.lhs[i1] ~ eq.rhs[i2]) + if !SU.is_array_shape(SU.shape(eq.lhs)) + push!(_eqs, eq) + continue + end + lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT} + rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT} + for (l, r) in zip(lhs, rhs) + push!(_eqs, l ~ r) end end return _eqs From 83faba1325365e3751fd69341b03f00e46b446fe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:33:07 +0530 Subject: [PATCH 129/157] fix: handle wrapped constants in `dummy_derivative_graph` --- src/structural_transformation/partial_state_selection.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index ab8e7f0f3d..5b205f414b 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -77,10 +77,12 @@ function dummy_derivative_graph!( # only accept small integers to avoid overflow is_all_small_int = all(_J) do x′ x = unwrap(x′) + SU.isconst(x) || return false x isa Number || return false - isinteger(x) && typemin(Int8) <= x <= typemax(Int8) + x = value(x) + isinteger(x) && typemin(Int8) <= Int(x) <= typemax(Int8) end - J = is_all_small_int ? Int.(unwrap.(_J)) : nothing + J = is_all_small_int ? Int.(value.(_J)) : nothing end while true nrows = length(eqs) From 6804adb1534fdad1899f0407167cc0e21c7e8462 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:33:52 +0530 Subject: [PATCH 130/157] fix: better handle absent parameter derivatives in simplification --- src/structural_transformation/symbolics_tearing.jl | 4 +--- src/systems/systemstructure.jl | 14 ++++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index d3f9a4c5da..2646e94722 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -72,9 +72,7 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...) vs = Set{SymbolicT}() SU.search_variables!(vs, eq.rhs) for v in vs - # parameters with unknown derivatives have a value of `nothing` in the map, - # so use `missing` as the default. - get(ts.param_derivative_map, v, missing) === nothing || continue + v in ts.no_deriv_params || continue _original_eq = equations(ts)[ieq] error(""" Encountered derivative of discrete variable `$(only(arguments(v)))` when \ diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 45b483db63..82ac0d8a8a 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -209,6 +209,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} structure::SystemStructure extra_eqs::Vector{Equation} param_derivative_map::Dict{SymbolicT, SymbolicT} + no_deriv_params::Set{SymbolicT} original_eqs::Vector{Equation} """ Additional user-provided observed equations. The variables calculated here @@ -362,6 +363,7 @@ function TearingState(sys; check = true, sort_eqs = true) original_eqs = copy(eqs) neqs = length(eqs) param_derivative_map = Dict{SymbolicT, SymbolicT}() + no_deriv_params = Set{SymbolicT}() fullvars = SymbolicT[] # * Scalarize unknowns dvs = Set{SymbolicT}() @@ -380,7 +382,7 @@ function TearingState(sys; check = true, sort_eqs = true) varsbuf = Set{SymbolicT}() eqs_to_retain = trues(length(eqs)) for (i, eq) in enumerate(eqs) - eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, eqs_to_retain, ps, iv, i, eq) + eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, no_deriv_params, eqs_to_retain, ps, iv, i, eq) empty!(varsbuf) SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{SU.Operator}()) incidence = Set{SymbolicT}() @@ -396,7 +398,7 @@ function TearingState(sys; check = true, sort_eqs = true) if symbolic_contains(v, ps) || getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v) if is_time_dependent_parameter(v, ps, iv) && - !haskey(param_derivative_map, Differential(iv)(v)) + !haskey(param_derivative_map, Differential(iv)(v)) && !(Differential(iv)(v) in no_deriv_params) # Parameter derivatives default to zero - they stay constant # between callbacks param_derivative_map[Differential(iv)(v)] = Symbolics.COMMON_ZERO @@ -480,6 +482,8 @@ function TearingState(sys; check = true, sort_eqs = true) push!(symbolic_incidence, collect(incidence)) end + filter!(Base.Fix2(!==, COMMON_NOTHING) ∘ last, param_derivative_map) + eqs = eqs[eqs_to_retain] original_eqs = original_eqs[eqs_to_retain] neqs = length(eqs) @@ -520,7 +524,7 @@ function TearingState(sys; check = true, sort_eqs = true) return TearingState{typeof(sys)}(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), - Equation[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) + Equation[], param_derivative_map, no_deriv_params, original_eqs, Equation[], typeof(sys)[]) end function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var_types::Vector{VariableType}, @nospecialize(iv::Union{SymbolicT, Nothing})) @@ -594,7 +598,7 @@ function collect_vars_to_set!(buffer::Set{SymbolicT}, vars::Vector{SymbolicT}) end end -function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation) +function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, no_deriv_params::Set{SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation) is_statemachine_equation = false lhs = eq.lhs rhs = eq.rhs @@ -612,6 +616,8 @@ function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, eqs_ else # change the equation if the RHS is `missing` so the rest of this loop works eq = Symbolics.COMMON_ZERO ~ Symbolics.COMMON_ZERO + push!(no_deriv_params, lhs) + delete!(param_derivative_map, lhs) end eqs_to_retain[i] = false end From 89f1803c1d53b35d6a10697e6da41a44c734727f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:33:58 +0530 Subject: [PATCH 131/157] fixup! fix: improve type-stability of `tearing_reassemble` --- src/structural_transformation/symbolics_tearing.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 2646e94722..f599c7e455 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -325,15 +325,17 @@ Effects on the system structure: """ function generate_derivative_variables!( ts::TearingState, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs, mm::SparseMatrixCLIL{T, Int}, iv::Union{SymbolicT, Nothing}) where {T} + var_sccs, mm::Union{Nothing, SparseMatrixCLIL}, iv::Union{SymbolicT, Nothing}) @unpack fullvars, sys, structure = ts @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) diff_to_var = invview(var_to_diff) is_discrete = is_only_discrete(structure) linear_eqs = Dict{Int, Int}() - for (i, e) in enumerate(mm.nzrows) - linear_eqs[e] = i + if mm !== nothing + for (i, e) in enumerate(mm.nzrows) + linear_eqs[e] = i + end end # We need the inverse mapping of `var_sccs` to update it efficiently later. From 2883d732718a7724dde175a9e1d8f18cb6444627 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:34:12 +0530 Subject: [PATCH 132/157] fix: handle wrapped constants in `find_eq_solvables!` --- src/structural_transformation/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 8f2a6cd6ae..c0807ae604 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -239,7 +239,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no fullvars = state.fullvars @unpack graph, solvable_graph = state.structure eq = equations(state)[ieq] - term = value(eq.rhs - eq.lhs) + term = unwrap(eq.rhs - eq.lhs) all_int_vars = true coeffs === nothing || empty!(coeffs) empty!(to_rm) @@ -289,7 +289,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no # When the expression is linear with numeric `a`, then we can safely # only consider `b` for the following iterations. term = b - if SU._isone(abs(a)) + if SU._isone(abs(unwrap_const(a))) coeffs === nothing || push!(coeffs, convert(Int, unwrap_const(a))) else all_int_vars = false From 8328947b431d793cdfb48595a132e1609e194927 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:34:34 +0530 Subject: [PATCH 133/157] fix: minor bug fix in initsys generation --- src/systems/nonlinear/initializesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 68d5b8f26e..fdde6ed32a 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -77,7 +77,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem; # PREPROCESSING op = anydict(op) if isempty(op) - op = copy(defs) + op = anydict(copy(defs)) end scalarize_vars_in_varmap!(op, arrvars) u0map = anydict() From ed1895cb6125bcde511df5e4b3f710817018cce4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:34:52 +0530 Subject: [PATCH 134/157] refactor: avoid using `vars!` in initsys generation --- src/systems/nonlinear/initializesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index fdde6ed32a..ea6ad2f00b 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -268,7 +268,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem; vs = Set() initialization_eqs = filter(initialization_eqs) do eq empty!(vs) - vars!(vs, eq; op = Initial) + SU.search_variables!(vs, eq; is_atomic = OperatorIsAtomic{Initial}()) allpars = full_parameters(sys) for p in allpars if symbolic_type(p) == ArraySymbolic() && SU.shape(p) isa SU.Unknown From af9488965345471b7bfbe82cd32c7abc7e9b134c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:35:06 +0530 Subject: [PATCH 135/157] fix: unwrap symbolic constants in initsys generation --- src/systems/nonlinear/initializesystem.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index ea6ad2f00b..9af6e54367 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -498,13 +498,13 @@ end function get_possibly_array_fallback_singletons(varmap, p) if haskey(varmap, p) - return varmap[p] + return value(varmap[p]) end if symbolic_type(p) == ArraySymbolic() symbolic_has_known_size(p) || return nothing scal = collect(p) if all(x -> haskey(varmap, x), scal) - res = [varmap[x] for x in scal] + res = [value(varmap[x]) for x in scal] if any(x -> x === nothing, res) return nothing elseif any(x -> x === missing, res) From f1b8c49e17d8b4331cccff4227b245223790c0e1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:35:10 +0530 Subject: [PATCH 136/157] fixup! fix: improve type-stability of some SII functions --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 9886f8d4c9..73ba6712b0 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -175,7 +175,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Union{ end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Int) - sym in 1:length(parameter_symbols(sys)) + !is_split(sys) && sym in 1:length(parameter_symbols(sys)) end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::SymbolicT) From 796899e7d942b3d1e22056be2a92a36c76263076 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:35:16 +0530 Subject: [PATCH 137/157] fixup! fix: improve type-stability of alias elimination --- src/systems/alias_elimination.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index 2a35f59458..8ac742f6be 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -64,7 +64,8 @@ function alias_elimination!(state::TearingState; kwargs...) resize!(eqs, nsrcs(graph)) __trivial_eq_rhs = let fullvars = fullvars - function trivial_eq_rhs(var, coeff) + function trivial_eq_rhs(pair) + var, coeff = pair iszero(coeff) && return Symbolics.COMMON_ZERO return coeff * fullvars[var] end From a171233c03f636772596ad1c5f5f4995b793f35d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:35:40 +0530 Subject: [PATCH 138/157] fixup! fix: fix parameter default parsing in `@mtkmodel` --- src/systems/model_parsing.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index b2f9ad4736..74832f6223 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -376,13 +376,13 @@ Base.@nospecializeinfer function parse_variable_def!( Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; - $varname = $first(@parameters ($a[$(indices...)]::$type = $varval), + $varname = $first(@parameters ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $(def_n_meta...))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; - $varname = $first(@constants ($a[$(indices...)]::$type = $varval), + $varname = $first(@constants ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $(def_n_meta...))) else Meta.isexpr(a, :call) || @@ -390,7 +390,7 @@ Base.@nospecializeinfer function parse_variable_def!( assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; $varname = $first(@variables $a[$(indices...)]::$type = ( - $varval), + $no_value_default_to_nothing($varval)), $(def_n_meta...))) end else @@ -398,18 +398,18 @@ Base.@nospecializeinfer function parse_variable_def!( Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + $varname = $first(@parameters $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@constants $a[$(indices...)]::$type = $varname)) + $varname = $first(@constants $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@variables $a[$(indices...)]::$type = $varname)) + $varname = $first(@variables $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) end varval, meta = def_n_meta, nothing end @@ -432,15 +432,15 @@ Base.@nospecializeinfer function parse_variable_def!( varname = a isa Expr && a.head == :call ? a.args[1] : a if varclass == :parameters Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@parameters $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@constants $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@constants $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :variables Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@variables $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@variables $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) else throw("Symbolic array with arbitrary length is not handled for $varclass. Please open an issue with an example.") From f4ae3a01a6986e66ece9f54ee06c4a19275b5cf8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:35:51 +0530 Subject: [PATCH 139/157] fix: add edge case for unit handling in `@mtkmodel` --- src/systems/model_parsing.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 74832f6223..89737bf6b9 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -904,6 +904,8 @@ function convert_units( DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) end +convert_units(::DynamicQuantities.Quantity, value::AbstractArray{Num}) = value + convert_units(::DynamicQuantities.Quantity, value::Num) = value function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) From 0e54dde643e31e68e87e63e38fc30f4d2a719250 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:36:07 +0530 Subject: [PATCH 140/157] fix: handle new `guesses` type in initprob generation --- src/systems/problem_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index fb937c3aa2..146331c0de 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1162,7 +1162,7 @@ function maybe_build_initialization_problem( nothing end meta = InitializationMetadata( - copy(op), copy(guesses), Vector{Equation}(initialization_eqs), + copy(op), anydict(copy(guesses)), Vector{Equation}(initialization_eqs), use_scc, time_dependent_init, ReconstructInitializeprob( sys, initializeprob.f.sys; u0_constructor, From 65c2e46479bbc6936a69094556e38259036d431b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:36:20 +0530 Subject: [PATCH 141/157] fix: improve type-stability of `System` constructor --- src/systems/system.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/systems/system.jl b/src/systems/system.jl index baaf8dc085..db15f218d6 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -383,6 +383,9 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; preface = [], checks = true) name === nothing && throw(NoNameError()) + if !(systems isa Vector{System}) + systems = Vector{System}(systems) + end if !(eqs isa Vector{Equation}) eqs = Equation[eqs] end @@ -626,8 +629,8 @@ the system. function System(eqs::Vector{Equation}; kwargs...) eqs = collect(eqs) - allunknowns = OrderedSet() - ps = OrderedSet() + allunknowns = OrderedSet{SymbolicT}() + ps = OrderedSet{SymbolicT}() for eq in eqs collect_vars!(allunknowns, ps, eq, nothing) end From 7a59416ae75951c7a1ff8df1b122d63e6f793823 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:36:43 +0530 Subject: [PATCH 142/157] fix: properly hashcons constant global symbolics --- src/ModelingToolkit.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6ae5debdf0..1f91ade436 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -359,6 +359,13 @@ for prop in [SYS_PROPS; [:continuous_events, :discrete_events]] @eval @public $getter, $hasfn end +function __init__() + SU.hashcons(unwrap(t_nounits), true) + SU.hashcons(unwrap(t), true) + SU.hashcons(COMMON_NOTHING, true) + SU.hashcons(COMMON_MISSING, true) +end + PrecompileTools.@compile_workload begin fold1 = Val{false}() using SymbolicUtils From 0b4000c1ac70c99e896d08d06168d62311883b6d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:37:09 +0530 Subject: [PATCH 143/157] fix: improve type-stability of `subexpressions_not_involving_vars!` --- src/utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 6eb6e427a1..0c8a00b080 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1054,9 +1054,8 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) # OR # none of `vars` are involved in `expr` if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) || - (vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs)) + (vs = SU.search_variables(expr); intersect!(vs, vars); isempty(vs)) sym = gensym(:subexpr) - stype = symtype(expr) var = similar_variable(expr, sym) state[expr] = var return var @@ -1066,7 +1065,7 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) indep_args = [] dep_args = [] for arg in args - _vs = ModelingToolkit.vars(arg) + _vs = SU.search_variables(arg) intersect!(_vs, vars) if !isempty(_vs) push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state)) From 6595ace18b7fad28ea9c92bb442fd925f94f762d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:37:25 +0530 Subject: [PATCH 144/157] fix: handle `Shift` applied to `Equation` --- src/discretedomain.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index bde06ee64d..d9fb3496f7 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -70,6 +70,9 @@ normalize_to_differential(s::Shift) = Differential(s.t)^s.steps Base.nameof(::Shift) = :Shift SymbolicUtils.isbinop(::Shift) = false +function (D::Shift)(x::Equation, allow_zero = false) + D(x.lhs, allow_zero) ~ D(x.rhs, allow_zero) +end function (D::Shift)(x, allow_zero = false) !allow_zero && D.steps == 0 && return x term(D, x; type = symtype(x), shape = SU.shape(x)) From f7186707bdc9f861132932ae3d9c889da74e6627 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Oct 2025 16:38:15 +0530 Subject: [PATCH 145/157] refactor: reorganize imports to recompile invalidations --- src/ModelingToolkit.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 1f91ade436..b8e80a5a14 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -6,6 +6,10 @@ using PrecompileTools, Reexport @recompile_invalidations begin using StaticArrays using Symbolics + using ImplicitDiscreteSolve + using JumpProcesses + # ONLY here for the invalidations + import REPL end import SymbolicUtils @@ -20,16 +24,16 @@ using DocStringExtensions using SpecialFunctions, NaNMath @recompile_invalidations begin using DiffEqCallbacks +using DiffEqNoiseProcess: DiffEqNoiseProcess, WienerProcess +using DiffEqBase, SciMLBase, ForwardDiff end using Graphs import ExprTools: splitdef, combinedef import OrderedCollections -using DiffEqNoiseProcess: DiffEqNoiseProcess, WienerProcess using SymbolicIndexingInterface using LinearAlgebra, SparseArrays using InteractiveUtils -using JumpProcesses using DataStructures @static if pkgversion(DataStructures) >= v"0.19" import DataStructures: IntDisjointSet @@ -51,7 +55,6 @@ using URIs: URI using SciMLStructures using Compat using AbstractTrees -using DiffEqBase, SciMLBase, ForwardDiff using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain, PeriodicClock, Clock, SolverStepClock, ContinuousClock, OverrideInit, NoInit @@ -60,7 +63,6 @@ using MLStyle import Moshi using Moshi.Data: @data import SCCNonlinearSolve -using ImplicitDiscreteSolve using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix @@ -77,7 +79,7 @@ using Symbolics: degree, VartypeT, SymbolicT using Symbolics: parse_vars, value, @derivatives, get_variables, exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, - NAMESPACE_SEPARATOR, setdefaultval, + NAMESPACE_SEPARATOR, setdefaultval, Arr, hasnode, fixpoint_sub, CallAndWrap, SArgsT, SSym, STerm const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, From 7a4a3a52bcdf9751628b8ce557b9a7a7983d1c41 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:57:21 +0530 Subject: [PATCH 146/157] fix: fix `pantelides_reassemble` --- src/structural_transformation/pantelides.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 581c1f1198..ad2dff0282 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -2,6 +2,8 @@ ### Reassemble: structural information -> system ### +const NOTHING_EQ = nothing ~ nothing + function pantelides_reassemble(state::TearingState, var_eq_matching) fullvars = state.fullvars @unpack var_to_diff, eq_to_diff = state.structure @@ -9,11 +11,11 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) # Step 1: write derivative equations in_eqs = equations(sys) out_eqs = Vector{Equation}(undef, nv(eq_to_diff)) - fill!(out_eqs, nothing) + fill!(out_eqs, NOTHING_EQ) out_eqs[1:length(in_eqs)] .= in_eqs out_vars = Vector{SymbolicT}(undef, nv(var_to_diff)) - fill!(out_vars, nothing) + fill!(out_vars, ModelingToolkit.COMMON_NOTHING) out_vars[1:length(fullvars)] .= fullvars iv = get_iv(sys) @@ -22,7 +24,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) for (varidx, diff) in edges(var_to_diff) # fullvars[diff] = D(fullvars[var]) vi = out_vars[varidx] - @assert vi!==nothing "Something went wrong on reconstructing unknowns from variable association list" + @assert vi!==ModelingToolkit.COMMON_NOTHING "Something went wrong on reconstructing unknowns from variable association list" # `fullvars[i]` needs to be not a `D(...)`, because we want the DAE to be # first-order. if isdifferential(vi) @@ -36,8 +38,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) # LHS variable is looked up from var_to_diff # the var_to_diff[i]-th variable is the differentiated version of var at i eq = out_eqs[eqidx] - lhs = if !(eq.lhs isa SymbolicT) - 0 + lhs = if SU.isconst(eq.lhs) + Symbolics.COMMON_ZERO elseif isdiffeq(eq) # look up the variable that represents D(lhs) lhsarg1 = arguments(eq.lhs)[1] @@ -47,7 +49,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) D(eq.lhs) else # remove clashing equations - lhs = Num(nothing) + lhs = ModelingToolkit.COMMON_NOTHING end else D(eq.lhs) @@ -55,14 +57,14 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) rhs = ModelingToolkit.expand_derivatives(D(eq.rhs)) rhs = substitute(rhs, state.param_derivative_map) substitution_dict = Dict(x.lhs => x.rhs - for x in out_eqs if x !== nothing && x.lhs isa SymbolicT) + for x in out_eqs if x !== NOTHING_EQ && !SU.isconst(eq.lhs)) sub_rhs = substitute(rhs, substitution_dict) out_eqs[diff] = lhs ~ sub_rhs end final_vars = unique(filter(x -> !(operation(x) isa Differential), fullvars)) final_eqs = map(identity, - filter(x -> value(x.lhs) !== nothing, + filter(x -> x.lhs !== ModelingToolkit.COMMON_NOTHING, out_eqs[sort(filter(x -> x !== unassigned, var_eq_matching))])) @set! sys.eqs = final_eqs From c8eebc16a7456eebf8d66cbc80383c2e6007e5af Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:57:26 +0530 Subject: [PATCH 147/157] fixup! fix: improve type-stability of `tearing_reassemble` --- src/structural_transformation/symbolics_tearing.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index f599c7e455..db7ba301df 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -428,7 +428,9 @@ function generate_derivative_variables!( end new_sccs = insert_sccs(var_sccs, sccs_to_insert) - @set! mm.ncols = ndsts(graph) + if mm !== nothing + @set! mm.ncols = ndsts(graph) + end return new_sccs end From 54350fbf69807c8e0abe5be41fc188fcbb4034b4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:57:46 +0530 Subject: [PATCH 148/157] fix: fix `change_independent_variable` --- src/systems/diffeqs/basic_transformations.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index b7baad8861..a73a8358ad 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -424,7 +424,7 @@ end """ change_independent_variable( sys::System, iv, eqs = []; - add_old_diff = false, simplify = true, fold = false + add_old_diff = false, simplify = true, fold = Val(false) ) Transform the independent variable (e.g. ``t``) of the ODE system `sys` to a dependent variable `iv` (e.g. ``u(t)``). @@ -478,7 +478,7 @@ julia> unknowns(M) """ function change_independent_variable( sys::System, iv, eqs = []; - add_old_diff = false, simplify = true, fold = false + add_old_diff = false, simplify = true, fold = Val(false) ) iv2_of_iv1 = unwrap(iv) # e.g. u(t) iv1 = get_iv(sys) # e.g. t @@ -538,11 +538,11 @@ function change_independent_variable( # e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u) function transform(ex::T) where {T} # 1) Replace the argument of every function; e.g. f(t) -> f(u(t)) - for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable) + for var in SU.search_variables(ex; is_atomic = OperatorIsAtomic{Nothing}()) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable) if is_function_of(var, iv1) && !isequal(var, iv2_of_iv1) # of the form f(t)? but prevent e.g. u(t) -> u(u(t)) var_of_iv1 = var # e.g. f(t) - var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1) # e.g. f(u(t)) - ex = substitute(ex, var_of_iv1 => var_of_iv2_of_iv1; fold) + var_of_iv2_of_iv1 = substitute(var_of_iv1, iv1 => iv2_of_iv1; filterer = Returns(true)) # e.g. f(u(t)) + ex = substitute(ex, var_of_iv1 => var_of_iv2_of_iv1; fold, filterer = Returns(true)) end end # 2) Repeatedly expand chain rule until nothing changes anymore @@ -550,10 +550,10 @@ function change_independent_variable( while !isequal(ex, orgex) orgex = ex # save original ex = expand_derivatives(ex, simplify) # expand chain rule, e.g. (d/dt)(f(u(t)))) -> df(u(t))/du(t) * du(t)/dt - ex = substitute(ex, D1(iv2_of_iv1) => div2_of_iv2_of_iv1; fold) # e.g. du(t)/dt -> uˍt(u(t)) + ex = substitute(ex, D1(iv2_of_iv1) => div2_of_iv2_of_iv1; fold, filterer = Returns(true)) # e.g. du(t)/dt -> uˍt(u(t)) end # 3) Set new independent variable - ex = substitute(ex, iv2_of_iv1 => iv2; fold) # set e.g. u(t) -> u everywhere + ex = substitute(ex, iv2_of_iv1 => iv2; fold, filterer = Returns(true)) # set e.g. u(t) -> u everywhere ex = substitute(ex, iv1 => iv1_of_iv2; fold) # set e.g. t -> t(u) everywhere return ex::T end From 0d1a8636f92cf285ff39304648794f12b1dec649 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:57:55 +0530 Subject: [PATCH 149/157] fix: unwrap consts in `respecialize` --- src/systems/diffeqs/basic_transformations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index a73a8358ad..f7678f9b81 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -1019,7 +1019,7 @@ function respecialize(sys::AbstractSystem, mapping; all = false) k, v = element else k = element - v = get(final_defs, k, nothing) + v = value(get(final_defs, k, nothing)) @assert v !== nothing """ Parameter $k needs an associated value to be respecialized. """ From 3bac999d8b87cafaca6f70ed26296921e7d3c2b1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:58:22 +0530 Subject: [PATCH 150/157] fix: handle edge case in `add_initialization_parameters` --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 73ba6712b0..7ff5155921 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -576,7 +576,7 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) for (i, v) in enumerate(initials) initials[i] = Initial()(v) end - @set! sys.ps = unique!([get_ps(sys); initials]) + @set! sys.ps = unique!([filter(!isinitial, get_ps(sys)); initials]) defs = copy(get_defaults(sys)) for ivar in initials if symbolic_type(ivar) == ScalarSymbolic() From 7a52bf6a7af3055700bb74850108ad0cc5df920b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:58:31 +0530 Subject: [PATCH 151/157] fixup! fix: improve type-stability of connection infrastructure --- src/systems/connectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 438b7231ed..5dcef24360 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -1135,7 +1135,7 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS svar = inner_streamvars[inner_i] args = SArgsT() push!(args, SU.Const{VartypeT}(Val(n_inner - 1))) - push!(args, SU.Const{VartypeT}(Val(n_outer - 1))) + push!(args, SU.Const{VartypeT}(Val(n_outer))) for i in eachindex(inner_cverts) i == inner_i && continue push!(args, inner_flowvars[i]) From 20a234f9497621697d2d1a27641a64ce21d9aced Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:58:34 +0530 Subject: [PATCH 152/157] fixup! fix: fix invalidations from `promote_symtype` method --- src/systems/connectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index 5dcef24360..182c38e0b5 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -1251,4 +1251,4 @@ function instream_rt(ins::Val{inner_n}, outs::Val{outer_n}, for k in 1:M and ck.m_flow.max > 0 =# end -SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Type{T}, ::Type{S}, ::Type{R}) where {T, S, R} = Real +SymbolicUtils.promote_symtype(::typeof(instream_rt), _...) = Real From 18200ec4c7d5efde9677313dd945db5cd29096d8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:59:06 +0530 Subject: [PATCH 153/157] fix: do not use `vars` in `ImperativeAffect` --- src/systems/imperative_affect.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 7fc0c6abe1..817f725bf2 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -138,7 +138,7 @@ function namespace_affect(affect::ImperativeAffect, s) end function invalid_variables(sys, expr) - filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = [])) + return SU.get_variables(expr, Set{SymbolicT}(all_symbols(sys))) end function unassignable_variables(sys, expr) From 40658303847d8b535c860024f488138fe5b41dd6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:59:11 +0530 Subject: [PATCH 154/157] fixup! fix: make `IndexCache` constructor more type-stable --- src/systems/index_cache.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index e00a1a250f..2195b9e7ca 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -111,8 +111,8 @@ function IndexCache(sys::AbstractSystem) cevs = continuous_events(sys) devs = discrete_events(sys) events = Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}[cevs; devs] - parse_callbacks_for_discretes!(cevs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, 0) - parse_callbacks_for_discretes!(devs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, length(cevs)) + parse_callbacks_for_discretes!(sys, cevs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, 0) + parse_callbacks_for_discretes!(sys, devs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, length(cevs)) clock_partitions = unique(collect(values(disc_param_callbacks)))::Vector{BitSet} disc_symtypes = Set{TypeT}() for x in keys(disc_param_callbacks) @@ -367,7 +367,7 @@ function insert_by_type!(buffers::Vector{SymbolicT}, sym::SymbolicT, ::TypeT) push!(buffers, sym) end -function parse_callbacks_for_discretes!(events::Vector, disc_param_callbacks::Dict{SymbolicT, BitSet}, constant_buffers::Dict{TypeT, Set{SymbolicT}}, nonnumeric_buffers::Dict{TypeT, Set{SymbolicT}}, offset::Int) +function parse_callbacks_for_discretes!(sys::AbstractSystem, events::Vector, disc_param_callbacks::Dict{SymbolicT, BitSet}, constant_buffers::Dict{TypeT, Set{SymbolicT}}, nonnumeric_buffers::Dict{TypeT, Set{SymbolicT}}, offset::Int) for (i, event) in enumerate(events) discs = Set{SymbolicParam}() affect = event.affect::Union{AffectSystem, ImperativeAffect, Nothing} From 83dc56b06a727acccefbc4649dac723ed52137ae Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 12:59:30 +0530 Subject: [PATCH 155/157] chore: additional precompile statements --- src/ModelingToolkit.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index b8e80a5a14..ada97ac7ad 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -451,10 +451,12 @@ PrecompileTools.@compile_workload begin q[1] q'q using ModelingToolkit - @variables x(ModelingToolkit.t_nounits) + @variables x(ModelingToolkit.t_nounits) y(ModelingToolkit.t_nounits) isequal(ModelingToolkit.D_nounits.x, ModelingToolkit.t_nounits) - sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + sys = System([ModelingToolkit.D_nounits(x) ~ x * y, y ~ 3x + 4 * D(y)], ModelingToolkit.t_nounits, [x, y], Num[]; name = :sys) + TearingState(sys) complete(sys) + mtkcompile(sys) @syms p[1:2] ndims(p) size(p) From 0d14c93e0b8e00bec1831ed512112a4cae5bc84d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 13:00:10 +0530 Subject: [PATCH 156/157] refactor: remove `vars`, `vars!` --- src/utils.jl | 84 ---------------------------------------------------- 1 file changed, 84 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 0c8a00b080..0793bcd717 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -426,90 +426,6 @@ function isvariable(x) hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1])::Bool end -""" - vars(x; op=Differential) - -Return a `Set` containing all variables in `x` that appear in - - - differential equations if `op = Differential` - -Example: - -``` -t = ModelingToolkit.t_nounits -@variables u(t) y(t) -D = Differential(t) -v = ModelingToolkit.vars(D(y) ~ u) -v == Set([D(y), u]) -``` -""" -function vars(exprs::SymbolicT; op = Differential) - iscall(exprs) ? vars([exprs]; op = op) : Set([exprs]) -end -vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) -vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op) -function vars(exprs; op = Differential) - if hasmethod(iterate, Tuple{typeof(exprs)}) - foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set()) - else - vars!(Set(), unwrap(exprs); op) - end -end -vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) -function vars!(vars, eq::Equation; op = Differential) - (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) -end -function vars!(vars, O::AbstractSystem; op = Differential) - for eq in equations(O) - vars!(vars, eq; op) - end - return vars -end -function vars!(vars, O; op = Differential) - if isvariable(O) - if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O))) - O = first(arguments(O)) - end - if iscalledparameter(O) - f = getcalledparameter(O) - push!(vars, f) - for arg in arguments(O) - if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray - for el in arg - vars!(vars, unwrap(el); op) - end - else - vars!(vars, arg; op) - end - end - return vars - end - return push!(vars, O) - end - if symbolic_type(O) == NotSymbolic() && O isa AbstractArray - for arg in O - vars!(vars, unwrap(arg); op) - end - return vars - end - !iscall(O) && return vars - - operation(O) isa op && return push!(vars, O) - - if operation(O) === (getindex) - arr = first(arguments(O)) - iscall(arr) && operation(arr) isa op && return push!(vars, O) - isvariable(arr) && return push!(vars, O) - end - - isvariable(operation(O)) && push!(vars, O) - for arg in arguments(O) - vars!(vars, arg; op = op) - end - - return vars -end - function collect_operator_variables(sys::AbstractSystem, args...) collect_operator_variables(equations(sys), args...) end From d4e43960dbb947bbf30b35a57d9fc236ad4541b2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 15 Oct 2025 13:00:31 +0530 Subject: [PATCH 157/157] refactor: improve type-stability of `subexpressions_not_involving_vars!` --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 0793bcd717..ff421e523a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -978,8 +978,8 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) end if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic() - indep_args = [] - dep_args = [] + indep_args = SymbolicT[] + dep_args = SymbolicT[] for arg in args _vs = SU.search_variables(arg) intersect!(_vs, vars)