Skip to content

Commit a82f9e6

Browse files
committed
Optionally enforce exact distances
1 parent 47b585c commit a82f9e6

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

umap/umap_.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,7 @@ def __init__(
16111611
transform_seed=42,
16121612
transform_mode="embedding",
16131613
force_approximation_algorithm=False,
1614+
force_exact_distances=False,
16141615
verbose=False,
16151616
unique=False,
16161617
densmap=False,
@@ -1648,6 +1649,7 @@ def __init__(
16481649
self.transform_seed = transform_seed
16491650
self.transform_mode = transform_mode
16501651
self.force_approximation_algorithm = force_approximation_algorithm
1652+
self.force_exact_distances = force_exact_distances
16511653
self.verbose = verbose
16521654
self.unique = unique
16531655

@@ -1842,6 +1844,9 @@ def _dist_only(x, y, *kwds):
18421844
if self.n_jobs < -1 or self.n_jobs == 0:
18431845
raise ValueError("n_jobs must be a postive integer, or -1 (for all cores)")
18441846

1847+
if self.force_approximation_algorithm and self.force_exact_distances:
1848+
raise ValueError("enforcing both exact distances and an approximation contradict each other")
1849+
18451850
if self.dens_lambda < 0.0:
18461851
raise ValueError("dens_lambda cannot be negative")
18471852
if self.dens_frac < 0.0 or self.dens_frac > 1.0:
@@ -1930,6 +1935,9 @@ def _populate_combined_params(self, *models):
19301935
self.force_approximation_algorithm = flattened(
19311936
[m.force_approximation_algorithm for m in models]
19321937
)
1938+
self.force_exact_distances = flattened(
1939+
[m.force_exact_distances for m in models]
1940+
)
19331941
self.verbose = flattened([m.verbose for m in models])
19341942
self.unique = flattened([m.unique for m in models])
19351943

@@ -2332,7 +2340,8 @@ def fit(self, X, y=None):
23322340
verbose=self.verbose,
23332341
)
23342342
# Handle small cases efficiently by computing all distances
2335-
elif X[index].shape[0] < 4096 and not self.force_approximation_algorithm:
2343+
elif self.force_exact_distances or (
2344+
X[index].shape[0] < 4096 and not self.force_approximation_algorithm):
23362345
self._small_data = True
23372346
try:
23382347
# sklearn pairwise_distances fails for callable metric on sparse data

0 commit comments

Comments
 (0)