Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ dependencies = [
"pyyaml",
"xarray[io]",

# Distributed Computing
"dask[distributed]",
# Numerics
"numpy",
"scipy",
Expand Down
121 changes: 121 additions & 0 deletions tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from types import SimpleNamespace

import numpy as np
import pandas as pd
import pytest
import xarray as xr
from hypothesis import given
from hypothesis import strategies as st

Expand Down Expand Up @@ -169,3 +171,122 @@ def test_create_hf_dataset_structure() -> None:
assert ds.attrs["dt"] == dt
assert ds.attrs["nt"] == n_time
assert ds.attrs["units"] == "cm/s^2"


def test_load_hf_dataset(tmp_path: Path) -> None:
station_file = tmp_path / "stations.ll"
station_file.write_text("172.6 -43.5 STAT_A\n172.7 -43.6 STAT_B\n172.8 -43.7 STAT_C\n")

seeds = SimpleNamespace(hf_seed=42)
resolution = Resolution(resolution=0.1)
hf_config = SimpleNamespace(t_sec=0.0)
velocity_model = SimpleNamespace(
model=pd.DataFrame({"Vs": [0.5]}),
)
domain_parameters = SimpleNamespace(duration=100.0)

ds = hf_sim.load_hf_dataset(
station_file,
seeds, # type: ignore[arg-type]
resolution,
hf_config, # type: ignore[arg-type]
velocity_model, # type: ignore[arg-type]
domain_parameters, # type: ignore[arg-type]
)

station_names = ["STAT_A", "STAT_B", "STAT_C"]
expected = xr.Dataset(
{
"latitude": ("station", [-43.5, -43.6, -43.7]),
"longitude": ("station", [172.6, 172.7, 172.8]),
"vref": ("station", [500.0, 500.0, 500.0]),
},
coords={"station": ("station", station_names)},
)
xr.testing.assert_allclose(
ds[["latitude", "longitude", "vref"]], expected
)

assert "seed" in ds.data_vars
assert ds.attrs["nt"] > 0
assert ds.attrs["dt"] == pytest.approx(0.005)
assert ds.attrs["start_sec"] == 0.0


def test_load_hf_dataset_chunking(tmp_path: Path) -> None:
# Create a station file with 1500 stations to test chunking logic
station_file = tmp_path / "stations.ll"
lines = [f"172.{i % 10} -43.{i % 10} STAT_{i:04d}" for i in range(1500)]
station_file.write_text("\n".join(lines) + "\n")

seeds = SimpleNamespace(hf_seed=42)
resolution = Resolution(resolution=0.1)
hf_config = SimpleNamespace(t_sec=0.0)
velocity_model = SimpleNamespace(
model=pd.DataFrame({"Vs": [0.5]}),
)
domain_parameters = SimpleNamespace(duration=10.0)

ds = hf_sim.load_hf_dataset(
station_file,
seeds, # type: ignore[arg-type]
resolution,
hf_config, # type: ignore[arg-type]
velocity_model, # type: ignore[arg-type]
domain_parameters, # type: ignore[arg-type]
)

# chunk_size = max(1, 1500 // 500) = 3
assert ds.chunks is not None
station_chunks = ds.chunks["station"]
assert all(c <= 3 for c in station_chunks)
assert sum(station_chunks) == 1500


def test_process_hf_dataset_structure() -> None:
nt = 100
dt = 0.02
station_names = np.array(["STAT_A", "STAT_B"])

input_ds = xr.Dataset(
{
"latitude": ("station", np.array([-43.5, -43.6])),
"longitude": ("station", np.array([172.6, 172.7])),
"seed": ("station", np.array([123, 456])),
"vref": ("station", np.array([500.0, 500.0])),
},
coords={"station": ("station", station_names)},
attrs={"nt": nt, "dt": dt, "start_sec": 0.0},
)

# We can't actually run the binary, but we can test that the function
# signature is correct and the output structure is as expected by
# mocking hf_simulate_station
import unittest.mock as mock

mock_waveform = np.random.rand(nt, 3).astype(np.float32)

with mock.patch.object(
hf_sim,
"hf_simulate_station",
side_effect=[
("STAT_A", 10.5, mock_waveform),
("STAT_B", 20.1, mock_waveform),
],
):
result = hf_sim.process_hf_dataset(
input_ds,
hf_sim_path="/fake/path",
hf_input_template="template",
)

assert "waveform" in result.data_vars
assert "epicentre_distance" in result.data_vars
assert result["waveform"].dims == ("component", "station", "time")
assert result.sizes == {"component": 3, "station": 2, "time": nt}
assert result["epicentre_distance"].dims == ("station",)
xr.testing.assert_equal(result.station, input_ds.station)
expected_components = xr.DataArray(
["x", "y", "z"], coords={"component": ["x", "y", "z"]}, dims="component"
)
xr.testing.assert_equal(result.component, expected_components)
Loading