Skip to content

Commit 1e434ae

Browse files
HenryChen4btjanaka
andauthored
Add new centroid generation techniques (#417)
## Description Added new centroid generation techniques in _cvt_archive.py and benchmarked these in benchmark.py. These techniques were studied in Mouret 2023: https://dl.acm.org/doi/10.1145/3583133.3590726 Notably, this PR bumps scipy to 1.7.0 since that is when scipy.stats.qmc was first introduced, but this should not be an issue for most users since scipy 1.7.0 supports Python 3.7+. ## TODO <!-- Notable points that this PR has either accomplished or will accomplish. --> ## Questions <!-- Any concerns or points of confusion? --> ## Status - [ ] I have read the guidelines in [CONTRIBUTING.md](https://github.com/icaros-usc/pyribs/blob/master/CONTRIBUTING.md) - [ ] I have formatted my code using `yapf` - [ ] I have tested my code by running `pytest` - [ ] I have linted my code with `pylint` - [ ] I have added a one-line description of my change to the changelog in `HISTORY.md` - [ ] This PR is ready to go --------- Co-authored-by: Bryon Tjanaka <[email protected]>
1 parent e0ec6db commit 1e434ae

File tree

4 files changed

+97
-68
lines changed

4 files changed

+97
-68
lines changed

benchmarks/benchmark.py renamed to benchmarks/centroid_quality.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,31 @@ def main():
6262
techniques used in the aforementioned paper.
6363
"""
6464

65-
score_seed = 1
66-
num_samples = 10000
67-
archive = CVTArchive(
68-
solution_dim=20,
69-
cells=512,
70-
ranges=[(0., 1.), (0., 1.)],
71-
)
72-
cvt_centroids = archive.centroids
73-
print(
74-
"Score for CVT generation: ",
75-
get_score(centroids=cvt_centroids,
76-
num_samples=num_samples,
77-
seed=score_seed))
78-
79-
centroid_gen_seed = 100
80-
num_centroids = 1024
81-
dim = 2
82-
rng = np.random.default_rng(seed=centroid_gen_seed)
83-
random_centroids = rng.random((num_centroids, dim))
84-
print(
85-
"Score for random generation: ",
86-
get_score(centroids=random_centroids,
87-
num_samples=num_samples,
88-
seed=score_seed))
65+
# Default settings to benchmark different centroid generation techniques.
66+
score_seed = 1823170571
67+
num_samples = 100000
68+
69+
# Settings for creating the CVTArchive.
70+
solution_dim = 20
71+
cells = 512
72+
ranges = [(0., 1.), (0., 1.)]
73+
74+
# Different methods for generating centroids.
75+
generation_methods = [
76+
"kmeans", "random", "sobol", "scrambled_sobol", "halton"
77+
]
78+
79+
# Benchmark each centroid generation technique.
80+
for method in generation_methods:
81+
archive = CVTArchive(solution_dim=solution_dim,
82+
cells=cells,
83+
ranges=ranges,
84+
centroid_method=method)
85+
print(
86+
f"Score for {method} generation: ",
87+
get_score(centroids=archive.centroids,
88+
num_samples=num_samples,
89+
seed=score_seed))
8990

9091

9192
if __name__ == "__main__":

pinned_reqs/install.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ numba==0.51.0
33
pandas==1.0.0
44
sortedcontainers==2.0.0
55
scikit-learn==1.1.0
6-
scipy==1.4.0
6+
scipy==1.7.0
77
threadpoolctl==3.0.0

ribs/archives/_cvt_archive.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Contains the CVTArchive class."""
2+
import numbers
3+
24
import numpy as np
35
from scipy.spatial import cKDTree # pylint: disable=no-name-in-module
6+
from scipy.stats.qmc import Halton, Sobol
47
from sklearn.cluster import k_means
58

69
from ribs._utils import check_batch_shape, check_finite
@@ -90,17 +93,22 @@ class CVTArchive(ArchiveBase):
9093
and a "bar" field that contains 10D values. Note that field names
9194
must be valid Python identifiers, and names already used in the
9295
archive are not allowed.
96+
custom_centroids (array-like): If passed in, this (cells, measure_dim)
97+
array will be used as the centroids of the CVT instead of generating
98+
new ones. In this case, ``samples`` will be ignored, and
99+
``archive.samples`` will be None. This can be useful when one wishes
100+
to use the same CVT across experiments for fair comparison.
101+
centroid_method (str): Pass in the following methods for
102+
generating centroids: "random", "sobol", "scrambled sobol",
103+
"halton". Default method is "kmeans". These methods are derived from
104+
Mouret 2023: https://dl.acm.org/doi/pdf/10.1145/3583133.3590726.
105+
Note: Samples are only used when method is "kmeans".
93106
samples (int or array-like): If it is an int, this specifies the number
94107
of samples to generate when creating the CVT. Otherwise, this must
95108
be a (num_samples, measure_dim) array where samples[i] is a sample
96109
to use when creating the CVT. It can be useful to pass in custom
97110
samples when there are restrictions on what samples in the measure
98111
space are (physically) possible.
99-
custom_centroids (array-like): If passed in, this (cells, measure_dim)
100-
array will be used as the centroids of the CVT instead of generating
101-
new ones. In this case, ``samples`` will be ignored, and
102-
``archive.samples`` will be None. This can be useful when one wishes
103-
to use the same CVT across experiments for fair comparison.
104112
k_means_kwargs (dict): kwargs for :func:`~sklearn.cluster.k_means`. By
105113
default, we pass in `n_init=1`, `init="random"`,
106114
`algorithm="lloyd"`, and `random_state=seed`.
@@ -128,12 +136,13 @@ def __init__(self,
128136
seed=None,
129137
dtype=np.float64,
130138
extra_fields=None,
131-
samples=100_000,
132139
custom_centroids=None,
133-
chunk_size=None,
140+
centroid_method="kmeans",
141+
samples=100_000,
134142
k_means_kwargs=None,
135143
use_kd_tree=True,
136-
ckdtree_kwargs=None):
144+
ckdtree_kwargs=None,
145+
chunk_size=None):
137146

