44
55__all__ = ["TransformMatchingShapes" , "TransformMatchingTex" ]
66
7- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , Any
88
99import numpy as np
1010
1111from manim .mobject .opengl .opengl_mobject import OpenGLGroup , OpenGLMobject
1212from manim .mobject .opengl .opengl_vectorized_mobject import OpenGLVGroup , OpenGLVMobject
13+ from manim .mobject .text .tex_mobject import SingleStringMathTex
1314
1415from .._config import config
1516from ..constants import RendererType
@@ -74,10 +75,10 @@ def __init__(
7475 transform_mismatches : bool = False ,
7576 fade_transform_mismatches : bool = False ,
7677 key_map : dict | None = None ,
77- ** kwargs ,
78+ ** kwargs : Any ,
7879 ):
7980 if isinstance (mobject , OpenGLVMobject ):
80- group_type = OpenGLVGroup
81+ group_type : type [ OpenGLVGroup | OpenGLGroup | VGroup | Group ] = OpenGLVGroup
8182 elif isinstance (mobject , OpenGLMobject ):
8283 group_type = OpenGLGroup
8384 elif isinstance (mobject , VMobject ):
@@ -141,31 +142,33 @@ def __init__(
141142 self .to_add = target_mobject
142143
143144 def get_shape_map (self , mobject : Mobject ) -> dict :
144- shape_map = {}
145+ shape_map : dict [ int | str , VGroup | OpenGLVGroup ] = {}
145146 for sm in self .get_mobject_parts (mobject ):
146147 key = self .get_mobject_key (sm )
147148 if key not in shape_map :
148149 if config ["renderer" ] == RendererType .OPENGL :
149150 shape_map [key ] = OpenGLVGroup ()
150151 else :
151152 shape_map [key ] = VGroup ()
153+ # error: Argument 1 to "add" of "OpenGLVGroup" has incompatible type "Mobject"; expected "OpenGLVMobject" [arg-type]
152154 shape_map [key ].add (sm )
153155 return shape_map
154156
155157 def clean_up_from_scene (self , scene : Scene ) -> None :
156158 # Interpolate all animations back to 0 to ensure source mobjects remain unchanged.
157159 for anim in self .animations :
158160 anim .interpolate (0 )
161+ # error: Argument 1 to "remove" of "Scene" has incompatible type "OpenGLMobject"; expected "Mobject" [arg-type]
159162 scene .remove (self .mobject )
160163 scene .remove (* self .to_remove )
161164 scene .add (self .to_add )
162165
163166 @staticmethod
164- def get_mobject_parts (mobject : Mobject ):
167+ def get_mobject_parts (mobject : Mobject ) -> list [ Mobject ] :
165168 raise NotImplementedError ("To be implemented in subclass." )
166169
167170 @staticmethod
168- def get_mobject_key (mobject : Mobject ):
171+ def get_mobject_key (mobject : Mobject ) -> int | str :
169172 raise NotImplementedError ("To be implemented in subclass." )
170173
171174
@@ -205,7 +208,7 @@ def __init__(
205208 transform_mismatches : bool = False ,
206209 fade_transform_mismatches : bool = False ,
207210 key_map : dict | None = None ,
208- ** kwargs ,
211+ ** kwargs : Any ,
209212 ):
210213 super ().__init__ (
211214 mobject ,
@@ -269,7 +272,7 @@ def __init__(
269272 transform_mismatches : bool = False ,
270273 fade_transform_mismatches : bool = False ,
271274 key_map : dict | None = None ,
272- ** kwargs ,
275+ ** kwargs : Any ,
273276 ):
274277 super ().__init__ (
275278 mobject ,
@@ -294,4 +297,5 @@ def get_mobject_parts(mobject: Mobject) -> list[Mobject]:
294297
295298 @staticmethod
296299 def get_mobject_key (mobject : Mobject ) -> str :
300+ assert isinstance (mobject , SingleStringMathTex )
297301 return mobject .tex_string
0 commit comments