Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions lambeq/backend/drawing/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
)
from lambeq.backend.grammar import Box, Diagram


if TYPE_CHECKING:
from IPython.core.display import HTML as HTML_ty

Expand Down Expand Up @@ -135,7 +134,7 @@ def draw(diagram: Diagram, **params) -> None:
backend: DrawingBackend = params.pop('backend')
elif params.get('to_tikz', False):
backend = TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None),
use_tikzstyles=params.get('use_tikzstyles', False),
box_linewidth=params.get('box_linewidth', TIKZ_BOX_LINEWIDTH),
wire_linewidth=params.get('wire_linewidth',
TIKZ_WIRE_LINEWIDTH),
Expand Down Expand Up @@ -222,7 +221,7 @@ def draw_pregroup(diagram: Diagram, **params) -> None:
backend: DrawingBackend = params.pop('backend')
elif params.get('to_tikz', False):
backend = TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None))
use_tikzstyles=params.get('use_tikzstyles', False))
else:
backend = MatBackend(figsize=params.get('figsize', None))

Expand Down Expand Up @@ -373,7 +372,7 @@ def draw_equation(*terms: grammar.Diagram,
backend: DrawingBackend = params.pop('backend')
elif params.get('to_tikz', False):
backend = TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None))
use_tikzstyles=params.get('use_tikzstyles', False))
else:
backend = MatBackend(figsize=params.get('figsize', None))

Expand Down
2 changes: 1 addition & 1 deletion lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@

from lambeq.core.utils import fast_deepcopy


if TYPE_CHECKING:
import discopy

from lambeq.text2diagram.pregroup_tree import PregroupTreeNode


Expand Down
1 change: 0 additions & 1 deletion lambeq/backend/pennylane.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
to_circuital)
from lambeq.backend.symbol import lambdify, Symbol


if TYPE_CHECKING:
from lambeq.backend.quantum import Diagram

Expand Down
8 changes: 4 additions & 4 deletions lambeq/backend/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def array(self):
with backend() as np:
return np.array(self.data)

__hash__: Callable[[Box], int] = Box.__hash__
__hash__: Callable[[Box], int] = Box.__hash__ # type: ignore[assignment]

def dagger(self):
return replace(self, data=self.data.conjugate())
Expand All @@ -1109,7 +1109,7 @@ def array(self):
with backend() as np:
return np.array(self.data ** .5)

__hash__: Callable[[], int] = Scalar.__hash__
__hash__: Callable[[], int] = Scalar.__hash__ # type: ignore[assignment]

def dagger(self):
return replace(self, data=np.conjugate(self.data))
Expand Down Expand Up @@ -1150,8 +1150,8 @@ def __setattr__(self, __name: str, __value: Any) -> None:
def dagger(self) -> Box:
return self.box

__hash__: Callable[[Box], int] = Box.__hash__
__repr__: Callable[[Box], str] = Box.__repr__
__hash__: Callable[[Box], int] = Box.__hash__ # type: ignore[assignment]
__repr__: Callable[[Box], str] = Box.__repr__ # type: ignore[assignment]


class Bit(Box):
Expand Down
9 changes: 7 additions & 2 deletions lambeq/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Dim(grammar.Ty):
Product of contained dimensions.

"""
objects: list[Self] # type: ignore[assignment,misc]
objects: list[Self] # type: ignore[misc]

def __init__(self,
*dim: int,
Expand Down Expand Up @@ -242,7 +242,12 @@ def recursive_free_symbols(data) -> set[Symbol]:
if isinstance(data, Mapping):
data = data.values()
if isinstance(data, Iterable):
if not hasattr(data, 'shape') or data.shape != ():
if (
not hasattr(data, 'shape') or data.shape != ()
) and not (
not isinstance(data, np.ndarray)
or np.issubdtype(data.dtype, np.number)
):
return set().union(*map(recursive_free_symbols, data))
# Remove scale before adding to set
return {data.unscaled} if isinstance(data, Symbol) else set()
Expand Down
2 changes: 1 addition & 1 deletion lambeq/bobcat/lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Atom(FastIntEnum):

for atom in Atom._member_map_.values():
if TYPE_CHECKING:
from typing import cast
from typing import cast # noqa: I300
atom = cast(Atom, atom)
atom.is_punct = atom >= Atom.COMMA

Expand Down
1 change: 0 additions & 1 deletion lambeq/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import spacy


if TYPE_CHECKING:
import spacy.cli

Expand Down
1 change: 0 additions & 1 deletion lambeq/experimental/discocirc/coref_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from lambeq.core.utils import get_spacy_tokeniser


if TYPE_CHECKING:
import spacy.cli

Expand Down
3 changes: 2 additions & 1 deletion lambeq/text2diagram/depccg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from depccg.annotator import (annotate_XX, english_annotator,
japanese_annotator)
from depccg.cat import Category

from lambeq.backend.grammar import Diagram


Expand Down Expand Up @@ -256,7 +257,7 @@ def sentences2trees(self,
'`sentences` does not have type '
'`list[list[str]]`.')
if TYPE_CHECKING: # temporary fix
from typing import cast
from typing import cast # noqa: I100
sentences = cast(list[list[str]], sentences)
else:
if not untokenised_batch_type_check(sentences):
Expand Down
1 change: 0 additions & 1 deletion lambeq/tokeniser/spacy_tokeniser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from lambeq.core.utils import get_spacy_tokeniser
from lambeq.tokeniser import Tokeniser


if TYPE_CHECKING:
import spacy.cli

Expand Down
3 changes: 2 additions & 1 deletion lambeq/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import numpy as np

if TYPE_CHECKING:
from jax import numpy as jnp
from types import ModuleType

from jax import numpy as jnp


class LossFunction(ABC):
"""Loss function base class.
Expand Down
2 changes: 1 addition & 1 deletion lambeq/training/nelder_mead_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def objective(self, x: Iterable[Any], y: ArrayLike, w: ArrayLike) -> float:
raise ValueError(
'Objective function must return a scalar'
) from e
return result
return result # type: ignore[return-value]

def backward(self, batch: tuple[Iterable[Any], np.ndarray]) -> float:
"""Calculate the gradients of the loss function.
Expand Down
1 change: 0 additions & 1 deletion lambeq/training/numpy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from lambeq.backend.tensor import Diagram
from lambeq.training.quantum_model import QuantumModel


if TYPE_CHECKING:
from jax import numpy as jnp

Expand Down
Loading