Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions manim/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from manim.constants import *
from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.opengl.opengl_mobject import OpenGLMobject

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [manim.mobject.opengl.opengl_mobject](1) begins an import cycle.
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.mobject.three_d.three_d_utils import (
get_3d_vmob_gradient_start_and_end_points,
Expand Down Expand Up @@ -1990,7 +1991,11 @@ def construct(self):

"""

def __init__(self, *vmobjects, **kwargs):
def __init__(
self,
*vmobjects: VMobject | Iterable[VMobject] | types.GeneratorType[VMobject],
**kwargs,
):
super().__init__(**kwargs)
self.add(*vmobjects)

Expand All @@ -2003,7 +2008,9 @@ def __str__(self) -> str:
f"submobject{'s' if len(self.submobjects) > 0 else ''}"
)

def add(self, *vmobjects: VMobject) -> Self:
def add(
self, *vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject]
) -> Self:
"""Checks if all passed elements are an instance of VMobject and then add them to submobjects

Parameters
Expand Down Expand Up @@ -2051,15 +2058,42 @@ def construct(self):
(gr-circle_red).animate.shift(RIGHT)
)
"""

flattened_args = []

for m in vmobjects:
if not isinstance(m, (VMobject, OpenGLVMobject)):
# Mobject and its subclasses are iterable
if not isinstance(m, (Iterable, Generator)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I may not have been clear enough in my last review, but you could reduce the amount of explanation needed and the clarity of the code by doing the following:

if isinstance(mob, VMobject): # first
    ...
elif isinstance(mob, Iterable): # generator not needed, it's a subprotocol of Iterable
    for submob in mob:
        if not isinstance(submob, VMobject):
            raise
    flattened.extend(mob)
else:
    raise

The benefit of this is that it's easier to read and it avoids repeating the same text for the ValueError in 3 places (it repeats it twice, in the ideal case it should never repeat it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you for the reviews! One problem that I'm running into testing this change locally is that when the constructor is passed an Mobject instance, it does not throw a TypeError.
The test is located in (tests/ ... /test_vectorized_mobject.py:77).
These are the changes I'm testing:

if isinstance(mob, VMobject):
    flattened_args.append(mob)
elif isinstance(mob, Iterable):
    for submob in mob:
        if not isinstance(submob, VMobject):
            raise TypeError(...)
        flattened_args.extend(mob)
else:
    raise TypeError(...)

The problem is that an empty Mobject is iterable but, since it is empty, it has no submobs and so the constructor never throws a TypeError. My first thought was to add another check before iterating through the submob as such:

elif isinstance(mob, Iterable):
    if (mob, is instance(Mobject)):
        raise TypeError(...)
    for submob in mob:
        if not isinstance(submob, VMobject):
            raise TypeError(...)
        flattened_args.extend(mob)

but, going off what you've already mentioned, I take it that it's not the best solution because it adds yet another copy of the TypeError and it is less readable. How do you think we could solve this without degrading the code's readability?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, didn't think about that. It could be circumvented with something like this

if isinstance(mob, VMobject):
    ...
elif isinstance(mob, Iterable) and type(mob) is not Mobject:
    for submob in mob:
        if not isinstance(submob, VMobject):
            raise
    flattened.extend(mob)
else:
    raise

Not as clean as I like, but it should keep only two copies of the error message.

I would also suggest putting the error message into a function and simply call that function both times to reduce the repetition of the error.

Thanks for your patience :)

raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(m)} ({type(m).__name__}) instead. "
"You can try using `Group` instead."
)

# If it's not a subclass of Mobject or OpenGLMobject, it must be an iterable or generator
if not isinstance(m, (Mobject, OpenGLMobject)):
temp = tuple(m)

# Verify that every element in the iterable is VMobject or OpenGLVMobject
for t in temp:
if not isinstance(t, (VMobject, OpenGLVMobject)):
raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(t)} ({type(t).__name__}) instead. "
"You can try using `Group` instead."
)

flattened_args.extend(temp)
elif isinstance(m, (VMobject, OpenGLVMobject)):
flattened_args.append(m)
else:
raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(m)} ({type(m).__name__}) instead. "
"You can try using `Group` instead."
)

return super().add(*vmobjects)
return super().add(*flattened_args)

def __add__(self, vmobject: VMobject) -> Self:
return VGroup(*self.submobjects, vmobject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def test_vgroup_init():
VGroup(Mobject(), Mobject())


def test_vgroup_iter_init():
"""Test the VGroup instantiation with an iterable type."""

def basic_generator(n, to_generate):
for _ in range(n):
yield to_generate()

obj = VGroup(basic_generator(5, VMobject))
assert len(obj.submobjects) == 5
with pytest.raises(TypeError):
VGroup(basic_generator(1, Mobject))


def test_vgroup_add():
"""Test the VGroup add method."""
obj = VGroup()
Expand Down
16 changes: 16 additions & 0 deletions tests/opengl/test_opengl_vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def test_vgroup_init(using_opengl_renderer):
VGroup(OpenGLMobject(), OpenGLMobject())


def test_vgroup_iter_init(using_opengl_renderer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See feedback above

"""Test the VGroup instantiation with an iterable type."""

def basic_generator(n, type):
i = 0
while i < n:
i += 1
yield type()

obj = VGroup(basic_generator(5, OpenGLVMobject))
assert len(obj.submobjects) == 5

with pytest.raises(TypeError):
VGroup(basic_generator(5, OpenGLMobject))


def test_vgroup_add(using_opengl_renderer):
"""Test the VGroup add method."""
obj = VGroup()
Expand Down