Skip to content

Commit

Permalink
Move transform_manager to submodule
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFabisch committed Jul 17, 2023
1 parent 20fc809 commit d6b76af
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 319 deletions.
6 changes: 3 additions & 3 deletions pytransform3d/test/test_transform_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,16 @@ def test_png_export():

def test_png_export_without_pydot_fails():
"""Test graph export without pydot."""
pydot_available = transform_manager.PYDOT_AVAILABLE
pydot_available = transform_manager._transform_manager.PYDOT_AVAILABLE
tm = TransformManager()
try:
transform_manager.PYDOT_AVAILABLE = False
transform_manager._transform_manager.PYDOT_AVAILABLE = False
with pytest.raises(
ImportError,
match="pydot must be installed to use this feature."):
tm.write_png("bla")
finally:
transform_manager.PYDOT_AVAILABLE = pydot_available
transform_manager._transform_manager.PYDOT_AVAILABLE = pydot_available


def test_deactivate_transform_manager_precision_error():
Expand Down
9 changes: 9 additions & 0 deletions pytransform3d/transform_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Manage complex chains of transformations.
See :doc:`transform_manager` for more information.
"""
from ._transform_graph_base import TransformGraphBase
from ._transform_manager import TransformManager


__all__ = ["TransformGraphBase", "TransformManager"]
259 changes: 259 additions & 0 deletions pytransform3d/transform_manager/_transform_graph_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import abc
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csgraph


class TransformGraphBase(abc.ABC):
"""Base class for all trees of rigid transformations.
Parameters
----------
strict_check : bool, optional (default: True)
Raise a ValueError if the transformation matrix is not numerically
close enough to a real transformation matrix. Otherwise we print a
warning.
check : bool, optional (default: True)
Check if transformation matrices are valid and requested nodes exist,
which might significantly slow down some operations.
"""
def __init__(self, strict_check=True, check=True):
self.strict_check = strict_check
self.check = check

# Names of nodes
self.nodes = []

# A pair (self.i[n], self.j[n]) represents indices of connected nodes
self.i = []
self.j = []
# We have to store the index n associated to a transformation to be
# able to remove the transformation later
self.transform_to_ij_index = {}
# Connection information as sparse matrix
self.connections = sp.csr_matrix((0, 0))
# Result of shortest path algorithm:
# distance matrix (distance is the number of transformations)
self.dist = np.empty(0)
self.predecessors = np.empty(0, dtype=np.int32)

self._cached_shortest_paths = {}

@property
@abc.abstractmethod
def transforms(self):
"""Rigid transformations between nodes."""

@abc.abstractmethod
def _check_transform(self, A2B):
"""Check validity of rigid transformation."""

@abc.abstractmethod
def _invert_transform(self, A2B):
"""Invert rigid transformation stored in the tree."""

@abc.abstractmethod
def _path_transform(self, path):
"""Convert sequence of node names to rigid transformation."""

def has_frame(self, frame):
"""Check if frame has been registered.
Parameters
----------
frame : Hashable
Frame name
Returns
-------
has_frame : bool
Frame is registered
"""
return frame in self.nodes

def add_transform(self, from_frame, to_frame, A2B):
"""Register a transformation.
Parameters
----------
from_frame : Hashable
Name of the frame for which the transformation is added in the
to_frame coordinate system
to_frame : Hashable
Name of the frame in which the transformation is defined
A2B : Any
Transformation from 'from_frame' to 'to_frame'
Returns
-------
self : TransformManager
This object for chaining
"""
if self.check:
A2B = self._check_transform(A2B)

if from_frame not in self.nodes:
self.nodes.append(from_frame)
if to_frame not in self.nodes:
self.nodes.append(to_frame)

transform_key = (from_frame, to_frame)

recompute_shortest_path = False
if transform_key not in self.transforms:
ij_index = len(self.i)
self.i.append(self.nodes.index(from_frame))
self.j.append(self.nodes.index(to_frame))
self.transform_to_ij_index[transform_key] = ij_index
recompute_shortest_path = True

if recompute_shortest_path:
self._recompute_shortest_path()

self.transforms[transform_key] = A2B

return self

def _recompute_shortest_path(self):
n_nodes = len(self.nodes)
self.connections = sp.csr_matrix(
(np.zeros(len(self.i)), (self.i, self.j)),
shape=(n_nodes, n_nodes))
self.dist, self.predecessors = csgraph.shortest_path(
self.connections, unweighted=True, directed=False, method="D",
return_predecessors=True)
self._cached_shortest_paths.clear()

def remove_transform(self, from_frame, to_frame):
"""Remove a transformation.
Nothing happens if there is no such transformation.
Parameters
----------
from_frame : Hashable
Name of the frame for which the transformation is added in the
to_frame coordinate system
to_frame : Hashable
Name of the frame in which the transformation is defined
Returns
-------
self : TransformManager
This object for chaining
"""
transform_key = (from_frame, to_frame)
if transform_key in self.transforms:
del self.transforms[transform_key]
ij_index = self.transform_to_ij_index[transform_key]
del self.transform_to_ij_index[transform_key]
self.transform_to_ij_index = dict(
(k, v if v < ij_index else v - 1)
for k, v in self.transform_to_ij_index.items())
del self.i[ij_index]
del self.j[ij_index]
self._recompute_shortest_path()
return self

def get_transform(self, from_frame, to_frame):
"""Request a transformation.
Parameters
----------
from_frame : Hashable
Name of the frame for which the transformation is requested in the
to_frame coordinate system
to_frame : Hashable
Name of the frame in which the transformation is defined
Returns
-------
A2B : Any
Transformation from 'from_frame' to 'to_frame'
Raises
------
KeyError
If one of the frames is unknown or there is no connection between
them
"""
if self.check:
if from_frame not in self.nodes:
raise KeyError("Unknown frame '%s'" % from_frame)
if to_frame not in self.nodes:
raise KeyError("Unknown frame '%s'" % to_frame)

if (from_frame, to_frame) in self.transforms:
return self.transforms[(from_frame, to_frame)]

if (to_frame, from_frame) in self.transforms:
return self._invert_transform(
self.transforms[(to_frame, from_frame)])

i = self.nodes.index(from_frame)
j = self.nodes.index(to_frame)
if not np.isfinite(self.dist[i, j]):
raise KeyError("Cannot compute path from frame '%s' to "
"frame '%s'." % (from_frame, to_frame))

path = self._shortest_path(i, j)
return self._path_transform(path)

def _shortest_path(self, i, j):
"""Names of nodes along the shortest path between two indices."""
if (i, j) in self._cached_shortest_paths:
return self._cached_shortest_paths[(i, j)]

path = []
k = i
while k != -9999:
path.append(self.nodes[k])
k = self.predecessors[j, k]
self._cached_shortest_paths[(i, j)] = path
return path

def connected_components(self):
"""Get number of connected components.
If the number is larger than 1 there will be frames without
connections.
Returns
-------
n_connected_components : int
Number of connected components.
"""
return csgraph.connected_components(
self.connections, directed=False, return_labels=False)

def check_consistency(self):
"""Check consistency of the known transformations.
The complexity of this is between :math:`O(n^2)` and :math:`O(n^3)`,
where :math:`n` is the number of nodes. In graphs where each pair of
nodes is directly connected the complexity is :math:`O(n^2)`. In graphs
that are actually paths, the complexity is :math:`O(n^3)`.
Returns
-------
consistent : bool
Is the graph consistent, i.e. is A2B always the same as the inverse
of B2A?
"""
consistent = True
for node1 in self.nodes:
for node2 in self.nodes:
try:
node1_to_node2 = self.get_transform(node1, node2)
node2_to_node1 = self.get_transform(node2, node1)
node1_to_node2_inv = self._invert_transform(node2_to_node1)
consistent = consistent and np.allclose(node1_to_node2,
node1_to_node2_inv)
except KeyError:
pass # Frames are not connected
return consistent
54 changes: 54 additions & 0 deletions pytransform3d/transform_manager/_transform_graph_base.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import abc
from typing import Dict, Tuple, List, Hashable, Any
import scipy.sparse as sp
import numpy as np
import numpy.typing as npt


class TransformGraphBase(abc.ABC):
strict_check: bool
check: bool
nodes: List[Hashable]
i: List[int]
j: List[int]
transform_to_ij_index = Dict[Tuple[Hashable, Hashable], int]
connections: sp.csr_matrix
dist: np.ndarray
predecessors: np.ndarray
_cached_shortest_paths: Dict[Tuple[int, int], List[Hashable]]

def __init__(self, strict_check: bool = ...,
check: bool = ...) -> "TransformGraphBase": ...

@property
@abc.abstractmethod
def transforms(self) -> Dict[Tuple[Hashable, Hashable], np.ndarray]: ...

@abc.abstractmethod
def _check_transform(self, A2B: Any) -> Any: ...

@abc.abstractmethod
def _invert_transform(self, A2B: Any) -> Any: ...

@abc.abstractmethod
def _path_transform(self, path: List[Hashable]) -> Any: ...

def has_frame(self, frame: Hashable) -> bool: ...

def add_transform(self, from_frame: Hashable, to_frame: Hashable,
A2B: Any) -> "TransformGraphBase": ...

def _recompute_shortest_path(self): ...

def remove_transform(
self, from_frame: Hashable,
to_frame: Hashable) -> "TransformGraphBase": ...

def get_transform(
self, from_frame: Hashable, to_frame: Hashable) -> Any: ...

def _shortest_path(self, i: int, j: int) -> List[Hashable]: ...

def connected_components(self) -> int: ...

def check_consistency(self) -> bool: ...
Loading

0 comments on commit d6b76af

Please sign in to comment.