diff --git a/.travis.yml b/.travis.yml index 212ddb77daa..4e5835b6da9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,6 +17,7 @@ matrix: - env: - CONDA_ENV=py36 - EXTRA_FLAGS="--run-flaky --run-network-tests" + - env: CONDA_ENV=py36-animatplot - env: CONDA_ENV=py36-dask-dev - env: CONDA_ENV=py36-pandas-dev - env: CONDA_ENV=py36-rasterio diff --git a/ci/requirements-py36-animatplot.yml b/ci/requirements-py36-animatplot.yml new file mode 100644 index 00000000000..993819dba98 --- /dev/null +++ b/ci/requirements-py36-animatplot.yml @@ -0,0 +1,25 @@ +name: test_env +channels: + - conda-forge +dependencies: + - python=3.6 + - animatplot + - cftime + - dask + - distributed + - h5py + - h5netcdf + - matplotlib + - netcdf4 + - pytest + - pytest-cov + - pytest-env + - coveralls + - pycodestyle + - numpy + - pandas + - scipy + - seaborn + - toolz + - bottleneck + - zarr diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4c126196469..0994654bd25 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -144,8 +144,9 @@ Bug fixes from higher frequencies to lower frequencies. Datapoints outside the bounds of the original time coordinate are now filled with NaN (:issue:`2197`). By `Spencer Clark `_. -- Line plots with the ``x`` argument set to a non-dimensional coord now plot the correct data for 1D DataArrays. - (:issue:`27251`). By `Tom Nicholas `_. +- Line plots with the ``x`` argument set to a non-dimensional coord now plot the + correct data for 1D DataArrays (:issue:`27251`). + By `Tom Nicholas `_. - Subtracting a scalar ``cftime.datetime`` object from a :py:class:`CFTimeIndex` now results in a :py:class:`pandas.TimedeltaIndex` instead of raising a ``TypeError`` (:issue:`2671`). By `Spencer Clark @@ -186,7 +187,7 @@ Bug fixes (e.g. '2000-01-01T00:00:00-05:00') no longer raises an error (:issue:`2649`). By `Spencer Clark `_. - Fixed performance regression with ``open_mfdataset`` (:issue:`2662`). - By `Tom Nicholas `_. + - Fixed supplying an explicit dimension in the ``concat_dim`` argument to to ``open_mfdataset`` (:issue:`2647`). By `Ben Root `_. diff --git a/setup.cfg b/setup.cfg index 18922b1647a..1a477ef45e3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,6 +21,8 @@ known_first_party=xarray multi_line_output=4 # Most of the numerical computing stack doesn't have type annotations yet. +[mypy-animatplot.*] +ignore_missing_imports = True [mypy-bottleneck.*] ignore_missing_imports = True [mypy-cdms2.*] diff --git a/xarray/plot/animate.py b/xarray/plot/animate.py new file mode 100644 index 00000000000..fb26ddeb2ce --- /dev/null +++ b/xarray/plot/animate.py @@ -0,0 +1,186 @@ +""" +Use this module directly: + import xarray.animate as xanim + +Or supply an ``animate`` keyword +argument to a normal plotting function: + DataArray.plot._____(animate='__') +""" + +import datetime + +import numpy as np +import pandas as pd + +from .utils import (_infer_line_data, _ensure_plottable, _update_axes, + get_axis, _rotate_date_xlabels, _check_animate, + _transpose_before_animation, import_matplotlib_pyplot) + + +def line(darray, animate=None, **kwargs): + """ + Line plot of DataArray index against values + + Wraps :func:`animatplot:animatplot.blocks.Line` + + Parameters + ---------- + darray : DataArray + Must be 2 dimensional. + animate: str + Dimension or coord in the DataArray over which to animate. + ``animatplot.blocks.Line`` will be used to animate the plot over this + dimension. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axis on which to plot this figure. By default, use the current axis. + Mutually exclusive with ``size`` and ``figsize``. + hue : string, optional + Dimension or coordinate for which you want multiple lines plotted. + If plotting against a 2D coordinate, ``hue`` must be a dimension. + x, y : string, optional + Dimensions or coordinates for x, y axis. + Only one of these may be specified. + The other coordinate plots values from the DataArray on which this + plot method is called. + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : optional + Specify x- and y-axes limits. + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_legend : boolean, optional + Add legend with y axis coordinates (3D inputs only). + **kwargs : optional + Additional arguments to animatplot.blocks.Line + + """ + + from animatplot.blocks import Line, Title + from animatplot.animation import Animation + + row = kwargs.pop('row', None) + col = kwargs.pop('col', None) + if row or col: + raise NotImplementedError("Animated FacetGrids not yet implemented") + + _check_animate(darray, animate) + darray = _transpose_before_animation(darray, animate) + + ndims = len(darray.dims) + if ndims > 3: + raise ValueError('Animated line plots are for 2- or 3-dimensional ' + 'DataArrays. Passed DataArray has {ndims} ' + 'dimensions'.format(ndims=ndims + 1)) + + # Ensures consistency with .plot method + figsize = kwargs.pop('figsize', None) + aspect = kwargs.pop('aspect', None) + size = kwargs.pop('size', None) + ax = kwargs.pop('ax', None) + hue = kwargs.pop('hue', None) + x = kwargs.pop('x', None) + y = kwargs.pop('y', None) + linestyle = kwargs.get('linestyle', '') + xincrease = kwargs.pop('xincrease', None) # default needs to be None + yincrease = kwargs.pop('yincrease', None) + xscale = kwargs.pop('xscale', None) # default needs to be None + yscale = kwargs.pop('yscale', None) + xticks = kwargs.pop('xticks', None) + yticks = kwargs.pop('yticks', None) + xlim = kwargs.pop('xlim', None) + ylim = kwargs.pop('ylim', None) + add_legend = kwargs.pop('add_legend', True) + _labels = kwargs.pop('_labels', True) + + ax = get_axis(figsize, size, aspect, ax) + xplt_val, yplt_val, hueplt, xlabel, ylabel, huelabel = \ + _infer_line_data(darray, x, y, hue, animate, linestyle) + + _ensure_plottable(xplt_val, yplt_val) + + fps = kwargs.pop('fps', 10) + timeline = _create_timeline(darray, animate, fps) + + if ylim is None: + ylim = [np.min(yplt_val), np.max(yplt_val)] + + # TODO this currently breaks step plots because they have a list of arrays + # for yplt_val + num_lines = len(hueplt) if hueplt is not None else 1 + # We transposed in _infer_line_data so that animate is last dim and hue is + # second-last dim + # TODO think of a more robust way of doing this + hueaxis = -2 if hue else 0 + + line_blocks = [Line(xplt_val, yplt_val_line.squeeze(), + ax=ax, t_axis=-1, **kwargs) + for yplt_val_line in np.split(yplt_val, num_lines, hueaxis)] + + # TODO if not _labels then no Title block is needed + if _labels: + if xlabel is not None: + ax.set_xlabel(xlabel) + + if ylabel is not None: + ax.set_ylabel(ylabel) + + # Would be nicer if we had something like in GH issue #266 + frame_titles = [darray[{animate: i}]._title_for_slice() + for i in range(len(timeline))] + title_block = Title(frame_titles, ax=ax) + + if ndims == 3 and add_legend: + # TODO ensure the legend stays in the same place throughout animation + ax.legend(handles=[block.line for block in line_blocks], + labels=list(hueplt.values), + title=huelabel) + + _rotate_date_xlabels(xplt_val, ax) + + _update_axes(ax, xincrease, yincrease, xscale, yscale, + xticks, yticks, xlim, ylim) + + anim = Animation([*line_blocks, title_block], timeline=timeline) + anim.controls(timeline_slider_args={'text': animate, 'valfmt': '%s'}) + + # Stop subsequent matplotlib plotting calls plotting onto the pause button! + plt = import_matplotlib_pyplot() + plt.sca(ax) + + return anim + + +def _create_timeline(darray, animate, fps): + + from animatplot.animation import Timeline + + if animate in darray.coords: + t_array = darray.coords[animate].values + + # Format datetimes in a nicer way + if isinstance(t_array[0], datetime.date) \ + or np.issubdtype(t_array.dtype, np.datetime64): + t_array = [pd.to_datetime(date) for date in t_array] + + else: # animating over a dimension without coords + t_array = np.arange(darray.sizes[animate]) + + if darray.coords[animate].attrs.get('units'): + units = ' [{}]'.format(darray.coords[animate].attrs['units']) + else: + units = '' + return Timeline(t_array, units=units, fps=fps) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 4f0232236f8..dbf5ebfd2c0 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -267,9 +267,11 @@ def map_dataarray_line(self, func, x, y, **kwargs): mappable = func(subset, x=x, y=y, ax=ax, **func_kwargs) self._mappables.append(mappable) + animate = kwargs.pop('animate', None) _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( darray=self.data.loc[self.name_dicts.flat[0]], - x=x, y=y, hue=func_kwargs['hue']) + x=x, y=y, hue=func_kwargs['hue'], animate=animate, + linestyle=func_kwargs.get('linestyle', '')) self._hue_var = hueplt self._hue_label = huelabel diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 8e2457603d6..d814a1c9270 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -13,91 +13,15 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _interval_to_double_bound_points, _interval_to_mid_points, - _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_2dplot, - _update_axes, _valid_other_type, get_axis, import_matplotlib_pyplot, - label_from_attrs) - - -def _infer_line_data(darray, x, y, hue): - error_msg = ('must be either None or one of ({0:s})' - .format(', '.join([repr(dd) for dd in darray.dims]))) - ndims = len(darray.dims) - - if x is not None and x not in darray.dims and x not in darray.coords: - raise ValueError('x ' + error_msg) - - if y is not None and y not in darray.dims and y not in darray.coords: - raise ValueError('y ' + error_msg) - - if x is not None and y is not None: - raise ValueError('You cannot specify both x and y kwargs' - 'for line plots.') - - if ndims == 1: - huename = None - hueplt = None - huelabel = '' - - if x is not None: - xplt = darray[x] - yplt = darray - - elif y is not None: - xplt = darray - yplt = darray[y] - - else: # Both x & y are None - dim = darray.dims[0] - xplt = darray[dim] - yplt = darray - - else: - if x is None and y is None and hue is None: - raise ValueError('For 2D inputs, please' - 'specify either hue, x or y.') - - if y is None: - xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename) - xplt = xplt.transpose(otherdim, huename) - else: - raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) - - else: - yplt = darray.transpose(xname, huename) - - else: - yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - xplt = darray.transpose(otherdim, huename) - else: - raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) - - else: - xplt = darray.transpose(yname, huename) - - huelabel = label_from_attrs(darray[huename]) - hueplt = darray[huename] - - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) - - return xplt, yplt, hueplt, xlabel, ylabel, huelabel + _infer_line_data, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, + _resolve_intervals_2dplot, + _update_axes, get_axis, import_matplotlib_pyplot, + label_from_attrs, _rotate_date_xlabels, _check_animate, + _transpose_before_animation, _infer_plot_type) def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, - rtol=0.01, subplot_kws=None, **kwargs): + rtol=0.01, animate=None, subplot_kws=None, **kwargs): """ Default plot of DataArray using matplotlib.pyplot. @@ -123,6 +47,11 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, If passed, make faceted line plots with hue on this dimension name col_wrap : integer, optional Use together with ``col`` to wrap faceted plots + animate: str, optional + Dimension or coord in the DataArray over which to animate. If this + argument is supplied then ``animatplot`` will be used to animate the + corresponding plot. The DataArray must have 1 more dimension than + specified in the table above. ax : matplotlib axes, optional If None, uses the current axis. Not applicable when using facets. rtol : number, optional @@ -135,41 +64,9 @@ def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None, Additional keyword arguments to matplotlib """ - darray = darray.squeeze() - - plot_dims = set(darray.dims) - plot_dims.discard(row) - plot_dims.discard(col) - plot_dims.discard(hue) - - ndims = len(plot_dims) - - error_msg = ('Only 1d and 2d plots are supported for facets in xarray. ' - 'See the package `Seaborn` for more options.') - - if ndims in [1, 2]: - if row or col: - kwargs['row'] = row - kwargs['col'] = col - kwargs['col_wrap'] = col_wrap - kwargs['subplot_kws'] = subplot_kws - if ndims == 1: - plotfunc = line - kwargs['hue'] = hue - elif ndims == 2: - if hue: - plotfunc = line - kwargs['hue'] = hue - else: - plotfunc = pcolormesh - else: - if row or col or hue: - raise ValueError(error_msg) - plotfunc = hist - - kwargs['ax'] = ax - - return plotfunc(darray, **kwargs) + return _infer_plot_type(darray, row=row, col=col, col_wrap=col_wrap, + ax=ax, hue=hue, rtol=rtol, animate=animate, + subplot_kws=subplot_kws, **kwargs) # This function signature should not change so that it can use @@ -183,7 +80,8 @@ def line(darray, *args, **kwargs): Parameters ---------- darray : DataArray - Must be 1 dimensional + Must be 1 dimensional, unless ``animate`` is specified, in which case + it must be 2 dimensional. figsize : tuple, optional A tuple (width, height) of the figure in inches. Mutually exclusive with ``size`` and ``ax``. @@ -199,6 +97,10 @@ def line(darray, *args, **kwargs): hue : string, optional Dimension or coordinate for which you want multiple lines plotted. If plotting against a 2D coordinate, ``hue`` must be a dimension. + animate: str, optional + Dimension or coord in the DataArray over which to animate. If this + argument is supplied then this function will redirect to + ``xarray.animate.animate_line``. x, y : string, optional Dimensions or coordinates for x, y axis. Only one of these may be specified. @@ -221,10 +123,19 @@ def line(darray, *args, **kwargs): """ + animate = kwargs.pop('animate', None) + if animate is not None: + animate_dim = _check_animate(darray, animate) + darray = _transpose_before_animation(darray, animate) + from .animate import line as animate_line + return animate_line(darray, animate=animate, **kwargs) + # Handle facetgrids first row = kwargs.pop('row', None) col = kwargs.pop('col', None) if row or col: + if animate is not None: + raise NotImplementedError allargs = locals().copy() allargs.update(allargs.pop('kwargs')) allargs.pop('darray') @@ -244,6 +155,7 @@ def line(darray, *args, **kwargs): hue = kwargs.pop('hue', None) x = kwargs.pop('x', None) y = kwargs.pop('y', None) + linestyle = kwargs.get('linestyle', '') xincrease = kwargs.pop('xincrease', None) # default needs to be None yincrease = kwargs.pop('yincrease', None) xscale = kwargs.pop('xscale', None) # default needs to be None @@ -258,29 +170,8 @@ def line(darray, *args, **kwargs): args = kwargs.pop('args', ()) ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ - _infer_line_data(darray, x, y, hue) - - # Remove pd.Intervals if contained in xplt.values. - if _valid_other_type(xplt.values, [pd.Interval]): - # Is it a step plot? (see matplotlib.Axes.step) - if kwargs.get('linestyle', '').startswith('steps-'): - xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, - yplt.values) - # Remove steps-* to be sure that matplotlib is not confused - kwargs['linestyle'] = (kwargs['linestyle'] - .replace('steps-pre', '') - .replace('steps-post', '') - .replace('steps-mid', '')) - if kwargs['linestyle'] == '': - kwargs.pop('linestyle') - else: - xplt_val = _interval_to_mid_points(xplt.values) - yplt_val = yplt.values - xlabel += '_center' - else: - xplt_val = xplt.values - yplt_val = yplt.values + xplt_val, yplt_val, hueplt, xlabel, ylabel, huelabel = \ + _infer_line_data(darray, x, y, hue, animate, linestyle) _ensure_plottable(xplt_val, yplt_val) @@ -295,19 +186,12 @@ def line(darray, *args, **kwargs): ax.set_title(darray._title_for_slice()) - if darray.ndim == 2 and add_legend: + if ndims == 2 and add_legend: ax.legend(handles=primitive, labels=list(hueplt.values), title=huelabel) - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha('right') + _rotate_date_xlabels(xplt_val, ax) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -410,7 +294,7 @@ def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. -class _PlotMethods(object): +class _PlotMethods: """ Enables use of xarray.plot functions as attributes on a DataArray. For example, DataArray.plot.imshow diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 21523ede4cd..a319dd90afd 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -273,6 +273,177 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, levels=levels, norm=norm) +def _infer_plot_type(darray, row=None, col=None, col_wrap=None, ax=None, + hue=None, rtol=0.01, animate=None, subplot_kws=None, + **kwargs): + from .plot import line, pcolormesh, hist + + darray = darray.squeeze() + + if animate is not None: + animate_dim = _check_animate(darray, animate) + kwargs['animate'] = animate + if col is not None or row is not None: + raise NotImplementedError("Animated FacetGrids not yet supported") + else: + animate_dim = None + + dims = set(darray.dims) + if animate is not None: + plot_dims = dims - set([animate_dim]) + else: + plot_dims = dims + + plot_dims.discard(row) + plot_dims.discard(col) + plot_dims.discard(hue) + + nplotdims = len(plot_dims) + + error_msg = ('Only 1d and 2d plots are supported for facets in xarray. ' + 'See the package `Seaborn` for more options.') + + if nplotdims in [1, 2]: + if row or col: + kwargs['row'] = row + kwargs['col'] = col + kwargs['col_wrap'] = col_wrap + kwargs['subplot_kws'] = subplot_kws + if nplotdims == 1: + plotfunc = line + kwargs['hue'] = hue + elif nplotdims == 2: + if hue: + plotfunc = line + kwargs['hue'] = hue + else: + plotfunc = pcolormesh + else: + if row or col or hue: + raise ValueError(error_msg) + plotfunc = hist + + kwargs['ax'] = ax + + if animate is not None: + if plotfunc is line: + from .animate import line as animate_line + plotfunc = animate_line + else: + raise NotImplementedError + + return plotfunc(darray, **kwargs) + + +def _infer_line_data(darray, x, y, hue, animate, linestyle): + error_msg = ('must be either None or one of ({0:s})' + .format(', '.join([repr(dd) for dd in darray.dims]))) + ndims = len(darray.dims) + + if x is not None and x not in darray.dims and x not in darray.coords: + raise ValueError('x ' + error_msg) + + if y is not None and y not in darray.dims and y not in darray.coords: + raise ValueError('y ' + error_msg) + + if x is not None and y is not None: + raise ValueError('You cannot specify both x and y kwargs' + 'for line plots.') + + # TODO there must be a neat one-line way of doing this check + animate_ndim = 1 if animate is not None else 0 + if ndims - animate_ndim == 1: + huename = None + hueplt = None + huelabel = '' + + if x is not None: + xplt = darray[x] + yplt = darray + + elif y is not None: + xplt = darray + yplt = darray[y] + + else: # Both x & y are None + dim = darray.dims[0] + xplt = darray[dim] + yplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError('For 2D inputs, please' + 'specify either hue, x or y.') + + if y is None: + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue, + animate=animate) + xplt = darray[xname] + if xplt.ndim > 1: + if animate is not None: + raise NotImplementedError + + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename) + xplt = xplt.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + if animate is not None: + yplt = darray.transpose(xname, huename, animate) + else: + yplt = darray.transpose(xname, huename) + + else: + if animate is not None: + raise NotImplementedError + + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + xplt = darray.transpose(otherdim, huename) + else: + raise ValueError('For 2D inputs, hue must be a dimension' + + ' i.e. one of ' + repr(darray.dims)) + + else: + xplt = darray.transpose(yname, huename) + + huelabel = label_from_attrs(darray[huename]) + hueplt = darray[huename] + + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) + + # Remove pd.Intervals if contained in xplt.values. + if _valid_other_type(xplt.values, [pd.Interval]): + # Is it a step plot? (see matplotlib.Axes.step) + if linestyle.startswith('steps-'): + xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, + yplt.values) + # Remove steps-* to be sure that matplotlib is not confused + linestyle = (linestyle.replace('steps-pre', '') + .replace('steps-post', '') + .replace('steps-mid', '')) + # if kwargs['linestyle'] == '': + # kwargs.pop('linestyle') + else: + xplt_val = _interval_to_mid_points(xplt.values) + yplt_val = yplt.values + xlabel += '_center' + else: + xplt_val = xplt.values + yplt_val = yplt.values + + return xplt_val, yplt_val, hueplt, xlabel, ylabel, huelabel + + def _infer_xy_labels_3d(darray, x, y, rgb): """ Determine x and y labels for showing RGB images. @@ -323,28 +494,41 @@ def _infer_xy_labels_3d(darray, x, y, rgb): return _infer_xy_labels(darray.isel(**{rgb: 0}), x, y) -def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): +def _infer_xy_labels(darray, x, y, animate=None, imshow=False, rgb=None): """ Determine x and y labels. For use in _plot2d darray must be a 2 dimensional data array, or 3d for imshow only. """ assert x is None or x != y + if animate is not None: + assert animate != x and animate != y + if imshow and darray.ndim == 3: + if animate is not None: + raise NotImplementedError return _infer_xy_labels_3d(darray, x, y, rgb) + # TODO there must be a more pythonic way of doing this + dims = list(darray.dims) + if animate in dims: + dims.remove(animate) + plotdims = tuple(dims) + if x is None and y is None: - if darray.ndim != 2: - raise ValueError('DataArray must be 2d') - y, x = darray.dims + required_ndims = 2 if animate is None else 3 + if darray.ndim != required_ndims: + raise ValueError('DataArray must be {}d'.format(required_ndims)) + y, x = plotdims elif x is None: if y not in darray.dims and y not in darray.coords: raise ValueError('y must be a dimension name if x is not supplied') - x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] + x = plotdims[0] if y == plotdims[1] else plotdims[1] elif y is None: if x not in darray.dims and x not in darray.coords: - raise ValueError('x must be a dimension name if y is not supplied') - y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] + raise ValueError( + 'x must be a dimension name if y is not supplied') + y = plotdims[0] if x == plotdims[1] else plotdims[1] elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): raise ValueError('x and y must be coordinate variables') return x, y @@ -400,6 +584,17 @@ def label_from_attrs(da, extra=''): return '\n'.join(textwrap.wrap(name + extra + units, 30)) +def _rotate_date_xlabels(xdata, ax): + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(np.array(xdata).dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha('right') + + def _interval_to_mid_points(array): """ Helper function which returns an array @@ -702,3 +897,28 @@ def _process_cmap_cbar_kwargs(func, kwargs, data): cmap_params = _determine_cmap_params(**cmap_kwargs) return cmap_params, cbar_kwargs + + +def _check_animate(darray, animate): + if animate is None: + raise ValueError + + if animate not in darray.coords and animate not in darray.dims: + raise ValueError("Can only animate over a dimension or coordinate " + "present in the DataArray") + + anim_coord = darray[animate].variable + if anim_coord.ndim != 1: + raise ValueError('Coordinate {} must be 1 dimensional but is {}' + ' dimensional'.format(anim_coord, anim_coord.ndim)) + anim_dim = anim_coord.dims[0] + return anim_dim + + +# TODO _transpose_before_animation should be a decorator applied to +# animate_line etc? +def _transpose_before_animation(darray, animate): + # Set animation dimension to be along last axis of data + dims = list(darray.dims) + dims.remove(animate) + return darray.transpose(*dims, animate) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 525360701fe..1c7cd94243a 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -74,6 +74,8 @@ def LooseVersion(vstring): has_np113, requires_np113 = _importorskip('numpy', minversion='1.13.0') has_iris, requires_iris = _importorskip('iris') has_cfgrib, requires_cfgrib = _importorskip('cfgrib') +has_animatplot, requires_animatplot = _importorskip('animatplot', + minversion='0.3.0') # some special cases has_h5netcdf07, requires_h5netcdf07 = _importorskip('h5netcdf', diff --git a/xarray/tests/test_animate.py b/xarray/tests/test_animate.py new file mode 100644 index 00000000000..b7ba72ee76b --- /dev/null +++ b/xarray/tests/test_animate.py @@ -0,0 +1,251 @@ +from functools import partial + +import numpy as np +import numpy.testing as npt +import pytest + +import xarray as xr +from xarray import DataArray +from . import requires_animatplot + +# import mpl and change the backend before other mpl imports +try: + import matplotlib as mpl + import matplotlib.pyplot as plt +except ImportError: + pass + +# TODO should check that matplotlib >= 2.2 is present first? +try: + import animatplot as amp +except ImportError: + pass + +from .test_plot import PlotTestCase, easy_array + +from xarray.plot.animate import _create_timeline +import xarray.plot.animate + + +@requires_animatplot +class TestTimeline: + def test_coord_timeline(self): + da = DataArray([1, 2, 3], + coords={'duration': ('time', [0.1, 0.2, 0.3])}, + dims='time') + da.coords['duration'].attrs['units'] = 's' + timeline = _create_timeline(da, animate='duration', fps=5) + + assert isinstance(timeline, amp.animation.Timeline) + assert len(timeline) == len(da.coords['duration']) + assert timeline.units == ' [s]' + npt.assert_equal(timeline.t, da.coords['duration'].values) + assert timeline.fps == 5 + + def test_dim_timeline(self): + da = DataArray([10, 20], dims='Time') + timeline = _create_timeline(da, animate='Time', fps=5) + + assert isinstance(timeline, amp.animation.Timeline) + assert len(timeline) == da.sizes['Time'] + assert timeline.units == '' + npt.assert_equal(timeline.t, np.array([0, 1])) + assert timeline.fps == 5 + + def test_datetimeline(self): + dates = np.array(['2000-01-01', '2000-01-02', '2000-01-03'], + dtype=np.datetime64) + da = DataArray([1, 2, 3], + coords={'date': ('time', dates)}, dims='time') + timeline = _create_timeline(da, animate='date', fps=5) + + assert str(timeline.t[0]) == '2000-01-01 00:00:00' + + +@pytest.fixture +def linedata(): + dat1 = np.array([[0.0, 1.1, 0.0, 2], + [0.1, 1.3, 0.2, 2.1], + [0.1, 1.4, 0.3, 2.2], + [0.2, 1.3, 0.2, 2.3], + [0.1, 1.2, 0.2, 2.2]]) + dat2 = np.array([[0.0, 1.1, 0.0, 2], + [0.1, 1.3, 0.2, 2.1], + [0.1, 1.4, 0.3, 2.2], + [0.2, 1.3, 0.2, 2.3], + [0.1, 1.2, 0.2, 2.2]]) + das = [] + for data in [dat1, dat2]: + coords = {'time': 10 * np.arange(data.shape[0]), + 'position': 0.1 * np.arange(data.shape[1])} + da = DataArray(data, name='height', coords=coords, + dims=('time', 'position'), attrs={'units': 'm'}) + da.time.attrs['units'] = 's' + da.position.attrs['units'] = 'cm' + + das.append(da) + + player = DataArray(name='player', data=['Tom', 'Bhavin'], dims='player') + return xr.concat(das, dim=player) + + +@requires_animatplot +class TestAnimateLine(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self, linedata): + self.darray = linedata.sel(player='Tom') + + def test_2d_animated_line_accepts_x_kw(self): + self.darray.plot.line(x='position', animate='time') + assert plt.gca().get_xlabel() == 'position [cm]' + plt.cla() + self.darray.plot.line(x='time', animate='position') + assert plt.gca().get_xlabel() == 'time [s]' + + @pytest.mark.skip + def test_2d_animated_line_accepts_y_kw(self): + self.darray.plot.line(y='position', animate='time') + assert plt.gca().get_ylabel() == 'position [cm]' + plt.cla() + self.darray.plot.line(y='time', animate='position') + assert plt.gca().get_ylabel() == 'time [s]' + + def test_animate_single_line_classes(self): + anim = self.darray.plot(animate='time') + assert isinstance(anim, amp.animation.Animation) + + line_block, title_block = anim.blocks + + assert isinstance(line_block, amp.blocks.Line) + assert isinstance(title_block, amp.blocks.Title) + + def test_animate_single_line_data(self): + line_block, title_block = self.darray.plot(animate='time').blocks + + assert len(line_block) == 5 + assert len(line_block) == len(title_block) + + npt.assert_equal(line_block.y, self.darray.transpose().values) + npt.assert_equal(line_block.x[:, 0], + self.darray.coords['position'].values) + + def test_animate_single_line_text(self): + anim = self.darray.plot(animate='time') + line_block, title_block = anim.blocks + + assert title_block.titles[0] == 'time = 0, player = Tom' + assert line_block.ax.get_xlabel() == 'position [cm]' + assert anim.timeline.units == ' [s]' + + # TODO test that omitting title block is handled gracefully + @pytest.mark.skip + def test_no_labels(self): + ... + + def test_can_pass_in_axis(self): + self.pass_in_axis(partial(self.darray.plot, animate='time')) + + def test_animate_single_line_axes(self): + line_block, title_block = self.darray.plot(animate='time').blocks + + # Check current axes is the plot (not the timeline etc.) + assert plt.gca() is line_block.ax + + def test_animate_as_function(self): + anim = xarray.plot.animate.line(self.darray, animate='time') + assert isinstance(anim, amp.animation.Animation) + + def test_animate_as_argument(self): + anim = self.darray.plot(animate='time') + assert isinstance(anim, amp.animation.Animation) + + anim = self.darray.plot.line(animate='time') + assert isinstance(anim, amp.animation.Animation) + + +@pytest.mark.xfail(reason="np.splitting the y data doesn't work for step plots" + "because they have lists of arrays for some reason") +@requires_animatplot +class TestAnimateStep(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = DataArray(easy_array((4, 5, 6))) + + def test_coord_with_interval_step(self): + bins = [-1, 0, 1, 2] + da = self.darray.groupby_bins('dim_0', bins).mean(xr.ALL_DIMS) + da = xr.concat([da, da * 2, da * 1.7], dim='new_dim') + + anim = da.plot.step(animate='new_dim') + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + npt.assert_equal(anim.timeline.t, np.array([0, 1, 2])) + + +@requires_animatplot +class TestAnimateMultipleLines(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self, linedata): + self.darray = linedata + + def test_2d_animated_line_accepts_hue_kw(self): + da = self.darray + print(da) + da.plot.line(hue='player', animate='time') + assert (plt.gca().get_legend().get_title().get_text() + == 'player') + plt.cla() + self.darray.plot.line(hue='time', animate='player') + assert (plt.gca().get_legend().get_title().get_text() + == 'time [s]') + + def test_animate_multiple_lines_classes(self): + anim = self.darray.plot(animate='time', hue='player') + assert isinstance(anim, amp.animation.Animation) + + line_block1, line_block2, title_block = anim.blocks + + assert isinstance(line_block1, amp.blocks.Line) + assert isinstance(line_block2, amp.blocks.Line) + assert isinstance(title_block, amp.blocks.Title) + + def test_animate_multiple_lines_data(self): + anim = self.darray.plot(animate='time', hue='player') + line_block1, _, title_block = anim.blocks + + assert len(line_block1) == 5 + assert len(line_block1) == len(title_block) + + expected = self.darray.isel(player=0).transpose('position', 'time') + npt.assert_equal(line_block1.y, expected.values) + npt.assert_equal(line_block1.x[:, 0], + self.darray.coords['position'].values) + + def test_animate_multiple_lines_text(self): + anim = self.darray.plot(animate='time', hue='player') + line_block1, _, title_block = anim.blocks + + assert title_block.titles[0] == 'time = 0' + assert line_block1.ax.get_xlabel() == 'position [cm]' + assert anim.timeline.units == ' [s]' + + # TODO check legend is correct + + def test_can_pass_in_axis(self): + self.pass_in_axis(partial(self.darray.plot, + animate='time', hue='player')) + + def test_animate_multiple_line_axes(self): + line_block1, line_block2, _ = self.darray.plot(animate='time', + hue='player').blocks + assert line_block1.ax is line_block2.ax + + # Check current axes is the plot (not the timeline etc.) + assert plt.gca() is line_block1.ax + + +class TestAnimatedFacetGrid: + def test_faceting_not_implemented(self): + da = DataArray(easy_array(2, 3, 4)) + + with pytest.raises(NotImplementedError): + da.plot(animate='dim_0', col='dim_1') diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c0e03b5791c..e92998b6f6f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd import pytest -from numpy.testing import assert_array_equal import xarray as xr import xarray.plot as xplt @@ -187,6 +186,13 @@ def test_2d_line_accepts_x_kw(self): self.darray[:, :, 0].plot.line(x='dim_1') assert plt.gca().get_xlabel() == 'dim_1' + def test_2d_line_accepts_y_kw(self): + self.darray[:, :, 0].plot.line(y='dim_0') + assert plt.gca().get_ylabel() == 'dim_0' + plt.cla() + self.darray[:, :, 0].plot.line(y='dim_1') + assert plt.gca().get_ylabel() == 'dim_1' + def test_2d_line_accepts_hue_kw(self): self.darray[:, :, 0].plot.line(hue='dim_0') assert (plt.gca().get_legend().get_title().get_text()