138147
ArchiveBase.__init__(
139148
self,
@@ -167,23 +176,55 @@ def __init__(self,
167176
self._k_means_kwargs.setdefault("algorithm", "lloyd")
168177
self._k_means_kwargs.setdefault("random_state", seed)
169178

170-
self._use_kd_tree = use_kd_tree
171-
self._centroid_kd_tree = None
172-
self._ckdtree_kwargs = ({} if ckdtree_kwargs is None else
173-
ckdtree_kwargs.copy())
174-
self._chunk_size = chunk_size
175-
176179
if custom_centroids is None:
177-
if not isinstance(samples, int):
178-
# Validate shape of custom samples. These are ignored when
179-
# `custom_centroids` is provided.
180-
samples = np.asarray(samples, dtype=self.dtype)
181-
if samples.shape[1] != self._measure_dim:
182-
raise ValueError(
183-
f"Samples has shape {samples.shape} but must be of "
184-
f"shape (n_samples, len(ranges)={self._measure_dim})")
185-
self._samples = samples
186-
self._centroids = None
180+
self._samples = None
181+
if centroid_method == "kmeans":
182+
if not isinstance(samples, numbers.Integral):
183+
# Validate shape of custom samples.
184+
samples = np.asarray(samples, dtype=self.dtype)
185+
if samples.shape[1] != self._measure_dim:
186+
raise ValueError(
187+
f"Samples has shape {samples.shape} but must be of "
188+
f"shape (n_samples, len(ranges)="
189+
f"{self._measure_dim})")
190+
self._samples = samples
191+
else:
192+
self._samples = self._rng.uniform(
193+
self._lower_bounds,
194+
self._upper_bounds,
195+
size=(samples, self._measure_dim),
196+
).astype(self.dtype)
197+
198+
self._centroids = k_means(self._samples, self._cells,
199+
**self._k_means_kwargs)[0]
200+
201+
if self._centroids.shape[0] < self._cells:
202+
raise RuntimeError(
203+
"While generating the CVT, k-means clustering found "
204+
f"{self._centroids.shape[0]} centroids, but this "
205+
f"archive needs {self._cells} cells. This most "
206+
"likely happened because there are too few samples "
207+
"and/or too many cells.")
208+
elif centroid_method == "random":
209+
# Generate random centroids for the archive.
210+
self._centroids = self._rng.uniform(self._lower_bounds,
211+
self._upper_bounds,
212+
size=(self._cells,
213+
self._measure_dim))
214+
elif centroid_method == "sobol":
215+
# Generate self._cells number of centroids as a Sobol sequence.
216+
sampler = Sobol(d=self._measure_dim, scramble=False)
217+
num_points = np.log2(self._cells).astype(int)
218+
self._centroids = sampler.random_base2(num_points)
219+
elif centroid_method == "scrambled_sobol":
220+
# Generates centroids as a scrambled Sobol sequence.
221+
sampler = Sobol(d=self._measure_dim, scramble=True)
222+
num_points = np.log2(self._cells).astype(int)
223+
self._centroids = sampler.random_base2(num_points)
224+
elif centroid_method == "halton":
225+
# Generates centroids using a Halton sequence.
226+
sampler = Halton(d=self._measure_dim)
227+
self._centroids = sampler.random(n=self._cells)
187228
else:
188229
# Validate shape of `custom_centroids` when they are provided.
189230
custom_centroids = np.asarray(custom_centroids, dtype=self.dtype)
@@ -195,24 +236,11 @@ def __init__(self,
195236
self._centroids = custom_centroids
196237
self._samples = None
197238

198-
if self._centroids is None:
199-
self._samples = self._rng.uniform(
200-
self._lower_bounds,
201-
self._upper_bounds,
202-
size=(self._samples, self._measure_dim),
203-
).astype(self.dtype) if isinstance(self._samples,
204-
int) else self._samples
205-
206-
self._centroids = k_means(self._samples, self._cells,
207-
**self._k_means_kwargs)[0]
208-
209-
if self._centroids.shape[0] < self._cells:
210-
raise RuntimeError(
211-
"While generating the CVT, k-means clustering found "
212-
f"{self._centroids.shape[0]} centroids, but this archive "
213-
f"needs {self._cells} cells. This most likely happened "
214-
"because there are too few samples and/or too many cells.")
215-
239+
self._use_kd_tree = use_kd_tree
240+
self._centroid_kd_tree = None
241+
self._ckdtree_kwargs = ({} if ckdtree_kwargs is None else
242+
ckdtree_kwargs.copy())
243+
self._chunk_size = chunk_size
216244
if self._use_kd_tree:
217245
self._centroid_kd_tree = cKDTree(self._centroids,
218246
**self._ckdtree_kwargs)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"pandas>=1.0.0",
2020
"sortedcontainers>=2.0.0", # Primarily used in SlidingBoundariesArchive.
2121
"scikit-learn>=1.1.0", # Primarily used in CVTArchive.
22-
"scipy>=1.4.0", # Primarily used in CVTArchive.
22+
"scipy>=1.7.0", # Primarily used in CVTArchive.
2323
"threadpoolctl>=3.0.0",
2424
]
2525

0 commit comments

Comments
 (0)