Skip to content
4 changes: 3 additions & 1 deletion kwave/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def _make_binary_executable(self):
binary_path.chmod(binary_path.stat().st_mode | stat.S_IEXEC)

def run_simulation(self, input_filename: str, output_filename: str, options: list[str]) -> dotdict:
# Validate execution options before running simulation
self.execution_options.validate()

command = [str(self.execution_options.binary_path), "-i", input_filename, "-o", output_filename] + options

try:
Expand Down Expand Up @@ -116,7 +119,6 @@ def parse_executable_output(output_filename: str) -> dotdict:
# # Combine the sensor data if using a kWaveTransducer as a sensor
# if isinstance(sensor, kWaveTransducer):
# sensor_data['p'] = sensor.combine_sensor_data(sensor_data['p'])

# # Compute the intensity outputs
# if any(key.startswith(('I_avg', 'I')) for key in sensor.get('record', [])):
# flags = {
Expand Down
35 changes: 30 additions & 5 deletions kwave/options/simulation_execution_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,34 @@ def __init__(
self.checkpoint_timesteps = checkpoint_timesteps
self.checkpoint_file = checkpoint_file

if self.checkpoint_file is not None:
if self.checkpoint_interval is None and self.checkpoint_timesteps is None:
raise ValueError("One of checkpoint_interval or checkpoint_timesteps must be set when checkpoint_file is set.")
self.validate()

def _validate_checkpoint_options(self):
# Checkpointing parameters are set
if self.checkpoint_interval is not None or self.checkpoint_timesteps is not None:
# No checkpoint file set
if self.checkpoint_file is None:
raise ValueError("`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.")
# Both checkpointing parameters are set
if self.checkpoint_interval is not None and self.checkpoint_timesteps is not None:
raise ValueError("`checkpoint_interval` and `checkpoint_timesteps` cannot be set at the same time.")
# Checkpoint file is set but no checkpointing parameters are set
if self.checkpoint_file is not None and self.checkpoint_interval is None and self.checkpoint_timesteps is None:
raise ValueError("`checkpoint_interval` or `checkpoint_timesteps` must be set when `checkpoint_file` is set.")
Comment on lines +61 to +67
Copy link

Copilot AI May 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The backticks in the exception message are literal characters; for consistency with other validators you may want to remove them or switch to quotes.

Suggested change
raise ValueError("`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.")
# Both checkpointing parameters are set
if self.checkpoint_interval is not None and self.checkpoint_timesteps is not None:
raise ValueError("`checkpoint_interval` and `checkpoint_timesteps` cannot be set at the same time.")
# Checkpoint file is set but no checkpointing parameters are set
if self.checkpoint_file is not None and self.checkpoint_interval is None and self.checkpoint_timesteps is None:
raise ValueError("`checkpoint_interval` or `checkpoint_timesteps` must be set when `checkpoint_file` is set.")
raise ValueError("'checkpoint_file' must be set when 'checkpoint_interval' or 'checkpoint_timesteps' is set.")
# Both checkpointing parameters are set
if self.checkpoint_interval is not None and self.checkpoint_timesteps is not None:
raise ValueError("'checkpoint_interval' and 'checkpoint_timesteps' cannot be set at the same time.")
# Checkpoint file is set but no checkpointing parameters are set
if self.checkpoint_file is not None and self.checkpoint_interval is None and self.checkpoint_timesteps is None:
raise ValueError("'checkpoint_interval' or 'checkpoint_timesteps' must be set when 'checkpoint_file' is set.")

Copilot uses AI. Check for mistakes.

def validate(self):
"""Validate all simulation options before running a simulation.

This method should be called before running a simulation to ensure all options
are in a valid state. It validates:
1. Checkpoint configuration (if any checkpoint options are set)

Raises:
ValueError: If any option configuration is invalid
"""
# Validate checkpoint configuration if any checkpoint options are set
if any(x is not None for x in [self.checkpoint_interval, self.checkpoint_timesteps, self.checkpoint_file]):
self._validate_checkpoint_options()

@property
def num_threads(self) -> Union[int, str]:
Expand Down Expand Up @@ -184,8 +209,8 @@ def checkpoint_interval(self) -> Optional[int]:
@checkpoint_interval.setter
def checkpoint_interval(self, value: Optional[int]):
if value is not None:
if not isinstance(value, int) or value < 0:
raise ValueError("Checkpoint interval must be a positive integer")
if not isinstance(value, int) or value <= 0:
raise ValueError("Checkpoint interval must be a positive integer in seconds")
self._checkpoint_interval = value

@property
Expand Down
126 changes: 89 additions & 37 deletions tests/test_simulation_execution_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,43 +266,20 @@ def test_as_list_with_valid_checkpoint_timesteps_options(self):

def test_initialization_with_invalid_checkpoint_options(self):
"""Test initialization with invalid checkpoint options."""

# Test with valid checkpoint_file but unset checkpoint_timesteps and checkpoint_interval
with TemporaryDirectory() as temp_dir:
checkpoint_file = Path(temp_dir) / "checkpoint.h5"
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_file=checkpoint_file)

