diff --git a/kwave/executor.py b/kwave/executor.py index 77e10803..7caff892 100644 --- a/kwave/executor.py +++ b/kwave/executor.py @@ -33,6 +33,9 @@ def _make_binary_executable(self): binary_path.chmod(binary_path.stat().st_mode | stat.S_IEXEC) def run_simulation(self, input_filename: str, output_filename: str, options: list[str]) -> dotdict: + # Validate execution options before running simulation + self.execution_options.validate() + command = [str(self.execution_options.binary_path), "-i", input_filename, "-o", output_filename] + options try: @@ -116,7 +119,6 @@ def parse_executable_output(output_filename: str) -> dotdict: # # Combine the sensor data if using a kWaveTransducer as a sensor # if isinstance(sensor, kWaveTransducer): # sensor_data['p'] = sensor.combine_sensor_data(sensor_data['p']) - # # Compute the intensity outputs # if any(key.startswith(('I_avg', 'I')) for key in sensor.get('record', [])): # flags = { diff --git a/kwave/options/simulation_execution_options.py b/kwave/options/simulation_execution_options.py index f7bbf7c2..8e699cfa 100644 --- a/kwave/options/simulation_execution_options.py +++ b/kwave/options/simulation_execution_options.py @@ -51,9 +51,34 @@ def __init__( self.checkpoint_timesteps = checkpoint_timesteps self.checkpoint_file = checkpoint_file - if self.checkpoint_file is not None: - if self.checkpoint_interval is None and self.checkpoint_timesteps is None: - raise ValueError("One of checkpoint_interval or checkpoint_timesteps must be set when checkpoint_file is set.") + self.validate() + + def _validate_checkpoint_options(self): + # Checkpointing parameters are set + if self.checkpoint_interval is not None or self.checkpoint_timesteps is not None: + # No checkpoint file set + if self.checkpoint_file is None: + raise ValueError("`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.") + # Both checkpointing parameters are set + if self.checkpoint_interval is not None and self.checkpoint_timesteps is not None: + raise ValueError("`checkpoint_interval` and `checkpoint_timesteps` cannot be set at the same time.") + # Checkpoint file is set but no checkpointing parameters are set + if self.checkpoint_file is not None and self.checkpoint_interval is None and self.checkpoint_timesteps is None: + raise ValueError("`checkpoint_interval` or `checkpoint_timesteps` must be set when `checkpoint_file` is set.") + + def validate(self): + """Validate all simulation options before running a simulation. + + This method should be called before running a simulation to ensure all options + are in a valid state. It validates: + 1. Checkpoint configuration (if any checkpoint options are set) + + Raises: + ValueError: If any option configuration is invalid + """ + # Validate checkpoint configuration if any checkpoint options are set + if any(x is not None for x in [self.checkpoint_interval, self.checkpoint_timesteps, self.checkpoint_file]): + self._validate_checkpoint_options() @property def num_threads(self) -> Union[int, str]: @@ -184,8 +209,8 @@ def checkpoint_interval(self) -> Optional[int]: @checkpoint_interval.setter def checkpoint_interval(self, value: Optional[int]): if value is not None: - if not isinstance(value, int) or value < 0: - raise ValueError("Checkpoint interval must be a positive integer") + if not isinstance(value, int) or value <= 0: + raise ValueError("Checkpoint interval must be a positive integer in seconds") self._checkpoint_interval = value @property diff --git a/tests/test_simulation_execution_options.py b/tests/test_simulation_execution_options.py index 0af0c330..04da4890 100644 --- a/tests/test_simulation_execution_options.py +++ b/tests/test_simulation_execution_options.py @@ -266,43 +266,20 @@ def test_as_list_with_valid_checkpoint_timesteps_options(self): def test_initialization_with_invalid_checkpoint_options(self): """Test initialization with invalid checkpoint options.""" - - # Test with valid checkpoint_file but unset checkpoint_timesteps and checkpoint_interval with TemporaryDirectory() as temp_dir: checkpoint_file = Path(temp_dir) / "checkpoint.h5" - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_file=checkpoint_file) - - # Test with invalid checkpoint_file type - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_file=12345, checkpoint_timesteps=10) - - # Test with invalid checkpoint_timesteps - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_timesteps=-1, checkpoint_file="checkpoint.h5") - - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_timesteps="not_an_integer", checkpoint_file="checkpoint.h5") - - # Test with invalid checkpoint_interval - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_interval=-1, checkpoint_file="checkpoint.h5") - with self.assertRaises(ValueError): - SimulationExecutionOptions(checkpoint_interval="not_an_integer", checkpoint_file="checkpoint.h5") + # Test with invalid checkpoint_file type + with self.assertRaises(ValueError): + SimulationExecutionOptions(checkpoint_file=12345, checkpoint_timesteps=10) - # Test with invalid checkpoint_file - with self.assertRaises(ValueError): - SimulationExecutionOptions( - checkpoint_interval=10, - checkpoint_file="checkpoint.txt", # Wrong extension - ) + # Test with both checkpoint options - should raise ValueError + with self.assertRaises(ValueError): + SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_interval=20, checkpoint_file=checkpoint_file) - with self.assertRaises(FileNotFoundError): - SimulationExecutionOptions( - checkpoint_interval=10, - checkpoint_file="nonexistent_dir/checkpoint.h5", # Non-existent directory - ) + # Test with just checkpoint_file (should raise ValueError) + with self.assertRaises(ValueError): + SimulationExecutionOptions(checkpoint_file=checkpoint_file) def test_initialization_with_valid_checkpoint_options(self): """Test initialization with valid checkpoint options.""" @@ -319,11 +296,86 @@ def test_initialization_with_valid_checkpoint_options(self): self.assertEqual(options.checkpoint_interval, 20) self.assertEqual(options.checkpoint_file, checkpoint_file) - # Test with both checkpoint options - should be valid - options = SimulationExecutionOptions(checkpoint_timesteps=10, checkpoint_interval=20, checkpoint_file=checkpoint_file) - self.assertEqual(options.checkpoint_timesteps, 10) - self.assertEqual(options.checkpoint_interval, 20) - self.assertEqual(options.checkpoint_file, checkpoint_file) + def test_checkpoint_interval_validation(self): + """Test validation of checkpoint_interval property.""" + options = self.default_options + + # Test valid values + options.checkpoint_interval = 100 + self.assertEqual(options.checkpoint_interval, 100) + + # Test invalid values + with self.assertRaises(ValueError): + options.checkpoint_interval = 0 + with self.assertRaises(ValueError): + options.checkpoint_interval = -1 + with self.assertRaises(ValueError): + options.checkpoint_interval = "invalid" + with self.assertRaises(ValueError): + options.checkpoint_interval = 1.5 + + def test_checkpoint_timesteps_validation(self): + """Test validation of checkpoint_timesteps property.""" + options = self.default_options + + # Test valid values + options.checkpoint_timesteps = 0 + self.assertEqual(options.checkpoint_timesteps, 0) + options.checkpoint_timesteps = 100 + self.assertEqual(options.checkpoint_timesteps, 100) + + # Test invalid values + with self.assertRaises(ValueError): + options.checkpoint_timesteps = -1 + with self.assertRaises(ValueError): + options.checkpoint_timesteps = "invalid" + with self.assertRaises(ValueError): + options.checkpoint_timesteps = 1.5 + + def test_checkpoint_file_validation(self): + """Test validation of checkpoint file path.""" + options = self.default_options + + # Test with non-existent directory + with self.assertRaises(FileNotFoundError) as cm: + options.checkpoint_file = "invalid/path/checkpoint.h5" + expected_folder = str(Path("invalid") / "path") + self.assertEqual(str(cm.exception), f"Checkpoint folder {expected_folder} does not exist.") + + # Test with temporary directory + with TemporaryDirectory() as temp_dir: + # Test invalid file extension + invalid_file = Path(temp_dir) / "checkpoint.txt" + with self.assertRaises(ValueError) as cm: + options.checkpoint_file = invalid_file + self.assertEqual(str(cm.exception), f"Checkpoint file {invalid_file} must have .h5 extension.") + + # Test valid file path + valid_file = Path(temp_dir) / "checkpoint.h5" + options.checkpoint_file = valid_file + self.assertEqual(options.checkpoint_file, valid_file) + + # Test invalid type + with self.assertRaises(ValueError) as cm: + options.checkpoint_file = 123 + self.assertEqual(str(cm.exception), "Checkpoint file must be a string or Path object.") + + def test_checkpoint_file_required_when_parameters_set(self): + """Test that checkpoint file is required when checkpoint parameters are set.""" + options = self.default_options + + # Test with checkpoint_interval + options.checkpoint_interval = 10 + with self.assertRaises(ValueError) as cm: + options.validate() + self.assertEqual(str(cm.exception), "`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.") + + # Test with checkpoint_timesteps + options.checkpoint_interval = None + options.checkpoint_timesteps = 10 + with self.assertRaises(ValueError) as cm: + options.validate() + self.assertEqual(str(cm.exception), "`checkpoint_file` must be set when `checkpoint_interval` or `checkpoint_timesteps` is set.") if __name__ == "__main__":