diff --git a/README.md b/README.md index 6ef7a72..1b21732 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,6 @@ This is a wrapper for matplotlib so I can produce figures with consistent formatting. It also has some pretty nice additions such as using layers and exporting to tikz. -Related packages: [maxtikzlib](https://github.com/max-models/maxtikzlib) and [maxtexlib](https://github.com/max-models/maxtexlib). - ## Install Create and activate python environment diff --git a/pyproject.toml b/pyproject.toml index c7b14ca..2e61bbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "maxplotlibx" -version = "0.1" +version = "0.1.1" description = "A reproducible plotting module with various backends and export options." readme = "README.md" requires-python = ">=3.8" @@ -47,3 +47,12 @@ dev = [ [tool.setuptools.packages.find] where = ["src"] + +[tool.black] +line-length = 88 + +[tool.ruff] +line-length = 88 + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/src/maxplotlib/canvas/canvas.py b/src/maxplotlib/canvas/canvas.py index 0e6faaf..8fbc51f 100644 --- a/src/maxplotlib/canvas/canvas.py +++ b/src/maxplotlib/canvas/canvas.py @@ -1,16 +1,35 @@ import os +from typing import Dict import matplotlib.pyplot as plt import plotly.graph_objects as go from plotly.subplots import make_subplots -import maxplotlib.backends.matplotlib.utils as plt_utils +from maxplotlib.backends.matplotlib.utils import ( + set_size, + setup_plotstyle, + setup_tex_fonts, +) from maxplotlib.subfigure.line_plot import LinePlot from maxplotlib.subfigure.tikz_figure import TikzFigure +from maxplotlib.utils.options import Backends class Canvas: - def __init__(self, **kwargs): + def __init__( + self, + nrows: int = 1, + ncols: int = 1, + figsize: tuple | None = None, + caption: str | None = None, + description: str | None = None, + label: str | None = None, + fontsize: int = 14, + dpi: int = 300, + width: str = "17cm", + ratio: str = "golden", # TODO Add literal + gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1}, + ): """ Initialize the Canvas class for multiple subplots. @@ -18,21 +37,27 @@ def __init__(self, **kwargs): nrows (int): Number of subplot rows. Default is 1. ncols (int): Number of subplot columns. Default is 1. figsize (tuple): Figure size. + caption (str): Caption for the figure. + description (str): Description for the figure. + label (str): Label for the figure. + fontsize (int): Font size. Default is 14. + dpi (int): DPI for the figure. Default is 300. + width (str): Width of the figure. Default is "17cm". + ratio (str): Aspect ratio. Default is "golden". + gridspec_kw (dict): Gridspec keyword arguments. Default is {"wspace": 0.08, "hspace": 0.1}. """ - # nrows=1, ncols=1, caption=None, description=None, label=None, figsize=None - self._nrows = kwargs.get("nrows", 1) - self._ncols = kwargs.get("ncols", 1) - self._figsize = kwargs.get("figsize", None) - self._caption = kwargs.get("caption", None) - self._description = kwargs.get("description", None) - self._label = kwargs.get("label", None) - self._fontsize = kwargs.get("fontsize", 14) - self._dpi = kwargs.get("dpi", 300) - # self._width = kwargs.get("width", 426.79135) - self._width = kwargs.get("width", "17cm") - self._ratio = kwargs.get("ratio", "golden") - self._gridspec_kw = kwargs.get("gridspec_kw", {"wspace": 0.08, "hspace": 0.1}) + self._nrows = nrows + self._ncols = ncols + self._figsize = figsize + self._caption = caption + self._description = description + self._label = label + self._fontsize = fontsize + self._dpi = dpi + self._width = width + self._ratio = ratio + self._gridspec_kw = gridspec_kw self._plotted = False # Dictionary to store lines for each subplot @@ -196,11 +221,11 @@ def add_subplot( def savefig( self, filename, - backend="matplotlib", - layers=None, - layer_by_layer=False, - verbose=False, - plot=True, + backend: Backends = "matplotlib", + layers: list | None = None, + layer_by_layer: bool = False, + verbose: bool = False, + plot: bool = True, ): filename_no_extension, extension = os.path.splitext(filename) if backend == "matplotlib": @@ -238,7 +263,7 @@ def savefig( if verbose: print(f"Saved {full_filepath}") - def plot(self, backend="matplotlib", savefig=False, layers=None): + def plot(self, backend: Backends = "matplotlib", savefig=False, layers=None): if backend == "matplotlib": return self.plot_matplotlib(savefig=savefig, layers=layers) elif backend == "plotly": @@ -263,9 +288,9 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False): filename (str, optional): Filename to save the figure. """ - tex_fonts = plt_utils.setup_tex_fonts(fontsize=self.fontsize, usetex=usetex) + tex_fonts = setup_tex_fonts(fontsize=self.fontsize, usetex=usetex) - plt_utils.setup_plotstyle( + setup_plotstyle( tex_fonts=tex_fonts, axes_grid=True, axes_grid_which="major", @@ -276,7 +301,7 @@ def plot_matplotlib(self, savefig=False, layers=None, usetex=False): if self._figsize is not None: fig_width, fig_height = self._figsize else: - fig_width, fig_height = plt_utils.set_size( + fig_width, fig_height = set_size( width=self._width, ratio=self._ratio, dpi=self.dpi, @@ -313,7 +338,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False): savefig (str, optional): Filename to save the figure if provided. """ - tex_fonts = plt_utils.setup_tex_fonts( + tex_fonts = setup_tex_fonts( fontsize=self.fontsize, usetex=usetex, ) # adjust or redefine for Plotly if needed @@ -322,7 +347,7 @@ def plot_plotly(self, show=True, savefig=None, usetex=False): if self._figsize is not None: fig_width, fig_height = self._figsize else: - fig_width, fig_height = plt_utils.set_size( + fig_width, fig_height = set_size( width=self._width, ratio=self._ratio, ) diff --git a/src/maxplotlib/subfigure/line_plot.py b/src/maxplotlib/subfigure/line_plot.py index e1daf44..9435b37 100644 --- a/src/maxplotlib/subfigure/line_plot.py +++ b/src/maxplotlib/subfigure/line_plot.py @@ -163,7 +163,12 @@ def layers(self): layers.append(layer_name) return layers - def plot_matplotlib(self, ax, layers=None): + def plot_matplotlib( + self, + ax, + layers=None, + verbose: bool = False, + ): """ Plot all lines on the provided axis. @@ -210,13 +215,13 @@ def plot_matplotlib(self, ax, layers=None): ax.legend() if self._grid: ax.grid() - if self.xmin: + if self.xmin is not None: ax.axis(xmin=self.xmin) - if self.xmax: + if self.xmax is not None: ax.axis(xmax=self.xmax) - if self.ymin: + if self.ymin is not None: ax.axis(ymin=self.ymin) - if self.ymax: + if self.ymax is not None: ax.axis(ymax=self.ymax) def plot_plotly(self): diff --git a/src/maxplotlib/utils/__init__.py b/src/maxplotlib/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/maxplotlib/utils/options.py b/src/maxplotlib/utils/options.py new file mode 100644 index 0000000..78d5482 --- /dev/null +++ b/src/maxplotlib/utils/options.py @@ -0,0 +1,3 @@ +from typing import Literal + +Backends = Literal["matplotlib", "plotly"]