Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions manim/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import itertools as it
import sys
from types import GeneratorType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this be under if TYPE_CHECKING since generators are counted as iterables

from typing import (
TYPE_CHECKING,
Callable,
Expand All @@ -32,6 +33,7 @@
from manim.constants import *
from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.opengl.opengl_mobject import OpenGLMobject

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [manim.mobject.opengl.opengl_mobject](1) begins an import cycle.
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject
from manim.mobject.three_d.three_d_utils import (
get_3d_vmob_gradient_start_and_end_points,
Expand Down Expand Up @@ -1990,7 +1992,11 @@ def construct(self):

"""

def __init__(self, *vmobjects, **kwargs):
def __init__(
self,
*vmobjects: VMobject | Iterable[VMobject] | types.GeneratorType[VMobject],
**kwargs,
):
super().__init__(**kwargs)
self.add(*vmobjects)

Expand All @@ -2003,7 +2009,9 @@ def __str__(self) -> str:
f"submobject{'s' if len(self.submobjects) > 0 else ''}"
)

def add(self, *vmobjects: VMobject) -> Self:
def add(
self, *vmobjects: VMobject | Iterable[VMobject] | GeneratorType[VMobject]
) -> Self:
"""Checks if all passed elements are an instance of VMobject and then add them to submobjects

Parameters
Expand Down Expand Up @@ -2051,15 +2059,44 @@ def construct(self):
(gr-circle_red).animate.shift(RIGHT)
)
"""

flattened_args = []

for m in vmobjects:
if not isinstance(m, (VMobject, OpenGLVMobject)):
# Mobject and its subclasses are iterable
if isinstance(m, (Iterable, GeneratorType)):
# If it's not a subclass of Mobject or OpenGLMobject, it must be an iterable or generator
if not isinstance(m, (Mobject, OpenGLMobject)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could rewrite this in a way that reduces indentation by just checking if it's a Mobject before checking if it's an iterable.

temp = []
temp.extend(m)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
temp = []
temp.extend(m)
temp = tuple(m)


# Verify that every element in the iterable is VMobject or OpenGLVMobject
for t in temp:
if not isinstance(t, (VMobject, OpenGLVMobject)):
raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(m)} ({type(m).__name__}) instead. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Got {repr(m)} ({type(m).__name__}) instead. "
f"Got {repr(t)} ({type(t).__name__}) instead. "

Copy-paste error?

"You can try using `Group` instead."
)

flattened_args.extend(temp)
elif isinstance(m, (VMobject, OpenGLVMobject)):
flattened_args.append(m)
else:
raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(m)} ({type(m).__name__}) instead. "
"You can try using `Group` instead."
)

else:
raise TypeError(
f"All submobjects of {self.__class__.__name__} must be of type VMobject. "
f"Got {repr(m)} ({type(m).__name__}) instead. "
"You can try using `Group` instead."
)

return super().add(*vmobjects)
return super().add(*flattened_args)

def __add__(self, vmobject: VMobject) -> Self:
return VGroup(*self.submobjects, vmobject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ def test_vgroup_init():
VGroup(Mobject(), Mobject())


def test_vgroup_iter_init():
"""Test the VGroup instantiation with an iterable type."""

def basic_generator(n, type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using a different parameter name other than type (since it's a built-in function/metaclass/whatever you want to call it).
If you really think type is the best name, use the name type_ as that's the naming standard for python.

i = 0
while i < n:
i += 1
yield type()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not something like this?

Suggested change
i = 0
while i < n:
i += 1
yield type()
for _ in range(n):
yield type()


obj = VGroup(basic_generator(5, VMobject))
assert len(obj.submobjects) == 5
with pytest.raises(TypeError):
VGroup(basic_generator(Mobject))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VGroup(basic_generator(Mobject))
VGroup(basic_generator(1, Mobject))

Is this what you meant?



def test_vgroup_add():
"""Test the VGroup add method."""
obj = VGroup()
Expand Down
16 changes: 16 additions & 0 deletions tests/opengl/test_opengl_vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def test_vgroup_init(using_opengl_renderer):
VGroup(OpenGLMobject(), OpenGLMobject())


def test_vgroup_iter_init(using_opengl_renderer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See feedback above

"""Test the VGroup instantiation with an iterable type."""

def basic_generator(n, type):
i = 0
while i < n:
i += 1
yield type()

obj = VGroup(basic_generator(5, OpenGLVMobject))
assert len(obj.submobjects) == 5

with pytest.raises(TypeError):
VGroup(basic_generator(5, OpenGLMobject))


def test_vgroup_add(using_opengl_renderer):
"""Test the VGroup add method."""
obj = VGroup()
Expand Down