diff --git a/src/penai/models.py b/src/penai/models.py index 5928291..b46e627 100644 --- a/src/penai/models.py +++ b/src/penai/models.py @@ -273,23 +273,57 @@ def _element_to_string(cls, element: etree.Element) -> str: @classmethod def _remove_unwanted_elements(cls, tree: BetterElement) -> BetterElement: root = deepcopy(tree) + removed_elements = [] - for _i, element in enumerate(root.iter()): - keep = cls._is_penpot_element(element) - if not keep and element.tag == cls._tag("g", "svg"): + retained_element_set = set() + + # traverse the elements of the tree, collecting the ones to remove + for element in root.iter(): + if element in retained_element_set: + continue + + is_penpot_element = cls._is_penpot_element(element) + + # apply special handling depending on penpot element + if is_penpot_element: + if element.tag == cls._name("shape", "penpot"): + attr_type = element.attrib.get(cls._name("type", "penpot")) + if attr_type in ("path", "circle", "ellipse"): + # for certain types penpot shapes, we need to retain the subsequent sibling and all its descendants + subsequent_g = None + sibling = element + while subsequent_g is None: + sibling = sibling.getnext() + if sibling.tag == cls._name("defs", "svg"): + continue + elif sibling.tag == cls._name("g", "svg"): + subsequent_g = sibling + else: + raise ValueError( + f"Unexpected element after a penpot shape path: {sibling}" + ) + retained_element_set.update(subsequent_g.iter()) # type: ignore + + # decide whether to keep or remove the current element: + # We keep penpot elements and elements that have at least one penpot element as a child + keep = is_penpot_element + if not keep and element.tag == cls._name("g", "svg"): for child in element.getchildren(): if cls._is_penpot_element(child): keep = True break if not keep: removed_elements.append(element) + + # apply the actual removal of the elements for element in removed_elements: element.getparent().remove(element) + return root @classmethod - def _tag(cls, tag: str, namespace: str) -> str: - return "{" + cls.NSMAP[namespace] + "}" + tag + def _name(cls, name: str, namespace: str) -> str: + return "{" + cls.NSMAP[namespace] + "}" + name @classmethod def _nsmap_with_default(cls, default_ns: str) -> dict[str | None, str]: @@ -297,7 +331,7 @@ def _nsmap_with_default(cls, default_ns: str) -> dict[str | None, str]: def to_svg_root(self) -> etree.Element: nsmap = self._nsmap_with_default("svg") - root = etree.Element(self._tag("svg", namespace="svg"), nsmap=nsmap) + root = etree.Element(self._name("svg", namespace="svg"), nsmap=nsmap) root.append(self.root) return root