1
1
"""Contains the CVTArchive class."""
2
+ import numbers
3
+
2
4
import numpy as np
3
5
from scipy .spatial import cKDTree # pylint: disable=no-name-in-module
6
+ from scipy .stats .qmc import Halton , Sobol
4
7
from sklearn .cluster import k_means
5
8
6
9
from ribs ._utils import check_batch_shape , check_finite
@@ -90,17 +93,22 @@ class CVTArchive(ArchiveBase):
90
93
and a "bar" field that contains 10D values. Note that field names
91
94
must be valid Python identifiers, and names already used in the
92
95
archive are not allowed.
96
+ custom_centroids (array-like): If passed in, this (cells, measure_dim)
97
+ array will be used as the centroids of the CVT instead of generating
98
+ new ones. In this case, ``samples`` will be ignored, and
99
+ ``archive.samples`` will be None. This can be useful when one wishes
100
+ to use the same CVT across experiments for fair comparison.
101
+ centroid_method (str): Pass in the following methods for
102
+ generating centroids: "random", "sobol", "scrambled sobol",
103
+ "halton". Default method is "kmeans". These methods are derived from
104
+ Mouret 2023: https://dl.acm.org/doi/pdf/10.1145/3583133.3590726.
105
+ Note: Samples are only used when method is "kmeans".
93
106
samples (int or array-like): If it is an int, this specifies the number
94
107
of samples to generate when creating the CVT. Otherwise, this must
95
108
be a (num_samples, measure_dim) array where samples[i] is a sample
96
109
to use when creating the CVT. It can be useful to pass in custom
97
110
samples when there are restrictions on what samples in the measure
98
111
space are (physically) possible.
99
- custom_centroids (array-like): If passed in, this (cells, measure_dim)
100
- array will be used as the centroids of the CVT instead of generating
101
- new ones. In this case, ``samples`` will be ignored, and
102
- ``archive.samples`` will be None. This can be useful when one wishes
103
- to use the same CVT across experiments for fair comparison.
104
112
k_means_kwargs (dict): kwargs for :func:`~sklearn.cluster.k_means`. By
105
113
default, we pass in `n_init=1`, `init="random"`,
106
114
`algorithm="lloyd"`, and `random_state=seed`.
@@ -128,12 +136,13 @@ def __init__(self,
128
136
seed = None ,
129
137
dtype = np .float64 ,
130
138
extra_fields = None ,
131
- samples = 100_000 ,
132
139
custom_centroids = None ,
133
- chunk_size = None ,
140
+ centroid_method = "kmeans" ,
141
+ samples = 100_000 ,
134
142
k_means_kwargs = None ,
135
143
use_kd_tree = True ,
136
- ckdtree_kwargs = None ):
144
+ ckdtree_kwargs = None ,
145
+ chunk_size = None ):
137
146
138
147
ArchiveBase .__init__ (
139
148
self ,
@@ -167,23 +176,55 @@ def __init__(self,
167
176
self ._k_means_kwargs .setdefault ("algorithm" , "lloyd" )
168
177
self ._k_means_kwargs .setdefault ("random_state" , seed )
169
178
170
- self ._use_kd_tree = use_kd_tree
171
- self ._centroid_kd_tree = None
172
- self ._ckdtree_kwargs = ({} if ckdtree_kwargs is None else
173
- ckdtree_kwargs .copy ())
174
- self ._chunk_size = chunk_size
175
-
176
179
if custom_centroids is None :
177
- if not isinstance (samples , int ):
178
- # Validate shape of custom samples. These are ignored when
179
- # `custom_centroids` is provided.
180
- samples = np .asarray (samples , dtype = self .dtype )
181
- if samples .shape [1 ] != self ._measure_dim :
182
- raise ValueError (
183
- f"Samples has shape { samples .shape } but must be of "
184
- f"shape (n_samples, len(ranges)={ self ._measure_dim } )" )
185
- self ._samples = samples
186
- self ._centroids = None
180
+ self ._samples = None
181
+ if centroid_method == "kmeans" :
182
+ if not isinstance (samples , numbers .Integral ):
183
+ # Validate shape of custom samples.
184
+ samples = np .asarray (samples , dtype = self .dtype )
185
+ if samples .shape [1 ] != self ._measure_dim :
186
+ raise ValueError (
187
+ f"Samples has shape { samples .shape } but must be of "
188
+ f"shape (n_samples, len(ranges)="
189
+ f"{ self ._measure_dim } )" )
190
+ self ._samples = samples
191
+ else :
192
+ self ._samples = self ._rng .uniform (
193
+ self ._lower_bounds ,
194
+ self ._upper_bounds ,
195
+ size = (samples , self ._measure_dim ),
196
+ ).astype (self .dtype )
197
+
198
+ self ._centroids = k_means (self ._samples , self ._cells ,
199
+ ** self ._k_means_kwargs )[0 ]
200
+
201
+ if self ._centroids .shape [0 ] < self ._cells :
202
+ raise RuntimeError (
203
+ "While generating the CVT, k-means clustering found "
204
+ f"{ self ._centroids .shape [0 ]} centroids, but this "
205
+ f"archive needs { self ._cells } cells. This most "
206
+ "likely happened because there are too few samples "
207
+ "and/or too many cells." )
208
+ elif centroid_method == "random" :
209
+ # Generate random centroids for the archive.
210
+ self ._centroids = self ._rng .uniform (self ._lower_bounds ,
211
+ self ._upper_bounds ,
212
+ size = (self ._cells ,
213
+ self ._measure_dim ))
214
+ elif centroid_method == "sobol" :
215
+ # Generate self._cells number of centroids as a Sobol sequence.
216
+ sampler = Sobol (d = self ._measure_dim , scramble = False )
217
+ num_points = np .log2 (self ._cells ).astype (int )
218
+ self ._centroids = sampler .random_base2 (num_points )
219
+ elif centroid_method == "scrambled_sobol" :
220
+ # Generates centroids as a scrambled Sobol sequence.
221
+ sampler = Sobol (d = self ._measure_dim , scramble = True )
222
+ num_points = np .log2 (self ._cells ).astype (int )
223
+ self ._centroids = sampler .random_base2 (num_points )
224
+ elif centroid_method == "halton" :
225
+ # Generates centroids using a Halton sequence.
226
+ sampler = Halton (d = self ._measure_dim )
227
+ self ._centroids = sampler .random (n = self ._cells )
187
228
else :
188
229
# Validate shape of `custom_centroids` when they are provided.
189
230
custom_centroids = np .asarray (custom_centroids , dtype = self .dtype )
@@ -195,24 +236,11 @@ def __init__(self,
195
236
self ._centroids = custom_centroids
196
237
self ._samples = None
197
238
198
- if self ._centroids is None :
199
- self ._samples = self ._rng .uniform (
200
- self ._lower_bounds ,
201
- self ._upper_bounds ,
202
- size = (self ._samples , self ._measure_dim ),
203
- ).astype (self .dtype ) if isinstance (self ._samples ,
204
- int ) else self ._samples
205
-
206
- self ._centroids = k_means (self ._samples , self ._cells ,
207
- ** self ._k_means_kwargs )[0 ]
208
-
209
- if self ._centroids .shape [0 ] < self ._cells :
210
- raise RuntimeError (
211
- "While generating the CVT, k-means clustering found "
212
- f"{ self ._centroids .shape [0 ]} centroids, but this archive "
213
- f"needs { self ._cells } cells. This most likely happened "
214
- "because there are too few samples and/or too many cells." )
215
-
239
+ self ._use_kd_tree = use_kd_tree
240
+ self ._centroid_kd_tree = None
241
+ self ._ckdtree_kwargs = ({} if ckdtree_kwargs is None else
242
+ ckdtree_kwargs .copy ())
243
+ self ._chunk_size = chunk_size
216
244
if self ._use_kd_tree :
217
245
self ._centroid_kd_tree = cKDTree (self ._centroids ,
218
246
** self ._ckdtree_kwargs )
0 commit comments