1414
1515import itertools as it
1616import sys
17+ from types import GeneratorType
1718from typing import (
1819 TYPE_CHECKING ,
1920 Callable ,
3233from manim .constants import *
3334from manim .mobject .mobject import Mobject
3435from manim .mobject .opengl .opengl_compatibility import ConvertToOpenGL
36+ from manim .mobject .opengl .opengl_mobject import OpenGLMobject
3537from manim .mobject .opengl .opengl_vectorized_mobject import OpenGLVMobject
3638from manim .mobject .three_d .three_d_utils import (
3739 get_3d_vmob_gradient_start_and_end_points ,
@@ -1990,7 +1992,11 @@ def construct(self):
19901992
19911993 """
19921994
1993- def __init__ (self , * vmobjects , ** kwargs ):
1995+ def __init__ (
1996+ self ,
1997+ * vmobjects : VMobject | Iterable [VMobject ] | types .GeneratorType [VMobject ],
1998+ ** kwargs ,
1999+ ):
19942000 super ().__init__ (** kwargs )
19952001 self .add (* vmobjects )
19962002
@@ -2003,7 +2009,9 @@ def __str__(self) -> str:
20032009 f"submobject{ 's' if len (self .submobjects ) > 0 else '' } "
20042010 )
20052011
2006- def add (self , * vmobjects : VMobject ) -> Self :
2012+ def add (
2013+ self , * vmobjects : VMobject | Iterable [VMobject ] | GeneratorType [VMobject ]
2014+ ) -> Self :
20072015 """Checks if all passed elements are an instance of VMobject and then add them to submobjects
20082016
20092017 Parameters
@@ -2051,15 +2059,44 @@ def construct(self):
20512059 (gr-circle_red).animate.shift(RIGHT)
20522060 )
20532061 """
2062+
2063+ flattened_args = []
2064+
20542065 for m in vmobjects :
2055- if not isinstance (m , (VMobject , OpenGLVMobject )):
2066+ # Mobject and its subclasses are iterable
2067+ if isinstance (m , (Iterable , GeneratorType )):
2068+ # If it's not a subclass of Mobject or OpenGLMobject, it must be an iterable or generator
2069+ if not isinstance (m , (Mobject , OpenGLMobject )):
2070+ temp = []
2071+ temp .extend (m )
2072+
2073+ # Verify that every element in the iterable is VMobject or OpenGLVMobject
2074+ for t in temp :
2075+ if not isinstance (t , (VMobject , OpenGLVMobject )):
2076+ raise TypeError (
2077+ f"All submobjects of { self .__class__ .__name__ } must be of type VMobject. "
2078+ f"Got { repr (m )} ({ type (m ).__name__ } ) instead. "
2079+ "You can try using `Group` instead."
2080+ )
2081+
2082+ flattened_args .extend (temp )
2083+ elif isinstance (m , (VMobject , OpenGLVMobject )):
2084+ flattened_args .append (m )
2085+ else :
2086+ raise TypeError (
2087+ f"All submobjects of { self .__class__ .__name__ } must be of type VMobject. "
2088+ f"Got { repr (m )} ({ type (m ).__name__ } ) instead. "
2089+ "You can try using `Group` instead."
2090+ )
2091+
2092+ else :
20562093 raise TypeError (
20572094 f"All submobjects of { self .__class__ .__name__ } must be of type VMobject. "
20582095 f"Got { repr (m )} ({ type (m ).__name__ } ) instead. "
20592096 "You can try using `Group` instead."
20602097 )
20612098
2062- return super ().add (* vmobjects )
2099+ return super ().add (* flattened_args )
20632100
20642101 def __add__ (self , vmobject : VMobject ) -> Self :
20652102 return VGroup (* self .submobjects , vmobject )
0 commit comments