diff --git a/Project.toml b/Project.toml index 661debc6ce..c4303b57bd 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -33,7 +32,6 @@ FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ImplicitDiscreteSolve = "3263718b-31ed-49cf-8a0f-35a466e8af96" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -68,18 +66,22 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKCasADiDynamicOptExt = "CasADi" +MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" +MTKJuliaFormatterExt = "JuliaFormatter" MTKLabelledArraysExt = "LabelledArrays" MTKPyomoDynamicOptExt = "Pyomo" diff --git a/docs/Project.toml b/docs/Project.toml index 499a76a921..db2e22fca3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -32,7 +32,6 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [compat] Attractors = "1.24" @@ -65,4 +64,3 @@ StochasticDiffEq = "6" SymbolicIndexingInterface = "0.3.1" SymbolicUtils = "3, 4" Symbolics = "6" -Unitful = "1.12" diff --git a/docs/src/API/codegen.md b/docs/src/API/codegen.md index 4f31405174..0092e1cb8e 100644 --- a/docs/src/API/codegen.md +++ b/docs/src/API/codegen.md @@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b All code generation eventually calls `build_function_wrapper`. ```@docs -build_function_wrapper +ModelingToolkit.build_function_wrapper ``` diff --git a/docs/src/API/model_building.md b/docs/src/API/model_building.md index 7a8389fd2f..6ad624298e 100644 --- a/docs/src/API/model_building.md +++ b/docs/src/API/model_building.md @@ -13,8 +13,6 @@ ModelingToolkit.t_nounits ModelingToolkit.D_nounits ModelingToolkit.t ModelingToolkit.D -ModelingToolkit.t_unitful -ModelingToolkit.D_unitful ``` Users are recommended to use the appropriate common definition in their models. The required diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 2511c00675..7b712395b3 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -192,9 +192,7 @@ p, replace, alias = SciMLStructures.canonicalize(Tunable(), prob.p) # changes to the array will be reflected in parameter values ``` -See the [basic example on optimizing](https://docs.sciml.ai/ModelingToolkit/dev/examples/remake/#Optimizing-through-an-ODE-solve-and-re-creating-MTK-Problems) for combining these steps to optimizing parameters and use ForwardDiff.jl as the backend for Automatic Differentiation. - -# ERROR: ArgumentError: SymbolicUtils.BasicSymbolic{Real}[xˍt(t)] are missing from the variable map. +# ERROR: ArgumentError: `[xˍt(t)]` are missing from the variable map. This error can come up after running `mtkcompile` on a system that generates dummy derivatives (i.e. variables with `ˍt`). For example, here even though all the variables are defined with initial values, the `ODEProblem` generation will throw an error that defaults are missing from the variable map. diff --git a/docs/src/basics/Validation.md b/docs/src/basics/Validation.md index 3f36a06e5e..651cf3f477 100644 --- a/docs/src/basics/Validation.md +++ b/docs/src/basics/Validation.md @@ -155,7 +155,7 @@ future when `ModelingToolkit` is extended to support eliminating `DynamicQuantit ## Other Restrictions -`Unitful` provides non-scalar units such as `dBm`, `°C`, etc. At this time, `ModelingToolkit` only supports scalar quantities. Additionally, angular degrees (`°`) are not supported because trigonometric functions will treat plain numerical values as radians, which would lead systems validated using degrees to behave erroneously when being solved. +`DynamicQuantities` provides non-scalar units such as `°C`, etc. At this time, `ModelingToolkit` only supports scalar quantities. Additionally, angular degrees (`°`) are not supported because trigonometric functions will treat plain numerical values as radians, which would lead systems validated using degrees to behave erroneously when being solved. ## Troubleshooting & Gotchas @@ -169,7 +169,7 @@ Parameter and initial condition values are supplied to problem constructors as p ```julia function remove_units(p::Dict) - Dict(k => Unitful.ustrip(ModelingToolkit.get_unit(k), v) for (k, v) in p) + Dict(k => DynamicQuantities.ustrip(ModelingToolkit.get_unit(k), v) for (k, v) in p) end add_units(p::Dict) = Dict(k => v * ModelingToolkit.get_unit(k) for (k, v) in p) ``` diff --git a/ext/MTKCasADiDynamicOptExt.jl b/ext/MTKCasADiDynamicOptExt.jl index addc478d98..b1762e7f7f 100644 --- a/ext/MTKCasADiDynamicOptExt.jl +++ b/ext/MTKCasADiDynamicOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::CasADiModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X.u[i, :] : X(t)[i] + t isa Union{Num, SymbolicT} ? X.u[i, :] : X(t)[i] end function MTK.lowered_integral(model::CasADiModel, expr, lo, hi) diff --git a/src/adjoints.jl b/ext/MTKChainRulesCoreExt.jl similarity index 85% rename from src/adjoints.jl rename to ext/MTKChainRulesCoreExt.jl index 98266de938..409cad8704 100644 --- a/src/adjoints.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -1,3 +1,14 @@ +module MTKChainRulesCoreExt + +import ChainRulesCore +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk +using ModelingToolkit: MTKParameters, NONNUMERIC_PORTION, AbstractSystem +import ModelingToolkit +import ModelingToolkit as MTK +import SciMLStructures +import SymbolicIndexingInterface: remake_buffer +import SciMLBase: AbstractNonlinearProblem, remake + function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...) function mtp_pullback(dt) dt = unthunk(dt) @@ -104,3 +115,11 @@ function ChainRulesCore.rrule( end ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol) + +function ModelingToolkit.update_initializeprob!(initprob::AbstractNonlinearProblem, prob) + pgetter = ChainRulesCore.@ignore_derivatives MTK.get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + p = pgetter(prob, initprob) + return remake(initprob; p) +end + +end diff --git a/ext/MTKInfiniteOptExt.jl b/ext/MTKInfiniteOptExt.jl index e0f02c0436..acb80c1041 100644 --- a/ext/MTKInfiniteOptExt.jl +++ b/ext/MTKInfiniteOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::InfiniteOptModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X[i] : X[i](t) + t isa Union{Num, SymbolicT} ? X[i] : X[i](t) end function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau) @@ -256,13 +256,13 @@ for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqr end # JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both. -function Base.isequal(::SymbolicUtils.Symbolic, +function Base.isequal(::SymbolicT, ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}) false end function Base.isequal( ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}, - ::SymbolicUtils.Symbolic) + ::SymbolicT) false end end diff --git a/ext/MTKJuliaFormatterExt.jl b/ext/MTKJuliaFormatterExt.jl new file mode 100644 index 0000000000..a4efeec931 --- /dev/null +++ b/ext/MTKJuliaFormatterExt.jl @@ -0,0 +1,12 @@ +module MTKJuliaFormatterExt + +import ModelingToolkit: readable_code, _readable_code, rec_remove_macro_linenums! +import JuliaFormatter + +function readable_code(expr::Expr) + expr = Base.remove_linenums!(_readable_code(expr)) + rec_remove_macro_linenums!(expr) + JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) +end + +end diff --git a/ext/MTKPyomoDynamicOptExt.jl b/ext/MTKPyomoDynamicOptExt.jl index 5b4e9e7a1c..58888bd974 100644 --- a/ext/MTKPyomoDynamicOptExt.jl +++ b/ext/MTKPyomoDynamicOptExt.jl @@ -5,6 +5,7 @@ using DiffEqBase using UnPack using NaNMath using Setfield +import SymbolicUtils as SU const MTK = ModelingToolkit const SPECIAL_FUNCTIONS_DICT = Dict([acos => Pyomo.py_acos, @@ -53,7 +54,7 @@ struct PyomoDynamicOptProblem{uType, tType, isinplace, P, F, K} <: end end -function pysym_getproperty(s::Union{Num, Symbolics.Symbolic}, name::Symbol) +function pysym_getproperty(s::Union{Num, SymbolicT}, name::Symbol) Symbolics.wrap(SymbolicUtils.term( _getproperty, Symbolics.unwrap(s), Val{name}(), type = Symbolics.Struct{PyomoVar})) end @@ -112,7 +113,7 @@ function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons; n_idxs = 1) Symbolics.unwrap(expr), SPECIAL_FUNCTIONS_DICT, fold = false) cons_sym = Symbol("cons", hash(cons)) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) setproperty!(model, cons_sym, pyomo.Constraint(model.t, rule = Pyomo.pyfunc(f))) else @@ -124,7 +125,7 @@ end function MTK.set_objective!(pmodel::PyomoDynamicOptModel, expr) @unpack model, model_sym, t_sym, dummy_sym = pmodel expr = Symbolics.substitute(expr, SPECIAL_FUNCTIONS_DICT, fold = false) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) model.obj = pyomo.Objective(model.t, rule = Pyomo.pyfunc(f)) else @@ -165,7 +166,7 @@ end function MTK.lowered_var(m::PyomoDynamicOptModel, uv, i, t) X = Symbolics.value(pysym_getproperty(m.model_sym, uv)) - var = t isa Union{Num, Symbolics.Symbolic} ? X[i, m.t_sym] : X[i, t] + var = t isa Union{Num, SymbolicT} ? X[i, m.t_sym] : X[i, t] Symbolics.unwrap(var) end diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6c79eb4fb1..b8e80a5a14 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -6,28 +6,34 @@ using PrecompileTools, Reexport @recompile_invalidations begin using StaticArrays using Symbolics + using ImplicitDiscreteSolve + using JumpProcesses + # ONLY here for the invalidations + import REPL end import SymbolicUtils +import SymbolicUtils as SU import SymbolicUtils: iscall, arguments, operation, maketerm, promote_symtype, - Symbolic, isadd, ismul, ispow, issym, FnType, - @rule, Rewriters, substitute, metadata, BasicSymbolic, - Sym, Term + isadd, ismul, ispow, issym, FnType, isconst, BSImpl, + @rule, Rewriters, substitute, metadata, BasicSymbolic using SymbolicUtils.Code import SymbolicUtils.Code: toexpr import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint using DocStringExtensions using SpecialFunctions, NaNMath -using DiffEqCallbacks +@recompile_invalidations begin + using DiffEqCallbacks +using DiffEqNoiseProcess: DiffEqNoiseProcess, WienerProcess +using DiffEqBase, SciMLBase, ForwardDiff +end using Graphs import ExprTools: splitdef, combinedef import OrderedCollections -using DiffEqNoiseProcess: DiffEqNoiseProcess, WienerProcess using SymbolicIndexingInterface using LinearAlgebra, SparseArrays using InteractiveUtils -using JumpProcesses using DataStructures @static if pkgversion(DataStructures) >= v"0.19" import DataStructures: IntDisjointSet @@ -36,7 +42,7 @@ else const IntDisjointSet = IntDisjointSets end using Base.Threads -using Latexify, Unitful, ArrayInterface +using Latexify, ArrayInterface using Setfield, ConstructionBase import Libdl using DocStringExtensions @@ -49,17 +55,14 @@ using URIs: URI using SciMLStructures using Compat using AbstractTrees -using DiffEqBase, SciMLBase, ForwardDiff using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, TimeDomain, PeriodicClock, Clock, SolverStepClock, ContinuousClock, OverrideInit, NoInit using Distributed -import JuliaFormatter using MLStyle import Moshi using Moshi.Data: @data import SCCNonlinearSolve -using ImplicitDiscreteSolve using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix @@ -68,19 +71,16 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc using OffsetArrays: Origin import CommonSolve import EnumX -import ChainRulesCore -import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr -using Symbolics: degree -using Symbolics: _parse_vars, value, @derivatives, get_variables, - exprs_occur_in, symbolic_linear_solve, build_expr, unwrap, wrap, +using Symbolics: degree, VartypeT, SymbolicT +using Symbolics: parse_vars, value, @derivatives, get_variables, + exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, - NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, - hasnode, fixpoint_sub, fast_substitute, - CallWithMetadata, CallWithParent + NAMESPACE_SEPARATOR, setdefaultval, Arr, + hasnode, fixpoint_sub, CallAndWrap, SArgsT, SSym, STerm const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, @@ -89,7 +89,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity, ParallelForm, SerialForm, MultithreadedForm, build_function, rhss, lhss, prettify_expr, gradient, jacobian, hessian, derivative, sparsejacobian, sparsehessian, - substituter, scalarize, getparent, hasderiv, hasdiff + scalarize, hasderiv import DiffEqBase: @add_kwonly export independent_variables, unknowns, observables, parameters, full_parameters, @@ -98,7 +98,7 @@ export independent_variables, unknowns, observables, parameters, full_parameters @reexport using UnPack RuntimeGeneratedFunctions.init(@__MODULE__) -import DynamicQuantities, Unitful +import DynamicQuantities const DQ = DynamicQuantities import DifferentiationInterface as DI @@ -163,6 +163,10 @@ include("parameters.jl") include("independent_variables.jl") include("constants.jl") +const SymmapT = Dict{SymbolicT, SymbolicT} +const COMMON_NOTHING = SU.Const{VartypeT}(nothing) +const COMMON_MISSING = SU.Const{VartypeT}(missing) + include("utils.jl") include("systems/index_cache.jl") @@ -172,10 +176,10 @@ include("systems/model_parsing.jl") include("systems/connectiongraph.jl") include("systems/connectors.jl") include("systems/state_machines.jl") -include("systems/analysis_points.jl") include("systems/imperative_affect.jl") include("systems/callbacks.jl") include("systems/system.jl") +include("systems/analysis_points.jl") include("systems/codegen_utils.jl") include("problems/docs.jl") include("systems/codegen.jl") @@ -215,7 +219,6 @@ include("systems/pde/pdesystem.jl") include("systems/sparsematrixclil.jl") include("systems/unit_check.jl") -include("systems/validation.jl") include("systems/dependency_graphs.jl") include("clock.jl") include("discretedomain.jl") @@ -231,21 +234,16 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") -include("adjoints.jl") include("deprecations.jl") const t_nounits = let only(@independent_variables t) end -const t_unitful = let - only(@independent_variables t [unit = Unitful.u"s"]) -end const t = let only(@independent_variables t [unit = DQ.u"s"]) end const D_nounits = Differential(t_nounits) -const D_unitful = Differential(t_unitful) const D = Differential(t) export ODEFunction, convert_system_indepvar, @@ -340,13 +338,15 @@ export AbstractCollocation, JuMPCollocation, InfiniteOptCollocation, CasADiCollocation, PyomoCollocation export DynamicOptSolution +const set_scalar_metadata = setmetadata + @public apply_to_variables, equations_toplevel, unknowns_toplevel, parameters_toplevel @public continuous_events_toplevel, discrete_events_toplevel, assertions, is_alg_equation @public is_diff_equation, Equality, linearize_symbolic, reorder_unknowns @public similarity_transform, inputs, outputs, bound_inputs, unbound_inputs, bound_outputs @public unbound_outputs, is_bound @public AbstractSystem, CheckAll, CheckNone, CheckComponents, CheckUnits -@public t, D, t_nounits, D_nounits, t_unitful, D_unitful +@public t, D, t_nounits, D_nounits @public SymbolicContinuousCallback, SymbolicDiscreteCallback @public VariableType, MTKVariableTypeCtx, VariableBounds, VariableConnectType @public VariableDescription, VariableInput, VariableIrreducible, VariableMisc @@ -361,11 +361,108 @@ for prop in [SYS_PROPS; [:continuous_events, :discrete_events]] @eval @public $getter, $hasfn end +function __init__() + SU.hashcons(unwrap(t_nounits), true) + SU.hashcons(unwrap(t), true) + SU.hashcons(COMMON_NOTHING, true) + SU.hashcons(COMMON_MISSING, true) +end + PrecompileTools.@compile_workload begin - using ModelingToolkit + fold1 = Val{false}() + using SymbolicUtils + using SymbolicUtils: shape + using Symbolics + @syms x y f(t) q[1:5] + SymbolicUtils.Sym{SymReal}(:a; type = Real, shape = SymbolicUtils.ShapeVecT()) + x + y + x * y + x / y + x ^ y + x ^ 5 + 6 ^ x + x - y + -y + 2y + z = 2 + dict = SymbolicUtils.ACDict{VartypeT}() + dict[x] = 1 + dict[y] = 1 + type::typeof(DataType) = rand() < 0.5 ? Real : Float64 + nt = (; type, shape, unsafe = true) + Base.pairs(nt) + BSImpl.AddMul{VartypeT}(1, dict, SymbolicUtils.AddMulVariant.MUL; type, shape = SymbolicUtils.ShapeVecT(), unsafe = true) + *(y, z) + *(z, y) + SymbolicUtils.symtype(y) + f(x) + (5x / 5) + expand((x + y) ^ 2) + simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + rules2 = Dict(x => 1) + Dx = Differential(x) + Differential(y)(ex) + uex = unwrap(ex) + Symbolics.executediff(Dx, uex) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + substitute(ex, rules2; fold = fold1) + @variables foo + f(foo) + @variables x y f(::Real) q[1:5] + x + y + x * y + x / y + x ^ y + x ^ 5 + # 6 ^ x + x - y + -y + 2y + symtype(y) + z = 2 + *(y, z) + *(z, y) + f(x) + (5x / 5) + [x, y] + [x, f, f] + promote_type(Int, Num) + promote_type(Real, Num) + promote_type(Float64, Num) + # expand((x + y) ^ 2) + # simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + # rules2 = Dict(x => 1) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + Symbolics.linear_expansion(ex, y) + # substitute(ex, rules2; fold = fold1) + # substitute(ex, rules2) + # substitute(ex, rules1; fold = fold2) + # substitute(ex, rules2; fold = fold2) + q[1] + q'q + using ModelingToolkit @variables x(ModelingToolkit.t_nounits) - @named sys = System([ModelingToolkit.D_nounits(x) ~ -x], ModelingToolkit.t_nounits) - prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), jac = true) + isequal(ModelingToolkit.D_nounits.x, ModelingToolkit.t_nounits) + sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + complete(sys) + @syms p[1:2] + ndims(p) + size(p) + axes(p) + length(p) + v = [p] + isempty(v) + # mtkcompile(sys) @mtkmodel __testmod__ begin @constants begin c = 1.0 @@ -396,4 +493,17 @@ PrecompileTools.@compile_workload begin end end +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Nothing, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:id, :hash, :hash2), Tuple{Nothing, Int64, Int64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, Int64, Int64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, Nothing}}, Type{SymbolicUtils.BasicSymbolicImpl.Term{SymbolicUtils.SymReal}}}) +precompile(Tuple{typeof(Symbolics.parse_vars), Symbol, Type, Tuple{Symbol, Symbol}, Function}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:name, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{Symbol, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.vect), Symbolics.Equation, Vararg{Symbolics.Equation}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(SymbolicUtils.isequal_somescalar), Float64, Float64}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults, :guesses), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults, :guesses), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:type, :shape), Tuple{DataType, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}}}, typeof(SymbolicUtils.term), Any, SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}}) + end # module diff --git a/src/bipartite_graph.jl b/src/bipartite_graph.jl index 6e4f359617..0fab095928 100644 --- a/src/bipartite_graph.jl +++ b/src/bipartite_graph.jl @@ -375,7 +375,8 @@ end function 𝑑neighbors(g::BipartiteGraph, j::Integer, with_metadata::Val{M} = Val(false)) where {M} require_complete(g) - M ? zip(g.badjlist[j], (g.metadata[i][j] for i in g.badjlist[j])) : g.badjlist[j] + backj = g.badjlist[j]::Vector{Int} + M ? zip(backj, (g.metadata[i][j] for i in backj)) : backj end Graphs.ne(g::BipartiteGraph) = g.ne Graphs.nv(g::BipartiteGraph) = sum(length, vertices(g)) diff --git a/src/clock.jl b/src/clock.jl index df3b6f4b47..08c3d84dd0 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -45,7 +45,7 @@ has_time_domain(_, x) = has_time_domain(x) Determine if variable `x` has a time-domain attributed to it. """ -function has_time_domain(x::Symbolic) +function has_time_domain(x::SymbolicT) # getmetadata(x, ContinuousClock, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing getmetadata(x, VariableTimeDomain, nothing) !== nothing @@ -53,10 +53,8 @@ end has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false -for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),) - @eval output_timedomain(::$op, arg = nothing) = ContinuousClock() -end +input_timedomain(::Differential, arg = nothing) = InputTimeDomainElT[ContinuousClock()] +output_timedomain(::Differential, arg = nothing) = ContinuousClock() """ has_discrete_domain(x) @@ -77,7 +75,7 @@ See also [`is_continuous_domain`](@ref) """ function has_continuous_domain(x) issym(x) && return is_continuous_domain(x) - hasderiv(x) || hasdiff(x) || hassample(x) || hashold(x) + hasderiv(x) || hassample(x) || hashold(x) end """ diff --git a/src/constants.jl b/src/constants.jl index 4113287ad4..fed010a2ee 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -3,7 +3,7 @@ Test whether `x` is a constant-type Sym. """ function isconstant(x) x = unwrap(x) - x isa Symbolic && !getmetadata(x, VariableTunable, true) + x isa SymbolicT && !getmetadata(x, VariableTunable, true) end """ @@ -26,8 +26,8 @@ Define one or more constants. See also [`@independent_variables`](@ref), [`@parameters`](@ref) and [`@variables`](@ref). """ macro constants(xs...) - Symbolics._parse_vars(:constants, + Symbolics.parse_vars(:constants, Real, xs, - toconstant) |> esc + toconstant) end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 9e57296d9f..d9fb3496f7 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -33,7 +33,8 @@ at the inferred clock for that equation. struct SampleTime <: Operator SampleTime() = SymbolicUtils.term(SampleTime, type = Real) end -SymbolicUtils.promote_symtype(::Type{<:SampleTime}, t...) = Real +SymbolicUtils.promote_symtype(::Type{SampleTime}, ::Type{T}) where {T} = Real +SymbolicUtils.promote_shape(::Type{SampleTime}, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::SampleTime) = :SampleTime SymbolicUtils.isbinop(::SampleTime) = false @@ -60,7 +61,7 @@ julia> Δ = Shift(t) """ struct Shift <: Operator """Fixed Shift""" - t::Union{Nothing, Symbolic} + t::Union{Nothing, SymbolicT} steps::Int Shift(t, steps = 1) = new(value(t), steps) end @@ -69,13 +70,12 @@ normalize_to_differential(s::Shift) = Differential(s.t)^s.steps Base.nameof(::Shift) = :Shift SymbolicUtils.isbinop(::Shift) = false +function (D::Shift)(x::Equation, allow_zero = false) + D(x.lhs, allow_zero) ~ D(x.rhs, allow_zero) +end function (D::Shift)(x, allow_zero = false) !allow_zero && D.steps == 0 && return x - if Symbolics.isarraysymbolic(x) - Symbolics.array_term(D, x) - else - term(D, x) - end + term(D, x; type = symtype(x), shape = SU.shape(x)) end function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) !allow_zero && D.steps == 0 && return x @@ -94,7 +94,8 @@ function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) end wrap(D(vt, allow_zero)) end -SymbolicUtils.promote_symtype(::Shift, t) = t +SymbolicUtils.promote_symtype(::Shift, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Shift, @nospecialize(x::SU.ShapeT)) = x Base.show(io::IO, D::Shift) = print(io, "Shift(", D.t, ", ", D.steps, ")") @@ -162,9 +163,10 @@ function Sample(arg::Real) Sample()(arg) end end -(D::Sample)(x) = Term{symtype(x)}(D, Any[x]) +(D::Sample)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Sample)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Sample, x) = x +SymbolicUtils.promote_symtype(::Sample, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Sample, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Sample) = :Sample SymbolicUtils.isbinop(::Sample) = false @@ -208,9 +210,10 @@ end is_transparent_operator(::Type{Hold}) = true -(D::Hold)(x) = Term{symtype(x)}(D, Any[x]) +(D::Hold)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Hold)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Hold, x) = x +SymbolicUtils.promote_symtype(::Hold, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Hold, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Hold) = :Hold SymbolicUtils.isbinop(::Hold) = false @@ -264,9 +267,10 @@ end function (xn::Num)(k::ShiftIndex) @unpack clock, steps = k - x = value(xn) + x = unwrap(xn) # Verify that the independent variables of k and x match and that the expression doesn't have multiple variables - vars = ModelingToolkit.vars(x) + vars = Set{SymbolicT}() + SU.search_variables!(vars, x) if length(vars) != 1 error("Cannot shift a multivariate expression $x. Either create a new unknown and shift this, or shift the individual variables in the expression.") end @@ -324,6 +328,8 @@ end Base.:+(k::ShiftIndex, i::Int) = ShiftIndex(k.clock, k.steps + i) Base.:-(k::ShiftIndex, i::Int) = k + (-i) +const InputTimeDomainElT = Union{TimeDomain, InferredTimeDomain} + """ input_timedomain(op::Operator) @@ -332,9 +338,9 @@ Should return a tuple containing the time domain type for each argument to the o """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) - return get_time_domain(arg) + return InputTimeDomainElT[get_time_domain(arg)] end - (InferredDiscrete(),) + InputTimeDomainElT[InferredDiscrete()] end """ @@ -349,14 +355,14 @@ function output_timedomain(s::Shift, arg = nothing) InferredDiscrete() end -input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),) +input_timedomain(::Sample, _ = nothing) = InputTimeDomainElT[ContinuousClock()] output_timedomain(s::Sample, _ = nothing) = s.clock -function input_timedomain(h::Hold, arg = nothing) +function input_timedomain(::Hold, arg = nothing) if has_time_domain(arg) - return get_time_domain(arg) + return InputTimeDomainElT[get_time_domain(arg)] end - (InferredDiscrete(),) # the Hold accepts any discrete + InputTimeDomainElT[InferredDiscrete()] # the Hold accepts any discrete end output_timedomain(::Hold, _ = nothing) = ContinuousClock() diff --git a/src/independent_variables.jl b/src/independent_variables.jl index d1f2ab4210..fce2d93873 100644 --- a/src/independent_variables.jl +++ b/src/independent_variables.jl @@ -7,12 +7,12 @@ Define one or more independent variables. For example: @variables x(t) """ macro independent_variables(ts...) - Symbolics._parse_vars(:independent_variables, + Symbolics.parse_vars(:independent_variables, Real, ts, - toiv) |> esc + toiv) end -toiv(s::Symbolic) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) +toiv(s::SymbolicT) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) toiv(s::Symbolics.Arr) = wrap(toiv(value(s))) toiv(s::Num) = Num(toiv(value(s))) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index c113c4e753..8f0b063fb5 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -49,17 +49,38 @@ See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@re """ unbound_outputs(sys) = filter(x -> !is_bound(sys, x), outputs(sys)) -""" - is_bound(sys, u) +function _is_atomic_inside_operator(ex::SymbolicT) + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => false + _ => true + end +end -Determine whether input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`. -A variable/signal is considered bound if it appears in an equation together with variables from other subsystems. -The typical usecase for this function is to determine whether the input to an IO component is connected to another component, -or if it remains an external input that the user has to supply before simulating the system. +struct IsBoundValidator + eqs_vars::Vector{Set{SymbolicT}} + obs_vars::Vector{Set{SymbolicT}} + stack::OrderedSet{SymbolicT} +end -See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref) -""" -function is_bound(sys, u, stack = []) +function IsBoundValidator(sys::System) + eqs_vars = Set{SymbolicT}[] + for eq in equations(sys) + vars = Set{SymbolicT}() + SU.search_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + SU.search_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + push!(eqs_vars, vars) + end + obs_vars = Set{SymbolicT}[] + for eq in observed(sys) + vars = Set{SymbolicT}() + SU.search_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + SU.search_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) + push!(obs_vars, vars) + end + return IsBoundValidator(eqs_vars, obs_vars, OrderedSet{SymbolicT}()) +end + +function (ibv::IsBoundValidator)(u::SymbolicT) #= For observed quantities, we check if a variable is connected to something that is bound to something further out. In the following scenario @@ -71,35 +92,42 @@ function is_bound(sys, u, stack = []) When asking is_bound(sys₊y(t)), we know that we are looking through observed equations and can thus ask if var is bound, if it is, then sys₊y(t) is also bound. This can lead to an infinite recursion, so we maintain a stack of variables we have previously asked about to be able to break cycles =# - u ∈ Set(stack) && return false # Cycle detected - eqs = equations(sys) - eqs = filter(eq -> has_var(eq, u), eqs) # Only look at equations that contain u - # isout = isoutput(u) - for eq in eqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + u in ibv.stack && return false # Cycle detected + for vars in ibv.eqs_vars + u in vars || continue for var in vars var === u && continue - if !same_or_inner_namespace(u, var) - return true - end + same_or_inner_namespace(u, var) || return true end end - # Look through observed equations as well - oeqs = observed(sys) - oeqs = filter(eq -> has_var(eq, u), oeqs) # Only look at equations that contain u - for eq in oeqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + for vars in ibv.obs_vars + u in vars || continue for var in vars var === u && continue - if !same_or_inner_namespace(u, var) - return true - end - if is_bound(sys, var, [stack; u]) && !inner_namespace(u, var) # The variable we are comparing to can not come from an inner namespace, binding only counts outwards - return true - end + same_or_inner_namespace(u, var) || return true + push!(ibv.stack, u) + isbound = ibv(var) + pop!(ibv.stack) + # The variable we are comparing to can not come from an inner namespace, + # binding only counts outwards + isbound && !inner_namespace(u, var) && return true end end - false + return false +end + +""" + is_bound(sys, u) + +Determine whether input/output variable `u` is "bound" within the system, i.e., if it's to be considered internal to `sys`. +A variable/signal is considered bound if it appears in an equation together with variables from other subsystems. +The typical usecase for this function is to determine whether the input to an IO component is connected to another component, +or if it remains an external input that the user has to supply before simulating the system. + +See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@ref), [`unbound_outputs`](@ref) +""" +function is_bound(sys, u) + return IsBoundValidator(sys)(unwrap(u)) end """ @@ -203,9 +231,9 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs( # add to inputs for the purposes of io processing inputs = [inputs; disturbance_inputs] end - + inputs = vec(unwrap_vars(inputs)) dvs = unknowns(sys) - ps = parameters(sys; initial_parameters = true) + ps::Vector{SymbolicT} = parameters(sys; initial_parameters = true) ps = setdiff(ps, inputs) if disturbance_inputs !== nothing # remove from inputs since we do not want them as actual inputs to the dynamics @@ -252,7 +280,7 @@ end """ Turn input variables into parameters of the system. """ -function inputs_to_parameters!(state::TransformationState, inputsyms) +function inputs_to_parameters!(state::TransformationState, inputsyms::Vector{SymbolicT}) check_bound = inputsyms === nothing @unpack structure, fullvars, sys = state @unpack var_to_diff, graph, solvable_graph = structure @@ -262,9 +290,9 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) var_reidx = zeros(Int, length(fullvars)) ninputs = 0 nvar = 0 - new_parameters = [] - input_to_parameters = Dict() - new_fullvars = [] + new_parameters = SymbolicT[] + input_to_parameters = Dict{SymbolicT, SymbolicT}() + new_fullvars = SymbolicT[] for (i, v) in enumerate(fullvars) if isinput(v) && !(check_bound && is_bound(sys, v)) if var_to_diff[i] !== nothing @@ -283,8 +311,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) end end if ninputs == 0 - @set! sys.inputs = OrderedSet{BasicSymbolic}() - @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) + @set! sys.inputs = OrderedSet{SymbolicT}() + @set! sys.outputs = OrderedSet{SymbolicT}(filter(isoutput, fullvars)) state.sys = sys return state end @@ -312,15 +340,15 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) @set! structure.graph = complete(new_graph) @set! sys.eqs = isempty(input_to_parameters) ? equations(sys) : - fast_substitute(equations(sys), input_to_parameters) + substitute(equations(sys), input_to_parameters) @set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters)) ps = parameters(sys) @set! sys.ps = [ps; new_parameters] - @set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters) - @set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars)) + @set! sys.inputs = OrderedSet{SymbolicT}(new_parameters) + @set! sys.outputs = OrderedSet{SymbolicT}(filter(isoutput, fullvars)) @set! state.sys = sys - @set! state.fullvars = Vector{BasicSymbolic}(new_fullvars) + @set! state.fullvars = Vector{SymbolicT}(new_fullvars) @set! state.structure = structure return state end diff --git a/src/linearization.jl b/src/linearization.jl index 3c11484d61..aa8c93df94 100644 --- a/src/linearization.jl +++ b/src/linearization.jl @@ -611,13 +611,24 @@ end """ Modify the variable metadata of system variables to indicate which ones are inputs, outputs, and disturbances. Needed for `inputs`, `outputs`, `disturbances`, `unbound_inputs`, `unbound_outputs` to return the proper subsets. """ -function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true) +function markio!(state, orig_inputs::Set{SymbolicT}, + inputs::Vector{SymbolicT}, outputs::Vector{SymbolicT}, + disturbances::Vector{SymbolicT}; check = true) fullvars = get_fullvars(state) - inputset = Dict{Any, Bool}(i => false for i in inputs) - outputset = Dict{Any, Bool}(o => false for o in outputs) - disturbanceset = Dict{Any, Bool}(d => false for d in disturbances) + inputset = Dict{SymbolicT, Bool}() + for i in inputs + inputset[i] = false + end + outputset = Dict{SymbolicT, Bool}() + for o in outputs + outputset[o] = false + end + disturbanceset = Dict{SymbolicT, Bool}() + for d in disturbances + disturbanceset[d] = false + end for (i, v) in enumerate(fullvars) - if v in keys(inputset) + if haskey(inputset, v) if v in keys(outputset) v = setio(v, true, true) outputset[v] = true @@ -626,7 +637,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true end inputset[v] = true fullvars[i] = v - elseif v in keys(outputset) + elseif haskey(outputset, v) v = setio(v, false, true) outputset[v] = true fullvars[i] = v @@ -638,7 +649,7 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true fullvars[i] = v end - if v in keys(disturbanceset) + if haskey(disturbanceset, v) v = setio(v, true, false) v = setdisturbance(v, true) disturbanceset[v] = true diff --git a/src/parameters.jl b/src/parameters.jl index d8ff1bf1be..d3bc796d2f 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -15,48 +15,29 @@ The symbolic metadata key for storing the `VariableType`. """ struct MTKVariableTypeCtx end -getvariabletype(x, def = VARIABLE) = getmetadata(unwrap(x), MTKVariableTypeCtx, def) +getvariabletype(x, def = VARIABLE) = safe_getmetadata(MTKVariableTypeCtx, unwrap(x), def)::Union{typeof(def), VariableType} """ $TYPEDEF Check if the variable contains the metadata identifying it as a parameter. """ -function isparameter(x) - x = unwrap(x) - - if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing - return varT === PARAMETER - #TODO: Delete this branch - elseif x isa Symbolic && Symbolics.getparent(x, false) !== false - p = Symbolics.getparent(x) - isparameter(p) || - (hasmetadata(p, Symbolics.VariableSource) && - getmetadata(p, Symbolics.VariableSource)[1] == :parameters) - elseif iscall(x) && operation(x) isa Symbolic - varT === PARAMETER || isparameter(operation(x)) - elseif iscall(x) && operation(x) == (getindex) - isparameter(arguments(x)[1]) - elseif x isa Symbolic - varT === PARAMETER - else - false - end +isparameter(x::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) = isparameter(unwrap(x)) +function isparameter(x::SymbolicT) + varT = getvariabletype(x, nothing) + return varT === PARAMETER end +isparameter(x) = false function iscalledparameter(x) x = unwrap(x) - return isparameter(getmetadata(x, CallWithParent, nothing)) + return SymbolicUtils.is_called_function_symbolic(x) && isparameter(operation(x)) end function getcalledparameter(x) x = unwrap(x) - # `parent` is a `CallWithMetadata` with the correct metadata, - # but no namespacing. `operation(x)` has the correct namespacing, - # but is not a `CallWithMetadata` and doesn't have any metadata. - # This approach combines both. - parent = getmetadata(x, CallWithParent) - return CallWithMetadata(operation(x), metadata(parent)) + @assert iscalledparameter(x) + return operation(x) end """ @@ -80,7 +61,7 @@ toparam(s::Num) = wrap(toparam(value(s))) Maps the variable to an unknown. """ -tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) +tovar(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) tovar(s::Union{Num, Symbolics.Arr}) = wrap(tovar(unwrap(s))) """ @@ -91,10 +72,10 @@ Define one or more known parameters. See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref). """ macro parameters(xs...) - Symbolics._parse_vars(:parameters, + Symbolics.parse_vars(:parameters, Real, xs, - toparam) |> esc + toparam) end function find_types(array) diff --git a/src/problems/initializationproblem.jl b/src/problems/initializationproblem.jl index 6960811bbd..5267ad4abc 100644 --- a/src/problems/initializationproblem.jl +++ b/src/problems/initializationproblem.jl @@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const for k in keys(op) has_u0_ics |= is_variable(sys, k) || isdifferential(k) || symbolic_type(k) == ArraySymbolic() && - is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k)))) + symbolic_has_known_size(k) && is_variable(sys, unwrap(first(wrap(k)))) end if !has_u0_ics && get_initializesystem(sys) !== nothing isys = get_initializesystem(sys; initialization_eqs, check_units) diff --git a/src/problems/jumpproblem.jl b/src/problems/jumpproblem.jl index 32aa25182f..6d7eb01f24 100644 --- a/src/problems/jumpproblem.jl +++ b/src/problems/jumpproblem.jl @@ -194,17 +194,17 @@ function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump, Variable end ### Functions to determine which unknowns a jump depends on -function get_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}, variables) - jr = value(jump.rate) - (jr isa Symbolic) && get_variables!(dep, jr, variables) +function SU.search_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}; kw...) + jr = unwrap(jump.rate) + (jr isa SymbolicT) && SU.search_variables!(dep, jr; kw...) dep end -function get_variables!(dep, jump::MassActionJump, variables) - sr = value(jump.scaled_rates) - (sr isa Symbolic) && get_variables!(dep, sr, variables) +function SU.search_variables!(dep, jump::MassActionJump; is_atomic = SU.default_is_atomic, kw...) + sr = unwrap(jump.scaled_rates) + (sr isa SymbolicT) && SU.search_variables!(dep, sr; kw...) for varasop in jump.reactant_stoch - any(isequal(varasop[1]), variables) && push!(dep, varasop[1]) + is_atomic(varasop[1]) && push!(dep, varasop[1]) end dep end diff --git a/src/problems/sccnonlinearproblem.jl b/src/problems/sccnonlinearproblem.jl index d71124adde..8d131a1f31 100644 --- a/src/problems/sccnonlinearproblem.jl +++ b/src/problems/sccnonlinearproblem.jl @@ -1,5 +1,3 @@ -const TypeT = Union{DataType, UnionAll} - struct CacheWriter{F} fn::F end diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 681025cb81..62ae97186f 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -3,18 +3,21 @@ module StructuralTransformations using Setfield: @set!, @set using UnPack: @unpack -using Symbolics: unwrap, linear_expansion, fast_substitute +using Symbolics: unwrap, linear_expansion, VartypeT, SymbolicT import Symbolics using SymbolicUtils +using SymbolicUtils: BSImpl using SymbolicUtils.Code using SymbolicUtils.Rewriters -using SymbolicUtils: maketerm, iscall +using SymbolicUtils: maketerm, iscall, symtype +import SymbolicUtils as SU +import Moshi using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, - unknowns, equations, vars, Symbolic, diff2term_with_unit, + unknowns, equations, vars, diff2term_with_unit, shift2term_with_unit, value, - operation, arguments, Sym, Term, simplify, symbolic_linear_solve, + operation, arguments, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, empty_substitutions, get_substitutions, get_tearing_state, get_iv, independent_variables, @@ -27,7 +30,8 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, - Schedule, schedule, iscomplete, get_schedule + Schedule, schedule, iscomplete, get_schedule, VariableUnshifted, + VariableShift using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete @@ -40,7 +44,7 @@ using ModelingToolkit: algeqs, EquationsView, dervars_range, diffvars_range, algvars_range, DiffGraph, complete!, get_fullvars, system_subset -using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic +using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic, getname using ModelingToolkit.DiffEqBase using ModelingToolkit.StaticArrays diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 871bd99ef4..581c1f1198 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -8,11 +8,11 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) sys = state.sys # Step 1: write derivative equations in_eqs = equations(sys) - out_eqs = Vector{Any}(undef, nv(eq_to_diff)) + out_eqs = Vector{Equation}(undef, nv(eq_to_diff)) fill!(out_eqs, nothing) out_eqs[1:length(in_eqs)] .= in_eqs - out_vars = Vector{Any}(undef, nv(var_to_diff)) + out_vars = Vector{SymbolicT}(undef, nv(var_to_diff)) fill!(out_vars, nothing) out_vars[1:length(fullvars)] .= fullvars @@ -31,13 +31,12 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) out_vars[diff] = D(vi) end - d_dict = Dict(zip(fullvars, 1:length(fullvars))) - lhss = Set{Any}([x.lhs for x in in_eqs if isdiffeq(x)]) + d_dict = Dict{SymbolicT, Int}(zip(fullvars, 1:length(fullvars))) for (eqidx, diff) in edges(eq_to_diff) # LHS variable is looked up from var_to_diff # the var_to_diff[i]-th variable is the differentiated version of var at i eq = out_eqs[eqidx] - lhs = if !(eq.lhs isa Symbolic) + lhs = if !(eq.lhs isa SymbolicT) 0 elseif isdiffeq(eq) # look up the variable that represents D(lhs) @@ -54,9 +53,9 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) D(eq.lhs) end rhs = ModelingToolkit.expand_derivatives(D(eq.rhs)) - rhs = fast_substitute(rhs, state.param_derivative_map) + rhs = substitute(rhs, state.param_derivative_map) substitution_dict = Dict(x.lhs => x.rhs - for x in out_eqs if x !== nothing && x.lhs isa Symbolic) + for x in out_eqs if x !== nothing && x.lhs isa SymbolicT) sub_rhs = substitute(rhs, substitution_dict) out_eqs[diff] = lhs ~ sub_rhs end diff --git a/src/structural_transformation/partial_state_selection.jl b/src/structural_transformation/partial_state_selection.jl index ab8e7f0f3d..5b205f414b 100644 --- a/src/structural_transformation/partial_state_selection.jl +++ b/src/structural_transformation/partial_state_selection.jl @@ -77,10 +77,12 @@ function dummy_derivative_graph!( # only accept small integers to avoid overflow is_all_small_int = all(_J) do x′ x = unwrap(x′) + SU.isconst(x) || return false x isa Number || return false - isinteger(x) && typemin(Int8) <= x <= typemax(Int8) + x = value(x) + isinteger(x) && typemin(Int8) <= Int(x) <= typemax(Int8) end - J = is_all_small_int ? Int.(unwrap.(_J)) : nothing + J = is_all_small_int ? Int.(value.(_J)) : nothing end while true nrows = length(eqs) diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 39b959c5a6..f599c7e455 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -65,15 +65,14 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...) sys = ts.sys eq = equations(ts)[ieq] - eq = 0 ~ fast_substitute( + eq = 0 ~ substitute( ModelingToolkit.derivative( eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map) - vs = ModelingToolkit.vars(eq.rhs) + vs = Set{SymbolicT}() + SU.search_variables!(vs, eq.rhs) for v in vs - # parameters with unknown derivatives have a value of `nothing` in the map, - # so use `missing` as the default. - get(ts.param_derivative_map, v, missing) === nothing || continue + v in ts.no_deriv_params || continue _original_eq = equations(ts)[ieq] error(""" Encountered derivative of discrete variable `$(only(arguments(v)))` when \ @@ -108,7 +107,7 @@ end function solve_equation(eq, var, simplify) rhs = value(symbolic_linear_solve(eq, var; simplify = simplify, check = false)) - occursin(var, rhs) && throw(EquationSolveErrors(eq, var, rhs)) + SU.query(in(var), rhs) && throw(EquationSolveErrors(eq, var, rhs)) var ~ rhs end @@ -144,7 +143,7 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F, eq = 0 ~ eq.rhs - eq.lhs end rhs = eq.rhs - if rhs isa Symbolic + if rhs isa SymbolicT # Check if the RHS is solvable in all unknown variable derivatives and if those # the linear terms for them are all zero. If so, move them to the # LHS. @@ -204,7 +203,7 @@ State selection is done. All non-differentiated variables are algebraic variables, and all variables that appear differentiated are differential variables. """ function substitute_derivatives_algevars!( - ts::TearingState, neweqs, var_eq_matching, dummy_sub; iv = nothing, D = nothing) + ts::TearingState, neweqs::Vector{Equation}, var_eq_matching::Matching, dummy_sub::Dict{SymbolicT, SymbolicT}, iv::Union{Nothing, SymbolicT}, D::Union{Nothing, Differential, Shift}) @unpack fullvars, sys, structure = ts @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) @@ -214,10 +213,10 @@ function substitute_derivatives_algevars!( dv === nothing && continue if var_eq_matching[var] !== SelectedState() dd = fullvars[dv] - v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false) + v_t = setio(diff2term_with_unit(dd, iv), false, false) for eq in 𝑑neighbors(graph, dv) dummy_sub[dd] = v_t - neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t) + neweqs[eq] = substitute(neweqs[eq], dd => v_t) end fullvars[dv] = v_t # If we have: @@ -230,7 +229,7 @@ function substitute_derivatives_algevars!( while (ddx = var_to_diff[dx]) !== nothing dx_t = D(x_t) for eq in 𝑑neighbors(graph, ddx) - neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t) + neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t) end fullvars[ddx] = dx_t dx = ddx @@ -326,17 +325,22 @@ Effects on the system structure: """ function generate_derivative_variables!( ts::TearingState, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs; mm, iv = nothing, D = nothing) + var_sccs, mm::Union{Nothing, SparseMatrixCLIL}, iv::Union{SymbolicT, Nothing}) @unpack fullvars, sys, structure = ts @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) diff_to_var = invview(var_to_diff) is_discrete = is_only_discrete(structure) - linear_eqs = mm === nothing ? Dict{Int, Int}() : - Dict(reverse(en) for en in enumerate(mm.nzrows)) + linear_eqs = Dict{Int, Int}() + if mm !== nothing + for (i, e) in enumerate(mm.nzrows) + linear_eqs[e] = i + end + end # We need the inverse mapping of `var_sccs` to update it efficiently later. - v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph)) + v_to_scc = NTuple{2, Int}[] + resize!(v_to_scc, ndsts(graph)) for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc) v_to_scc[v] = (i, j) @@ -424,9 +428,7 @@ function generate_derivative_variables!( end new_sccs = insert_sccs(var_sccs, sccs_to_insert) - if mm !== nothing - @set! mm.ncols = ndsts(graph) - end + @set! mm.ncols = ndsts(graph) return new_sccs end @@ -539,26 +541,28 @@ Reorder the equations and unknowns to be in the BLT sorted form. Return the new equations, the solved equations, the new orderings, and the number of solved variables and equations. """ -function generate_system_equations!(state::TearingState, neweqs, var_eq_matching, - full_var_eq_matching, var_sccs, extra_eqs_vars; - simplify = false, iv = nothing, D = nothing) +function generate_system_equations!(state::TearingState, neweqs::Vector{Equation}, + var_eq_matching::Matching, full_var_eq_matching::Matching, + var_sccs::Vector{Vector{Int}}, extra_eqs_vars::NTuple{2, Vector{Int}}, + iv::Union{SymbolicT, Nothing}, D::Union{Differential, Shift, Nothing}; + simplify::Bool = false) @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure eq_var_matching = invview(var_eq_matching) - full_eq_var_matching = invview(full_var_eq_matching) diff_to_var = invview(var_to_diff) extra_eqs, extra_vars = extra_eqs_vars - total_sub = Dict() + total_sub = Dict{SymbolicT, SymbolicT}() if is_only_discrete(structure) for (i, v) in enumerate(fullvars) - op = operation(v) - op isa Shift && (op.steps < 0) && - begin + Moshi.Match.@match v begin + BSImpl.Term(; f) && if f isa Shift && f.steps < 0 end => begin lowered = lower_shift_varname_with_unit(v, iv) total_sub[v] = lowered fullvars[i] = lowered end + _ => nothing + end end end @@ -579,13 +583,11 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching end digraph = DiCMOBiGraph{false}(graph, var_eq_matching) - idep = iv for (i, scc) in enumerate(var_sccs) # note that the `vscc <-> escc` relation is a set-to-set mapping, and not # point-to-point. vscc, escc = get_sorted_scc(digraph, full_var_eq_matching, var_eq_matching, scc) var_sccs[i] = vscc - if length(escc) != length(vscc) isempty(escc) && continue escc = setdiff(escc, extra_eqs) @@ -594,11 +596,10 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching isempty(vscc) && continue end - offset = 1 for ieq in escc iv = eq_var_matching[ieq] - eq = neweqs[ieq] - codegen_equation!(eq_generator, neweqs[ieq], ieq, iv; simplify) + neq = neweqs[ieq] + codegen_equation!(eq_generator, neq, ieq, iv; simplify) end end @@ -620,14 +621,16 @@ function generate_system_equations!(state::TearingState, neweqs, var_eq_matching solved_vars_set = BitSet(solved_vars) # We filled zeros for algebraic variables, so fill them properly here offset = 1 + findnextfn = let diff_vars_set = diff_vars_set, solved_vars_set = solved_vars_set, + diff_to_var = diff_to_var, ispresent = ispresent + j -> !(j in diff_vars_set || j in solved_vars_set) && diff_to_var[j] === nothing && + ispresent(j) + end for (i, v) in enumerate(var_ordering) v == 0 || continue # find the next variable which is not differential or solved, is not the # derivative of another variable and is present in the equations - index = findnext(1:ndsts(graph), offset) do j - !(j in diff_vars_set || j in solved_vars_set) && diff_to_var[j] === nothing && - ispresent(j) - end + index = findnext(findnextfn, 1:ndsts(graph), offset) # in case of overdetermined systems, this may not be present index === nothing && break var_ordering[i] = index @@ -649,11 +652,20 @@ Sort the provided SCC `scc`, given the `digraph` of the system constructed using function get_sorted_scc( digraph::DiCMOBiGraph, full_var_eq_matching::Matching, var_eq_matching::Matching, scc::Vector{Int}) eq_var_matching = invview(var_eq_matching) - full_eq_var_matching = invview(full_var_eq_matching) # obtain the matched equations in the SCC - scc_eqs = Int[full_var_eq_matching[v] for v in scc if full_var_eq_matching[v] isa Int] + scc_eqs = Int[] # obtain the equations in the SCC that are linearly solvable - scc_solved_eqs = Int[var_eq_matching[v] for v in scc if var_eq_matching[v] isa Int] + scc_solved_eqs = Int[] + for v in scc + e = full_var_eq_matching[v] + if e isa Int + push!(scc_eqs, e) + end + e = var_eq_matching[v] + if e isa Int + push!(scc_solved_eqs, e) + end + end # obtain the subgraph of the contracted graph involving the solved equations subgraph, varmap = Graphs.induced_subgraph(digraph, scc_solved_eqs) # topologically sort the solved equations and append the remainder @@ -661,7 +673,13 @@ function get_sorted_scc( setdiff(scc_eqs, scc_solved_eqs)] # the variables of the SCC are obtained by inverse mapping the sorted equations # and appending the rest - scc_vars = [eq_var_matching[e] for e in scc_eqs if eq_var_matching[e] isa Int] + scc_vars = Int[] + for e in scc_eqs + v = eq_var_matching[e] + if v isa Int + push!(scc_vars, v) + end + end append!(scc_vars, setdiff(scc, scc_vars)) return scc_vars, scc_eqs end @@ -672,7 +690,7 @@ end Struct containing the information required to generate equations of a system, as well as the generated equations and associated metadata. """ -struct EquationGenerator{S, D, I} +struct EquationGenerator{S} """ `TearingState` of the system. """ @@ -681,15 +699,15 @@ struct EquationGenerator{S, D, I} Substitutions to perform in all subsequent equations. For each differential equation `D(x) ~ f(..)`, the substitution `D(x) => f(..)` is added to the rules. """ - total_sub::Dict{Any, Any} + total_sub::Dict{SymbolicT, SymbolicT} """ The differential operator, or `nothing` if not applicable. """ - D::D + D::Union{Differential, Shift, Nothing} """ The independent variable, or `nothing` if not applicable. """ - idep::I + idep::Union{SymbolicT, Nothing} """ The new generated equations of the system. """ @@ -829,19 +847,17 @@ Generate a first-order differential equation whose LHS is `dx`. `var` and `dx` represent the same variable, but `var` may be a higher-order differential and `dx` is always first-order. For example, if `var` is D(D(x)), then `dx` would be `D(x_t)`. Solve `eq` for `var`, substitute previously solved variables, and return the differential equation. """ function make_differential_equation(var, dx, eq, total_sub) - dx ~ simplify_shifts(Symbolics.fixpoint_sub( - Symbolics.symbolic_linear_solve(eq, var), - total_sub; operator = ModelingToolkit.Shift)) + v1 = Symbolics.symbolic_linear_solve(eq, var)::SymbolicT + v2 = Symbolics.fixpoint_sub(v1, total_sub; operator = ModelingToolkit.Shift) + v3 = simplify_shifts(v2) + dx ~ v3 end """ Generate an algebraic equation. Substitute solved variables into `eq` and return the equation. """ function make_algebraic_equation(eq, total_sub) - rhs = eq.rhs - if !(eq.lhs isa Number && eq.lhs == 0) - rhs = eq.rhs - eq.lhs - end + rhs = eq.rhs - eq.lhs 0 ~ simplify_shifts(Symbolics.fixpoint_sub(rhs, total_sub)) end @@ -910,7 +926,7 @@ function reorder_vars!(state::TearingState, var_eq_matching, var_sccs, eq_orderi var_ordering_set = BitSet(var_ordering) for scc in var_sccs # Map variables to their new indices - map!(v -> varsperm[v], scc, scc) + map!(Base.Fix1(getindex, varsperm), scc, scc) # Remove variables not in the reduced set filter!(!iszero, scc) end @@ -929,16 +945,20 @@ end Update the system equations, unknowns, and observables after simplification. """ function update_simplified_system!( - state::TearingState, neweqs, solved_eqs, dummy_sub, var_sccs, extra_unknowns; - array_hack = true, D = nothing, iv = nothing) + state::TearingState, neweqs::Vector{Equation}, solved_eqs::Vector{Equation}, + dummy_sub::Dict{SymbolicT, SymbolicT}, var_sccs::Vector{Vector{Int}}, + extra_unknowns::Vector{SymbolicT}, iv::Union{SymbolicT, Nothing}, + D::Union{Differential, Shift, Nothing}; array_hack = true) @unpack fullvars, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure diff_to_var = invview(var_to_diff) # Since we solved the highest order derivative variable in discrete systems, # we make a list of the solved variables and avoid including them in the # unknowns. - solved_vars = Set() + solved_vars = Set{SymbolicT}() if is_only_discrete(structure) + iv = iv::SymbolicT + D = D::Shift for eq in solved_eqs var = eq.lhs if isequal(eq.lhs, eq.rhs) @@ -961,25 +981,28 @@ function update_simplified_system!( obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; - fast_substitute(state.additional_observed, obs_sub)] + obs = [substitute(observed(sys), obs_sub); solved_eqs; + substitute(state.additional_observed, obs_sub)] - unknown_idxs = filter( - i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars)) + filterer = let diff_to_var = diff_to_var, ispresent = ispresent, fullvars = fullvars, + solved_vars = solved_vars + i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars) + end + unknown_idxs = filter(filterer, eachindex(state.fullvars)) unknowns = state.fullvars[unknown_idxs] unknowns = [unknowns; extra_unknowns] if is_only_discrete(structure) # Algebraic variables are shifted forward by one, so we backshift them. - unknowns = map(enumerate(unknowns)) do (i, var) - if iscall(var) && operation(var) isa Shift && operation(var).steps == 1 - # We might have shifted a variable with io metadata. That is irrelevant now - # because we handled io variables earlier in `_mtkcompile!` so just ignore - # it here. - setio(backshift_expr(var, iv), false, false) - else - var + _unknowns = SymbolicT[] + for var in unknowns + Moshi.Match.@match var begin + BSImpl.Term(; f, args, type, shape, metadata) && if f isa Shift && f.steps == 1 end => begin + push!(_unknowns, setio(args[1], false, false)) + end + _ => push!(_unknowns, var) end end + unknowns = _unknowns end @set! sys.unknowns = unknowns @@ -1044,7 +1067,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, extra_eqs_vars = get_extra_eqs_vars( state, var_eq_matching, full_var_eq_matching, fully_determined) neweqs = collect(equations(state)) - dummy_sub = Dict() + dummy_sub = Dict{SymbolicT, SymbolicT}() if ModelingToolkit.has_iv(state.sys) iv = get_iv(state.sys) @@ -1060,29 +1083,28 @@ function tearing_reassemble(state::TearingState, var_eq_matching::Matching, extra_unknowns = state.fullvars[extra_eqs_vars[2]] if is_only_discrete(state.structure) var_sccs = add_additional_history!( - state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; iv, D) + state, var_eq_matching, full_var_eq_matching, var_sccs, iv) end # Structural simplification - substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub; iv, D) + substitute_derivatives_algevars!(state, neweqs, var_eq_matching, dummy_sub, iv, D) var_sccs = generate_derivative_variables!( - state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs; mm, iv, D) - + state, neweqs, var_eq_matching, full_var_eq_matching, var_sccs, mm, iv) neweqs, solved_eqs, eq_ordering, var_ordering, nelim_eq, nelim_var = generate_system_equations!( state, neweqs, var_eq_matching, full_var_eq_matching, - var_sccs, extra_eqs_vars; simplify, iv, D) + var_sccs, extra_eqs_vars, iv, D; simplify) state = reorder_vars!( state, var_eq_matching, var_sccs, eq_ordering, var_ordering, nelim_eq, nelim_var) # var_eq_matching and full_var_eq_matching are now invalidated sys = update_simplified_system!(state, neweqs, solved_eqs, dummy_sub, var_sccs, - extra_unknowns; array_hack, iv, D) + extra_unknowns, iv, D; array_hack) @set! state.sys = sys @set! sys.tearing_state = state @@ -1120,24 +1142,23 @@ x(k) ~ x(k-1) + x(k-2) Where the last equation is the observed equation. """ function add_additional_history!( - state::TearingState, neweqs::Vector, var_eq_matching::Matching, - full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; iv, D) + state::TearingState, var_eq_matching::Matching, + full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}, iv::Union{SymbolicT, Nothing}) + iv === nothing && return var_sccs + iv = iv::SymbolicT @unpack fullvars, sys, structure = state @unpack solvable_graph, var_to_diff, eq_to_diff, graph = structure - eq_var_matching = invview(var_eq_matching) diff_to_var = invview(var_to_diff) - is_discrete = is_only_discrete(structure) - digraph = DiCMOBiGraph{false}(graph, var_eq_matching) # We need the inverse mapping of `var_sccs` to update it efficiently later. - v_to_scc = Vector{NTuple{2, Int}}(undef, ndsts(graph)) + v_to_scc = NTuple{2, Int}[] + resize!(v_to_scc, ndsts(graph)) for (i, scc) in enumerate(var_sccs), (j, v) in enumerate(scc) v_to_scc[v] = (i, j) end vars_to_backshift = BitSet() - eqs_to_backshift = BitSet() # add history for differential variables for ivar in 1:length(fullvars) ieq = var_eq_matching[ivar] @@ -1189,7 +1210,7 @@ end Backshift the given expression `ex`. """ function backshift_expr(ex, iv) - ex isa Symbolic || return ex + ex isa SymbolicT || return ex return descend_lower_shift_varname_with_unit( simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv) end @@ -1243,7 +1264,7 @@ occurs in observed equations (and unknowns if it's split). function tearing_hacks(sys, obs, unknowns, neweqs; array = true) # map of array observed variable (unscalarized) to number of its # scalarized terms that appear in observed equations - arr_obs_occurrences = Dict() + arr_obs_occurrences = Dict{SymbolicT, Int}() for (i, eq) in enumerate(obs) lhs = eq.lhs rhs = eq.rhs @@ -1251,7 +1272,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) array || continue iscall(lhs) || continue operation(lhs) === getindex || continue - Symbolics.shape(lhs) != Symbolics.Unknown() || continue + SU.shape(lhs) isa SU.Unknown && continue arg1 = arguments(lhs)[1] cnt = get(arr_obs_occurrences, arg1, 0) arr_obs_occurrences[arg1] = cnt + 1 @@ -1264,7 +1285,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) for sym in unknowns iscall(sym) || continue operation(sym) === getindex || continue - Symbolics.shape(sym) != Symbolics.Unknown() || continue + SU.shape(sym) isa SU.Unknown && continue arg1 = arguments(sym)[1] cnt = get(arr_obs_occurrences, arg1, 0) cnt == 0 && continue diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 3fa4f28aa9..c0807ae604 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -239,28 +239,40 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no fullvars = state.fullvars @unpack graph, solvable_graph = state.structure eq = equations(state)[ieq] - term = value(eq.rhs - eq.lhs) + term = unwrap(eq.rhs - eq.lhs) all_int_vars = true coeffs === nothing || empty!(coeffs) empty!(to_rm) + __indexed_fullvar_is_var = let fullvars = fullvars + function indexed_fullvar_is_var(x::SymbolicT) + for v in fullvars + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex && isequal(args[1], x) end => return true + _ => nothing + end + end + return false + end + end + __allow_sym_par_cond = let fullvars = fullvars, is_atomic = ModelingToolkit.OperatorIsAtomic{Union{Differential, Shift, Pre, Sample, Hold, Initial}}(), __indexed_fullvar_is_var = __indexed_fullvar_is_var + function allow_sym_par_cond(v) + is_atomic(v) && any(isequal(v), fullvars) || + symbolic_type(v) == ArraySymbolic() && (SU.shape(v) isa SU.Unknown || + __indexed_fullvar_is_var(v)) + end + end for j in 𝑠neighbors(graph, ieq) var = fullvars[j] isirreducible(var) && (all_int_vars = false; continue) a, b, islinear = linear_expansion(term, var) - a, b = unwrap(a), unwrap(b) + islinear || (all_int_vars = false; continue) - if a isa Symbolic + if !SU.isconst(a) all_int_vars = false if !allow_symbolic if allow_parameter # if any of the variables in `a` are present in fullvars (taking into account arrays) - if any( - v -> any(isequal(v), fullvars) || - symbolic_type(v) == ArraySymbolic() && - Symbolics.shape(v) != Symbolics.Unknown() && - any(x -> any(isequal(x), fullvars), collect(v)), - vars( - a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial})) + if SU.query(__allow_sym_par_cond, a) continue end else @@ -270,20 +282,20 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no add_edge!(solvable_graph, ieq, j) continue end - if !(a isa Number) + if !(symtype(a) <: Number) all_int_vars = false continue end # When the expression is linear with numeric `a`, then we can safely # only consider `b` for the following iterations. term = b - if isone(abs(a)) - coeffs === nothing || push!(coeffs, convert(Int, a)) + if SU._isone(abs(unwrap_const(a))) + coeffs === nothing || push!(coeffs, convert(Int, unwrap_const(a))) else all_int_vars = false conservative && continue end - if a != 0 + if !SU._iszero(a) add_edge!(solvable_graph, ieq, j) else if may_be_zero @@ -503,43 +515,55 @@ end """ Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t). """ -function shift2term(var) - iscall(var) || return var - op = operation(var) - op isa Shift || return var - iv = op.t - arg = only(arguments(var)) - if operation(arg) === getindex - idxs = arguments(arg)[2:end] - newvar = shift2term(op(first(arguments(arg))))[idxs...] - unshifted = ModelingToolkit.getunshifted(newvar)[idxs...] - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted) - return newvar - end - is_lowered = !isnothing(ModelingToolkit.getunshifted(arg)) - - backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps - - # Char(0x208b) = ₋ (subscripted minus) - # Char(0x208a) = ₊ (subscripted plus) - pm = backshift > 0 ? Char(0x208a) : Char(0x208b) - # subscripted number, e.g. ₁ - num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift)))) - # Char(0x209c) = ₜ - # ds = ₜ₋₁ - ds = join([Char(0x209c), pm, num]) - - O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg - oldop = operation(O) - newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : - Symbol(string(nameof(oldop))) - - newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), - Symbolics.children(O), Symbolics.metadata(O)) - newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname)) - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) - newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) - return newvar +function shift2term(var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(f, args) && if f isa Shift end => begin + op = f + arg = args[1] + Moshi.Match.@match arg begin + BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = shift2term(op(newargs[1])) + unshifted_args = copy(newargs) + unshifted_args[1] = ModelingToolkit.getunshifted(newargs[1]) + unshifted = BSImpl.Term{VartypeT}(getindex, unshifted_args; type, shape, metadata) + if metadata === nothing + metadata = Base.ImmutableDict{DataType, Any}(VariableUnshifted, unshifted) + elseif metadata isa Base.ImmutableDict{DataType, Any} + metadata = Base.ImmutableDict(metadata, VariableUnshifted, unshifted) + end + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + end + _ => nothing + end + unshifted = ModelingToolkit.getunshifted(arg) + is_lowered = unshifted !== nothing + backshift = op.steps + ModelingToolkit.getshift(arg) + io = IOBuffer() + O = (is_lowered ? unshifted : arg)::SymbolicT + write(io, getname(O)) + # Char(0x209c) = ₜ + write(io, Char(0x209c)) + # Char(0x208b) = ₋ (subscripted minus) + # Char(0x208a) = ₊ (subscripted plus) + pm = backshift > 0 ? Char(0x208a) : Char(0x208b) + write(io, pm) + backshift = abs(backshift) + N = ndigits(backshift) + den = 10 ^ (N - 1) + for _ in 1:N + # subscripted number, e.g. ₁ + write(io, Char(0x2080 + div(backshift, den) % 10)) + den = div(den, 10) + end + newname = Symbol(take!(io)) + newvar = Symbolics.rename(var, newname) + newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) + newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) + return newvar + end + _ => return var + end end function isdoubleshift(var) @@ -547,27 +571,35 @@ function isdoubleshift(var) ModelingToolkit.isoperator(arguments(var)[1], ModelingToolkit.Shift) end +simplify_shifts(eq::Equation) = simplify_shifts(eq.lhs) ~ simplify_shifts(eq.rhs) + +function _simplify_shifts(var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift && f.steps == 0 end => return args[1] + BSImpl.Term(; f = op1, args) && if op1 isa Shift end => begin + vv1 = args[1] + Moshi.Match.@match vv1 begin + BSImpl.Term(; f = op2, args = a2) && if op2 isa Shift end => begin + vv2 = a2[1] + s1 = op1.steps + s2 = op2.steps + t1 = op1.t + t2 = op2.t + return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2)) + end + _ => return var + end + end + _ => var + end +end + """ Simplify multiple shifts: Shift(t, k1)(Shift(t, k2)(x)) becomes Shift(t, k1+k2)(x). """ -function simplify_shifts(var) +function simplify_shifts(var::SymbolicT) ModelingToolkit.hasshift(var) || return var - var isa Equation && return simplify_shifts(var.lhs) ~ simplify_shifts(var.rhs) - (op = operation(var)) isa Shift && op.steps == 0 && return first(arguments(var)) - if isdoubleshift(var) - op1 = operation(var) - vv1 = arguments(var)[1] - op2 = operation(vv1) - vv2 = arguments(vv1)[1] - s1 = op1.steps - s2 = op2.steps - t1 = op1.t - t2 = op2.t - return simplify_shifts(ModelingToolkit.Shift(t1 === nothing ? t2 : t1, s1 + s2)(vv2)) - else - return maketerm(typeof(var), operation(var), simplify_shifts.(arguments(var)), - unwrap(var).metadata) - end + return SU.Rewriters.Postwalk(_simplify_shifts)(var) end """ diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0bd05bb4b9..73ba6712b0 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -69,7 +69,7 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end -const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___) +const MTKPARAMETERS_ARG = SSym(:___mtkparameters___; type = Vector{Vector{Any}}, shape = SymbolicUtils.Unknown(1)) """ $(TYPEDSIGNATURES) @@ -94,11 +94,11 @@ See also [`@independent_variables`](@ref) and [`ModelingToolkit.get_iv`](@ref). """ function independent_variables(sys::AbstractSystem) if isdefined(sys, :iv) && getfield(sys, :iv) !== nothing - return [getfield(sys, :iv)] + return SymbolicT[getfield(sys, :iv)] elseif isdefined(sys, :ivs) - return getfield(sys, :ivs) + return unwrap.(getfield(sys, :ivs))::Vector{SymbolicT} else - return [] + return SymbolicT[] end end @@ -170,17 +170,20 @@ function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem) return unknowns(sys) end -function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) - sym = unwrap(sym) +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + is_parameter(sys, unwrap(sym)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Int) + !is_split(sys) && sym in 1:length(parameter_symbols(sys)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::SymbolicT) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - iscall(sym) && - operation(sym) === getindex && + iscall(sym) && operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end - if unwrap(sym) isa Int - return unwrap(sym) in 1:length(parameter_symbols(sys)) - end return any(isequal(sym), parameter_symbols(sys)) || hasname(sym) && !(iscall(sym) && operation(sym) == getindex) && is_parameter(sys, getname(sym)) @@ -191,7 +194,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol return is_parameter(ic, sym) end - named_parameters = [getname(x) + named_parameters = Symbol[getname(x) for x in parameter_symbols(sys) if hasname(x) && !(iscall(x) && operation(x) == getindex)] return any(isequal(sym), named_parameters) || @@ -200,6 +203,8 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol Symbol.(nameof(sys), NAMESPACE_SEPARATOR_SYMBOL, named_parameters)) == 1 end +SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) = false + function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) sym = unwrap(sym) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing @@ -312,7 +317,8 @@ for traitT in [ ArraySymbolic ] @eval function _all_ts_idxs!(ts_idxs, ::$traitT, sys, sym) - allsyms = vars(sym; op = Symbolics.Operator) + allsyms = Set{SymbolicT}() + SU.search_variables!(allsyms, sym; is_atomic = OperatorIsAtomic{Symbolics.Operator}()) for s in allsyms s = unwrap(s) if is_variable(sys, s) || is_independent_variable(sys, s) @@ -488,7 +494,8 @@ of a system. See the documentation section on initialization for more informatio struct Initial <: Symbolics.Operator end is_timevarying_operator(::Type{Initial}) = false Initial(x) = Initial()(x) -SymbolicUtils.promote_symtype(::Type{Initial}, T) = T +SymbolicUtils.promote_symtype(::Initial, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Initial, @nospecialize(x::SU.ShapeT)) = x SymbolicUtils.isbinop(::Initial) = false Base.nameof(::Initial) = :Initial Base.show(io::IO, x::Initial) = print(io, "Initial") @@ -507,16 +514,16 @@ function (f::Initial)(x) end # don't double wrap iscall(x) && operation(x) isa Initial && return x - result = if symbolic_type(x) == ArraySymbolic() - # create an array for `Initial(array)` - Symbolics.array_term(f, x) - elseif iscall(x) && operation(x) == getindex + sh = SU.shape(x) + result = if SU.is_array_shape(sh) + term(f, x; type = symtype(x), shape = sh) + elseif iscall(x) && operation(x) === getindex # instead of `Initial(x[1])` create `Initial(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, f(arr), arguments(x)[2:end]...) + f(arr)[arguments(x)[2:end]...] else - term(f, x) + term(f, x; type = symtype(x), shape = sh) end # the result should be a parameter result = toparam(result) @@ -526,15 +533,6 @@ function (f::Initial)(x) return result end -# This is required so `fast_substitute` works -function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta) - val = Initial()(args...) - if symbolic_type(val) == NotSymbolic() - return val - end - return metadata(val, meta) -end - supports_initialization(sys::AbstractSystem) = true function add_initialization_parameters(sys::AbstractSystem; split = true) @@ -542,16 +540,20 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) supports_initialization(sys) || return sys is_initializesystem(sys) && return sys - all_initialvars = Set{BasicSymbolic}() + all_initialvars = Set{SymbolicT}() # time-independent systems don't initialize unknowns # but may initialize parameters using guesses for unknowns eqs = equations(sys) - if !(eqs isa Vector{Equation}) - eqs = Equation[x for x in eqs if x isa Equation] - end obs, eqs = unhack_observed(observed(sys), eqs) - for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs))) - x = unwrap(x) + for x in unknowns(sys) + if iscall(x) && operation(x) == getindex && split + push!(all_initialvars, arguments(x)[1]) + else + push!(all_initialvars, x) + end + end + for eq in obs + x = eq.lhs if iscall(x) && operation(x) == getindex && split push!(all_initialvars, arguments(x)[1]) else @@ -561,15 +563,19 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) # add derivatives of all variables for steady-state initial conditions if is_time_dependent(sys) && !is_discrete_system(sys) - D = Differential(get_iv(sys)) - union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)]) + D = Differential(get_iv(sys)::SymbolicT) + for v in collect(all_initialvars) + iscall(v) && push!(all_initialvars, D(v)) + end end for eq in get_parameter_dependencies(sys) is_variable_floatingpoint(eq.lhs) || continue push!(all_initialvars, eq.lhs) end - all_initialvars = collect(all_initialvars) - initials = map(Initial(), all_initialvars) + initials = collect(all_initialvars) + for (i, v) in enumerate(initials) + initials[i] = Initial()(v) + end @set! sys.ps = unique!([get_ps(sys); initials]) defs = copy(get_defaults(sys)) for ivar in initials @@ -577,7 +583,8 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) defs[ivar] = false else defs[ivar] = collect(ivar) - for scal_ivar in defs[ivar] + for idx in SU.stable_eachindex(ivar) + scal_ivar = ivar[idx] defs[scal_ivar] = false end end @@ -601,9 +608,9 @@ end Find [`GlobalScope`](@ref)d variables in `sys` and add them to the unknowns/parameters. """ function discover_globalscoped(sys::AbstractSystem) - newunknowns = OrderedSet() - newparams = OrderedSet() - iv = has_iv(sys) ? get_iv(sys) : nothing + newunknowns = OrderedSet{SymbolicT}() + newparams = OrderedSet{SymbolicT}() + iv::Union{SymbolicT, Nothing} = has_iv(sys) ? get_iv(sys) : nothing collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1) setdiff!(newunknowns, observables(sys)) @set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams))) @@ -626,34 +633,36 @@ This namespacing functionality can also be toggled independently of `complete` using [`toggle_namespacing`](@ref). """ function complete( - sys::AbstractSystem; split = true, flatten = true, add_initial_parameters = true) + sys::T; split = true, flatten = true, add_initial_parameters = true) where {T <: AbstractSystem} sys = discover_globalscoped(sys) if flatten - eqs = equations(sys) - if eqs isa AbstractArray && eltype(eqs) <: Equation - newsys = expand_connections(sys) - else - newsys = sys - end + newsys = expand_connections(sys) newsys = ModelingToolkit.flatten(newsys) if has_parent(newsys) && get_parent(sys) === nothing - @set! newsys.parent = complete(sys; split = false, flatten = false) + @set! newsys.parent = complete(sys; split = false, flatten = false)::T end sys = newsys - sys = process_parameter_equations(sys) + sys = process_parameter_equations(sys)::T if add_initial_parameters - sys = add_initialization_parameters(sys; split) + sys = add_initialization_parameters(sys; split)::T end + cb_alg_eqs = Equation[alg_equations(sys); observed(sys)] if has_continuous_events(sys) && is_time_dependent(sys) - @set! sys.continuous_events = complete.( - get_continuous_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + cevts = SymbolicContinuousCallback[] + for ev in get_continuous_events(sys) + ev = complete(ev; iv = get_iv(sys)::SymbolicT, alg_eqs = cb_alg_eqs) + push!(cevts, ev) + end + @set! sys.continuous_events = cevts end if has_discrete_events(sys) && is_time_dependent(sys) - @set! sys.discrete_events = complete.( - get_discrete_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + devts = SymbolicDiscreteCallback[] + for ev in get_discrete_events(sys) + ev = complete(ev; iv = get_iv(sys)::SymbolicT, alg_eqs = cb_alg_eqs) + push!(devts, ev) + end + @set! sys.discrete_events = devts end end if split && has_index_cache(sys) @@ -661,39 +670,48 @@ function complete( # Ideally we'd do `get_ps` but if `flatten = false` # we don't get all of them. So we call `parameters`. all_ps = parameters(sys; initial_parameters = true) + all_ps_set = Set{SymbolicT}(all_ps) # inputs have to be maintained in a specific order input_vars = inputs(sys) if !isempty(all_ps) # reorder parameters by portions - ps_split = reorder_parameters(sys, all_ps) + ps_split = Vector{Vector{SymbolicT}}(reorder_parameters(sys, all_ps)) # if there are tunables, they will all be in `ps_split[1]` # and the arrays will have been scalarized - ordered_ps = eltype(all_ps)[] + ordered_ps = SymbolicT[] + offset = 0 # if there are no tunables, vcat them if !isempty(get_index_cache(sys).tunable_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[offset + 1], all_ps_set) + offset += 1 end # unflatten initial parameters if !isempty(get_index_cache(sys).initials_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[offset + 1], all_ps_set) + offset += 1 + end + for i in (offset+1):length(ps_split) + append!(ordered_ps, ps_split[i]::Vector{SymbolicT}) end - ordered_ps = vcat( - ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[])) if isscheduled(sys) # ensure inputs are sorted - input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,)) - @assert all(!isnothing, input_idxs) - @assert issorted(input_idxs) + last_idx = 0 + for p in input_vars + idx = findfirst(isequal(p), ordered_ps)::Int + @assert last_idx < idx + last_idx = idx + end end @set! sys.ps = ordered_ps end elseif has_index_cache(sys) @set! sys.index_cache = nothing end - if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing - @set! sys.initializesystem = complete(get_initializesystem(sys); split) + if has_initializesystem(sys) + isys = get_initializesystem(sys) + if isys isa T + @set! sys.initializesystem = complete(isys::T; split) + end end sys = toggle_namespacing(sys, false; safe = true) isdefined(sys, :complete) ? (@set! sys.complete = true) : sys @@ -722,26 +740,28 @@ parameters in the system `all_ps`, unscalarize the elements in `params` and appe to `buffer` in the same order as they are present in `params`. Effectively, if `params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`. """ -function unflatten_parameters!(buffer, params, all_ps) +function unflatten_parameters!(buffer::Vector{SymbolicT}, params::Vector{SymbolicT}, all_ps::Set{SymbolicT}) i = 1 # go through all the tunables while i <= length(params) sym = params[i] # if the sym is not a scalarized array symbolic OR it was already scalarized, # just push it as-is - if !iscall(sym) || operation(sym) != getindex || - any(isequal(sym), all_ps) + if !iscall(sym) || operation(sym) !== getindex || sym in all_ps push!(buffer, sym) i += 1 continue end + + arrsym = first(arguments(sym)) # the next `length(sym)` symbols should be scalarized versions of the same # array symbolic - if !allequal(first(arguments(x)) - for x in view(params, i:(i + length(sym) - 1))) - error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + for j in (i+1):(i+length(sym)-1) + p = params[j] + if !(iscall(p) && operation(p) === getindex && isequal(arguments(p)[1], arrsym)) + error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + end end - arrsym = first(arguments(sym)) push!(buffer, arrsym) i += length(arrsym) end @@ -958,8 +978,12 @@ function getvar(sys::AbstractSystem, name::Symbol; namespace = does_namespacing( if has_eqs(sys) for eq in get_eqs(sys) eq isa Equation || continue - if eq.lhs isa AnalysisPoint && nameof(eq.rhs) == name - return namespace ? renamespace(sys, eq.rhs) : eq.rhs + lhs = value(eq.lhs) + rhs = value(eq.rhs) + if value(lhs) isa AnalysisPoint + rhs = rhs::AnalysisPoint + nameof(rhs) == name || continue + return namespace ? renamespace(sys, rhs) : rhs end end end @@ -1019,7 +1043,7 @@ struct LocalScope <: SymScope end Apply `LocalScope` to `sym`. """ -function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function LocalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1051,7 +1075,7 @@ end Apply `ParentScope` to `sym`, with `parent` being `LocalScope`. """ -function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function ParentScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1081,7 +1105,7 @@ struct GlobalScope <: SymScope end Apply `GlobalScope` to `sym`. """ -function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function GlobalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) == getindex args = arguments(sym) @@ -1094,44 +1118,61 @@ function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) end end +const AllScopes = Union{LocalScope, ParentScope, GlobalScope} + renamespace(sys, eq::Equation) = namespace_equation(eq, sys) renamespace(names::AbstractVector, x) = foldr(renamespace, names, init = x) +renamespace(sys, tgt::AbstractSystem) = rename(tgt, renamespace(sys, nameof(tgt))) +renamespace(sys, tgt::Symbol) = Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, tgt) +renamespace(sys, x::Num) = Num(renamespace(sys, unwrap(x))) +renamespace(sys, x::Arr{T, N}) where {T, N} = Arr{T, N}(renamespace(sys, unwrap(x))) +renamespace(sys, x::CallAndWrap{T}) where {T} = CallAndWrap{T}(renamespace(sys, unwrap(x))) + """ $(TYPEDSIGNATURES) Namespace `x` with the name of `sys`. """ -function renamespace(sys, x) - sys === nothing && return x - x = unwrap(x) - if x isa Symbolic - T = typeof(x) - if iscall(x) && operation(x) isa Operator - return maketerm(typeof(x), operation(x), - Any[renamespace(sys, only(arguments(x)))], - metadata(x))::T - end - if iscall(x) && operation(x) === getindex - args = arguments(x) - return maketerm( - typeof(x), operation(x), vcat(renamespace(sys, args[1]), args[2:end]), - metadata(x))::T - end - let scope = getmetadata(x, SymScope, LocalScope()) +function renamespace(sys, x::SymbolicT) + isequal(x, SU.idxs_for_arrayop(VartypeT)) && return x + Moshi.Match.@match x begin + BSImpl.Sym(; name) => let scope = getmetadata(x, SymScope, LocalScope())::AllScopes if scope isa LocalScope - rename(x, renamespace(getname(sys), getname(x)))::T + return rename(x, renamespace(getname(sys), name))::SymbolicT elseif scope isa ParentScope - setmetadata(x, SymScope, scope.parent)::T - else # GlobalScope - x::T + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x end + error() + end + BSImpl.Term(; f, args, shape, type, metadata) => begin + if f === getindex + newargs = copy(parent(args)) + newargs[1] = renamespace(sys, args[1]) + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + elseif f isa SymbolicT + let scope = getmetadata(x, SymScope, LocalScope())::Union{LocalScope, ParentScope, GlobalScope} + if scope isa LocalScope + return rename(x, renamespace(getname(sys), getname(x)))::SymbolicT + elseif scope isa ParentScope + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x + end + error() + end + elseif f isa Operator + newargs = copy(parent(args)) + for (i, arg) in enumerate(args) + newargs[i] = renamespace(sys, arg) + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata) + end + error() end - elseif x isa AbstractSystem - rename(x, renamespace(sys, nameof(x))) - else - Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, x) end end @@ -1156,8 +1197,14 @@ Return `equations(sys)`, namespaced by the name of `sys`. """ function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys)) eqs = equations(sys) - isempty(eqs) && return Equation[] - map(eq -> namespace_equation(eq, sys; ivs), eqs) + isempty(eqs) && return eqs + if eqs === get_eqs(sys) + eqs = copy(eqs) + end + for i in eachindex(eqs) + eqs[i] = namespace_equation(eqs[i], sys; ivs) + end + return eqs end function namespace_initialization_equations( @@ -1204,11 +1251,26 @@ function namespace_jump(j::MassActionJump, sys) end function namespace_jumps(sys::AbstractSystem) - return [namespace_jump(j, sys) for j in get_jumps(sys)] + js = jumps(sys) + isempty(js) && return js + if js === get_jumps(sys) + js = copy(js) + end + for i in eachindex(js) + js[i] = namespace_jump(js[i], sys) + end + return js end function namespace_brownians(sys::AbstractSystem) - return [renamespace(sys, b) for b in brownians(sys)] + bs = brownians(sys) + if bs === get_brownians(sys) + bs = copy(bs) + end + for i in eachindex(bs) + bs[i] = renamespace(sys, bs[i]) + end + return bs end function namespace_assignment(eq::Assignment, sys) @@ -1224,48 +1286,66 @@ function is_array_of_symbolics(x) any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) end -function namespace_expr( - O, sys, n = (sys === nothing ? nothing : nameof(sys)); - ivs = sys === nothing ? nothing : independent_variables(sys)) - sys === nothing && return O - O = unwrap(O) - # Exceptions for arrays of symbolic and Ref of a symbolic, the latter - # of which shows up in broadcasts - if symbolic_type(O) == NotSymbolic() && !(O isa AbstractArray) && !(O isa Ref) - return O - end - if any(isequal(O), ivs) - return O - elseif iscall(O) - T = typeof(O) - renamed = let sys = sys, n = n, T = T - map(a -> namespace_expr(a, sys, n; ivs)::Any, arguments(O)) - end - if isvariable(O) - # Use renamespace so the scope is correct, and make sure to use the - # metadata from the rescoped variable - rescoped = renamespace(n, O) - maketerm(typeof(rescoped), operation(rescoped), renamed, - metadata(rescoped)) - elseif Symbolics.isarraysymbolic(O) - # promote_symtype doesn't work for array symbolics - maketerm(typeof(O), operation(O), renamed, metadata(O)) - else - maketerm(typeof(O), operation(O), renamed, metadata(O)) +function namespace_expr(O, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + return O +end +function namespace_expr(O::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + typeof(O)(namespace_expr(unwrap(O), sys, n; kw...)) +end +function namespace_expr(O::AbstractArray, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + is_array_of_symbolics(O) || return O + O = copy(O) + for i in eachindex(O) + O[i] = namespace_expr(O[i], sys, n; ivs) + end + return O +end +function namespace_expr(O::SymbolicT, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + any(isequal(O), ivs) && return O + isvar = isvariable(O) + Moshi.Match.@match O begin + BSImpl.Const(;) => return O + BSImpl.Sym(;) => return isvar ? renamespace(n, O) : O + BSImpl.Term(; f, args, metadata, type, shape) => begin + newargs = copy(parent(args)) + for i in eachindex(args) + newargs[i] = namespace_expr(newargs[i], sys, n; ivs) + end + if isvar + rescoped = renamespace(n, O) + f = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :f) + meta = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :metadata) + elseif f isa SymbolicT + f = renamespace(n, f) + meta = metadata + else + meta = metadata + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata = meta) end - elseif isvariable(O) - renamespace(n, O) - elseif O isa AbstractArray && is_array_of_symbolics(O) - let sys = sys, n = n - map(o -> namespace_expr(o, sys, n; ivs), O) + BSImpl.AddMul(; coeff, dict, variant, type, shape, metadata) => begin + newdict = copy(dict) + empty!(newdict) + for (k, v) in dict + newdict[namespace_expr(k, sys, n; ivs)] = v + end + return BSImpl.AddMul{VartypeT}(coeff, newdict, variant; type, shape, metadata) + end + BSImpl.Div(; num, den, type, shape, metadata) => begin + num = namespace_expr(num, sys, n; ivs) + den = namespace_expr(den, sys, n; ivs) + return BSImpl.Div{VartypeT}(num, den, false; type, shape, metadata) + end + BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce, type, shape, metadata) => begin + if term isa SymbolicT + term = namespace_expr(term, sys, n; ivs) + end + expr = namespace_expr(expr, sys, n; ivs) + return BSImpl.ArrayOp{VartypeT}(output_idx, expr, reduce, term, ranges; type, shape, metadata) end - else - O end end -_nonum(@nospecialize x) = x isa Num ? x.val : x - """ $(TYPEDSIGNATURES) @@ -1277,21 +1357,14 @@ See also [`ModelingToolkit.get_unknowns`](@ref). function unknowns(sys::AbstractSystem) sts = get_unknowns(sys) systems = get_systems(sys) - nonunique_unknowns = if isempty(systems) - sts - else - system_unknowns = reduce(vcat, namespace_variables.(systems)) - isempty(sts) ? system_unknowns : [sts; system_unknowns] + if isempty(systems) + return sts end - isempty(nonunique_unknowns) && return nonunique_unknowns - # `Vector{Any}` is incompatible with the `SymbolicIndexingInterface`, which uses - # `elsymtype = symbolic_type(eltype(_arg))` - # which inappropriately returns `NotSymbolic()` - if nonunique_unknowns isa Vector{Any} - nonunique_unknowns = _nonum.(nonunique_unknowns) + result = copy(sts) + for subsys in systems + append!(result, namespace_variables(subsys)) end - @assert typeof(nonunique_unknowns) !== Vector{Any} - unique(nonunique_unknowns) + return result end """ @@ -1315,19 +1388,21 @@ See also [`@parameters`](@ref) and [`ModelingToolkit.get_ps`](@ref). """ function parameters(sys::AbstractSystem; initial_parameters = false) ps = get_ps(sys) - if ps == SciMLBase.NullParameters() - return [] + if ps === SciMLBase.NullParameters() + return SymbolicT[] end if eltype(ps) <: Pair - ps = first.(ps) + ps = Vector{SymbolicT}(unwrap.(first.(ps))) end systems = get_systems(sys) - result = unique(isempty(systems) ? ps : - [ps; reduce(vcat, namespace_parameters.(systems))]) + result = copy(ps) + for subsys in systems + append!(result, namespace_parameters(subsys)) + end if !initial_parameters && !is_initializesystem(sys) filter!(result) do sym return !(isoperator(sym, Initial) || - iscall(sym) && operation(sym) == getindex && + iscall(sym) && operation(sym) === getindex && isoperator(arguments(sym)[1], Initial)) end end @@ -1456,10 +1531,15 @@ See also [`observables`](@ref) and [`ModelingToolkit.get_observed()`](@ref). function observed(sys::AbstractSystem) obs = get_observed(sys) systems = get_systems(sys) - [obs; - reduce(vcat, - (map(o -> namespace_equation(o, s), observed(s)) for s in systems), - init = Equation[])] + isempty(systems) && return obs + obs = copy(obs) + for subsys in systems + _obs = observed(subsys) + for eq in _obs + push!(obs, namespace_equation(eq, subsys)) + end + end + return obs end """ @@ -1500,10 +1580,8 @@ function defaults_and_guesses(sys::AbstractSystem) end unknowns(sys::Union{AbstractSystem, Nothing}, v) = namespace_expr(v, sys) -for vType in [Symbolics.Arr, Symbolics.Symbolic{<:AbstractArray}] - @eval unknowns(sys::AbstractSystem, v::$vType) = namespace_expr(v, sys) - @eval parameters(sys::AbstractSystem, v::$vType) = toparam(unknowns(sys, v)) -end +unknowns(sys::AbstractSystem, v::Symbolics.Arr) = namespace_expr(v, sys) +parameters(sys::AbstractSystem, v::Symbolics.Arr) = toparam(unknowns(sys, v)) parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(unknowns(sys, v)) for f in [:unknowns, :parameters] @eval function $f(sys::AbstractSystem, vs::AbstractArray) @@ -1525,15 +1603,12 @@ See also [`full_equations`](@ref) and [`ModelingToolkit.get_eqs`](@ref). function equations(sys::AbstractSystem) eqs = get_eqs(sys) systems = get_systems(sys) - if isempty(systems) - return eqs - else - eqs = Equation[eqs; - reduce(vcat, - namespace_equations.(get_systems(sys)); - init = Equation[])] - return eqs + isempty(systems) && return eqs + eqs = copy(eqs) + for subsys in systems + append!(eqs, namespace_equations(subsys)) end + return eqs end """ @@ -1613,10 +1688,12 @@ all the subsystems of `sys` (appropriately namespaced). function jumps(sys::AbstractSystem) js = get_jumps(sys) systems = get_systems(sys) - if isempty(systems) - return js + isempty(systems) && return js + js = copy(js) + for subsys in systems + append!(js, namespace_jumps(subsys)) end - return [js; reduce(vcat, namespace_jumps.(systems); init = [])] + return js end """ @@ -1631,7 +1708,11 @@ function brownians(sys::AbstractSystem) if isempty(systems) return bs end - return [bs; reduce(vcat, namespace_brownians.(systems); init = [])] + bs = copy(bs) + for subsys in systems + append!(bs, namespace_brownians(subsys)) + end + return bs end """ @@ -1645,10 +1726,13 @@ function cost(sys::AbstractSystem) consolidate = get_consolidate(sys) systems = get_systems(sys) if isempty(systems) - return consolidate(cs, Float64[]) + return consolidate(cs, Float64[])::SymbolicT end - subcosts = [namespace_expr(cost(subsys), subsys) for subsys in systems] - return consolidate(cs, subcosts) + subcosts = SymbolicT[] + for subsys in systems + push!(subcosts, namespace_expr(cost(subsys), subsys)) + end + return consolidate(cs, subcosts)::SymbolicT end namespace_constraint(eq::Equation, sys) = namespace_equation(eq, sys) @@ -1665,8 +1749,14 @@ end function namespace_constraints(sys) cstrs = constraints(sys) - isempty(cstrs) && return Vector{Union{Equation, Inequality}}(undef, 0) - map(cstr -> namespace_constraint(cstr, sys), cstrs) + isempty(cstrs) && return cstrs + if cstrs === get_constraints(sys) + cstrs = copy(cstrs) + end + for i in eachindex(cstrs) + cstrs[i] = namespace_constraint(cstrs[i], sys) + end + return cstrs end """ @@ -1677,7 +1767,12 @@ Get all constraints in the system `sys` and all of its subsystems, appropriately function constraints(sys::AbstractSystem) cs = get_constraints(sys) systems = get_systems(sys) - isempty(systems) ? cs : [cs; reduce(vcat, namespace_constraints.(systems))] + isempty(systems) && return cs + cs = copy(cs) + for subsys in systems + append!(cs, namespace_constraints(subsys)) + end + return cs end """ @@ -2257,7 +2352,7 @@ end function default_to_parentscope(v) uv = unwrap(v) - uv isa Symbolic || return v + uv isa SymbolicT || return v apply_to_variables(v) do sym ParentScope(sym) end @@ -2657,8 +2752,8 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys) if has_is_dde(sys) @set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys)) end - newunknowns = OrderedSet() - newparams = OrderedSet() + newunknowns = OrderedSet{SymbolicT}() + newparams = OrderedSet{SymbolicT}() iv = has_iv(sys) ? get_iv(sys) : nothing for ssys in systems collect_scoped_vars!(newunknowns, newparams, ssys, iv) @@ -2735,26 +2830,26 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, elseif sys isa System rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), collect(rules))) - newsys = @set sys.eqs = fast_substitute(get_eqs(sys), rules) + newsys = @set sys.eqs = substitute(get_eqs(sys), rules) @set! newsys.unknowns = map(get_unknowns(sys)) do var get(rules, var, var) end @set! newsys.ps = map(get_ps(sys)) do var get(rules, var, var) end - @set! newsys.parameter_dependencies = fast_substitute( + @set! newsys.parameter_dependencies = substitute( get_parameter_dependencies(sys), rules) - @set! newsys.defaults = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.defaults = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_defaults(sys)) - @set! newsys.guesses = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.guesses = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_guesses(sys)) - @set! newsys.noise_eqs = fast_substitute(get_noise_eqs(sys), rules) - @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(fast_substitute( + @set! newsys.noise_eqs = substitute(get_noise_eqs(sys), rules) + @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(substitute( get_costs(sys), rules)) - @set! newsys.observed = fast_substitute(get_observed(sys), rules) - @set! newsys.initialization_eqs = fast_substitute( + @set! newsys.observed = substitute(get_observed(sys), rules) + @set! newsys.initialization_eqs = substitute( get_initialization_eqs(sys), rules) - @set! newsys.constraints = fast_substitute(get_constraints(sys), rules) + @set! newsys.constraints = substitute(get_constraints(sys), rules) @set! newsys.systems = map(s -> substitute(s, rules), get_systems(sys)) else error("substituting symbols is not supported for $(typeof(sys))") @@ -2774,21 +2869,23 @@ function process_parameter_equations(sys::AbstractSystem) if !isempty(get_systems(sys)) throw(ArgumentError("Expected flattened system")) end - varsbuf = Set() + varsbuf = Set{SymbolicT}() pareq_idxs = Int[] eqs = equations(sys) for (i, eq) in enumerate(eqs) empty!(varsbuf) - vars!(varsbuf, eq; op = Union{Differential, Initial, Pre}) + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Union{Differential, Initial, Pre}}()) # singular equations isempty(varsbuf) && continue - if all(varsbuf) do sym + if let sys = sys + all(varsbuf) do sym is_parameter(sys, sym) || symbolic_type(sym) == ArraySymbolic() && - is_sized_array_symbolic(sym) && + symbolic_has_known_size(sym) && all(Base.Fix1(is_parameter, sys), collect(sym)) || iscall(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[1]) + end end # Everything in `varsbuf` is a parameter, so this is a cheap `is_parameter` # check. @@ -2802,8 +2899,11 @@ function process_parameter_equations(sys::AbstractSystem) end end - pareqs = [get_parameter_dependencies(sys); eqs[pareq_idxs]] - explicitpars = [eq.lhs for eq in pareqs] + pareqs = Equation[get_parameter_dependencies(sys); eqs[pareq_idxs]] + explicitpars = SymbolicT[] + for eq in pareqs + push!(explicitpars, eq.lhs) + end pareqs = topsort_equations(pareqs, explicitpars) eqs = eqs[setdiff(eachindex(eqs), pareq_idxs)] diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index dc25378b4a..8ac742f6be 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -10,10 +10,12 @@ function alias_eliminate_graph!(state::TransformationState; kwargs...) @unpack graph, var_to_diff, solvable_graph = state.structure mm = alias_eliminate_graph!(state, mm; kwargs...) s = state.structure - for g in (s.graph, s.solvable_graph) - g === nothing && continue + for (ei, e) in enumerate(mm.nzrows) + set_neighbors!(s.graph, e, mm.row_cols[ei]) + end + if s.solvable_graph isa BipartiteGraph{Int, Nothing} for (ei, e) in enumerate(mm.nzrows) - set_neighbors!(g, e, mm.row_cols[ei]) + set_neighbors!(s.solvable_graph, e, mm.row_cols[ei]) end end @@ -46,14 +48,12 @@ alias_elimination(sys) = alias_elimination!(TearingState(sys))[1] function alias_elimination!(state::TearingState; kwargs...) sys = state.sys complete!(state.structure) - graph_orig = copy(state.structure.graph) mm = alias_eliminate_graph!(state; kwargs...) fullvars = state.fullvars @unpack var_to_diff, graph, solvable_graph = state.structure - subs = Dict() - obs = Equation[] + subs = Dict{SymbolicT, SymbolicT}() # If we encounter y = -D(x), then we need to expand the derivative when # D(y) appears in the equation, so that D(-D(x)) becomes -D(D(x)). to_expand = Int[] @@ -62,17 +62,22 @@ function alias_elimination!(state::TearingState; kwargs...) dels = Int[] eqs = collect(equations(state)) resize!(eqs, nsrcs(graph)) + + __trivial_eq_rhs = let fullvars = fullvars + function trivial_eq_rhs(pair) + var, coeff = pair + iszero(coeff) && return Symbolics.COMMON_ZERO + return coeff * fullvars[var] + end + end for (ei, e) in enumerate(mm.nzrows) vs = 𝑠neighbors(graph, e) if isempty(vs) # remove empty equations push!(dels, e) else - rhs = mapfoldl(+, pairs(nonzerosmap(@view mm[ei, :]))) do (var, coeff) - iszero(coeff) && return 0 - return coeff * fullvars[var] - end - eqs[e] = 0 ~ rhs + rhs = mapfoldl(__trivial_eq_rhs, +, pairs(nonzerosmap(@view mm[ei, :]))) + eqs[e] = Symbolics.COMMON_ZERO ~ rhs end end deleteat!(eqs, sort!(dels)) @@ -92,21 +97,22 @@ function alias_elimination!(state::TearingState; kwargs...) n_new_eqs = idx - lineqs = BitSet(mm.nzrows) eqs_to_update = BitSet() - nvs_orig = ndsts(graph_orig) for ieq in eqs_to_update eq = eqs[ieq] - eqs[ieq] = fast_substitute(eq, subs) + eqs[ieq] = substitute(eq, subs) end - @set! mm.nparentrows = nsrcs(graph) - @set! mm.row_cols = eltype(mm.row_cols)[mm.row_cols[i] - for (i, eq) in enumerate(mm.nzrows) - if old_to_new_eq[eq] > 0] - @set! mm.row_vals = eltype(mm.row_vals)[mm.row_vals[i] - for (i, eq) in enumerate(mm.nzrows) - if old_to_new_eq[eq] > 0] - @set! mm.nzrows = Int[old_to_new_eq[eq] for eq in mm.nzrows if old_to_new_eq[eq] > 0] + new_nparentrows = nsrcs(graph) + new_row_cols = eltype(mm.row_cols)[] + new_row_vals = eltype(mm.row_vals)[] + new_nzrows = Int[] + for (i, eq) in enumerate(mm.nzrows) + old_to_new_eq[eq] > 0 || continue + push!(new_row_cols, mm.row_cols[i]) + push!(new_row_vals, mm.row_vals[i]) + push!(new_nzrows, old_to_new_eq[eq]) + end + mm = typeof(mm)(new_nparentrows, mm.ncols, new_nzrows, new_row_cols, new_row_vals) for old_ieq in to_expand ieq = old_to_new_eq[old_ieq] @@ -138,7 +144,13 @@ function alias_elimination!(state::TearingState; kwargs...) sys = state.sys @set! sys.eqs = eqs state.sys = sys - return invalidate_cache!(sys), mm + # This phrasing infers the return type as `Union{Tuple{...}}` instead of + # `Tuple{Union{...}, ...}` + if mm isa SparseMatrixCLIL{BigInt, Int} + return invalidate_cache!(sys), mm + else + return invalidate_cache!(sys), mm + end end """ @@ -301,7 +313,13 @@ function aag_bareiss!(structure, mm_orig::SparseMatrixCLIL{T, Ti}) where {T, Ti} bar = do_bareiss!(mm, mm_orig, is_linear_variables, is_highest_diff) end - return mm, solvable_variables, bar + # This phrasing infers the return type as `Union{Tuple{...}}` instead of + # `Tuple{Union{...}, ...}` + if mm isa SparseMatrixCLIL{BigInt, Ti} + return mm, solvable_variables, bar + else + return mm, solvable_variables, bar + end end function do_bareiss!(M, Mold, is_linear_variables, is_highest_diff) @@ -411,7 +429,7 @@ julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k]) Equation(x(t), y(t) + z(t)) ``` """ -function topsort_equations(eqs, unknowns; check = true) +function topsort_equations(eqs::Vector{Equation}, unknowns::Vector{SymbolicT}; check = true) graph, assigns = observed2graph(eqs, unknowns) neqs = length(eqs) degrees = zeros(Int, neqs) @@ -460,22 +478,25 @@ function topsort_equations(eqs, unknowns; check = true) return ordered_eqs end -function observed2graph(eqs, unknowns) +function observed2graph(eqs::Vector{Equation}, unknowns::Vector{SymbolicT})::Tuple{BipartiteGraph{Int, Nothing}, Vector{Int}} graph = BipartiteGraph(length(eqs), length(unknowns)) - v2j = Dict(unknowns .=> 1:length(unknowns)) + v2j = Dict{SymbolicT, Int}(unknowns .=> 1:length(unknowns)) # `assigns: eq -> var`, `eq` defines `var` assigns = similar(eqs, Int) - + vars = Set{SymbolicT}() for (i, eq) in enumerate(eqs) lhs_j = get(v2j, eq.lhs, nothing) lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns.")) assigns[i] = lhs_j - vs = vars(eq.rhs; op = Symbolics.Operator) - for v in vs + empty!(vars) + SU.search_variables!(vars, eq.rhs; is_atomic = OperatorIsAtomic{SU.Operator}()) + for v in vars j = get(v2j, v, nothing) - j !== nothing && add_edge!(graph, i, j) + if j isa Int + add_edge!(graph, i, j) + end end end diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index a5a612b9ca..d25f717d72 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -71,7 +71,7 @@ struct AnalysisPoint The outputs of the connection. In the context of ModelingToolkitStandardLibrary.jl, these are all `RealInput` connectors. """ - outputs::Union{Nothing, Vector{Any}} + outputs::Union{Nothing, Vector{System}, Vector{SymbolicT}} function AnalysisPoint(input, name::Symbol, outputs; verbose = true) # input to analysis point should be an output variable @@ -161,7 +161,17 @@ end Convert an `AnalysisPoint` to a standard connection. """ function to_connection(ap::AnalysisPoint) - return connect(ap.input, ap.outputs...) + if ap.input isa System + vs = System[ap.input] + append!(vs, ap.outputs::Vector{System}) + return Connection() ~ Connection(vs) + elseif ap.input isa SymbolicT + vs = SymbolicT[ap.input] + append!(vs, ap.outputs::Vector{SymbolicT}) + return Connection() ~ Connection(vs) + else + error("Unreachable!") + end end """ @@ -179,7 +189,7 @@ end # create analysis points via `connect` function connect(in, ap::AnalysisPoint, outs...; verbose = true) - return AnalysisPoint() ~ AnalysisPoint(in, ap.name, collect(outs); verbose) + return AnalysisPoint() ~ AnalysisPoint(unwrap(in), ap.name, collect(unwrap.(outs)); verbose) end """ @@ -220,7 +230,7 @@ typically is not (unless the model is an inverse model). warning if you are analyzing an inverse model. """ function connect(in::AbstractSystem, name::Symbol, out, outs...; verbose = true) - return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) + return AnalysisPoint() ~ AnalysisPoint(in, name, System[out; collect(outs)]; verbose) end function connect( @@ -249,8 +259,13 @@ end Remove all `AnalysisPoint`s in `sys` and any of its subsystems, replacing them by equivalent connections. """ function remove_analysis_points(sys::AbstractSystem) - eqs = map(get_eqs(sys)) do eq - eq.lhs isa AnalysisPoint ? to_connection(eq.rhs) : eq + eqs = Equation[] + for eq in get_eqs(sys) + if unwrap_const(eq.lhs) isa AnalysisPoint + push!(eqs, to_connection(unwrap_const(eq.rhs)::AnalysisPoint)) + else + push!(eqs, eq) + end end @set! sys.eqs = eqs @set! sys.systems = map(remove_analysis_points, get_systems(sys)) @@ -413,7 +428,7 @@ Search for the analysis point with the given `name` in `get_eqs(sys)`. function analysis_point_index(sys::AbstractSystem, name::Symbol) name = namespace_hierarchy(name)[end] findfirst(get_eqs(sys)) do eq - eq.lhs isa AnalysisPoint && nameof(eq.rhs) == name + value(eq.lhs) isa AnalysisPoint && nameof(value(eq.rhs)::AnalysisPoint) == name end end @@ -525,7 +540,7 @@ function apply_transformation(tf::Break, sys::AbstractSystem) breaksys_eqs = copy(get_eqs(breaksys)) @set! breaksys.eqs = breaksys_eqs - ap = breaksys_eqs[ap_idx].rhs + ap = value(breaksys_eqs[ap_idx].rhs)::AnalysisPoint deleteat!(breaksys_eqs, ap_idx) breaksys = with_analysis_point_ignored(breaksys, ap) @@ -583,7 +598,7 @@ function apply_transformation(tf::GetInput, sys::AbstractSystem) error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).") # get the analysis point ap_sys_eqs = get_eqs(ap_sys) - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # input variable ap_ivar = ap_var(ap.input) @@ -638,7 +653,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem) # modified equations ap_sys_eqs = copy(get_eqs(ap_sys)) @set! ap_sys.eqs = ap_sys_eqs - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # remove analysis point deleteat!(ap_sys_eqs, ap_idx) ap_sys = with_analysis_point_ignored(ap_sys, ap) @@ -708,7 +723,7 @@ function apply_transformation(tf::AddVariable, sys::AbstractSystem) ap_idx === nothing && error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).") ap_sys_eqs = get_eqs(ap_sys) - ap = ap_sys_eqs[ap_idx].rhs + ap = value(ap_sys_eqs[ap_idx].rhs)::AnalysisPoint # add equations involving new variable ap_ivar = ap_var(ap.input) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c94166103b..8bf49f508d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -7,15 +7,19 @@ end struct SymbolicAffect affect::Vector{Equation} alg_eqs::Vector{Equation} - discrete_parameters::Vector{Any} + discrete_parameters::Vector{SymbolicT} end function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[], - discrete_parameters = Any[], kwargs...) - if !(discrete_parameters isa AbstractVector) - discrete_parameters = Any[discrete_parameters] - elseif !(discrete_parameters isa Vector{Any}) - discrete_parameters = Vector{Any}(discrete_parameters) + discrete_parameters = SymbolicT[], kwargs...) + if symbolic_type(discrete_parameters) !== NotSymbolic() + discrete_parameters = SymbolicT[unwrap(discrete_parameters)] + elseif !(discrete_parameters isa Vector{SymbolicT}) + _discs = SymbolicT[] + for p in discrete_parameters + push!(_discs, unwrap(p)) + end + discrete_parameters = _discs end SymbolicAffect(affect, alg_eqs, discrete_parameters) end @@ -25,34 +29,31 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...) end SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...) -function Symbolics.fast_substitute(aff::SymbolicAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs), - map(substituter, aff.discrete_parameters)) +function (s::SymbolicUtils.Substituter)(aff::SymbolicAffect) + SymbolicAffect(s(aff.affect), s(aff.alg_eqs), s(aff.discrete_parameters)) end struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem """Unknowns of the parent ODESystem whose values are modified or accessed by the affect.""" - unknowns::Vector + unknowns::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are accessed by the affect.""" - parameters::Vector + parameters::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are modified by the affect.""" - discretes::Vector + discretes::Vector{SymbolicT} end -function Symbolics.fast_substitute(aff::AffectSystem, rules) - substituter = Base.Fix2(fast_substitute, rules) +function (s::SymbolicUtils.Substituter)(aff::AffectSystem) sys = aff.system - @set! sys.eqs = map(substituter, get_eqs(sys)) - @set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys)) - @set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)]) - @set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)]) - @set! sys.unknowns = map(substituter, get_unknowns(sys)) - @set! sys.ps = map(substituter, get_ps(sys)) - AffectSystem(sys, map(substituter, aff.unknowns), - map(substituter, aff.parameters), map(substituter, aff.discretes)) + @set! sys.eqs = s(get_eqs(sys)) + @set! sys.parameter_dependencies = (get_parameter_dependencies(sys)) + @set! sys.defaults = Dict([k => s(v) for (k, v) in defaults(sys)]) + @set! sys.guesses = Dict([k => s(v) for (k, v) in guesses(sys)]) + @set! sys.unknowns = s(get_unknowns(sys)) + @set! sys.ps = s(get_ps(sys)) + AffectSystem(sys, s(aff.unknowns), s(aff.parameters), s(aff.discretes)) + end function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...) @@ -60,7 +61,11 @@ function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], discrete_parameters = spec.discrete_parameters, kwargs...) end -function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], +@noinline function warn_algebraic_equation(eq::Equation) + @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." +end + +function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[], iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...) isempty(affect) && return nothing if isnothing(iv) @@ -68,26 +73,24 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], @warn "No independent variable specified. Defaulting to t_nounits." end - discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters]) - discrete_parameters = unwrap.(discrete_parameters) + discrete_parameters = SymbolicAffect(affect; alg_eqs, discrete_parameters).discrete_parameters for p in discrete_parameters - occursin(unwrap(iv), unwrap(p)) || + SU.query(isequal(unwrap(iv)), unwrap(p)) || error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).") end - dvs = OrderedSet() - params = OrderedSet() - _varsbuf = Set() + dvs = OrderedSet{SymbolicT}() + params = OrderedSet{SymbolicT}() + _varsbuf = Set{SymbolicT}() for eq in affect - if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() || - symbolic_type(eq.lhs) === NotSymbolic()) - @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." + if !haspre(eq) && !(isconst(eq.lhs) && isconst(eq.rhs)) + @invokelatest warn_algebraic_equation(eq) end collect_vars!(dvs, params, eq, iv; op = Pre) empty!(_varsbuf) - vars!(_varsbuf, eq; op = Pre) - filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf) + SU.search_variables!(_varsbuf, eq; is_atomic = OperatorIsAtomic{Pre}()) + filter!(x -> iscall(x) && operation(x) === Pre(), _varsbuf) union!(params, _varsbuf) diffvs = collect_applied_operators(eq, Differential) union!(dvs, diffvs) @@ -95,30 +98,35 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], for eq in alg_eqs collect_vars!(dvs, params, eq, iv) end - pre_params = filter(haspre ∘ value, params) - sys_params = collect(setdiff(params, union(discrete_parameters, pre_params))) + pre_params = filter(haspre, params) + sys_params = SymbolicT[] + disc_ps_set = Set{SymbolicT}(discrete_parameters) + for p in params + p in disc_ps_set && continue + p in pre_params && continue + push!(sys_params, p) + end discretes = map(tovar, discrete_parameters) dvs = collect(dvs) _dvs = map(default_toterm, dvs) - rev_map = Dict(zip(discrete_parameters, discretes)) - subs = merge(rev_map, Dict(zip(dvs, _dvs))) - affect = Symbolics.fast_substitute(affect, subs) - alg_eqs = Symbolics.fast_substitute(alg_eqs, subs) + rev_map = Dict{SymbolicT, SymbolicT}(zip(discrete_parameters, discretes)) + subs = merge(rev_map, Dict{SymbolicT, SymbolicT}(zip(dvs, _dvs))) + affect = substitute(affect, subs) + alg_eqs = substitute(alg_eqs, subs) @named affectsys = System( vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)), collect(union(pre_params, sys_params)); is_discrete = true) affectsys = mtkcompile(affectsys; fully_determined = nothing) # get accessed parameters p from Pre(p) in the callback parameters - accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params)))) + accessed_params = Vector{SymbolicT}(filter(isparameter, map(unPre, collect(pre_params)))) union!(accessed_params, sys_params) # add scalarized unknowns to the map. - _dvs = reduce(vcat, map(scalarize, _dvs), init = Any[]) + _dvs = reduce(vcat, map(scalarize, _dvs), init = SymbolicT[]) - AffectSystem(affectsys, collect(_dvs), collect(accessed_params), - collect(discrete_parameters)) + AffectSystem(affectsys, _dvs, accessed_params, discrete_parameters) end system(a::AffectSystem) = a.system @@ -169,7 +177,7 @@ Base.nameof(::Pre) = :Pre Base.show(io::IO, x::Pre) = print(io, "Pre") unPre(x::Num) = unPre(unwrap(x)) unPre(x::Symbolics.Arr) = unPre(unwrap(x)) -unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x +unPre(x::SymbolicT) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x function (p::Pre)(x) iw = Symbolics.iswrapped(x) @@ -186,14 +194,15 @@ function (p::Pre)(x) iscall(x) && operation(x) isa Pre && return x result = if symbolic_type(x) == ArraySymbolic() # create an array for `Pre(array)` + term(p, x; type = symtype(x), shape = SU.shape(x)) Symbolics.array_term(p, x) elseif iscall(x) && operation(x) == getindex # instead of `Pre(x[1])` create `Pre(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, p(arr), arguments(x)[2:end]...) + p(arr)[arguments(x)[2:end]...] else - term(p, x) + term(p, x; type = symtype(x), shape = SU.shape(x)) end # the result should be a parameter result = toparam(result) @@ -420,14 +429,14 @@ A callback that triggers at the first timestep that the conditions are satisfied The condition can be one of: - Δt::Real - periodic events with period Δt - ts::Vector{Real} - events trigger at these preset times given by `ts` -- eqs::Vector{Symbolic} - events trigger when the condition evaluates to true +- eqs::Vector{SymbolicT} - events trigger when the condition evaluates to true Arguments: - iv: The independent variable of the system. This must be specified if the independent variable appears in one of the equations explicitly, as in x ~ t + 1. - alg_eqs: Algebraic equations of the system that must be satisfied after the callback occurs. """ struct SymbolicDiscreteCallback <: AbstractCallback - conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}} + conditions::Union{Number, Vector{<:Number}, SymbolicT} affect::Union{Affect, SymbolicAffect, Nothing} initialize::Union{Affect, SymbolicAffect, Nothing} finalize::Union{Affect, SymbolicAffect, Nothing} @@ -435,9 +444,10 @@ struct SymbolicDiscreteCallback <: AbstractCallback end function SymbolicDiscreteCallback( - condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing; + condition::Union{SymbolicT, Number, Vector{<:Number}}, affect = nothing; initialize = nothing, finalize = nothing, reinitializealg = nothing, kwargs...) + @assert !(condition isa SymbolicT && symtype(condition) != Bool) c = is_timed_condition(condition) ? condition : value(scalarize(condition)) if isnothing(reinitializealg) @@ -569,6 +579,9 @@ conditions(cb::AbstractCallback) = cb.conditions function conditions(cbs::Vector{<:AbstractCallback}) reduce(vcat, conditions(cb) for cb in cbs; init = []) end +function conditions(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(conditions, vcat, cbs; init = Equation[]) +end equations(cb::AbstractCallback) = conditions(cb) equations(cb::Vector{<:AbstractCallback}) = conditions(cb) @@ -871,7 +884,7 @@ function default_operating_point(affsys::AffectSystem) T = symtype(p) if T <: Number op[p] = false - elseif T <: Array{<:Real} && is_sized_array_symbolic(p) + elseif T <: Array{<:Real} && symbolic_has_known_size(p) op[p] = zeros(size(p)) end end @@ -897,7 +910,7 @@ function compile_equational_affect( obseqs, eqs = unhack_observed(observed(affsys), equations(affsys)) if isempty(equations(affsys)) - update_eqs = Symbolics.fast_substitute( + update_eqs = substitute( obseqs, Dict([p => unPre(p) for p in parameters(affsys)])) rhss = map(x -> x.rhs, update_eqs) lhss = map(x -> x.lhs, update_eqs) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index a88e8c42fe..96529e0423 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -58,17 +58,24 @@ struct NotInferredTimeDomain end function error_sample_time(eq) error("$eq\ncontains `SampleTime` but it is not an Inferred discrete equation.") end -function substitute_sample_time(ci::ClockInference, ts::TearingState) +function substitute_sample_time(ci::ClockInference{T}, ts::T) where {T <: TearingState} @unpack eq_domain = ci eqs = copy(equations(ts)) @assert length(eqs) == length(eq_domain) + subrules = Dict{SymbolicT, SymbolicT}() + st = SampleTime() for i in eachindex(eqs) eq = eqs[i] domain = eq_domain[i] - dt = sampletime(domain) - neweq = substitute_sample_time(eq, dt) - if neweq isa NotInferredTimeDomain - error_sample_time(eq) + dt = SU.Const{VartypeT}(sampletime(domain)) + if dt === COMMON_NOTHING + if SU.query(isequal(st), eq.lhs) || SU.query(isequal(st), eq.rhs) + error_sample_time(eq) + end + neweq = eq + else + subrules[st] = dt + neweq = substitute(eq, subrules) end eqs[i] = neweq end @@ -76,30 +83,6 @@ function substitute_sample_time(ci::ClockInference, ts::TearingState) @set! ci.ts = ts end -function substitute_sample_time(eq::Equation, dt) - substitute_sample_time(eq.lhs, dt) ~ substitute_sample_time(eq.rhs, dt) -end - -function substitute_sample_time(ex, dt) - iscall(ex) || return ex - op = operation(ex) - args = arguments(ex) - if op == SampleTime - dt === nothing && return NotInferredTimeDomain() - return dt - else - new_args = similar(args) - for (i, arg) in enumerate(args) - ex_arg = substitute_sample_time(arg, dt) - if ex_arg isa NotInferredTimeDomain - return ex_arg - end - new_args[i] = ex_arg - end - maketerm(typeof(ex), op, new_args, metadata(ex)) - end -end - """ Update the equation-to-time domain mapping by inferring the time domain from the variables. """ @@ -109,7 +92,7 @@ function infer_clocks!(ci::ClockInference) fullvars = get_fullvars(ts) isempty(inferred) && return ci - var_to_idx = Dict(fullvars .=> eachindex(fullvars)) + var_to_idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars)) # all shifted variables have the same clock as the unshifted variant for (i, v) in enumerate(fullvars) @@ -122,9 +105,9 @@ function infer_clocks!(ci::ClockInference) # preallocated buffers: # variables in each equation - varsbuf = Set() + varsbuf = Set{SymbolicT}() # variables in each argument to an operator - arg_varsbuf = Set() + arg_varsbuf = Set{SymbolicT}() # hyperedge for each equation hyperedge = Set{ClockVertex.Type}() # hyperedge for each argument to an operator @@ -136,7 +119,7 @@ function infer_clocks!(ci::ClockInference) empty!(varsbuf) empty!(hyperedge) # get variables in equation - vars!(varsbuf, eq; op = Symbolics.Operator) + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Symbolics.Operator}()) # add the equation to the hyperedge eq_node = if is_initialization_equation ClockVertex.InitEquation(ieq) @@ -155,14 +138,17 @@ function infer_clocks!(ci::ClockInference) # now we only care about synchronous operators iscall(var) || continue op = operation(var) - is_timevarying_operator(op) || continue + is_timevarying_operator(op)::Bool || continue # arguments and corresponding time domains args = arguments(var) tdomains = input_timedomain(op) - if !(tdomains isa AbstractArray || tdomains isa Tuple) - tdomains = [tdomains] + if tdomains isa Tuple + tdomains = Vector{InputTimeDomainElT}(collect(tdomains)) + elseif !(tdomains isa Vector{InputTimeDomainElT}) + tdomains = InputTimeDomainElT[tdomains] end + tdomains = tdomains::Vector{InputTimeDomainElT} nargs = length(args) ndoms = length(tdomains) if nargs != ndoms @@ -178,7 +164,7 @@ function infer_clocks!(ci::ClockInference) empty!(arg_varsbuf) empty!(arg_hyperedge) # get variables in argument - vars!(arg_varsbuf, arg; op = Union{Differential, Shift}) + SU.search_variables!(arg_varsbuf, arg; is_atomic = OperatorIsAtomic{Union{Differential, Shift}}()) # get hyperedge for involved variables for v in arg_varsbuf vidx = get(var_to_idx, v, nothing) @@ -200,7 +186,7 @@ function infer_clocks!(ci::ClockInference) # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't # add the edge, and instead add this to the `relative_hyperedges` mapping. InferredClock.InferredDiscrete(i) => begin - relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i) + relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i) union!(relative_edge, arg_hyperedge) end end @@ -237,7 +223,7 @@ function infer_clocks!(ci::ClockInference) clock_partitions = connectionsets(inference_graph) for partition in clock_partitions - clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition) + clockidxs = findall(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Clock), partition) if isempty(clockidxs) push!(partition, ClockVertex.Clock(ContinuousClock())) push!(clockidxs, length(partition)) @@ -341,7 +327,7 @@ function split_system(ci::ClockInference{S}) where {S} continuous_id = continuous_id[] # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) - inputs = map(_ -> Any[], 1:cid_counter[]) + inputs = map(_ -> SymbolicT[], 1:cid_counter[]) # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) # put variables into the right clock partition diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 2687fedb80..486d2f6180 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -562,7 +562,7 @@ function generate_boundary_conditions(sys::System, u0, u0_idxs, t0; expression = cons = [con.lhs - con.rhs for con in constraints(sys)] # conssubs = Dict() # get_constraint_unknown_subs!(conssubs, cons, stidxmap, iv, sol) - # cons = map(x -> fast_substitute(x, conssubs), cons) + # cons = map(x -> substitute(x, conssubs), cons) init_conds = Any[] for i in u0_idxs @@ -1024,6 +1024,17 @@ function build_explicit_observed_function(sys, ts; cse = true, mkarray = nothing, wrap_delays = is_dde(sys)) + if inputs === nothing + inputs = () + else + inputs = vec(unwrap_vars(inputs)) + end + if disturbance_inputs === nothing + disturbance_inputs = () + else + disturbance_inputs = vec(unwrap_vars(disturbance_inputs)) + end + ps::Vector{SymbolicT} = vec(unwrap_vars(ps)) # TODO: cleanup is_tuple = ts isa Tuple if is_tuple @@ -1038,7 +1049,8 @@ function build_explicit_observed_function(sys, ts; ts = symbol_to_symbolic(sys, ts; allsyms) end - vs = ModelingToolkit.vars(ts; op) + vs = Set{SymbolicT}() + SU.search_variables!(vs, ts; is_atomic = OperatorIsAtomic{op}()) namespace_subs = Dict() ns_map = Dict{Any, Any}(renamespace(sys, eq.lhs) => eq.lhs for eq in observed(sys)) for sym in unknowns(sys) @@ -1066,7 +1078,7 @@ function build_explicit_observed_function(sys, ts; Base.throw(ArgumentError("Symbol $var is not present in the system.")) end end - ts = fast_substitute(ts, namespace_subs) + ts = substitute(ts, namespace_subs) obsfilter = if param_only if is_split(sys) @@ -1084,13 +1096,11 @@ function build_explicit_observed_function(sys, ts; else (unknowns(sys),) end - if inputs === nothing - inputs = () - else + if inputs isa Vector{SymbolicT} ps = setdiff(ps, inputs) # Inputs have been converted to parameters by io_preprocessing, remove those from the parameter list inputs = (inputs,) end - if disturbance_inputs !== nothing + if disturbance_inputs isa Vector{SymbolicT} # Disturbance inputs may or may not be included as inputs, depending on disturbance_argument ps = setdiff(ps, disturbance_inputs) end @@ -1099,15 +1109,15 @@ function build_explicit_observed_function(sys, ts; else disturbance_inputs = () end - ps = reorder_parameters(sys, ps) + rps::ReorderedParametersT = reorder_parameters(sys, ps) iv = if is_time_dependent(sys) (get_iv(sys),) else () end - args = (dvs..., inputs..., ps..., iv..., disturbance_inputs...) + args = (dvs..., inputs..., rps..., iv..., disturbance_inputs...) p_start = length(dvs) + length(inputs) + 1 - p_end = length(dvs) + length(inputs) + length(ps) + p_end = length(dvs) + length(inputs) + length(rps) fns = build_function_wrapper( sys, ts, args...; p_start, p_end, filter_observed = obsfilter, output_type, mkarray, try_namespaced = true, expression = Val{true}, cse, @@ -1118,7 +1128,7 @@ function build_explicit_observed_function(sys, ts; end oop, iip = eval_or_rgf.(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( + p_start + wrap_delays, length(args) - length(rps) + 1 + wrap_delays, is_split(sys))}( oop, iip) return return_inplace ? (f, f) : f else @@ -1127,7 +1137,7 @@ function build_explicit_observed_function(sys, ts; end f = eval_or_rgf(fns; eval_expression, eval_module) f = GeneratedFunctionWrapper{( - p_start + wrap_delays, length(args) - length(ps) + 1 + wrap_delays, is_split(sys))}( + p_start + wrap_delays, length(args) - length(rps) + 1 + wrap_delays, is_split(sys))}( f, nothing) return f end diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index dbbd7f85a8..b77bba98b4 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -135,8 +135,8 @@ end """ The argument of generated functions corresponding to the history function. """ -const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___) -const BVP_SOLUTION = Sym{Symbolics.FnType{Tuple{<:Real}, Vector{Real}}}(:__sol__) +const DDE_HISTORY_FUN = SSym(:___history___; type = SU.FnType{Tuple{Any, <:Real}, Vector{Real}}, shape = SU.Unknown(1)) +const BVP_SOLUTION = SSym(:__sol__; type = Symbolics.FnType{Tuple{<:Real}, Vector{Real}}, shape = SU.Unknown(1)) """ $(TYPEDSIGNATURES) diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 27e542c2ab..00424789e5 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -45,8 +45,8 @@ Create a `ConnectionVertex` given use for this connection. """ function ConnectionVertex( - namespace::Vector{Symbol}, var::Union{BasicSymbolic, AbstractSystem}, isouter::Bool) - if var isa BasicSymbolic + namespace::Vector{Symbol}, var::Union{SymbolicT, AbstractSystem}, isouter::Bool) + if var isa SymbolicT name = getname(var) else name = nameof(var) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index c0ddf5baee..438b7231ed 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -23,8 +23,18 @@ Connect multiple connectors created via `@connector`. All connected connectors must be unique. """ function connect(sys1::AbstractSystem, sys2::AbstractSystem, syss::AbstractSystem...) - syss = (sys1, sys2, syss...) - length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") + _syss = System[] + push!(_syss, sys1) + push!(_syss, sys2) + for sys in syss + push!(_syss, sys) + end + syss = _syss + sysnames = Symbol[] + for sys in syss + push!(sysnames, nameof(sys)) + end + allunique(sysnames) || error("connect takes distinct systems!") Equation(Connection(), Connection(syss)) # the RHS are connected systems end @@ -72,12 +82,8 @@ end Get the connection type of symbolic variable `s` from the `VariableConnectType` metadata. Defaults to `Equality` if not present. """ -function get_connection_type(s::Symbolic) - s = unwrap(s) - if iscall(s) && operation(s) === getindex - s = arguments(s)[1] - end - getmetadata(s, VariableConnectType, Equality) +function get_connection_type(s::SymbolicT) + safe_getmetadata(VariableConnectType, s, Equality)::DataType end """ @@ -171,8 +177,12 @@ get_systems(c::Connection) = c.systems Refer to the [Connection semantics](@ref connect_semantics) section of the docs for more information. """ -instream(a) = term(instream, unwrap(a), type = symtype(a)) -SymbolicUtils.promote_symtype(::typeof(instream), _) = Real +function instream(a::SymbolicT) + BSImpl.Term{VartypeT}(instream, SArgsT((unwrap(a),)); type = symtype(a), shape = SU.shape(a)) +end +instream(a::Num) = Num(instream(unwrap(a))) +instream(a::Symbolics.Arr{T, N}) where {T, N} = Symbolics.Arr{T, N}(instream(unwrap(a))) +SymbolicUtils.promote_symtype(::typeof(instream), ::Type{T}) where {T} = T isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing @@ -183,7 +193,7 @@ Utility struct which wraps a symbolic variable used in a `Connection` to enable to work. """ struct SymbolicWithNameof - var::Any + var::SymbolicT end function Base.nameof(x::SymbolicWithNameof) @@ -192,7 +202,7 @@ end is_causal_variable_connection(c) = false function is_causal_variable_connection(c::Connection) - all(x -> x isa SymbolicWithNameof, get_systems(c)) + all(Base.Fix2(isa, SymbolicWithNameof), get_systems(c)) end const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr} @@ -214,24 +224,36 @@ end Perform validation for a connect statement involving causal variables. """ -function validate_causal_variables_connection(allvars) - var1 = allvars[1] - var2 = allvars[2] - vars = Base.tail(Base.tail(allvars)) +function validate_causal_variables_connection(allvars::Vector{SymbolicT}) for var in allvars vtype = getvariabletype(var) vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) end - if length(unique(allvars)) !== length(allvars) + if !allunique(allvars) throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) end - allsizes = map(size, allvars) - if !allequal(allsizes) - throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively.")) + sh1 = SU.shape(allvars[1])::SU.ShapeVecT + sz1 = SU.SmallV{Int}() + for x in sh1 + push!(sz1, length(x)) + end + sz2 = SU.SmallV{Int}() + for v in allvars + sh = SU.shape(v)::SU.ShapeVecT + empty!(sz2) + for x in sh + push!(sz2, length(x)) + end + if !isequal(sz1, sz2) + throw(ArgumentError("Expected all connection variables to have the same size. Got variables $(allvars[1]) and $v with sizes $sz1 and $sz2 respectively.")) + + end end - non_causal_variables = filter(allvars) do var - !isinput(var) && !isoutput(var) + non_causal_variables = SymbolicT[] + for x in allvars + (isinput(x) || isoutput(x)) && continue + push!(non_causal_variables, x) end isempty(non_causal_variables) || throw(NonCausalVariableError(non_causal_variables)) end @@ -250,9 +272,14 @@ var1 ~ var3 """ function connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...) - allvars = (var1, var2, vars...) + allvars = SymbolicT[] + push!(allvars, unwrap(var1)) + push!(allvars, unwrap(var2)) + for var in vars + push!(allvars, unwrap(var)) + end validate_causal_variables_connection(allvars) - return Equation(Connection(), Connection(map(SymbolicWithNameof, unwrap.(allvars)))) + return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars))) end """ @@ -302,6 +329,27 @@ mydiv(num, den) = end @register_symbolic mydiv(n, d) +struct IsOuter + outer_connectors::Set{Symbol} +end + +function (io::IsOuter)(name::Symbol) + name in io.outer_connectors +end + +function (io::IsOuter)(sys) + nm = nameof(sys) + isconnector(sys) || error("$nm is not a connector!") + s = string(nm) + idx = findfirst(NAMESPACE_SEPARATOR, s) + parent_name = if idx === nothing + nm + else + Symbol(@view(s[1:prevind(s, idx)])) + end + return io(parent_name) +end + """ $(TYPEDSIGNATURES) @@ -309,23 +357,12 @@ Return a function which checks whether the connector (system) passed to it is an connector of `sys`. The function can also be given the name of a system as a `Symbol`. """ function generate_isouter(sys::AbstractSystem) - outer_connectors = Symbol[] + outer_connectors = Set{Symbol}() for s in get_systems(sys) n = nameof(s) isconnector(s) && push!(outer_connectors, n) end - let outer_connectors = outer_connectors - function isouter(sys)::Bool - s = string(nameof(sys)) - isconnector(sys) || error("$s is not a connector!") - idx = findfirst(isequal(NAMESPACE_SEPARATOR), s) - parent_name = Symbol(idx === nothing ? s : s[1:prevind(s, idx)]) - isouter(parent_name) - end - function isouter(name::Symbol)::Bool - return name in outer_connectors - end - end + return IsOuter(outer_connectors) end @noinline function connection_error(ss) @@ -336,14 +373,24 @@ abstract type IsFrame end "Return true if the system is a 3D multibody frame, otherwise return false." function isframe(sys) - getmetadata(sys, IsFrame, false) + getmetadata(sys, IsFrame, false)::Bool end abstract type FrameOrientation end +struct RotationMatrix + R::Matrix{SymbolicT} + w::Vector{SymbolicT} + + function RotationMatrix(R::AbstractMatrix, w::AbstractVector) + new(unwrap_vars(R), unwrap_vars(w)) + end + +end + "Return orientation object of a multibody frame." function ori(sys) - getmetadata(sys, FrameOrientation, nothing) + getmetadata(sys, FrameOrientation, nothing)::Union{RotationMatrix, Nothing} end """ @@ -373,13 +420,13 @@ index_from_type(::Type{OutputVar{I}}) where {I} = I Chain `getproperty` calls on sys in the order given by `names` and return the unwrapped result. """ -function iterative_getproperty(sys::AbstractSystem, names::AbstractVector{Symbol}) +function iterative_getproperty(sys::AbstractSystem, names::Vector{Symbol}) # we don't want to namespace the first time - result = toggle_namespacing(sys, false) + result::Union{SymbolicT, System} = toggle_namespacing(sys, false) for name in names - result = getproperty(result, name) + result = getvar(result, name)::Union{SymbolicT, System} end - return unwrap(result) + return result end """ @@ -389,10 +436,13 @@ Return the variable/subsystem of `sys` referred to by vertex `vert`. """ function variable_from_vertex(sys::AbstractSystem, vert::ConnectionVertex) value = iterative_getproperty(sys, vert.name) - value isa AbstractSystem && return value + value isa System && return value + value = value::SymbolicT vert.type <: Union{InputVar, OutputVar} || return value + vert.type === InputVar{CartesianIndex()} && return value + vert.type === OutputVar{CartesianIndex()} && return value # index possibly array causal variable - unwrap(wrap(value)[index_from_type(vert.type)]) + value[index_from_type(vert.type)]::SymbolicT end """ @@ -408,7 +458,7 @@ function returned from [`generate_isouter`](@ref) for the system referred to by `namespace` must not contain the name of the root system. """ function generate_connectionsets!(connection_state::AbstractConnectionState, - namespace::Vector{Symbol}, connected, isouter) + namespace::Vector{Symbol}, connected, isouter::IsOuter) initial_len = length(namespace) _generate_connectionsets!(connection_state, namespace, connected, isouter) # Enforce postcondition as a sanity check that the namespacing is implemented correctly @@ -418,51 +468,57 @@ end function _generate_connectionsets!(connection_state::AbstractConnectionState, namespace::Vector{Symbol}, - connected_vars::Union{ - AbstractVector{SymbolicWithNameof}, Tuple{Vararg{SymbolicWithNameof}}}, - isouter) + connected_vars::Vector{SymbolicWithNameof}, + isouter::IsOuter) # unwrap the `SymbolicWithNameof` into the contained symbolic variables. - connected_vars = map(x -> x.var, connected_vars) - _generate_connectionsets!(connection_state, namespace, connected_vars, isouter) + _connected_vars = SymbolicT[] + for x in connected_vars + push!(_connected_vars, x.var) + end + _generate_connectionsets!(connection_state, namespace, _connected_vars, isouter) end -function _generate_connectionsets!(connection_state::AbstractConnectionState, - namespace::Vector{Symbol}, - connected_vars::Union{ - AbstractVector{<:BasicSymbolic}, Tuple{Vararg{BasicSymbolic}}}, - isouter) - # NOTE: variable connections don't populate the domain network +@noinline function throw_both_input_output(var::SymbolicT, connected_vars::Vector{SymbolicT}) + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is both input and output. + """)) +end +@noinline function throw_not_input_output(var::SymbolicT, connected_vars::Vector{SymbolicT}) + names = join(string.(connected_vars), ", ") + throw(ArgumentError(""" + Variable $var in connection `connect($names)` is neither input nor output. + """)) +end - # wrap to be able to call `eachindex` on a non-array variable - representative = wrap(first(connected_vars)) +function _generate_connectionsets_with_idxs!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, connected_vars::Vector{SymbolicT}, isouter::IsOuter, + idxs::CartesianIndices{N, NTuple{N, UnitRange{Int}}}) where {N} # all of them have the same size, but may have different axes/shape # so we iterate over `eachindex(eachindex(..))` since that is identical for all - for sz_i in eachindex(eachindex(representative)) - hyperedge = map(connected_vars) do var - var = unwrap(var) + for sz_i in eachindex(idxs) + hyperedge = ConnectionVertex[] + for var in connected_vars var_ns = namespace_hierarchy(getname(var)) - i = eachindex(wrap(var))[sz_i] - + if N === 0 + i = sz_i + else + i = (eachindex(var)::CartesianIndices{N, NTuple{N, UnitRange{Int}}})[sz_i]::CartesianIndex{N} + end is_input = isinput(var) is_output = isoutput(var) if is_input && is_output - names = join(string.(connected_vars), ", ") - throw(ArgumentError(""" - Variable $var in connection `connect($names)` is both input and output. - """)) + throw_both_input_output(var, connected_vars) elseif is_input type = InputVar{i} elseif is_output type = OutputVar{i} else - names = join(string.(connected_vars), ", ") - throw(ArgumentError(""" - Variable $var in connection `connect($names)` is neither input nor output. - """)) + throw_not_input_output(var, connected_vars) end - - return ConnectionVertex( + vert = ConnectionVertex( [namespace; var_ns], length(var_ns) == 1 || isouter(var_ns[1]), type) + push!(hyperedge, vert) end add_connection_edge!(connection_state, hyperedge) @@ -480,10 +536,36 @@ end function _generate_connectionsets!(connection_state::AbstractConnectionState, namespace::Vector{Symbol}, - systems::Union{AbstractVector{<:AbstractSystem}, Tuple{Vararg{AbstractSystem}}}, - isouter) + connected_vars::Vector{SymbolicT}, + isouter::IsOuter) + # NOTE: variable connections don't populate the domain network + + representative = first(connected_vars) + idxs = eachindex(representative) + # Manual dispatch for common cases + if idxs isa CartesianIndices{0, Tuple{}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + elseif idxs isa CartesianIndices{1, Tuple{UnitRange{Int}}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + elseif idxs isa CartesianIndices{2, NTuple{2, UnitRange{Int}}} + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + else + # Dynamic dispatch + _generate_connectionsets_with_idxs!(connection_state, namespace, connected_vars, + isouter, idxs) + end +end + +function _generate_connectionsets!(connection_state::AbstractConnectionState, + namespace::Vector{Symbol}, + systems::Vector{T}, + isouter::IsOuter) where {T <: AbstractSystem} + systems = systems::Vector{System} regular_systems = System[] - domain_system = nothing + domain_system::Union{Nothing, System} = nothing for s in systems if is_domain_connector(s) if domain_system === nothing @@ -518,7 +600,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, push!(domain_hyperedge, domain_vertex) push!(hyperedge, dv_vertex) - for (i, sys) in enumerate(systems) + for sys in systems sts = unknowns(sys) sys_is_outer = isouter(sys) @@ -526,6 +608,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # are properly namespaced sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) for v in sts vtype = get_connection_type(v) @@ -539,7 +622,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, push!(domain_hyperedge, sys_vertex) end # remember to remove the added namespace! - foreach(_ -> pop!(namespace), sys_ns) + resize!(namespace, N) end @assert length(hyperedge) > 1 @assert length(domain_hyperedge) == length(hyperedge) @@ -553,10 +636,10 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # Add 9 orientation variables if connection is between multibody frames if isframe(sys1) # Multibody O = ori(sys1) - orientation_vars = Symbolics.unwrap.(collect(vec(O.R))) - sys1_dvs = [sys1_dvs; orientation_vars] + orientation_vars = vec(O.R) + sys1_dvs = SymbolicT[sys1_dvs; orientation_vars] end - sys1_dvs_set = Set(sys1_dvs) + sys1_dvs_set = Set{SymbolicT}(sys1_dvs) num_unknowns = length(sys1_dvs) # We first build sets of all vertices that are connected together @@ -567,8 +650,8 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # Add 9 orientation variables if connection is between multibody frames if isframe(sys) # Multibody O = ori(sys) - orientation_vars = Symbolics.unwrap.(vec(O.R)) - unknown_vars = [unknown_vars; orientation_vars] + orientation_vars = vec(O.R) + unknown_vars = SymbolicT[unknown_vars; orientation_vars] end # Error if any subsequent systems do not have the same number of unknowns # or have unknowns not in the others. @@ -580,6 +663,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, # are properly namespaced sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) sys_is_outer = isouter(sys) for (j, v) in enumerate(unknown_vars) @@ -588,7 +672,7 @@ function _generate_connectionsets!(connection_state::AbstractConnectionState, domain_vertex = ConnectionVertex(namespace) push!(domain_hyperedge, domain_vertex) # remember to remove the added namespace! - foreach(_ -> pop!(namespace), sys_ns) + resize!(namespace, N) end for var_set in var_sets # all connected variables should have the same type @@ -630,35 +714,39 @@ can be pushed, unmodified. Connection equations update the given `state`. The eq present at the path in the hierarchical system given by `namespace`. `isouter` is the function returned from `generate_isouter`. """ -function handle_maybe_connect_equation!(eqs, state::AbstractConnectionState, - eq::Equation, namespace::Vector{Symbol}, isouter) - lhs = eq.lhs - rhs = eq.rhs +function handle_maybe_connect_equation!(eqs::Vector{Equation}, state::AbstractConnectionState, + eq::Equation, namespace::Vector{Symbol}, isouter::IsOuter) + lhs = value(eq.lhs) + rhs = value(eq.rhs) if !(lhs isa Connection) # split connections and equations - if eq.lhs isa AbstractArray || eq.rhs isa AbstractArray - append!(eqs, Symbolics.scalarize(eq)) - else - push!(eqs, eq) - end + push!(eqs, eq) return end + lhs = lhs::Connection + rhs = rhs::Connection + handle_maybe_connect_equation!(state, lhs, rhs, namespace, isouter) +end +function handle_maybe_connect_equation!(state::AbstractConnectionState, + lhs::Connection, rhs::Connection, namespace::Vector{Symbol}, isouter::IsOuter) if get_systems(lhs) === :domain # This is a domain connection, so we only update the domain connection graph - hyperedge = map(get_systems(rhs)) do sys - sys isa AbstractSystem || error("Domain connections can only connect systems!") + syss = get_systems(rhs)::Vector{System} + hyperedge = ConnectionVertex[] + for sys in syss sysname = nameof(sys) sys_ns = namespace_hierarchy(sysname) + N = length(namespace) append!(namespace, sys_ns) vertex = ConnectionVertex(namespace) - foreach(_ -> pop!(namespace), sys_ns) - return vertex + resize!(namespace, N) + push!(hyperedge, vertex) end add_domain_connection_edge!(state, hyperedge) else - connected_systems = get_systems(rhs) + connected_systems = get_systems(rhs)::Union{Vector{System}, Vector{SymbolicWithNameof}} generate_connectionsets!(state, namespace, connected_systems, isouter) end return nothing @@ -712,12 +800,13 @@ function _generate_connection_set!(connection_state::ConnectionState, end # go through the removed connections and update the negative graph - for conn in something(get_ignored_connections(sys), ()) - eq = Equation(Connection(), conn) - # there won't be any standard equations, so we can pass `nothing` instead of - # `eqs`. - handle_maybe_connect_equation!( - nothing, negative_connection_state, eq, namespace, isouter) + ignored = get_ignored_connections(sys) + if ignored isa Vector{Connection} + for conn in ignored + # there won't be any standard equations, so we can pass `nothing` instead of + # `eqs`. + handle_maybe_connect_equation!(negative_connection_state, Connection(), conn, namespace, isouter) + end end # all connectors are eventually inside connectors, and all flow variables @@ -732,17 +821,35 @@ function _generate_connection_set!(connection_state::ConnectionState, end pop!(namespace) end - - # recurse down the hierarchy - @set! sys.systems = map(subsys) do s - generate_connection_set!(connection_state, negative_connection_state, s, namespace) + new_systems = System[] + for s in subsys + news = generate_connection_set!(connection_state, negative_connection_state, s, namespace) + push!(new_systems, news) end + # recurse down the hierarchy + @set! sys.systems = new_systems @set! sys.eqs = eqs # Remember to pop the name at the end! does_namespacing(sys) && pop!(namespace) return sys end +function _flow_equations_from_idxs!(sys::AbstractSystem, eqs::Vector{Equation}, cset::Vector{ConnectionVertex}, len::Int) + add_buffer = SymbolicT[] + # each variable can have different axes, but they all have the same size + for sz_i in 1:len + empty!(add_buffer) + for cvert in cset + v = variable_from_vertex(sys, cvert)::SymbolicT + vidxs = SU.stable_eachindex(v) + v = v[vidxs[sz_i]] + push!(add_buffer, cvert.isouter ? -v : v) + end + rhs = SU.add_worker(VartypeT, add_buffer) + push!(eqs, Symbolics.COMMON_ZERO ~ rhs) + end +end + """ $(TYPEDSIGNATURES) @@ -756,7 +863,7 @@ function generate_connection_equations_and_stream_connections( for cset in csets cvert = cset[1] - var = variable_from_vertex(sys, cvert)::BasicSymbolic + var = variable_from_vertex(sys, cvert)::SymbolicT vtype = cvert.type if vtype <: Union{InputVar, OutputVar} length(cset) > 1 || continue @@ -782,10 +889,10 @@ function generate_connection_equations_and_stream_connections( end end root_vert = something(inner_output, outer_input) - root_var = variable_from_vertex(sys, root_vert) + root_var = variable_from_vertex(sys, root_vert)::SymbolicT for cvert in cset isequal(cvert, root_vert) && continue - push!(eqs, variable_from_vertex(sys, cvert) ~ root_var) + push!(eqs, variable_from_vertex(sys, cvert)::SymbolicT ~ root_var) end elseif vtype === Stream push!(stream_connections, cset) @@ -793,48 +900,39 @@ function generate_connection_equations_and_stream_connections( # arrays have to be broadcasted to be added/subtracted/negated which leads # to bad-looking equations. Just generate scalar equations instead since # mtkcompile will scalarize anyway. - representative = variable_from_vertex(sys, cset[1]) - # each variable can have different axes, but they all have the same size - for sz_i in eachindex(eachindex(wrap(representative))) - rhs = 0 - for cvert in cset - # all of this wrapping/unwrapping is necessary because the relevant - # methods are defined on `Arr/Num` and not `BasicSymbolic`. - v = variable_from_vertex(sys, cvert)::BasicSymbolic - idxs = eachindex(wrap(v)) - v = unwrap(wrap(v)[idxs[sz_i]]) - rhs += cvert.isouter ? unwrap(-wrap(v)) : v - end - push!(eqs, 0 ~ rhs) - end + representative = variable_from_vertex(sys, cset[1])::SymbolicT + _flow_equations_from_idxs!(sys, eqs, cset, length(representative)::Int) else # Equality - vars = map(Base.Fix1(variable_from_vertex, sys), cset) - outer_input = inner_output = nothing + vars = SymbolicT[] + for cvar in cset + push!(vars, variable_from_vertex(sys, cvar)::SymbolicT) + end + outer_input = inner_output = 0 all_io = true # attempt to interpret the equality as a causal connectionset if # possible - for (cvert, vert) in zip(cset, vars) + for (i, vert) in enumerate(vars) is_i = isinput(vert) is_o = isoutput(vert) all_io &= is_i || is_o all_io || break if cvert.isouter && is_i && outer_input === nothing - outer_input = cvert + outer_input = i elseif !cvert.isouter && is_o && inner_output === nothing - inner_output = cvert + inner_output = i end end # this doesn't necessarily mean this is a well-structured causal connection, # but it is sufficient and we're generating equalities anyway. - if all_io && xor(outer_input !== nothing, inner_output !== nothing) - root_vert = something(inner_output, outer_input) - root_var = variable_from_vertex(sys, root_vert) - for (cvert, var) in zip(cset, vars) - isequal(cvert, root_vert) && continue + if all_io && xor(!iszero(outer_input), !iszero(inner_output)) + root_vert_i = iszero(outer_input) ? inner_output : outer_input + root_var = vars[root_vert_i] + for (i, var) in enumerate(vars) + i == root_vert_i && continue push!(eqs, var ~ root_var) end else - base = variable_from_vertex(sys, cset[1]) + base = vars[1] for i in 2:length(cset) v = vars[i] push!(eqs, base ~ v) @@ -852,12 +950,15 @@ Generate the defaults for parameters in the domain sets given by `domain_csets`. """ function domain_defaults( sys::AbstractSystem, domain_csets::Vector{Vector{ConnectionVertex}}) - defs = Dict() + defs = Dict{SymbolicT, SymbolicT}() for cset in domain_csets - systems = map(Base.Fix1(variable_from_vertex, sys), cset) - @assert all(x -> x isa AbstractSystem, systems) + systems = System[] + for cvar in cset + push!(systems, variable_from_vertex(sys, cvar)::System) + end idx = findfirst(is_domain_connector, systems) idx === nothing && continue + idx = idx::Int domain_sys = systems[idx] # note that these will not be namespaced with `domain_sys`. domain_defs = defaults(domain_sys) @@ -871,7 +972,7 @@ function domain_defaults( for par in parameters(csys) defval = get(domain_defs, par, nothing) defval === nothing && continue - defs[parameters(csys, par)] = parameters(domain_sys, par) + defs[renamespace(csys, par)] = renamespace(domain_sys, par) end end end @@ -915,16 +1016,24 @@ Given a connection vertex `cvert` referring to a variable in a connector in `sys the flow variable in that connector. """ function get_flowvar(sys::AbstractSystem, cvert::ConnectionVertex) - parent_names = @view cvert.name[1:(end - 1)] - parent_sys = iterative_getproperty(sys, parent_names) + tmp = pop!(cvert.name) + parent_sys = iterative_getproperty(sys, cvert.name)::System + push!(cvert.name, tmp) for var in unknowns(parent_sys) type = get_connection_type(var) - type == Flow || continue - return unwrap(unknowns(parent_sys, var)) + type === Flow || continue + return renamespace(parent_sys, var) end throw(ArgumentError("There is no flow variable in system `$(nameof(parent_sys))`")) end +function instream_is_atomic(ex::SymbolicT) + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f === instream end => true + _ => false + end +end + """ $(TYPEDSIGNATURES) @@ -937,43 +1046,54 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS tol = 1e-8) eqs = equations(sys) # collect all `instream` terms in the equations - instream_exprs = Set{BasicSymbolic}() + instream_exprs = Set{SymbolicT}() for eq in eqs - collect_instream!(instream_exprs, eq) + SU.search_variables!(instream_exprs, eq; is_atomic = instream_is_atomic) end # specifically substitute `instream(x[i]) => instream(x)[i]` - instream_subs = Dict{BasicSymbolic, BasicSymbolic}() + instream_subs = Dict{SymbolicT, SymbolicT}() for expr in instream_exprs - stream_var = only(arguments(expr)) - iscall(stream_var) && operation(stream_var) === getindex || continue - args = arguments(stream_var) - new_expr = Symbolics.array_term( - instream, args[1]; size = size(args[1]), ndims = ndims(args[1]))[args[2:end]...] - instream_subs[expr] = new_expr + exargs = Moshi.Data.variant_getfield(expr, BSImpl.Term{VartypeT}, :args) + stream_var = only(exargs) + Moshi.Match.@match stream_var begin + BSImpl.Term(; f, args, type, shape) && if f === getindex end => begin + newargs = copy(parent(args)) + arg = newargs[1] + sharg = SU.shape(arg) + starg = SU.symtype(arg) + newargs[1] = BSImpl.Term{VartypeT}(instream, SArgsT((arg,)); type = starg, shape = sharg) + new_expr = BSImpl.Term{VartypeT}(getindex, newargs; type, shape) + instream_subs[expr] = new_expr + end + _ => nothing + end end # for all the newly added `instream(x)[i]`, add `instream(x)` to `instream_exprs` # also remove all `instream(x[i])` for (k, v) in instream_subs - push!(instream_exprs, arguments(v)[1]) + push!(instream_exprs, Moshi.Match.@match v begin + BSImpl.Term(; args) => args[1] + end) delete!(instream_exprs, k) end # This is an implementation of the modelica spec # https://specification.modelica.org/maint/3.6/stream-connectors.html additional_eqs = Equation[] + add_buffer = SymbolicT[] for cset in csets n_outer = count(cvert -> cvert.isouter, cset) n_inner = length(cset) - n_outer if n_inner == 1 && n_outer == 0 cvert = only(cset) - stream_var = variable_from_vertex(sys, cvert)::BasicSymbolic + stream_var = variable_from_vertex(sys, cvert)::SymbolicT instream_subs[instream(stream_var)] = stream_var elseif n_inner == 2 && n_outer == 0 cvert1, cvert2 = cset - stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic - stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + stream_var1 = variable_from_vertex(sys, cvert1)::SymbolicT + stream_var2 = variable_from_vertex(sys, cvert2)::SymbolicT instream_subs[instream(stream_var1)] = stream_var2 instream_subs[instream(stream_var2)] = stream_var1 elseif n_inner == 1 && n_outer == 1 @@ -981,14 +1101,14 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS if cvert_inner.isouter cvert_inner, cvert_outer = cvert_outer, cvert_inner end - streamvar_inner = variable_from_vertex(sys, cvert_inner)::BasicSymbolic - streamvar_outer = variable_from_vertex(sys, cvert_outer)::BasicSymbolic + streamvar_inner = variable_from_vertex(sys, cvert_inner)::SymbolicT + streamvar_outer = variable_from_vertex(sys, cvert_outer)::SymbolicT instream_subs[instream(streamvar_inner)] = instream(streamvar_outer) push!(additional_eqs, (streamvar_outer ~ streamvar_inner)) elseif n_inner == 0 && n_outer == 2 cvert1, cvert2 = cset - stream_var1 = variable_from_vertex(sys, cvert1)::BasicSymbolic - stream_var2 = variable_from_vertex(sys, cvert2)::BasicSymbolic + stream_var1 = variable_from_vertex(sys, cvert1)::SymbolicT + stream_var2 = variable_from_vertex(sys, cvert2)::SymbolicT push!(additional_eqs, (stream_var1 ~ instream(stream_var2)), (stream_var2 ~ instream(stream_var1))) else @@ -997,55 +1117,70 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS # https://specification.modelica.org/maint/3.6/stream-connectors.html#instream-and-connection-equations # We could implement the "if" case using variable bounds? It would be nice to # move that metadata to the system (storing it similar to `defaults`). - outer_cverts = filter(cvert -> cvert.isouter, cset) - inner_cverts = filter(cvert -> !cvert.isouter, cset) - - outer_streamvars = map(Base.Fix1(variable_from_vertex, sys), outer_cverts) - inner_streamvars = map(Base.Fix1(variable_from_vertex, sys), inner_cverts) - - outer_flowvars = map(Base.Fix1(get_flowvar, sys), outer_cverts) - inner_flowvars = map(Base.Fix1(get_flowvar, sys), inner_cverts) + outer_cverts = ConnectionVertex[] + inner_cverts = ConnectionVertex[] + outer_streamvars = SymbolicT[] + inner_streamvars = SymbolicT[] + outer_flowvars = SymbolicT[] + inner_flowvars = SymbolicT[] + for cvert in cset + svar = variable_from_vertex(sys, cvert)::SymbolicT + fvar = get_flowvar(sys, cvert)::SymbolicT + push!(cvert.isouter ? outer_cverts : inner_cverts, cvert) + push!(cvert.isouter ? outer_streamvars : inner_streamvars, svar) + push!(cvert.isouter ? outer_flowvars : inner_flowvars, fvar) + end - mask = trues(length(inner_cverts)) for inner_i in eachindex(inner_cverts) - # mask out the current variable - mask[inner_i] = false svar = inner_streamvars[inner_i] - instream_subs[instream(svar)] = term( - instream_rt, Val(n_inner - 1), Val(n_outer), inner_flowvars[mask]..., - inner_streamvars[mask]..., outer_flowvars..., outer_streamvars...) - # make sure to reset the mask - mask[inner_i] = true + args = SArgsT() + push!(args, SU.Const{VartypeT}(Val(n_inner - 1))) + push!(args, SU.Const{VartypeT}(Val(n_outer - 1))) + for i in eachindex(inner_cverts) + i == inner_i && continue + push!(args, inner_flowvars[i]) + end + for i in eachindex(inner_cverts) + i == inner_i && continue + push!(args, inner_streamvars[i]) + end + append!(args, outer_flowvars) + append!(args, outer_streamvars) + expr = BSImpl.Term{VartypeT}(instream_rt, args; + type = Real, shape = SU.ShapeVecT()) + instream_subs[instream(svar)] = expr end for q in 1:n_outer - sq = mapreduce(+, inner_flowvars) do fvar - max(-fvar, 0) + empty!(add_buffer) + for fvar in inner_flowvars + push!(add_buffer, max(-fvar, 0)) end - sq += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) - outer_i == q && return 0 - max(fvar, 0) + for (i, fvar) in outer_flowvars + i == q && continue + push!(add_buffer, max(fvar, 0)) end - # sanity check to make sure it isn't going to codegen a `mapreduce` - @assert operation(sq) == (+) + sq = SU.add_worker(VartypeT, add_buffer) - num = mapreduce(+, inner_flowvars, inner_streamvars) do fvar, svar - positivemax(-fvar, sq; tol) * svar + empty!(add_buffer) + for (fvar, svar) in zip(inner_flowvars, inner_streamvars) + push!(add_buffer, positivemax(-fvar, sq; tol) * svar) end - num += mapreduce( - +, enumerate(outer_flowvars), outer_streamvars) do (outer_i, fvar), svar - outer_i == q && return 0 - positivemax(fvar, sq; tol) * instream(svar) + for (i, (fvar, svar)) in enumerate(zip(outer_flowvars, outer_streamvars)) + i == q && continue + push!(add_buffer, positivemax(fvar, sq; tol) * instream(svar)) end - @assert operation(num) == (+) + num = SU.add_worker(VartypeT, add_buffer) - den = mapreduce(+, inner_flowvars) do fvar - positivemax(-fvar, sq; tol) + empty!(add_buffer) + for fvar in inner_flowvars + push!(add_buffer, positivemax(-fvar, sq; tol)) end - den += mapreduce(+, enumerate(outer_flowvars)) do (outer_i, fvar) - outer_i == q && return 0 - positivemax(fvar, sq; tol) + for (i, fvar) in enumerate(outer_flowvars) + i == q && continue + push!(add_buffer, positivemax(fvar, sq; tol)) end + den = SU.add_worker(VartypeT, add_buffer) push!(additional_eqs, (outer_streamvars[q] ~ num / den)) end @@ -1116,4 +1251,4 @@ function instream_rt(ins::Val{inner_n}, outs::Val{outer_n}, for k in 1:M and ck.m_flow.max > 0 =# end -SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Vararg) = Real +SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Type{T}, ::Type{S}, ::Type{R}) where {T, S, R} = Real diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index f823260e2a..b7baad8861 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -142,7 +142,7 @@ function change_of_variables( for (new_var, ex, first, second) in zip(new_vars, dfdt, ∂f∂x, ∂2f∂x2) for (eqs, neq) in zip(old_eqs, neqs) - if occursin(value(eqs.lhs), value(ex)) + if SU.query(isequal(value(eqs.lhs)), value(ex)) ex = substitute(ex, eqs.lhs => eqs.rhs) if isSDE for (noise, B) in zip(neq, brownvars) @@ -470,7 +470,7 @@ julia> M = change_independent_variable(M, x); julia> M = mtkcompile(M; allow_symbolic = true); julia> unknowns(M) -3-element Vector{SymbolicUtils.BasicSymbolic{Real}}: +3-element Vector{Symbolics.SymbolicsT}: xˍt(x) y(x) yˍx(x) @@ -836,33 +836,43 @@ function noise_to_brownians(sys::System; names::Union{Symbol, Vector{Symbol}} = if neqs === nothing throw(ArgumentError("Expected a system with `noise_eqs`.")) end + neqs = neqs::Union{Vector{SymbolicT}, Matrix{SymbolicT}} if !isempty(get_systems(sys)) throw(ArgumentError("The system must be flattened.")) end # vector means diagonal noise - nbrownians = ndims(neqs) == 1 ? length(neqs) : size(neqs, 2) - if names isa Symbol - names = [Symbol(names, :_, i) for i in 1:nbrownians] + nbrownians = if neqs isa Vector{SymbolicT} + length(neqs) + elseif neqs isa Matrix{SymbolicT} + size(neqs, 2) end - if length(names) != nbrownians + if names isa Symbol + _names = Symbol[] + for i in 1:nbrownians + push!(_names, Symbol(names, :_, i)) + end + names = _names + elseif names isa Vector{Symbol} && length(names) != nbrownians throw(ArgumentError(""" The system has $nbrownians brownian variables. Received $(length(names)) names \ for the brownian variables. Provide $nbrownians names or a single `Symbol` to use \ an array variable of the appropriately length. """)) end - brownvars = map(names) do name - unwrap(only(@brownians $name)) + names = names::Vector{Symbol} + brownvars = SymbolicT[] + for name in names + push!(brownvars, unwrap(only(@brownians $name))) end - - terms = if ndims(neqs) == 1 + terms = if neqs isa Vector{SymbolicT} neqs .* brownvars - else + elseif neqs isa Matrix{SymbolicT} neqs * brownvars end - eqs = map(get_eqs(sys), terms) do eq, term - eq.lhs ~ eq.rhs + term + eqs = Equation[] + for (eq, term) in zip(get_eqs(sys), terms) + push!(eqs, eq.lhs ~ eq.rhs + term) end @set! sys.eqs = eqs @@ -1037,11 +1047,13 @@ function respecialize(sys::AbstractSystem, mapping; all = false) """ if iscall(k) - op = operation(k) + op = operation(k)::SymbolicT + @assert !iscall(op) + op = SU.Sym{VartypeT}(nameof(op); type = SU.FnType{Tuple, T, Nothing}, shape = SU.shape(k)) args = arguments(k) - new_p = SymbolicUtils.term(op, args...; type = T) + new_p = op(args...) else - new_p = SymbolicUtils.Sym{T}(getname(k)) + new_p = SSym(getname(k); type = T, shape = SU.shape(v)) end get_ps(sys)[idx] = new_p @@ -1049,7 +1061,7 @@ function respecialize(sys::AbstractSystem, mapping; all = false) subrules[unwrap(k)] = unwrap(new_p) end - substituter = Base.Fix2(fast_substitute, subrules) + substituter = Base.Fix2(substitute, subrules) @set! sys.eqs = map(substituter, get_eqs(sys)) @set! sys.observed = map(substituter, get_observed(sys)) @set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys)) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 1c43022f4b..7fc0c6abe1 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -67,10 +67,10 @@ function ImperativeAffect(; f, kwargs...) ImperativeAffect(f; kwargs...) end -function Symbolics.fast_substitute(aff::ImperativeAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms, - map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) +function (s::SymbolicUtils.Substituter)(aff::ImperativeAffect) + ImperativeAffect(aff.f, s(aff.obs), aff.obs_syms, + s(aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) + end function Base.show(io::IO, mfa::ImperativeAffect) @@ -85,10 +85,16 @@ context(a::ImperativeAffect) = a.ctx observed(a::ImperativeAffect) = a.obs observed_syms(a::ImperativeAffect) = a.obs_syms function discretes(a::ImperativeAffect) - Iterators.filter(ModelingToolkit.isparameter, - Iterators.flatten(Iterators.map( - x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x], - a.modified))) + discs = SymbolicT[] + for val in a.modified + val = unwrap(val) + if val isa SymbolicT + isparameter(a) && push!(discs, val) + elseif val isa AbstractArray + append!(discs, filter(isparameter, map(unwrap, val))) + end + end + return discs end modified(a::ImperativeAffect) = a.modified modified_syms(a::ImperativeAffect) = a.mod_syms diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 19c78413cf..e00a1a250f 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,13 +1,10 @@ +const TypeT = Union{DataType, UnionAll, Union} + struct BufferTemplate - type::Union{DataType, UnionAll, Union} + type::TypeT length::Int end -function BufferTemplate(s::Type{<:Symbolics.Struct}, length::Int) - T = Symbolics.juliatype(s) - BufferTemplate(T, length) -end - struct Nonnumeric <: SciMLStructures.AbstractPortion end const NONNUMERIC_PORTION = Nonnumeric() @@ -31,16 +28,15 @@ struct DiscreteIndex idx_in_clock::Int end -const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} -const NonnumericMap = Dict{ - Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}} -const UnknownIndexMap = Dict{ - BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} -const TunableIndexMap = Dict{BasicSymbolic, - Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} +const MaybeUnknownArrayIndexT = Union{Int, UnitRange{Int}, AbstractArray{Int}} +const MaybeArrayIndexT = Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} +const ParamIndexMap = Dict{SymbolicT, Tuple{Int, Int}} +const NonnumericMap = Dict{SymbolicT, Tuple{Int, Int}} +const UnknownIndexMap = Dict{SymbolicT, MaybeUnknownArrayIndexT} +const TunableIndexMap = Dict{SymbolicT, MaybeArrayIndexT} const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} -const SymbolicParam = Union{BasicSymbolic, CallWithMetadata} +const SymbolicParam = SymbolicT struct IndexCache unknown_idx::UnknownIndexMap @@ -52,9 +48,8 @@ struct IndexCache initials_idx::TunableIndexMap constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap - observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} - dependent_pars_to_timeseries::Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType} + observed_syms_to_timeseries::Dict{SymbolicT, TimeseriesSetType} + dependent_pars_to_timeseries::Dict{SymbolicT, TimeseriesSetType} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate initials_buffer_size::BufferTemplate @@ -70,26 +65,36 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks - usym = unwrap(sym) - rsym = renamespace(sys, usym) - sym_idx = if Symbolics.isarraysymbolic(sym) + rsym = renamespace(sys, sym) + sym_idx::MaybeUnknownArrayIndexT = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end - unk_idxs[usym] = sym_idx + unk_idxs[sym] = sym_idx unk_idxs[rsym] = sym_idx idx += length(sym) end + found_array_syms = Set{SymbolicT}() for sym in unks - usym = unwrap(sym) iscall(sym) && operation(sym) === getindex || continue arrsym = arguments(sym)[1] - all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue - - idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)] + arrsym in found_array_syms && continue + idxs = Int[] + valid_arrsym = true + for i in eachindex(arrsym) + idxsym = arrsym[i] + idx = get(unk_idxs, idxsym, nothing)::Union{Int, Nothing} + valid_arrsym = idx !== nothing + valid_arrsym || break + push!(idxs, idx::Int) + end + push!(found_array_syms, arrsym) + valid_arrsym || break if idxs == idxs[begin]:idxs[end] - idxs = reshape(idxs[begin]:idxs[end], size(idxs)) + idxs = reshape(idxs[begin]:idxs[end], size(arrsym))::AbstractArray{Int} + else + idxs = reshape(idxs, size(arrsym))::AbstractArray{Int} end rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs @@ -97,62 +102,24 @@ function IndexCache(sys::AbstractSystem) end end - tunable_pars = BasicSymbolic[] - initial_pars = BasicSymbolic[] - constant_buffers = Dict{Any, Set{BasicSymbolic}}() - nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}() - - function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S} - sym = unwrap(sym) - buf = get!(buffers, ctype, S()) - push!(buf, sym) - end - function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype) - sym = unwrap(sym) - push!(buffers, sym) - end - - disc_param_callbacks = Dict{SymbolicParam, Set{Int}}() - events = vcat(continuous_events(sys), discrete_events(sys)) - for (i, event) in enumerate(events) - discs = Set{SymbolicParam}() - affs = affects(event) - if !(affs isa AbstractArray) - affs = [affs] - end - for affect in affs - if affect isa AffectSystem || affect isa ImperativeAffect - union!(discs, unwrap.(discretes(affect))) - elseif isnothing(affect) - continue - else - error("Unhandled affect type $(typeof(affect))") - end - end - - for sym in discs - is_parameter(sys, sym) || - error("Expected discrete variable $sym in callback to be a parameter") - - # Only `foo(t)`-esque parameters can be saved - if iscall(sym) && length(arguments(sym)) == 1 && - isequal(only(arguments(sym)), get_iv(sys)) - clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym) - push!(clocks, i) - elseif is_variable_floatingpoint(sym) - insert_by_type!(constant_buffers, sym, symtype(sym)) - else - stype = symtype(sym) - if stype <: FnType - stype = fntype_to_function_type(stype) - end - insert_by_type!(nonnumeric_buffers, sym, stype) - end - end - end - clock_partitions = unique(collect(values(disc_param_callbacks))) - disc_symtypes = unique(symtype.(keys(disc_param_callbacks))) - disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes)) + tunable_pars = SymbolicT[] + initial_pars = SymbolicT[] + constant_buffers = Dict{TypeT, Set{SymbolicT}}() + nonnumeric_buffers = Dict{TypeT, Set{SymbolicT}}() + + disc_param_callbacks = Dict{SymbolicParam, BitSet}() + cevs = continuous_events(sys) + devs = discrete_events(sys) + events = Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}[cevs; devs] + parse_callbacks_for_discretes!(cevs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, 0) + parse_callbacks_for_discretes!(devs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, length(cevs)) + clock_partitions = unique(collect(values(disc_param_callbacks)))::Vector{BitSet} + disc_symtypes = Set{TypeT}() + for x in keys(disc_param_callbacks) + push!(disc_symtypes, symtype(x)) + end + disc_symtypes = collect(disc_symtypes)::Vector{TypeT} + disc_symtype_idx = Dict{TypeT, Int}(zip(disc_symtypes, eachindex(disc_symtypes))) disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes] for sym in keys(disc_param_callbacks) push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym) @@ -160,13 +127,12 @@ function IndexCache(sys::AbstractSystem) disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes] for (i, buffer) in enumerate(disc_syms_by_symtype) for partition in clock_partitions - push!(disc_syms_by_symtype_by_partition[i], - [sym for sym in buffer if disc_param_callbacks[sym] == partition]) + push!(disc_syms_by_symtype_by_partition[i], filter(==(partition) ∘ Base.Fix1(getindex, disc_param_callbacks), buffer)) end end disc_idxs = Dict{SymbolicParam, DiscreteIndex}() callback_to_clocks = Dict{ - Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}() + Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, BitSet}() for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition) symi = 0 for (parti, disc_syms) in enumerate(disc_syms_by_partition) @@ -194,26 +160,24 @@ function IndexCache(sys::AbstractSystem) disc_buffer_templates = Vector{BufferTemplate}[] for (symtype, disc_syms_by_partition) in zip( disc_symtypes, disc_syms_by_symtype_by_partition) - push!(disc_buffer_templates, - [BufferTemplate(symtype, length(buf)) for buf in disc_syms_by_partition]) + push!(disc_buffer_templates, map(Base.Fix1(BufferTemplate, symtype) ∘ length, disc_syms_by_partition)) end for p in parameters(sys; initial_parameters = true) - p = unwrap(p) ctype = symtype(p) if ctype <: FnType - ctype = fntype_to_function_type(ctype) + ctype = fntype_to_function_type(ctype)::TypeT end haskey(disc_idxs, p) && continue haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() && + if istunable(p, true) && symbolic_has_known_size(p) && (ctype == Real || ctype <: AbstractFloat || ctype <: AbstractArray{Real} || ctype <: AbstractArray{<:AbstractFloat}) - if iscall(p) && operation(p) isa Initial + if iscall(p) && operation(p) === Initial() initial_pars else tunable_pars @@ -229,33 +193,10 @@ function IndexCache(sys::AbstractSystem) ) end - function get_buffer_sizes_and_idxs(T, buffers::Dict) - idxs = T() - buffer_sizes = BufferTemplate[] - for (i, (T, buf)) in enumerate(buffers) - for (j, p) in enumerate(buf) - ttp = default_toterm(p) - rp = renamespace(sys, p) - rttp = renamespace(sys, ttp) - idxs[p] = (i, j) - idxs[ttp] = (i, j) - idxs[rp] = (i, j) - idxs[rttp] = (i, j) - end - if T <: Symbolics.FnType - T = Any - end - push!(buffer_sizes, BufferTemplate(T, length(buf))) - end - return idxs, buffer_sizes - end - const_idxs, - const_buffer_sizes = get_buffer_sizes_and_idxs( - ParamIndexMap, constant_buffers) + const_buffer_sizes = get_buffer_sizes_and_idxs(ParamIndexMap, sys, constant_buffers) nonnumeric_idxs, - nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs( - NonnumericMap, nonnumeric_buffers) + nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(NonnumericMap, sys, nonnumeric_buffers) tunable_idxs = TunableIndexMap() tunable_buffer_size = 0 @@ -264,7 +205,8 @@ function IndexCache(sys::AbstractSystem) empty!(initial_pars) end for p in tunable_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) tunable_buffer_size + 1 else reshape( @@ -282,7 +224,8 @@ function IndexCache(sys::AbstractSystem) initials_idxs = TunableIndexMap() initials_buffer_size = 0 for p in initial_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) initials_buffer_size + 1 else reshape( @@ -300,24 +243,27 @@ function IndexCache(sys::AbstractSystem) for k in collect(keys(tunable_idxs)) v = tunable_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk::SymbolicT, vv) in zip(iter, v) tunable_idxs[kk] = vv end end for k in collect(keys(initials_idxs)) v = initials_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk, vv) in zip(iter, v) initials_idxs[kk] = vv end end - dependent_pars_to_timeseries = Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}() - + dependent_pars_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() + vs = Set{SymbolicT}() for eq in get_parameter_dependencies(sys) sym = eq.lhs - vs = vars(eq.rhs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs @@ -331,24 +277,29 @@ function IndexCache(sys::AbstractSystem) rttsym = renamespace(sys, ttsym) for s in (sym, ttsym, rsym, rttsym) dependent_pars_to_timeseries[s] = timeseries - if hasname(s) && (!iscall(s) || operation(s) != getindex) + if hasname(s) && (!iscall(s) || operation(s) !== getindex) symbol_to_variable[getname(s)] = sym end end end - observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}() + observed_syms_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() for eq in observed(sys) if symbolic_type(eq.lhs) != NotSymbolic() sym = eq.lhs - vs = vars(eq.rhs; op = Nothing) + empty!(vs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs if (idx = get(disc_idxs, v, nothing)) !== nothing push!(timeseries, idx.clock_idx) - elseif iscall(v) && operation(v) === getindex && - (idx = get(disc_idxs, arguments(v)[1], nothing)) !== nothing + elseif Moshi.Match.@match v begin + BSImpl.Term(; f, args) => begin + f === getindex && (idx = get(disc_idxs, args[1], nothing)) !== nothing + end + _ => false + end push!(timeseries, idx.clock_idx) elseif haskey(observed_syms_to_timeseries, v) union!(timeseries, observed_syms_to_timeseries[v]) @@ -369,13 +320,12 @@ function IndexCache(sys::AbstractSystem) end end - for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), - keys(const_idxs), keys(nonnumeric_idxs), - keys(observed_syms_to_timeseries), independent_variable_symbols(sys))) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = sym - end - end + populate_symbol_to_var!(symbol_to_variable, keys(unk_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(disc_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(tunable_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(const_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(nonnumeric_idxs)) + populate_symbol_to_var!(symbol_to_variable, independent_variable_symbols(sys)) return IndexCache( unk_idxs, @@ -396,36 +346,117 @@ function IndexCache(sys::AbstractSystem) ) end +function populate_symbol_to_var!(symbol_to_variable::Dict{Symbol, SymbolicT}, vars) + for sym::SymbolicT in vars + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end +end + +""" + $TYPEDSIGNATURES + +Utility function for the `IndexCache` constructor. +""" +function insert_by_type!(buffers::Dict{TypeT, Set{SymbolicT}}, sym::SymbolicT, ctype::TypeT) + buf = get!(Set{SymbolicT}, buffers, ctype) + push!(buf, sym) +end +function insert_by_type!(buffers::Vector{SymbolicT}, sym::SymbolicT, ::TypeT) + push!(buffers, sym) +end + +function parse_callbacks_for_discretes!(events::Vector, disc_param_callbacks::Dict{SymbolicT, BitSet}, constant_buffers::Dict{TypeT, Set{SymbolicT}}, nonnumeric_buffers::Dict{TypeT, Set{SymbolicT}}, offset::Int) + for (i, event) in enumerate(events) + discs = Set{SymbolicParam}() + affect = event.affect::Union{AffectSystem, ImperativeAffect, Nothing} + if affect isa AffectSystem || affect isa ImperativeAffect + union!(discs, discretes(affect)) + elseif affect === nothing + continue + end + + for sym in discs + is_parameter(sys, sym) || + error("Expected discrete variable $sym in callback to be a parameter") + + # Only `foo(t)`-esque parameters can be saved + if iscall(sym) && length(arguments(sym)) == 1 && + isequal(only(arguments(sym)), get_iv(sys)) + clocks = get!(BitSet, disc_param_callbacks, sym) + push!(clocks, i + offset) + elseif is_variable_floatingpoint(sym) + insert_by_type!(constant_buffers, sym, symtype(sym)) + else + stype = symtype(sym) + if stype <: FnType + stype = fntype_to_function_type(stype)::TypeT + end + insert_by_type!(nonnumeric_buffers, sym, stype) + end + end + end +end + +function get_buffer_sizes_and_idxs(::Type{BufT}, sys::AbstractSystem, buffers::Dict) where {BufT} + idxs = BufT() + buffer_sizes = BufferTemplate[] + for (i, (T, buf)) in enumerate(buffers) + for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) + idxs[p] = (i, j) + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) + end + if T <: Symbolics.FnType + T = Any + end + push!(buffer_sizes, BufferTemplate(T, length(buf))) + end + return idxs, buffer_sizes +end + function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) variable_index(ic, sym) !== nothing end -function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return nothing - end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + variable_index(ic, unwrap(sym)) +end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::Symbol) + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + variable_index(ic, sym) +end +function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym::SymbolicT) idx = check_index_map(ic.unknown_idx, sym) idx === nothing || return idx iscall(sym) && operation(sym) == getindex || return nothing args = arguments(sym) idx = variable_index(ic, args[1]) idx === nothing && return nothing - return idx[args[2:end]...] + return idx[unwrap_const.(args[2:end])...] end +SymbolicIndexingInterface.variable_index(ic::IndexCache, sym) = false function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) parameter_index(ic, sym) !== nothing end -function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return nothing - end - sym = unwrap(sym) - validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray && - Symbolics.shape(sym) !== Symbolics.Unknown() +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + parameter_index(ic, unwrap(sym)) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Symbol) + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + parameter_index(ic, sym) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::SymbolicT) + validate_size = Symbolics.isarraysymbolic(sym) && symbolic_has_known_size(sym) return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing ParameterIndex(SciMLStructures.Tunable(), idx, validate_size) elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing @@ -472,77 +503,81 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...)) end -function check_index_map(idxmap, sym) - if (idx = get(idxmap, sym, nothing)) !== nothing - return idx - elseif !isa(sym, Symbol) && (!iscall(sym) || operation(sym) !== getindex) && - hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing - return idx - end +function check_index_map(idxmap::Dict{SymbolicT, V}, sym::SymbolicT)::Union{V, Nothing} where {V} + idx = get(idxmap, sym, nothing) + idx === nothing || return idx dsym = default_toterm(sym) isequal(sym, dsym) && return nothing - if (idx = get(idxmap, dsym, nothing)) !== nothing - idx - elseif !isa(dsym, Symbol) && (!iscall(dsym) || operation(dsym) !== getindex) && - hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing - idx - else - nothing - end + idx = get(idxmap, dsym, nothing) + idx === nothing || return idx + return nothing end +const ReorderedParametersT = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}} + function reorder_parameters( sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing - reorder_parameters(get_index_cache(sys), ps; kwargs...) + return reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) elseif ps isa Tuple - ps + return ReorderedParametersT(collect(ps)) else - (ps,) + return eltype(ReorderedParametersT)[ps] end end -function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true) - isempty(ps) && return () - param_buf = if ic.tunable_buffer_size.length == 0 - () - else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.tunable_buffer_size.length)],) +const COMMON_DEFAULT_VAR = unwrap(only(@variables __DEF__)) + +function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing = false, flatten = true) + result = ReorderedParametersT() + isempty(ps) && return result + param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length) + if !isempty(param_buf) || !flatten + push!(result, param_buf) + end + initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length) + if !isempty(initials_buf) || !flatten + push!(result, initials_buf) + end + + disc_buf = Vector{SymbolicT}[] + for bufszs in ic.discrete_buffer_sizes + push!(disc_buf, fill(COMMON_DEFAULT_VAR, sum(x -> x.length, bufszs))) + end + const_buf = Vector{SymbolicT}[] + for bufsz in ic.constant_buffer_sizes + push!(const_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) end - initials_buf = if ic.initials_buffer_size.length == 0 - () + nonnumeric_buf = Vector{SymbolicT}[] + for bufsz in ic.nonnumeric_buffer_sizes + push!(nonnumeric_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) + end + if flatten + append!(result, disc_buf) + append!(result, const_buf) + append!(result, nonnumeric_buf) else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.initials_buffer_size.length)],) - end - - disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(sum(x -> x.length, temp))] - for temp in ic.discrete_buffer_sizes) - const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.constant_buffer_sizes) - nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF)) - for _ in 1:(temp.length)] - for temp in ic.nonnumeric_buffer_sizes) + push!(result, disc_buf) + push!(result, const_buf) + push!(result, nonnumeric_buf) + end for p in ps - p = unwrap(p) if haskey(ic.discrete_idx, p) idx = ic.discrete_idx[p] disc_buf[idx.buffer_idx][idx.idx_in_buffer] = p elseif haskey(ic.tunable_idx, p) i = ic.tunable_idx[p] if i isa Int - param_buf[1][i] = unwrap(p) + param_buf[i] = p else - param_buf[1][i] = unwrap.(collect(p)) + param_buf[i] = collect(p) end elseif haskey(ic.initials_idx, p) i = ic.initials_idx[p] if i isa Int - initials_buf[1][i] = unwrap(p) + initials_buf[i] = p else - initials_buf[1][i] = unwrap.(collect(p)) + initials_buf[i] = collect(p) end elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] @@ -555,37 +590,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = end end - param_buf = broadcast.(unwrap, param_buf) - initials_buf = broadcast.(unwrap, initials_buf) - disc_buf = broadcast.(unwrap, disc_buf) - const_buf = broadcast.(unwrap, const_buf) - nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf) - if drop_missing - filterer = !isequal(unwrap(variable(:DEF))) - param_buf = filter.(filterer, param_buf) - initials_buf = filter.(filterer, initials_buf) - disc_buf = filter.(filterer, disc_buf) - const_buf = filter.(filterer, const_buf) - nonnumeric_buf = filter.(filterer, nonnumeric_buf) - end - - if flatten - result = ( - param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...) - if all(isempty, result) - return () - end - return result - else - if isempty(param_buf) - param_buf = ((),) - end - if isempty(initials_buf) - initials_buf = ((),) + filterer = !isequal(COMMON_DEFAULT_VAR) + for inner in result + if inner isa Vector{SymbolicT} + filter!(filterer, inner) + elseif inner isa Vector{Vector{SymbolicT}} + for buf in inner + filter!(filterer, buf) + end + end end - return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf) end + + return result end # Given a parameter index, find the index of the buffer it is in when diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 699cfee8fd..89737bf6b9 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -259,6 +259,9 @@ function unit_handled_variable_value(meta, varname) return varval end +no_value_default_to_nothing(::NoValue) = nothing +no_value_default_to_nothing(x) = x + # This function parses various variable/parameter definitions. # # The comments indicate the syntax matched by a block; either when parsed directly @@ -336,17 +339,17 @@ Base.@nospecializeinfer function parse_variable_def!( unit_handled_variable_value(meta, varname) if varclass == :parameters Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@parameters ($a[$(indices...)]::$type = $varval), + var = :($varname = $first(@parameters ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@constants ($a[$(indices...)]::$type = $varval), + var = :($varname = $first(@constants ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@variables ($a[$(indices)]::$type = $varval), + var = :($varname = $first(@variables ($a[$(indices)]::$type = $no_value_default_to_nothing($varval)), $meta_val)) end update_array_kwargs_and_metadata!( @@ -373,13 +376,13 @@ Base.@nospecializeinfer function parse_variable_def!( Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; - $varname = $first(@parameters ($a[$(indices...)]::$type = $varval), + $varname = $first(@parameters ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $(def_n_meta...))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; - $varname = $first(@constants ($a[$(indices...)]::$type = $varval), + $varname = $first(@constants ($a[$(indices...)]::$type = $no_value_default_to_nothing($varval)), $(def_n_meta...))) else Meta.isexpr(a, :call) || @@ -387,7 +390,7 @@ Base.@nospecializeinfer function parse_variable_def!( assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $val : $varname; $varname = $first(@variables $a[$(indices...)]::$type = ( - $varval), + $no_value_default_to_nothing($varval)), $(def_n_meta...))) end else @@ -395,18 +398,18 @@ Base.@nospecializeinfer function parse_variable_def!( Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + $varname = $first(@parameters $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@constants $a[$(indices...)]::$type = $varname)) + $varname = $first(@constants $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) else Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) var = :($varname = $varname === $NO_VALUE ? $def_n_meta : $varname; - $varname = $first(@variables $a[$(indices...)]::$type = $varname)) + $varname = $first(@variables $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) end varval, meta = def_n_meta, nothing end @@ -429,15 +432,15 @@ Base.@nospecializeinfer function parse_variable_def!( varname = a isa Expr && a.head == :call ? a.args[1] : a if varclass == :parameters Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@parameters $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@parameters $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :constants Meta.isexpr(a, :call) && assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@constants $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@constants $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) elseif varclass == :variables Meta.isexpr(a, :call) || throw("$a is not a variable of the independent variable") assert_unique_independent_var(dict, a.args[end]) - var = :($varname = $first(@variables $a[$(indices...)]::$type = $varname)) + var = :($varname = $first(@variables $a[$(indices...)]::$type = $no_value_default_to_nothing($varname))) else throw("Symbolic array with arbitrary length is not handled for $varclass. Please open an issue with an example.") @@ -573,7 +576,7 @@ function get_t(mod, t) get_var(mod, t) catch e if e isa UndefVarError - @warn("Could not find a predefined `t` in `$mod`; generating a new one within this model.\nConsider defining it or importing `t` (or `t_nounits`, `t_unitful` as `t`) from ModelingToolkit.") + @warn("Could not find a predefined `t` in `$mod`; generating a new one within this model.\nConsider defining it or importing `t` (or `t_nounits as t`) from ModelingToolkit.") variable(:t) else throw(e) @@ -629,7 +632,7 @@ function _set_var_metadata!(metadata_with_exprs, a, m, v::Expr) a end function _set_var_metadata!(metadata_with_exprs, a, m, v) - wrap(set_scalar_metadata(unwrap(a), m, v)) + wrap(setmetadata(unwrap(a), m, v)) end function set_var_metadata(a, ms) @@ -901,17 +904,7 @@ function convert_units( DynamicQuantities.SymbolicUnits.as_quantity(varunits), value)) end -function convert_units(varunits::Unitful.FreeUnits, value) - Unitful.ustrip(varunits, value) -end - -convert_units(::Unitful.FreeUnits, value::NoValue) = NO_VALUE - -function convert_units(varunits::Unitful.FreeUnits, value::AbstractArray{T}) where {T} - Unitful.ustrip.(varunits, value) -end - -convert_units(::Unitful.FreeUnits, value::Num) = value +convert_units(::DynamicQuantities.Quantity, value::AbstractArray{Num}) = value convert_units(::DynamicQuantities.Quantity, value::Num) = value @@ -930,8 +923,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types) try $setdefault($vv, $convert_units($unit, $name)) catch e - if isa(e, $(DynamicQuantities.DimensionError)) || - isa(e, $(Unitful.DimensionError)) + if isa(e, $(DynamicQuantities.DimensionError)) error("Unable to convert units for \'" * string(:($$vv)) * "\'") elseif isa(e, MethodError) error("No or invalid units provided for \'" * string(:($$vv)) * diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 96c00411ad..cce180511d 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -1,5 +1,5 @@ function contains_variable(x, wrt) - any(y -> occursin(y, x), wrt) + any(y -> SU.query(isequal(y), x), wrt) end """ @@ -270,7 +270,7 @@ function PolynomialTransformation(sys::System) transformation_err = nothing for t in all_non_poly_terms # if the term involves multiple unknowns, we can't invert it - dvs_in_term = map(x -> occursin(x, t), dvs) + dvs_in_term = map(x -> SU.query(isequal(x), t), dvs) if count(dvs_in_term) > 1 transformation_err = MultivarTerm(t, dvs[dvs_in_term]) is_poly = false @@ -369,7 +369,7 @@ function transform_system(sys::System, transformation::PolynomialTransformation; t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs)) # the substituted variable occurs outside the substituted term poly_and_nonpoly = map(dvs) do x - all(!isequal(x), new_dvs) && occursin(x, t) + all(!isequal(x), new_dvs) && SU.query(isequal(x), t) end if any(poly_and_nonpoly) return NotPolynomialError( diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index f377f0202f..9af6e54367 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -77,7 +77,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem; # PREPROCESSING op = anydict(op) if isempty(op) - op = copy(defs) + op = anydict(copy(defs)) end scalarize_vars_in_varmap!(op, arrvars) u0map = anydict() @@ -268,11 +268,10 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem; vs = Set() initialization_eqs = filter(initialization_eqs) do eq empty!(vs) - vars!(vs, eq; op = Initial) + SU.search_variables!(vs, eq; is_atomic = OperatorIsAtomic{Initial}()) allpars = full_parameters(sys) for p in allpars - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(p) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && SU.shape(p) isa SU.Unknown append!(allpars, Symbolics.scalarize(p)) end end @@ -499,13 +498,13 @@ end function get_possibly_array_fallback_singletons(varmap, p) if haskey(varmap, p) - return varmap[p] + return value(varmap[p]) end if symbolic_type(p) == ArraySymbolic() - is_sized_array_symbolic(p) || return nothing + symbolic_has_known_size(p) || return nothing scal = collect(p) if all(x -> haskey(varmap, x), scal) - res = [varmap[x] for x in scal] + res = [value(varmap[x]) for x in scal] if any(x -> x === nothing, res) return nothing elseif any(x -> x === missing, res) @@ -824,31 +823,12 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w initialization. """ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) - subs = Dict() - tempvars = Set() - rm_idxs = Int[] + mask = trues(length(obseqs)) for (i, eq) in enumerate(obseqs) - iscall(eq.rhs) || continue - if operation(eq.rhs) == StructuralTransformations.change_origin - push!(rm_idxs, i) - continue - end - end - - for (i, eq) in enumerate(obseqs) - if eq.lhs in tempvars - subs[eq.lhs] = eq.rhs - push!(rm_idxs, i) - end + mask[i] = !iscall(eq.rhs) || operation(eq.rhs) !== StructuralTransformations.change_origin end - obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] - obseqs = map(obseqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) - end - eqs = map(eqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) - end + obseqs = obseqs[mask] return obseqs, eqs end diff --git a/src/systems/optimal_control_interface.jl b/src/systems/optimal_control_interface.jl index 5a0ddbf8d5..108eb05893 100644 --- a/src/systems/optimal_control_interface.jl +++ b/src/systems/optimal_control_interface.jl @@ -357,16 +357,16 @@ function substitute_model_vars(model, sys, exprs, tspan) t = get_iv(sys) exprs = map( - c -> Symbolics.fast_substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) + c -> substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) (ti, tf) = tspan if symbolic_type(tf) === ScalarSymbolic() _tf = model.tₛ + ti exprs = map( - c -> Symbolics.fast_substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) - exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs) + c -> substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, Dict(tf => _tf)), exprs) end - exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) exprs end @@ -440,7 +440,7 @@ end function substitute_toterm(vars, exprs) toterm_map = Dict([u => default_toterm(value(u)) for u in vars]) - exprs = map(c -> Symbolics.fast_substitute(c, toterm_map), exprs) + exprs = map(c -> substitute(c, toterm_map), exprs) end function substitute_params(pmap, exprs) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 637ed674ae..322889016e 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,4 +1,3 @@ -symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x) symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) @@ -124,7 +123,7 @@ function MTKParameters( val = symconvert(ctype, val) done = set_value(sym, val) if !done && Symbolics.isarraysymbolic(sym) - if Symbolics.shape(sym) === Symbolics.Unknown() + if !symbolic_has_known_size(sym) for i in eachindex(val) set_value(sym[i], val[i]) end @@ -464,11 +463,11 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val) end stype = symtype(p) sz = if stype <: AbstractArray - Symbolics.shape(p) == Symbolics.Unknown() ? Symbolics.Unknown() : size(p) + size(p) elseif stype <: Number size(p) else - Symbolics.Unknown() + SU.Unknown(-1) end validate_parameter_type(ic, stype, sz, p, idx, val) end @@ -480,7 +479,7 @@ function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val) stype = AbstractArray{<:stype} end validate_parameter_type( - ic, stype, Symbolics.Unknown(), nothing, idx, val) + ic, stype, SU.Unknown(-1), nothing, idx, val) end function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) @@ -500,7 +499,7 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end # ... and must match sizes - if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz + if stype <: AbstractArray && !(sz isa SU.Unknown) && size(val) != sz throw(InvalidParameterSizeException(sym, val)) end # Early exit @@ -719,7 +718,7 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru sym = idx idx = parameter_index(ic, sym) if idx === nothing - Symbolics.shape(sym) == Symbolics.Unknown() && + symbolic_has_known_size(sym) || throw(ParameterNotInSystem(sym)) size(sym) == size(val) || throw(InvalidParameterSizeException(sym, val)) @@ -763,8 +762,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru oldbuf.discrete, newbuf.discrete) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( oldbuf.constant, newbuf.constant) - @set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.( - oldbuf.nonnumeric, newbuf.nonnumeric) + for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric) + for i in eachindex(oldv) + isassigned(newv, i) && continue + newv[i] = oldv[i] + end + end + @set! newbuf.nonnumeric = Tuple( + typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)) if !ArrayInterface.ismutable(oldbuf) @set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable) @set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2a1208586b..146331c0de 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -17,10 +17,10 @@ anydict(x) = AnyDict(x) """ $(TYPEDSIGNATURES) -Check if `x` is a symbolic with known size. Assumes `Symbolics.shape(unwrap(x))` +Check if `x` is a symbolic with known size. Assumes `SymbolicUtils.shape(unwrap(x))` is a valid operation. """ -is_sized_array_symbolic(x) = Symbolics.shape(unwrap(x)) != Symbolics.Unknown() +symbolic_has_known_size(x) = !(SU.shape(unwrap(x)) isa SU.Unknown) """ $(TYPEDSIGNATURES) @@ -128,7 +128,7 @@ function add_fallbacks!( haskey(varmap, ttvar) && continue # array symbolics with a defined size may be present in the scalarized form - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) val = map(eachindex(var)) do idx # @something is lazy and saves from writing a massive if-elseif-else @something(get(varmap, var[idx], nothing), @@ -162,7 +162,7 @@ function add_fallbacks!( fallbacks, arrvar, nothing) get(fallbacks, ttarrvar, nothing) Some(nothing) if val !== nothing val = val[idxs...] - is_sized_array_symbolic(arrvar) && push!(arrvars, arrvar) + symbolic_has_known_size(arrvar) && push!(arrvars, arrvar) end else val = nothing @@ -197,7 +197,7 @@ function missingvars( ttsym = toterm(var) haskey(varmap, ttsym) && continue - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) mask = map(eachindex(var)) do idx !haskey(varmap, var[idx]) && !haskey(varmap, ttsym[idx]) end @@ -242,10 +242,10 @@ symbolic values, all of which need to be unwrapped. Specializes when `x isa Abst to unwrap keys and values, returning an `AnyDict`. """ function recursive_unwrap(x::AbstractArray) - symbolic_type(x) == ArraySymbolic() ? unwrap(x) : recursive_unwrap.(x) + symbolic_type(x) == ArraySymbolic() ? value(x) : recursive_unwrap.(x) end -recursive_unwrap(x) = unwrap(x) +recursive_unwrap(x) = value(x) function recursive_unwrap(x::AbstractDict) return anydict(unwrap(k) => recursive_unwrap(v) for (k, v) in x) @@ -261,7 +261,7 @@ entry for `eq.lhs`, insert the reverse mapping if `eq.rhs` is not a number. function add_observed_equations!(varmap::AbstractDict, eqs) for eq in eqs if var_in_varlist(eq.lhs, keys(varmap), nothing) - eq.rhs isa Number && continue + SU.isconst(eq.rhs) && continue var_in_varlist(eq.rhs, keys(varmap), nothing) && continue !iscall(eq.rhs) || issym(operation(eq.rhs)) || continue varmap[eq.rhs] = eq.lhs @@ -446,14 +446,14 @@ function check_substitution_cycles( for (k, v) in varmap kidx = var_to_idx[k] if symbolic_type(v) != NotSymbolic() - vars!(buffer, v) + SU.search_variables!(buffer, v) for var in buffer haskey(var_to_idx, var) || continue add_edge!(graph, kidx, var_to_idx[var]) end elseif v isa AbstractArray for val in v - vars!(buffer, val) + SU.search_variables!(buffer, val) end for var in buffer haskey(var_to_idx, var) || continue @@ -486,7 +486,7 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) v === nothing && continue symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = fixpoint_sub(v, varmap; maxiters = limit) + varmap[k] = value(fixpoint_sub(v, varmap; maxiters = limit, fold = Val(true))) end end @@ -535,7 +535,7 @@ If a scalarized entry already exists, it is not overridden. function scalarize_vars_in_varmap!(varmap::AbstractDict, vars) for var in vars symbolic_type(var) == ArraySymbolic() || continue - is_sized_array_symbolic(var) || continue + symbolic_has_known_size(var) || continue haskey(varmap, var) || continue for i in eachindex(var) haskey(varmap, var[i]) && continue @@ -701,9 +701,10 @@ function. Note that the getter ONLY works for problem-like objects, since it generates an observed function. It does NOT work for solutions. """ -Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector) +Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module) @nospecialize - obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false) + obsfn = build_explicit_observed_function( + indp, syms; wrap_delays = false, eval_expression, eval_module) return ObservedWrapper{is_time_dependent(indp)}(obsfn) end @@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns - `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`. """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; - initials = false, unwrap_initials = false, p_constructor = identity) + initials = false, unwrap_initials = false, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the @@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac tunable_getter = if isempty(tunable_syms) Returns(SizedVector{0, Float64}()) else - p_constructor ∘ concrete_getu(srcsys, tunable_syms) + p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module) end initials_getter = if initials && !isempty(syms[2]) initsyms = Vector{Any}(syms[2]) @@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end end end - p_constructor ∘ concrete_getu(srcsys, initsyms) + p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module) else Returns(SizedVector{0, Float64}()) end @@ -810,12 +812,12 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # tuple of `BlockedArray`s Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘ Base.Fix1(broadcast, p_constructor) ∘ - getu(srcsys, syms[3]) + concrete_getu(srcsys, Tuple(syms[3]); eval_expression, eval_module) end const_getter = if syms[4] == () Returns(()) else - Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, Tuple(syms[4])) end nonnumeric_getter = if syms[5] == () Returns(()) @@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end) # nonnumerics retain the assigned buffer type without narrowing Base.Fix1(broadcast, _p_constructor) ∘ - Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ + concrete_getu(srcsys, Tuple(syms[5]); eval_expression, eval_module) end getters = ( tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) @@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of ` with values from `srcsys`. """ function ReconstructInitializeprob( - srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity) + srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity, + eval_expression = false, eval_module = @__MODULE__) @assert is_initializesystem(dstsys) - ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys)) + ugetter = u0_constructor ∘ + concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module) if is_split(dstsys) - pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor) + pgetter = get_mtkparameters_reconstructor( + srcsys, dstsys; p_constructor, eval_expression, eval_module) else syms = parameters(dstsys) - pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor + pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module), + p_constructor = p_constructor + function _getter2(valp, initprob) p_constructor(inner(valp)) end @@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the `initializeprobpmap` function in `OverrideInitData` for the systems. """ function construct_initializeprobpmap( - sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity) + sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module) @assert is_initializesystem(initsys) if is_split(sys) return let getter = get_mtkparameters_reconstructor( - initsys, sys; initials = true, unwrap_initials = true, p_constructor) + initsys, sys; initials = true, unwrap_initials = true, p_constructor, + eval_expression, eval_module) function initprobpmap_split(prob, initsol) getter(initsol, prob) end end else - return let getter = getu(initsys, parameters(sys; initial_parameters = true)), - p_constructor = p_constructor + return let getter = concrete_getu( + initsys, parameters(sys; initial_parameters = true); + eval_expression, eval_module), p_constructor = p_constructor function initprobpmap_nosplit(prob, initsol) return p_constructor(getter(initsol)) @@ -958,9 +968,11 @@ end A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires `is_update_oop = Val(true)` to be passed to `update_initializeprob!`. + +Any changes to this method should also be made to the one in ChainRulesCoreExt. """ function update_initializeprob!(initprob, prob) - pgetter = ChainRulesCore.@ignore_derivatives get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + pgetter = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter p = pgetter(prob, initprob) return remake(initprob; p) end @@ -1039,14 +1051,14 @@ struct GetUpdatedU0{GG, GIU} get_initial_unknowns::GIU end -function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict) +function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict) dvs = unknowns(sys) eqs = equations(sys) guessvars = trues(length(dvs)) for (i, var) in enumerate(dvs) guessvars[i] = !isequal(get(op, var, nothing), Initial(var)) end - get_guessvars = getu(initsys, dvs[guessvars]) + get_guessvars = getu(initprob, dvs[guessvars]) get_initial_unknowns = getu(sys, Initial.(dvs)) return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns) end @@ -1108,7 +1120,7 @@ function maybe_build_initialization_problem( guesses, missing_unknowns; implicit_dae = false, time_dependent_init = is_time_dependent(sys), u0_constructor = identity, p_constructor = identity, floatT = Float64, initialization_eqs = [], - use_scc = true, kwargs...) + use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) if t === nothing && is_time_dependent(sys) @@ -1117,7 +1129,7 @@ function maybe_build_initialization_problem( initializeprob = ModelingToolkit.InitializationProblem{iip}( sys, t, op; guesses, time_dependent_init, initialization_eqs, - use_scc, u0_constructor, p_constructor, kwargs...) + use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...) if state_values(initializeprob) !== nothing _u0 = state_values(initializeprob) if ArrayInterface.ismutable(_u0) @@ -1145,15 +1157,16 @@ function maybe_build_initialization_problem( initializeprob = remake(initializeprob; p = initp) get_initial_unknowns = if time_dependent_init - GetUpdatedU0(sys, initializeprob.f.sys, op) + GetUpdatedU0(sys, initializeprob, op) else nothing end meta = InitializationMetadata( - copy(op), copy(guesses), Vector{Equation}(initialization_eqs), + copy(op), anydict(copy(guesses)), Vector{Equation}(initialization_eqs), use_scc, time_dependent_init, ReconstructInitializeprob( - sys, initializeprob.f.sys; u0_constructor, p_constructor), + sys, initializeprob.f.sys; u0_constructor, + p_constructor, eval_expression, eval_module), get_initial_unknowns, SetInitialUnknowns(sys)) if time_dependent_init @@ -1172,10 +1185,9 @@ function maybe_build_initialization_problem( initializeprobpmap = nothing else initializeprobpmap = construct_initializeprobpmap( - sys, initializeprob.f.sys; p_constructor) + sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module) end - reqd_syms = parameter_symbols(initializeprob) # we still want the `initialization_data` because it helps with `remake` if initializeprobmap === nothing && initializeprobpmap === nothing update_initializeprob! = nothing @@ -1186,7 +1198,9 @@ function maybe_build_initialization_problem( filter!(punknowns) do p is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing end - pvals = getu(initializeprob, punknowns)(initializeprob) + # See comment below for why `getu` is not used here. + _pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns) + pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob)) for (p, pval) in zip(punknowns, pvals) p = unwrap(p) op[p] = pval @@ -1198,7 +1212,13 @@ function maybe_build_initialization_problem( end if time_dependent_init - uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob) + # We can't use `getu` here because that goes to `SII.observed`, which goes to + # `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If + # `eval_expression == true`, this then runs into world-age issues. Building an + # RGF here is fine since it is always discarded. We can't use `eval_module` for + # the RGF since the user may not have run RGF's init. + _ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns)) + uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob)) for (v, val) in zip(missing_unknowns, uvals) op[v] = val end @@ -1461,7 +1481,7 @@ function process_SciMLProblem( if is_time_dependent(sys) && t0 === nothing t0 = zero(floatT) end - initialization_data = SciMLBase.remake_initialization_data( + initialization_data = @invokelatest SciMLBase.remake_initialization_data( sys, kwargs, u0, t0, p, u0, p) kwargs = merge(kwargs, (; initialization_data)) end @@ -1773,7 +1793,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs """ function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...) # Call `remake` so it runs initialization if it is trivial - remake(T(args...; kwargs...)) + # Use `@invokelatest` to avoid world-age issues with `eval_expression = true` + @invokelatest remake(T(args...; kwargs...)) end """ diff --git a/src/systems/solver_nlprob.jl b/src/systems/solver_nlprob.jl index badfe21efb..d4772018c1 100644 --- a/src/systems/solver_nlprob.jl +++ b/src/systems/solver_nlprob.jl @@ -55,7 +55,7 @@ function inner_nlsystem(sys::System, mm, nlstep_compile::Bool) subrules = Dict([v => unwrap(gamma2*v + inner_tmp[i]) for (i, v) in enumerate(dvs)]) subrules[t] = unwrap(c) - new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) + new_rhss = map(Base.Fix2(substitute, subrules), rhss) new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs new_eqs = [0 ~ rhs for rhs in new_rhss] diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index ea65981804..48d9d2f4f6 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -75,7 +75,7 @@ for (s, T) in [(:timeInState, :Real), seed = hash(s) @eval begin $s(x) = wrap(term($s, x)) - SymbolicUtils.promote_symtype(::typeof($s), _...) = $T + SymbolicUtils.promote_symtype(::typeof($s), ::Type{S}) where {S} = $T function SymbolicUtils.show_call(io, ::typeof($s), args) if isempty(args) print(io, $s, "()") diff --git a/src/systems/system.jl b/src/systems/system.jl index dcb3ed6f9b..db15f218d6 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -3,7 +3,7 @@ struct Schedule """ Mapping of `Differential`s of variables to corresponding derivative expressions. """ - dummy_sub::Dict{Any, Any} + dummy_sub::Dict{SymbolicT, SymbolicT} end const MetadataT = Base.ImmutableDict{DataType, Any} @@ -46,7 +46,7 @@ struct System <: IntermediateDeprecationSystem this noise matrix is diagonal. Diagonal noise can be specified by providing an `N` length vector. If this field is `nothing`, the system does not have noise. """ - noise_eqs::Union{Nothing, AbstractVector, AbstractMatrix} + noise_eqs::Union{Nothing, Vector{SymbolicT}, Matrix{SymbolicT}} """ Jumps associated with the system. Each jump can be a `VariableRateJump`, `ConstantRateJump` or `MassActionJump`. See `JumpProcesses.jl` for more information. @@ -63,7 +63,7 @@ struct System <: IntermediateDeprecationSystem loss of an optimization problem. Scalar loss values must also be provided as a single- element vector. """ - costs::Vector{<:Union{BasicSymbolic, Real}} + costs::Vector{SymbolicT} """ A function which combines costs into a scalar value. This should take two arguments, the `costs` of this system and the consolidated costs of all subsystems in the order @@ -76,25 +76,25 @@ struct System <: IntermediateDeprecationSystem The variables being solved for by this system. For example, in a differential equation system, this contains the dependent variables. """ - unknowns::Vector + unknowns::Vector{SymbolicT} """ The parameters of the system. Parameters can either be variables that parameterize the problem being solved for (e.g. the spring constant of a mass-spring system) or additional unknowns not part of the main dynamics of the system (e.g. discrete/clocked variables in a hybrid ODE). """ - ps::Vector + ps::Vector{SymbolicT} """ The brownian variables of the system, created via `@brownians`. Each brownian variable represents an independent noise. A system with brownians cannot be simulated directly. It needs to be compiled using `mtkcompile` into `noise_eqs`. """ - brownians::Vector + brownians::Vector{SymbolicT} """ The independent variable for a time-dependent system, or `nothing` for a time-independent system. """ - iv::Union{Nothing, BasicSymbolic{Real}} + iv::Union{Nothing, SymbolicT} """ Equations that compute variables of a system that have been eliminated from the set of unknowns by `mtkcompile`. More generally, this contains all variables that can be @@ -117,7 +117,7 @@ struct System <: IntermediateDeprecationSystem A mapping from the name of a variable to the actual symbolic variable in the system. This is used to enable `getproperty` syntax to access variables of a system. """ - var_to_name::Dict{Symbol, Any} + var_to_name::Dict{Symbol, SymbolicT} """ The name of the system. """ @@ -132,11 +132,11 @@ struct System <: IntermediateDeprecationSystem by initial values provided to the problem constructor. Defaults of parent systems take priority over those in child systems. """ - defaults::Dict + defaults::SymmapT """ Guess values for variables of a system that are solved for during initialization. """ - guesses::Dict + guesses::SymmapT """ A list of subsystems of this system. Used for hierarchically building models. """ @@ -167,7 +167,7 @@ struct System <: IntermediateDeprecationSystem associated error message. By default these assertions cause the generated code to output `NaN`s if violated, but can be made to error using `debug_system`. """ - assertions::Dict{BasicSymbolic, String} + assertions::Dict{SymbolicT, String} """ The metadata associated with this system, as a `Base.ImmutableDict`. This follows the same interface as SymbolicUtils.jl. Metadata can be queried and updated using @@ -193,12 +193,12 @@ struct System <: IntermediateDeprecationSystem $INTERNAL_FIELD_WARNING The list of input variables of the system. """ - inputs::OrderedSet{BasicSymbolic} + inputs::OrderedSet{SymbolicT} """ $INTERNAL_FIELD_WARNING The list of output variables of the system. """ - outputs::OrderedSet{BasicSymbolic} + outputs::OrderedSet{SymbolicT} """ The `TearingState` of the system post-simplification with `mtkcompile`. """ @@ -264,9 +264,9 @@ struct System <: IntermediateDeprecationSystem tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, - connector_type, assertions = Dict{BasicSymbolic, String}(), + connector_type, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [], - inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(), + inputs = Set{SymbolicT}(), outputs = Set{SymbolicT}(), tearing_state = nothing, namespacing = true, complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, @@ -278,15 +278,23 @@ struct System <: IntermediateDeprecationSystem variable $iv. """)) end - jumps = Vector{JumpType}(jumps) - if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing - check_independent_variables([iv]) + @assert iv === nothing || symtype(iv) === Real + if (checks isa Bool && checks === true || checks isa Int && (checks & CheckComponents) > 0) && iv !== nothing + check_independent_variables((iv,)) check_variables(unknowns, iv) check_parameters(ps, iv) check_equations(eqs, iv) - if noise_eqs !== nothing && size(noise_eqs, 1) != length(eqs) - throw(IllFormedNoiseEquationsError(size(noise_eqs, 1), length(eqs))) + Neq = length(eqs) + if noise_eqs isa Matrix{SymbolicT} + N1 = size(noise_eqs, 1) + elseif noise_eqs isa Vector{SymbolicT} + N1 = length(noise_eqs) + elseif noise_eqs === nothing + N1 = Neq + else + error() end + N1 == Neq || throw(IllFormedNoiseEquationsError(N1, Neq)) check_equations(equations(continuous_events), iv) check_subsystems(systems) end @@ -314,10 +322,33 @@ struct System <: IntermediateDeprecationSystem end end +_sum_costs(costs::Vector{SymbolicT}) = SU.add_worker(VartypeT, costs) +_sum_costs(costs::Vector{Num}) = SU.add_worker(VartypeT, costs) +# `reduce` instead of `sum` because the rrule for `sum` doesn't +# handle the `init` kwarg. +_sum_costs(costs::Vector) = reduce(+, costs; init = 0.0) + function default_consolidate(costs, subcosts) - # `reduce` instead of `sum` because the rrule for `sum` doesn't - # handle the `init` kwarg. - return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0) + return _sum_costs(costs) + _sum_costs(subcosts) +end + +unwrap_vars(x) = unwrap_vars(collect(x)) +unwrap_vars(vars::AbstractArray{SymbolicT}) = vars +function unwrap_vars(vars::AbstractArray) + result = similar(vars, SymbolicT) + for i in eachindex(vars) + result[i] = SU.Const{VartypeT}(vars[i]) + end + return result +end + +defsdict(x::SymmapT) = x +function defsdict(x::Union{AbstractDict, AbstractArray{<:Pair}}) + result = SymmapT() + for (k, v) in x + result[unwrap(k)] = SU.Const{VartypeT}(v) + end + return result end """ @@ -336,74 +367,82 @@ for time-independent systems, unknowns `dvs`, parameters `ps` and brownian varia All other keyword arguments are named identically to the corresponding fields in [`System`](@ref). """ -function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; - constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = [], - costs = BasicSymbolic[], consolidate = default_consolidate, - observed = Equation[], parameter_dependencies = Equation[], defaults = Dict(), - guesses = Dict(), systems = System[], initialization_eqs = Equation[], +function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; + constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = JumpType[], + costs = SymbolicT[], consolidate = default_consolidate, + observed = Equation[], parameter_dependencies = Equation[], defaults = SymmapT(), + guesses = SymmapT(), systems = System[], initialization_eqs = Equation[], continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[], - connector_type = nothing, assertions = Dict{BasicSymbolic, String}(), + connector_type = nothing, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, - is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(), - outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing, + is_dde = nothing, tstops = [], inputs = OrderedSet{SymbolicT}(), + outputs = OrderedSet{SymbolicT}(), tearing_state = nothing, ignored_connections = nothing, parent = nothing, description = "", name = nothing, discover_from_metadata = true, initializesystem = nothing, is_initializesystem = false, is_discrete = false, preface = [], checks = true) name === nothing && throw(NoNameError()) + + if !(systems isa Vector{System}) + systems = Vector{System}(systems) + end + if !(eqs isa Vector{Equation}) + eqs = Equation[eqs] + end + eqs = eqs::Vector{Equation} + if !isempty(parameter_dependencies) - @warn """ - The `parameter_dependencies` keyword argument is deprecated. Please provide all - such equations as part of the normal equations of the system. - """ - eqs = Equation[eqs; parameter_dependencies] + @invokelatest warn_pdeps() + append!(eqs, parameter_dependencies) end iv = unwrap(iv) - ps = unwrap.(ps) - dvs = unwrap.(dvs) - filter!(!Base.Fix2(isdelay, iv), dvs) - brownians = unwrap.(brownians) - - if !(eqs isa AbstractArray) - eqs = [eqs] + ps = vec(unwrap_vars(ps)) + dvs = vec(unwrap_vars(dvs)) + if iv !== nothing + filter!(!Base.Fix2(isdelay, iv), dvs) end + brownians = unwrap_vars(brownians) if noise_eqs !== nothing - noise_eqs = unwrap.(noise_eqs) + noise_eqs = unwrap_vars(noise_eqs) end - costs = unwrap.(costs) - if isempty(costs) - costs = Union{BasicSymbolic, Real}[] - end - - defaults = anydict(defaults) - guesses = anydict(guesses) + costs = vec(unwrap_vars(costs)) - inputs = unwrap.(inputs) - outputs = unwrap.(outputs) - inputs = OrderedSet{BasicSymbolic}(inputs) - outputs = OrderedSet{BasicSymbolic}(outputs) + defaults = defsdict(defaults) + guesses = defsdict(guesses) + if !(inputs isa OrderedSet{SymbolicT}) + inputs = unwrap.(inputs) + inputs = OrderedSet{SymbolicT}(inputs) + end + if !(outputs isa OrderedSet{SymbolicT}) + outputs = unwrap.(outputs) + outputs = OrderedSet{SymbolicT}(outputs) + end for subsys in systems - for var in ModelingToolkit.inputs(subsys) + for var in get_inputs(subsys) push!(inputs, renamespace(subsys, var)) end - for var in ModelingToolkit.outputs(subsys) + for var in get_outputs(subsys) push!(outputs, renamespace(subsys, var)) end end - var_to_name = anydict() + var_to_name = Dict{Symbol, SymbolicT}() - let defaults = discover_from_metadata ? defaults : Dict(), - guesses = discover_from_metadata ? guesses : Dict(), - inputs = discover_from_metadata ? inputs : Set(), - outputs = discover_from_metadata ? outputs : Set() + let defaults = discover_from_metadata ? defaults : SymmapT(), + guesses = discover_from_metadata ? guesses : SymmapT(), + inputs = discover_from_metadata ? inputs : OrderedSet{SymbolicT}(), + outputs = discover_from_metadata ? outputs : OrderedSet{SymbolicT}() process_variables!(var_to_name, defaults, guesses, dvs) process_variables!(var_to_name, defaults, guesses, ps) - process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed]) - process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed]) + buffer = SymbolicT[] + for eq in observed + push!(buffer, eq.lhs) + push!(buffer, eq.rhs) + end + process_variables!(var_to_name, defaults, guesses, buffer) for var in dvs if isinput(var) @@ -413,15 +452,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; end end end - filter!(!(isnothing ∘ last), defaults) - filter!(!(isnothing ∘ last), guesses) - defaults = anydict([unwrap(k) => unwrap(v) for (k, v) in defaults]) - guesses = anydict([unwrap(k) => unwrap(v) for (k, v) in guesses]) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), defaults) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), guesses) - sysnames = nameof.(systems) - unique_sysnames = Set(sysnames) - if length(unique_sysnames) != length(sysnames) - throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) + + if !allunique(map(nameof, systems)) + nonunique_subsystems(systems) end continuous_events, discrete_events = create_symbolic_events( @@ -435,7 +471,10 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; is_dde = _check_if_dde(eqs, iv, systems) end - assertions = Dict{BasicSymbolic, String}(unwrap(k) => v for (k, v) in assertions) + _assertions = Dict{SymbolicT, String} + for (k, v) in assertions + _assertions[unwrap(k)::SymbolicT] = v + end if isempty(metadata) metadata = MetadataT() @@ -449,6 +488,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; metadata = meta end metadata = refreshed_metadata(metadata) + jumps = Vector{JumpType}(jumps) System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints, costs, consolidate, dvs, ps, brownians, iv, observed, Equation[], var_to_name, name, description, defaults, guesses, systems, initialization_eqs, @@ -458,6 +498,21 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; initializesystem, is_initializesystem, is_discrete; checks) end +@noinline function nonunique_subsystems(systems) + sysnames = nameof.(systems) + unique_sysnames = Set(sysnames) + throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) +end + +@noinline function warn_pdeps() + @warn """ + The `parameter_dependencies` keyword argument is deprecated. Please provide all + such equations as part of the normal equations of the system. + """ +end + +SymbolicIndexingInterface.getname(x::System) = nameof(x) + """ $(TYPEDSIGNATURES) @@ -478,13 +533,14 @@ other symbolic expressions passed to the system. function System(eqs::Vector{Equation}, iv; kwargs...) iv === nothing && return System(eqs; kwargs...) - diffvars = OrderedSet() - othervars = OrderedSet() - ps = Set() + diffvars = OrderedSet{SymbolicT}() + othervars = OrderedSet{SymbolicT}() + ps = OrderedSet{SymbolicT}() diffeqs = Equation[] othereqs = Equation[] + iv = unwrap(iv) for eq in eqs - if !(eq.lhs isa Union{Symbolic, Number, AbstractArray}) + if !(eq.lhs isa Union{SymbolicT, Number, AbstractArray}) push!(othereqs, eq) continue end @@ -511,7 +567,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...) allunknowns = union(diffvars, othervars) eqs = [diffeqs; othereqs] - brownians = Set() + brownians = Set{SymbolicT}() for x in allunknowns x = unwrap(x) if getvariabletype(x) == BROWNIAN @@ -550,8 +606,8 @@ function System(eqs::Vector{Equation}, iv; kwargs...) noiseeqs = get(kwargs, :noise_eqs, nothing) if noiseeqs !== nothing # validate noise equations - noisedvs = OrderedSet() - noiseps = OrderedSet() + noisedvs = OrderedSet{SymbolicT}() + noiseps = OrderedSet{SymbolicT}() collect_vars!(noisedvs, noiseps, noiseeqs, iv) for dv in noisedvs dv ∈ allunknowns || @@ -573,8 +629,8 @@ the system. function System(eqs::Vector{Equation}; kwargs...) eqs = collect(eqs) - allunknowns = OrderedSet() - ps = OrderedSet() + allunknowns = OrderedSet{SymbolicT}() + ps = OrderedSet{SymbolicT}() for eq in eqs collect_vars!(allunknowns, ps, eq, nothing) end @@ -611,15 +667,13 @@ function gather_array_params(ps) for p in ps if iscall(p) && operation(p) === getindex par = arguments(p)[begin] - if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && - all(par[i] in ps for i in eachindex(par)) + if symbolic_has_known_size(p) && all(par[i] in ps for i in eachindex(par)) push!(new_ps, par) else push!(new_ps, p) end else - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(unwrap(p)) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && symbolic_has_known_size(p) for i in eachindex(p) delete!(new_ps, p[i]) end @@ -635,10 +689,10 @@ Process variables in constraints of the (ODE) System. """ function process_constraint_system( constraints::Vector{Union{Equation, Inequality}}, sts, ps, iv; validate = true) - isempty(constraints) && return Set(), Set() + isempty(constraints) && return OrderedSet{SymbolicT}(), OrderedSet{SymbolicT}() - constraintsts = OrderedSet() - constraintps = OrderedSet() + constraintsts = OrderedSet{SymbolicT}() + constraintps = OrderedSet{SymbolicT}() for cons in constraints collect_vars!(constraintsts, constraintps, cons, iv) union!(constraintsts, collect_applied_operators(cons, Differential)) @@ -656,8 +710,8 @@ end Process the costs for the constraint system. """ function process_costs(costs::Vector, sts, ps, iv) - coststs = OrderedSet() - costps = OrderedSet() + coststs = OrderedSet{SymbolicT}() + costps = OrderedSet{SymbolicT}() for cost in costs collect_vars!(coststs, costps, cost, iv) end @@ -680,7 +734,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) for var in auxvars if !iscall(var) - occursin(iv, var) && (var ∈ sts || + SU.query(isequal(iv), var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))) elseif length(arguments(var)) > 1 throw(ArgumentError("Too many arguments for variable $var.")) @@ -692,8 +746,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) operation(var)(iv) ∈ sts || throw(ArgumentError("Variable $var is not a variable of the System. Called variables must be variables of the System.")) - isequal(arg, iv) || isparameter(arg) || arg isa Integer || - arg isa AbstractFloat || + isequal(arg, iv) || isparameter(arg) || isconst(arg) && symtype(arg) <: Real || throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds.")) isparameter(arg) && !isequal(arg, iv) && push!(auxps, arg) @@ -725,19 +778,15 @@ differential equations. """ is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys) -function _check_if_dde(eqs, iv, subsystems) - is_dde = any(ModelingToolkit.is_dde, subsystems) - if !is_dde - vs = Set() - for eq in eqs - vars!(vs, eq) - is_dde = any(vs) do sym - isdelay(unwrap(sym), iv) - end - is_dde && break - end +_check_if_dde(eqs::Vector{Equation}, iv::Nothing, subsystems::Vector{System}) = false +function _check_if_dde(eqs::Vector{Equation}, iv::SymbolicT, subsystems::Vector{System}) + any(ModelingToolkit.is_dde, subsystems) && return true + pred = Base.Fix2(isdelay, iv) + for eq in eqs + SU.query(pred, eq.lhs) && return true + SU.query(pred, eq.rhs) && return true end - return is_dde + return false end """ @@ -751,9 +800,9 @@ function flatten(sys::System, noeqs = false) isempty(systems) && return sys costs = cost(sys) if _iszero(costs) - costs = Union{Real, BasicSymbolic}[] + costs = SymbolicT[] else - costs = [costs] + costs = SymbolicT[costs] end # We don't include `ignored_connections` in the flattened system, because # connection expansion inherently requires the hierarchy structure. If the system @@ -899,7 +948,7 @@ function NonlinearSystem(sys::System) subrules[var] = 0.0 end eqs = map(eqs) do eq - fast_substitute(eq, subrules) + substitute(eq, subrules) end nsys = System(eqs, unknowns(sys), [parameters(sys); get_iv(sys)]; defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys), diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 4c52300239..9b0d469534 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -28,40 +28,31 @@ present in the equations of the system will be removed in this process. + `sort_eqs=true` controls whether equations are sorted lexicographically before simplification or not. """ function mtkcompile( - sys::AbstractSystem; additional_passes = [], simplify = false, split = true, + sys::System; additional_passes = (), simplify = false, split = true, allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs = SymbolicT[], outputs = SymbolicT[], + disturbance_inputs = SymbolicT[], kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) - newsys′ = __mtkcompile(sys; simplify, + # Canonicalize types of arguments to prevent repeated compilation of inner methods + inputs = unwrap_vars(inputs) + outputs = unwrap_vars(outputs) + disturbance_inputs = unwrap_vars(disturbance_inputs) + newsys = __mtkcompile(sys; simplify, allow_symbolic, allow_parameter, conservative, fully_determined, inputs, outputs, disturbance_inputs, additional_passes, kwargs...) - if newsys′ isa Tuple - @assert length(newsys′) == 2 - newsys = newsys′[1] - else - newsys = newsys′ - end for pass in additional_passes newsys = pass(newsys) end - if has_parent(newsys) - @set! newsys.parent = complete(sys; split = false, flatten = false) - end + @set! newsys.parent = complete(sys; split = false, flatten = false) newsys = complete(newsys; split) - if newsys′ isa Tuple - idxs = [parameter_index(newsys, i) for i in io[1]] - return newsys, idxs - else - return newsys - end + return newsys end function __mtkcompile(sys::AbstractSystem; simplify = false, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs::Vector{SymbolicT} = SymbolicT[], outputs::Vector{SymbolicT} = SymbolicT[], + disturbance_inputs::Vector{SymbolicT} = SymbolicT[], sort_eqs = true, kwargs...) # TODO: convert noise_eqs to brownians for simplification @@ -72,7 +63,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return sys end if isempty(equations(sys)) && !is_time_dependent(sys) && !_iszero(cost(sys)) - return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) + return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)::System end sys, statemachines = extract_top_level_statemachines(sys) @@ -175,42 +166,62 @@ function simplify_optimization_system(sys::System; split = true, kwargs...) sys = flatten(sys) cons = constraints(sys) econs = Equation[] - icons = similar(cons, 0) + icons = Inequality[] for e in cons if e isa Equation push!(econs, e) - else + elseif e isa Inequality push!(icons, e) end end - irreducible_subs = Dict() - dvs = mapreduce(Symbolics.scalarize, vcat, unknowns(sys)) - if !(dvs isa Array) - dvs = [dvs] + irreducible_subs = Dict{SymbolicT, SymbolicT}() + dvs = SymbolicT[] + for var in unknowns(sys) + sh = SU.shape(var)::SU.ShapeVecT + if isempty(sh) + push!(dvs, var) + else + append!(dvs, vec(collect(var)::Array{SymbolicT})::Vector{SymbolicT}) + end end for i in eachindex(dvs) var = dvs[i] if hasbounds(var) - irreducible_subs[var] = irrvar = setirreducible(var, true) + irreducible_subs[var] = irrvar = setirreducible(var, true)::SymbolicT dvs[i] = irrvar end end - econs = fast_substitute.(econs, (irreducible_subs,)) + subst = SU.Substituter{false}(irreducible_subs, SU.default_substitute_filter) + for i in eachindex(econs) + econs[i] = subst(econs[i]) + end nlsys = System(econs, dvs, parameters(sys); name = :___tmp_nlsystem) - snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false) + snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false)::System obs = observed(snlsys) seqs = equations(snlsys) trueobs, _ = unhack_observed(obs, seqs) - subs = Dict(eq.lhs => eq.rhs for eq in trueobs) - cons_simplified = similar(cons, length(icons) + length(seqs)) - for (i, eq) in enumerate(Iterators.flatten((seqs, icons))) - cons_simplified[i] = fixpoint_sub(eq, subs) + subs = Dict{SymbolicT, SymbolicT}() + for eq in trueobs + subs[eq.lhs] = eq.rhs end - newsts = setdiff(dvs, keys(subs)) + cons_simplified = Union{Equation, Inequality}[] + for eq in seqs + push!(cons_simplified, fixpoint_sub(eq, subs)) + end + for eq in icons + push!(cons_simplified, fixpoint_sub(eq, subs)) + end + setdiff!(dvs, keys(subs)) + newsts = dvs @set! sys.constraints = cons_simplified - @set! sys.observed = [observed(sys); obs] - newcost = fixpoint_sub.(get_costs(sys), (subs,)) - @set! sys.costs = newcost + newobs = copy(observed(sys)) + append!(newobs, obs) + @set! sys.observed = newobs + newcosts = copy(get_costs(sys)) + for i in eachindex(newcosts) + newcosts[i] = fixpoint_sub(newcosts[i], subs) + end + @set! sys.costs = newcosts @set! sys.unknowns = newsts return sys end diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a1460731cb..82ac0d8a8a 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,11 +1,11 @@ using DataStructures using Symbolics: linear_expansion, unwrap -using SymbolicUtils: iscall, operation, arguments, Symbolic +using SymbolicUtils: iscall, operation, arguments using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value, InvalidSystemException, isdifferential, _iszero, - isparameter, Connection, + isparameter, Connection, SymbolicT independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, InferredTimeDomain, @@ -153,7 +153,7 @@ Base.@kwdef mutable struct SystemStructure """Graph that connects equations to the variable they will be solved for during simplification.""" solvable_graph::Union{BipartiteGraph{Int, Nothing}, Nothing} """Variable types (brownian, variable, parameter) in the system.""" - var_types::Union{Vector{VariableType}, Nothing} + var_types::Vector{VariableType} """Whether the system is discrete.""" only_discrete::Bool end @@ -205,10 +205,11 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T """The set of variables of the system.""" - fullvars::Vector{BasicSymbolic} + fullvars::Vector{SymbolicT} structure::SystemStructure - extra_eqs::Vector - param_derivative_map::Dict{BasicSymbolic, Any} + extra_eqs::Vector{Equation} + param_derivative_map::Dict{SymbolicT, SymbolicT} + no_deriv_params::Set{SymbolicT} original_eqs::Vector{Equation} """ Additional user-provided observed equations. The variables calculated here @@ -278,7 +279,7 @@ function Base.show(io::IO, state::TearingState) print(io, "TearingState of ", typeof(state.sys)) end -struct EquationsView{T} <: AbstractVector{Any} +struct EquationsView{T} <: AbstractVector{Equation} ts::TearingState{T} end equations(ts::TearingState) = EquationsView(ts) @@ -301,11 +302,11 @@ function is_time_dependent_parameter(p, allps, iv) (args = arguments(p); length(args)) == 1 && isequal(only(args), iv)) end -function symbolic_contains(var, set) - var in set || - symbolic_type(var) == ArraySymbolic() && - Symbolics.shape(var) != Symbolics.Unknown() && - all(x -> x in set, Symbolics.scalarize(var)) +function symbolic_contains(var::SymbolicT, set::Set{SymbolicT}) + var in set # || + # symbolic_type(var) == ArraySymbolic() && + # symbolic_has_known_size(var) && + # all(x -> x in set, Symbolics.scalarize(var)) end """ @@ -315,21 +316,21 @@ Descend through the system hierarchy and look for statemachines. Remove equation the inner statemachine systems. Return the new `sys` and an array of top-level statemachines. """ -function extract_top_level_statemachines(sys::AbstractSystem) +function extract_top_level_statemachines(sys::System) eqs = get_eqs(sys) - - if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + predicate = Base.Fix2(isa, StateMachineOperator) ∘ SU.unwrap_const + if !isempty(eqs) && all(predicate, eqs) # top-level statemachine with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) return with_removed, [sys] - elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + elseif !isempty(eqs) && any(predicate, eqs) # error: can't mix error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") else # descend subsystems = get_systems(sys) - newsubsystems = eltype(subsystems)[] - statemachines = eltype(subsystems)[] + newsubsystems = System[] + statemachines = System[] for subsys in subsystems newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) push!(newsubsystems, newsubsys) @@ -346,12 +347,12 @@ end Return `sys` with all equations (including those in subsystems) removed. """ function remove_child_equations(sys::AbstractSystem) - @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.eqs = Equation[] @set! sys.systems = map(remove_child_equations, get_systems(sys)) return sys end -function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) +function TearingState(sys; check = true, sort_eqs = true) # flatten system sys = flatten(sys) sys = process_parameter_equations(sys) @@ -361,76 +362,30 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eqs = flatten_equations(equations(sys)) original_eqs = copy(eqs) neqs = length(eqs) - param_derivative_map = Dict{BasicSymbolic, Any}() + param_derivative_map = Dict{SymbolicT, SymbolicT}() + no_deriv_params = Set{SymbolicT}() + fullvars = SymbolicT[] # * Scalarize unknowns - dvs = Set{BasicSymbolic}() - fullvars = BasicSymbolic[] - for x in unknowns(sys) - push!(dvs, x) - xx = Symbolics.scalarize(x) - if xx isa AbstractArray - union!(dvs, xx) - end - end - ps = Set{Symbolic}() - for x in full_parameters(sys) - push!(ps, x) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() - xx = Symbolics.scalarize(x) - union!(ps, xx) - end - end - browns = Set{BasicSymbolic}() - for x in brownians(sys) - push!(browns, x) - xx = Symbolics.scalarize(x) - if xx isa AbstractArray - union!(browns, xx) - end - end - var2idx = Dict{BasicSymbolic, Int}() + dvs = Set{SymbolicT}() + collect_vars_to_set!(dvs, unknowns(sys)) + ps = Set{SymbolicT}() + collect_vars_to_set!(ps, full_parameters(sys)) + browns = Set{SymbolicT}() + collect_vars_to_set!(browns, brownians(sys)) + var2idx = Dict{SymbolicT, Int}() var_types = VariableType[] - addvar! = let fullvars = fullvars, dvs = dvs, var2idx = var2idx, var_types = var_types - (var, vtype) -> get!(var2idx, var) do - push!(dvs, var) - push!(fullvars, var) - push!(var_types, vtype) - return length(fullvars) - end - end + + addvar! = AddVar!(var2idx, dvs, fullvars, var_types) # build symbolic incidence - symbolic_incidence = Vector{BasicSymbolic}[] - varsbuf = Set() + symbolic_incidence = Vector{SymbolicT}[] + varsbuf = Set{SymbolicT}() eqs_to_retain = trues(length(eqs)) for (i, eq) in enumerate(eqs) - _eq = eq - if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential && - isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv) - # parameter derivatives are opted out by specifying `D(p) ~ missing`, but - # we want to store `nothing` in the map because that means `fast_substitute` - # will ignore the rule. We will this identify the presence of `eq′.lhs` in - # the differentiated expression and error. - param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing) - eqs_to_retain[i] = false - # change the equation if the RHS is `missing` so the rest of this loop works - eq = 0.0 ~ coalesce(eq.rhs, 0.0) - end - is_statemachine_equation = false - if eq.lhs isa StateMachineOperator - is_statemachine_equation = true - eq = eq - rhs = eq.rhs - elseif _iszero(eq.lhs) - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - else - lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - eq = 0 ~ rhs - lhs - end + eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, no_deriv_params, eqs_to_retain, ps, iv, i, eq) empty!(varsbuf) - vars!(varsbuf, eq; op = Symbolics.Operator) - incidence = Set{BasicSymbolic}() + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{SU.Operator}()) + incidence = Set{SymbolicT}() isalgeq = true for v in varsbuf # additionally track brownians in fullvars @@ -443,10 +398,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) if symbolic_contains(v, ps) || getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v) if is_time_dependent_parameter(v, ps, iv) && - !haskey(param_derivative_map, Differential(iv)(v)) + !haskey(param_derivative_map, Differential(iv)(v)) && !(Differential(iv)(v) in no_deriv_params) # Parameter derivatives default to zero - they stay constant # between callbacks - param_derivative_map[Differential(iv)(v)] = 0.0 + param_derivative_map[Differential(iv)(v)] = Symbolics.COMMON_ZERO end continue end @@ -455,14 +410,27 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) isdelay(v, iv) && continue if !symbolic_contains(v, dvs) - isvalid = iscall(v) && - (operation(v) isa Shift || is_transparent_operator(operation(v))) + isvalid = Moshi.Match.@match v begin + BSImpl.Term(; f) => f isa Shift || f isa Operator && is_transparent_operator(f)::Bool + _ => false + end v′ = v - while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift} - v′ = arguments(v′)[1] - if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope - isvalid = true - break + while !isvalid + Moshi.Match.@match v′ begin + BSImpl.Term(; f, args) => begin + if f isa Differential + v′ = args[1] + elseif f isa Shift + v′ = args[1] + else + break + end + if v′ in dvs || getmetadata(v′, SymScope, LocalScope()) isa GlobalScope + isvalid = true + break + end + end + _ => break end end if !isvalid @@ -470,43 +438,52 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) end addvar!(v, VARIABLE) - if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && - (it = input_timedomain(v)) !== nothing - for v′ in arguments(v) - addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f isa SU.Operator && + !(f isa Differential) && (it = input_timedomain(v)::Vector{InputTimeDomainElT}) !== nothing + end => begin + for (v′, td) in zip(args, it) + addvar!(setmetadata(v′, VariableTimeDomain, td), VARIABLE) + end end + _ => nothing end end isalgeq &= !isdifferential(v) - if symbolic_type(v) == ArraySymbolic() - vv = collect(v) + sh = SU.shape(v)::SU.ShapeVecT + if isempty(sh) + push!(incidence, v) + addvar!(v, VARIABLE) + elseif length(sh) == 1 + vv = collect(v)::Vector{SymbolicT} + union!(incidence, vv) + for vi in vv + addvar!(vi, VARIABLE) + end + elseif length(sh) == 2 + vv = collect(v)::Matrix{SymbolicT} union!(incidence, vv) - map(vv) do vi + for vi in vv addvar!(vi, VARIABLE) end else - push!(incidence, v) - addvar!(v, VARIABLE) + vv = collect(v) + union!(incidence, vv)::Array{SymbolicT} + for vi in vv + addvar!(vi, VARIABLE) + end end end if isalgeq || is_statemachine_equation eqs[i] = eq - else - eqs[i] = eqs[i].lhs ~ rhs end push!(symbolic_incidence, collect(incidence)) end - dervaridxs = OrderedSet{Int}() - for (i, v) in enumerate(fullvars) - while isdifferential(v) - push!(dervaridxs, i) - v = arguments(v)[1] - i = addvar!(v, VARIABLE) - end - end + filter!(Base.Fix2(!==, COMMON_NOTHING) ∘ last, param_derivative_map) + eqs = eqs[eqs_to_retain] original_eqs = original_eqs[eqs_to_retain] neqs = length(eqs) @@ -521,51 +498,42 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) symbolic_incidence = symbolic_incidence[sortidxs] end + dervaridxs = OrderedSet{Int}() + add_intermediate_derivatives!(fullvars, dervaridxs, addvar!) # Handle shifts - find lowest shift and add intermediates with derivative edges ### Handle discrete variables - lowest_shift = Dict() - for var in fullvars - if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) - steps = operation(var).steps - if steps > 0 - error("Only non-positive shifts allowed. Found $var with a shift of $steps") - end - v = arguments(var)[1] - lowest_shift[v] = min(get(lowest_shift, v, 0), steps) - end - end - for var in fullvars - if ModelingToolkit.isoperator(var, ModelingToolkit.Shift) - op = operation(var) - steps = op.steps - v = arguments(var)[1] - lshift = lowest_shift[v] - tt = op.t - elseif haskey(lowest_shift, var) - lshift = lowest_shift[var] - steps = 0 - tt = iv - v = var - else - continue - end - if lshift < steps - push!(dervaridxs, var2idx[var]) - end - for s in (steps - 1):-1:(lshift + 1) - sf = Shift(tt, s) - dvar = sf(v) - idx = addvar!(dvar, VARIABLE) - if !(idx in dervaridxs) - push!(dervaridxs, idx) - end - end - end - + add_intermediate_shifts!(fullvars, dervaridxs, var2idx, addvar!, iv) # sort `fullvars` such that the mass matrix is as diagonal as possible. dervaridxs = collect(dervaridxs) - sorted_fullvars = OrderedSet(fullvars[dervaridxs]) - var_to_old_var = Dict(zip(fullvars, fullvars)) + fullvars, var_types = sort_fullvars(fullvars, dervaridxs, var_types, iv) + var2idx = Dict{SymbolicT, Int}(fullvars .=> eachindex(fullvars)) + ndervars = length(dervaridxs) + # invalidate `dervaridxs`, it is just `1:ndervars` + dervaridxs = nothing + + # build `var_to_diff` + var_to_diff = build_var_to_diff(fullvars, ndervars, var2idx, iv) + + # build incidence graph + graph = build_incidence_graph(length(fullvars), symbolic_incidence, var2idx) + + @set! sys.eqs = eqs + + eq_to_diff = DiffGraph(nsrcs(graph)) + + return TearingState{typeof(sys)}(sys, fullvars, + SystemStructure(complete(var_to_diff), complete(eq_to_diff), + complete(graph), nothing, var_types, false), + Equation[], param_derivative_map, no_deriv_params, original_eqs, Equation[], typeof(sys)[]) +end + +function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var_types::Vector{VariableType}, @nospecialize(iv::Union{SymbolicT, Nothing})) + if iv === nothing + return fullvars, var_types + end + iv = iv::SymbolicT + sorted_fullvars = OrderedSet{SymbolicT}(fullvars[dervaridxs]) + var_to_old_var = Dict{SymbolicT, SymbolicT}(zip(fullvars, fullvars)) for dervaridx in dervaridxs dervar = fullvars[dervaridx] diffvar = var_to_old_var[lower_order_var(dervar, iv)] @@ -582,38 +550,166 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) sortperm = indexin(new_fullvars, fullvars) fullvars = new_fullvars var_types = var_types[sortperm] - var2idx = Dict(fullvars .=> eachindex(fullvars)) - dervaridxs = 1:length(dervaridxs) + return fullvars, var_types +end - # build `var_to_diff` +function build_var_to_diff(fullvars::Vector{SymbolicT}, ndervars::Int, var2idx::Dict{SymbolicT, Int}, @nospecialize(iv::Union{SymbolicT, Nothing})) nvars = length(fullvars) - diffvars = [] var_to_diff = DiffGraph(nvars, true) - for dervaridx in dervaridxs + if iv === nothing + return var_to_diff + end + iv = iv::SymbolicT + for dervaridx in 1:ndervars dervar = fullvars[dervaridx] diffvar = lower_order_var(dervar, iv) diffvaridx = var2idx[diffvar] - push!(diffvars, diffvar) var_to_diff[diffvaridx] = dervaridx end + return var_to_diff +end - # build incidence graph +function build_incidence_graph(nvars::Int, symbolic_incidence::Vector{Vector{SymbolicT}}, var2idx::Dict{SymbolicT, Int}) + neqs = length(symbolic_incidence) graph = BipartiteGraph(neqs, nvars, Val(false)) for (ie, vars) in enumerate(symbolic_incidence), v in vars - jv = var2idx[v] add_edge!(graph, ie, jv) end - @set! sys.eqs = eqs + return graph +end - eq_to_diff = DiffGraph(nsrcs(graph)) +function collect_vars_to_set!(buffer::Set{SymbolicT}, vars::Vector{SymbolicT}) + for x in vars + push!(buffer, x) + Moshi.Match.@match x begin + BSImpl.Term(; f, args) && if f === getindex end => push!(buffer, args[1]) + _ => nothing + end + sh = SU.shape(x) + sh isa SU.Unknown && continue + sh = sh::SU.ShapeVecT + isempty(sh) && continue + idxs = SU.stable_eachindex(x) + for i in idxs + push!(buffer, x[i]) + end + end +end - ts = TearingState(sys, fullvars, - SystemStructure(complete(var_to_diff), complete(eq_to_diff), - complete(graph), nothing, var_types, false), - Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) - return ts +function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, no_deriv_params::Set{SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation) + is_statemachine_equation = false + lhs = eq.lhs + rhs = eq.rhs + Moshi.Match.@match lhs begin + BSImpl.Term(; f, args) && if f isa Differential && iv isa SymbolicT && isequal(f.x, iv) && + is_time_dependent_parameter(args[1], ps, iv) + end => begin + # parameter derivatives are opted out by specifying `D(p) ~ missing`, but + # we want to store `nothing` in the map because that means `substitute` + # will ignore the rule. We will this identify the presence of `eq′.lhs` in + # the differentiated expression and error. + if eq.rhs !== COMMON_MISSING + param_derivative_map[lhs] = rhs + eq = Symbolics.COMMON_ZERO ~ rhs + else + # change the equation if the RHS is `missing` so the rest of this loop works + eq = Symbolics.COMMON_ZERO ~ Symbolics.COMMON_ZERO + push!(no_deriv_params, lhs) + delete!(param_derivative_map, lhs) + end + eqs_to_retain[i] = false + end + BSImpl.Const(; val) && if val isa StateMachineOperator end => begin + is_statemachine_equation = true + end + BSImpl.Const(;) && if _iszero(lhs) end => nothing + _ => begin + eq = Symbolics.COMMON_ZERO ~ (rhs - lhs) + end + end + return eq, is_statemachine_equation +end + +struct AddVar! + var2idx::Dict{SymbolicT, Int} + dvs::Set{SymbolicT} + fullvars::Vector{SymbolicT} + var_types::Vector{VariableType} +end + +function (avc::AddVar!)(var::SymbolicT, vtype::VariableType) + idx = get(avc.var2idx, var, nothing) + idx === nothing || return idx::Int + push!(avc.dvs, var) + push!(avc.fullvars, var) + push!(avc.var_types, vtype) + return avc.var2idx[var] = length(avc.fullvars) +end + +function add_intermediate_derivatives!(fullvars::Vector{SymbolicT}, dervaridxs::OrderedSet{Int}, addvar!::AddVar!) + for (i, v) in enumerate(fullvars) + while true + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f isa Differential end => begin + push!(dervaridxs, i) + v = args[1] + addvar!(v, VARIABLE) + end + _ => break + end + end + end +end + +function add_intermediate_shifts!(fullvars::Vector{SymbolicT}, dervaridxs::OrderedSet{Int}, var2idx::Dict{SymbolicT, Int}, addvar!::AddVar!, iv::Union{SymbolicT, Nothing}) + lowest_shift = Dict{SymbolicT, Int}() + for var in fullvars + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift end => begin + steps = f.steps + if steps > 0 + error("Only non-positive shifts allowed. Found $var with a shift of $steps") + end + v = args[1] + lowest_shift[v] = min(get(lowest_shift, v, 0), steps) + end + _ => nothing + end + end + for var in fullvars + lshift = typemax(Int) + steps = typemax(Int) + tt = iv + v = var + Moshi.Match.@match var begin + BSImpl.Term(; f, args) && if f isa Shift end => begin + steps = f.steps + v = args[1] + lshift = lowest_shift[v] + tt = f.t + end + if haskey(lowest_shift, var) end => begin + lshift = lowest_shift[var] + steps = 0 + tt = iv + v = var + end + _ => continue + end + if lshift < steps + push!(dervaridxs, var2idx[var]) + end + for s in (steps - 1):-1:(lshift + 1) + sf = Shift(tt, s) + dvar = sf(v) + idx = addvar!(dvar, VARIABLE) + if !(idx in dervaridxs) + push!(dervaridxs, idx) + end + end + end end """ @@ -680,7 +776,7 @@ function trivial_tearing!(ts::TearingState) end isvalid || continue # skip if the LHS is present in the RHS, since then this isn't explicit - if occursin(eq.lhs, eq.rhs) + if SU.query(isequal(eq.lhs), eq.rhs) push!(blacklist, i) continue end @@ -719,44 +815,46 @@ function trivial_tearing!(ts::TearingState) return ts end -function lower_order_var(dervar, t) - if isdifferential(dervar) - diffvar = arguments(dervar)[1] - elseif ModelingToolkit.isoperator(dervar, ModelingToolkit.Shift) - s = operation(dervar) - step = s.steps - 1 - vv = arguments(dervar)[1] - if step != 0 - diffvar = Shift(s.t, step)(vv) - else - diffvar = vv +function lower_order_var(dervar::SymbolicT, t::SymbolicT) + Moshi.Match.@match dervar begin + BSImpl.Term(; f, args) && if f isa Differential end => return args[1] + BSImpl.Term(; f, args) && if f isa Shift end => begin + step = f.steps - 1 + vv = args[1] + if step != 0 + diffvar = Shift(f.t, step)(vv) + else + diffvar = vv + end + return diffvar end - else - return Shift(t, -1)(dervar) + _ => return Shift(t, -1)(dervar) end - diffvar end function shift_discrete_system(ts::TearingState) @unpack fullvars, sys = ts - discvars = OrderedSet() + fullvars_set = Set{SymbolicT}(fullvars) + discvars = OrderedSet{SymbolicT}() eqs = equations(sys) for eq in eqs - vars!(discvars, eq; op = Union{Sample, Hold, Pre}) + SU.search_variables!(discvars, eq; is_atomic = OperatorIsAtomic{Union{Sample, Hold, Pre}}()) end - iv = get_iv(sys) - - discmap = Dict(k => StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) + iv = get_iv(sys)::SymbolicT + discmap = Dict{SymbolicT, SymbolicT}() for k in discvars - if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold, Pre})) + k in fullvars_set || continue + isoperator(k, Union{Sample, Hold, Pre}) && continue + discmap[k] = StructuralTransformations.simplify_shifts(Shift(iv, 1)(k)) + end for i in eachindex(fullvars) - fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( - fullvars[i], discmap; operator = Union{Sample, Hold, Pre})) + fullvars[i] = StructuralTransformations.simplify_shifts(substitute( + fullvars[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end for i in eachindex(eqs) - eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute( - eqs[i], discmap; operator = Union{Sample, Hold, Pre})) + eqs[i] = StructuralTransformations.simplify_shifts(substitute( + eqs[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end @set! ts.sys.eqs = eqs @set! ts.fullvars = fullvars @@ -846,7 +944,7 @@ function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure) " variables\n") Base.print_matrix(io, SystemStructurePrintMatrix(s)) else - S = incidence_matrix(s.graph, Num(Sym{Real}(:×))) + S = incidence_matrix(s.graph, Num(SSym(:×; type = Real, shape = SU.ShapeVecT()))) print(io, "Incidence matrix:") show(io, mime, S) end @@ -911,9 +1009,9 @@ function make_eqs_zero_equals!(ts::TearingState) end function mtkcompile!(state::TearingState; simplify = false, - check_consistency = true, fully_determined = true, warn_initialize_determined = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + check_consistency = true, fully_determined = true, + inputs = SymbolicT[], outputs = SymbolicT[], + disturbance_inputs = SymbolicT[], kwargs...) if !is_time_dependent(state.sys) return _mtkcompile!(state; simplify, check_consistency, @@ -925,8 +1023,6 @@ function mtkcompile!(state::TearingState; simplify = false, # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) - time_domains = merge(Dict(state.fullvars .=> ci.var_domain), - Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) if !isempty(tss) && continuous_id == 0 # do a trait check here - handle fully discrete system @@ -983,17 +1079,17 @@ function mtkcompile!(state::TearingState; simplify = false, end function _mtkcompile!(state::TearingState; simplify = false, - check_consistency = true, fully_determined = true, warn_initialize_determined = false, + check_consistency = true, fully_determined = true, dummy_derivative = true, - inputs = Any[], outputs = Any[], - disturbance_inputs = Any[], + inputs::Vector{SymbolicT} = SymbolicT[], outputs::Vector{SymbolicT} = SymbolicT[], + disturbance_inputs::Vector{SymbolicT} = SymbolicT[], kwargs...) if fully_determined isa Bool check_consistency &= fully_determined else check_consistency = true end - orig_inputs = Set() + orig_inputs = Set{SymbolicT}() ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs) state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs]) trivial_tearing!(state) @@ -1002,6 +1098,22 @@ function _mtkcompile!(state::TearingState; simplify = false, fully_determined = ModelingToolkit.check_consistency( state, orig_inputs; nothrow = fully_determined === nothing) end + # This phrasing avoids making the `kwcall` dynamic dispatch due to the type of a + # keyword (`mm`) being non-concrete + if mm isa SparseMatrixCLIL{BigInt, Int} + sys = _mtkcompile_worker!(state, sys, mm; fully_determined, dummy_derivative, simplify, kwargs...) + else + sys =_mtkcompile_worker!(state, sys, mm; fully_determined, dummy_derivative, simplify, kwargs...) + end + fullunknowns = [observables(sys); unknowns(sys)] + @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns) + + ModelingToolkit.invalidate_cache!(sys) +end + +function _mtkcompile_worker!(state::TearingState{S}, sys::S, mm::SparseMatrixCLIL{T, Int}; + fully_determined::Bool, dummy_derivative::Bool, simplify::Bool, + kwargs...) where {S, T} if fully_determined && dummy_derivative sys = ModelingToolkit.dummy_derivative( sys, state; simplify, mm, check_consistency, kwargs...) @@ -1009,17 +1121,14 @@ function _mtkcompile!(state::TearingState; simplify = false, var_eq_matching = pantelides!(state; finalize = false, kwargs...) sys = pantelides_reassemble(state, var_eq_matching) state = TearingState(sys) - sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) + sys, mm::SparseMatrixCLIL{T, Int} = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...) sys = ModelingToolkit.dummy_derivative( sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) else sys = ModelingToolkit.tearing( sys, state; simplify, mm, check_consistency, fully_determined, kwargs...) end - fullunknowns = [observables(sys); unknowns(sys)] - @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullunknowns) - - ModelingToolkit.invalidate_cache!(sys) + return sys end struct DifferentiatedVariableNotUnknownError <: Exception diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index acf7451065..050da35edb 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -12,7 +12,7 @@ function __get_literal_unit(x) if x isa Pair x = x[1] end - if !(x isa Union{Num, Symbolic}) + if !(x isa Union{Num, SymbolicT}) return nothing end v = value(x) @@ -23,8 +23,6 @@ function __get_scalar_unit_type(v) u = __get_literal_unit(v) if u isa DQ.AbstractQuantity return Val(:DynamicQuantities) - elseif u isa Unitful.Unitlike - return Val(:Unitful) end return nothing end @@ -71,7 +69,6 @@ get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(unwrap(x)) get_unit(x::Symbolics.Arr) = get_unit(unwrap(x)) get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) get_unit(op::typeof(getindex), args) = get_unit(args[1]) get_unit(x::SciMLBase.NullParameters) = unitless get_unit(op::typeof(instream), args) = get_unit(args[1]) @@ -114,7 +111,7 @@ function get_unit(op::Conditional, args) return terms[2] end -function get_unit(op::typeof(Symbolics._mapreduce), args) +function get_unit(op::typeof(mapreduce), args) if args[2] == + get_unit(args[3]) else @@ -129,8 +126,10 @@ function get_unit(op::Comparison, args) return unitless end -function get_unit(x::Symbolic) - if (u = __get_literal_unit(x)) !== nothing +function get_unit(x::SymbolicT) + if isconst(x) + return get_unit(value(x)) + elseif (u = __get_literal_unit(x)) !== nothing screen_unit(u) elseif issym(x) get_literal_unit(x) @@ -150,14 +149,14 @@ function get_unit(x::Symbolic) if base == unitless unitless else - pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] + isconst(pargs[2]) ? base^unwrap_const(pargs[2]) : (1 * base)^pargs[2] end elseif iscall(x) op = operation(x) if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + elseif iscall(op) && operation(op) === getindex + gp = arguments(op)[1] return screen_unit(getmetadata(gp, VariableUnit, unitless)) end # Actual function calls: args = arguments(x) @@ -196,7 +195,7 @@ function _validate(terms::Vector, labels::Vector{String}; info::String = "") equnit = safe_get_unit(term, info * label) if equnit === nothing valid = false - elseif !isequal(term, 0) + elseif !SU._iszero(term) if first_unit === nothing first_unit = equnit first_label = label @@ -249,14 +248,14 @@ function _validate(conn::Connection; info::String = "") end function validate(jump::Union{VariableRateJump, - ConstantRateJump}, t::Symbolic; + ConstantRateJump}, t::SymbolicT; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MassActionJump, t::SymbolicT; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -267,7 +266,7 @@ function validate(jump::MassActionJump, t::Symbolic; info::String = "") ["scaled_rates", "1/(t*reactants^$n))"]; info) end -function validate(jumps::Vector{JumpType}, t::Symbolic) +function validate(jumps::Vector{JumpType}, t::SymbolicT) labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] majs = filter(x -> x isa MassActionJump, jumps) crjs = filter(x -> x isa ConstantRateJump, jumps) @@ -277,14 +276,14 @@ function validate(jumps::Vector{JumpType}, t::Symbolic) end function validate(eq::Union{Inequality, Equation}; info::String = "") - if typeof(eq.lhs) == Connection - _validate(eq.rhs; info) + if isconst(eq.lhs) && value(eq.lhs) isa Connection + _validate(value(eq.rhs)::Connection; info) else _validate([eq.lhs, eq.rhs], ["left", "right"]; info) end end function validate(eq::Equation, - term::Union{Symbolic, DQ.AbstractQuantity, Num}; info::String = "") + term::Union{SymbolicT, DQ.AbstractQuantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end function validate(eq::Equation, terms::Vector; info::String = "") @@ -306,10 +305,10 @@ function validate(eqs::Vector, noise::Matrix; info::String = "") all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -function validate(eqs::Vector, term::Symbolic; info::String = "") +function validate(eqs::Vector, term::SymbolicT; info::String = "") all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing +validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing """ Throws error if units of equations are invalid. diff --git a/src/systems/validation.jl b/src/systems/validation.jl deleted file mode 100644 index d416a02ea2..0000000000 --- a/src/systems/validation.jl +++ /dev/null @@ -1,287 +0,0 @@ -module UnitfulUnitCheck - -using ..ModelingToolkit, Symbolics, SciMLBase, Unitful, RecursiveArrayTools -using ..ModelingToolkit: ValidationError, - ModelingToolkit, Connection, instream, JumpType, VariableUnit, - get_systems, - Conditional, Comparison -using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: Symbolic, value, issym, isadd, ismul, ispow -const MT = ModelingToolkit - -Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y -Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y - -""" -Throw exception on invalid unit types, otherwise return argument. -""" -function screen_unit(result) - result isa Unitful.Unitlike || - throw(ValidationError("Unit must be a subtype of Unitful.Unitlike, not $(typeof(result)).")) - result isa Unitful.ScalarUnits || - throw(ValidationError("Non-scalar units such as $result are not supported. Use a scalar unit instead.")) - result == u"°" && - throw(ValidationError("Degrees are not supported. Use radians instead.")) - result -end - -""" -Test unit equivalence. - -Example of implemented behavior: - -```julia -using ModelingToolkit, Unitful -MT = ModelingToolkit -@parameters γ P [unit = u"MW"] E [unit = u"kJ"] τ [unit = u"ms"] -@test MT.equivalent(u"MW", u"kJ/ms") # Understands prefixes -@test !MT.equivalent(u"m", u"cm") # Units must be same magnitude -@test MT.equivalent(MT.get_unit(P^γ), MT.get_unit((E / τ)^γ)) # Handles symbolic exponents -``` -""" -equivalent(x, y) = isequal(1 * x, 1 * y) -const unitless = Unitful.unit(1) - -""" -Find the unit of a symbolic item. -""" -get_unit(x::Real) = unitless -get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) -get_unit(x::AbstractArray) = map(get_unit, x) -get_unit(x::Num) = get_unit(value(x)) -function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}) - get_literal_unit(x) -end -get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::typeof(getindex), args) = get_unit(args[1]) -get_unit(x::SciMLBase.NullParameters) = unitless -get_unit(op::typeof(instream), args) = get_unit(args[1]) - -get_literal_unit(x) = screen_unit(getmetadata(x, VariableUnit, unitless)) - -function get_unit(op, args) # Fallback - result = op(1 .* get_unit.(args)...) - try - unit(result) - catch - throw(ValidationError("Unable to get unit for operation $op with arguments $args.")) - end -end - -function get_unit(op::Integral, args) - unit = 1 - if op.domain.variables isa Vector - for u in op.domain.variables - unit *= get_unit(u) - end - else - unit *= get_unit(op.domain.variables) - end - return get_unit(args[1]) * unit -end - -function get_unit(op::Conditional, args) - terms = get_unit.(args) - terms[1] == unitless || - throw(ValidationError(", in $op, [$(terms[1])] is not dimensionless.")) - equivalent(terms[2], terms[3]) || - throw(ValidationError(", in $op, units [$(terms[2])] and [$(terms[3])] do not match.")) - return terms[2] -end - -function get_unit(op::typeof(Symbolics._mapreduce), args) - if args[2] == + - get_unit(args[3]) - else - throw(ValidationError("Unsupported array operation $op")) - end -end - -function get_unit(op::Comparison, args) - terms = get_unit.(args) - equivalent(terms[1], terms[2]) || - throw(ValidationError(", in comparison $op, units [$(terms[1])] and [$(terms[2])] do not match.")) - return unitless -end - -function get_unit(x::Symbolic) - if issym(x) - get_literal_unit(x) - elseif isadd(x) - terms = get_unit.(arguments(x)) - firstunit = terms[1] - for other in terms[2:end] - termlist = join(map(repr, terms), ", ") - equivalent(other, firstunit) || - throw(ValidationError(", in sum $x, units [$termlist] do not match.")) - end - return firstunit - elseif ispow(x) - pargs = arguments(x) - base, expon = get_unit.(pargs) - @assert expon isa Unitful.DimensionlessUnits - if base == unitless - unitless - else - pargs[2] isa Number ? base^pargs[2] : (1 * base)^pargs[2] - end - elseif iscall(x) - op = operation(x) - if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls - return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) - return screen_unit(getmetadata(gp, VariableUnit, unitless)) - end # Actual function calls: - args = arguments(x) - return get_unit(op, args) - else # This function should only be reached by Terms, for which `iscall` is true - throw(ArgumentError("Unsupported value $x.")) - end -end - -""" -Get unit of term, returning nothing & showing warning instead of throwing errors. -""" -function safe_get_unit(term, info) - side = nothing - try - side = get_unit(term) - catch err - if err isa Unitful.DimensionError - @warn("$info: $(err.x) and $(err.y) are not dimensionally compatible.") - elseif err isa ValidationError - @warn(info*err.message) - elseif err isa MethodError - @warn("$info: no method matching $(err.f) for arguments $(typeof.(err.args)).") - else - rethrow() - end - end - side -end - -function _validate(terms::Vector, labels::Vector{String}; info::String = "") - valid = true - first_unit = nothing - first_label = nothing - for (term, label) in zip(terms, labels) - equnit = safe_get_unit(term, info * label) - if equnit === nothing - valid = false - elseif !isequal(term, 0) - if first_unit === nothing - first_unit = equnit - first_label = label - elseif !equivalent(first_unit, equnit) - valid = false - @warn("$info: units [$(first_unit)] for $(first_label) and [$(equnit)] for $(label) do not match.") - end - end - end - valid -end - -function _validate(conn::Connection; info::String = "") - valid = true - syss = get_systems(conn) - sys = first(syss) - unks = unknowns(sys) - for i in 2:length(syss) - s = syss[i] - _unks = unknowns(s) - if length(unks) != length(_unks) - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) have $(length(unks)) and $(length(_unks)) unknowns, cannot connect.") - continue - end - for (i, x) in enumerate(unks) - j = findfirst(isequal(x), _unks) - if j == nothing - valid = false - @warn("$info: connected systems $(nameof(sys)) and $(nameof(s)) do not have the same unknowns.") - else - aunit = safe_get_unit(x, info * string(nameof(sys)) * "#$i") - bunit = safe_get_unit(_unks[j], info * string(nameof(s)) * "#$j") - if !equivalent(aunit, bunit) - valid = false - @warn("$info: connected system unknowns $x and $(_unks[j]) have mismatched units.") - end - end - end - end - valid -end - -function validate(jump::Union{MT.VariableRateJump, - MT.ConstantRateJump}, t::Symbolic; - info::String = "") - newinfo = replace(info, "eq." => "jump") - _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units - validate(jump.affect!, info = newinfo) -end - -function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") - left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols - net_symbols = [x[1] for x in jump.net_stoch] - all_symbols = vcat(left_symbols, net_symbols) - allgood = _validate(all_symbols, string.(all_symbols); info) - n = sum(x -> x[2], jump.reactant_stoch, init = 0) - base_unitful = all_symbols[1] #all same, get first - allgood && _validate([jump.scaled_rates, 1 / (t * base_unitful^n)], - ["scaled_rates", "1/(t*reactants^$n))"]; info) -end - -function validate(jumps::Vector{JumpType}, t::Symbolic) - labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] - majs = filter(x -> x isa MassActionJump, jumps) - crjs = filter(x -> x isa ConstantRateJump, jumps) - vrjs = filter(x -> x isa VariableRateJump, jumps) - splitjumps = [majs, crjs, vrjs] - all([validate(js, t; info) for (js, info) in zip(splitjumps, labels)]) -end - -function validate(eq::MT.Equation; info::String = "") - if typeof(eq.lhs) == Connection - _validate(eq.rhs; info) - else - _validate([eq.lhs, eq.rhs], ["left", "right"]; info) - end -end -function validate(eq::MT.Equation, - term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "") - _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) -end -function validate(eq::MT.Equation, terms::Vector; info::String = "") - _validate(vcat([eq.lhs, eq.rhs], terms), - vcat(["left", "right"], "noise #" .* string.(1:length(terms))); info) -end - -""" -Returns true iff units of equations are valid. -""" -function validate(eqs::Vector; info::String = "") - all([validate(eqs[idx], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Vector; info::String = "") - all([validate(eqs[idx], noise[idx], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, noise::Matrix; info::String = "") - all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") - for idx in 1:length(eqs)]) -end -function validate(eqs::Vector, term::Symbolic; info::String = "") - all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) -end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing - -""" -Throws error if units of equations are invalid. -""" -function MT.check_units(::Val{:Unitful}, eqs...) - validate(eqs...) || - throw(ValidationError("Some equations had invalid units. See warnings for details.")) -end - -end # module diff --git a/src/utils.jl b/src/utils.jl index 0da7e4860b..0c8a00b080 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,7 +20,7 @@ function detime_dvs(op) if !iscall(op) op elseif issym(operation(op)) - Sym{Real}(nameof(operation(op))) + SSym(nameof(operation(op)); type = Real, shape = SU.shape(op)) else maketerm(typeof(op), operation(op), detime_dvs.(arguments(op)), metadata(op)) @@ -33,7 +33,7 @@ end Reverse `detime_dvs` for the given `dvs` using independent variable `iv`. """ function retime_dvs(op, dvs, iv) - issym(op) && return Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(op))(iv) + issym(op) && return SSym(nameof(op); type = FnType{Tuple{symtype(iv)}, Real}, shape = SU.ShapeVecT())(iv) iscall(op) ? maketerm(typeof(op), operation(op), retime_dvs.(arguments(op), (dvs,), (iv,)), metadata(op)) : @@ -84,7 +84,7 @@ end function readable_code(expr) expr = Base.remove_linenums!(_readable_code(expr)) rec_remove_macro_linenums!(expr) - JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) + return string(expr) end # System validation enums @@ -117,11 +117,14 @@ const CheckUnits = 1 << 2 function check_independent_variables(ivs) for iv in ivs - isparameter(iv) || - @warn "Independent variable $iv should be defined with @independent_variables $iv." + isparameter(iv) || @invokelatest warn_indepvar(iv) end end +@noinline function warn_indepvar(iv::SymbolicT) + @warn "Independent variable $iv should be defined with @independent_variables $iv." +end + function check_parameters(ps, iv) for p in ps isequal(iv, p) && @@ -129,29 +132,23 @@ function check_parameters(ps, iv) end end -function is_delay_var(iv, var) - if Symbolics.isarraysymbolic(var) - return is_delay_var(iv, first(collect(var))) - end - args = nothing - try - args = arguments(var) - catch - return false +function is_delay_var(iv::SymbolicT, var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(; f, args) => begin + length(args) > 1 && return false + arg = args[1] + isequal(arg, iv) && return false + return symtype(arg) <: Real + end + _ => false end - length(args) > 1 && return false - isequal(first(args), iv) && return false - delay = iv - first(args) - delay isa Integer || - delay isa AbstractFloat || - (delay isa Num && isreal(value(delay))) end function check_variables(dvs, iv) for dv in dvs isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables.")) - (is_delay_var(iv, dv) || occursin(iv, dv)) || + (is_delay_var(iv, dv) || SU.query(isequal(iv), dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv.")) end end @@ -187,20 +184,35 @@ function collect_ivs(eqs, op = Differential) return ivs end +struct IndepvarCheckPredicate + iv::SymbolicT +end + +function (icp::IndepvarCheckPredicate)(ex::SymbolicT) + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Differential end => begin + f = f::Differential + isequal(f.x, icp.iv) || throw_multiple_iv(icp.iv, f.x) + return false + end + _ => false + end +end + +@noinline function throw_multiple_iv(iv, newiv) + throw(ArgumentError("Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed.")) +end + """ check_equations(eqs, iv) Assert that equations are well-formed when building ODE, i.e., only containing a single independent variable. """ -function check_equations(eqs, iv) - ivs = collect_ivs(eqs) - display = collect(ivs) - length(ivs) <= 1 || - throw(ArgumentError("Differential w.r.t. multiple variables $display are not allowed.")) - if length(ivs) == 1 - single_iv = pop!(ivs) - isequal(single_iv, iv) || - throw(ArgumentError("Differential w.r.t. variable ($single_iv) other than the independent variable ($iv) are not allowed.")) +function check_equations(eqs::Vector{Equation}, iv::SymbolicT) + icp = IndepvarCheckPredicate(iv) + for eq in eqs + SU.query(icp, eq.lhs) + SU.query(icp, eq.rhs) end end @@ -211,10 +223,12 @@ Assert that the subsystems have the appropriate namespacing behavior. """ function check_subsystems(systems) idxs = findall(!does_namespacing, systems) - if !isempty(idxs) - names = join(" " .* string.(nameof.(systems[idxs])), "\n") - throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) - end + isempty(idxs) || throw_bad_namespacing(systems, idxs) +end + +@noinline function throw_bad_namespacing(systems, idxs) + names = join(" " .* string.(nameof.(systems[idxs])), "\n") + throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) end """ @@ -271,57 +285,73 @@ function setdefault(v, val) val === nothing ? v : wrap(setdefaultval(unwrap(v), value(val))) end -function process_variables!(var_to_name, defs, guesses, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, guesses::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_guesses!(guesses, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function process_variables!(var_to_name, defs, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function collect_defaults!(defs, vars) +function collect_defaults!(defs::SymmapT, vars::Vector{SymbolicT}) for v in vars - symbolic_type(v) == NotSymbolic() && continue - if haskey(defs, v) || !hasdefault(unwrap(v)) || (def = getdefault(v)) === nothing + isconst(v) && continue + haskey(defs, v) && continue + def = Symbolics.getdefaultval(v, nothing) + if def !== nothing + defs[v] = SU.Const{VartypeT}(def) continue end - defs[v] = getdefault(v) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex end => begin + haskey(defs, args[1]) && continue + def = Symbolics.getdefaultval(args[1], nothing) + def === nothing && continue + defs[args[1]] = def + end + _ => nothing + end end return defs end -function collect_guesses!(guesses, vars) +function collect_guesses!(guesses::SymmapT, vars::Vector{SymbolicT}) for v in vars + isconst(v) && continue symbolic_type(v) == NotSymbolic() && continue - if haskey(guesses, v) || !hasguess(unwrap(v)) || (def = getguess(v)) === nothing + haskey(guesses, v) && continue + def = getguess(v) + if def !== nothing + guesses[v] = SU.Const{VartypeT}(def) continue end - guesses[v] = getguess(v) + Moshi.Match.@match v begin + BSImpl.Term(; f, args) && if f === getindex end => begin + haskey(guesses, args[1]) && continue + def = Symbolics.getdefaultval(args[1], nothing) + def === nothing && continue + guesses[args[1]] = def + end + _ => nothing + end end return guesses end -function collect_var_to_name!(vars, xs) +function collect_var_to_name!(vars::Dict{Symbol, SymbolicT}, xs::Vector{SymbolicT}) for x in xs - symbolic_type(x) == NotSymbolic() && continue - x = unwrap(x) - if hasmetadata(x, Symbolics.GetindexParent) - xarr = getmetadata(x, Symbolics.GetindexParent) - hasname(xarr) || continue - vars[Symbolics.getname(xarr)] = xarr - else - if iscall(x) && operation(x) === getindex - x = arguments(x)[1] - end - x = unwrap(x) - hasname(x) || continue - vars[Symbolics.getname(unwrap(x))] = x + x = Moshi.Match.@match x begin + BSImpl.Const(;) => continue + BSImpl.Term(; f, args) && if f === getindex end => args[1] + _ => x end + hasname(x) || continue + vars[getname(x)] = x end end @@ -329,9 +359,7 @@ end Throw error when difference/derivative operation occurs in the R.H.S. """ @noinline function throw_invalid_operator(opvar, eq, op::Type) - if op === Difference - error("The Difference operator is deprecated, use ShiftIndex instead") - elseif op === Differential + if op === Differential optext = "derivative" end msg = "The $optext variable must be isolated to the left-hand " * @@ -354,11 +382,11 @@ end Check if all the LHS are unique """ function check_operator_variables(eqs, op::T) where {T} - ops = Set() - tmp = Set() + ops = Set{SymbolicT}() + tmp = Set{SymbolicT}() for eq in eqs _check_operator_variables(eq, op) - vars!(tmp, eq.lhs) + SU.search_variables!(tmp, eq.lhs; is_atomic = OperatorIsAtomic{Differential}()) if length(tmp) == 1 x = only(tmp) if op === Differential @@ -381,18 +409,21 @@ function check_operator_variables(eqs, op::T) where {T} end end -isoperator(expr, op) = iscall(expr) && operation(expr) isa op -isoperator(op) = expr -> isoperator(expr, op) +function isoperator(expr::SymbolicT, ::Type{op}) where {op <: SU.Operator} + Moshi.Match.@match expr begin + BSImpl.Term(; f) => f isa op + _ => false + end +end +isoperator(::Type{op}) where {op <: SU.Operator} = Base.Fix2(isoperator, op) isdifferential(expr) = isoperator(expr, Differential) isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) isvariable(x::Num)::Bool = isvariable(value(x)) -function isvariable(x)::Bool - x isa Symbolic || return false - p = getparent(x, nothing) - p === nothing || (x = p) - hasmetadata(x, VariableSource) +function isvariable(x) + x isa SymbolicT || return false + hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1])::Bool end """ @@ -412,7 +443,7 @@ v = ModelingToolkit.vars(D(y) ~ u) v == Set([D(y), u]) ``` """ -function vars(exprs::Symbolic; op = Differential) +function vars(exprs::SymbolicT; op = Differential) iscall(exprs) ? vars([exprs]; op = op) : Set([exprs]) end vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) @@ -487,16 +518,16 @@ function collect_operator_variables(eq::Equation, args...) end """ - collect_operator_variables(eqs::AbstractVector{Equation}, op) + collect_operator_variables(eqs::Vector{Equation}, ::Type{op}) where {op} Return a `Set` containing all variables that have Operator `op` applied to them. See also [`collect_differential_variables`](@ref). """ -function collect_operator_variables(eqs::AbstractVector{Equation}, op) - vars = Set() - diffvars = Set() +function collect_operator_variables(eqs::Vector{Equation}, ::Type{op}) where {op} + vars = Set{SymbolicT}() + diffvars = Set{SymbolicT}() for eq in eqs - vars!(vars, eq; op = op) + SU.search_variables!(vars, eq; is_atomic = OperatorIsAtomic{op}()) for v in vars isoperator(v, op) || continue push!(diffvars, arguments(v)[1]) @@ -522,13 +553,10 @@ ModelingToolkit.collect_applied_operators(eq, Differential) == Set([D(y)]) The difference compared to `collect_operator_variables` is that `collect_operator_variables` returns the variable without the operator applied. """ -function collect_applied_operators(x, op) - v = vars(x, op = op) - filter(v) do x - issym(x) && return false - iscall(x) && return operation(x) isa op - false - end +function collect_applied_operators(x, ::Type{op}) where {op} + v = Set{SymbolicT}() + SU.search_variables!(v, x; is_atomic = OnlyOperatorIsAtomic{op}()) + return v end """ @@ -539,12 +567,12 @@ Search through equations and parameter dependencies of `sys`, where sys is at a recursively searches through all subsystems of `sys`, increasing the depth if it is not `-1`. A depth of `-1` indicates searching for variables with `GlobalScope`. """ -function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential) +function collect_scoped_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, sys::AbstractSystem, iv::Union{SymbolicT, Nothing}; depth = 1, op = Differential) if has_eqs(sys) for eq in equations(sys) eqtype_supports_collect_vars(eq) || continue if eq isa Equation - eq.lhs isa Union{Symbolic, Number} || continue + symtype(eq.lhs) <: Number || continue end collect_vars!(unknowns, parameters, eq, iv; depth, op) end @@ -618,6 +646,24 @@ function Base.showerror(io::IO, err::OperatorIndepvarMismatchError) end end +struct OnlyOperatorIsAtomic{O} end + +function (::OnlyOperatorIsAtomic{O})(ex::SymbolicT) where {O} + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa O end => true + _ => false + end +end + +struct OperatorIsAtomic{O} end + +function (::OperatorIsAtomic{O})(ex::SymbolicT) where {O} + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => f isa O + _ => true + end +end + """ $(TYPEDSIGNATURES) @@ -632,11 +678,15 @@ can be checked using `check_scope_depth`. This function should return `nothing`. """ -function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics.Operator) - if issym(expr) - return collect_var!(unknowns, parameters, expr, iv; depth) +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + Moshi.Match.@match expr begin + BSImpl.Const(;) => return + BSImpl.Sym(;) => return collect_var!(unknowns, parameters, expr, iv; depth) + _ => nothing end - for var in vars(expr; op) + vars = Set{SymbolicT}() + SU.search_variables!(vars, expr; is_atomic = OperatorIsAtomic{op}()) + for var in vars while iscall(var) && operation(var) isa op validate_operator(operation(var), arguments(var), iv; context = expr) var = arguments(var)[1] @@ -646,6 +696,13 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics return nothing end +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + for var in expr + collect_vars!(unknowns, parameters, var, iv; depth, op) + end + return nothing +end + """ $(TYPEDSIGNATURES) @@ -657,7 +714,7 @@ eqtype_supports_collect_vars(eq::Equation) = true eqtype_supports_collect_vars(eq::Inequality) = true eqtype_supports_collect_vars(eq::Pair) = true -function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv; +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, eq::Union{Equation, Inequality}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) @@ -665,12 +722,17 @@ function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv end function collect_vars!( - unknowns, parameters, p::Pair, iv; depth = 0, op = Symbolics.Operator) + unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, p::Pair, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) collect_vars!(unknowns, parameters, p[1], iv; depth, op) collect_vars!(unknowns, parameters, p[2], iv; depth, op) return nothing end +function collect_vars!( + unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + return nothing +end + """ $(TYPEDSIGNATURES) @@ -678,7 +740,7 @@ Identify whether `var` belongs to the current system using `depth` and scoping i Add `var` to `unknowns` or `parameters` appropriately, and search through any expressions in known metadata of `var` using `collect_vars!`. """ -function collect_var!(unknowns, parameters, var, iv; depth = 0) +function collect_var!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, var::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0) isequal(var, iv) && return nothing if Symbolics.iswrapped(var) error(""" @@ -687,11 +749,11 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0) Encountered a wrapped value in `collect_var!`. This function should only ever \ receive unwrapped symbolic variables. This is likely a bug in the code generating \ an expression passed to `collect_vars!` or `collect_scoped_vars!`. A common cause \ - is using `substitute` or `fast_substitute` with rules where the values are \ + is using `substitute` with rules where the values are \ wrapped symbolic variables. """) end - check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing + check_scope_depth(getmetadata(var, SymScope, LocalScope())::AllScopes, depth) || return nothing var = setmetadata(var, SymScope, LocalScope()) if iscalledparameter(var) callable = getcalledparameter(var) @@ -719,7 +781,7 @@ function check_scope_depth(scope, depth) if scope isa LocalScope return depth == 0 elseif scope isa ParentScope - return depth > 0 && check_scope_depth(scope.parent, depth - 1) + return depth > 0 && check_scope_depth(scope.parent, depth - 1)::Bool elseif scope isa GlobalScope return depth == -1 end @@ -803,7 +865,7 @@ end function _with_unit(f, x, t, args...) x = f(x, args...) - if hasmetadata(x, VariableUnit) && (t isa Symbolic && hasmetadata(t, VariableUnit)) + if hasmetadata(x, VariableUnit) && (t isa SymbolicT && hasmetadata(t, VariableUnit)) xu = getmetadata(x, VariableUnit) tu = getmetadata(t, VariableUnit) x = setmetadata(x, VariableUnit, xu / tu) @@ -833,8 +895,8 @@ end Check if `T` is an appropriate symtype for a symbolic variable representing a floating point number or array of such numbers. """ -function is_floatingpoint_symtype(T::Type) - return T == Real || T == Number || T == Complex || T <: AbstractFloat || +function is_floatingpoint_symtype(T) + return T === Real || T === Number || T === Complex || T <: AbstractFloat || T <: AbstractArray && is_floatingpoint_symtype(eltype(T)) end @@ -902,7 +964,16 @@ Keyword arguments: `available_vars` will not be searched for in the observed equations. """ function observed_equations_used_by(sys::AbstractSystem, exprs; - involved_vars = vars(exprs; op = Union{Shift, Differential, Initial}), obs = observed(sys), available_vars = []) + involved_vars = nothing, obs = observed(sys), available_vars = Set{SymbolicT}()) + if involved_vars === nothing + involved_vars = Set{SymbolicT}() + SU.search_variables!(involved_vars, exprs; is_atomic = OperatorIsAtomic{Union{Shift, Differential, Initial}}()) + elseif !(involved_vars isa Set{SymbolicT}) + involved_vars = Set{SymbolicT}(involved_vars) + end + if !(available_vars isa Set) + available_vars = Set(available_vars) + end if iscomplete(sys) && obs == observed(sys) cache = getmetadata(sys, MutableCacheKey, nothing) obs_graph_cache = get!(cache, ObservedGraphCacheKey) do @@ -916,10 +987,6 @@ function observed_equations_used_by(sys::AbstractSystem, exprs; graph = observed_dependency_graph(obs) end - if !(available_vars isa Set) - available_vars = Set(available_vars) - end - obsidxs = BitSet() for sym in involved_vars sym in available_vars && continue @@ -978,7 +1045,7 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) end any(isequal(expr), vars) && return expr iscall(expr) || return expr - Symbolics.shape(expr) == Symbolics.Unknown() && return expr + symbolic_has_known_size(expr) || return expr haskey(state, expr) && return state[expr] op = operation(expr) args = arguments(expr) @@ -987,9 +1054,8 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) # OR # none of `vars` are involved in `expr` if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) || - (vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs)) + (vs = SU.search_variables(expr); intersect!(vs, vars); isempty(vs)) sym = gensym(:subexpr) - stype = symtype(expr) var = similar_variable(expr, sym) state[expr] = var return var @@ -999,7 +1065,7 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) indep_args = [] dep_args = [] for arg in args - _vs = ModelingToolkit.vars(arg) + _vs = SU.search_variables(arg) intersect!(_vs, vars) if !isempty(_vs) push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state)) @@ -1069,7 +1135,7 @@ function var_in_varlist(var, varlist::AbstractSet, iv) # indexed array symbolic, unscalarized array present (iscall(var) && operation(var) === getindex && arguments(var)[1] in varlist) || # unscalarized sized array symbolic, all scalarized elements present - (symbolic_type(var) == ArraySymbolic() && is_sized_array_symbolic(var) && + (symbolic_type(var) == ArraySymbolic() && symbolic_has_known_size(var) && all(x -> x in varlist, collect(var))) || # delayed variables (isdelay(var, iv) && var_in_varlist(operation(var)(iv), varlist, iv)) @@ -1113,25 +1179,19 @@ Given a list of equations where some may be array equations, flatten the array e without scalarizing occurrences of array variables and return the new list of equations. """ function flatten_equations(eqs::Vector{Equation}) - mapreduce(vcat, eqs; init = Equation[]) do eq - islhsarr = eq.lhs isa AbstractArray || Symbolics.isarraysymbolic(eq.lhs) - isrhsarr = eq.rhs isa AbstractArray || Symbolics.isarraysymbolic(eq.rhs) - if islhsarr || isrhsarr - islhsarr && isrhsarr || - error(""" - LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions \ - or both scalar - """) - size(eq.lhs) == size(eq.rhs) || - error(""" - Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got \ - $(size(eq.lhs)) and $(size(eq.rhs)) - """) - return vec(collect(eq.lhs) .~ collect(eq.rhs)) - else - eq + _eqs = Equation[] + for eq in eqs + if !SU.is_array_shape(SU.shape(eq.lhs)) + push!(_eqs, eq) + continue + end + lhs = vec(collect(eq.lhs)::Array{SymbolicT})::Vector{SymbolicT} + rhs = vec(collect(eq.rhs)::Array{SymbolicT})::Vector{SymbolicT} + for (l, r) in zip(lhs, rhs) + push!(_eqs, l ~ r) end end + return _eqs end const JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump} @@ -1184,4 +1244,4 @@ function wrap_with_D(n, D, repeats) else wrap_with_D(D(n), D, repeats - 1) end -end \ No newline at end of file +end diff --git a/src/variables.jl b/src/variables.jl index 46c9c95bc6..7075aa3a0c 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -179,7 +179,7 @@ struct Stream <: AbstractConnectType end # special stream connector Get the connect type of x. See also [`hasconnect`](@ref). """ getconnect(x::Num) = getconnect(unwrap(x)) -getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing) +getconnect(x::SymbolicT) = Symbolics.getmetadata(x, VariableConnectType, nothing) """ hasconnect(x) @@ -190,13 +190,13 @@ function setconnect(x, t::Type{T}) where {T <: AbstractConnectType} setmetadata(x, VariableConnectType, t) end -### Input, Output, Irreducible -isvarkind(m, x::Union{Num, Symbolics.Arr}) = isvarkind(m, value(x)) -function isvarkind(m, x) - iskind = getmetadata(x, m, nothing) - iskind !== nothing && return iskind - x = getparent(x, x) - getmetadata(x, m, false) +### Input, Output, Irreducible +isvarkind(m, x, def = false) = safe_getmetadata(m, x, def) +safe_getmetadata(m, x::Union{Num, Symbolics.Arr}, def) = safe_getmetadata(m, value(x), def) +function safe_getmetadata(m::DataType, x::SymbolicT, default) + hasmetadata(x, m) && return getmetadata(x, m) + iscall(x) && operation(x) === getindex && return safe_getmetadata(m, arguments(x)[1], default) + return default end """ @@ -218,13 +218,13 @@ setio(x, i::Bool, o::Bool) = setoutput(setinput(x, i), o) Check if variable `x` is marked as an input. """ -isinput(x) = isvarkind(VariableInput, x) +isinput(x) = isvarkind(VariableInput, x)::Bool """ $(TYPEDSIGNATURES) Check if variable `x` is marked as an output. """ -isoutput(x) = isvarkind(VariableOutput, x) +isoutput(x) = isvarkind(VariableOutput, x)::Bool # Before the solvability check, we already have handled IO variables, so # irreducibility is independent from IO. @@ -234,7 +234,7 @@ isoutput(x) = isvarkind(VariableOutput, x) Check if `x` is marked as irreducible. This prevents it from being eliminated as an observed variable in `mtkcompile`. """ -isirreducible(x) = isvarkind(VariableIrreducible, x) +isirreducible(x) = isvarkind(VariableIrreducible, x)::Bool setirreducible(x, v::Bool) = setmetadata(x, VariableIrreducible, v) state_priority(x::Union{Num, Symbolics.Arr}) = state_priority(unwrap(x)) """ @@ -245,19 +245,30 @@ chosen as a state in `mtkcompile`. """ state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64 -normalize_to_differential(x) = x +function normalize_to_differential(@nospecialize(op)) + if op isa Shift && op.t isa SymbolicT + return Differential(op.t) ^ op.steps + else + return op + end +end -function default_toterm(x) - if iscall(x) && (op = operation(x)) isa Operator - if !(op isa Differential) - if op isa Shift && op.steps < 0 +default_toterm(x) = x +function default_toterm(x::SymbolicT) + Moshi.Match.@match x begin + BSImpl.Term(; f, args, shape, type, metadata) && if f isa Operator end => begin + if f isa Shift && f.steps < 0 return shift2term(x) + elseif f isa Differential + return Symbolics.diff2term(x) + else + newf = normalize_to_differential(f) + f === newf && return x + x = BSImpl.Term{VartypeT}(newf, args; type, shape, metadata) + return Symbolics.diff2term(x) end - x = normalize_to_differential(op)(arguments(x)...) end - Symbolics.diff2term(x) - else - x + _ => return x end end @@ -280,12 +291,12 @@ Create parameters with bounds like this @parameters p [bounds=(-1, 1)] ``` """ -function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic}) - x = unwrap(x) - p = Symbolics.getparent(x, nothing) - if p === nothing +getbounds(x::Union{Num, Symbolics.Arr}) = getbounds(unwrap(x)) +function getbounds(x::SymbolicT) + if operation(p) === getindex + p = arguments(p)[1] bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf)) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() + if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) bounds = map(bounds) do b b isa AbstractArray && return b return fill(b, size(x)) @@ -297,7 +308,7 @@ function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic}) idxs = arguments(x)[2:end] bounds = map(bounds) do b if b isa AbstractArray - if Symbolics.shape(p) != Symbolics.Unknown() && size(p) != size(b) + if symbolic_has_known_size(p) && size(p) != size(b) throw(DimensionMismatch("Expected array variable $p with shape $(size(p)) to have bounds of identical size. Found $bounds of size $(size(bounds)).")) end return b[idxs...] @@ -318,8 +329,8 @@ Determine whether symbolic variable `x` has bounds associated with it. See also [`getbounds`](@ref). """ function hasbounds(x) - b = getbounds(x) - any(isfinite.(b[1]) .|| isfinite.(b[2])) + b = getbounds(x)::NTuple{2} + any(isfinite.(b[1]) .|| isfinite.(b[2]))::Bool end function setbounds(x::Num, bounds) @@ -339,9 +350,7 @@ isdisturbance(x::Num) = isdisturbance(Symbolics.unwrap(x)) Determine whether symbolic variable `x` is marked as a disturbance input. """ function isdisturbance(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDisturbance, false) + isvarkind(VariableDisturbance, x)::Bool end setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v) @@ -372,9 +381,7 @@ Create a tunable parameter by See also [`tunable_parameters`](@ref), [`getbounds`](@ref) """ function istunable(x, default = true) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableTunable, default) + isvarkind(VariableTunable, x, default)::Bool end ## Dist ======================================================================== @@ -398,9 +405,7 @@ getdist(u) # retrieve distribution ``` """ function getdist(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDistribution, nothing) + safe_getmetadata(VariableDistribution, x, nothing) end """ @@ -492,9 +497,7 @@ getdescription(x::Symbolics.Arr) = getdescription(Symbolics.unwrap(x)) Return any description attached to variables `x`. If no description is attached, an empty string is returned. """ function getdescription(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDescription, "") + safe_getmetadata(VariableDescription, x, "") end """ @@ -512,7 +515,7 @@ end Maps the brownianiable to an unknown. """ -tobrownian(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) +tobrownian(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) tobrownian(s::Num) = Num(tobrownian(value(s))) isbrownian(s) = getvariabletype(s) === BROWNIAN @@ -526,10 +529,10 @@ macro brownians(xs...) x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$), xs) || error("@brownians only takes scalar expressions!") - Symbolics._parse_vars(:brownian, + Symbolics.parse_vars(:brownian, Real, xs, - tobrownian) |> esc + tobrownian) end ## Guess ====================================================================== @@ -587,7 +590,7 @@ Fetch any miscellaneous data associated with symbolic variable `x`. See also [`hasmisc(x)`](@ref). """ getmisc(x::Num) = getmisc(unwrap(x)) -getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing) +getmisc(x::SymbolicT) = Symbolics.getmetadata(x, VariableMisc, nothing) """ hasmisc(x) @@ -606,7 +609,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata) Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions. """ getunit(x::Num) = getunit(unwrap(x)) -getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing) +getunit(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnit, nothing) """ hasunit(x) @@ -615,10 +618,10 @@ Check if the variable `x` has a unit. hasunit(x) = getunit(x) !== nothing getunshifted(x::Num) = getunshifted(unwrap(x)) -getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) +getunshifted(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnshifted, nothing)::Union{SymbolicT, Nothing} getshift(x::Num) = getshift(unwrap(x)) -getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) +getshift(x::SymbolicT) = Symbolics.getmetadata(x, VariableShift, 0)::Int ################### ### Evaluate at ### @@ -629,7 +632,7 @@ getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) An operator that evaluates time-dependent variables at a specific absolute time point `t`. # Fields -- `t::Union{Symbolic, Number}`: The absolute time at which to evaluate the variable. +- `t::Union{SymbolicT, Number}`: The absolute time at which to evaluate the variable. # Description `EvalAt` is used to evaluate time-dependent variables at a specific time point. This is particularly @@ -677,12 +680,12 @@ end See also: [`Differential`](@ref) """ struct EvalAt <: Symbolics.Operator - t::Union{Symbolic, Number} + t::Union{SymbolicT, Number} end -function (A::EvalAt)(x::Symbolic) +function (A::EvalAt)(x::SymbolicT) if symbolic_type(x) == NotSymbolic() || !iscall(x) - if x isa Symbolics.CallWithMetadata + if x isa CallAndWrap return x(A.t) else return x diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index dc9d71f300..19899aa50f 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -340,11 +340,12 @@ foofn(x) = 4 @testset "`respecialize`" begin @parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r - rp, - rp2 = let - only(@parameters p::Bar), - SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz) - end + rp = only(let p = nothing + @parameters p::Bar + end) + rp2 = only(let p2 = nothing + @parameters p2(t)::Baz + end) @variables x(t) = 1.0 @named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r]) diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 2c713d4149..eefd155db6 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -201,8 +201,7 @@ resistor = getproperty(rc, :resistor; namespace = false) @named pi_model = PiModel() - @test typeof(ModelingToolkit.getdefault(pi_model.p)) <: - SymbolicUtils.BasicSymbolic{Irrational} + @test symtype(ModelingToolkit.getdefault(pi_model.p)) <: Irrational @test getdefault(getdefault(pi_model.p)) == π end @@ -1007,7 +1006,7 @@ end vars = Symbolics.get_variables(only(equations(ex))) @test length(vars) == 2 for u in Symbolics.unwrap.(unknowns(ex)) - @test !Symbolics.hasmetadata(u, Symbolics.CallWithParent) + @test !SymbolicUtils.is_function_symbolic(u) @test any(isequal(u), vars) end end diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 28ab3759ef..08beb7ed53 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -357,7 +357,7 @@ ps = MTKParameters( (BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]), BlockedArray(falses(1), [1, 0])), (), (), ()) -@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} +@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector} tsidx1 = 1 tsidx2 = 2 @test length(ps.discrete[1][Block(tsidx1)]) == 3 @@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values( sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false]))) @test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0] @test ps.discrete[2][Block(tsidx1)][] == false + +@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin + @variables x(t) + @parameters p::Any + @named sys = System(D(x) ~ x, t, [x], [p]) + sys = complete(sys) + ps = MTKParameters(sys, [p => 1.0]) + @test ps.nonnumeric isa Tuple{Vector{Any}} + ps2 = remake_buffer(sys, ps, [p], [:a]) + @test ps2.nonnumeric isa Tuple{Vector{Any}} +end diff --git a/test/odesystem.jl b/test/odesystem.jl index cf16b31a42..417eaa70a7 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -535,7 +535,7 @@ sys = complete(sys) us = map(s -> (@variables $s(t))[1], syms) ps = map(s -> (@variables $s(t))[1], syms_p) buffer, = @variables $buffername[1:length(u0)] - dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia + dummy_var = Symbolics.SSym(:_; type = Any) # this is safe because _ cannot be a rvalue in Julia ss = Iterators.flatten((us, ps)) vv = Iterators.flatten((u0, p0)) diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index 30e80e1cec..c29b68b9be 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -399,7 +399,7 @@ end prob = OptimizationProblem(sys, [x => [42.0, 12.37]]; hess = true, sparse = true) symbolic_hess = Symbolics.hessian(cost(sys), x) - symbolic_hess_value = Symbolics.fast_substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) + symbolic_hess_value = substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) oop_hess = prob.f.hess(prob.u0, prob.p) @test oop_hess ≈ symbolic_hess_value diff --git a/test/precompile_test.jl b/test/precompile_test.jl index 38051d9d49..6fdc1b5cbf 100644 --- a/test/precompile_test.jl +++ b/test/precompile_test.jl @@ -1,5 +1,6 @@ using Test using ModelingToolkit +using OrdinaryDiffEqDefault using Distributed @@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1) @test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) == ODEPrecompileTest @test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16] + +@test_nowarn solve(ODEPrecompileTest.prob_eval) diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl index 2111f7ba64..c77c897655 100644 --- a/test/precompile_test/ODEPrecompileTest.jl +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__) # Change the module the eval'd function is eval'd into to be the containing module, # which should make it be in the package image const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__) + +function problem(; kwargs...) + # Define some variables + @independent_variables t + @parameters σ ρ β + @variables x(t) y(t) z(t) + D = Differential(t) + + # Define a differential equation + eqs = [D(x) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + + @named de = System(eqs, t) + de = complete(de) + return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...) +end + +const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__) + end diff --git a/test/sciml_problem_inputs.jl b/test/sciml_problem_inputs.jl index a91a8d8c7c..cdb83ad0e2 100644 --- a/test/sciml_problem_inputs.jl +++ b/test/sciml_problem_inputs.jl @@ -2,7 +2,7 @@ # Fetch packages using ModelingToolkit, JumpProcesses, NonlinearSolve, OrdinaryDiffEq, StaticArrays, - SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test + SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test, SymbolicUtils using ModelingToolkit: t_nounits as t, D_nounits as D # Sets rnd number. @@ -29,7 +29,7 @@ begin ] noise_eqs = fill(0.01, 3, 6) jumps = [ - MassActionJump(kp, Pair{Symbolics.BasicSymbolic{Real}, Int64}[], [X => 1]), + MassActionJump(kp, Pair{Symbolics.SymbolicT, Int64}[], [X => 1]), MassActionJump(kd, [X => 1], [X => -1]), MassActionJump(k1, [X => 1], [X => -1, Y => 1]), MassActionJump(k2, [Y => 1], [X => 1, Y => -1]), diff --git a/test/simplify.jl b/test/simplify.jl index 4252e3262e..6968883132 100644 --- a/test/simplify.jl +++ b/test/simplify.jl @@ -1,5 +1,7 @@ using ModelingToolkit using ModelingToolkit: value +using Symbolics: STerm +import SymbolicUtils using Test @independent_variables t @@ -11,7 +13,7 @@ null_op = 0 * t one_op = 1 * t @test isequal(simplify(one_op), t) -identity_op = Num(Term(identity, [value(x)])) +identity_op = Num(STerm(identity, [value(x)]; type = Real, shape = SymbolicUtils.ShapeVecT())) @test isequal(simplify(identity_op), x) minus_op = -x diff --git a/test/variable_parsing.jl b/test/variable_parsing.jl index 60b4e24d64..a00aa2f729 100644 --- a/test/variable_parsing.jl +++ b/test/variable_parsing.jl @@ -2,15 +2,16 @@ using ModelingToolkit using Test using ModelingToolkit: value, Flow -using SymbolicUtils: FnType +using Symbolics: SSym +using SymbolicUtils: FnType, ShapeVecT @independent_variables t @variables x(t) y(t) # test multi-arg @variables z(t) # test single-arg -x1 = Num(Sym{FnType{Tuple{Any}, Real}}(:x)(value(t))) -y1 = Num(Sym{FnType{Tuple{Any}, Real}}(:y)(value(t))) -z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) +x1 = Num(SSym(:x; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +y1 = Num(SSym(:y; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +z1 = Num(SSym(:z; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) @test isequal(x1, x) @test isequal(y1, y) @@ -22,9 +23,9 @@ z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) end @parameters σ(..) -t1 = Num(Sym{Real}(:t)) -s1 = Num(Sym{Real}(:s)) -σ1 = Num(Sym{FnType{Tuple, Real}}(:σ)) +t1 = Num(SSym(:t; type = Real, shape = ShapeVecT())) +s1 = Num(SSym(:s; type = Real, shape = ShapeVecT())) +σ1 = Num(SSym(:σ; type = FnType{Tuple, Real}, shape = ShapeVecT())) @test isequal(t1, t) @test isequal(s1, s) @test isequal(σ1(t), σ(t))