diff --git a/.travis.yml b/.travis.yml index a0758ee1..641e3c31 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: python python: - - "3.5" - "3.6" + - "3.7" install: - pip install --upgrade setuptools pip pytest pytest-cov coverage codecov - pip install -r requirements.txt diff --git a/requirements.txt b/requirements.txt index 2df45747..cd273591 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ xarray >= 0.12.2 dask[array] >= 1.0.0 natsort >= 5.5.0 -matplotlib >= 2.2 +matplotlib >= 3.0.3 animatplot >= 0.3 netcdf4 >= 1.4.0 +Pillow >= 6.1.0 diff --git a/setup.py b/setup.py index 9318d5cd..fed11c39 100644 --- a/setup.py +++ b/setup.py @@ -19,14 +19,15 @@ author_email="thomas.nicholas@york.ac.uk", description='Collect data from BOUT++ runs in python using xarray', license="Apache", - python_requires='>=3.5', + python_requires='>=3.6', install_requires=[ 'xarray>=v0.12.2', 'dask[array]>=1.0.0', 'natsort>=5.5.0', - 'matplotlib>=2.2', + 'matplotlib>=3.0.3', 'animatplot>=0.3', 'netcdf4>=1.4.0', + 'Pillow>=6.1.0' ], extras_require={ 'tests': ['pytest >= 3.3.0'], @@ -42,7 +43,6 @@ "License :: OSI Approved :: Apache License", "Natural Language :: English", "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Topic :: Scientific/Engineering :: Visualization" diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index 4a0dcbd9..06eff929 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -97,7 +97,6 @@ def animate1D(self, animate_over='t', x='x', y='y', animate=True, print("{} data passed has {} dimensions - will use " "animatplot.blocks.Line()".format(variable, str(n_dims))) line_block = animate_line(data=data, animate_over=animate_over, - x=x, y=y, sep_pos=sep_pos, - animate=animate, fps=fps, + sep_pos=sep_pos, animate=animate, fps=fps, save_as=save_as, ax=ax, **kwargs) return line_block diff --git a/xbout/plotting/animate.py b/xbout/plotting/animate.py index 5a0a4cb9..5124bbf9 100644 --- a/xbout/plotting/animate.py +++ b/xbout/plotting/animate.py @@ -4,10 +4,10 @@ import animatplot as amp from .utils import plot_separatrix - +from matplotlib.animation import PillowWriter def animate_imshow(data, animate_over='t', x='x', y='y', animate=True, - vmin='min', vmax='max', fps=10, save_as=None, + vmin=None, vmax=None, fps=10, save_as=None, sep_pos=None, ax=None, **kwargs): """ Plots a color plot which is animated with time over the specified @@ -60,9 +60,9 @@ def animate_imshow(data, animate_over='t', x='x', y='y', animate=True, image_data = data.values # If not specified, determine max and min values across entire data series - if vmax is 'max': + if vmax is None: vmax = np.max(image_data) - if vmin is 'min': + if vmin is None: vmin = np.min(image_data) if not ax: @@ -93,16 +93,89 @@ def animate_imshow(data, animate_over='t', x='x', y='y', animate=True, if not save_as: save_as = "{}_over_{}".format(variable, animate_over) - # TODO save using PillowWriter instead once matplotlib 3.1 comes out - # see https://github.com/t-makaro/animatplot/issues/24 - anim.save(save_as + '.gif', writer='imagemagick') + anim.save(save_as + '.gif', writer=PillowWriter(fps=fps)) return imshow_block -def animate_line(data, animate_over='t', x='x', y='y', animate=True, - fps=10, save_as=None, sep_pos=None, ax=None, **kwargs): +def animate_line(data, animate_over='t', animate=True, + vmin=None, vmax=None, fps=10, save_as=None, sep_pos=None, ax=None, + **kwargs): + """ + Plots a line plot which is animated with time. + + Currently only supports 1D+1 data, which it plots with xarray's + wrapping of matplotlib's plot. + + Parameters + ---------- + data : xarray.DataArray + animate_over : str, optional + Dimension over which to animate + vmin : float, optional + Minimum value to use for colorbar. Default is to use minimum value of + data across whole timeseries. + vmax : float, optional + Maximum value to use for colorbar. Default is to use maximum value of + data across whole timeseries. + sep_pos : int, optional + Radial position at which to plot the separatrix + save_as: str, optional + Filename to give to the resulting gif + fps : int, optional + Frames per second of resulting gif + kwargs : dict, optional + Additional keyword arguments are passed on to the plotting function + (e.g. imshow for 2D plots). + """ + variable = data.name + # Check plot is the right orientation t_read, x_read = data.dims - raise NotImplementedError + if (t_read is animate_over): + pass + else: + data = data.transpose(x_read, animate_over) + + # Load values eagerly otherwise for some reason the plotting takes + # 100's of times longer - for some reason animatplot does not deal + # well with dask arrays! + image_data = data.values + + # If not specified, determine max and min values across entire data series + if vmax is None: + vmax = np.max(image_data) + if vmin is None: + vmin = np.min(image_data) + + if not ax: + fig, ax = plt.subplots() + + # set range of plot + ax.set_ylim([vmin, vmax]) + + line_block = amp.blocks.Line(image_data, ax=ax, **kwargs) + + timeline = amp.Timeline(np.arange(data.sizes[animate_over]), fps=fps) + + if animate: + anim = amp.Animation([line_block], timeline) + + # Add title and axis labels + ax.set_title("{} variation over {}".format(variable, animate_over)) + ax.set_xlabel(x_read) + ax.set_ylabel(variable) + + # Plot separatrix + if sep_pos: + ax.plot_vline(sep_pos, '--') + + if animate: + anim.controls(timeline_slider_args={'text': animate_over}) + + if not save_as: + save_as = "{}_over_{}".format(variable, animate_over) + anim.save(save_as + '.gif', writer=PillowWriter(fps=fps)) + + return line_block diff --git a/xbout/tests/test_animate.py b/xbout/tests/test_animate.py new file mode 100644 index 00000000..fb67e4be --- /dev/null +++ b/xbout/tests/test_animate.py @@ -0,0 +1,44 @@ +import pytest + +from xbout import open_boutdataset +from xbout.boutdataarray import BoutDataArrayAccessor +from .test_load import create_bout_ds_list + +from animatplot.blocks import Imshow, Line + + +@pytest.fixture +def create_test_file(tmpdir_factory): + + # Create temp dir for output of animate1D/2D + save_dir = tmpdir_factory.mktemp("test_data") + + # Generate some test data + ds_list, file_list = create_bout_ds_list("BOUT.dmp", nxpe=3, nype=3, + syn_data_type="linear") + for ds, file_name in zip(ds_list, file_list): + ds.to_netcdf(str(save_dir.join(str(file_name)))) + + ds = open_boutdataset(save_dir.join("BOUT.dmp.*.nc")) # Open test data + + return save_dir, ds + + +class TestAnimate: + """ + Set of tests to check whether animate1D() and animate2D() are running properly + and PillowWriter is saving each animation correctly + """ + def test_animate2D(self, create_test_file): + + save_dir, ds = create_test_file + animation = ds['n'].isel(y=1).bout.animate2D(y='z', save_as="%s/test" % save_dir) + + assert isinstance(animation, Imshow) + + def test_animate1D(self, create_test_file): + + save_dir, ds = create_test_file + animation = ds['n'].isel(y=2, z=0).bout.animate1D(save_as="%s/test" % save_dir) + + assert isinstance(animation, Line) diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index 59482f70..ad8b2d48 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -168,7 +168,7 @@ def bout_xyt_example_files(tmpdir_factory): return _bout_xyt_example_files -def _bout_xyt_example_files(tmpdir_factory, prefix='BOUT.dmp', lengths=(2,4,7,6), +def _bout_xyt_example_files(tmpdir_factory, prefix='BOUT.dmp', lengths=(6,2,4,7), nxpe=4, nype=2, nt=1, guards={}, syn_data_type='random'): """ Mocks up a set of BOUT-like netCDF files, and return the temporary test directory containing them. @@ -197,7 +197,7 @@ def _bout_xyt_example_files(tmpdir_factory, prefix='BOUT.dmp', lengths=(2,4,7,6) return glob_pattern -def create_bout_ds_list(prefix, lengths=(2, 4, 7, 6), nxpe=4, nype=2, nt=1, guards={}, +def create_bout_ds_list(prefix, lengths=(6,2,4,7), nxpe=4, nype=2, nt=1, guards={}, syn_data_type='random'): """ Mocks up a set of BOUT-like datasets. @@ -217,17 +217,8 @@ def create_bout_ds_list(prefix, lengths=(2, 4, 7, 6), nxpe=4, nype=2, nt=1, guar upper_bndry_cells = {dim: guards.get(dim) for dim in guards.keys()} lower_bndry_cells = {dim: guards.get(dim) for dim in guards.keys()} - # Include boundary cells - for dim in ['x', 'y']: - if dim in guards.keys(): - if i == 0: - lower_bndry_cells[dim] = guards[dim] - if i == nxpe-1: - upper_bndry_cells[dim] = guards[dim] - ds = create_bout_ds(syn_data_type=syn_data_type, num=num, lengths=lengths, nxpe=nxpe, nype=nype, - upper_bndry_cells=upper_bndry_cells, lower_bndry_cells=lower_bndry_cells, - guards=guards) + xproc=i, yproc=j, guards=guards) ds_list.append(ds) # Sort this in order of num to remove any BOUT-specific structure @@ -237,25 +228,44 @@ def create_bout_ds_list(prefix, lengths=(2, 4, 7, 6), nxpe=4, nype=2, nt=1, guar return ds_list_sorted, file_list_sorted -def create_bout_ds(syn_data_type='random', lengths=(2,4,7,6), num=0, nxpe=1, nype=1, - upper_bndry_cells={}, lower_bndry_cells={}, guards={}): +def create_bout_ds(syn_data_type='random', lengths=(6,2,4,7), num=0, nxpe=1, nype=1, + xproc=0, yproc=0, guards={}): # Set the shape of the data in this dataset - x_length, y_length, z_length, t_length = lengths - x_length += upper_bndry_cells.get('x', 0) + lower_bndry_cells.get('x', 0) - y_length += upper_bndry_cells.get('y', 0) + lower_bndry_cells.get('y', 0) - z_length += upper_bndry_cells.get('z', 0) + lower_bndry_cells.get('z', 0) - t_length += upper_bndry_cells.get('t', 0) + lower_bndry_cells.get('t', 0) - shape = (x_length, y_length, z_length, t_length) + t_length, x_length, y_length, z_length = lengths + mxg = guards.get('x', 0) + myg = guards.get('y', 0) + x_length += 2*mxg + y_length += 2*myg + shape = (t_length, x_length, y_length, z_length) + + # calculate global nx, ny and nz + nx = nxpe*lengths[1] + 2*mxg + ny = nype*lengths[2] + nz = 1*lengths[3] # Fill with some kind of synthetic data if syn_data_type is 'random': - # Each dataset contains the same random noise - np.random.seed(seed=0) + # Each dataset contains unique random noise + np.random.seed(seed = num) data = np.random.randn(*shape) elif syn_data_type is 'linear': # Variables increase linearly across entire domain - raise NotImplementedError + data = DataArray(-np.ones(shape), dims=('t', 'x', 'y', 'z')) + + t_array = DataArray((nx - 2*mxg)*ny*nz*np.arange(t_length, dtype=float), + dims='t') + x_array = DataArray(ny*nz*(xproc*lengths[1] + mxg + + np.arange(lengths[1], dtype=float)), + dims='x') + y_array = DataArray(nz*(yproc*lengths[2] + myg + + np.arange(lengths[2], dtype=float)), + dims='y') + z_array = DataArray(np.arange(z_length, dtype=float), dims='z') + + data[:, mxg:x_length-mxg, myg:y_length-myg, :] = ( + t_array + x_array + y_array + z_array + ) elif syn_data_type is 'stepped': # Each dataset contains a different number depending on the filename data = np.ones(shape) * num @@ -264,29 +274,74 @@ def create_bout_ds(syn_data_type='random', lengths=(2,4,7,6), num=0, nxpe=1, nyp else: raise ValueError('Not a recognised choice of type of synthetic bout data.') - T = DataArray(data, dims=['x', 'y', 'z', 't']) - n = DataArray(data, dims=['x', 'y', 'z', 't']) + T = DataArray(data, dims=['t', 'x', 'y', 'z']) + n = DataArray(data, dims=['t', 'x', 'y', 'z']) ds = Dataset({'n': n, 'T': T}) - # Include metadata + # Include grid data ds['NXPE'] = nxpe ds['NYPE'] = nype - ds['MXG'] = guards.get('x', 0) - ds['MYG'] = guards.get('y', 0) - ds['nx'] = x_length - ds['MXSUB'] = guards.get('x', 0) - ds['MYSUB'] = guards.get('y', 0) - ds['MZ'] = z_length - ds['jyseps1_1'] = -1 - ds['jyseps1_2'] = -1 - ds['jyseps2_1'] = -1 - ds['jyseps2_2'] = -1 + ds['NZPE'] = 1 + ds['PE_XIND'] = xproc + ds['PE_YIND'] = yproc + ds['MYPE'] = num + + ds['MXG'] = mxg + ds['MYG'] = myg + ds['nx'] = nx + ds['ny'] = ny + ds['nz'] = nz + ds['MZ'] = 1*lengths[3] + ds['MXSUB'] = lengths[1] + ds['MYSUB'] = lengths[2] + ds['MZSUB'] = lengths[3] + ds['ixseps1'] = nx + ds['ixseps2'] = nx + ds['jyseps1_1'] = 0 + ds['jyseps1_2'] = ny + ds['jyseps2_1'] = ny//2 - 1 + ds['jyseps2_2'] = ny//2 - 1 + ds['ny_inner'] = ny//2 + + one = DataArray(np.ones((x_length, y_length)), dims=['x', 'y']) + zero = DataArray(np.zeros((x_length, y_length)), dims=['x', 'y']) + + ds['zperiod'] = 1 + ds['ZMIN'] = 0. + ds['ZMAX'] = 2.*np.pi + ds['g11'] = one + ds['g22'] = one + ds['g33'] = one + ds['g12'] = zero + ds['g13'] = zero + ds['g23'] = zero + ds['g_11'] = one + ds['g_22'] = one + ds['g_33'] = one + ds['g_12'] = zero + ds['g_13'] = zero + ds['g_23'] = zero + ds['G1'] = zero + ds['G2'] = zero + ds['G3'] = zero + ds['J'] = one + ds['Bxy'] = one + ds['zShift'] = zero + + ds['dx'] = 0.5*one + ds['dy'] = 2.*one + ds['dz'] = 0.7 + + ds['iteration'] = t_length + ds['t_array'] = DataArray(np.arange(t_length, dtype=float)*10., dims='t') return ds -METADATA_VARS = ['NXPE', 'NYPE', 'MXG', 'MYG', 'nx', 'MXSUB', 'MYSUB', 'MZ', - 'jyseps1_1', 'jyseps1_2', 'jyseps2_1', 'jyseps2_2'] +METADATA_VARS = ['NXPE', 'NYPE', 'NZPE', 'PE_XIND', 'PE_YIND', 'MYPE', 'MXG', 'MYG', 'nx', + 'ny', 'nz', 'MZ', 'MXSUB', 'MYSUB', 'MZSUB', 'ixseps1', 'ixseps2', + 'jyseps1_1', 'jyseps1_2', 'jyseps2_1', 'jyseps2_2', 'ny_inner', + 'zperiod', 'ZMIN', 'ZMAX', 'dz', 'iteration'] class TestStripMetadata(): @@ -315,7 +370,8 @@ def test_combine_along_x(self, tmpdir_factory, bout_xyt_example_files): actual = open_boutdataset(datapath=path, keep_xboundaries=False) bout_ds = create_bout_ds - expected = concat([bout_ds(0), bout_ds(1), bout_ds(2), bout_ds(3)], dim='x') + expected = concat([bout_ds(0), bout_ds(1), bout_ds(2), bout_ds(3)], dim='x', + data_vars='minimal') xrt.assert_equal(actual.load(), expected.drop(METADATA_VARS)) def test_combine_along_y(self, tmpdir_factory, bout_xyt_example_files): @@ -324,7 +380,8 @@ def test_combine_along_y(self, tmpdir_factory, bout_xyt_example_files): actual = open_boutdataset(datapath=path, keep_xboundaries=False) bout_ds = create_bout_ds - expected = concat([bout_ds(0), bout_ds(1), bout_ds(2)], dim='y') + expected = concat([bout_ds(0), bout_ds(1), bout_ds(2)], dim='y', + data_vars='minimal') xrt.assert_equal(actual.load(), expected.drop(METADATA_VARS)) @pytest.mark.skip @@ -337,10 +394,14 @@ def test_combine_along_xy(self, tmpdir_factory, bout_xyt_example_files): actual = open_boutdataset(datapath=path, keep_xboundaries=False) bout_ds = create_bout_ds - line1 = concat([bout_ds(0), bout_ds(1), bout_ds(2), bout_ds(3)], dim='x') - line2 = concat([bout_ds(4), bout_ds(5), bout_ds(6), bout_ds(7)], dim='x') - line3 = concat([bout_ds(8), bout_ds(9), bout_ds(10), bout_ds(11)], dim='x') - expected = concat([line1, line2, line3], dim='y') + line1 = concat([bout_ds(0), bout_ds(1), bout_ds(2), bout_ds(3)], dim='x', + data_vars='minimal') + line2 = concat([bout_ds(4), bout_ds(5), bout_ds(6), bout_ds(7)], dim='x', + data_vars='minimal') + line3 = concat([bout_ds(8), bout_ds(9), bout_ds(10), bout_ds(11)], dim='x', + data_vars='minimal') + expected = concat([line1, line2, line3], dim='y', + data_vars='minimal') xrt.assert_equal(actual.load(), expected.drop(METADATA_VARS)) @pytest.mark.skip