# Test with invalid checkpoint_file type
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_file=12345, checkpoint_timesteps=10)

# Test with invalid checkpoint_timesteps
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_timesteps=-1, checkpoint_file="checkpoint.h5")

with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_timesteps="not_an_integer", checkpoint_file="checkpoint.h5")

# Test with invalid checkpoint_interval
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_interval=-1, checkpoint_file="checkpoint.h5")

with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_interval="not_an_integer", checkpoint_file="checkpoint.h5")
# Test with invalid checkpoint_file type
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_file=12345, checkpoint_timesteps=10)

# Test with invalid checkpoint_file
with self.assertRaises(ValueError):
SimulationExecutionOptions(
checkpoint_interval=10,
checkpoint_file="checkpoint.txt", # Wrong extension
)
# Test with both checkpoint options - should raise ValueError
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_interval=20, checkpoint_file=checkpoint_file)

with self.assertRaises(FileNotFoundError):
SimulationExecutionOptions(
checkpoint_interval=10,
checkpoint_file="nonexistent_dir/checkpoint.h5", # Non-existent directory
)
# Test with just checkpoint_file (should raise ValueError)
with self.assertRaises(ValueError):
SimulationExecutionOptions(checkpoint_file=checkpoint_file)

def test_initialization_with_valid_checkpoint_options(self):
"""Test initialization with valid checkpoint options."""
Expand All @@ -319,11 +296,86 @@ def test_initialization_with_valid_checkpoint_options(self):
self.assertEqual(options.checkpoint_interval, 20)
self.assertEqual(options.checkpoint_file, checkpoint_file)

# Test with both checkpoint options - should be valid
options = SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_interval=20, checkpoint_file=checkpoint_file)
self.assertEqual(options.checkpoint_timesteps, 10)
self.assertEqual(options.checkpoint_interval, 20)
self.assertEqual(options.checkpoint_file, checkpoint_file)
def test_checkpoint_interval_validation(self):
"""Test validation of checkpoint_interval property."""
options = self.default_options

# Test valid values
options.checkpoint_interval = 100
self.assertEqual(options.checkpoint_interval, 100)

# Test invalid values
with self.assertRaises(ValueError):
options.checkpoint_interval = 0
with self.assertRaises(ValueError):
options.checkpoint_interval = -1
with self.assertRaises(ValueError):
options.checkpoint_interval = "invalid"
with self.assertRaises(ValueError):
options.checkpoint_interval = 1.5

def test_checkpoint_timesteps_validation(self):
"""Test validation of checkpoint_timesteps property."""
options = self.default_options

# Test valid values
options.checkpoint_timesteps = 0
self.assertEqual(options.checkpoint_timesteps, 0)
options.checkpoint_timesteps = 100
self.assertEqual(options.checkpoint_timesteps, 100)

# Test invalid values
with self.assertRaises(ValueError):
options.checkpoint_timesteps = -1
with self.assertRaises(ValueError):
options.checkpoint_timesteps = "invalid"
with self.assertRaises(ValueError):
options.checkpoint_timesteps = 1.5

def test_checkpoint_file_validation(self):
"""Test validation of checkpoint file path."""
options = self.default_options

# Test with non-existent directory
with self.assertRaises(FileNotFoundError) as cm:
options.checkpoint_file = "invalid/path/checkpoint.h5"
expected_folder = str(Path("invalid") / "path")
self.assertEqual(str(cm.exception), f"Checkpoint folder {expected_folder} does not exist.")

# Test with temporary directory
with TemporaryDirectory() as temp_dir:
# Test invalid file extension
invalid_file = Path(temp_dir) / "checkpoint.txt"
with self.assertRaises(ValueError) as cm:
options.checkpoint_file = invalid_file
self.assertEqual(str(cm.exception), f"Checkpoint file {invalid_file} must have .h5 extension.")

# Test valid file path
valid_file = Path(temp_dir) / "checkpoint.h5"
options.checkpoint_file = valid_file
self.assertEqual(options.checkpoint_file, valid_file)

# Test invalid type
with self.assertRaises(ValueError) as cm:
options.checkpoint_file = 123
self.assertEqual(str(cm.exception), "Checkpoint file must be a string or Path object.")

def test_checkpoint_file_required_when_parameters_set(self):
"""Test that checkpoint file is required when checkpoint parameters are set."""
options = self.default_options

# Test with checkpoint_interval
options.checkpoint_interval = 10
with self.assertRaises(ValueError) as cm:
options.validate()
self.assertEqual(str(cm.exception), "`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.")

# Test with checkpoint_timesteps
options.checkpoint_interval = None
options.checkpoint_timesteps = 10
with self.assertRaises(ValueError) as cm:
options.validate()
self.assertEqual(str(cm.exception), "`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.")


if __name__ == "__main__":
Expand Down
Loading