Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion manim/animation/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,14 @@ def construct(self):

"""

def __init__(self, mobject, target_mobject, stretch=True, dim_to_match=1, **kwargs):
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject,
stretch: bool = True,
dim_to_match: int = 1,
**kwargs: Any,
):
self.to_add_on_completion = target_mobject
self.stretch = stretch
self.dim_to_match = dim_to_match
Expand Down
20 changes: 12 additions & 8 deletions manim/animation/transform_matching_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

__all__ = ["TransformMatchingShapes", "TransformMatchingTex"]

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np

from manim.mobject.opengl.opengl_mobject import OpenGLGroup, OpenGLMobject
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVGroup, OpenGLVMobject
from manim.mobject.text.tex_mobject import SingleStringMathTex

from .._config import config
from ..constants import RendererType
Expand Down Expand Up @@ -74,10 +75,10 @@ def __init__(
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
**kwargs: Any,
):
if isinstance(mobject, OpenGLVMobject):
group_type = OpenGLVGroup
group_type: type[OpenGLVGroup | OpenGLGroup | VGroup | Group] = OpenGLVGroup
elif isinstance(mobject, OpenGLMobject):
group_type = OpenGLGroup
elif isinstance(mobject, VMobject):
Expand Down Expand Up @@ -141,31 +142,33 @@ def __init__(
self.to_add = target_mobject

def get_shape_map(self, mobject: Mobject) -> dict:
shape_map = {}
shape_map: dict[int | str, VGroup | OpenGLVGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
if config["renderer"] == RendererType.OPENGL:
shape_map[key] = OpenGLVGroup()
else:
shape_map[key] = VGroup()
# error: Argument 1 to "add" of "OpenGLVGroup" has incompatible type "Mobject"; expected "OpenGLVMobject" [arg-type]
shape_map[key].add(sm)
return shape_map

def clean_up_from_scene(self, scene: Scene) -> None:
# Interpolate all animations back to 0 to ensure source mobjects remain unchanged.
for anim in self.animations:
anim.interpolate(0)
# error: Argument 1 to "remove" of "Scene" has incompatible type "OpenGLMobject"; expected "Mobject" [arg-type]
scene.remove(self.mobject)
scene.remove(*self.to_remove)
scene.add(self.to_add)

@staticmethod
def get_mobject_parts(mobject: Mobject):
def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
raise NotImplementedError("To be implemented in subclass.")

@staticmethod
def get_mobject_key(mobject: Mobject):
def get_mobject_key(mobject: Mobject) -> int | str:
raise NotImplementedError("To be implemented in subclass.")


Expand Down Expand Up @@ -205,7 +208,7 @@ def __init__(
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
**kwargs: Any,
):
super().__init__(
mobject,
Expand Down Expand Up @@ -269,7 +272,7 @@ def __init__(
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs,
**kwargs: Any,
):
super().__init__(
mobject,
Expand All @@ -294,4 +297,5 @@ def get_mobject_parts(mobject: Mobject) -> list[Mobject]:

@staticmethod
def get_mobject_key(mobject: Mobject) -> str:
assert isinstance(mobject, SingleStringMathTex)
return mobject.tex_string
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ ignore_errors = True
[mypy-manim.animation.speedmodifier]
ignore_errors = True

[mypy-manim.animation.transform_matching_parts]
ignore_errors = True

[mypy-manim.animation.transform]
ignore_errors = True

Expand Down
Loading