diff --git a/src/main/java/org/neuroml/export/neuron/NeuronWriter.java b/src/main/java/org/neuroml/export/neuron/NeuronWriter.java index 927e4be5..c6f4b65a 100644 --- a/src/main/java/org/neuroml/export/neuron/NeuronWriter.java +++ b/src/main/java/org/neuroml/export/neuron/NeuronWriter.java @@ -34,6 +34,7 @@ import org.lemsml.jlems.core.type.Property; import org.lemsml.jlems.core.type.Requirement; import org.lemsml.jlems.core.type.Target; +import org.lemsml.jlems.core.type.Meta; import org.lemsml.jlems.core.type.dynamics.Case; import org.lemsml.jlems.core.type.dynamics.ConditionalDerivedVariable; import org.lemsml.jlems.core.type.dynamics.DerivedVariable; @@ -329,7 +330,6 @@ public String getMainScript() throws GenerationException, NeuroMLException Target target = lems.getTarget(); - Component simCpt = target.getComponent(); String len = simCpt.getStringValue("length"); @@ -340,21 +340,49 @@ public String getMainScript() throws GenerationException, NeuroMLException len = "" + Float.parseFloat(len) * 1000; } - String dt = simCpt.getStringValue("step"); - dt = dt.replaceAll("ms", "").trim(); - if(dt.indexOf("s") > 0) + /* cvode usage: + * https://nrn.readthedocs.io/en/latest/hoc/simctrl/cvode.html + * - we do not currently support the local variable time step method + */ + boolean nrn_cvode = false; + String dt = "0.01"; + /* defaults from NEURON */ + String abs_tol = "None"; + String rel_tol = "None"; + LemsCollection metas = simCpt.metas; + for(Meta m : metas) { - dt = dt.replaceAll("s", "").trim(); - dt = "" + Float.parseFloat(dt) * 1000; + HashMap attributes = m.getAttributes(); + if (attributes.getOrDefault("for", "").equals("neuron")) + { + if (attributes.getOrDefault("method", "").equals("cvode")) + { + nrn_cvode = true; + abs_tol = attributes.getOrDefault("abs_tolerance", abs_tol); + rel_tol = attributes.getOrDefault("rel_tolerance", rel_tol); + E.info("CVode with abs_tol="+abs_tol+" , rel_tol="+rel_tol+" selected for NEURON simulation"); + } + } + } + if (nrn_cvode == false) + { + dt = simCpt.getStringValue("step"); + dt = dt.replaceAll("ms", "").trim(); + if(dt.indexOf("s") > 0) + { + dt = dt.replaceAll("s", "").trim(); + dt = "" + Float.parseFloat(dt) * 1000; + } + } main.append("class NeuronSimulation():\n\n"); int seed = DLemsWriter.DEFAULT_SEED; if (simCpt.hasStringValue("seed")) seed = Integer.parseInt(simCpt.getStringValue("seed")); - main.append(" def __init__(self, tstop, dt, seed="+seed+"):\n\n"); + main.append(" def __init__(self, tstop, dt=None, seed="+seed+", abs_tol=None, rel_tol=None):\n\n"); Component targetComp = simCpt.getRefComponents().get("target"); @@ -362,6 +390,8 @@ public String getMainScript() throws GenerationException, NeuroMLException main.append(bIndent+"print(\"\\n Starting simulation in NEURON of %sms generated from NeuroML2 model...\\n\"%tstop)\n\n"); main.append(bIndent+"self.setup_start = time.time()\n"); main.append(bIndent+"self.seed = seed\n"); + main.append(bIndent+"self.abs_tol = abs_tol\n"); + main.append(bIndent+"self.rel_tol = rel_tol\n"); if (target.reportFile!=null) { @@ -1306,8 +1336,14 @@ else if(cc.getComponentType().isOrExtends(NeuroMLElements.CONTINUOUS_CONNECTION_ main.append(toRec); main.append(bIndent+"h.tstop = tstop\n\n"); - main.append(bIndent+"h.dt = dt\n\n"); - main.append(bIndent+"h.steps_per_ms = 1/h.dt\n\n"); + main.append(bIndent+"if self.abs_tol is not None and self.rel_tol is not None:\n"); + main.append(bIndent+" cvode = h.CVode()\n"); + main.append(bIndent+" cvode.active(1)\n"); + main.append(bIndent+" cvode.atol(self.abs_tol)\n"); + main.append(bIndent+" cvode.rtol(self.rel_tol)\n"); + main.append(bIndent+"else:\n"); + main.append(bIndent+" h.dt = dt\n"); + main.append(bIndent+" h.steps_per_ms = 1/h.dt\n\n"); if(!nogui) { @@ -1350,7 +1386,8 @@ else if(cc.getComponentType().isOrExtends(NeuroMLElements.CONTINUOUS_CONNECTION_ columnsPre.get(timeRef).add(bIndent+"h(' objectvar v_" + timeRef + " ')"); columnsPre.get(timeRef).add(bIndent+"h(' { v_" + timeRef + " = new Vector() } ')"); columnsPre.get(timeRef).add(bIndent+"h(' { v_" + timeRef + ".record(&t) } ')"); - columnsPre.get(timeRef).add(bIndent+"h.v_" + timeRef + ".resize((h.tstop * h.steps_per_ms) + 1)"); + columnsPre.get(timeRef).add(bIndent+"if self.abs_tol is None or self.rel_tol is None:\n"); + columnsPre.get(timeRef).add(bIndent+" h.v_" + timeRef + ".resize((h.tstop * h.steps_per_ms) + 1)"); columnsPost0.get(timeRef).add(bIndent+"py_v_" + timeRef + " = [ t/1000 for t in h.v_" + timeRef + ".to_python() ] # Convert to Python list for speed..."); @@ -1410,7 +1447,8 @@ else if(cc.getComponentType().isOrExtends(NeuroMLElements.CONTINUOUS_CONNECTION_ columnsPre.get(outfileId).add(bIndent+"h(' objectvar v_" + colId + " ')"); columnsPre.get(outfileId).add(bIndent+"h(' { v_" + colId + " = new Vector() } ')"); columnsPre.get(outfileId).add(bIndent+"h(' { v_" + colId + ".record(&" + lqp.getNeuronVariableReference() + ") } ')"); - columnsPre.get(outfileId).add(bIndent+"h.v_" + colId + ".resize((h.tstop * h.steps_per_ms) + 1)"); + columnsPre.get(outfileId).add(bIndent+"if self.abs_tol is None or self.rel_tol is None:\n"); + columnsPre.get(outfileId).add(bIndent+" h.v_" + colId + ".resize((h.tstop * h.steps_per_ms) + 1)"); float conv = NRNUtils.getNeuronUnitFactor(lqp.getDimension().getName()); String factor = (conv == 1) ? "" : " / " + conv; @@ -1571,7 +1609,12 @@ else if (eofFormat.equals(EventWriter.FORMAT_ID_TIME)) main.append(bIndent+"self.initialized = True\n"); main.append(bIndent+"sim_start = time.time()\n"); - main.append(bIndent+"print(\"Running a simulation of %sms (dt = %sms; seed=%s)\" % (h.tstop, h.dt, self.seed))\n\n"); + + main.append(bIndent+"if self.abs_tol is not None and self.rel_tol is not None:\n"); + main.append(bIndent+" print(\"Running a simulation of %sms (cvode abs_tol = %sms, rel_tol = %sms; seed=%s)\" % (h.tstop, self.abs_tol, self.rel_tol, self.seed))\n"); + main.append(bIndent+"else:\n"); + main.append(bIndent+" print(\"Running a simulation of %sms (dt = %sms; seed=%s)\" % (h.tstop, h.dt, self.seed))\n\n"); + main.append(bIndent+"try:\n"); main.append(bIndent+" h.run()\n"); main.append(bIndent+"except Exception as e:\n"); @@ -1709,7 +1752,7 @@ else if (eofFormat.equals(EventWriter.FORMAT_ID_TIME)) main.append("if __name__ == '__main__':\n\n"); - main.append(" ns = NeuronSimulation(tstop="+len+", dt="+dt+", seed="+seed+")\n\n"); + main.append(" ns = NeuronSimulation(tstop="+len+", dt="+dt+", seed="+seed+", abs_tol="+abs_tol+", rel_tol="+rel_tol+")\n\n"); main.append(" ns.run()\n\n");