11from __future__ import annotations
22from typing import Any , List , Union , Optional , Tuple , Sequence , TYPE_CHECKING
33from matplotlib .axes import Axes # type: ignore
4+ import warnings
45
56if TYPE_CHECKING :
67 from ._collection import ComparerCollection
@@ -44,7 +45,7 @@ def scatter(
4445 xlabel : Optional [str ] = None ,
4546 ylabel : Optional [str ] = None ,
4647 skill_table : Optional [Union [str , List [str ], bool ]] = None ,
47- ax : Optional [ Axes ] = None ,
48+ ax = None ,
4849 ** kwargs ,
4950 ):
5051 """Scatter plot showing compared data: observation vs modelled
@@ -113,11 +114,72 @@ def scatter(
113114 >>> cc.plot.scatter(observations=['c2','HKNA'])
114115 """
115116
116- # select model
117- mod_id = _get_idx (model , self .cc .mod_names )
118- mod_name = self .cc .mod_names [mod_id ]
117+ cc = self .cc
118+ if model is None :
119+ mod_names = cc .mod_names
120+ else :
121+ warnings .warn (
122+ "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.scatter()" ,
123+ FutureWarning ,
124+ )
125+
126+ model_list = [model ] if isinstance (model , (str , int )) else model
127+ mod_names = [
128+ self .cc .mod_names [_get_idx (m , self .cc .mod_names )] for m in model_list
129+ ]
130+
131+ axes = []
132+ for mod_name in mod_names :
133+ ax_mod = self ._scatter_one_model (
134+ mod_name = mod_name ,
135+ bins = bins ,
136+ quantiles = quantiles ,
137+ fit_to_quantiles = fit_to_quantiles ,
138+ show_points = show_points ,
139+ show_hist = show_hist ,
140+ show_density = show_density ,
141+ backend = backend ,
142+ figsize = figsize ,
143+ xlim = xlim ,
144+ ylim = ylim ,
145+ reg_method = reg_method ,
146+ title = title ,
147+ xlabel = xlabel ,
148+ ylabel = ylabel ,
149+ skill_table = skill_table ,
150+ ax = ax ,
151+ ** kwargs ,
152+ )
153+ axes .append (ax_mod )
154+ return axes [0 ] if len (axes ) == 1 else axes
119155
120- cmp = self .cc
156+ def _scatter_one_model (
157+ self ,
158+ * ,
159+ mod_name : str ,
160+ bins : int | float ,
161+ quantiles : int | Sequence [float ] | None ,
162+ fit_to_quantiles : bool ,
163+ show_points : bool | int | float | None ,
164+ show_hist : Optional [bool ],
165+ show_density : Optional [bool ],
166+ backend : str ,
167+ figsize : Tuple [float , float ],
168+ xlim : Optional [Tuple [float , float ]],
169+ ylim : Optional [Tuple [float , float ]],
170+ reg_method : str | bool ,
171+ title : Optional [str ],
172+ xlabel : Optional [str ],
173+ ylabel : Optional [str ],
174+ skill_table : Optional [Union [str , List [str ], bool ]],
175+ ax ,
176+ ** kwargs ,
177+ ):
178+ assert (
179+ mod_name in self .cc .mod_names
180+ ), f"Model { mod_name } not found in collection { self .cc .mod_names } "
181+
182+ cmp = self .cc .sel (model = mod_name )
121183
122184 if cmp .n_points == 0 :
123185 raise ValueError ("No data found in selection" )
@@ -183,7 +245,7 @@ def scatter(
183245
184246 return ax
185247
186- def kde (self , ax = None , figsize = None , title = None , ** kwargs ) -> Axes :
248+ def kde (self , * , ax = None , figsize = None , title = None , ** kwargs ) -> Axes :
187249 """Plot kernel density estimate of observation and model data.
188250
189251 Parameters
@@ -247,10 +309,11 @@ def kde(self, ax=None, figsize=None, title=None, **kwargs) -> Axes:
247309
248310 def hist (
249311 self ,
250- model = None ,
251- bins = 100 ,
312+ bins : int | Sequence = 100 ,
313+ * ,
314+ model : str | int | None = None ,
252315 title : Optional [str ] = None ,
253- density = True ,
316+ density : bool = True ,
254317 alpha : float = 0.5 ,
255318 ax = None ,
256319 figsize : Optional [Tuple [float , float ]] = None ,
@@ -262,8 +325,6 @@ def hist(
262325
263326 Parameters
264327 ----------
265- model : str, optional
266- model name, by default None, i.e. the first model
267328 bins : int, optional
268329 number of bins, by default 100
269330 title : str, optional
@@ -292,12 +353,53 @@ def hist(
292353 pandas.Series.hist
293354 matplotlib.axes.Axes.hist
294355 """
356+ if model is None :
357+ mod_names = self .cc .mod_names
358+ else :
359+ warnings .warn (
360+ "The 'model' keyword is deprecated! Instead, filter comparer before plotting cmp.sel(model=...).plot.hist()" ,
361+ FutureWarning ,
362+ )
363+ model_list = [model ] if isinstance (model , (str , int )) else model
364+ mod_names = [
365+ self .cc .mod_names [_get_idx (m , self .cc .mod_names )] for m in model_list
366+ ]
367+
368+ axes = []
369+ for mod_name in mod_names :
370+ ax_mod = self ._hist_one_model (
371+ mod_name = mod_name ,
372+ bins = bins ,
373+ title = title ,
374+ density = density ,
375+ alpha = alpha ,
376+ ax = ax ,
377+ figsize = figsize ,
378+ ** kwargs ,
379+ )
380+ axes .append (ax_mod )
381+ return axes [0 ] if len (axes ) == 1 else axes
382+
383+ def _hist_one_model (
384+ self ,
385+ * ,
386+ mod_name : str ,
387+ bins : int | Sequence ,
388+ title : Optional [str ],
389+ density : bool ,
390+ alpha : float ,
391+ ax ,
392+ figsize : Optional [Tuple [float , float ]],
393+ ** kwargs ,
394+ ):
295395 from ._comparison import MOD_COLORS
296396
297397 _ , ax = _get_fig_ax (ax , figsize )
298398
299- mod_id = _get_idx (model , self .cc .mod_names )
300- mod_name = self .cc .mod_names [mod_id ]
399+ assert (
400+ mod_name in self .cc .mod_names
401+ ), f"Model { mod_name } not found in collection"
402+ mod_id = _get_idx (mod_name , self .cc .mod_names )
301403
302404 title = (
303405 _default_univarate_title ("Histogram" , self .cc ) if title is None else title
@@ -331,6 +433,7 @@ def hist(
331433
332434 def taylor (
333435 self ,
436+ * ,
334437 normalize_std : bool = False ,
335438 aggregate_observations : bool = True ,
336439 figsize : Tuple [float , float ] = (7 , 7 ),
0 commit comments