diff --git a/animatplot/blocks/lineplots.py b/animatplot/blocks/lineplots.py index 007b1f6..cb65bed 100644 --- a/animatplot/blocks/lineplots.py +++ b/animatplot/blocks/lineplots.py @@ -1,15 +1,20 @@ +import numpy as np + from .base import Block from animatplot.util import parametric_line -import numpy as np class Line(Block): - """Animates lines + """ + Animates a single line. + + Accepts additional keyword arguments to be passed to + :meth:`matplotlib.axes.Axes.plot`. Parameters ---------- - x : list of 1D numpy arrays or a 2D numpy array - The x data to be animated. + x : 1D numpy array, list of 1D numpy arrays or a 2D numpy array, optional + The x data to be animated. If 1D then will be constant over animation. y : list of 1D numpy arrays or a 2D numpy array The y data to be animated. ax : matplotlib.axes.Axes, optional @@ -21,41 +26,103 @@ class Line(Block): The default is chosen to be consistent with: X, T = numpy.meshgrid(x, t) + **kwargs + Passed on to `matplotlib.axes.Axes.plot`. Attributes ---------- + line: matplotlib.lines.Line2D + ax : matplotlib.axes.Axes The matplotlib axes that the block is attached to. Notes ----- - This block accepts additional keyword arguments to be passed to - :meth:`matplotlib.axes.Axes.plot` + This block animates a single line - to animate multiple lines you must call + this once for each line, and then animate all of the blocks returned by + passing a list of those blocks to `animatplot.Animation`. """ - def __init__(self, x, y, ax=None, t_axis=0, **kwargs): - self.x = np.asanyarray(x) - self.y = np.asanyarray(y) - if self.x.shape != self.y.shape: - raise ValueError("x, y must have the same shape" - "or be lists of the same length") + def __init__(self, *args, ax=None, t_axis=0, **kwargs): + super().__init__(ax, t_axis) - self._is_list = (self.x.dtype == 'object') - Slice = self._make_slice(0, 2) - self.line, = self.ax.plot(self.x[Slice], self.y[Slice], **kwargs) + if len(args) == 1: + y = args[0] + x = None + elif len(args) == 2: + [x, y] = args + else: + raise ValueError("Invalid data arguments to Line block") + + if y is None: + raise ValueError("Must supply y data to plot") + y = np.asanyarray(y) + if str(y.dtype) == 'object': + self.t_axis = 0 + + # ragged array + if x is None: + raise ValueError("Must specify x data explicitly when passing" + "a ragged array for y data") + + x = np.asanyarray(x) + + if not all(len(xline) == len(yline) for xline, yline in zip(x, y)): + raise ValueError("Length of x & y data must match one another " + "for every frame") + + self._is_list = True + + else: + # Rectangular data + if y.ndim != 2: + raise ValueError("y data must be 2-dimensional") + + # x is optional + shape = list(y.shape) + shape.remove(y.shape[t_axis]) + data_length, = shape + if x is None: + x = np.arange(data_length) + else: + x = np.asanyarray(x) + + shape_mismatch = "The dimensions of x must be compatible with " \ + "those of y, but the shape of x is {} and the " \ + "shape of y is {}".format(x.shape, y.shape) + if x.ndim == 1: + # x is constant over time + if len(x) == data_length: + # Broadcast x to match y + x = np.expand_dims(x, axis=t_axis) + x = np.repeat(x, repeats=y.shape[t_axis], axis=t_axis) + else: + raise ValueError(shape_mismatch) + elif x.ndim == 2: + if x.shape != y.shape: + raise ValueError(shape_mismatch) + else: + raise ValueError("x, must be either 1- or 2-dimensional") - def _update(self, i): - Slice = self._make_slice(i, 2) - x_vector = self.x[Slice] - y_vector = self.y[Slice] + self.x = x + self.y = y + + frame_slice = self._make_slice(i=0, dim=2) + + x_first_frame_data = self.x[frame_slice] + y_first_frame_data = self.y[frame_slice] + + self.line, = self.ax.plot(x_first_frame_data, + y_first_frame_data, **kwargs) + def _update(self, frame): + frame_slice = self._make_slice(frame, dim=2) + x_vector = self.x[frame_slice] + y_vector = self.y[frame_slice] self.line.set_data(x_vector, y_vector) - return self.line def __len__(self): - if self._is_list: - return self.x.shape[0] - return self.x.shape[self.t_axis] + return self.y.shape[self.t_axis] class ParametricLine(Line): @@ -79,8 +146,8 @@ class ParametricLine(Line): :meth:`matplotlib.axes.Axes.plot` """ def __init__(self, x, y, *args, **kwargs): - X, Y = parametric_line(x, y) - super().__init__(X, Y, *args, *kwargs) + x_grid, y_grid = parametric_line(x, y) + super().__init__(x_grid, y_grid, *args, *kwargs) class Scatter(Block): diff --git a/tests/test_blocks.py b/tests/test_blocks.py index a74e9f5..358ac7f 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -1,6 +1,8 @@ from matplotlib.testing import setup import numpy as np +import numpy.testing as npt import matplotlib.pyplot as plt +import matplotlib as mpl import pytest @@ -69,6 +71,132 @@ def test_mpl_kwargs(self): assert actual._mpl_kwargs == expected +def assert_jagged_arrays_equal(x, y): + for x, y in zip(x, y): + npt.assert_equal(x, y) + + +class TestLineBlock: + def test_2d_inputs(self): + x = np.linspace(0, 1, 10) + t = np.linspace(0, 1, 5) + x_grid, t_grid = np.meshgrid(x, t) + y_data = np.sin(2 * np.pi * (x_grid + t_grid)) + + line_block = amp.blocks.Line(x_grid, y_data) + + assert isinstance(line_block, amp.blocks.Line) + npt.assert_equal(line_block.x, x_grid) + npt.assert_equal(line_block.y, y_data) + assert len(line_block) == len(t) + + assert isinstance(line_block.line, mpl.lines.Line2D) + xdata, ydata = line_block.line.get_data() + npt.assert_equal(xdata, x) + npt.assert_equal(ydata, y_data[0, :]) + + def test_update(self): + x = np.linspace(0, 1, 10) + t = np.linspace(0, 1, 5) + x_grid, t_grid = np.meshgrid(x, t) + y_data = np.sin(2 * np.pi * (x_grid + t_grid)) + + line_block = amp.blocks.Line(x_grid, y_data) + line_block._update(frame=1) + + npt.assert_equal(line_block.line.get_xdata(), x) + npt.assert_equal(line_block.line.get_ydata(), y_data[1, :]) + + def test_constant_x(self): + x = np.linspace(0, 1, 10) + t = np.linspace(0, 1, 5) + x_grid, t_grid = np.meshgrid(x, t) + y_data = np.sin(2 * np.pi * (x_grid + t_grid)) + + line_block = amp.blocks.Line(x, y_data) + + npt.assert_equal(line_block.line.get_xdata(), x) + npt.assert_equal(line_block.x[-1], x) + + def test_no_x_input(self): + x = np.linspace(0, 1, 10) + t = np.linspace(0, 1, 5) + x_grid, t_grid = np.meshgrid(x, t) + y_data = np.sin(2 * np.pi * (x_grid + t_grid)) + + line_block = amp.blocks.Line(y_data) + + expected_x = np.arange(10) + npt.assert_equal(line_block.line.get_xdata(), expected_x) + + def test_list_input(self): + x_data = [np.array([1, 2, 3]), np.array([1, 2, 3])] + y_data = [np.array([5, 6, 7]), np.array([4, 2, 9])] + line_block = amp.blocks.Line(x_data, y_data) + npt.assert_equal(line_block.y, np.array([[5, 6, 7], [4, 2, 9]])) + npt.assert_equal(line_block.x, np.array([[1, 2, 3], [1, 2, 3]])) + + def test_ragged_list_input(self): + x_data = [np.array([1, 2, 3]), np.array([1, 2, 3, 4])] + y_data = [np.array([5, 6, 7]), np.array([4, 2, 9, 10])] + + with pytest.raises(ValueError) as err: + line_block = amp.blocks.Line(y_data) + assert "Must specify x data explicitly" in str(err) + + line_block = amp.blocks.Line(x_data, y_data) + + assert_jagged_arrays_equal(line_block.x, np.array(x_data)) + assert_jagged_arrays_equal(line_block.y, np.array(y_data)) + + def test_bad_ragged_list_input(self): + x_data = np.array([np.array([1, 2, 3]), np.array([1, 2, 3, 4])]) + y_data = np.array([np.array([5, 6, 7]), np.array([4, 2, 9, 10, 11])]) + + with pytest.raises(ValueError) as err: + line_block = amp.blocks.Line(x_data, y_data) + assert "x & y data must match" in str(err) + + def test_bad_input(self): + # incorrect number of args + with pytest.raises(ValueError) as err: + amp.blocks.Line(1, 2, 3) + assert 'Invalid data arguments' in str(err.value) + with pytest.raises(ValueError) as err: + amp.blocks.Line() + assert 'Invalid data arguments' in str(err.value) + + # No y data + with pytest.raises(ValueError) as err: + amp.blocks.Line(np.arange(5), None) + assert 'Must supply y data' in str(err.value) + with pytest.raises(ValueError) as err: + amp.blocks.Line(None) + assert 'Must supply y data' in str(err.value) + + # y data not 2d + with pytest.raises(ValueError) as err: + amp.blocks.Line(np.arange(5), np.random.randn(5, 2, 2)) + assert 'y data must be 2-dimensional' in str(err.value) + + # 1d x doesn't match y + with pytest.raises(ValueError) as err: + amp.blocks.Line(np.arange(5), np.random.randn(4, 2)) + assert 'dimensions of x must be compatible' in str(err.value) + + # 2d x doesn't match y + with pytest.raises(ValueError) as err: + x = np.array([np.arange(5), np.arange(5)]) + amp.blocks.Line(x, np.random.randn(4, 2), t_axis=1) + assert 'dimensions of x must be compatible' in str(err.value) + + def test_kwarg_throughput(self): + x = np.array([np.arange(5), np.arange(5)]) + line_block = amp.blocks.Line(x, np.random.randn(2, 5), t_axis=1, + alpha=0.5) + assert line_block.line.get_alpha() == 0.5 + + class TestComparisons: @animation_compare(baseline_images='Blocks/Line', nframes=5) def test_Line(self):