Skip to content
54 changes: 54 additions & 0 deletions src/scripts/dev/scenarios/suspend-resume-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
This file defines a batch run of a large population for a long time with all disease modules and full use of HSIs
It's used for calibrations (demographic patterns, health burdens and healthsystem usage)

Run on the batch system using:
```tlo batch-submit src/scripts/calibration_analyses/scenarios/long_run_all_diseases.py```

or locally using:
```tlo scenario-run src/scripts/calibration_analyses/scenarios/long_run_all_diseases.py```

"""

from tlo import Date, logging
from tlo.analysis.utils import get_parameters_for_status_quo
from tlo.methods.fullmodel import fullmodel
from tlo.scenario import BaseScenario


class SuspendResumeTest(BaseScenario):
def __init__(self):
super().__init__()
self.seed = 0
self.start_date = Date(2010, 1, 1)
self.end_date = Date(2012, 1, 1) # The simulation will stop before reaching this date.
self.pop_size = 1_000
self.number_of_draws = 2
self.runs_per_draw = 2

def log_configuration(self):
return {
'filename': 'suspend_resume_test', # <- (specified only for local running)
'directory': './outputs', # <- (specified only for local running)
'custom_levels': {
'*': logging.WARNING,
'tlo.methods.demography': logging.INFO,
'tlo.methods.demography.detail': logging.WARNING,
'tlo.methods.healthburden': logging.INFO,
'tlo.methods.healthsystem': logging.INFO,
'tlo.methods.healthsystem.summary': logging.INFO,
"tlo.methods.contraception": logging.INFO,
}
}

def modules(self):
return fullmodel()

def draw_parameters(self, draw_number, rng):
return get_parameters_for_status_quo()


if __name__ == '__main__':
from tlo.cli import scenario_run

scenario_run([__file__])
29 changes: 21 additions & 8 deletions src/tlo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ def parse_log(log_directory):
with open(path / f"{key}.pickle", "wb") as f:
pickle.dump(output, f)

@cli.command()
@cli.command(context_settings=dict(ignore_unknown_options=True))
@click.argument("scenario_file", type=click.Path(exists=True))
@click.option("--asserts-on", type=bool, default=False, is_flag=True, help="Enable assertions in simulation run.")
@click.option("--more-memory", type=bool, default=False, is_flag=True,
help="Request machine wth more memory (for larger population sizes).")
@click.option("--image-tag", type=str, help="Tag of the Docker image to use.")
@click.option("--keep-pool-alive", type=bool, default=False, is_flag=True, hidden=True)
@click.argument('scenario_args', nargs=-1, type=click.UNPROCESSED)
@click.pass_context
def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, image_tag=None):
def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, image_tag=None, scenario_args=None):
"""Submit a scenario to the batch system.

SCENARIO_FILE is path to file containing scenario class.
Expand All @@ -132,6 +133,24 @@ def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, i

scenario = load_scenario(scenario_file)

config = load_config(ctx.obj['config_file'])

# Directory where the file share will be mounted, relative to
# ${AZ_BATCH_NODE_MOUNTS_DIR}.
file_share_mount_point = "mnt"

# if we have other scenario arguments, parse them
if scenario_args is not None:
# we rewrite the path to the simulation to resume
if '--resume-simulation' in scenario_args:
i = scenario_args.index('--resume-simulation')
path_to_job = (f"${{AZ_BATCH_NODE_MOUNTS_DIR}}/"
f"{file_share_mount_point}/"
f"{config['DEFAULT']['USERNAME']}/"
f"{scenario_args[i+1]}")
scenario_args = scenario_args[:i + 1] + (path_to_job, ) + scenario_args[i + 2:]
scenario.parse_arguments(scenario_args)

# get the commit we're going to submit to run on batch, and save the run config for that commit
# it's the most recent commit on current branch
repo = Repo(".")
Expand All @@ -140,8 +159,6 @@ def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, i

print(">Setting up batch\r", end="")

config = load_config(ctx.obj['config_file'])

