Skip to content

Commit

Permalink
hammer lapx into qbindiff
Browse files Browse the repository at this point in the history
  • Loading branch information
thebabush authored and patacca committed Nov 28, 2024
1 parent 13e156e commit 0ecbc99
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/qbindiff/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@

# Third-party imports
import numpy as np
from lapjv import lapjv # type: ignore[import-not-found]
from scipy.sparse import csr_matrix, coo_matrix # type: ignore[import-untyped]

# Local imports
from qbindiff.matcher.squares import find_squares # type: ignore[import-untyped]
from qbindiff.matcher.belief_propagation import BeliefMWM, BeliefQAP
from qbindiff.utils import lap


if TYPE_CHECKING:
Expand All @@ -50,7 +50,7 @@ def solve_linear_assignment(cost_matrix: Matrix) -> RawMapping:
cost_matrix = cost_matrix.T
full_cost_matrix = np.zeros((m, m), dtype=cost_matrix.dtype)
full_cost_matrix[:n, :m] = cost_matrix
col_indices = lapjv(full_cost_matrix)[0][:n]
col_indices = lap(full_cost_matrix)[:n]
if transposed:
return col_indices, np.arange(n)
return np.arange(n), col_indices
Expand Down
2 changes: 1 addition & 1 deletion src/qbindiff/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
Collection of utilities used internally.
"""

from .utils import is_debug, iter_csr_matrix, log_once, wrapper_iter
from .utils import is_debug, iter_csr_matrix, lap, log_once, wrapper_iter
18 changes: 17 additions & 1 deletion src/qbindiff/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@
import logging
from typing import TYPE_CHECKING

try:
from lapjv import lapjv as _lap # type: ignore[import-not-found]
_is_lapjv = True
except ImportError:
from lap import lapjv as _lap
_is_lapjv = False

if TYPE_CHECKING:
from collections.abc import Generator
from typing import Any
from qbindiff.types import SparseMatrix
from qbindiff.types import Matrix, RawMapping, SparseMatrix


def is_debug() -> bool:
Expand All @@ -49,6 +56,15 @@ def iter_csr_matrix(matrix: SparseMatrix) -> Generator[tuple[int, int, Any], Non
yield (x, y, v)


def lap(cost_matrix: Matrix) -> RawMapping:
if _is_lapjv:
row_ind_lapjv, _col_ind_lapjv, _ = _lap(cost_matrix)
return row_ind_lapjv
else:
_cost, x, _y = _lap(cost_matrix)
return x


@cache
def log_once(level: int, message: str) -> None:
"""
Expand Down

0 comments on commit 0ecbc99

Please sign in to comment.