Skip to content

Commit

Permalink
Fix #9 (#10)
Browse files Browse the repository at this point in the history
- Add `n_repeats` kwarg to `DiffPaSSModel.fit_bootstrap`
- Tidy up `fit_bootstrap` code
  • Loading branch information
ulupo authored May 15, 2024
1 parent da0a8f5 commit fb7c79b
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 98 deletions.
191 changes: 143 additions & 48 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def fit_bootstrap(
int
] = None, # If ``None``, the bootstrap will end when all pairs are fixed. Otherwise, the bootstrap will end when `n_end` pairs are fixed
step_size: int = 1, # Difference between the number of fixed pairings chosen at consecutive bootstrap iterations
n_repeats: int = 1, # At each bootstrap iteration, `n_repeats` runs will be performed, and the run with the lowest loss will be chosen
show_pbar: bool = True, # If ``True``, show progress bar. Default: ``True``
single_fit_cfg: Optional[
dict
Expand All @@ -528,95 +529,187 @@ def fit_bootstrap(
and fixed for the next run.
The number of pairings fixed at each iteration ranges between `n_start` (default: 1) and `n_end` (default: total number of pairs), with a step size of `step_size`.
"""
self.prepare_fit(x, y)
if self.fixed_pairings is None:
initial_fixed_pairings = [[] for _ in self.group_sizes]
else:
initial_fixed_pairings = [list(fm) for fm in self.fixed_pairings]

_single_fit_cfg = deepcopy(self.single_fit_default_cfg)
if single_fit_cfg is not None:
_single_fit_cfg.update(single_fit_cfg)
single_fit_cfg = _single_fit_cfg
########## Preparations ##########

# Initialize DiffPaSSResults object
results = self._init_results(
record_log_alphas=single_fit_cfg["record_log_alphas"],
record_soft_perms=single_fit_cfg["record_soft_perms"],
record_soft_losses=single_fit_cfg["record_soft_losses"],
)
available_fields = [
field.name
for field in fields(results)
if getattr(results, field.name) is not None
]
field_to_length_so_far = {field_name: 0 for field_name in available_fields}
# Input validation
self.prepare_fit(x, y)

# Prepare variables for indexing
n_samples = len(x)
n_groups = len(self.group_sizes)
cumsum_group_sizes = np.cumsum([0] + list(self.group_sizes))
offsets = np.repeat(cumsum_group_sizes[:-1], repeats=self.group_sizes)
group_idxs = np.repeat(np.arange(n_groups), repeats=self.group_sizes)

# First fit with initial fixed matchings
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)
# Initially fixed pairings as derived from the `fixed_pairings` attribute
if self.fixed_pairings is None:
initially_fixed_pairings = [[] for _ in self.group_sizes]
else:
initially_fixed_pairings = [list(fm) for fm in self.fixed_pairings]

# Find effective initial fixed matchings
effective_initial_fixed_idxs = []
# *Effective* initially fixed pairings as global indices (not relative to group)
# Used to exclude these pairs from the random sampling of new fixed pairings
# and to determine when the bootstrap will end
effective_initially_fixed_idxs = []
for s, efmz in zip(
cumsum_group_sizes, self.permutation._effective_fixed_pairings_zip
):
if efmz:
effective_initial_fixed_idxs += [
(s + efmz_fixed) for efmz_fixed in efmz[1]
effective_initially_fixed_idxs += [
s + efmz_fixed for efmz_fixed in efmz[1]
]
effective_initial_fixed_idxs = np.asarray(effective_initial_fixed_idxs)
nonfixed_idxs = np.setdiff1d(np.arange(n_samples), effective_initial_fixed_idxs)
n_effective_initial_fixed_pairings = len(effective_initial_fixed_idxs)

non_initially_fixed_idxs = np.setdiff1d(
np.arange(n_samples), effective_initially_fixed_idxs
)
if n_end is None:
n_end = n_samples - n_effective_initial_fixed_pairings - 1

# Subsequent fits: at a given iteration we use fixed matchings chosen uniformly at
# random from the results of the previous iteration (excluding the effective initial
# fixed matchings)
pbar = list(range(n_start, n_end, step_size))
n_end = n_samples - len(effective_initially_fixed_idxs) - 1
# Bootstrap range and progress bar
pbar = range(n_start, n_end, step_size)
pbar = tqdm(pbar) if show_pbar else pbar
n_iters_with_optimization = int(can_optimize)
for N in pbar:
latest_hard_perms = results.hard_perms[-1]
mapped_idxs = offsets + np.concatenate(latest_hard_perms)
rand_fixed_idxs = np.random.permutation(nonfixed_idxs)[:N]

########## End preparations ##########

########## Closures ##########

def make_new_fixed_pairings(
mapped_idxs: np.ndarray, N: int
) -> IndexPairsInGroups:
"""Subroutine for randomly sampling new fixed pairings for the next bootstrap iteration."""
rand_fixed_idxs = np.random.permutation(non_initially_fixed_idxs)[:N]
rand_fixed_idxs = np.sort(rand_fixed_idxs)
rand_mapped_idxs = mapped_idxs[rand_fixed_idxs]
rand_group_idxs = group_idxs[rand_fixed_idxs]
rand_fixed_rel_idxs = rand_fixed_idxs - offsets[rand_fixed_idxs]
rand_mapped_rel_idxs = rand_mapped_idxs - offsets[rand_mapped_idxs]

# Update fixed matchings
# Update fixed pairings
fixed_pairings = [[] for _ in range(n_groups)]
for rand_group_idx, mapped_rel_idx, fixed_rel_idx in zip(
rand_group_idxs, rand_mapped_rel_idxs, rand_fixed_rel_idxs
):
pair = (mapped_rel_idx, fixed_rel_idx)
fixed_pairings[rand_group_idx].append(pair)
fixed_pairings = [
initial_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)
initially_fixed_pairings[k] + fixed_pairings[k] for k in range(n_groups)
]
self.permutation.init_fixed_pairings_and_log_alphas(
fixed_pairings, device=x.device

return fixed_pairings

def init_diffpassresults() -> DiffPaSSResults:
return self._init_results(
record_log_alphas=single_fit_cfg["record_log_alphas"],
record_soft_perms=single_fit_cfg["record_soft_perms"],
record_soft_losses=single_fit_cfg["record_soft_losses"],
)

def extend_results_with_lowest_loss_repeat(
results_this_iter: DiffPaSSResults,
results: DiffPaSSResults,
can_optimize: bool,
) -> None:
"""Extend the global optimization object `results` with the portion of
`results_this_iter` (from the latest bootstrap iteration) corresponding to
the repeat with the lowest hard loss."""
if can_optimize:
# Select run with lowest hard loss, discard the rest
reshaped_hard_losses_this_repeat = np.asarray(
results_this_iter.hard_losses
).reshape(n_repeats, -1)
min_loss_idx = np.argmin(reshaped_hard_losses_this_repeat[:, -1])
size_each_repeat = reshaped_hard_losses_this_repeat.shape[1]
# Record complete results of the run with the lowest loss
slice_to_append = slice(
min_loss_idx * size_each_repeat,
(min_loss_idx + 1) * size_each_repeat,
)
else:
slice_to_append = slice(None)
[
getattr(results, field_name).extend(
getattr(results_this_iter, field_name)[slice_to_append]
)
for field_name in available_fields
]

postprocess_results_after_repeats = (
extend_results_with_lowest_loss_repeat
if n_repeats > 1
else lambda *args: None
)

########## End closures ##########

# Configuration for each gradient descent run
_single_fit_cfg = deepcopy(self.single_fit_default_cfg)
if single_fit_cfg is not None:
_single_fit_cfg.update(single_fit_cfg)
single_fit_cfg = _single_fit_cfg

# Initialize DiffPaSSResults object
results = init_diffpassresults()
available_fields = [
field.name
for field in fields(results)
if getattr(results, field.name) is not None
]
field_to_length_so_far = {field_name: 0 for field_name in available_fields}

########## Optimization ##########

# First fit with initially fixed pairings
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)
n_iters_with_optimization = int(can_optimize)

# DiffPaSSResults object for each bootstrap iteration:
# new object if `n_repeats` > 1, else the existing `results`
get_results_to_use_in_each_bootstrap_iter = (
init_diffpassresults if n_repeats > 1 else lambda: results
)

# Subsequent bootstrap fits: at a given iteration we use fixed pairings chosen uniformly at
# random from the results of the previous iteration (excluding the effective initially
# fixed pairings)
for N in pbar:
latest_hard_perms = results.hard_perms[-1]
mapped_idxs = offsets + np.concatenate(latest_hard_perms)

field_to_length_so_far = {
field_name: len(getattr(results, field_name))
for field_name in available_fields
}
can_optimize = self._fit(x, y, results=results, **single_fit_cfg)

results_this_iter = (
get_results_to_use_in_each_bootstrap_iter()
) # `results` alias if `n_repeats` == 1
for _ in range(n_repeats):
# Randomly sample N fixed pairings
fixed_pairings = make_new_fixed_pairings(mapped_idxs, N)
# Reinitialize permutation module with new fixed pairings
self.permutation.init_fixed_pairings_and_log_alphas(
fixed_pairings, device=x.device
)
# Fit with gradient descent
can_optimize = self._fit(
x, y, results=results_this_iter, **single_fit_cfg
)
if not can_optimize:
# If we can't fit, we break the "repeats" loop
break

postprocess_results_after_repeats(
results_this_iter, results, can_optimize
) # Does nothing if `n_repeats` == 1

if can_optimize:
n_iters_with_optimization += 1
else:
# If we could not fit, terminate the bootstrap
break

########## End optimization ##########

########## Post-processing ##########

# Reshape results according to number of iterations performed
reshaped_fields = {}
for field_name in available_fields:
Expand Down Expand Up @@ -644,4 +737,6 @@ def fit_bootstrap(
)
results = replace(results, **reshaped_fields)

########## End post-processing ##########

return results
2 changes: 1 addition & 1 deletion diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def hard_(self) -> None:
self.mode = "hard"

def _impl_fixed_pairings(self, func: callable) -> callable:
"""Include fixed matchings in the Gumbel-Sinkhorn or Gumbel-matching operators."""
"""Include fixed pairings in the Gumbel-Sinkhorn or Gumbel-matching operators."""

def wrapper(gen: Iterator[torch.Tensor]) -> Iterator[torch.Tensor]:
for s, mat, (row_group, col_group), mask in zip(
Expand Down
Loading

0 comments on commit fb7c79b

Please sign in to comment.