Skip to content

Commit 19a6a38

Browse files
oscjacJasonGrace2282
authored andcommitted
Added support for iterators in VGroup constructor and relevant tests
1 parent dff83be commit 19a6a38

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

manim/mobject/types/vectorized_mobject.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import itertools as it
1616
import sys
17+
from types import GeneratorType
1718
from typing import (
1819
TYPE_CHECKING,
1920
Callable,
@@ -32,6 +33,7 @@
3233
from manim.constants import *
3334
from manim.mobject.mobject import Mobject
3435
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
36+
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
3537
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
3638
from manim.mobject.three_d.three_d_utils import (
3739
get_3d_vmob_gradient_start_and_end_points,
@@ -1990,7 +1992,11 @@ def construct(self):
19901992
19911993
"""
19921994

1993-
def __init__(self, *vmobjects, **kwargs):
1995+
def __init__(
1996+
self,
1997+
*vmobjects: VMobject | Iterable[VMobject] | types.GeneratorType[VMobject],
1998+
**kwargs,
1999+
):
19942000
super().__init__(**kwargs)
19952001
self.add(*vmobjects)
19962002

@@ -2003,7 +2009,9 @@ def __str__(self) -> str:
20032009
f"submobject{'s' if len(self.submobjects) > 0 else ''}"
20042010
)
20052011

2006-
def add(self, *vmobjects: VMobject) -> Self:
2012+
def add(
2013+
self, *vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject]
2014+
) -> Self:
20072015
"""Checks if all passed elements are an instance of VMobject and then add them to submobjects
20082016
20092017
Parameters
@@ -2051,15 +2059,44 @@ def construct(self):
20512059
(gr-circle_red).animate.shift(RIGHT)
20522060
)
20532061
"""
2062+
2063+
flattened_args = []
2064+
20542065
for m in vmobjects:
2055-
if not isinstance(m, (VMobject, OpenGLVMobject)):
2066+
# Mobject and its subclasses are iterable
2067+
if isinstance(m, (Iterable, GeneratorType)):
2068+
# If it's not a subclass of Mobject or OpenGLMobject, it must be an iterable or generator
2069+
if not isinstance(m, (Mobject, OpenGLMobject)):
2070+
temp = []
2071+
temp.extend(m)
2072+
2073+
# Verify that every element in the iterable is VMobject or OpenGLVMobject
2074+
for t in temp:
2075+
if not isinstance(t, (VMobject, OpenGLVMobject)):
2076+
raise TypeError(
2077+
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
2078+
f"Got {repr(m)} ({type(m).__name__}) instead. "
2079+
"You can try using `Group` instead."
2080+
)
2081+
2082+
flattened_args.extend(temp)
2083+
elif isinstance(m, (VMobject, OpenGLVMobject)):
2084+
flattened_args.append(m)
2085+
else:
2086+
raise TypeError(
2087+
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
2088+
f"Got {repr(m)} ({type(m).__name__}) instead. "
2089+
"You can try using `Group` instead."
2090+
)
2091+
2092+
else:
20562093
raise TypeError(
20572094
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
20582095
f"Got {repr(m)} ({type(m).__name__}) instead. "
20592096
"You can try using `Group` instead."
20602097
)
20612098

2062-
return super().add(*vmobjects)
2099+
return super().add(*flattened_args)
20632100

20642101
def __add__(self, vmobject: VMobject) -> Self:
20652102
return VGroup(*self.submobjects, vmobject)

tests/module/mobject/types/vectorized_mobject/test_vectorized_mobject.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ def test_vgroup_init():
8585
VGroup(Mobject(), Mobject())
8686

8787

88+
def test_vgroup_iter_init():
89+
"""Test the VGroup instantiation with an iterable type."""
90+
91+
def basic_generator(n, type):
92+
i = 0
93+
while i < n:
94+
i += 1
95+
yield type()
96+
97+
obj = VGroup(basic_generator(5, VMobject))
98+
assert len(obj.submobjects) == 5
99+
with pytest.raises(TypeError):
100+
VGroup(basic_generator(Mobject))
101+
102+
88103
def test_vgroup_add():
89104
"""Test the VGroup add method."""
90105
obj = VGroup()

tests/opengl/test_opengl_vectorized_mobject.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@ def test_vgroup_init(using_opengl_renderer):
4343
VGroup(OpenGLMobject(), OpenGLMobject())
4444

4545

46+
def test_vgroup_iter_init(using_opengl_renderer):
47+
"""Test the VGroup instantiation with an iterable type."""
48+
49+
def basic_generator(n, type):
50+
i = 0
51+
while i < n:
52+
i += 1
53+
yield type()
54+
55+
obj = VGroup(basic_generator(5, OpenGLVMobject))
56+
assert len(obj.submobjects) == 5
57+
58+
with pytest.raises(TypeError):
59+
VGroup(basic_generator(5, OpenGLMobject))
60+
61+
4662
def test_vgroup_add(using_opengl_renderer):
4763
"""Test the VGroup add method."""
4864
obj = VGroup()

0 commit comments

Comments
 (0)