# ID of the Batch job.
timestamp = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
job_id = scenario.get_log_config()["filename"] + "-" + timestamp
Expand Down Expand Up @@ -221,10 +238,6 @@ def batch_submit(ctx, scenario_file, asserts_on, more_memory, keep_pool_alive, i
# Options for running the Docker container
container_run_options = "--rm --workdir /TLOmodel"

# Directory where the file share will be mounted, relative to
# ${AZ_BATCH_NODE_MOUNTS_DIR}.
file_share_mount_point = "mnt"

azure_file_share_configuration = batch_models.AzureFileShareConfiguration(
account_name=config["STORAGE"]["NAME"],
azure_file_url=azure_file_url,
Expand Down
49 changes: 31 additions & 18 deletions src/tlo/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def draw_parameters(self, draw_number, rng):
import argparse
import datetime
import json
import os
from collections.abc import Iterable
from itertools import product
from pathlib import Path, PurePosixPath
Expand Down Expand Up @@ -160,7 +161,7 @@ def parse_arguments(self, extra_arguments: List[str]) -> None:
for key, value in vars(arguments).items():
if value is not None:
if hasattr(self, key):
logger.info(key="message", data=f"Overriding attribute: {key}: {getattr(self, key)} -> {value}")
print(f"Overriding attribute with argument value: {key}: {getattr(self, key)} -> {value}")
setattr(self, key, value)

def add_arguments(self, parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -359,8 +360,8 @@ def __init__(self, run_configuration_path):
self.scenario = ScenarioLoader(self.run_config["scenario_script_path"]).get_scenario()
if self.run_config["arguments"] is not None:
self.scenario.parse_arguments(self.run_config["arguments"])
logger.info(key="message", data=f"Loaded scenario using {run_configuration_path}")
logger.info(key="message", data=f"Found {self.number_of_draws} draws; {self.runs_per_draw} runs/draw")
print(f"Loaded scenario from config at {run_configuration_path}")
print(f"Found {self.number_of_draws} draws with {self.runs_per_draw} runs-per-draw.")

@property
def number_of_draws(self):
Expand Down Expand Up @@ -397,27 +398,38 @@ def run_sample_by_number(self, output_directory, draw_number, sample_number):
sample = self.get_sample(draw, sample_number)
log_config = self.scenario.get_log_config(output_directory)

logger.info(
key="message",
data=f"Running draw {sample['draw_number']}, sample {sample['sample_number']}",
)
print(f"Running draw {sample['draw_number']}, run {sample['sample_number']}.")

# if user has specified a restore simulation, we load it from a pickle file
if (
hasattr(self.scenario, "resume_simulation")
and self.scenario.resume_simulation is not None
):
suspended_simulation_path = (
Path(self.scenario.resume_simulation)
/ str(draw_number)
/ str(sample_number)
/ "suspended_simulation.pickle"
)
# expand any environment variables in the path
if "$" in self.scenario.resume_simulation:
self.scenario.resume_simulation = os.path.expandvars(self.scenario.resume_simulation)

suspended_simulation_path = Path(self.scenario.resume_simulation)

# if the resume_simulation doesn't end with a draw number, we are resuming all draws
last_component = self.scenario.resume_simulation.rstrip("/").split("/")[-1]
try:
int(last_component)
except ValueError:
suspended_simulation_path = suspended_simulation_path / str(draw_number)

suspended_simulation_path = suspended_simulation_path / str(sample_number) / "suspended_simulation.pickle"

sim = Simulation.load_from_pickle(pickle_path=suspended_simulation_path, log_config=log_config)

logger.info(
key="message",
data=f"Loading pickled suspended simulation from {suspended_simulation_path}",
data=f"Loading suspended simulation from {suspended_simulation_path}",
)
sim = Simulation.load_from_pickle(pickle_path=suspended_simulation_path, log_config=log_config)

# if parameters are specified, we override them
if sample["parameters"] is not None:
self.override_parameters(sim, sample["parameters"])
else:
sim = Simulation(
start_date=self.scenario.start_date,
Expand All @@ -442,12 +454,13 @@ def run_sample_by_number(self, output_directory, draw_number, sample_number):
):
sim.run_simulation_to(to_date=self.scenario.suspend_date)
suspended_simulation_path = Path(log_config["directory"]) / "suspended_simulation.pickle"
sim.save_to_pickle(pickle_path=suspended_simulation_path)
sim.close_output_file()
logger.info(
key="message",
data=f"Simulation suspended at {self.scenario.suspend_date} and saved to {suspended_simulation_path}",
data=f"Suspending simulation at {self.scenario.suspend_date} and saving to {suspended_simulation_path}."
f" Note, output file handle will be closed first and no more output logged",
)
sim.close_output_file()
sim.save_to_pickle(pickle_path=suspended_simulation_path)
else:
sim.run_simulation_to(to_date=self.scenario.end_date)
sim.finalise()
Expand Down