Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

This is a wrapper for matplotlib so I can produce figures with consistent formatting. It also has some pretty nice additions such as using layers and exporting to tikz.

Related packages: [maxtikzlib](https://github.com/max-models/maxtikzlib) and [maxtexlib](https://github.com/max-models/maxtexlib).

## Install

Create and activate python environment
Expand Down
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "maxplotlibx"
version = "0.1"
version = "0.1.1"
description = "A reproducible plotting module with various backends and export options."
readme = "README.md"
requires-python = ">=3.8"
Expand Down Expand Up @@ -47,3 +47,12 @@ dev = [

[tool.setuptools.packages.find]
where = ["src"]

[tool.black]
line-length = 88

[tool.ruff]
line-length = 88

[tool.isort]
profile = "black"
77 changes: 51 additions & 26 deletions src/maxplotlib/canvas/canvas.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,63 @@
import os
from typing import Dict

import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import maxplotlib.backends.matplotlib.utils as plt_utils
from maxplotlib.backends.matplotlib.utils import (
set_size,
setup_plotstyle,
setup_tex_fonts,
)
from maxplotlib.subfigure.line_plot import LinePlot
from maxplotlib.subfigure.tikz_figure import TikzFigure
from maxplotlib.utils.options import Backends


class Canvas:
def __init__(self, **kwargs):
def __init__(
self,
nrows: int = 1,
ncols: int = 1,
figsize: tuple | None = None,
caption: str | None = None,
description: str | None = None,
label: str | None = None,
fontsize: int = 14,
dpi: int = 300,
width: str = "17cm",
ratio: str = "golden", # TODO Add literal
gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1},
):
"""
Initialize the Canvas class for multiple subplots.

Parameters:
nrows (int): Number of subplot rows. Default is 1.
ncols (int): Number of subplot columns. Default is 1.
figsize (tuple): Figure size.
caption (str): Caption for the figure.
description (str): Description for the figure.
label (str): Label for the figure.
fontsize (int): Font size. Default is 14.
dpi (int): DPI for the figure. Default is 300.
width (str): Width of the figure. Default is "17cm".
ratio (str): Aspect ratio. Default is "golden".
gridspec_kw (dict): Gridspec keyword arguments. Default is {"wspace": 0.08, "hspace": 0.1}.
"""

# nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None
self._nrows = kwargs.get("nrows", 1)
self._ncols = kwargs.get("ncols", 1)
self._figsize = kwargs.get("figsize", None)
self._caption = kwargs.get("caption", None)
self._description = kwargs.get("description", None)
self._label = kwargs.get("label", None)
self._fontsize = kwargs.get("fontsize", 14)
self._dpi = kwargs.get("dpi", 300)
# self._width = kwargs.get("width", 426.79135)
self._width = kwargs.get("width", "17cm")
self._ratio = kwargs.get("ratio", "golden")
self._gridspec_kw = kwargs.get("gridspec_kw", {"wspace": 0.08, "hspace": 0.1})
self._nrows = nrows
self._ncols = ncols
self._figsize = figsize
self._caption = caption
self._description = description
self._label = label
self._fontsize = fontsize
self._dpi = dpi
self._width = width
self._ratio = ratio
self._gridspec_kw = gridspec_kw
self._plotted = False

# Dictionary to store lines for each subplot
Expand Down Expand Up @@ -196,11 +221,11 @@ def add_subplot(
def savefig(
self,
filename,
backend="matplotlib",
layers=None,
layer_by_layer=False,
verbose=False,
plot=True,
backend: Backends = "matplotlib",
layers: list | None = None,
layer_by_layer: bool = False,
verbose: bool = False,
plot: bool = True,
):
filename_no_extension, extension = os.path.splitext(filename)
if backend == "matplotlib":
Expand Down Expand Up @@ -238,7 +263,7 @@ def savefig(
if verbose:
print(f"Saved {full_filepath}")

def plot(self, backend="matplotlib", savefig=False, layers=None):
def plot(self, backend: Backends = "matplotlib", savefig=False, layers=None):
if backend == "matplotlib":
return self.plot_matplotlib(savefig=savefig, layers=layers)
elif backend == "plotly":
Expand All @@ -263,9 +288,9 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
filename (str, optional): Filename to save the figure.
"""

tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)
tex_fonts = setup_tex_fonts(fontsize=self.fontsize, usetex=usetex)

plt_utils.setup_plotstyle(
setup_plotstyle(
tex_fonts=tex_fonts,
axes_grid=True,
axes_grid_which="major",
Expand All @@ -276,7 +301,7 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False):
if self._figsize is not None:
fig_width, fig_height = self._figsize
else:
fig_width, fig_height = plt_utils.set_size(
fig_width, fig_height = set_size(
width=self._width,
ratio=self._ratio,
dpi=self.dpi,
Expand Down Expand Up @@ -313,7 +338,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
savefig (str, optional): Filename to save the figure if provided.
"""

tex_fonts = plt_utils.setup_tex_fonts(
tex_fonts = setup_tex_fonts(
fontsize=self.fontsize,
usetex=usetex,
) # adjust or redefine for Plotly if needed
Expand All @@ -322,7 +347,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False):
if self._figsize is not None:
fig_width, fig_height = self._figsize
else:
fig_width, fig_height = plt_utils.set_size(
fig_width, fig_height = set_size(
width=self._width,
ratio=self._ratio,
)
Expand Down
15 changes: 10 additions & 5 deletions src/maxplotlib/subfigure/line_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def layers(self):
layers.append(layer_name)
return layers

def plot_matplotlib(self, ax, layers=None):
def plot_matplotlib(
self,
ax,
layers=None,
verbose: bool = False,
):
"""
Plot all lines on the provided axis.

Expand Down Expand Up @@ -210,13 +215,13 @@ def plot_matplotlib(self, ax, layers=None):
ax.legend()
if self._grid:
ax.grid()
if self.xmin:
if self.xmin is not None:
ax.axis(xmin=self.xmin)
if self.xmax:
if self.xmax is not None:
ax.axis(xmax=self.xmax)
if self.ymin:
if self.ymin is not None:
ax.axis(ymin=self.ymin)
if self.ymax:
if self.ymax is not None:
ax.axis(ymax=self.ymax)

def plot_plotly(self):
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions src/maxplotlib/utils/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Literal

Backends = Literal["matplotlib", "plotly"]
Loading