Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions evaluation/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from typing import Any, Callable

import numpy as np
from numpy import fft
import xarray as xr
from tqdm import tqdm
import pandas as pd

from numpy import fft
from pysteps.verification.salscores import sal as pysteps_sal

def _filter_by_season(x: xr.Dataset, season: str | None) -> xr.Dataset:
"""
Expand Down Expand Up @@ -35,6 +38,33 @@ def _filter_by_season(x: xr.Dataset, season: str | None) -> xr.Dataset:
return x.where(x["time.season"] == "SON", drop=True)
return x

def _filter_by_dates(x: xr.Dataset, dates: list[str] | None) -> xr.Dataset:
"""
Filter the dataset by specific dates.

Parameters
----------
x : xr.Dataset
Dataset to filter.
dates : list[str] | None
List of date strings in format 'YYYY-MM-DD' or None to skip filtering.

Returns
-------
xr.Dataset
Filtered dataset.
"""
if dates is None:
return x

target_dates = pd.to_datetime(dates)
x_dates = pd.to_datetime(x['time'].dt.date)

mask = x_dates.isin(target_dates)
mask_da = xr.DataArray(mask, coords={'time': x['time']}, dims=['time'])

return x.where(mask_da, drop=True)

def _radial_average(array_2d: np.ndarray) -> np.ndarray:
"""
Compute the radial average of a two-dimensional field.
Expand Down Expand Up @@ -280,4 +310,73 @@ def psd(
psd_x0_da = xr.DataArray(avg_psd_x0, dims=["wavenumber"], name="PSD_x0")
psd_x1_da = xr.DataArray(avg_psd_x1, dims=["wavenumber"], name="PSD_x1")

return psd_x0_da, psd_x1_da
return psd_x0_da, psd_x1_da


def sal(
x0: xr.Dataset,
x1: xr.Dataset,
var: str,
season: str | None = None,
dates: list[str] | None = None,
thr_factor: float = 1/15.,
thr_quantile: float = 0.95,
) -> xr.Dataset:
"""
Compute the Structure-Amplitude-Location (SAL) spatial verification metric.

Parameters
----------
x0 : xr.Dataset
Ground truth (observation) dataset.
x1 : xr.Dataset
Predicted dataset.
var : str
Variable name to analyse.
season : str | None, optional
Season to filter before computing the SAL metric.
dates : list[str] | None, optional
List of specific dates to filter by (format: 'YYYY-MM-DD'). If None, no date filtering is applied.
thr_factor : float, default 1/15.
Factor used to compute the detection threshold.
thr_quantile : float, default 0.95
The wet quantile between 0 and 1 used to define the detection threshold.

Returns
-------
xr.Dataset
Dataset containing the structure, amplitude, and location components of the SAL score.
"""

x0 = _filter_by_season(x0, season)
x1 = _filter_by_season(x1, season)

x0 = _filter_by_dates(x0, dates)
x1 = _filter_by_dates(x1, dates)

x0_np = np.array(x0[var].values)
x1_np = np.nan_to_num(x1[var].values)

sal_structure = []
sal_amplitude = []
sal_location = []

for i in tqdm(range(x0_np.shape[0]), desc="Computing SAL", leave=False):
s, a, l = pysteps_sal(
x1_np[i],
x0_np[i],
thr_factor=thr_factor,
thr_quantile=thr_quantile,
)

sal_structure.append(s)
sal_amplitude.append(a)
sal_location.append(l)

final_score = xr.Dataset({"SAL_structure": ("time", np.array(sal_structure)),
"SAL_amplitude": ("time", np.array(sal_amplitude)),
"SAL_location": ("time", np.array(sal_location))})

final_score['time'] = x0['time']

return final_score