Skip to content

Commit 8cca083

Browse files
committed
Merge branch 'main' of https://github.com/DHI/modelskill
2 parents 8b671b1 + f843b9a commit 8cca083

3 files changed

Lines changed: 241 additions & 36 deletions

File tree

modelskill/comparison/_collection_plotter.py

Lines changed: 116 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from typing import Any, List, Union, Optional, Tuple, Sequence, TYPE_CHECKING
33
from matplotlib.axes import Axes # type: ignore
4+
import warnings
45

56
if TYPE_CHECKING:
67
from ._collection import ComparerCollection
@@ -44,7 +45,7 @@ def scatter(
4445
xlabel: Optional[str] = None,
4546
ylabel: Optional[str] = None,
4647
skill_table: Optional[Union[str, List[str], bool]] = None,
47-
ax: Optional[Axes] = None,
48+
ax=None,
4849
**kwargs,
4950
):
5051
"""Scatter plot showing compared data: observation vs modelled
@@ -113,11 +114,72 @@ def scatter(
113114
>>> cc.plot.scatter(observations=['c2','HKNA'])
114115
"""
115116

116-
# select model
117-
mod_id = _get_idx(model, self.cc.mod_names)
118-
mod_name = self.cc.mod_names[mod_id]
117+
cc = self.cc
118+
if model is None:
119+
mod_names = cc.mod_names
120+
else:
121+
warnings.warn(
122+
"The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.scatter()",
123+
FutureWarning,
124+
)
125+
126+
model_list = [model] if isinstance(model, (str, int)) else model
127+
mod_names = [
128+
self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list
129+
]
130+
131+
axes = []
132+
for mod_name in mod_names:
133+
ax_mod = self._scatter_one_model(
134+
mod_name=mod_name,
135+
bins=bins,
136+
quantiles=quantiles,
137+
fit_to_quantiles=fit_to_quantiles,
138+
show_points=show_points,
139+
show_hist=show_hist,
140+
show_density=show_density,
141+
backend=backend,
142+
figsize=figsize,
143+
xlim=xlim,
144+
ylim=ylim,
145+
reg_method=reg_method,
146+
title=title,
147+
xlabel=xlabel,
148+
ylabel=ylabel,
149+
skill_table=skill_table,
150+
ax=ax,
151+
**kwargs,
152+
)
153+
axes.append(ax_mod)
154+
return axes[0] if len(axes) == 1 else axes
119155

120-
cmp = self.cc
156+
def _scatter_one_model(
157+
self,
158+
*,
159+
mod_name: str,
160+
bins: int | float,
161+
quantiles: int | Sequence[float] | None,
162+
fit_to_quantiles: bool,
163+
show_points: bool | int | float | None,
164+
show_hist: Optional[bool],
165+
show_density: Optional[bool],
166+
backend: str,
167+
figsize: Tuple[float, float],
168+
xlim: Optional[Tuple[float, float]],
169+
ylim: Optional[Tuple[float, float]],
170+
reg_method: str | bool,
171+
title: Optional[str],
172+
xlabel: Optional[str],
173+
ylabel: Optional[str],
174+
skill_table: Optional[Union[str, List[str], bool]],
175+
ax,
176+
**kwargs,
177+
):
178+
assert (
179+
mod_name in self.cc.mod_names
180+
), f"Model {mod_name} not found in collection {self.cc.mod_names}"
181+
182+
cmp = self.cc.sel(model=mod_name)
121183

122184
if cmp.n_points == 0:
123185
raise ValueError("No data found in selection")
@@ -183,7 +245,7 @@ def scatter(
183245

184246
return ax
185247

186-
def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:
248+
def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes:
187249
"""Plot kernel density estimate of observation and model data.
188250
189251
Parameters
@@ -247,10 +309,11 @@ def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:
247309

248310
def hist(
249311
self,
250-
model=None,
251-
bins=100,
312+
bins: int | Sequence = 100,
313+
*,
314+
model: str | int | None = None,
252315
title: Optional[str] = None,
253-
density=True,
316+
density: bool = True,
254317
alpha: float = 0.5,
255318
ax=None,
256319
figsize: Optional[Tuple[float, float]] = None,
@@ -262,8 +325,6 @@ def hist(
262325
263326
Parameters
264327
----------
265-
model : str, optional
266-
model name, by default None, i.e. the first model
267328
bins : int, optional
268329
number of bins, by default 100
269330
title : str, optional
@@ -292,12 +353,53 @@ def hist(
292353
pandas.Series.hist
293354
matplotlib.axes.Axes.hist
294355
"""
356+
if model is None:
357+
mod_names = self.cc.mod_names
358+
else:
359+
warnings.warn(
360+
"The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.hist()",
361+
FutureWarning,
362+
)
363+
model_list = [model] if isinstance(model, (str, int)) else model
364+
mod_names = [
365+
self.cc.mod_names[_get_idx(m, self.cc.mod_names)] for m in model_list
366+
]
367+
368+
axes = []
369+
for mod_name in mod_names:
370+
ax_mod = self._hist_one_model(
371+
mod_name=mod_name,
372+
bins=bins,
373+
title=title,
374+
density=density,
375+
alpha=alpha,
376+
ax=ax,
377+
figsize=figsize,
378+
**kwargs,
379+
)
380+
axes.append(ax_mod)
381+
return axes[0] if len(axes) == 1 else axes
382+
383+
def _hist_one_model(
384+
self,
385+
*,
386+
mod_name: str,
387+
bins: int | Sequence,
388+
title: Optional[str],
389+
density: bool,
390+
alpha: float,
391+
ax,
392+
figsize: Optional[Tuple[float, float]],
393+
**kwargs,
394+
):
295395
from ._comparison import MOD_COLORS
296396

297397
_, ax = _get_fig_ax(ax, figsize)
298398

299-
mod_id = _get_idx(model, self.cc.mod_names)
300-
mod_name = self.cc.mod_names[mod_id]
399+
assert (
400+
mod_name in self.cc.mod_names
401+
), f"Model {mod_name} not found in collection"
402+
mod_id = _get_idx(mod_name, self.cc.mod_names)
301403

302404
title = (
303405
_default_univarate_title("Histogram", self.cc) if title is None else title
@@ -331,6 +433,7 @@ def hist(
331433

332434
def taylor(
333435
self,
436+
*,
334437
normalize_std: bool = False,
335438
aggregate_observations: bool = True,
336439
figsize: Tuple[float, float] = (7, 7),

0 commit comments

Comments
 (0)