Skip to content

Commit

Permalink
Don't use transforms in base class
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFabisch committed Jul 18, 2023
1 parent 25fa493 commit d78eac0
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
32 changes: 24 additions & 8 deletions pytransform3d/transform_manager/_transform_graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ def _invert_transform(self, A2B):
def _path_transform(self, path):
"""Convert sequence of node names to rigid transformation."""

@abc.abstractmethod
def _transform_available(self, key):
"""Check if transformation key is available."""

@abc.abstractmethod
def _set_transform(self, key, A2B):
"""Store transformation under given key."""

@abc.abstractmethod
def _get_transform(self, key):
"""Retrieve stored transformation under given key."""

@abc.abstractmethod
def _del_transform(self, key):
"""Delete transformation stored under given key."""

def has_frame(self, frame):
"""Check if frame has been registered.
Expand Down Expand Up @@ -103,7 +119,7 @@ def add_transform(self, from_frame, to_frame, A2B):
transform_key = (from_frame, to_frame)

recompute_shortest_path = False
if transform_key not in self.transforms:
if not self._transform_available(transform_key):
ij_index = len(self.i)
self.i.append(self.nodes.index(from_frame))
self.j.append(self.nodes.index(to_frame))
Expand All @@ -113,7 +129,7 @@ def add_transform(self, from_frame, to_frame, A2B):
if recompute_shortest_path:
self._recompute_shortest_path()

self.transforms[transform_key] = A2B
self._set_transform(transform_key, A2B)

return self

Expand Down Expand Up @@ -147,8 +163,8 @@ def remove_transform(self, from_frame, to_frame):
This object for chaining
"""
transform_key = (from_frame, to_frame)
if transform_key in self.transforms:
del self.transforms[transform_key]
if self._transform_available(transform_key):
self._del_transform(transform_key)
ij_index = self.transform_to_ij_index[transform_key]
del self.transform_to_ij_index[transform_key]
self.transform_to_ij_index = dict(
Expand Down Expand Up @@ -188,12 +204,12 @@ def get_transform(self, from_frame, to_frame):
if to_frame not in self.nodes:
raise KeyError("Unknown frame '%s'" % to_frame)

if (from_frame, to_frame) in self.transforms:
return self.transforms[(from_frame, to_frame)]
if self._transform_available((from_frame, to_frame)):
return self._get_transform((from_frame, to_frame))

if (to_frame, from_frame) in self.transforms:
if self._transform_available((to_frame, from_frame)):
return self._invert_transform(
self.transforms[(to_frame, from_frame)])
self._get_transform((to_frame, from_frame)))

i = self.nodes.index(from_frame)
j = self.nodes.index(to_frame)
Expand Down
12 changes: 12 additions & 0 deletions pytransform3d/transform_manager/_transform_graph_base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ class TransformGraphBase(abc.ABC):
@abc.abstractmethod
def _path_transform(self, path: List[Hashable]) -> Any: ...

@abc.abstractmethod
def _transform_available(self, key: Tuple[Hashable, Hashable]) -> bool: ...

@abc.abstractmethod
def _set_transform(self, key: Tuple[Hashable, Hashable], A2B: Any): ...

@abc.abstractmethod
def _get_transform(self, key: Tuple[Hashable, Hashable]) -> Any: ...

@abc.abstractmethod
def _del_transform(self, key: Tuple[Hashable, Hashable]): ...

def has_frame(self, frame: Hashable) -> bool: ...

def add_transform(self, from_frame: Hashable, to_frame: Hashable,
Expand Down
12 changes: 12 additions & 0 deletions pytransform3d/transform_manager/_transform_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def _path_transform(self, path):
strict_check=self.strict_check, check=self.check)
return A2B

def _transform_available(self, key):
return key in self._transforms

def _set_transform(self, key, A2B):
self._transforms[key] = A2B

def _get_transform(self, key):
return self._transforms[key]

def _del_transform(self, key):
del self._transforms[key]

def plot_frames_in(self, frame, ax=None, s=1.0, ax_s=1, show_name=True,
whitelist=None, **kwargs): # pragma: no cover
"""Plot all frames in a given reference frame.
Expand Down
8 changes: 8 additions & 0 deletions pytransform3d/transform_manager/_transform_manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class TransformManager(TransformGraphBase):

def _path_transform(self, path: List[Hashable]) -> np.ndarray: ...

def _transform_available(self, key: Tuple[Hashable, Hashable]) -> bool: ...

def _set_transform(self, key: Tuple[Hashable, Hashable], A2B: Any): ...

def _get_transform(self, key: Tuple[Hashable, Hashable]) -> Any: ...

def _del_transform(self, key: Tuple[Hashable, Hashable]): ...

def add_transform(self, from_frame: Hashable, to_frame: Hashable,
A2B: np.ndarray) -> "TransformManager": ...

Expand Down

0 comments on commit d78eac0

Please sign in to comment.