Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 92 additions & 25 deletions animatplot/blocks/lineplots.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import numpy as np

from .base import Block
from animatplot.util import parametric_line
import numpy as np


class Line(Block):
"""Animates lines
"""
Animates a single line.

Accepts additional keyword arguments to be passed to
:meth:`matplotlib.axes.Axes.plot`.

Parameters
----------
x : list of 1D numpy arrays or a 2D numpy array
The x data to be animated.
x : 1D numpy array, list of 1D numpy arrays or a 2D numpy array, optional
The x data to be animated. If 1D then will be constant over animation.
y : list of 1D numpy arrays or a 2D numpy array
The y data to be animated.
ax : matplotlib.axes.Axes, optional
Expand All @@ -21,41 +26,103 @@ class Line(Block):

The default is chosen to be consistent with:
X, T = numpy.meshgrid(x, t)
**kwargs
Passed on to `matplotlib.axes.Axes.plot`.

Attributes
----------
line: matplotlib.lines.Line2D

ax : matplotlib.axes.Axes
The matplotlib axes that the block is attached to.

Notes
-----
This block accepts additional keyword arguments to be passed to
:meth:`matplotlib.axes.Axes.plot`
This block animates a single line - to animate multiple lines you must call
this once for each line, and then animate all of the blocks returned by
passing a list of those blocks to `animatplot.Animation`.
"""
def __init__(self, x, y, ax=None, t_axis=0, **kwargs):
self.x = np.asanyarray(x)
self.y = np.asanyarray(y)
if self.x.shape != self.y.shape:
raise ValueError("x, y must have the same shape"
"or be lists of the same length")
def __init__(self, *args, ax=None, t_axis=0, **kwargs):

super().__init__(ax, t_axis)

self._is_list = (self.x.dtype == 'object')
Slice = self._make_slice(0, 2)
self.line, = self.ax.plot(self.x[Slice], self.y[Slice], **kwargs)
if len(args) == 1:
y = args[0]
x = None
elif len(args) == 2:
[x, y] = args
else:
raise ValueError("Invalid data arguments to Line block")

if y is None:
raise ValueError("Must supply y data to plot")
y = np.asanyarray(y)
if str(y.dtype) == 'object':
self.t_axis = 0

# ragged array
if x is None:
raise ValueError("Must specify x data explicitly when passing"
"a ragged array for y data")

x = np.asanyarray(x)

if not all(len(xline) == len(yline) for xline, yline in zip(x, y)):
raise ValueError("Length of x & y data must match one another "
"for every frame")

self._is_list = True

else:
# Rectangular data
if y.ndim != 2:
raise ValueError("y data must be 2-dimensional")

# x is optional
shape = list(y.shape)
shape.remove(y.shape[t_axis])
data_length, = shape
if x is None:
x = np.arange(data_length)
else:
x = np.asanyarray(x)

shape_mismatch = "The dimensions of x must be compatible with " \
"those of y, but the shape of x is {} and the " \
"shape of y is {}".format(x.shape, y.shape)
if x.ndim == 1:
# x is constant over time
if len(x) == data_length:
# Broadcast x to match y
x = np.expand_dims(x, axis=t_axis)
x = np.repeat(x, repeats=y.shape[t_axis], axis=t_axis)
else:
raise ValueError(shape_mismatch)
elif x.ndim == 2:
if x.shape != y.shape:
raise ValueError(shape_mismatch)
else:
raise ValueError("x, must be either 1- or 2-dimensional")

def _update(self, i):
Slice = self._make_slice(i, 2)
x_vector = self.x[Slice]
y_vector = self.y[Slice]
self.x = x
self.y = y

frame_slice = self._make_slice(i=0, dim=2)

x_first_frame_data = self.x[frame_slice]
y_first_frame_data = self.y[frame_slice]

self.line, = self.ax.plot(x_first_frame_data,
y_first_frame_data, **kwargs)

def _update(self, frame):
frame_slice = self._make_slice(frame, dim=2)
x_vector = self.x[frame_slice]
y_vector = self.y[frame_slice]
self.line.set_data(x_vector, y_vector)
return self.line

def __len__(self):
if self._is_list:
return self.x.shape[0]
return self.x.shape[self.t_axis]
return self.y.shape[self.t_axis]
Comment thread
t-makaro marked this conversation as resolved.


class ParametricLine(Line):
Expand All @@ -79,8 +146,8 @@ class ParametricLine(Line):
:meth:`matplotlib.axes.Axes.plot`
"""
def __init__(self, x, y, *args, **kwargs):
X, Y = parametric_line(x, y)
super().__init__(X, Y, *args, *kwargs)
x_grid, y_grid = parametric_line(x, y)
super().__init__(x_grid, y_grid, *args, *kwargs)


class Scatter(Block):
Expand Down
128 changes: 128 additions & 0 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from matplotlib.testing import setup
import numpy as np
import numpy.testing as npt
Comment thread
t-makaro marked this conversation as resolved.
import matplotlib.pyplot as plt
import matplotlib as mpl

import pytest

Expand Down Expand Up @@ -69,6 +71,132 @@ def test_mpl_kwargs(self):
assert actual._mpl_kwargs == expected


