Skip to content

Commit

Permalink
Fix #5 (#6)
Browse files Browse the repository at this point in the history
- Add `exclude_diagonal` kwarg to the constructor and default it to `True` to keep compatibility.
- Improve docstrings for `IntraGroupSimilarityLoss` and `InterGroupSimilarityLoss`
  • Loading branch information
ulupo authored May 14, 2024
1 parent c26b106 commit 6170564
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
32 changes: 23 additions & 9 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,13 +542,18 @@ class InterGroupSimilarityLoss(Module):
relationships.
Similarity matrices are expected to be square and symmetric. The loss is computed
by comparing the (unrolled and concatenated) upper triangular blocks containing
inter-group similarities."""
by comparing the (flattened and concatenated) blocks containing inter-group
similarities."""

def __init__(
self,
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Iterable[int],
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated inter-group blocks of the similarity matrices.
# Default: dot product
score_fn: Union[callable, None] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -588,17 +593,24 @@ class IntraGroupSimilarityLoss(Module):
relationships.
Similarity matrices are expected to be square and symmetric. Their diagonal
elements are ignored.
If `group_sizes` is provided, the loss is computed by comparing the (unrolled
and concatenated) upper triangular blocks containing intra-group similarities.
elements are ignored if `exclude_diagonal` is set to True.
If `group_sizes` is provided, the loss is computed by comparing the flattened
and concatenated upper triangular blocks containing intra-group similarities.
Otherwise, the loss is computed by comparing the upper triangular part of the
full similarity matrices, excluding the main diagonal."""
full similarity matrices."""

def __init__(
self,
*,
# Number of entries in each group (e.g. species). Groups are assumed to be
# contiguous in the input similarity matrices
group_sizes: Optional[Iterable[int]] = None,
# If not ``None``, custom callable to compute the differentiable score between
# the flattened and concatenated intra-group blocks of the similarity matrices
# Default: dot product
score_fn: Union[callable, None] = None,
# If ``True``, exclude the diagonal elements from the computation
exclude_diagonal: bool = True,
) -> None:
super().__init__()
self.group_sizes = (
Expand All @@ -607,15 +619,17 @@ def __init__(
self.score_fn = (
partial(torch.tensordot, dims=1) if score_fn is None else score_fn
)
self.exclude_diagonal = exclude_diagonal

if self.group_sizes is not None:
# Boolean mask for the main diagonal blocks corresponding to groups
diag_blocks_mask = torch.block_diag(
*[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes]
)
# Extract the upper triangular part, excluding the main diagonal
# Extract the upper triangular part
self.register_buffer(
"_upper_diag_blocks_mask", torch.triu(diag_blocks_mask, diagonal=1)
"_upper_diag_blocks_mask",
torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)),
)
else:
self._upper_diag_blocks_mask = None
Expand All @@ -638,7 +652,7 @@ def forward(
layout=similarities_x.layout,
device=similarities_x.device,
),
diagonal=1,
diagonal=int(self.exclude_diagonal),
)
else:
mask = self._upper_diag_blocks_mask
Expand Down
32 changes: 23 additions & 9 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1042,13 +1042,18 @@
" relationships.\n",
"\n",
" Similarity matrices are expected to be square and symmetric. The loss is computed\n",
" by comparing the (unrolled and concatenated) upper triangular blocks containing\n",
" inter-group similarities.\"\"\"\n",
" by comparing the (flattened and concatenated) blocks containing inter-group\n",
" similarities.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Iterable[int],\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated inter-group blocks of the similarity matrices.\n",
" # Default: dot product\n",
" score_fn: Union[callable, None] = None,\n",
" ) -> None:\n",
" super().__init__()\n",
Expand Down Expand Up @@ -1088,17 +1093,24 @@
" relationships.\n",
"\n",
" Similarity matrices are expected to be square and symmetric. Their diagonal\n",
" elements are ignored.\n",
" If `group_sizes` is provided, the loss is computed by comparing the (unrolled\n",
" and concatenated) upper triangular blocks containing intra-group similarities.\n",
" elements are ignored if `exclude_diagonal` is set to True.\n",
" If `group_sizes` is provided, the loss is computed by comparing the flattened\n",
" and concatenated upper triangular blocks containing intra-group similarities.\n",
" Otherwise, the loss is computed by comparing the upper triangular part of the\n",
" full similarity matrices, excluding the main diagonal.\"\"\"\n",
" full similarity matrices.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" *,\n",
" # Number of entries in each group (e.g. species). Groups are assumed to be\n",
" # contiguous in the input similarity matrices\n",
" group_sizes: Optional[Iterable[int]] = None,\n",
" # If not ``None``, custom callable to compute the differentiable score between\n",
" # the flattened and concatenated intra-group blocks of the similarity matrices\n",
" # Default: dot product\n",
" score_fn: Union[callable, None] = None,\n",
" # If ``True``, exclude the diagonal elements from the computation\n",
" exclude_diagonal: bool = True,\n",
" ) -> None:\n",
" super().__init__()\n",
" self.group_sizes = (\n",
Expand All @@ -1107,15 +1119,17 @@
" self.score_fn = (\n",
" partial(torch.tensordot, dims=1) if score_fn is None else score_fn\n",
" )\n",
" self.exclude_diagonal = exclude_diagonal\n",
"\n",
" if self.group_sizes is not None:\n",
" # Boolean mask for the main diagonal blocks corresponding to groups\n",
" diag_blocks_mask = torch.block_diag(\n",
" *[torch.ones((s, s), dtype=torch.bool) for s in self.group_sizes]\n",
" )\n",
" # Extract the upper triangular part, excluding the main diagonal\n",
" # Extract the upper triangular part\n",
" self.register_buffer(\n",
" \"_upper_diag_blocks_mask\", torch.triu(diag_blocks_mask, diagonal=1)\n",
" \"_upper_diag_blocks_mask\",\n",
" torch.triu(diag_blocks_mask, diagonal=int(self.exclude_diagonal)),\n",
" )\n",
" else:\n",
" self._upper_diag_blocks_mask = None\n",
Expand All @@ -1138,7 +1152,7 @@
" layout=similarities_x.layout,\n",
" device=similarities_x.device,\n",
" ),\n",
" diagonal=1,\n",
" diagonal=int(self.exclude_diagonal),\n",
" )\n",
" else:\n",
" mask = self._upper_diag_blocks_mask\n",
Expand Down

0 comments on commit 6170564

Please sign in to comment.