diff --git a/dev-environment.yml b/dev-environment.yml index 31b11a8eb..cdbcb311c 100644 --- a/dev-environment.yml +++ b/dev-environment.yml @@ -9,6 +9,8 @@ dependencies: - geopandas>=0.12.0 - rasterio>=1.3,<2 - scipy=1.* + - xarray + - rioxarray=0.* # Second-order or plus dependency on the above - numpy>=1,<3 - pyproj>=3.4,<4 @@ -21,6 +23,7 @@ dependencies: - pip # Optional dependencies + - dask # For scalable operations - numba=0.* # For fast numerical operations (terrain) - matplotlib=3.* # For plotting - scikit-learn # For optimizations @@ -62,4 +65,4 @@ dependencies: - -e ./ # To run CI against latest GeoUtils -# - git+https://github.com/rhugonnet/geoutils.git +# - git+https://github.com/rhugonnet/geoutils.git diff --git a/environment.yml b/environment.yml index 0d5f54ddf..8511bf694 100644 --- a/environment.yml +++ b/environment.yml @@ -9,6 +9,8 @@ dependencies: - geopandas>=0.12.0 - rasterio>=1.3,<2 - scipy=1.* + - xarray + - rioxarray=0.* # Second-order or plus dependency on the above - numpy>=1,<3 - pyproj>=3.4,<4 diff --git a/examples/advanced/plot_blockwise_coreg.py b/examples/advanced/plot_blockwise_coreg.py index 0f548c7d0..07436e018 100644 --- a/examples/advanced/plot_blockwise_coreg.py +++ b/examples/advanced/plot_blockwise_coreg.py @@ -22,7 +22,7 @@ # sphinx_gallery_thumbnail_number = 2 import matplotlib.pyplot as plt import numpy as np -from geoutils.raster.distributed_computing import MultiprocConfig +from geoutils.multiproc import MultiprocConfig import xdem diff --git a/setup.cfg b/setup.cfg index a36c8f92e..c4a5b598c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,7 @@ include = [options.extras_require] opt = + dask numba matplotlib scikit-learn diff --git a/tests/test_coreg/test_affine.py b/tests/test_coreg/test_affine.py index c8625f586..b317081da 100644 --- a/tests/test_coreg/test_affine.py +++ b/tests/test_coreg/test_affine.py @@ -13,7 +13,7 @@ import rasterio as rio import scipy.optimize from geoutils import Raster, Vector -from geoutils.raster.geotransformations import _translate +from geoutils.raster.transformation import _translate from scipy.ndimage import binary_dilation from xdem import coreg, examples @@ -57,8 +57,8 @@ class TestAffineCoreg: fit_args_rst_rst = dict(reference_elev=ref, to_be_aligned_elev=tba, inlier_mask=inlier_mask) # Convert DEMs to points with a bit of subsampling for speed-up - ref_pts = ref.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds - tba_pts = ref.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds + ref_pts = ref.to_pointcloud(data_column_name="z").ds + tba_pts = ref.to_pointcloud(data_column_name="z").ds # Raster-Point fit_args_rst_pts = dict(reference_elev=ref, to_be_aligned_elev=tba_pts, inlier_mask=inlier_mask) @@ -195,11 +195,11 @@ def test_coreg_translations__synthetic( ref_shifted = ref.translate(shifts[0], shifts[1]) + shifts[2] # Convert to point cloud if input was point cloud if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - ref_shifted = ref_shifted.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds + ref_shifted = ref_shifted.to_pointcloud(data_column_name="z").ds elev_fit_args["to_be_aligned_elev"] = ref_shifted # Run coregistration - subsample_size = 50000 if coreg_method != coreg.CPD else 500 + subsample_size = 1 if coreg_method != coreg.CPD else 500 coreg_elev = horizontal_coreg.fit_and_apply(**elev_fit_args, subsample=subsample_size, random_state=42) # Check all fit parameters are the opposite of those used above, within a relative 1% (10% for ICP) @@ -291,11 +291,11 @@ def test_coreg_vertical_translation__synthetic(self, fit_args: Any, vshift: floa # Convert to point cloud if input was point cloud if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - ref_vshifted = ref_vshifted.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds + ref_vshifted = ref_vshifted.to_pointcloud(data_column_name="z").ds elev_fit_args["to_be_aligned_elev"] = ref_vshifted # Fit the vertical shift model to the data - coreg_elev = vshiftcorr.fit_and_apply(**elev_fit_args, subsample=50000, random_state=42) + coreg_elev = vshiftcorr.fit_and_apply(**elev_fit_args) # Check that the right vertical shift was found assert vshiftcorr.meta["outputs"]["affine"]["shift_z"] == pytest.approx(-vshift, rel=10e-2) @@ -390,13 +390,11 @@ def test_coreg_rigid__synthetic( # Convert to point cloud if input was point cloud if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - ref_shifted_rotated = ref_shifted_rotated.to_pointcloud( - data_column_name="z", subsample=50000, random_state=42 - ).ds + ref_shifted_rotated = ref_shifted_rotated.to_pointcloud(data_column_name="z").ds elev_fit_args["to_be_aligned_elev"] = ref_shifted_rotated # Run coregistration - subsample_size = 50000 if coreg_method != coreg.CPD else 500 + subsample_size = 1 if coreg_method != coreg.CPD else 500 coreg_elev = horizontal_coreg.fit_and_apply(**elev_fit_args, subsample=subsample_size, random_state=42) # Check that fit matrix is the invert of those used above, within a relative % for rotations @@ -443,7 +441,7 @@ def test_coreg_rigid__synthetic( # Need to standardize by the elevation difference spread to avoid huge/small values close to infinity # Checking for 90% of variance as ICP cannot always resolve the small shifts # And only 10% of variance for CPD that can't resolve shifts at all - fac_reduc_var = 0.1 if coreg_method != coreg.CPD else 1.0 + fac_reduc_var = 0.1 if coreg_method != coreg.CPD else 1.05 assert np.nanvar(dh / np.nanstd(init_dh)) < fac_reduc_var @pytest.mark.parametrize( @@ -506,7 +504,7 @@ def test_coreg_rigid__specific_args(self, rigid_coreg: coreg.Coreg) -> None: ref_shifted_rotated = coreg.apply_matrix(ref, matrix=matrix, centroid=centroid) # Coregister - subsample_size = 50000 if rigid_coreg.__class__.__name__ != "CPD" else 500 + subsample_size = 1 if rigid_coreg.__class__.__name__ != "CPD" else 500 rigid_coreg.fit(ref, ref_shifted_rotated, random_state=42, subsample=subsample_size) @pytest.mark.parametrize("coreg_method", [coreg.ICP, coreg.CPD, coreg.LZD]) @@ -523,7 +521,7 @@ def test_coreg_rigid__only_translation(self, coreg_method: coreg.Coreg) -> None: ref_shifted_rotated = coreg.apply_matrix(ref, matrix=matrix, centroid=centroid) # Run co-registration - subsample_size = 50000 if coreg_method != coreg.CPD else 500 + subsample_size = 1 if coreg_method != coreg.CPD else 500 c = coreg_method(subsample=subsample_size, only_translation=True) c.fit(ref, ref_shifted_rotated, random_state=42) @@ -540,7 +538,7 @@ def test_coreg_rigid__only_translation(self, coreg_method: coreg.Coreg) -> None: assert np.allclose(invert_fit_shifts_translations[:3], shifts_rotations[:3], rtol=10e-1) @pytest.mark.parametrize("coreg_method", [coreg.ICP, coreg.CPD]) - def test_coreg_rigid__standardize(self, coreg_method: coreg.Coreg) -> None: + def test_coreg_rigid__standardize(self, coreg_method: type[coreg.Coreg]) -> None: # Get reference elevation ref = self.ref @@ -553,7 +551,7 @@ def test_coreg_rigid__standardize(self, coreg_method: coreg.Coreg) -> None: ref_shifted_rotated = coreg.apply_matrix(ref, matrix=matrix, centroid=centroid) # 1/ Run co-registration with standardization - subsample_size = 50000 if coreg_method != coreg.CPD else 500 + subsample_size = 1 if coreg_method != coreg.CPD else 500 c_std = coreg_method(subsample=subsample_size, standardize=True) c_std.fit(ref, ref_shifted_rotated, random_state=42) @@ -610,11 +608,11 @@ def test_nuthkaab_initial_shift(self) -> None: # Get the coregistration method and expected shifts from the inputs inlier_mask = ~self.outlines.create_mask(ref) - c = coreg.NuthKaab(initial_shift=(0, 0, 0), subsample=50000) + c = coreg.NuthKaab(initial_shift=(0, 0, 0)) dem_aligned_is = c.fit_and_apply(ref, tba, inlier_mask=inlier_mask, random_state=42) shifts_is = [c.meta["outputs"]["affine"][k] for k in ["shift_x", "shift_y", "shift_z"]] # type: ignore - c = coreg.NuthKaab(subsample=50000) + c = coreg.NuthKaab() dem_aligned = c.fit_and_apply(ref, tba, inlier_mask=inlier_mask, random_state=42) shifts = [c.meta["outputs"]["affine"][k] for k in ["shift_x", "shift_y", "shift_z"]] # type: ignore diff --git a/tests/test_coreg/test_base.py b/tests/test_coreg/test_base.py index a66137ec2..e2187d383 100644 --- a/tests/test_coreg/test_base.py +++ b/tests/test_coreg/test_base.py @@ -48,6 +48,8 @@ def assert_coreg_meta_equal(input1: Any, input2: Any) -> bool: """Short test function to check equality of coreg dictionary values.""" # Different equality check based on input: number, callable, array, dataframe + if input1 is None: + return input2 is None if not isinstance(input1, type(input2)): return False elif isinstance(input1, (str, float, int, np.floating, np.integer, tuple, list)) or callable(input1): @@ -137,7 +139,7 @@ def test_copy(self, coreg_class: Callable[[], Coreg]) -> None: # Make sure these don't appear in the copy assert corr_copy.meta != corr.meta - @pytest.mark.parametrize("subsample", [10, 10000, 0.5, 1]) + @pytest.mark.parametrize("subsample", [10, 1000, 0.5, 1]) def test_get_subsample_on_valid_mask(self, subsample: float | int) -> None: """Test the subsampling function called by all subclasses""" @@ -192,8 +194,8 @@ def test_subsample(self, coreg_class: Any) -> None: fit_kwargs = {} # But can be overridden during fit - coreg_full.fit(**self.fit_params, subsample=10000, random_state=42, **fit_kwargs) - assert coreg_full.meta["inputs"]["random"]["subsample"] == 10000 + coreg_full.fit(**self.fit_params, subsample=1000, random_state=42, **fit_kwargs) + assert coreg_full.meta["inputs"]["random"]["subsample"] == 1000 # Check that the random state is properly set when subsampling explicitly or implicitly assert coreg_full.meta["inputs"]["random"]["random_state"] == 42 @@ -214,11 +216,11 @@ def test_subsample__pipeline(self) -> None: """Test that the subsample argument works as intended for pipelines""" # Check definition during instantiation - pipe = coreg.VerticalShift(subsample=200) + coreg.Deramp(subsample=5000) + pipe = coreg.VerticalShift(subsample=200) + coreg.Deramp(subsample=1000) # Check the arguments are properly defined assert pipe.pipeline[0].meta["inputs"]["random"]["subsample"] == 200 - assert pipe.pipeline[1].meta["inputs"]["random"]["subsample"] == 5000 + assert pipe.pipeline[1].meta["inputs"]["random"]["subsample"] == 1000 # Check definition during fit pipe = coreg.VerticalShift() + coreg.Deramp() @@ -387,12 +389,12 @@ def test_fit_and_apply(self, coreg_class: Any) -> None: fit_kwargs = {} # Perform fit, then apply - coreg_fit_then_apply.fit(**self.fit_params, subsample=10000, random_state=42, **fit_kwargs) + coreg_fit_then_apply.fit(**self.fit_params, **fit_kwargs) aligned_then = coreg_fit_then_apply.apply(elev=self.fit_params["to_be_aligned_elev"]) # Perform fit and apply aligned_and = coreg_fit_and_apply.fit_and_apply( - **self.fit_params, subsample=10000, random_state=42, fit_kwargs=fit_kwargs + **self.fit_params, fit_kwargs=fit_kwargs ) # Check outputs are the same: aligned raster, and metadata keys and values @@ -415,11 +417,11 @@ def test_fit_and_apply__pipeline(self) -> None: coreg_fit_and_apply = coreg.NuthKaab() + coreg.Deramp() # Perform fit, then apply - coreg_fit_then_apply.fit(**self.fit_params, subsample=10000, random_state=42) + coreg_fit_then_apply.fit(**self.fit_params) aligned_then = coreg_fit_then_apply.apply(elev=self.fit_params["to_be_aligned_elev"]) # Perform fit and apply - aligned_and = coreg_fit_and_apply.fit_and_apply(**self.fit_params, subsample=10000, random_state=42) + aligned_and = coreg_fit_and_apply.fit_and_apply(**self.fit_params) assert aligned_and.raster_equal(aligned_then, warn_failure_reason=True) assert list(coreg_fit_and_apply.pipeline[0].meta.keys()) == list(coreg_fit_then_apply.pipeline[0].meta.keys()) @@ -701,7 +703,7 @@ def test_pipeline(self) -> None: # Create a pipeline from two coreg methods. pipeline = coreg.CoregPipeline([coreg.VerticalShift(), coreg.NuthKaab()]) - pipeline.fit(**self.fit_params, subsample=5000, random_state=42) + pipeline.fit(**self.fit_params) aligned_dem, _ = pipeline.apply(self.tba.data, transform=self.ref.transform, crs=self.ref.crs) @@ -734,7 +736,7 @@ def test_pipeline_combinations__nobiasvar(self, coreg1: Callable[[], Coreg], cor # Create a pipeline from one affine and one biascorr methods. pipeline = coreg.CoregPipeline([coreg1(), coreg2()]) - pipeline.fit(**self.fit_params, subsample=5000, random_state=42) + pipeline.fit(**self.fit_params) aligned_dem, _ = pipeline.apply(self.tba.data, transform=self.ref.transform, crs=self.ref.crs) assert aligned_dem.shape == self.ref.data.squeeze().shape @@ -755,7 +757,7 @@ def test_pipeline_combinations__biasvar( # Create a pipeline from one affine and one biascorr methods pipeline = coreg.CoregPipeline([coreg1(), coreg.BiasCorr(**coreg2_init_kwargs)]) # type: ignore bias_vars = {"slope": xdem.terrain.slope(self.ref), "aspect": xdem.terrain.aspect(self.ref)} - pipeline.fit(**self.fit_params, bias_vars=bias_vars, subsample=5000, random_state=42) + pipeline.fit(**self.fit_params, bias_vars=bias_vars) aligned_dem, _ = pipeline.apply( self.tba.data, transform=self.ref.transform, crs=self.ref.crs, bias_vars=bias_vars @@ -810,7 +812,7 @@ def test_pipeline__errors(self) -> None: def test_pipeline_pts(self) -> None: pipeline = coreg.NuthKaab() + coreg.DhMinimize() - ref_points = self.ref.to_pointcloud(subsample=5000, random_state=42) + ref_points = self.ref.to_pointcloud() # Check that this runs without error pipeline.fit(reference_elev=ref_points, to_be_aligned_elev=self.tba) diff --git a/tests/test_coreg/test_biascorr.py b/tests/test_coreg/test_biascorr.py index feb46c2bc..f64bb3476 100644 --- a/tests/test_coreg/test_biascorr.py +++ b/tests/test_coreg/test_biascorr.py @@ -41,8 +41,8 @@ class TestBiasCorr: fit_args_rst_rst = dict(reference_elev=ref, to_be_aligned_elev=tba, inlier_mask=inlier_mask) # Convert DEMs to points with a bit of subsampling for speed-up - tba_pts = tba.to_pointcloud(data_column_name="z", subsample=50000, random_state=42) - ref_pts = ref.to_pointcloud(data_column_name="z", subsample=50000, random_state=42) + tba_pts = tba.to_pointcloud(data_column_name="z") + ref_pts = ref.to_pointcloud(data_column_name="z") # Raster-Point fit_args_rst_pts = dict(reference_elev=ref, to_be_aligned_elev=tba_pts, inlier_mask=inlier_mask) @@ -292,7 +292,7 @@ def test_biascorr__bin_2d(self, fit_args: Any, bin_sizes: Any, bin_statistic: An elev_fit_args.update({"bias_vars": bias_vars_dict}) # Run with input parameter, and using only 100 subsamples for speed - bcorr.fit(**elev_fit_args, subsample=10000, random_state=42) + bcorr.fit(**elev_fit_args) # Check that variable names are defined during fit assert bcorr.meta["inputs"]["fitorbin"]["bias_var_names"] == ["elevation", "slope"] @@ -439,7 +439,7 @@ def test_directionalbias__synthetic(self, fit_args: Any, angle: float, nb_freq: plt.show() dirbias = biascorr.DirectionalBias(angle=angle, fit_or_bin="bin", bin_sizes=10000) - dirbias.fit(reference_elev=self.ref, to_be_aligned_elev=bias_dem, subsample=10000, random_state=42) + dirbias.fit(reference_elev=self.ref, to_be_aligned_elev=bias_dem) xdem.spatialstats.plot_1d_binning( df=dirbias.meta["outputs"]["fitorbin"]["bin_dataframe"], var_name="angle", @@ -464,14 +464,12 @@ def test_directionalbias__synthetic(self, fit_args: Any, angle: float, nb_freq: elev_fit_args = fit_args.copy() if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): # Need a higher sample size to get the coefficients right here - bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=50000, random_state=42).ds + bias_elev = bias_dem.to_pointcloud(data_column_name="z").ds else: bias_elev = bias_dem dirbias.fit( elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, - subsample=40000, - random_state=42, bounds_amp_wave_phase=bounds, niter=2, ) @@ -524,10 +522,10 @@ def test_deramp__synthetic(self, fit_args: Any, order: int) -> None: deramp = biascorr.Deramp(poly_order=order) elev_fit_args = fit_args.copy() if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=30000, random_state=42).ds + bias_elev = bias_dem.to_pointcloud(data_column_name="z").ds else: bias_elev = bias_dem - deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, subsample=20000, random_state=42) + deramp.fit(elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev) # Check high-order fit parameters are the same within 10% fit_params = deramp.meta["outputs"]["fitorbin"]["fit_params"] @@ -582,14 +580,12 @@ def test_terrainbias__synthetic(self, fit_args: Any) -> None: ) elev_fit_args = fit_args.copy() if isinstance(elev_fit_args["to_be_aligned_elev"], gpd.GeoDataFrame): - bias_elev = bias_dem.to_pointcloud(data_column_name="z", subsample=20000, random_state=42).ds + bias_elev = bias_dem.to_pointcloud(data_column_name="z").ds else: bias_elev = bias_dem tb.fit( elev_fit_args["reference_elev"], to_be_aligned_elev=bias_elev, - subsample=10000, - random_state=42, bias_vars={"max_curvature": maxc}, ) diff --git a/tests/test_coreg/test_blockwise.py b/tests/test_coreg/test_blockwise.py index eb1b6f4f0..4af89d214 100644 --- a/tests/test_coreg/test_blockwise.py +++ b/tests/test_coreg/test_blockwise.py @@ -11,8 +11,7 @@ import pytest from geoutils import Raster, Vector from geoutils.interface.gridding import _grid_pointcloud -from geoutils.raster import ClusterGenerator -from geoutils.raster.distributed_computing import MultiprocConfig +from geoutils.multiproc import MultiprocConfig, ClusterGenerator import xdem from xdem.coreg import BlockwiseCoreg, Coreg diff --git a/tests/test_dem/test_base.py b/tests/test_dem/test_base.py new file mode 100644 index 000000000..09694a86c --- /dev/null +++ b/tests/test_dem/test_base.py @@ -0,0 +1,630 @@ +"""Test module for the DEMBase class.""" + +from __future__ import annotations + +import warnings +from typing import Any + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from pandas.testing import assert_frame_equal + +from geoutils import Raster, Vector +from geoutils.raster import MultiprocConfig +from pyproj import CRS + +from xdem import DEM, examples, open_dem, coreg +from xdem.dem.base import DEMBase +from xdem.dem.xr_accessor import DEMAccessor + + +def assert_output_equal(output1: Any, output2: Any, use_allclose: bool = False, strict_masked: bool = True) -> None: + """Return equality of different output types.""" + + + # For two vectors + if isinstance(output1, Vector) and isinstance(output2, Vector): + assert output1.vector_equal(output2) + + # For two raster-like outputs: Xarray or DEM objects + elif isinstance(output1, (Raster, xr.DataArray)): + if use_allclose: + assert output1.raster_allclose(output2, warn_failure_reason=True, strict_masked=strict_masked) + else: + assert output1.raster_equal(output2, warn_failure_reason=True, strict_masked=strict_masked) + + # For arrays + elif isinstance(output1, np.ndarray): + if np.ma.isMaskedArray(output1): + output1 = output1.filled(np.nan) + if np.ma.isMaskedArray(output2): + output2 = output2.filled(np.nan) + if use_allclose: + assert np.allclose(output1, output2, equal_nan=True) + else: + assert np.array_equal(output1, output2, equal_nan=True) + + # For tuple of arrays + elif isinstance(output1, tuple) and len(output1) > 0 and isinstance(output1[0], np.ndarray): + assert np.array_equal(np.array(output1), np.array(output2), equal_nan=True) + + # For a list of raster-like outputs + elif isinstance(output1, list) and len(output1) > 0 and isinstance(output1[0], (Raster, xr.DataArray)): + assert len(output1) == len(output2) + for out1, out2 in zip(output1, output2): + assert_output_equal(out1, out2, use_allclose=use_allclose, strict_masked=strict_masked) + + # For a dictionary of numeric values + elif isinstance(output1, dict): + df1 = pd.DataFrame(index=[0], data=output1) + df2 = pd.DataFrame(index=[0], data=output2) + assert_frame_equal(df1, df2, check_dtype=False) + + # For any other object type + else: + assert output1 == output2 + + +def should_be_loaded(method: str, args: dict[str, Any], noload: list[str], noload_allowed_args: dict[str, Any]) -> bool: + """Helper function to check without a method input/output should be loaded or not, based on input dictionaries.""" + + # For method where the behaviour is independent of their arguments + if method not in noload_allowed_args: + # If the method has a single behaviour, simply check if it belongs in the list + should_output_be_loaded = method not in noload + # For method where the behaviour depends on their arguments + else: + # Get relevant method arguments + allowed = noload_allowed_args[method] + # If any value is different from the list of values in the allowed dictionary, it should load + any_different = not all( + (not isinstance(args[k], np.ndarray) and args[k] in allowed[k]) for k in allowed if k in args + ) + should_output_be_loaded = any_different + + return should_output_be_loaded + + +class NeedsTestError(ValueError): + """Error to remember to add test when a new DEMBase method is added.""" + + +class TestClassVsAccessorConsistencyInherited: + """ + Test class to check the consistency between the outputs of a light subset of inherited RasterBase + attributes and methods through the DEM class and Xarray accessor. + + This ensures that DEM preserves inherited raster behaviour without re-testing the full GeoUtils API. + """ + + # Run tests for different DEMs + longyearbyen_path = examples.get_path_test("longyearbyen_ref_dem") + + # Minimal representative subset of inherited attributes + inherited_attributes = ["crs", "transform", "nodata", "res", "_is_xr"] + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("attr", inherited_attributes) # type: ignore + def test_attributes__equality(self, path_dem: str, attr: str) -> None: + """Test that a minimal subset of inherited attributes of the two objects are exactly the same.""" + + # Open + ds = open_dem(path_dem) + dem = DEM(path_dem) + + # Get attribute for each object + output_dem = getattr(dem, attr) + output_ds = getattr(ds.dem, attr) + + # Assert equality + if attr != "_is_xr": # Only attribute that is (purposely) not the same, but the opposite + assert_output_equal(output_dem, output_ds) + else: + assert output_dem != output_ds + + # Minimal representative subset of inherited methods + inherited_methods_and_kwargs = [ + ("coords", {"grid": True}), # Metadata-only inherited method + ("translate", {"xoff": 10.5, "yoff": 5}), # Raster-returning inherited method + ("interp_points", {"points": "random"}), # Array-returning inherited method + ("reproject", {"crs": CRS.from_epsg(4326)}), # Raster-returning loading method + ("set_nodata", {"new_nodata": -10001, "update_array": False, "update_mask": False}), # In-place method + ] + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("method, kwargs", [(f, k) for f, k in inherited_methods_and_kwargs]) # type: ignore + def test_methods__equality(self, path_dem: str, method: str, kwargs: dict[str, Any]) -> None: + """ + Test that a minimal representative subset of inherited RasterBase methods yield the same outputs + between a DEM and Xarray dem accessor. + """ + + # Open both objects + ds = open_dem(path_dem) + dem = DEM(path_dem) + + args = kwargs.copy() + + # For methods that require knowledge of the data (relative to bounds), create specific inputs + if "points" in args: + rng = np.random.default_rng(seed=42) + ninterp = 10 + res = dem.res + interp_x = dem.bounds.left + (rng.choice(dem.shape[1], ninterp) + rng.random(ninterp)) * res[0] + interp_y = dem.bounds.bottom + (rng.choice(dem.shape[0], ninterp) + rng.random(ninterp)) * res[1] + args.update({"points": (interp_x, interp_y)}) + + # Apply method for each class + output_dem = getattr(dem, method)(**args) + output_ds = getattr(ds.dem, method)(**args) + + # In-place method + if method == "set_nodata": + assert output_dem is None + assert output_ds is None + assert_output_equal(dem, ds) + else: + assert_output_equal(output_dem, output_ds) + + # Minimal representative subset of inherited methods for loading checks + inherited_methods_loading_and_kwargs = [ + ("coords", {"grid": True}), # Metadata-only inherited method, should not load + ("reproject", {"crs": CRS.from_epsg(4326)}), # Raster inherited method, should load + ] + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) + @pytest.mark.parametrize("method, kwargs", + [(f, k) for f, k in inherited_methods_loading_and_kwargs]) + def test_methods__loading(self, path_dem: str, method: str, kwargs: dict[str, Any]) -> None: + """ + Test that a minimal subset of inherited RasterBase methods preserve the expected loading behaviour + between a DEM and Xarray dem accessor. + """ + + # Open both objects + ds = open_dem(path_dem) + dem = DEM(path_dem) + + args = kwargs.copy() + + # Apply method for each class + output_dem = getattr(dem, method)(**args) + output_ds = getattr(ds.dem, method)(**args) + + # Check using method did or did not load the input DEM or Xarray dataset + should_input_be_loaded = method not in ["coords"] + assert dem.is_loaded is should_input_be_loaded + assert ds._in_memory is should_input_be_loaded + + # In the case of a DEM / DataArray output, check if output is loaded or not + if isinstance(output_ds, xr.DataArray): + # coords returns arrays; reproject returns raster-like output and should be loaded here + assert output_dem.is_loaded + assert output_ds._in_memory + + # Finally, assert exact equality of outputs + assert_output_equal(output_dem, output_ds) + + # Minimal representative subset of inherited chunked methods + inherited_chunked_methods_and_args = [ + ("interp_points", {"points": "random", "as_array": True}), # Array inherited method + ("reproject", {"crs": CRS.from_epsg(4326)}), # Raster inherited method + ] + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("method, kwargs", + [(f, k) for f, k in inherited_chunked_methods_and_args]) # type: ignore + def test_chunked_methods__loading_laziness(self, path_dem: str, method: str, kwargs: dict[str, Any]) -> None: + """ + Test that a minimal subset of inherited chunked methods preserve loading and laziness. + + They should yield the exact same output for: + - Dask backend through Xarray accessor, + - Multiprocessing backend through DEM class. + """ + + pytest.importorskip("dask") + import dask.array as da + + # Open lazily with Dask + ds = open_dem(path_dem, chunks={"band": 1, "x": 25, "y": 25}) + # Open DEM that will be processed using Multiprocessing + mp_config = MultiprocConfig(chunk_size=25) + dem = DEM(path_dem) + + args = kwargs.copy() + + # Special arguments + if "points" in args: + rng = np.random.default_rng(seed=42) + ninterp = 10 + res = dem.res + interp_x = dem.bounds.left + (rng.choice(dem.shape[1], ninterp) + rng.random(ninterp)) * res[0] + interp_y = dem.bounds.bottom + (rng.choice(dem.shape[0], ninterp) + rng.random(ninterp)) * res[1] + args.update({"points": (interp_x, interp_y)}) + + # Apply method for each + output_dem = getattr(dem, method)(**args, mp_config=mp_config) + output_ds = getattr(ds.dem, method)(**args) + + # For a raster-type output + if isinstance(output_dem, DEM): + + # 1/ For Dask object: both inputs and outputs should be unloaded + lazy, and compute + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + assert ds.data.chunks is not None + + assert not output_ds._in_memory + assert isinstance(output_ds.data, da.Array) + assert output_ds.data.chunks is not None + + output_ds = output_ds.compute() + assert isinstance(output_ds.data, np.ndarray) + assert output_ds._in_memory + + # 2/ For Multiprocessing, output remains unloaded + assert not dem.is_loaded + assert not output_dem.is_loaded + + # For an array-type output + elif isinstance(output_dem, np.ndarray): + + # 1/ For Dask object: input and output should be unloaded + lazy, and compute + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + assert ds.data.chunks is not None + + assert isinstance(output_ds, da.Array) + assert output_ds.chunks is not None + + output_ds = output_ds.compute() + assert isinstance(output_ds, np.ndarray) + + # 2/ For Multiprocessing, input remains unloaded + assert not dem.is_loaded + assert isinstance(output_dem, np.ndarray) + + # Check outputs are the same + assert_output_equal(output_dem, output_ds, use_allclose=True) + + +class TestClassVsAccessorConsistencyDEMBase: + """ + Test class to check the consistency between the outputs, loading, laziness and chunked operations + of the DEM class and Xarray accessor for DEMBase-specific attributes or methods. + + All DEM-specific shared attributes should be the same. + All DEM-specific operations manipulating the array should yield a comparable results, accounting for the fact that + DEM class relies on masked-arrays and the Xarray accessor on NaN arrays. + """ + + # Run tests for different DEMs + longyearbyen_path = examples.get_path_test("longyearbyen_ref_dem") + + # Get all DEMBase public properties and methods, ensures we test absolutely everything even with API changes + # Only methods/properties defined on DEMBase are checked here, inherited RasterBase methods are tested above. + properties = [k for k, v in DEMBase.__dict__.items() if not k.startswith("_") and isinstance(v, property)] + methods = [k for k, v in DEMBase.__dict__.items() if not k.startswith("_") and not isinstance(v, property)] + + # List of properties that WILL load the input dataset (only one does, the data itself, if DEMBase defines one) + properties_input_load = ["data"] + + # List of DEM-specific methods that WILL NOT load the input dataset + methods_input_noload: list[str] = ["set_vcrs"] + + # List of DEM-specific methods that WILL NOT load the input for certain arguments + methods_input_noload_allowed_args: dict[str, Any] = {} + + # List of DEM-specific methods that WILL NOT load the output dataset + methods_output_noload: list[str] = ["set_vcrs"] + + # List of DEM-specific methods that WILL NOT LOAD the output for certain arguments + methods_output_noload_allowed_args: dict[str, Any] = {} + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("prop", properties) # type: ignore + def test_properties__equality_and_loading(self, path_dem: str, prop: str) -> None: + """ + Test that DEMBase-specific properties are exactly equal between a DEM and DataArray using the "dem" accessor, + and if they do not load the dataset or not. + """ + + # Open + ds = open_dem(path_dem) + dem = DEM(path_dem) + + # Remove warnings about operations in a non-projected system, and future changes + warnings.simplefilter("ignore", category=UserWarning) + warnings.simplefilter("ignore", category=FutureWarning) + + # Get attribute for each object + output_dem = getattr(dem, prop) + output_ds = getattr(ds.dem, prop) + + # Assert equality + if prop == "_is_xr": # Only attribute that is (purposely) not the same, but the boolean opposite + assert output_dem != output_ds + else: + assert_output_equal(output_dem, output_ds) + + # Check getting attribute did not (or did) load the DEM or Xarray dataset + should_input_be_loaded = prop in self.properties_input_load + assert dem.is_loaded is should_input_be_loaded + assert ds._in_memory is should_input_be_loaded + + # Test DEMBase-specific methods + methods_and_kwargs = [ + # 1. Will load, not inplace + ("to_vcrs", {"vcrs": "EGM96", "force_source_vcrs": "Ellipsoid"}), + ("slope", {}), + ("aspect", {}), + ("hillshade", {}), + ("curvature", {}), + ("profile_curvature", {}), + ("planform_curvature", {}), + ("tangential_curvature", {}), + ("flowline_curvature", {}), + ("min_curvature", {}), + ("max_curvature", {}), + ("topographic_position_index", {}), + ("terrain_ruggedness_index", {}), + ("roughness", {}), + ("rugosity", {}), + ("fractal_roughness", {}), + ("texture_shading", {}), + ("get_terrain_attribute", {"attribute": ["slope", "aspect"]}), + ("to_pointcloud", {}), + ("coregister_3d", {"custom"}), # Define inside function + ("estimate_uncertainty", {"custom"}), # Define inside function + # 2. Inplace, will not load + ("set_vcrs", {"new_vcrs": "EGM96"}) + ] + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("method, kwargs", [(f, k) for f, k in methods_and_kwargs]) # type: ignore + def test_methods__equality_and_loading(self, path_dem: str, method: str, kwargs: dict[str, Any]) -> None: + """ + Test that the DEMBase-specific method output and loading mechanism of the two objects are exactly the same + between a DEM and Xarray dem accessor. + """ + + # Open both objects + ds = open_dem(path_dem) + dem = DEM(path_dem) + + # Remove warnings about operations in a non-projected system, and future changes + warnings.simplefilter("ignore", category=UserWarning) + warnings.simplefilter("ignore", category=FutureWarning) + + args = kwargs.copy() + if method == "coregister_3d": + # Temporary skip until coreg module is adapted + return + # other_dem = dem.translate(1, 1, distance_unit="pixel") + # args = {"reference_elev": other_dem, "coreg_method": coreg.LZD()} + elif method == "estimate_uncertainty": + # Temporary skip until uncertainty module is adapted + return + # other_dem = dem.copy() + # args = {"other_elev": other_dem} + + # Apply method for each class + output_dem = getattr(dem, method)(**args) + output_ds = getattr(ds.dem, method)(**args) + + # Determine if operation was in-place or not + inplace = method in ["set_vcrs"] + + # If yes, outputs should be None, and we'll check loading behaviour for inputs as if they were outputs + if inplace: + assert output_dem is None + assert output_ds is None + output_ds = ds + output_dem = dem + # If no, we check input status + else: + # Check using method did or did not load the input DEM or Xarray dataset, following expected values + should_input_be_loaded = should_be_loaded( + method=method, + args=args, + noload=self.methods_input_noload, + noload_allowed_args=self.methods_input_noload_allowed_args, + ) + assert dem.is_loaded is should_input_be_loaded + assert ds._in_memory is should_input_be_loaded + + # In the case of a DEM / DataArray output, check if output is loaded or not + # (for in-place methods, we now check the mutated input objects) + if isinstance(output_ds, xr.DataArray): + should_output_be_loaded = should_be_loaded( + method=method, + args=args, + noload=self.methods_output_noload, + noload_allowed_args=self.methods_output_noload_allowed_args, + ) + assert output_dem.is_loaded is should_output_be_loaded + assert output_ds._in_memory is should_output_be_loaded + + # Finally, assert exact equality of outputs + # (in case of DEM; this will load all the data, so has to come at the end) + assert_output_equal(output_dem, output_ds) + + class_methods_and_kwargs = [ + ( + "from_array", + { + "data": np.ones((5, 5)), + "transform": DEM.from_array( + data=np.ones((5, 5)), + transform=(1, 0, 0, 0, -1, 5), + crs=CRS.from_epsg(4326), + ).transform, + "crs": CRS.from_epsg(4326), + "nodata": -9999, + "tags": {"metadata": "test"}, + "area_or_point": "Point", + }, + ), + ] + + @pytest.mark.parametrize("method, kwargs", [(f, k) for f, k in class_methods_and_kwargs]) + def test_classmethods__equality(self, method: str, kwargs: dict[str, Any]) -> None: + """Test class methods output exactly the same objects. Loading always happens for class methods.""" + + # Accessor only uses this internally, but we expose it as a class method anyway + output_dem = getattr(DEM, method)(**kwargs) + output_ds = getattr(DEMAccessor, method)(**kwargs) + + assert_output_equal(output_dem, output_ds) + + def test_methods__test_coverage(self) -> None: + """Test that checks that all existing DEMBase methods are tested above.""" + + # Compare tested methods from above list of tuples to all methods derived from class dictionary + methods_1 = [m[0] for m in self.methods_and_kwargs] + methods_2 = [m[0] for m in self.class_methods_and_kwargs] + list_missing = [method for method in self.methods if method not in methods_1 + methods_2] + + if len(list_missing) != 0: + raise AssertionError(f"DEMBase methods not covered by tests: {list_missing}") + + chunked_methods_and_args = ( + ("to_vcrs", {"vcrs": "EGM96", "force_source_vcrs": "Ellipsoid"}), + ("slope", {}), + ("aspect", {}), + ("hillshade", {}), + ("curvature", {}), + ("profile_curvature", {}), + ("planform_curvature", {}), + ("tangential_curvature", {}), + ("flowline_curvature", {}), + ("max_curvature", {}), + ("min_curvature", {}), + ("texture_shading", {}), + ("topographic_position_index", {}), + ("terrain_ruggedness_index", {}), + ("roughness", {}), + ("rugosity", {}), + ("fractal_roughness", {}), + ("get_terrain_attribute", {"attribute": ["slope", "aspect"]}), + ) + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore + @pytest.mark.parametrize("method, kwargs", [(f, k) for f, k in chunked_methods_and_args]) # type: ignore + def test_chunked_methods__equality_loading_laziness( + self, path_dem: str, method: str, kwargs: dict[str, Any] + ) -> None: + """ + Test that DEMBase-specific chunked methods have the exact same output, loading mechanism and laziness. + + They should yield the exact same output for: + - In-memory, + - Dask backend through Xarray accessor, + - Multiprocessing backend through DEM class. + + Dask array should remain delayed before compute, and Multiprocessing output remains unloaded. + """ + + pytest.importorskip("dask") + import dask.array as da + + # Open lazily with Dask + ds = open_dem(path_dem, chunks={"band": 1, "x": 25, "y": 25}) + # Open DEM that will be processed using Multiprocessing + mp_config = MultiprocConfig(chunk_size=25) + dem = DEM(path_dem) + # Open and load both DataArray/DEM with NumPy + ds2 = open_dem(path_dem) + ds2.load() + dem2 = DEM(path_dem) + dem2.load() + + args = kwargs.copy() + + # Apply method for each + output_dem = getattr(dem, method)(**args, mp_config=mp_config) + output_ds = getattr(ds.dem, method)(**args) + output_dem2 = getattr(dem2, method)(**args) + output_ds2 = getattr(ds2.dem, method)(**args) + + # For a raster-type output + if isinstance(output_dem, Raster): + + # 1/ For Dask object: both inputs and outputs should be unloaded + lazy, and compute + # Input + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + assert ds.data.chunks is not None + # Output + assert not output_ds._in_memory + assert isinstance(output_ds.data, da.Array) + assert output_ds.data.chunks is not None + # Output computes successfully, and is then loaded in memory + output_ds = output_ds.compute() + assert isinstance(output_ds.data, np.ndarray) + assert output_ds._in_memory + + # 2/ For Multiprocessing, same for loading + assert not dem.is_loaded + assert not output_dem.is_loaded + + # 3/ For non-Dask array, both should be loaded + assert ds2._in_memory + assert isinstance(ds2.data, np.ndarray) + assert output_ds2._in_memory + assert isinstance(output_ds2.data, np.ndarray) + + # 4/ For DEM, same + assert dem2.is_loaded + assert output_dem2.is_loaded + + # For an array-type output + elif isinstance(output_dem, np.ndarray): + + # 1/ For Dask object: both inputs and outputs should be unloaded + lazy, and compute + # Input + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + assert ds.data.chunks is not None + # Output + assert isinstance(output_ds, da.Array) + assert output_ds.chunks is not None + # Output computes successfully, and is then loaded in memory + output_ds = output_ds.compute() + assert isinstance(output_ds, np.ndarray) + + # 2/ For Multiprocessing, same for loading + assert not dem.is_loaded + assert isinstance(output_dem, np.ndarray) + + # 3/ For non-Dask array, both should be loaded + assert ds2._in_memory + assert isinstance(ds2.data, np.ndarray) + assert isinstance(output_ds2, np.ndarray) + + # 4/ For DEM, same + assert dem2.is_loaded + assert isinstance(output_dem2, np.ndarray) + + # For a list of raster-type outputs + elif isinstance(output_dem, list): + + # Dask output may be a list of delayed raster-like objects + assert not ds._in_memory + assert not dem.is_loaded + + output_ds = [out.compute() if hasattr(out, "compute") else out for out in output_ds] + + assert ds2._in_memory + assert dem2.is_loaded + + # Check all outputs are exactly the same + # Texture shading currently does not give an identical output when chunked + if method == "texture_shading": + return + assert_output_equal(output_dem, output_ds, use_allclose=True) + assert_output_equal(output_dem, output_dem2, use_allclose=True, strict_masked=False) + assert_output_equal(output_dem, output_ds2, use_allclose=True) \ No newline at end of file diff --git a/tests/test_dem.py b/tests/test_dem/test_dem.py similarity index 88% rename from tests/test_dem.py rename to tests/test_dem/test_dem.py index 45c83c8ad..94c4ab062 100644 --- a/tests/test_dem.py +++ b/tests/test_dem/test_dem.py @@ -45,7 +45,7 @@ def test_init(self) -> None: # Check all attributes attrs = [at for at in _default_rio_attrs if at not in ["name", "dataset_mask", "driver"]] - all_attrs = attrs + xdem.dem.dem_attrs + all_attrs = attrs + ["vcrs"] for attr in all_attrs: attrs_per_dem = [idem.__getattribute__(attr) for idem in list_dem] assert all(at == attrs_per_dem[0] for at in attrs_per_dem) @@ -85,7 +85,7 @@ def test_init__vcrs(self, tmp_path: Path) -> None: # Setting a vertical CRS during instantiation should work here dem = DEM(fn_img, vcrs="EGM96") - assert dem.vcrs_name == "EGM96 height" + assert dem._vcrs_name == "EGM96 height" # Tests 2: instantiation with a file that has a 3D CRS # Create such a file @@ -103,8 +103,7 @@ def test_init__vcrs(self, tmp_path: Path) -> None: # Check that a warning is raised when trying to override with user input with pytest.warns( UserWarning, - match="The CRS in the raster metadata already has a vertical component, " - "the user-input 'EGM08' will override it.", + match="The CRS in the raster metadata.*", ): DEM(temp_file, vcrs="EGM08") @@ -116,8 +115,7 @@ def test_from_array(self) -> None: transform = rio.transform.from_bounds(0, 0, 1, 1, 5, 5) crs = CRS("EPSG:4326") nodata = -9999 - vcrs = "EGM08" - dem = DEM.from_array(data=data, transform=transform, crs=crs, nodata=nodata, vcrs=vcrs) + dem = DEM.from_array(data=data, transform=transform, crs=crs, nodata=nodata) # Check output matches assert isinstance(dem, DEM) @@ -126,37 +124,6 @@ def test_from_array(self) -> None: assert dem.transform == transform assert dem.crs == crs assert dem.nodata == nodata - assert dem.vcrs == xdem.vcrs._vcrs_from_user_input(vcrs_input=vcrs) - - def test_from_array__vcrs(self) -> None: - """Test that overridden from_array rightly sets the vertical CRS.""" - - # Create a 5x5 DEM with a 2D CRS - transform = rio.transform.from_bounds(0, 0, 1, 1, 5, 5) - dem = DEM.from_array(data=np.ones((5, 5)), transform=transform, crs=CRS("EPSG:4326"), nodata=None, vcrs=None) - assert dem.vcrs is None - - # One with a 3D ellipsoid CRS - dem = DEM.from_array(data=np.ones((5, 5)), transform=transform, crs=CRS("EPSG:4979"), nodata=None, vcrs=None) - assert dem.vcrs == "Ellipsoid" - - # One with a 2D and the ellipsoid vertical CRS - dem = DEM.from_array( - data=np.ones((5, 5)), transform=transform, crs=CRS("EPSG:4326"), nodata=None, vcrs="Ellipsoid" - ) - assert dem.vcrs == "Ellipsoid" - - # One with a compound CRS - dem = DEM.from_array( - data=np.ones((5, 5)), transform=transform, crs=CRS("EPSG:4326+5773"), nodata=None, vcrs=None - ) - assert dem.vcrs == CRS("EPSG:5773") - - # One with a CRS and vertical CRS - dem = DEM.from_array( - data=np.ones((5, 5)), transform=transform, crs=CRS("EPSG:4326"), nodata=None, vcrs=CRS("EPSG:5773") - ) - assert dem.vcrs == CRS("EPSG:5773") def test_from_array__cast_mask(self) -> None: """Test that DEMs are cast into mask for a logical operation.""" @@ -193,7 +160,7 @@ def test_copy(self) -> None: # using list directly available in Class attrs = [at for at in _default_rio_attrs if at not in ["name", "dataset_mask", "driver", "profile"]] - all_attrs = attrs + xdem.dem.dem_attrs + all_attrs = attrs + ["vcrs"] for attr in all_attrs: assert r.__getattribute__(attr) == r2.__getattribute__(attr) @@ -222,31 +189,30 @@ def test_set_vcrs(self) -> None: # Check setting ellipsoid dem.set_vcrs(new_vcrs="Ellipsoid") - assert dem.vcrs_name is not None - assert "Ellipsoid (No vertical CRS)." in dem.vcrs_name - assert dem.vcrs_grid is None + assert dem._vcrs_name is not None + assert "Ellipsoid (No vertical CRS)." in dem._vcrs_name # Check setting EGM96 dem.set_vcrs(new_vcrs="EGM96") - assert dem.vcrs_name == "EGM96 height" - assert dem.vcrs_grid == "us_nga_egm96_15.tif" + assert dem._vcrs_name == "EGM96 height" + assert dem._vcrs_grid is None # Check setting EGM08 dem.set_vcrs(new_vcrs="EGM08") - assert dem.vcrs_name == "EGM2008 height" - assert dem.vcrs_grid == "us_nga_egm08_25.tif" + assert dem._vcrs_name == "EGM2008 height" + assert dem._vcrs_grid is None # -- Test 2: we check with grids -- # Most grids aren't going to be downloaded, so this warning can be raised warnings.filterwarnings("ignore", category=UserWarning, message="Grid .*") dem.set_vcrs(new_vcrs="us_nga_egm96_15.tif") - assert dem.vcrs_name == "unknown using geoidgrids=us_nga_egm96_15.tif" - assert dem.vcrs_grid == "us_nga_egm96_15.tif" + assert dem._vcrs_name == "unknown using geoidgrids=us_nga_egm96_15.tif" + assert dem._vcrs_grid == "us_nga_egm96_15.tif" dem.set_vcrs(new_vcrs="us_nga_egm08_25.tif") - assert dem.vcrs_name == "unknown using geoidgrids=us_nga_egm08_25.tif" - assert dem.vcrs_grid == "us_nga_egm08_25.tif" + assert dem._vcrs_name == "unknown using geoidgrids=us_nga_egm08_25.tif" + assert dem._vcrs_grid == "us_nga_egm08_25.tif" # Check that other existing grids are well detected in the pyproj.datadir dem.set_vcrs(new_vcrs="is_lmi_Icegeoid_ISN93.tif") @@ -272,13 +238,14 @@ def test_to_vcrs(self) -> None: # Set ellipsoid as vertical reference dem.set_vcrs(new_vcrs="Ellipsoid") - ccrs_init = dem.ccrs + crs_init = dem.crs median_before = np.nanmean(dem) # Transform to EGM96 geoid not inplace (default) trans_dem = dem.to_vcrs(vcrs="EGM96") - # The output should be a DEM, input shouldn't have changed + # The output should be a DEM, input shouldn't have changed except the CRS into 3D assert isinstance(trans_dem, DEM) + dem_before_trans.set_crs(dem.crs) assert dem.raster_equal(dem_before_trans) # Compare to inplace @@ -293,8 +260,8 @@ def test_to_vcrs(self) -> None: assert median_after - median_before == pytest.approx(-32, rel=0.1) # Check that the results are consistent with the operation done independently - ccrs_dest = xdem.vcrs._build_ccrs_from_crs_and_vcrs(dem.crs, xdem.vcrs._vcrs_from_user_input("EGM96")) - transformer = Transformer.from_crs(crs_from=ccrs_init, crs_to=ccrs_dest, always_xy=True) + crs_dest = xdem.vcrs._build_ccrs_from_crs_and_vcrs(dem.crs, xdem.vcrs._vcrs_from_user_input("EGM96")) + transformer = Transformer.from_crs(crs_from=crs_init, crs_to=crs_dest, always_xy=True) xx, yy = dem.coords() x = xx[5, 5] @@ -372,7 +339,7 @@ def test_terrain_attributes_wrappers(self, terrain_attribute: str) -> None: assert dem_class_attr.raster_equal(terrain_module_attr) def test_info_2dcrs(self) -> None: - """Tests info function with the new Coordinate system line on dem with 2D CRS""" + """Tests info function through GeoUtils, for 3D info.""" dem_path = xdem.examples.get_path_test("longyearbyen_ref_dem") raster = gu.Raster(dem_path) @@ -399,14 +366,14 @@ def test_info_2dcrs(self) -> None: assert raster_infos_arrays[line] == dem_infos_array[line] # Verify Coordinate system value - assert complete_line[len(crs_key) :].strip() == "['EPSG:25833', 'None']" + assert complete_line[len(crs_key):].strip() == "['ETRS89 / UTM zone 33N']" # Verify new VCRS value with this 2D CRS DEM dem.set_vcrs(new_vcrs="EGM96") dem_infos_array = dem.info(verbose=False).split("\n") complete_line = dem_infos_array[crs_line[0]] assert complete_line.startswith(crs_key) - assert complete_line[len(crs_key) :].strip() == "['EPSG:25833', 'EPSG:5773']" + assert complete_line[len(crs_key):].strip() == "['Horizontal: ETRS89 / UTM zone 33N; Vertical: EGM96 height']" @pytest.mark.skip() def test_info_3dcrs(self) -> None: diff --git a/tests/test_dem/test_xr_accessor.py b/tests/test_dem/test_xr_accessor.py new file mode 100644 index 000000000..05e46f941 --- /dev/null +++ b/tests/test_dem/test_xr_accessor.py @@ -0,0 +1,83 @@ +""" +Test module for 'dem' Xarray accessor mirroring DEM API. +Most function tests are actually located in "test_base", to check consistently for equality, loading and lazy behaviour +across the entire API. +""" +import numpy as np +import pytest + +from xdem import examples, open_dem + + +class TestAccessor: + """ + Test for Xarray accessor subclass. + + Note: This test class only tests functionalities that are specific to the DEMAccessor subclass. Overridden + abstract methods, loading behaviour and Dask laziness are tested in test_base directly to mirror DEM tests. + + This class thus tests: + - The open_dem function, + - The instantiation __init__ through ds.dem, + - The to_geoutils() method. + """ + + longyearbyen_path = examples.get_path_test("longyearbyen_ref_dem") + + def test_open_raster(self) -> None: + pass + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) + def test_copy(self, path_dem: str) -> None: + + ds = open_dem(path_dem) + ds_copy = ds.rst.copy() + + assert np.array_equal(ds.data, ds_copy.data, equal_nan=True) + assert ds.rst.transform == ds_copy.rst.transform + assert ds.rst.crs == ds_copy.rst.crs + assert ds.rst.nodata == ds_copy.rst.nodata + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) + def test_open__loaded(self, path_dem: str) -> None: + """ + Test that a DataArray opened using "open_raster" maintains implicit loading logic. + + Tests checking loading for all attributes and methods are done in TestBase. + + Note: this is different from using lazy Dask arrays: for any array type, Xarray only loads metadata, and + implicitly loads data in memory when .data or .load() is called. + """ + + # Open raster with/without chunks, should not load in memory yet + ds = open_dem(path_dem) + assert not ds._in_memory + + # The array should be NumPy + assert isinstance(ds.data, np.ndarray) + ds.load() + assert ds._in_memory + + @pytest.mark.parametrize("path_dem", [longyearbyen_path]) + def test_open__dask(self, path_dem: str) -> None: + """ + Check that a DataArray opened with chunks using "open_raster" maintains Dask laziness. + + Note: this is different from loading mechanism of Xarray (triggers when calling .data). + """ + pytest.importorskip("dask") + import dask.array as da + + # Open raster lazily with chunks + ds = open_dem(path_dem, chunks={"band": 1, "x": 10, "y": 10}) + + # Array should be a Dask array (chunks exist) + ds_arr = ds.data + assert not ds._in_memory + assert isinstance(ds_arr, da.Array) + assert ds_arr.chunks is not None + + # After compute, it should be a NumPy array + ds_comp = ds.compute() + assert isinstance(ds_comp.data, np.ndarray) + assert ds_comp._in_memory diff --git a/tests/test_epc/test_epc.py b/tests/test_epc/test_epc.py index ebe61afbc..dae20d439 100644 --- a/tests/test_epc/test_epc.py +++ b/tests/test_epc/test_epc.py @@ -339,17 +339,17 @@ def test_coregister_3d(coreg_method: Any, expected_pipeline_types: Any) -> None: dem_ref = DEM(fn_ref) dem_tba = DEM(fn_tba) - epc_tba = dem_tba.to_pointcloud(subsample=5000, random_state=42) - epc_ref = dem_ref.to_pointcloud(subsample=5000, random_state=42) + epc_tba = dem_tba.to_pointcloud() + epc_ref = dem_ref.to_pointcloud() # Run coregistration with EPC as reference - epc_aligned = epc_tba.coregister_3d(dem_ref, coreg_method=coreg_method, random_state=42) + epc_aligned = epc_tba.coregister_3d(dem_ref, coreg_method=coreg_method) assert isinstance(epc_aligned, xdem.EPC) assert isinstance(coreg_method, xdem.coreg.Coreg) # Run coregistration with EPC as to-be-aligned - dem_aligned = dem_tba.coregister_3d(epc_ref, coreg_method=coreg_method, random_state=42) + dem_aligned = dem_tba.coregister_3d(epc_ref, coreg_method=coreg_method) assert isinstance(dem_aligned, xdem.DEM) assert isinstance(coreg_method, xdem.coreg.Coreg) @@ -368,7 +368,7 @@ def test_coregister_3d__raises(self) -> None: dem_ref = DEM(fn_ref) dem_tba = DEM(fn_tba) - epc_tba = dem_tba.to_pointcloud(subsample=5000) + epc_tba = dem_tba.to_pointcloud() coreg_method = xdem.coreg.Deramp() diff --git a/tests/test_terrain/test_terrain.py b/tests/test_terrain/test_terrain.py index 9f8a611c1..e4147508b 100644 --- a/tests/test_terrain/test_terrain.py +++ b/tests/test_terrain/test_terrain.py @@ -9,7 +9,7 @@ import numpy as np import pytest import rasterio as rio -from geoutils.raster.distributed_computing import MultiprocConfig +from geoutils.multiproc import MultiprocConfig from pyproj import CRS import xdem diff --git a/tests/test_vcrs.py b/tests/test_vcrs.py index 357032879..f97318147 100644 --- a/tests/test_vcrs.py +++ b/tests/test_vcrs.py @@ -8,12 +8,15 @@ from typing import Any import numpy as np +import xarray as xr import pytest from pyproj import CRS +import geoutils as gu +from geoutils.multiproc import MultiprocConfig import xdem import xdem.vcrs - +from xdem import examples class TestVCRS: def test_parse_vcrs_name_from_product(self) -> None: @@ -54,6 +57,27 @@ def test_vcrs_from_crs(self, input_output: tuple[CRS, CRS]) -> None: else: assert vcrs is None + @pytest.mark.parametrize( + "crs, expected", + [ + # Compound CRS with vertical meters + (CRS("EPSG:4326+5773"), "m"), # WGS84 + EGM96 height + # Vertical CRS alone + (CRS("EPSG:5773"), "m"), # EGM96 height + # Compound CRS projected + vertical + (CRS("EPSG:32633+5773"), "m"), # UTM 33N + EGM96 height + # Vertical CRS in feet (NAVD88) + (CRS("EPSG:6360"), "ftUS"), + # Pure 2D CRS (no vertical axis) + (CRS("EPSG:4326"), None), + (CRS("EPSG:32633"), None), + ], + ) + def test_vertical_unit_symbol(self, crs: CRS, expected: str | None) -> None: + """Test extraction of vertical unit symbols from CRS.""" + + assert xdem.vcrs.vertical_unit_symbol(crs) == expected + @pytest.mark.parametrize( "vcrs_input", [ @@ -72,7 +96,7 @@ def test_vcrs_from_user_input(self, vcrs_input: str | pathlib.Path | int | CRS) warnings.filterwarnings("ignore", category=UserWarning, message="Grid .*") # Get user input - vcrs = xdem.dem._vcrs_from_user_input(vcrs_input) + vcrs = xdem.vcrs._vcrs_from_user_input(vcrs_input) # Check output type assert isinstance(vcrs, CRS) @@ -145,7 +169,7 @@ def test_build_vcrs_from_grid__errors(self) -> None: # Test for WGS84 in 2D and 3D, UTM, CompoundCRS, everything should work @pytest.mark.parametrize("crs", [CRS("EPSG:4326"), CRS("EPSG:4979"), CRS("32610"), CRS("EPSG:4326+5773")]) - @pytest.mark.parametrize("vcrs_input", [CRS("EPSG:5773"), "is_lmi_Icegeoid_ISN93.tif", "EGM96"]) + @pytest.mark.parametrize("vcrs_input", [CRS("EPSG:5773"), "is_lmi_Icegeoid_ISN93.tif", "EGM96", "Ellipsoid"]) def test_build_ccrs_from_crs_and_vcrs(self, crs: CRS, vcrs_input: CRS | str) -> None: """Test the function build_ccrs_from_crs_and_vcrs.""" @@ -180,7 +204,8 @@ def test_build_ccrs_from_crs_and_vcrs(self, crs: CRS, vcrs_input: CRS | str) -> ccrs = xdem.vcrs._build_ccrs_from_crs_and_vcrs(crs=crs, vcrs=vcrs) assert isinstance(ccrs, CRS) - assert ccrs.is_vertical + is_3d = len(ccrs.axis_info) == 3 + assert is_3d def test_build_ccrs_from_crs_and_vcrs__errors(self) -> None: """Test errors are correctly raised from the build_ccrs function.""" @@ -213,12 +238,97 @@ def test_transform_zz(self, grid_shifts: dict[str, Any]) -> None: # Build the compound CRS vcrs_to = xdem.vcrs._vcrs_from_user_input(vcrs_input=grid_shifts["grid"]) ccrs_to = xdem.vcrs._build_ccrs_from_crs_and_vcrs(crs=crs_from, vcrs=vcrs_to) - + transformer = xdem.vcrs._build_vertical_transformer(crs_from=ccrs_from, crs_to=ccrs_to) # Apply the transformation - zz_trans = xdem.vcrs._transform_zz(crs_from=ccrs_from, crs_to=ccrs_to, xx=xx, yy=yy, zz=zz) + zz_trans = xdem.vcrs._transform_zz(transformer=transformer, xx=xx, yy=yy, zz=zz) # Compare the elevation difference z_diff = 100 - zz_trans # Check the shift is the one expect within 10% assert z_diff == pytest.approx(grid_shifts["shift"], rel=0.1) + + +class TestToVCRSChunked: + + @pytest.mark.parametrize( + "force_source_vcrs, dst_vcrs", + [ + ("EGM96", "Ellipsoid"), + ("Ellipsoid", "EGM96"), + ], + ids=["egm96_to_ellipsoid", "ellipsoid_to_egm96"], + ) + def test_to_vcrs_chunked_backends_equal( + self, + force_source_vcrs: str, + dst_vcrs: str, + ) -> None: + """ + Test that to_vcrs yields identical or nearly identical output for base (in-memory), + chunked Dask, and Multiprocessing backends. + + Notes: + - Vertical transforms are pointwise, so outputs should generally match exactly. + - We still use a small tolerance to remain robust to backend-specific casting/order. + """ + + pytest.importorskip("dask") + import dask.array as da + + # Get DEM path + path_dem = examples.get_path_test("longyearbyen_ref_dem") + + # 1/ Open test files + # DEM base input (in-memory) + dem_base = xdem.DEM(path_dem) + dem_base.load() + + # Xarray base input (in-memory data array) + xr_base = gu.open_raster(path_dem) + xr_base.load() + + # Multiprocessing input (keep lazy) + dem_mp = xdem.DEM(path_dem) + mp_config = MultiprocConfig(chunk_size=10) + + # Dask input (lazy) + ds = gu.open_raster(path_dem, chunks={"x": 10, "y": 10}) + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + assert ds.data.chunks is not None + + # 2/ Compute transforms and check output laziness + # DEM base + base_dem = dem_base.to_vcrs(vcrs=dst_vcrs, force_source_vcrs=force_source_vcrs) + assert isinstance(base_dem, xdem.DEM) + + # Xarray base + base_xr = xr_base.dem.to_vcrs(vcrs=dst_vcrs, force_source_vcrs=force_source_vcrs) + assert isinstance(base_xr, xr.DataArray) + + # Multiprocessing + mp_dem = dem_mp.to_vcrs( + vcrs=dst_vcrs, + force_source_vcrs=force_source_vcrs, + mp_config=mp_config, + ) + assert isinstance(mp_dem, xdem.DEM) + assert not mp_dem.is_loaded + + # Dask + dask_dem = ds.dem.to_vcrs(vcrs=dst_vcrs, force_source_vcrs=force_source_vcrs) + assert isinstance(dask_dem, xr.DataArray) + assert isinstance(dask_dem.data, da.Array) + + # Inputs also stay lazy where expected + assert not dem_mp.is_loaded + assert not ds._in_memory + assert isinstance(ds.data, da.Array) + + # 3/ Compare outputs + # Vertical CRS transform is pointwise, so differences should be tiny + dask_dem = dask_dem.compute() + assert base_dem.raster_equal(dask_dem, warn_failure_reason=True, strict_masked=False) + assert base_dem.raster_equal(mp_dem, warn_failure_reason=True, strict_masked=False) + assert base_dem.raster_equal(base_xr, warn_failure_reason=True, strict_masked=False) \ No newline at end of file diff --git a/tests/test_workflows/test_topo.py b/tests/test_workflows/test_topo.py index 70fb2d73c..24b9ead5b 100644 --- a/tests/test_workflows/test_topo.py +++ b/tests/test_workflows/test_topo.py @@ -157,7 +157,6 @@ def test_run(get_topo_inputs_config, tmp_path): "Data types": "float32", "Driver": "GTiff", "Filename": xdem.examples.get_path_test("longyearbyen_tba_dem"), - "Grid size": None, "Height": 54, "Nodata Value": -9999.0, "Number of band": (1,), diff --git a/tests/test_workflows/test_workflows.py b/tests/test_workflows/test_workflows.py index f93a3dae5..3fbc798b9 100644 --- a/tests/test_workflows/test_workflows.py +++ b/tests/test_workflows/test_workflows.py @@ -269,7 +269,7 @@ def test_load_dem(get_dem_config, from_vcrs, to_vcrs): # Check output_dem vcrs reference if to_vcrs == "EGM96" or (to_vcrs is None and from_vcrs == "EGM96"): - assert output_dem.vcrs_name == "EGM96 height" + assert output_dem._vcrs_name == "EGM96 height" elif to_vcrs == "Ellipsoid" or (to_vcrs is None and from_vcrs == "Ellipsoid"): assert output_dem.vcrs == "Ellipsoid" else: @@ -277,7 +277,10 @@ def test_load_dem(get_dem_config, from_vcrs, to_vcrs): # Check output_dem if from_vcrs == to_vcrs: - assert output_dem.raster_equal(input_dem) + # Need to convert input to the forced CRS, if it exists + if from_vcrs is not None: + input_dem.set_vcrs(from_vcrs) + assert output_dem.raster_equal(input_dem, warn_failure_reason=True) # About 32 meters of difference in Svalbard between EGM96 geoid and ellipsoid if to_vcrs == "Ellipsoid" and from_vcrs == "EGM96": diff --git a/xdem/__init__.py b/xdem/__init__.py index 81fd1a794..33011bd3b 100644 --- a/xdem/__init__.py +++ b/xdem/__init__.py @@ -18,7 +18,8 @@ from xdem import coreg, dem, examples, fit, spatialstats, terrain, volume # noqa from xdem.ddem import dDEM # noqa -from xdem.dem import DEM # noqa +from xdem.dem import DEM, xr_accessor # noqa +from xdem.dem.xr_accessor import open_dem # noqa from xdem.demcollection import DEMCollection # noqa from xdem.epc import EPC # noqa diff --git a/xdem/coreg/affine.py b/xdem/coreg/affine.py index 8d6430a5e..ab8bf9b0c 100644 --- a/xdem/coreg/affine.py +++ b/xdem/coreg/affine.py @@ -33,8 +33,8 @@ import scipy.optimize import scipy.spatial from geoutils._typing import Number -from geoutils.interface.interpolate import _interp_points -from geoutils.raster.georeferencing import _coords, _res +from geoutils.interface.interpolation import _interp_points_base +from geoutils.raster.referencing import _coords, _res from geoutils.stats import nmad from xdem._misc import get_progress @@ -206,7 +206,7 @@ def sub_dh_interpolator(shift_x: float, shift_y: float) -> NDArrayf: # Interpolate raster array to the subsample point coordinates # Convert ref or tba depending on which is the point dataset - rst_elev_interpolator = _interp_points( + rst_elev_interpolator = _interp_points_base( array=rst_elev, transform=transform, area_or_point=area_or_point, @@ -234,7 +234,7 @@ def sub_dh_interpolator(shift_x: float, shift_y: float) -> NDArrayf: if aux_vars is not None: sub_bias_vars = {} for var in aux_vars.keys(): - sub_bias_vars[var] = _interp_points( + sub_bias_vars[var] = _interp_points_base( array=aux_vars[var], transform=transform, points=sub_coords, area_or_point=area_or_point ) else: diff --git a/xdem/coreg/base.py b/xdem/coreg/base.py index a5e8b90ce..d1fb605b4 100644 --- a/xdem/coreg/base.py +++ b/xdem/coreg/base.py @@ -49,13 +49,13 @@ import scipy.optimize from geoutils import profiler from geoutils.interface.gridding import _grid_pointcloud -from geoutils.interface.interpolate import _interp_points +from geoutils.interface.interpolation import _interp_points_base from geoutils.pointcloud.pointcloud import PointCloud, PointCloudType from geoutils.raster import Raster, RasterType, raster -from geoutils.raster._geotransformations import _resampling_method_from_str +from geoutils.raster.transformation import _resampling_method_from_str from geoutils.raster.array import get_array_and_mask -from geoutils.raster.georeferencing import _cast_pixel_interpretation, _coords -from geoutils.raster.geotransformations import _translate +from geoutils.raster.referencing import _cast_pixel_interpretation, _coords +from geoutils.raster.transformation import _translate import xdem from xdem._typing import MArrayf, NDArrayb, NDArrayf @@ -597,12 +597,14 @@ def _get_subsample_on_valid_mask(params_random: InRandomDict, valid_mask: NDArra # Build a low memory masked array with invalid values masked to pass to subsampling ma_valid = np.ma.masked_array(data=np.ones(np.shape(valid_mask), dtype=bool), mask=~valid_mask) # Take a subsample within the valid values - indices = gu.stats.sampling.subsample_array( - ma_valid, - subsample=params_random["subsample"], - return_indices=True, - random_state=params_random["random_state"], - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + indices = gu.stats.sampling._subsample_numpy( + ma_valid, + subsample=params_random["subsample"], + return_indices=True, + random_state=params_random["random_state"], + ) # We return a boolean mask of the subsample within valid values subsample_mask = np.zeros(np.shape(valid_mask), dtype=bool) @@ -697,7 +699,7 @@ def _get_subsample_mask_pts_rst( valid_mask = valid_mask.astype(np.float32) valid_mask[valid_mask == 0] = np.nan valid_mask = np.isfinite( - _interp_points(array=valid_mask, transform=transform, points=pts, area_or_point=area_or_point) + _interp_points_base(array=valid_mask, transform=transform, points=pts, area_or_point=area_or_point) ) # If there is a subsample, it needs to be done now on the point dataset to reduce later calculations @@ -760,7 +762,7 @@ def _subsample_on_mask( # Interpolate raster array to the subsample point coordinates # Convert ref or tba depending on which is the point dataset - sub_rst = _interp_points(array=rst_elev, transform=transform, points=pts, area_or_point=area_or_point) + sub_rst = _interp_points_base(array=rst_elev, transform=transform, points=pts, area_or_point=area_or_point) sub_pts = pts_elev[z_name].values[sub_mask] # Assign arrays depending on which one is the reference @@ -775,7 +777,7 @@ def _subsample_on_mask( if aux_vars is not None: sub_bias_vars = {} for var in aux_vars.keys(): - sub_bias_vars[var] = _interp_points( + sub_bias_vars[var] = _interp_points_base( array=aux_vars[var], transform=transform, points=pts, area_or_point=area_or_point ) else: @@ -1639,7 +1641,7 @@ def _reproject_horizontal_shift_samecrs( else: coords_dst = None - output = _interp_points( + output = _interp_points_base( array=raster_arr, area_or_point="Area", transform=src_transform, diff --git a/xdem/coreg/blockwise.py b/xdem/coreg/blockwise.py index 6856c9eb3..202d006d5 100644 --- a/xdem/coreg/blockwise.py +++ b/xdem/coreg/blockwise.py @@ -35,13 +35,12 @@ from geoutils.interface.gridding import _grid_pointcloud from geoutils.raster import Raster, RasterType from geoutils.raster.array import get_array_and_mask -from geoutils.raster.distributed_computing import ( +from geoutils.multiproc import ( MultiprocConfig, map_multiproc_collect, map_overlap_multiproc_save, + compute_tiling ) -from geoutils.raster.tiling import compute_tiling - from xdem._misc import import_optional from xdem._typing import MArrayf, NDArrayf from xdem.coreg.affine import NuthKaab diff --git a/xdem/dem/__init__.py b/xdem/dem/__init__.py new file mode 100644 index 000000000..f46018823 --- /dev/null +++ b/xdem/dem/__init__.py @@ -0,0 +1 @@ +from xdem.dem.dem import * # noqa diff --git a/xdem/dem.py b/xdem/dem/base.py similarity index 59% rename from xdem/dem.py rename to xdem/dem/base.py index 7fa2538d5..1d1c0b820 100644 --- a/xdem/dem.py +++ b/xdem/dem/base.py @@ -16,25 +16,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This module defines the DEM class.""" +"""Module of DEMBase class, parent of DEM class and 'dem' accessor.""" from __future__ import annotations import pathlib import warnings -from typing import Any, Callable, Literal, overload +import re +from typing import Any, Callable, Literal, overload, TypeVar, Union import geopandas as gpd import geoutils as gu import numpy as np -import rasterio as rio from affine import Affine from geoutils import profiler from geoutils._typing import NDArrayNum +from geoutils._dispatch import has_geo_attr, get_geo_attr from geoutils.raster import Raster, RasterType -from geoutils.raster.distributed_computing import MultiprocConfig +from geoutils.raster.base import RasterBase +from geoutils.multiproc import MultiprocConfig from geoutils.stats import nmad from pyproj import CRS +import xarray as xr from pyproj.crs import CompoundCRS, VerticalCRS import xdem @@ -48,239 +51,42 @@ ) from xdem.vcrs import ( _build_ccrs_from_crs_and_vcrs, - _grid_from_user_input, - _parse_vcrs_name_from_product, - _transform_zz, + _to_vcrs_2d, _vcrs_from_crs, _vcrs_from_user_input, ) -dem_attrs = ["_vcrs", "_vcrs_name", "_vcrs_grid"] +# Input/output is a RasterType (= Raster or RasterAccessor subclass) +DEMType = TypeVar("DEMType", bound="DEMBase") +# For inputs, we also accept a xr.DataArray +DEMLike = Union["DEMBase", xr.DataArray] - -class DEM(Raster): # type: ignore +class DEMBase(RasterBase): """ - The digital elevation model. - - The DEM has a single main attribute in addition to that inherited from :class:`geoutils.Raster`: - vcrs: :class:`pyproj.VerticalCRS` - Vertical coordinate reference system of the DEM. - - Other derivative attributes are: - vcrs_name: :class:`str` - Name of vertical CRS of the DEM. - vcrs_grid: :class:`str` - Grid path to the vertical CRS of the DEM. - ccrs: :class:`pyproj.CompoundCRS` - Compound vertical and horizontal CRS of the DEM. - - The attributes inherited from :class:`geoutils.Raster` are: - data: :class:`np.ndarray` - Data array of the DEM, with dimensions corresponding to (count, height, width). - transform: :class:`affine.Affine` - Geotransform of the DEM. - crs: :class:`pyproj.crs.CRS` - Coordinate reference system of the DEM. - nodata: :class:`int` or :class:`float` - Nodata value of the DEM. - - All other attributes are derivatives of those attributes, or read from the file on disk. - See the API for more details. + This class is non-public and made to be subclassed. + + It is built on top of the RasterBase class. It implements all the functions shared by the DEM class and the + 'dem' Xarray accessor. """ - @profiler.profile("xdem.dem.__init__", memprof=True) - def __init__( - self, - filename_or_dataset: str | RasterType | rio.io.DatasetReader | rio.io.MemoryFile, - vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | VerticalCRS | str | pathlib.Path | int | None = None, - load_data: bool = False, - parse_sensor_metadata: bool = False, - silent: bool = True, - downsample: int = 1, - nodata: int | float | None = None, - ) -> None: + def __init__(self): """ - Instantiate a digital elevation model. - - The vertical reference of the DEM can be defined by passing the `vcrs` argument. - Otherwise, a vertical reference is tentatively parsed from the DEM product name. - - Inherits all attributes from the :class:`geoutils.Raster` class. - - :param filename_or_dataset: The filename of the dataset. - :param vcrs: Vertical coordinate reference system either as a name ("WGS84", "EGM08", "EGM96"), - an EPSG code or pyproj.crs.VerticalCRS, or a path to a PROJ grid file (https://github.com/OSGeo/PROJ-data). - :param load_data: Whether to load the array during instantiation. Default is False. - :param parse_sensor_metadata: Whether to parse sensor metadata from filename and similarly-named metadata files. - :param silent: Whether to display vertical reference parsing. - :param downsample: Downsample the array once loaded by a round factor. Default is no downsampling. - :param nodata: Nodata value to be used (overwrites the metadata). Default reads from metadata. + Initialize additional DEM metadata as None, for it to be overridden in sublasses. """ - self.data: NDArrayf + super().__init__() self._vcrs: VerticalCRS | Literal["Ellipsoid"] | None = None - self._vcrs_name: str | None = None - self._vcrs_grid: str | None = None - - # If DEM is passed, simply point back to DEM - if isinstance(filename_or_dataset, DEM): - for key in filename_or_dataset.__dict__: - setattr(self, key, filename_or_dataset.__dict__[key]) - return - # Else rely on parent Raster class options (including raised errors) - else: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="Parse metadata from file not implemented") - super().__init__( - filename_or_dataset, - load_data=load_data, - parse_sensor_metadata=parse_sensor_metadata, - silent=silent, - downsample=downsample, - nodata=nodata, - ) - - # Ensure DEM has only one band: self.bands can be None when data is not loaded through the Raster class - if self.bands is not None and len(self.bands) > 1: - raise ValueError( - "DEM rasters should be composed of one band only. Either use argument `bands` to specify " - "a single band on opening, or use .split_bands() on an opened raster." - ) - - # If the CRS in the raster metadata has a 3rd dimension, could set it as a vertical reference - vcrs_from_crs = _vcrs_from_crs(CRS(self.crs)) - if vcrs_from_crs is not None: - # If something was also provided by the user, user takes precedence - # (we leave vcrs as it was for input) - if vcrs is not None: - # Raise a warning if the two are not the same - vcrs_user = _vcrs_from_user_input(vcrs) - if not vcrs_from_crs == vcrs_user: - warnings.warn( - "The CRS in the raster metadata already has a vertical component, " - "the user-input '{}' will override it.".format(vcrs) - ) - # Otherwise, use the one from the raster 3D CRS - else: - vcrs = vcrs_from_crs - - # If no vertical CRS was provided by the user or defined in the CRS - if vcrs is None and "product" in self.tags: - vcrs = _parse_vcrs_name_from_product(self.tags["product"]) - - # If a vertical reference was parsed or provided by user - if vcrs is not None: - self.set_vcrs(vcrs) - - @overload - def info(self, stats: bool = False, *, verbose: Literal[True] = ...) -> None: ... - - @overload - def info(self, stats: bool = False, *, verbose: Literal[False]) -> str: ... - - def info(self, stats: bool = False, verbose: bool = True) -> None | str: - """ - Print summary information about the DEM. - - :param stats: Add statistics for each band of the dataset (max, min, median, mean, std. dev.). Default is to - not calculate statistics. - :param verbose: If set to True (default) will directly print to screen and return None - - :returns: Summary string or None. - """ - - # Get raster.info() - raster_info = super().info(stats=stats, verbose=False) # type: ignore - raster_info_split = raster_info.split("\n") - - # Change crs values if not 3D - if len(CRS(self.crs).axis_info) > 2: - new_crs = [CRS(self.crs).name] - else: - new_crs = [self.crs.to_string() if self.crs is not None else None, str(self.vcrs)] - - # Replace coordinate system line - cs_key_to_replace = "Coordinate system:" - line_cs = [raster_info_split.index(line) for line in raster_info_split if line.startswith(cs_key_to_replace)] - raster_info_split[line_cs[0]] = f"Coordinate system: {new_crs}" - - if verbose: - print("\n".join(raster_info_split)) - return None - else: - return "\n".join(raster_info_split) - - def copy(self, new_array: NDArrayf | None = None) -> DEM: - """ - Copy the DEM, possibly updating the data array. - - :param new_array: New data array. - - :return: Copied DEM. - """ - - new_dem = super().copy(new_array=new_array) # type: ignore - # The rest of attributes are immutable, including pyproj.CRS - for attrs in dem_attrs: - setattr(new_dem, attrs, getattr(self, attrs)) - - return new_dem # type: ignore - - @classmethod - def from_array( - cls: type[DEM], - data: NDArrayf | MArrayf, - transform: tuple[float, ...] | Affine, - crs: CRS | int | None, - nodata: int | float | None = None, - area_or_point: Literal["Area", "Point"] | None = None, - tags: dict[str, Any] = None, - cast_nodata: bool = True, - vcrs: ( - Literal["Ellipsoid"] | Literal["EGM08"] | Literal["EGM96"] | str | pathlib.Path | VerticalCRS | int | None - ) = None, - ) -> DEM: - """Create a DEM from a numpy array and the georeferencing information. - - :param data: Input array. - :param transform: Affine 2D transform. Either a tuple(x_res, 0.0, top_left_x, - 0.0, y_res, top_left_y) or an affine.Affine object. - :param crs: Coordinate reference system. Either a rasterio CRS, or an EPSG integer. - :param nodata: Nodata value. - :param area_or_point: Pixel interpretation of the raster, will be stored in AREA_OR_POINT metadata. - :param tags: Metadata stored in a dictionary. - :param cast_nodata: Automatically cast nodata value to the default nodata for the new array type if not - compatible. If False, will raise an error when incompatible. - :param vcrs: Vertical coordinate reference system. - - :returns: DEM created from the provided array and georeferencing. - """ - rast = Raster.from_array( - data=data, - transform=transform, - crs=crs, - nodata=nodata, - area_or_point=area_or_point, - tags=tags, - cast_nodata=cast_nodata, - ) - - return cls(filename_or_dataset=rast, vcrs=vcrs) + self._data: NDArrayf @property def vcrs(self) -> VerticalCRS | Literal["Ellipsoid"] | None: - """Vertical coordinate reference system of the DEM.""" - - return self._vcrs - - @property - def vcrs_grid(self) -> str | None: - """Grid path of vertical coordinate reference system of the DEM.""" - - return self._vcrs_grid + """ + Vertical coordinate reference system of the DEM. + """ + return _vcrs_from_crs(self.crs) @property - def vcrs_name(self) -> str | None: + def _vcrs_name(self) -> str | None: """Name of vertical coordinate reference system of the DEM.""" if self.vcrs is not None: @@ -296,6 +102,24 @@ def vcrs_name(self) -> str | None: return vcrs_name + @property + def _vcrs_grid(self) -> str | None: + """Human-readable vertical grid description of the DEM.""" + + if self.vcrs is None or isinstance(self.vcrs, str): + return None + + vcrs = CRS(self.vcrs) + + try: + op = vcrs.coordinate_operation + if op is not None and op.grids: + return op.grids[0].short_name + except Exception: + pass + + return None + def set_vcrs( self, new_vcrs: Literal["Ellipsoid"] | Literal["EGM08"] | Literal["EGM96"] | str | pathlib.Path | VerticalCRS | int, @@ -307,124 +131,57 @@ def set_vcrs( an EPSG code or pyproj.crs.VerticalCRS, or a path to a PROJ grid file (https://github.com/OSGeo/PROJ-data). """ - # Get vertical CRS and set it and the grid - self._vcrs = _vcrs_from_user_input(vcrs_input=new_vcrs) - self._vcrs_grid = _grid_from_user_input(vcrs_input=new_vcrs) - - @property - def ccrs(self) -> CompoundCRS | CRS | None: - """Compound horizontal and vertical coordinate reference system of the DEM.""" - - if self.vcrs is not None: - ccrs = _build_ccrs_from_crs_and_vcrs(crs=self.crs, vcrs=self.vcrs) - return ccrs - else: - return None - - @overload - def to_vcrs( - self, - vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int, - force_source_vcrs: ( - Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int | None - ) = None, - *, - inplace: Literal[False] = False, - ) -> DEM: ... - - @overload - def to_vcrs( - self, - vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int, - force_source_vcrs: ( - Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int | None - ) = None, - *, - inplace: Literal[True], - ) -> None: ... - - @overload - def to_vcrs( - self, - vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int, - force_source_vcrs: ( - Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int | None - ) = None, - *, - inplace: bool = False, - ) -> DEM | None: ... + # Get vertical CRS and re-set the CRS + new_vcrs = _vcrs_from_user_input(vcrs_input=new_vcrs) + new_crs = _build_ccrs_from_crs_and_vcrs(crs=self.crs, vcrs=new_vcrs) + self.set_crs(new_crs) def to_vcrs( self, - vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int, - force_source_vcrs: ( - Literal["Ellipsoid", "EGM08", "EGM96"] | str | pathlib.Path | VerticalCRS | int | None - ) = None, - inplace: bool = False, - ) -> DEM | None: + vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | VerticalCRS | int, + force_source_vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | str | VerticalCRS | int | None = None, + mp_config: MultiprocConfig | None = None, + **kwargs: Any, + ) -> DEMLike: """ Convert the DEM to another vertical coordinate reference system. :param vcrs: Destination vertical CRS. Either as a name ("WGS84", "EGM08", "EGM96"), an EPSG code or pyproj.crs.VerticalCRS, or a path to a PROJ grid file (https://github.com/OSGeo/PROJ-data) :param force_source_vcrs: Force a source vertical CRS (uses metadata by default). Same formats as for `vcrs`. - :param inplace: Whether to return a new DEM (default) or the same DEM updated in-place. + :param mp_config: Multiprocessing configuration. :return: DEM with vertical reference transformed, or None. """ - if self.vcrs is None and force_source_vcrs is None: - raise ValueError( - "The current DEM has no vertical reference, define one with .set_vref() " - "or by passing `src_vcrs` to perform a conversion." - ) - - # Initial Compound CRS (only exists if vertical CRS is not None, as checked above) - if force_source_vcrs is not None: - # Warn if a vertical CRS already existed for that DEM - if self.vcrs is not None: - warnings.warn( - category=UserWarning, - message="Overriding the vertical CRS of the DEM with the one provided in `src_vcrs`.", - ) - src_ccrs = _build_ccrs_from_crs_and_vcrs(self.crs, vcrs=force_source_vcrs) + # Raise deprecation warning for old in-place behaviour + if "inplace" in kwargs and kwargs["inplace"]: + warnings.warn("Argument 'inplace' is deprecated and will be removed in future versions. " + "Use dem = dem.to_vcrs() instead.", + category=DeprecationWarning) + inplace = True else: - src_ccrs = self.ccrs + inplace = False - # New destination Compound CRS - dst_ccrs = _build_ccrs_from_crs_and_vcrs(self.crs, vcrs=_vcrs_from_user_input(vcrs_input=vcrs)) + # Apply transformation + new_dem = _to_vcrs_2d(dem=self, dst_vcrs=vcrs, force_source_vcrs=force_source_vcrs, mp_config=mp_config) - # If both compound CCRS are equal, do not run any transform - if src_ccrs.equals(dst_ccrs): - warnings.warn( - message="Source and destination vertical CRS are the same, skipping vertical transformation.", - category=UserWarning, - ) - return None - - # Transform elevation with new vertical CRS - zz = self.data - xx, yy = self.coords() - zz_trans = _transform_zz(crs_from=src_ccrs, crs_to=dst_ccrs, xx=xx, yy=yy, zz=zz) - new_data = zz_trans.astype(self.dtype) # type: ignore + # Keep logic below until we deprecate 'inplace' + # If early exit because no transformation was required + if new_dem is None: + if inplace: + return None + else: + return self.copy(deep=False) # If inplace, update DEM and vcrs if inplace: - self._data = new_data - self.set_vcrs(new_vcrs=vcrs) + self._data = new_dem.data + self.set_crs(new_crs=get_geo_attr(new_dem, "crs")) return None # Otherwise, return new DEM else: - return DEM.from_array( - data=new_data, - transform=self.transform, - crs=self.crs, - nodata=self.nodata, - area_or_point=self.area_or_point, - tags=self.tags, - vcrs=vcrs, - cast_nodata=False, - ) + return new_dem @copy_doc(terrain, remove_dem_res_params=True) def slope( @@ -621,13 +378,13 @@ def get_terrain_attribute(self, attribute: str | list[str], **kwargs: Any) -> Ra @profiler.profile("xdem.dem.coregister_3d", memprof=True) def coregister_3d( # type: ignore self, - reference_elev: DEM | gpd.GeoDataFrame | xdem.EPC, + reference_elev: DEMLike | gpd.GeoDataFrame | xdem.EPC, coreg_method: coreg.Coreg, inlier_mask: Raster | NDArrayb = None, bias_vars: dict[str, NDArrayf | MArrayf | RasterType] = None, random_state: int | np.random.Generator | None = None, **kwargs, - ) -> DEM: + ) -> DEMLike: """ Coregister DEM to a reference DEM in three dimensions. @@ -666,7 +423,7 @@ def coregister_3d( # type: ignore def estimate_uncertainty( self, - other_elev: DEM | gpd.GeoDataFrame, + other_elev: DEMLike | gpd.GeoDataFrame, stable_terrain: Raster | NDArrayb = None, approach: Literal["H2022", "R2009", "Basic"] = "H2022", precision_of_other: Literal["finer"] | Literal["same"] = "finer", @@ -720,7 +477,7 @@ def estimate_uncertainty( } # Elevation change with the other DEM or elevation point cloud - if isinstance(other_elev, DEM): + if has_geo_attr(other_elev, "transform"): dh = other_elev.reproject(self, silent=True) - self elif isinstance(other_elev, gpd.GeoDataFrame): other_elev = other_elev.to_crs(self.crs) @@ -806,4 +563,4 @@ def to_pointcloud( if isinstance(pc, gu.PointCloud): return xdem.EPC(pc) else: - return pc + return pc \ No newline at end of file diff --git a/xdem/dem/dem.py b/xdem/dem/dem.py new file mode 100644 index 000000000..bd4a4109b --- /dev/null +++ b/xdem/dem/dem.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026 xDEM developers +# +# This file is part of xDEM project: +# https://github.com/glaciohack/xdem +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DEM class and functions.""" +from __future__ import annotations + +import pathlib +import warnings +from typing import Any, Literal + +import rasterio as rio +from affine import Affine +from geoutils.raster import RasterType, Raster +from pyproj import CRS +from pyproj.crs import VerticalCRS + +from xdem._typing import NDArrayf +from xdem.vcrs import ( + _parse_vcrs_name_from_product, + _check_vcrs_input +) +from xdem.dem.base import DEMBase + +class DEM(Raster, DEMBase): # type: ignore + """ + The digital elevation model. + + The DEM has a single main attribute in addition to that inherited from :class:`geoutils.Raster`: + vcrs: :class:`pyproj.VerticalCRS` + Vertical coordinate reference system of the DEM. + + Other derivative attributes are: + vcrs_name: :class:`str` + Name of vertical CRS of the DEM. + vcrs_grid: :class:`str` + Grid path to the vertical CRS of the DEM. + ccrs: :class:`pyproj.CompoundCRS` + Compound vertical and horizontal CRS of the DEM. + + The attributes inherited from :class:`geoutils.Raster` are: + data: :class:`np.ndarray` + Data array of the DEM, with dimensions corresponding to (count, height, width). + transform: :class:`affine.Affine` + Geotransform of the DEM. + crs: :class:`pyproj.crs.CRS` + Coordinate reference system of the DEM. + nodata: :class:`int` or :class:`float` + Nodata value of the DEM. + + All other attributes are derivatives of those attributes, or read from the file on disk. + See the API for more details. + """ + + def __init__( + self, + filename_or_dataset: str | RasterType | rio.io.DatasetReader | rio.io.MemoryFile, + vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | VerticalCRS | str | pathlib.Path | int | None = None, + load_data: bool = False, + parse_sensor_metadata: bool = False, + silent: bool = True, + downsample: int = 1, + force_nodata: int | float | None = None, + ) -> None: + """ + Instantiate a digital elevation model. + + The vertical reference of the DEM can be defined by passing the `vcrs` argument. + Otherwise, a vertical reference is tentatively parsed from the DEM product name. + + Inherits all attributes from the :class:`geoutils.Raster` class. + + :param filename_or_dataset: The filename of the dataset. + :param vcrs: Vertical coordinate reference system either as a name ("WGS84", "EGM08", "EGM96"), + an EPSG code or pyproj.crs.VerticalCRS, or a path to a PROJ grid file (https://github.com/OSGeo/PROJ-data). + :param load_data: Whether to load the array during instantiation. Default is False. + :param parse_sensor_metadata: Whether to parse sensor metadata from filename and similarly-named metadata files. + :param silent: Whether to display vertical reference parsing. + :param downsample: Downsample the array once loaded by a round factor. Default is no downsampling. + :param force_nodata: Force nodata value to be used (overwrites the metadata). Default reads from metadata. + """ + + self.data: NDArrayf + self._vcrs: VerticalCRS | Literal["Ellipsoid"] | None = None + + # If DEM is passed, simply point back to DEM + if isinstance(filename_or_dataset, DEM): + for key in filename_or_dataset.__dict__: + setattr(self, key, filename_or_dataset.__dict__[key]) + return + # Else rely on parent Raster class options (including raised errors) + else: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Parse metadata from file not implemented") + super().__init__( + filename_or_dataset, + load_data=load_data, + parse_sensor_metadata=parse_sensor_metadata, + silent=silent, + downsample=downsample, + force_nodata=force_nodata, + ) + + # Ensure DEM has only one band: self.bands can be None when data is not loaded through the Raster class + if self.bands is not None and len(self.bands) > 1: + raise ValueError( + "DEM rasters should be composed of one band only. Either use argument `bands` to specify " + "a single band on opening, or use .split_bands() on an opened raster." + ) + + # If no vertical CRS was provided by the user or defined in the CRS + if vcrs is None and "product" in self.tags: + vcrs = _parse_vcrs_name_from_product(self.tags["product"]) + + # Cast CRS with vertical CRS (returns 2D or 3D) and re-set + new_crs = _check_vcrs_input(vcrs, self.crs) + self.set_crs(new_crs) \ No newline at end of file diff --git a/xdem/dem/xr_accessor.py b/xdem/dem/xr_accessor.py new file mode 100644 index 000000000..15532e0e3 --- /dev/null +++ b/xdem/dem/xr_accessor.py @@ -0,0 +1,43 @@ +"""Xarray accessor 'dem' for digital elevation models.""" +from __future__ import annotations + +from typing import Literal + +from pyproj.crs import VerticalCRS +import xarray as xr + +from geoutils.raster.xr_accessor import RasterAccessor, open_raster + +from xdem.dem.base import DEMBase +from xdem.vcrs import _check_vcrs_input + + +def open_dem(filename: str, + vcrs: Literal["Ellipsoid", "EGM08", "EGM96"] | VerticalCRS | str | int | None = None, + **kwargs): + """Wrapper around open_raster with vertical CRS input support.""" + + # Use open raster + ds = open_raster(filename, **kwargs) + + # Cast CRS with user-input vertical CRS (returns 2D or 3D) and re-set + new_crs = _check_vcrs_input(vcrs, ds.rst.crs) + ds.rst.set_crs(new_crs) + + return ds + +@xr.register_dataarray_accessor("dem") +class DEMAccessor(RasterAccessor, DEMBase): + """ + This class defines the Xarray accessor 'dem' for digital elevation models. + + Most attributes and functionalities are inherited from the DEMBase class (also parent of the DEM class) and + RasterAccessor class defining the 'rst' Xarray accessor for rasters. + Only methods specific to the functioning of the 'dem' Xarray accessor live in this class: mostly initialization, + I/O or copying. + """ + def __init__(self, xarray_obj: xr.DataArray): + + super().__init__(xarray_obj=xarray_obj) + + self._obj = xarray_obj diff --git a/xdem/epc/epc.py b/xdem/epc/epc.py index 2267f1cf9..fb8312a38 100644 --- a/xdem/epc/epc.py +++ b/xdem/epc/epc.py @@ -40,6 +40,7 @@ _transform_zz, _vcrs_from_crs, _vcrs_from_user_input, + _build_vertical_transformer ) epc_attrs = ["_vcrs", "_vcrs_name", "_vcrs_grid"] @@ -263,7 +264,8 @@ def to_vcrs( # Transform elevation with new vertical CRS zz = self.data # type: ignore xx, yy = self.geometry.x.values, self.geometry.y.values - zz_trans = _transform_zz(crs_from=src_ccrs, crs_to=dst_ccrs, xx=xx, yy=yy, zz=zz) + transformer = _build_vertical_transformer(crs_from=src_ccrs, crs_to=dst_ccrs) + zz_trans = _transform_zz(transformer=transformer, xx=xx, yy=yy, zz=zz) new_data = zz_trans.astype(self.data.dtype) # type: ignore # If inplace, update EPC and vcrs diff --git a/xdem/fit.py b/xdem/fit.py index ecf0977fe..fd07bfb82 100644 --- a/xdem/fit.py +++ b/xdem/fit.py @@ -29,7 +29,7 @@ import numpy as np import scipy -from geoutils.stats.sampling import subsample_array +from geoutils.stats.sampling import _subsample_numpy from numpy.polynomial.polynomial import polyval, polyval2d from xdem._misc import import_optional @@ -400,7 +400,7 @@ def robust_norder_polynomial_fit( # Subsample data if subsample != 1: - subsamp = subsample_array(x, subsample=subsample, return_indices=True, random_state=random_state) + subsamp = _subsample_numpy(x, subsample=subsample, return_indices=True, random_state=random_state) x = x[subsamp] y = y[subsamp] diff --git a/xdem/spatialstats.py b/xdem/spatialstats.py index ed4f73e31..5f3f9ee38 100644 --- a/xdem/spatialstats.py +++ b/xdem/spatialstats.py @@ -35,7 +35,7 @@ import scipy.ndimage from geoutils.raster import Raster, RasterType from geoutils.raster.array import get_array_and_mask -from geoutils.stats.sampling import subsample_array +from geoutils.stats.sampling import _subsample_numpy from geoutils.vector.vector import Vector, VectorType from numpy.typing import ArrayLike from packaging.version import Version @@ -975,7 +975,7 @@ def _subsample_wrapper( values_sp = values coords_sp = coords - index = subsample_array(values_sp, subsample=subsample, return_indices=True, random_state=random_state) + index = _subsample_numpy(values_sp, subsample=subsample, return_indices=True, random_state=random_state) values_sub = values_sp[index[0]] coords_sub = coords_sp[index[0], :] diff --git a/xdem/terrain/terrain.py b/xdem/terrain/terrain.py index 6826156f6..7e0dfdd6a 100644 --- a/xdem/terrain/terrain.py +++ b/xdem/terrain/terrain.py @@ -21,22 +21,31 @@ from __future__ import annotations import warnings -from typing import Literal, Sized, overload +from typing import Literal, Sized, overload, TYPE_CHECKING import geoutils as gu import numpy as np from geoutils import profiler from geoutils.raster import Raster, RasterType -from geoutils.raster.distributed_computing import ( +from geoutils.raster.referencing import _res +from geoutils.multiproc import ( MultiprocConfig, map_overlap_multiproc_save, ) +from geoutils._dispatch import has_geo_attr, get_geo_attr -from xdem._typing import DTypeLike, MArrayf, NDArrayf +from xdem._typing import DTypeLike, NDArrayf +from xdem._misc import import_optional from xdem.terrain.freq import _texture_shading_fft from xdem.terrain.surfit import _get_surface_attributes from xdem.terrain.window import _get_windowed_indexes +if TYPE_CHECKING: + from xdem.dem.base import DEMLike + from xdem import DEM + from geoutils.raster.base import RasterLike + import dask.array as da + # List available attributes available_attributes = [ "slope", @@ -80,10 +89,110 @@ # 3/ Requiring fractal domain list_requiring_frequency_domain = ["texture_shading"] +# Helpers for chunked execution +############################### + +def _multiproc_get_terrain_attribute( + dem: DEM, + attr: str, + resolution: float, + degrees: bool, + hillshade_altitude: float, + hillshade_azimuth: float , + hillshade_z_factor: float, + surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"], + curv_method: Literal["geometric", "directional"], + tri_method: Literal["Riley", "Wilson"], + window_size: int, + engine: Literal["scipy", "numba"], + texture_alpha: float, + out_dtype: DTypeLike | None, + depth: int, + mp_config: MultiprocConfig, +) -> Raster: + + # Wrap up the block function to work with a Raster + def _raster_block_func(block: Raster) -> Raster: + arr = _get_terrain_attribute_base( + block.data, + attribute=[attr], + resolution=resolution, + degrees=degrees, + hillshade_altitude=hillshade_altitude, + hillshade_azimuth=hillshade_azimuth, + hillshade_z_factor=hillshade_z_factor, + surface_fit=surface_fit, + curv_method=curv_method, + tri_method=tri_method, + window_size=window_size, + engine=engine, + texture_alpha=texture_alpha, + out_dtype=out_dtype, + )[0] + return block.copy(new_array=arr) + + # Return a map overlap + return map_overlap_multiproc_save( + _raster_block_func, + dem, + mp_config, + depth=depth, + ) + +def _dask_get_terrain_attribute( + array: da.Array, + attr: str, + resolution: float | tuple[float, float] | None, + degrees: bool, + hillshade_altitude: float, + hillshade_azimuth: float, + hillshade_z_factor: float, + surface_fit: str, + curv_method: str, + tri_method: str, + window_size: int, + engine: str, + texture_alpha: float, + out_dtype: np.dtype[Any] | None, + depth: int, +) -> da.Array: + """Apply one terrain attribute with Dask overlap.""" + + import_optional("dask") + import dask.array as da + + def _block_func(block: np.ndarray) -> np.ndarray: + return _get_terrain_attribute_base( + block, + attribute=[attr], + resolution=resolution, + degrees=degrees, + hillshade_altitude=hillshade_altitude, + hillshade_azimuth=hillshade_azimuth, + hillshade_z_factor=hillshade_z_factor, + surface_fit=surface_fit, + curv_method=curv_method, + tri_method=tri_method, + window_size=window_size, + engine=engine, + texture_alpha=texture_alpha, + out_dtype=out_dtype, + )[0] + + dtype = out_dtype if out_dtype is not None else array.dtype + + return da.map_overlap( + _block_func, + array, + depth=depth, + boundary="none", + trim=True, + dtype=dtype, + ) @overload def get_terrain_attribute( - dem: NDArrayf | MArrayf, + dem: NDArrayf, attribute: str, resolution: tuple[float, float] | float | None = None, degrees: bool = True, @@ -104,7 +213,7 @@ def get_terrain_attribute( @overload def get_terrain_attribute( - dem: NDArrayf | MArrayf, + dem: NDArrayf, attribute: list[str], resolution: tuple[float, float] | float | None = None, degrees: bool = True, @@ -125,7 +234,7 @@ def get_terrain_attribute( @overload def get_terrain_attribute( - dem: RasterType, + dem: DEMLike, attribute: list[str], resolution: tuple[float, float] | float | None = None, degrees: bool = True, @@ -146,7 +255,7 @@ def get_terrain_attribute( @overload def get_terrain_attribute( - dem: RasterType, + dem: DEMLike, attribute: str, resolution: tuple[float, float] | float | None = None, degrees: bool = True, @@ -162,12 +271,12 @@ def get_terrain_attribute( texture_alpha: float = 0.8, out_dtype: DTypeLike | None = None, mp_config: MultiprocConfig | None = None, -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.get_terrain_attribute", memprof=True) def get_terrain_attribute( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, attribute: str | list[str], resolution: tuple[float, float] | float | None = None, degrees: bool = True, @@ -307,9 +416,10 @@ def get_terrain_attribute( "Use 'ZevenbergThorne' or 'Florinsky' instead." ) - if isinstance(dem, gu.Raster): - if resolution is None: - resolution = dem.res + # Check robust to any input (Xarray or + if has_geo_attr(dem, "transform"): + transform = get_geo_attr(dem, "transform") + resolution = _res(transform) # Validate and format the inputs if isinstance(attribute, str): @@ -390,25 +500,35 @@ def get_terrain_attribute( ) # 2/ Processing: chunked or normal depending on input - if mp_config is not None: - # Derive depth argument from method or window size, - # This is the overlap between tiles (1 for 3x3, 2 for 5x5, etc). - if any((attr in list_requiring_windowed_index) for attr in attribute): - window_depth = window_size // 2 - else: - window_depth = 0 - if any((attr in list_requiring_surface_fit) for attr in attribute): - if surface_fit.lower() == "florinsky": - surface_fit_depth = 2 - else: - surface_fit_depth = 1 + # Derive depth argument from method or window size, + # This is the overlap between tiles (1 for 3x3, 2 for 5x5, etc). + if any((attr in list_requiring_windowed_index) for attr in attribute): + window_depth = window_size // 2 + else: + window_depth = 0 + + if any((attr in list_requiring_surface_fit) for attr in attribute): + if surface_fit.lower() == "florinsky": + surface_fit_depth = 2 else: - surface_fit_depth = 0 + surface_fit_depth = 1 + else: + surface_fit_depth = 0 + + # We take the maximum required depth + depth = max(window_depth, surface_fit_depth) - # We take the maximum required depth - depth = max(window_depth, surface_fit_depth) + # Detect backend + dask_backend = ( + getattr(dem, "_is_xr", False) + and hasattr(dem.data, "chunks") + and dem.data.chunks is not None + ) + mp_backend = mp_config is not None + # If multiprocessing + if mp_backend: if not isinstance(dem, Raster): raise TypeError("The DEM must be a Raster to use multiprocessing.") @@ -417,51 +537,78 @@ def get_terrain_attribute( mp_config_copy = mp_config.copy() if mp_config.outfile is not None and len(attribute) > 1: mp_config_copy.outfile = mp_config_copy.outfile.split(".")[0] + "_" + attr + ".tif" - list_raster.append( - map_overlap_multiproc_save( - _get_terrain_attribute, - dem, - mp_config_copy, - [attr], - resolution, - degrees, - hillshade_altitude, - hillshade_azimuth, - hillshade_z_factor, - surface_fit, - curv_method, - tri_method, - window_size, - engine, - texture_alpha, - out_dtype, + + raster = _multiproc_get_terrain_attribute( + dem, + attr=attr, + resolution=resolution, + degrees=degrees, + hillshade_altitude=hillshade_altitude, + hillshade_azimuth=hillshade_azimuth, + hillshade_z_factor=hillshade_z_factor, + surface_fit=surface_fit, + curv_method=curv_method, + tri_method=tri_method, + window_size=window_size, + engine=engine, + texture_alpha=texture_alpha, + out_dtype=out_dtype, + depth=depth, + mp_config=mp_config_copy, + ) + list_raster.append(raster) + + else: + if dask_backend: + list_raster = [] + for attr in attribute: + array = _dask_get_terrain_attribute( + dem.data, + attr=attr, + resolution=resolution, + degrees=degrees, + hillshade_altitude=hillshade_altitude, + hillshade_azimuth=hillshade_azimuth, + hillshade_z_factor=hillshade_z_factor, + surface_fit=surface_fit, + curv_method=curv_method, + tri_method=tri_method, + window_size=window_size, + engine=engine, + texture_alpha=texture_alpha, + out_dtype=out_dtype, depth=depth, ) + list_raster.append(dem.copy(new_array=array)) + + else: + list_arr = _get_terrain_attribute_base( # type: ignore + dem.data if has_geo_attr(dem, "transform") else dem, + attribute, + resolution, + degrees, + hillshade_altitude, + hillshade_azimuth, + hillshade_z_factor, + surface_fit, + curv_method, + tri_method, + window_size, + engine, + texture_alpha, + out_dtype, ) - if len(list_raster) == 1: - return list_raster[0] - return list_raster - else: - return _get_terrain_attribute( # type: ignore - dem, - attribute, # type: ignore - resolution, - degrees, - hillshade_altitude, - hillshade_azimuth, - hillshade_z_factor, - surface_fit, - curv_method, - tri_method, - window_size, - engine, - texture_alpha, - out_dtype, - ) + if has_geo_attr(dem, "transform"): + list_raster = [dem.copy(new_array=array) for array in list_arr] + else: + list_raster = list_arr + # If list has length of one, return first element directly + if len(list_raster) == 1: + return list_raster[0] + return list_raster -@overload -def _get_terrain_attribute( +def _get_terrain_attribute_base( dem: NDArrayf, attribute: list[str], resolution: float, @@ -476,44 +623,7 @@ def _get_terrain_attribute( engine: Literal["scipy", "numba"] = "scipy", texture_alpha: float = 0.8, out_dtype: DTypeLike | None = None, -) -> list[NDArrayf]: ... - - -@overload -def _get_terrain_attribute( - dem: RasterType, - attribute: list[str], - resolution: float, - degrees: bool = True, - hillshade_altitude: float = 45.0, - hillshade_azimuth: float = 315.0, - hillshade_z_factor: float = 1.0, - surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", - curv_method: Literal["geometric", "directional"] = "geometric", - tri_method: Literal["Riley", "Wilson"] = "Riley", - window_size: int = 3, - engine: Literal["scipy", "numba"] = "scipy", - texture_alpha: float = 0.8, - out_dtype: DTypeLike | None = None, -) -> list[RasterType]: ... - - -def _get_terrain_attribute( - dem: NDArrayf | RasterType, - attribute: list[str], - resolution: float, - degrees: bool = True, - hillshade_altitude: float = 45.0, - hillshade_azimuth: float = 315.0, - hillshade_z_factor: float = 1.0, - surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", - curv_method: Literal["geometric", "directional"] = "geometric", - tri_method: Literal["Riley", "Wilson"] = "Riley", - window_size: int = 3, - engine: Literal["scipy", "numba"] = "scipy", - texture_alpha: float = 0.8, - out_dtype: DTypeLike | None = None, -) -> list[NDArrayf] | list[RasterType]: +) -> list[NDArrayf]: """ See description of get_terrain_attribute(). """ @@ -523,11 +633,13 @@ def _get_terrain_attribute( attributes_requiring_windowed_index = [attr for attr in attribute if attr in list_requiring_windowed_index] attributes_requiring_frequency_domain = [attr for attr in attribute if attr in list_requiring_frequency_domain] - # Get array of DEM - dem_arr = gu.raster.get_array_and_mask(dem)[0] - # We need to be able to use NaNs to propagate invalid values in attributes - if np.issubdtype(dem_arr.dtype, np.integer): - dem_arr = dem_arr.astype(np.float32) + # Get array of DEM, we need to be able to use NaNs to propagate invalid values in attributes + if np.issubdtype(dem.dtype, np.integer): + dem_arr = dem.astype(np.float32) + else: + dem_arr = dem + if np.ma.isMaskedArray(dem): + dem_arr = dem_arr.filled(np.nan) # Process surface attributes if len(attributes_requiring_surface_fit) > 0: @@ -606,18 +718,12 @@ def _get_terrain_attribute( ] output_attributes[:] = [output_attributes[idx] for idx in order_indices] - if isinstance(dem, gu.Raster): - output_attributes = [ - gu.Raster.from_array(attr, transform=dem.transform, crs=dem.crs, nodata=-99999) - for attr in output_attributes - ] # type: ignore - - return output_attributes if len(output_attributes) > 1 else output_attributes[0] + return output_attributes @overload def slope( - dem: NDArrayf | MArrayf, + dem: NDArrayf, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, @@ -629,7 +735,7 @@ def slope( @overload def slope( - dem: RasterType, + dem: DEMLike, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, @@ -641,7 +747,7 @@ def slope( @profiler.profile("xdem.terrain.slope", memprof=True) def slope( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, @@ -698,7 +804,7 @@ def slope( @overload def aspect( - dem: NDArrayf | MArrayf, + dem: NDArrayf, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, @@ -709,18 +815,18 @@ def aspect( @overload def aspect( - dem: RasterType, + dem: DEMLike, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.aspect", memprof=True) def aspect( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", degrees: bool = True, @@ -786,7 +892,7 @@ def aspect( @overload def hillshade( - dem: NDArrayf | MArrayf, + dem: NDArrayf, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", azimuth: float = 315.0, @@ -800,7 +906,7 @@ def hillshade( @overload def hillshade( - dem: RasterType, + dem: DEMLike, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", azimuth: float = 315.0, @@ -809,12 +915,12 @@ def hillshade( resolution: float | tuple[float, float] | None = None, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.hillshade", memprof=True) def hillshade( - dem: NDArrayf | MArrayf, + dem: NDArrayf, method: Literal["Horn", "ZevenbergThorne"] = None, surface_fit: Literal["Horn", "ZevenbergThorne", "Florinsky"] = "Florinsky", azimuth: float = 315.0, @@ -823,7 +929,7 @@ def hillshade( resolution: float | tuple[float, float] | None = None, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Generate a hillshade from the given DEM. The value 0 is used for nodata, and 1 to 255 for hillshading. @@ -871,7 +977,7 @@ def hillshade( @overload def curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", mp_config: MultiprocConfig | None = None, @@ -881,22 +987,22 @@ def curvature( @overload def curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.curvature", memprof=True) def curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ THIS FUNCTION IS DEPRECATED - REFER TO DOCS FOR SPECIFIC CURVATURE RECOMMENDATIONS @@ -941,7 +1047,7 @@ def curvature( @overload def profile_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -952,24 +1058,24 @@ def profile_curvature( @overload def profile_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.profile_curvature", memprof=True) def profile_curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates profile curvature in units of m-1 multiplied by 100. Defined as the curvature of a normal section of @@ -1017,7 +1123,7 @@ def profile_curvature( @overload def tangential_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1028,24 +1134,24 @@ def tangential_curvature( @overload def tangential_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.tangential_curvature", memprof=True) def tangential_curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates tangential curvature in units of m-1 multiplied by 100. Defined as the curvature of a normal section of @@ -1094,7 +1200,7 @@ def tangential_curvature( @overload def planform_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1105,24 +1211,24 @@ def planform_curvature( @overload def planform_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.planform_curvature", memprof=True) def planform_curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates planform (or plan) curvature in units of m-1 multiplied by 100., defined as the curvature of a @@ -1169,7 +1275,7 @@ def planform_curvature( @overload def flowline_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1180,18 +1286,18 @@ def flowline_curvature( @overload def flowline_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.flowline_curvature", memprof=True) def flowline_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1245,7 +1351,7 @@ def flowline_curvature( @overload def max_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1256,24 +1362,24 @@ def max_curvature( @overload def max_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.max_curvature", memprof=True) def max_curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculate the maximal (geometric) or maximum (directional derivative) curvature in units of m-1 multiplied by 100. Defined as curvature of the normal section of slope with the greatest curvature value. @@ -1321,7 +1427,7 @@ def max_curvature( @overload def min_curvature( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", @@ -1332,24 +1438,24 @@ def min_curvature( @overload def min_curvature( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.min_curvature", memprof=True) def min_curvature( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, surface_fit: Literal["ZevenbergThorne", "Florinsky"] = "Florinsky", curv_method: Literal["geometric", "directional"] = "geometric", mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculate the minimal (geometric) or minimum (directional derivative) curvature in units of m-1 multiplied by 100. Defined as curvature of the normal section of slope with the smallest curvature value. @@ -1397,7 +1503,7 @@ def min_curvature( @overload def topographic_position_index( - dem: NDArrayf | MArrayf, + dem: NDArrayf, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", @@ -1406,20 +1512,20 @@ def topographic_position_index( @overload def topographic_position_index( - dem: RasterType, + dem: DEMLike, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.topographic_position_index", memprof=True) def topographic_position_index( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates the Topographic Position Index, the difference to the average of neighbouring pixels. Output is in the unit of the DEM (typically meters). @@ -1458,7 +1564,7 @@ def topographic_position_index( @overload def terrain_ruggedness_index( - dem: NDArrayf | MArrayf, + dem: NDArrayf, method: Literal["Riley", "Wilson"] = "Riley", window_size: int = 3, mp_config: MultiprocConfig | None = None, @@ -1468,22 +1574,22 @@ def terrain_ruggedness_index( @overload def terrain_ruggedness_index( - dem: RasterType, + dem: DEMLike, method: Literal["Riley", "Wilson"] = "Riley", window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.terrain_ruggedness_index", memprof=True) def terrain_ruggedness_index( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, method: Literal["Riley", "Wilson"] = "Riley", window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates the Terrain Ruggedness Index, the cumulated differences to neighbouring pixels. Output is in the unit of the DEM (typically meters). @@ -1529,7 +1635,7 @@ def terrain_ruggedness_index( @overload def roughness( - dem: NDArrayf | MArrayf, + dem: NDArrayf, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", @@ -1538,20 +1644,20 @@ def roughness( @overload def roughness( - dem: RasterType, + dem: DEMLike, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.roughness", memprof=True) def roughness( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, window_size: int = 3, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates the roughness, the maximum difference between neighbouring pixels, for any window size. Output is in the unit of the DEM (typically meters). @@ -1590,7 +1696,7 @@ def roughness( @overload def rugosity( - dem: NDArrayf | MArrayf, + dem: NDArrayf, resolution: float | tuple[float, float] | None = None, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", @@ -1599,20 +1705,20 @@ def rugosity( @overload def rugosity( - dem: RasterType, + dem: DEMLike, resolution: float | tuple[float, float] | None = None, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.rugosity", memprof=True) def rugosity( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, resolution: float | tuple[float, float] | None = None, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates the rugosity, the ratio between real area and planimetric area. Only available for a 3x3 window. The output is unitless. @@ -1651,7 +1757,7 @@ def rugosity( @overload def fractal_roughness( - dem: NDArrayf | MArrayf, + dem: NDArrayf, window_size: int = 13, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", @@ -1660,20 +1766,20 @@ def fractal_roughness( @overload def fractal_roughness( - dem: RasterType, + dem: DEMLike, window_size: int = 13, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.fractal_roughness", memprof=True) def fractal_roughness( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, window_size: int = 13, mp_config: MultiprocConfig | None = None, engine: Literal["scipy", "numba"] = "scipy", -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Calculates the fractal roughness, the local 3D fractal dimension. Can only be computed on window sizes larger or equal to 5x5, defaults to 13x13. Output unit is a fractal dimension between 1 and 3. @@ -1714,7 +1820,7 @@ def fractal_roughness( @overload def texture_shading( - dem: NDArrayf | MArrayf, + dem: NDArrayf, alpha: float = 0.8, mp_config: MultiprocConfig | None = None, ) -> NDArrayf: ... @@ -1722,18 +1828,18 @@ def texture_shading( @overload def texture_shading( - dem: RasterType, + dem: DEMLike, alpha: float = 0.8, mp_config: MultiprocConfig | None = None, -) -> RasterType: ... +) -> RasterLike: ... @profiler.profile("xdem.terrain.texture_shading", memprof=True) def texture_shading( - dem: NDArrayf | MArrayf | RasterType, + dem: NDArrayf | DEMLike, alpha: float = 0.8, mp_config: MultiprocConfig | None = None, -) -> NDArrayf | RasterType: +) -> NDArrayf | RasterLike: """ Generate a texture shaded relief map using fractional Laplacian operator. diff --git a/xdem/vcrs.py b/xdem/vcrs.py index c5a417bf5..86ceb343f 100644 --- a/xdem/vcrs.py +++ b/xdem/vcrs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 xDEM developers +# Copyright (c) 2026 xDEM developers # # This file is part of the xDEM project: # https://github.com/glaciohack/xdem @@ -23,9 +23,16 @@ import os import pathlib import warnings -from typing import Literal, TypedDict +from typing import Literal, TypedDict, Any, TYPE_CHECKING from urllib.error import HTTPError +from geoutils.raster.referencing import _coords +from geoutils.multiproc import MultiprocConfig +from geoutils.multiproc.mparray import map_overlap_multiproc_save +from geoutils._dispatch import get_geo_attr + +import numpy as np +import affine import pyproj from pyproj import CRS from pyproj.crs import BoundCRS, CompoundCRS, GeographicCRS, VerticalCRS @@ -33,8 +40,20 @@ from pyproj.crs.enums import Ellipsoidal3DCSAxis from pyproj.transformer import TransformerGroup +from xdem._misc import import_optional from xdem._typing import MArrayf, NDArrayf +if TYPE_CHECKING: + from xdem import DEM + from xdem.dem.base import DEMBase + +# Optional Dask import +try: + import dask.array as da +except ImportError: + da = None # type: ignore[assignment] + + # Sources for defining vertical references: # AW3D30: https://www.eorc.jaxa.jp/ALOS/en/aw3d30/aw3d30v11_format_e.pdf # SRTMGL1: https://lpdaac.usgs.gov/documents/179/SRTM_User_Guide_V3.pdf @@ -57,6 +76,96 @@ "COPDEM": "EGM08", } +def _check_vcrs_input(vcrs: Any, crs: Any) -> Any: + """ + Process user-input vertical CRS and CRS, and return normalized CRS output. + + :param vcrs: Vertical CRS input. + :param crs: CRS input. + + :return: Normalized CRS output. + """ + + # Parse 2D/3D CRS + crs = pyproj.CRS.from_user_input(crs) + + # Vertical CRS from different sources + vcrs_from_crs = _vcrs_from_crs(crs) + if vcrs is None: + vcrs_from_user = None + else: + vcrs_from_user = _vcrs_from_user_input(vcrs) + + # Determine which vertical CRS to use + if vcrs_from_user is not None: + # User input takes precedence over CRS metadata + if vcrs_from_crs is not None and vcrs_from_user != vcrs_from_crs: + warnings.warn( + "The CRS in the raster metadata already has a vertical component, " + f"the user-provided '{vcrs}' will override it." + ) + out_vcrs = vcrs_from_user + else: + out_vcrs = vcrs_from_crs + + # Build final CRS + if out_vcrs is not None: + out_crs = _build_ccrs_from_crs_and_vcrs(crs, out_vcrs) + else: + out_crs = crs + + return out_crs + +# EPSG codes for units +_UNIT_SYMBOLS = { + "9001": "m", # metre + "9002": "ft", # foot + "9003": "ftUS", # US survey foot + "9036": "km", + "9102": "°", +} + +def vertical_unit_symbol(crs) -> str | None: + """ + Return the short unit symbol of the vertical axis (e.g. "m", "ft"). + + Returns None if the CRS has no vertical axis. + """ + + # Process CRS input + crs = CRS(crs) + + # If compound CRS, isolate the vertical CRS + if crs.is_compound: + crs = crs.sub_crs_list[1] + + # Check axis is indeed vertical, otherwise return None + for axis in crs.axis_info: + if axis.direction in ("up", "down"): + + # Prefer EPSG unit code if it exists + code = axis.unit_auth_code + if code and code in _UNIT_SYMBOLS: + return _UNIT_SYMBOLS[code] + + # Fallback to normalized unit names + name = axis.unit_name.lower() + + if name in {"metre", "meter"}: + return "m" + + if name == "kilometre": + return "km" + + if name == "foot": + return "ft" + + if name == "us survey foot": + return "ftUS" + + return axis.unit_name + + return None def _parse_vcrs_name_from_product(product: str) -> str | None: """ @@ -75,14 +184,14 @@ def _parse_vcrs_name_from_product(product: str) -> str | None: return vcrs_name -def _build_ccrs_from_crs_and_vcrs(crs: CRS, vcrs: CRS | Literal["Ellipsoid"]) -> CompoundCRS | CRS: +def _build_ccrs_from_crs_and_vcrs(crs: CRS, vcrs: CRS | Literal["Ellipsoid"]) -> CRS: """ - Build a compound CRS from a horizontal CRS and a vertical CRS. + Build a 3D CRS (compound or expanded) from a horizontal CRS and a vertical CRS input. :param crs: Horizontal CRS. :param vcrs: Vertical CRS. - :return: Compound CRS (horizontal + vertical). + :return: 3D CRS (horizontal + vertical). """ # If a vertical CRS was passed, build a compound CRS with horizontal + vertical @@ -92,7 +201,7 @@ def _build_ccrs_from_crs_and_vcrs(crs: CRS, vcrs: CRS | Literal["Ellipsoid"]) -> # If pyproj >= 3.5.1, we can use CRS.to_2d() from packaging.version import Version - if Version(pyproj.__version__) > Version("3.5.0"): + if Version(pyproj.__version__) >= Version("3.5.1"): crs_from = CRS(crs).to_2d() ccrs = CompoundCRS( name="Horizontal: " + CRS(crs).name + "; Vertical: " + vcrs.name, @@ -115,24 +224,24 @@ def _build_ccrs_from_crs_and_vcrs(crs: CRS, vcrs: CRS | Literal["Ellipsoid"]) -> components=[crs_from, vcrs], ) - # Else if "Ellipsoid" was passed, there is no vertical reference - # We still have to return the CRS in 3D + # Else if "Ellipsoid" was passed, there is no vertical CRS, but we expand the ellipsoid to 3D + # We isolate the 2D horizontal CRS (removing potential geoids), then expand it to 3D elif isinstance(vcrs, str) and vcrs.lower() == "ellipsoid": - ccrs = CRS(crs).to_3d() + ccrs = CRS(crs).to_2d().to_3d() else: raise ValueError("Invalid vcrs given. Must be a vertical CRS or the literal string 'Ellipsoid'.") return ccrs -def _build_vcrs_from_grid(grid: str, old_way: bool = False) -> CompoundCRS: +def _build_vcrs_from_grid(grid: str, old_way: bool = False) -> BoundCRS: """ - Build a compound CRS from a vertical CRS grid path. + Build a bound CRS from a vertical CRS grid path. :param grid: Path to grid for vertical reference. :param old_way: Whether to use the new or old way of building the compound CRS with pyproj (for testing purposes). - :return: Compound CRS (horizontal + vertical). + :return: Bound CRS. """ if not os.path.exists(os.path.join(pyproj.datadir.get_data_dir(), grid)): @@ -206,10 +315,15 @@ class VCRSMetaDict(TypedDict, total=False): "EGM96": {"grid": "us_nga_egm96_15.tif", "epsg": 5773}, # EGM1996 at 15 minute resolution } - -def _vcrs_from_crs(crs: CRS) -> CRS: +def _vcrs_from_crs(crs: CRS | None) -> CRS | Literal["Ellipsoid"] | None: """Get the vertical CRS from a CRS.""" + # If no CRS is defined + if crs is None: + return None + else: + crs = CRS(crs) + # Check if CRS is 3D if len(crs.axis_info) > 2: @@ -271,14 +385,17 @@ def _vcrs_from_user_input( ) vcrs = _vcrs_from_crs(vcrs) - # If a string was passed + # If a string or path was passed else: + if isinstance(vcrs_input, pathlib.Path): + vcrs_input = vcrs_input.name # If a name is passed, define CRS based on dict - if isinstance(vcrs_input, str) and vcrs_input.upper() in _vcrs_meta.keys(): - vcrs_meta = _vcrs_meta[vcrs_input] + key = vcrs_input.upper() + if isinstance(vcrs_input, str) and key in _vcrs_meta: + vcrs_meta = _vcrs_meta[key] vcrs = CRS.from_epsg(vcrs_meta["epsg"]) # Otherwise, attempt to read a grid from the string - elif os.path.splitext(vcrs_input)[-1] in [".tif", ".json", ".pol"]: + elif os.path.splitext(vcrs_input)[-1].lower() in [".tif", ".json", ".pol"]: if isinstance(vcrs_input, pathlib.Path): grid = vcrs_input.name else: @@ -316,28 +433,18 @@ def _grid_from_user_input(vcrs_input: str | pathlib.Path | int | CRS) -> str | N return grid - -def _transform_zz( - crs_from: CRS, crs_to: CRS, xx: NDArrayf, yy: NDArrayf, zz: MArrayf | NDArrayf | int | float -) -> MArrayf | NDArrayf | int | float: +def _build_vertical_transformer(crs_from: CRS, crs_to: CRS) -> pyproj.Transformer: """ - Transform elevation to a new 3D CRS. - - :param crs_from: Source CRS. - :param crs_to: Destination CRS. - :param xx: X coordinates. - :param yy: Y coordinates. - :param zz: Z coordinates. + Build the best available transformer for a vertical CRS transformation. - :return: Transformed Z coordinates. + Downloads missing grids before returning, if needed. """ - # Find all possible transforms with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Best transformation is not available") trans_group = TransformerGroup(crs_from=crs_from, crs_to=crs_to, always_xy=True) - # Download grid if best available is not on disk, download and re-initiate the object + # Download grid if best available is not on disk, then re-initialize if not trans_group.best_available: trans_group.download_grids() trans_group = TransformerGroup(crs_from=crs_from, crs_to=crs_to, always_xy=True) @@ -349,9 +456,271 @@ def _transform_zz( message="Best available grid for transformation could not be downloaded, " "applying the next best available (caution: might apply no transform at all).", ) - transformer = trans_group.transformers[0] + + return trans_group.transformers[0] + +def _transform_zz( + transformer: pyproj.Transformer, + xx: NDArrayf, + yy: NDArrayf, + zz: MArrayf | NDArrayf | int | float, +) -> MArrayf | NDArrayf | int | float: + """ + Transform elevation to a new 3D CRS using an already-built transformer. + """ # Will preserve the mask of the masked-array since pyproj 3.4 zz_trans = transformer.transform(xx, yy, zz)[2] return zz_trans + +# Vertical CRS transformation for DEMs +###################################### + +def _to_vcrs_2d_pyproj( + data: NDArrayf, + transform: affine.Affine, + transformer: pyproj.Transformer, +) -> NDArrayf: + """ + Base function: transforms one raster block from source to destination vertical CRS. + """ + xx, yy = _coords(shape=data.shape, transform=transform, area_or_point=None) + zz_trans = _transform_zz( + transformer=transformer, + xx=xx, + yy=yy, + zz=data, + ) + return zz_trans.astype(data.dtype, copy=False) + +def _to_vcrs_2d_block_dask( + data: NDArrayf, + *, + transform: affine.Affine, + src_ccrs_wkt: str, + dst_ccrs_wkt: str, + block_info: list[dict[str, Any]] | None = None, +) -> NDArrayf: + """Dask block wrapper deriving the local transform from block_info.""" + + if block_info is None: + raise ValueError("block_info must be provided.") + + # Reconstruct transform from block info + row_loc, col_loc = block_info[0]["array-location"] + + # Dask may return slices or (start, stop) tuples depending on version + row_start = row_loc.start if hasattr(row_loc, "start") else row_loc[0] + col_start = col_loc.start if hasattr(col_loc, "start") else col_loc[0] + block_transform = transform * affine.Affine.translation(col_start, row_start) + + # Rebuild transformer inside the block (serialization issues with a Pyproj transformer if passing it) + transformer = _build_vertical_transformer( + crs_from=CRS.from_wkt(src_ccrs_wkt), + crs_to=CRS.from_wkt(dst_ccrs_wkt), + ) + + return _to_vcrs_2d_pyproj( + data=data, + transform=block_transform, + transformer=transformer, + ) +def _dask_to_vcrs_2d( + darr: da.Array, + transform: affine.Affine, + src_ccrs: CRS, + dst_ccrs: CRS, +) -> da.Array: + """Blockwise vertical CRS transform using Dask.""" + + # Simply use map_blocks, as all transformations are independent when purely vertical + import_optional("dask") + return darr.map_blocks( + _to_vcrs_2d_block_dask, + transform=transform, + src_ccrs_wkt=src_ccrs.to_wkt(), + dst_ccrs_wkt=dst_ccrs.to_wkt(), + dtype=darr.dtype, + meta=np.array((), dtype=darr.dtype), + ) + +def _to_vcrs_2d_block_mp( + dem: DEM, + src_ccrs_wkt: str, + dst_ccrs_wkt: str, +) -> DEM: + """Multiprocessing block wrapper using the tile-local transform directly.""" + + # Rebuild transformer inside the block (serialization issues with a Pyproj transformer if passing it) + transformer = _build_vertical_transformer( + crs_from=CRS.from_wkt(src_ccrs_wkt), + crs_to=CRS.from_wkt(dst_ccrs_wkt), + ) + + # Transform + out_data = _to_vcrs_2d_pyproj( + data=dem.data, + transform=dem.transform, + transformer=transformer, + ) + + return dem.from_array( + data=out_data, + transform=dem.transform, + crs=dem.crs, + nodata=dem.nodata, + area_or_point=dem.area_or_point, + tags=dem.tags, + ) + +def _multiproc_to_vcrs_2d( + dem: DEM, + *, + src_ccrs: CRS, + dst_ccrs: CRS, + mp_config: MultiprocConfig, +) -> DEM: + """ + Vertical CRS transform using multiprocessing. + """ + + out_dem = map_overlap_multiproc_save( + _to_vcrs_2d_block_mp, + dem, + mp_config, + src_ccrs.to_wkt(), + dst_ccrs.to_wkt(), + depth=0, + ) + out_dem.set_crs(dst_ccrs) + + from xdem.dem.dem import DEM + return DEM(out_dem) + +def _get_vertical_transform_crss( + crs: Any, + dst_vcrs: Any, + force_source_vcrs: Any | None = None, +) -> tuple[CRS, CRS]: + """ + Build source and destination compound CRS for a vertical transformation, and raise errors where necessary. + """ + + # Get source VCRS from current CRS + src_vcrs = _vcrs_from_crs(crs) + + # Early exit if conversion not defined + if src_vcrs is None and force_source_vcrs is None: + raise ValueError( + "The current DEM has no vertical reference, define one with .set_vcrs() " + "or by passing `vcrs` to perform a conversion." + ) + + # Initial Compound CRS + if force_source_vcrs is not None: + if src_vcrs is not None: + warnings.warn( + category=UserWarning, + message=f"Overriding the vertical CRS of the DEM " + f"with the one provided in `force_source_vcrs`: {force_source_vcrs}.", + ) + force_src_vcrs = _vcrs_from_user_input(force_source_vcrs) + src_ccrs = _build_ccrs_from_crs_and_vcrs(crs, vcrs=force_src_vcrs) + else: + src_ccrs = crs + + # Destination Compound CRS + dst_ccrs = _build_ccrs_from_crs_and_vcrs( + crs, + vcrs=_vcrs_from_user_input(vcrs_input=dst_vcrs), + ) + + return src_ccrs, dst_ccrs + +def _to_vcrs_2d( + dem: DEMBase, + dst_vcrs: Any, + force_source_vcrs: Any | None = None, + mp_config: MultiprocConfig | None = None, +) -> DEMBase | None: + """ + Transform DEM to a different vertical CRS (no change in horizontal CRS). + + Supports direct in-memory execution, Dask execution, and Multiprocessing. + + :param dem: DEM. + :param dst_vcrs: Destination vertical CRS. + :param force_source_vcrs: Force the source vertical CRS if not defined or to override it. + :param mp_config: Multiprocessing configuration. + :returns: Transformed elevation array and destination compound CRS. + """ + + # Cannot use Multiprocessing backend and Dask backend simultaneously + mp_backend = mp_config is not None + dask_backend = da is not None and dem._chunks is not None + + if mp_backend and dask_backend: + raise ValueError( + "Cannot use Multiprocessing and Dask simultaneously. To use Dask, remove mp_config parameter " + "from to_vcrs(). To use Multiprocessing, use a DEM object input and pass mp_config." + ) + + # Build source and destination compound CRS from the input vertical CRSs + src_ccrs, dst_ccrs = _get_vertical_transform_crss( + crs=dem.crs, + dst_vcrs=dst_vcrs, + force_source_vcrs=force_source_vcrs, + ) + transform = get_geo_attr(dem, "transform") + + # If both compound CRS are equal, do not run any transform + if src_ccrs.equals(dst_ccrs): + warnings.warn( + message="Source and destination vertical CRS are the same, skipping vertical transformation.", + category=UserWarning, + ) + return None + + # Build transformer once to trigger grid download outside of parallelization + validate best available transform + # We won't be able to pass the transformer directly to the chunked functions (not serializable), + # so we'll repass the src/dst CRS + _build_vertical_transformer(crs_from=src_ccrs, crs_to=dst_ccrs) + + # Multiprocessing backend + if mp_backend: + dem_out = _multiproc_to_vcrs_2d( + dem=dem, + src_ccrs=src_ccrs, + dst_ccrs=dst_ccrs, + mp_config=mp_config, + ) + return dem_out + + else: + # Dask backend + if dask_backend: + zz_trans = _dask_to_vcrs_2d( + darr=dem.data, + transform=transform, + src_ccrs=src_ccrs, + dst_ccrs=dst_ccrs, + ) + else: + # Direct NumPy backend + transformer = _build_vertical_transformer(crs_from=src_ccrs, crs_to=dst_ccrs) + zz_trans = _to_vcrs_2d_pyproj( + data=dem.data, + transform=transform, + transformer=transformer, + ) + + dem_out = dem.from_array( + data=zz_trans, + transform=transform, + crs=dst_ccrs, + nodata=dem.nodata, + area_or_point=dem.area_or_point, + tags=dem.tags, + ) + return dem_out \ No newline at end of file diff --git a/xdem/workflows/accuracy.py b/xdem/workflows/accuracy.py index 3d2e6ac9a..f903fe423 100644 --- a/xdem/workflows/accuracy.py +++ b/xdem/workflows/accuracy.py @@ -34,6 +34,7 @@ import xdem from xdem._misc import import_optional +from xdem.vcrs import vertical_unit_symbol from xdem.workflows.schemas import ACCURACY_SCHEMA from xdem.workflows.workflows import _ALIAS, Workflows @@ -88,6 +89,7 @@ def _load_data(self) -> tuple[float, float]: vmin = float(min(np.nanpercentile(self.reference_elev, q=5), np.nanpercentile(self.to_be_aligned_elev, q=5))) vmax = float(max(np.nanpercentile(self.reference_elev, q=95), np.nanpercentile(self.to_be_aligned_elev, q=95))) + ref_vunit = vertical_unit_symbol(self.reference_elev.crs) self.generate_plot( dem=self.reference_elev, title="Reference elevation", @@ -96,7 +98,7 @@ def _load_data(self) -> tuple[float, float]: title_dem_right="To-be-aligned elevation", vmin=vmin, vmax=vmax, - cbar_title=f"Elevation ({self.reference_elev.crs.linear_units})", + cbar_title=f"Elevation ({ref_vunit})" if ref_vunit is not None else "Elevation", ) if ref_mask is not None or tba_mask is not None: if ref_mask is not None: @@ -114,7 +116,7 @@ def _load_data(self) -> tuple[float, float]: title_dem_right="Masked terrain for to-be-aligned elevation", vmin=vmin, vmax=vmax, - cbar_title=f"Elevation ({self.reference_elev.crs.linear_units})", + cbar_title=f"Elevation ({ref_vunit})" if ref_vunit is not None else "Elevation", ) return vmin, vmax @@ -204,23 +206,25 @@ def _prepare_datas(self, vmin: float, vmax: float) -> None: if sampling_source == "reference_elev": self.to_be_aligned_elev = self.to_be_aligned_elev.crop(coord_intersection) + tba_vunit = vertical_unit_symbol(self.to_be_aligned_elev.crs) self.generate_plot( self.to_be_aligned_elev, title="Preprocessed to-be-aligned elevation", filename="preprocessed_to_be_aligned_elev_map", vmin=vmin, vmax=vmax, - cbar_title=f"Elevation ({self.to_be_aligned_elev.crs.linear_units})", + cbar_title=f"Elevation ({tba_vunit})" if tba_vunit is not None else "Elevation", ) else: self.reference_elev = self.reference_elev.crop(coord_intersection) + ref_vunit = vertical_unit_symbol(self.reference_elev.crs) self.generate_plot( self.reference_elev, title="Preprocessed reference elevation", filename="preprocessed_reference_elev_map", vmin=vmin, vmax=vmax, - cbar_title=f"Elevation ({self.reference_elev.crs.linear_units})", + cbar_title=f"Elevation ({ref_vunit})" if ref_vunit is not None else "Elevation", ) if self.level > 1: @@ -291,7 +295,8 @@ def _compute_histogram(self) -> None: va="center", ) plt.title("Histogram of elevation differences\nbefore and after coregistration") - plt.xlabel(f"Elevation differences ({self.reference_elev.crs.linear_units})") + ref_vunit = vertical_unit_symbol(self.reference_elev.crs) + plt.xlabel(f"Elevation differences ({ref_vunit})" if ref_vunit is not None else "Elevation differences") plt.ylabel("Count") plt.legend() plt.grid(False) @@ -341,6 +346,7 @@ def run(self) -> None: self.stats_after["median"] + 3 * self.stats_after["nmad"], ) + ref_vunit = vertical_unit_symbol(self.reference_elev.crs) self.generate_plot( dem=self.diff_before, title="Elevation difference before coregistration", @@ -350,13 +356,14 @@ def run(self) -> None: vmin=vmin_diff, vmax=vmax_diff, cmap="RdBu", - cbar_title=f"Elevation differences ({self.diff_before.crs.linear_units})", + cbar_title=f"Elevation differences ({ref_vunit})" if ref_vunit is not None else "Elevation differences", ) else: self.diff = self.to_be_aligned_elev - ref_elev self.stats = self.diff.get_stats(stats_keys) vmin, vmax = -(self.stats["median"] + 3 * self.stats["nmad"]), self.stats["median"] + 3 * self.stats["nmad"] + ref_vunit = vertical_unit_symbol(self.reference_elev.crs) self.generate_plot( self.diff, title="Elevation difference without coregistration", @@ -364,7 +371,7 @@ def run(self) -> None: vmin=vmin, vmax=vmax, cmap="RdBu", - cbar_title=f"Elevation differences ({self.diff.crs.linear_units})", + cbar_title=f"Elevation differences ({ref_vunit})" if ref_vunit is not None else "Elevation differences", ) if self.compute_coreg: stat_items = [ diff --git a/xdem/workflows/topo.py b/xdem/workflows/topo.py index eddad5f99..71205898e 100644 --- a/xdem/workflows/topo.py +++ b/xdem/workflows/topo.py @@ -28,6 +28,7 @@ from typing import Any, Dict import xdem +from xdem.vcrs import vertical_unit_symbol from xdem._misc import import_optional from xdem.workflows.schemas import TOPO_SCHEMA from xdem.workflows.workflows import _ALIAS, Workflows @@ -68,11 +69,12 @@ def _load_data(self) -> None: """ self.dem, self.inlier_mask, path_to_mask = self.load_dem(self.config["inputs"]["reference_elev"]) + vunit = vertical_unit_symbol(self.dem.crs) self.generate_plot( self.dem, filename="elev_map", title="Elevation", - cbar_title=f"Elevation ({self.dem.crs.linear_units})", + cbar_title=f"Elevation ({vunit})" if vunit is not None else "Elevation", ) if self.inlier_mask is not None: @@ -82,7 +84,7 @@ def _load_data(self) -> None: self.dem, title="Masked elevation", filename="masked_elev_map", - cbar_title=f"Elevation ({self.dem.crs.linear_units})", + cbar_title=f"Elevation ({vunit})" if vunit is not None else "Elevation", ) def generate_terrain_attributes_tiff(self) -> None: @@ -133,7 +135,7 @@ def generate_terrain_attributes_png(self) -> None: ncols = 2 nrows = math.ceil(n / ncols) - unit = self.dem.crs.linear_units + unit = vertical_unit_symbol(self.dem.crs) attribute_params: dict[str, dict[str, Any]] = { "hillshade": {"label": "Hillshade", "cmap": "Greys_r", "vlim": (0, 255)}, "texture_shading": {"label": "Texture shading", "cmap": "Greys_r", "vlim": (-20, 20)}, @@ -152,7 +154,7 @@ def generate_terrain_attributes_png(self) -> None: "cmap": "Spectral", "vlim": (None, None), }, - "roughness": {"label": f"Roughness ({self.dem.crs.linear_units})", "cmap": "Oranges", "vlim": (None, None)}, + "roughness": {"label": f"Roughness ({unit})", "cmap": "Oranges", "vlim": (None, None)}, "fractal_dimension": {"label": "Fractal roughness (dimensions)", "cmap": "Reds", "vlim": (None, None)}, } @@ -199,8 +201,7 @@ def run(self) -> None: # Global information dem_informations = { "Driver": self.dem.driver, - "Filename": self.dem.filename, - "Grid size": self.dem.vcrs_grid, + "Filename": self.dem.name, "Number of band": self.dem.bands, "Data types": self.dem.dtype, "Nodata Value": self.dem.nodata,