def assert_jagged_arrays_equal(x, y):
for x, y in zip(x, y):
npt.assert_equal(x, y)


class TestLineBlock:
def test_2d_inputs(self):
x = np.linspace(0, 1, 10)
t = np.linspace(0, 1, 5)
x_grid, t_grid = np.meshgrid(x, t)
y_data = np.sin(2 * np.pi * (x_grid + t_grid))

line_block = amp.blocks.Line(x_grid, y_data)

assert isinstance(line_block, amp.blocks.Line)
npt.assert_equal(line_block.x, x_grid)
npt.assert_equal(line_block.y, y_data)
assert len(line_block) == len(t)

assert isinstance(line_block.line, mpl.lines.Line2D)
xdata, ydata = line_block.line.get_data()
npt.assert_equal(xdata, x)
npt.assert_equal(ydata, y_data[0, :])

def test_update(self):
x = np.linspace(0, 1, 10)
t = np.linspace(0, 1, 5)
x_grid, t_grid = np.meshgrid(x, t)
y_data = np.sin(2 * np.pi * (x_grid + t_grid))

line_block = amp.blocks.Line(x_grid, y_data)
line_block._update(frame=1)

npt.assert_equal(line_block.line.get_xdata(), x)
npt.assert_equal(line_block.line.get_ydata(), y_data[1, :])

def test_constant_x(self):
Comment thread
t-makaro marked this conversation as resolved.
x = np.linspace(0, 1, 10)
t = np.linspace(0, 1, 5)
x_grid, t_grid = np.meshgrid(x, t)
y_data = np.sin(2 * np.pi * (x_grid + t_grid))

line_block = amp.blocks.Line(x, y_data)

npt.assert_equal(line_block.line.get_xdata(), x)
npt.assert_equal(line_block.x[-1], x)

def test_no_x_input(self):
x = np.linspace(0, 1, 10)
t = np.linspace(0, 1, 5)
x_grid, t_grid = np.meshgrid(x, t)
y_data = np.sin(2 * np.pi * (x_grid + t_grid))

line_block = amp.blocks.Line(y_data)

expected_x = np.arange(10)
npt.assert_equal(line_block.line.get_xdata(), expected_x)

def test_list_input(self):
Comment thread
t-makaro marked this conversation as resolved.
x_data = [np.array([1, 2, 3]), np.array([1, 2, 3])]
y_data = [np.array([5, 6, 7]), np.array([4, 2, 9])]
line_block = amp.blocks.Line(x_data, y_data)
npt.assert_equal(line_block.y, np.array([[5, 6, 7], [4, 2, 9]]))
npt.assert_equal(line_block.x, np.array([[1, 2, 3], [1, 2, 3]]))

def test_ragged_list_input(self):
x_data = [np.array([1, 2, 3]), np.array([1, 2, 3, 4])]
y_data = [np.array([5, 6, 7]), np.array([4, 2, 9, 10])]

with pytest.raises(ValueError) as err:
line_block = amp.blocks.Line(y_data)
assert "Must specify x data explicitly" in str(err)

line_block = amp.blocks.Line(x_data, y_data)

assert_jagged_arrays_equal(line_block.x, np.array(x_data))
assert_jagged_arrays_equal(line_block.y, np.array(y_data))

def test_bad_ragged_list_input(self):
x_data = np.array([np.array([1, 2, 3]), np.array([1, 2, 3, 4])])
y_data = np.array([np.array([5, 6, 7]), np.array([4, 2, 9, 10, 11])])

with pytest.raises(ValueError) as err:
line_block = amp.blocks.Line(x_data, y_data)
assert "x & y data must match" in str(err)

def test_bad_input(self):
# incorrect number of args
with pytest.raises(ValueError) as err:
amp.blocks.Line(1, 2, 3)
assert 'Invalid data arguments' in str(err.value)
with pytest.raises(ValueError) as err:
amp.blocks.Line()
assert 'Invalid data arguments' in str(err.value)

# No y data
with pytest.raises(ValueError) as err:
amp.blocks.Line(np.arange(5), None)
assert 'Must supply y data' in str(err.value)
with pytest.raises(ValueError) as err:
amp.blocks.Line(None)
assert 'Must supply y data' in str(err.value)

# y data not 2d
with pytest.raises(ValueError) as err:
amp.blocks.Line(np.arange(5), np.random.randn(5, 2, 2))
assert 'y data must be 2-dimensional' in str(err.value)

# 1d x doesn't match y
with pytest.raises(ValueError) as err:
amp.blocks.Line(np.arange(5), np.random.randn(4, 2))
assert 'dimensions of x must be compatible' in str(err.value)

# 2d x doesn't match y
with pytest.raises(ValueError) as err:
x = np.array([np.arange(5), np.arange(5)])
amp.blocks.Line(x, np.random.randn(4, 2), t_axis=1)
assert 'dimensions of x must be compatible' in str(err.value)

def test_kwarg_throughput(self):
x = np.array([np.arange(5), np.arange(5)])
line_block = amp.blocks.Line(x, np.random.randn(2, 5), t_axis=1,
alpha=0.5)
assert line_block.line.get_alpha() == 0.5


class TestComparisons:
@animation_compare(baseline_images='Blocks/Line', nframes=5)
def test_Line(self):
Expand Down