-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Added support for iterables in VGroup constructor #3686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
afe1810
db7b4a6
19a6a38
7362ee8
d29f4ee
c64adfe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||
|
||||||||
import itertools as it | ||||||||
import sys | ||||||||
from types import GeneratorType | ||||||||
from typing import ( | ||||||||
TYPE_CHECKING, | ||||||||
Callable, | ||||||||
|
@@ -32,6 +33,7 @@ | |||||||
from manim.constants import * | ||||||||
from manim.mobject.mobject import Mobject | ||||||||
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL | ||||||||
from manim.mobject.opengl.opengl_mobject import OpenGLMobject | ||||||||
Check noticeCode scanning / CodeQL Cyclic import
Import of module [manim.mobject.opengl.opengl_mobject](1) begins an import cycle.
|
||||||||
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject | ||||||||
from manim.mobject.three_d.three_d_utils import ( | ||||||||
get_3d_vmob_gradient_start_and_end_points, | ||||||||
|
@@ -1990,7 +1992,11 @@ def construct(self): | |||||||
|
||||||||
""" | ||||||||
|
||||||||
def __init__(self, *vmobjects, **kwargs): | ||||||||
def __init__( | ||||||||
self, | ||||||||
*vmobjects: VMobject | Iterable[VMobject] | types.GeneratorType[VMobject], | ||||||||
**kwargs, | ||||||||
): | ||||||||
super().__init__(**kwargs) | ||||||||
self.add(*vmobjects) | ||||||||
|
||||||||
|
@@ -2003,7 +2009,9 @@ def __str__(self) -> str: | |||||||
f"submobject{'s' if len(self.submobjects) > 0 else ''}" | ||||||||
) | ||||||||
|
||||||||
def add(self, *vmobjects: VMobject) -> Self: | ||||||||
def add( | ||||||||
self, *vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject] | ||||||||
) -> Self: | ||||||||
"""Checks if all passed elements are an instance of VMobject and then add them to submobjects | ||||||||
|
||||||||
Parameters | ||||||||
|
@@ -2051,15 +2059,44 @@ def construct(self): | |||||||
(gr-circle_red).animate.shift(RIGHT) | ||||||||
) | ||||||||
""" | ||||||||
|
||||||||
flattened_args = [] | ||||||||
|
||||||||
for m in vmobjects: | ||||||||
if not isinstance(m, (VMobject, OpenGLVMobject)): | ||||||||
# Mobject and its subclasses are iterable | ||||||||
if isinstance(m, (Iterable, GeneratorType)): | ||||||||
# If it's not a subclass of Mobject or OpenGLMobject, it must be an iterable or generator | ||||||||
if not isinstance(m, (Mobject, OpenGLMobject)): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could rewrite this in a way that reduces indentation by just checking if it's a |
||||||||
temp = [] | ||||||||
temp.extend(m) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
# Verify that every element in the iterable is VMobject or OpenGLVMobject | ||||||||
for t in temp: | ||||||||
if not isinstance(t, (VMobject, OpenGLVMobject)): | ||||||||
raise TypeError( | ||||||||
f"All submobjects of {self.__class__.__name__} must be of type VMobject. " | ||||||||
f"Got {repr(m)} ({type(m).__name__}) instead. " | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Copy-paste error? |
||||||||
"You can try using `Group` instead." | ||||||||
) | ||||||||
|
||||||||
flattened_args.extend(temp) | ||||||||
elif isinstance(m, (VMobject, OpenGLVMobject)): | ||||||||
flattened_args.append(m) | ||||||||
else: | ||||||||
raise TypeError( | ||||||||
f"All submobjects of {self.__class__.__name__} must be of type VMobject. " | ||||||||
f"Got {repr(m)} ({type(m).__name__}) instead. " | ||||||||
"You can try using `Group` instead." | ||||||||
) | ||||||||
|
||||||||
else: | ||||||||
raise TypeError( | ||||||||
f"All submobjects of {self.__class__.__name__} must be of type VMobject. " | ||||||||
f"Got {repr(m)} ({type(m).__name__}) instead. " | ||||||||
"You can try using `Group` instead." | ||||||||
) | ||||||||
|
||||||||
return super().add(*vmobjects) | ||||||||
return super().add(*flattened_args) | ||||||||
|
||||||||
def __add__(self, vmobject: VMobject) -> Self: | ||||||||
return VGroup(*self.submobjects, vmobject) | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -85,6 +85,21 @@ def test_vgroup_init(): | |||||||||||||
VGroup(Mobject(), Mobject()) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def test_vgroup_iter_init(): | ||||||||||||||
"""Test the VGroup instantiation with an iterable type.""" | ||||||||||||||
|
||||||||||||||
def basic_generator(n, type): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest using a different parameter name other than |
||||||||||||||
i = 0 | ||||||||||||||
while i < n: | ||||||||||||||
i += 1 | ||||||||||||||
yield type() | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not something like this?
Suggested change
|
||||||||||||||
|
||||||||||||||
obj = VGroup(basic_generator(5, VMobject)) | ||||||||||||||
assert len(obj.submobjects) == 5 | ||||||||||||||
with pytest.raises(TypeError): | ||||||||||||||
VGroup(basic_generator(Mobject)) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Is this what you meant? |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
def test_vgroup_add(): | ||||||||||||||
"""Test the VGroup add method.""" | ||||||||||||||
obj = VGroup() | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,22 @@ def test_vgroup_init(using_opengl_renderer): | |
VGroup(OpenGLMobject(), OpenGLMobject()) | ||
|
||
|
||
def test_vgroup_iter_init(using_opengl_renderer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See feedback above |
||
"""Test the VGroup instantiation with an iterable type.""" | ||
|
||
def basic_generator(n, type): | ||
i = 0 | ||
while i < n: | ||
i += 1 | ||
yield type() | ||
|
||
obj = VGroup(basic_generator(5, OpenGLVMobject)) | ||
assert len(obj.submobjects) == 5 | ||
|
||
with pytest.raises(TypeError): | ||
VGroup(basic_generator(5, OpenGLMobject)) | ||
|
||
|
||
def test_vgroup_add(using_opengl_renderer): | ||
"""Test the VGroup add method.""" | ||
obj = VGroup() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer this be under
if TYPE_CHECKING
since generators are counted as iterables