diff --git a/pyproject.toml b/pyproject.toml index fdd6492..94e632d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies = [ "pyyaml", "xarray[io]", + # Distributed Computing + "dask[distributed]", # Numerics "numpy", "scipy", diff --git a/tests/test_hf.py b/tests/test_hf.py index 7d72c8b..561ac1d 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -2,7 +2,9 @@ from types import SimpleNamespace import numpy as np +import pandas as pd import pytest +import xarray as xr from hypothesis import given from hypothesis import strategies as st @@ -169,3 +171,122 @@ def test_create_hf_dataset_structure() -> None: assert ds.attrs["dt"] == dt assert ds.attrs["nt"] == n_time assert ds.attrs["units"] == "cm/s^2" + + +def test_load_hf_dataset(tmp_path: Path) -> None: + station_file = tmp_path / "stations.ll" + station_file.write_text("172.6 -43.5 STAT_A\n172.7 -43.6 STAT_B\n172.8 -43.7 STAT_C\n") + + seeds = SimpleNamespace(hf_seed=42) + resolution = Resolution(resolution=0.1) + hf_config = SimpleNamespace(t_sec=0.0) + velocity_model = SimpleNamespace( + model=pd.DataFrame({"Vs": [0.5]}), + ) + domain_parameters = SimpleNamespace(duration=100.0) + + ds = hf_sim.load_hf_dataset( + station_file, + seeds, # type: ignore[arg-type] + resolution, + hf_config, # type: ignore[arg-type] + velocity_model, # type: ignore[arg-type] + domain_parameters, # type: ignore[arg-type] + ) + + station_names = ["STAT_A", "STAT_B", "STAT_C"] + expected = xr.Dataset( + { + "latitude": ("station", [-43.5, -43.6, -43.7]), + "longitude": ("station", [172.6, 172.7, 172.8]), + "vref": ("station", [500.0, 500.0, 500.0]), + }, + coords={"station": ("station", station_names)}, + ) + xr.testing.assert_allclose( + ds[["latitude", "longitude", "vref"]], expected + ) + + assert "seed" in ds.data_vars + assert ds.attrs["nt"] > 0 + assert ds.attrs["dt"] == pytest.approx(0.005) + assert ds.attrs["start_sec"] == 0.0 + + +def test_load_hf_dataset_chunking(tmp_path: Path) -> None: + # Create a station file with 1500 stations to test chunking logic + station_file = tmp_path / "stations.ll" + lines = [f"172.{i % 10} -43.{i % 10} STAT_{i:04d}" for i in range(1500)] + station_file.write_text("\n".join(lines) + "\n") + + seeds = SimpleNamespace(hf_seed=42) + resolution = Resolution(resolution=0.1) + hf_config = SimpleNamespace(t_sec=0.0) + velocity_model = SimpleNamespace( + model=pd.DataFrame({"Vs": [0.5]}), + ) + domain_parameters = SimpleNamespace(duration=10.0) + + ds = hf_sim.load_hf_dataset( + station_file, + seeds, # type: ignore[arg-type] + resolution, + hf_config, # type: ignore[arg-type] + velocity_model, # type: ignore[arg-type] + domain_parameters, # type: ignore[arg-type] + ) + + # chunk_size = max(1, 1500 // 500) = 3 + assert ds.chunks is not None + station_chunks = ds.chunks["station"] + assert all(c <= 3 for c in station_chunks) + assert sum(station_chunks) == 1500 + + +def test_process_hf_dataset_structure() -> None: + nt = 100 + dt = 0.02 + station_names = np.array(["STAT_A", "STAT_B"]) + + input_ds = xr.Dataset( + { + "latitude": ("station", np.array([-43.5, -43.6])), + "longitude": ("station", np.array([172.6, 172.7])), + "seed": ("station", np.array([123, 456])), + "vref": ("station", np.array([500.0, 500.0])), + }, + coords={"station": ("station", station_names)}, + attrs={"nt": nt, "dt": dt, "start_sec": 0.0}, + ) + + # We can't actually run the binary, but we can test that the function + # signature is correct and the output structure is as expected by + # mocking hf_simulate_station + import unittest.mock as mock + + mock_waveform = np.random.rand(nt, 3).astype(np.float32) + + with mock.patch.object( + hf_sim, + "hf_simulate_station", + side_effect=[ + ("STAT_A", 10.5, mock_waveform), + ("STAT_B", 20.1, mock_waveform), + ], + ): + result = hf_sim.process_hf_dataset( + input_ds, + hf_sim_path="/fake/path", + hf_input_template="template", + ) + + assert "waveform" in result.data_vars + assert "epicentre_distance" in result.data_vars + assert result["waveform"].dims == ("component", "station", "time") + assert result.sizes == {"component": 3, "station": 2, "time": nt} + assert result["epicentre_distance"].dims == ("station",) + xr.testing.assert_equal(result.station, input_ds.station) + expected_components = xr.DataArray( + ["x", "y", "z"], coords={"component": ["x", "y", "z"]}, dims="component" + ) + xr.testing.assert_equal(result.component, expected_components) diff --git a/workflow/scripts/hf_sim.py b/workflow/scripts/hf_sim.py index 9904800..8f672d6 100644 --- a/workflow/scripts/hf_sim.py +++ b/workflow/scripts/hf_sim.py @@ -23,6 +23,10 @@ > [!NOTE] > The high-frequency code is very brittle. It is recommended you have both versions 6.0.3 and 5.4.5 built to run with. Sometimes it is necessary to switch between versions if one does not work. +> [!NOTE] +> Dask worker memory limits should account for the external binary's footprint, +> as `hb_high_binmod` runs as a subprocess and its memory usage is not tracked by Dask. + Usage ----- `hf-sim [OPTIONS] REALISATION_FFP STOCH_FFP STATION_FILE OUT_FILE` @@ -32,19 +36,19 @@ See the output of `hf-sim --help`. """ -import concurrent.futures import subprocess import tempfile from collections.abc import Iterable -from concurrent.futures.thread import ThreadPoolExecutor from pathlib import Path from typing import Annotated +import dask.array as da import numpy as np import numpy.typing as npt import pandas as pd import typer import xarray as xr +from dask.distributed import Client, LocalCluster from qcore import cli from workflow import log_utils, realisations, utils @@ -60,6 +64,15 @@ app = typer.Typer() +_TARGET_TASK_COUNT = 500 +"""Target number of Dask tasks for station chunking. + +Chosen to balance scheduler overhead against parallelism: too few tasks +under-utilise workers, while too many (e.g. one per station at 100 k+) +flood the scheduler with graph overhead. A value of 500–1 000 keeps the +task graph manageable while still saturating a large cluster. +""" + def rupture_velocity_hf_transition_bands( rupture_velocity: RuptureVelocity, @@ -272,6 +285,146 @@ def station_seeds(seed: int, stations: Iterable[str]) -> npt.NDArray[np.int32]: return np.int32(seed) ^ station_hashes +def load_hf_dataset( + station_file: Path, + seeds: Seeds, + resolution: Resolution, + hf_config: HFConfig, + velocity_model: HFVelocityModel1D, + domain_parameters: DomainParameters, +) -> xr.Dataset: + """Load station data into a chunked xarray Dataset for distributed processing. + + Parameters + ---------- + station_file : Path + Path to station CSV file (columns: longitude, latitude, name). + seeds : Seeds + Seed configuration for the simulation. + resolution : Resolution + HF simulation resolution. + hf_config : HFConfig + The high-frequency config. + velocity_model : HFVelocityModel1D + The 1D velocity model. + domain_parameters : DomainParameters + The simulation domain parameters. + + Returns + ------- + xr.Dataset + A dataset indexed by ``station`` with variables for latitude, + longitude, seed, and vref, chunked for distributed processing. + """ + stations = pd.read_csv( + station_file, + delimiter=r"\s+", + header=None, + names=["longitude", "latitude", "name"], + ).set_index("name") + + seeds_array = station_seeds(seeds.hf_seed, stations.index) + vs = velocity_model.model["Vs"].iloc[0] * 1000 + vref = np.full(len(stations), vs, dtype=np.float64) + + nt = int( + np.float32(domain_parameters.duration) / np.float32(resolution.dt) + ) + + total_stations = len(stations) + chunk_size = max(1, total_stations // _TARGET_TASK_COUNT) + + ds = xr.Dataset( + { + "latitude": ("station", stations["latitude"].values), + "longitude": ("station", stations["longitude"].values), + "seed": ("station", seeds_array), + "vref": ("station", vref), + }, + coords={ + "station": ("station", stations.index.values.astype(str)), + }, + attrs={ + "nt": nt, + "dt": resolution.dt, + "start_sec": hf_config.t_sec, + }, + ) + return ds.chunk({"station": chunk_size}) + + +def process_hf_dataset( + ds: xr.Dataset, + *, + hf_sim_path: str, + hf_input_template: str, +) -> xr.Dataset: + """Process a chunk of the HF dataset by running station simulations. + + Designed to be used with :func:`xarray.map_blocks`. Iterates over the + stations in the chunk and executes ``hf_simulate_station`` for each one. + + Parameters + ---------- + ds : xr.Dataset + A chunk of the input dataset with ``latitude``, ``longitude``, + and ``seed`` variables indexed by ``station``. + hf_sim_path : str + Path to the HF simulation binary (passed as string for + serialization). + hf_input_template : str + The stdin input template for the HF simulation binary. + + Returns + ------- + xr.Dataset + A dataset containing ``waveform`` (dims: component, station, time) + and ``epicentre_distance`` (dims: station) for the stations in the + chunk. + """ + station_names = ds.station.values + n_stations = len(station_names) + nt = ds.attrs["nt"] + + waveform = np.empty((3, n_stations, nt), dtype=np.float32) + epicentre_distances = np.empty(n_stations, dtype=np.float64) + + for i, station_name in enumerate(station_names): + lat = float(ds["latitude"].values[i]) + lon = float(ds["longitude"].values[i]) + seed = int(ds["seed"].values[i]) + + _, epicentre, station_waveform = hf_simulate_station( + Path(hf_sim_path), + hf_input_template, + lat, + lon, + str(station_name), + seed, + ) + epicentre_distances[i] = epicentre + for component in range(3): + waveform[component, i] = station_waveform[:, component] + + dt = ds.attrs["dt"] + time_coords = np.arange(nt) * dt + + return xr.Dataset( + { + "waveform": ( + ["component", "station", "time"], + waveform, + ), + "epicentre_distance": (["station"], epicentre_distances), + }, + coords={ + "station": ("station", station_names), + "component": ("component", ["x", "y", "z"]), + "time": ("time", time_coords), + }, + ) + + def create_hf_dataset( # array-like used here to reduce the number of times we have to # change the types if the downstream function inputs change. @@ -374,10 +527,10 @@ def run_hf( This function performs the following steps: 1. Reads configuration and domain parameters from the realisation file. - 2. Filters stations based on their location relative to the domain. - 3. Uses multiprocessing to simulate each station and calculate epicentre distances. - 4. Reads the velocity model and calculates the `vs` value. - 5. Writes the HF output file, including header and station-specific data. + 2. Loads station data into a chunked xarray Dataset. + 3. Uses Dask to lazily simulate each station chunk in parallel. + 4. Writes the HF output file in NetCDF format chunk-by-chunk to + support larger-than-memory datasets. Parameters ---------- @@ -416,19 +569,8 @@ def run_hf( realisation_ffp, metadata.defaults_version ) - stations = pd.read_csv( - station_file, - delimiter=r"\s+", - header=None, - names=["longitude", "latitude", "name"], - ).set_index("name") - stations["seed"] = station_seeds(seeds.hf_seed, stations.index) velocity_model_path = work_directory / "velocity_model" velocity_model.write_velocity_model(velocity_model_path) - nt = int( - np.float32(domain_parameters.duration) / np.float32(resolution.dt) - ) # Match Fortran's single-precision for consistent nt calculation - waveform = np.empty((3, len(stations), nt), dtype=np.float32) hf_input_template = build_hf_input( stoch_ffp, @@ -438,43 +580,76 @@ def run_hf( rupture_velocity, domain_parameters, ) - stations["epicentre_distance"] = np.nan - - with ThreadPoolExecutor(max_workers=utils.get_available_cores()) as executor: - station_index = {station: i for i, station in enumerate(stations.index)} - futures = [ - executor.submit( - hf_simulate_station, - hf_sim_path, - hf_input_template, - station["latitude"], - station["longitude"], - str(name), - int(station["seed"]), - ) - for name, station in stations.iterrows() - ] - for future in concurrent.futures.as_completed(futures): - station, epicentre, station_waveform = future.result() - stations.loc[station, "epicentre_distance"] = epicentre - i = station_index[station] - for component in range(3): - waveform[component, i] = station_waveform[:, component] + input_ds = load_hf_dataset( + station_file, + seeds, + resolution, + hf_config, + velocity_model, + domain_parameters, + ) - vs = velocity_model.model["Vs"].iloc[0] * 1000 - stations["vs"] = vs - - ds = create_hf_dataset( - waveform=waveform, - latitude=stations["latitude"], - longitude=stations["longitude"], - names=stations.index, - epicentre_distance=stations["epicentre_distance"], - seed=stations["seed"], - vref=stations["vs"], - dt=resolution.dt, - start_sec=hf_config.t_sec, + nt = input_ds.attrs["nt"] + dt = input_ds.attrs["dt"] + # Station coordinate labels are always in-memory in xarray (not + # dask-backed), so accessing .values here is safe and necessary to + # construct the template. + station_names = input_ds.station.values + n_stations = len(station_names) + time_coords = np.arange(nt) * dt + + # Use dask.array.empty so the template does not allocate memory for + # the full waveform array (which can be 100 GB+ for large runs). + template = xr.Dataset( + { + "waveform": ( + ["component", "station", "time"], + da.empty((3, n_stations, nt), dtype=np.float32), + ), + "epicentre_distance": ( + ["station"], + da.empty(n_stations, dtype=np.float64), + ), + }, + coords={ + "station": ("station", station_names), + "component": ("component", ["x", "y", "z"]), + "time": ("time", time_coords), + }, ) - ds.to_netcdf(out_file, engine="h5netcdf") + + with LocalCluster() as cluster, Client(cluster): + result_ds = xr.map_blocks( + process_hf_dataset, + input_ds, + template=template, + kwargs={ + "hf_sim_path": str(hf_sim_path), + "hf_input_template": hf_input_template, + }, + ) + + # Attach station metadata from the input dataset. These are + # small 1-D arrays (one value per station) so they do not + # contribute to memory pressure. + result_ds["seed"] = input_ds["seed"] + result_ds["vref"] = input_ds["vref"] + result_ds = result_ds.assign_coords( + lat=input_ds["latitude"], + lon=input_ds["longitude"], + ) + result_ds.attrs.update( + { + "start_sec": hf_config.t_sec, + "nt": nt, + "dt": dt, + "units": "cm/s^2", + } + ) + + # Write lazily — Dask streams chunks to disk one at a time so + # the full waveform array never needs to reside in memory. + result_ds.to_netcdf(out_file, engine="h5netcdf") + realisations.append_log_entry(realisation_ffp)