Skip to content

Commit 1bd82af

Browse files
Merge pull request #213 from andrewheusser/remove-hdbscan-dep
Remove hdbscan dependency
2 parents 671ac89 + 6a461eb commit 1bd82af

File tree

6 files changed

+71
-69
lines changed

6 files changed

+71
-69
lines changed

examples/animate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
geo = hyp.load('weights_avg')
1919

2020
# plot
21-
geo.plot(animate=True)
21+
geo.plot(animate=True, legend=['first', 'second'])

hypertools/tools/cluster.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
#!/usr/bin/env python
22
import warnings
3-
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
4-
import numpy as np
53
import six
6-
from hdbscan import HDBSCAN
4+
import numpy as np
5+
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
76
from .._shared.helpers import *
87
from .format_data import format_data as formatter
98

9+
# dictionary of models
10+
models = {
11+
'KMeans': KMeans,
12+
'MiniBatchKMeans': MiniBatchKMeans,
13+
'AgglomerativeClustering': AgglomerativeClustering,
14+
'FeatureAgglomeration': FeatureAgglomeration,
15+
'Birch': Birch,
16+
'SpectralClustering': SpectralClustering,
17+
}
18+
19+
try:
20+
from hdbscan import HDBSCAN
21+
_has_hdbscan = True
22+
models.update({'HDBSCAN': HDBSCAN})
23+
except ImportError:
24+
_has_hdbscan = False
25+
1026

1127
@memoize
1228
def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
@@ -46,48 +62,39 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
4662
4763
"""
4864

49-
# if cluster is None, just return data
50-
if cluster is None:
65+
if cluster == None:
5166
return x
52-
else:
53-
54-
if ndims is not None:
55-
warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.')
56-
57-
if format_data:
58-
x = formatter(x, ppca=True)
59-
60-
# dictionary of models
61-
models = {
62-
'KMeans' : KMeans,
63-
'MiniBatchKMeans' : MiniBatchKMeans,
64-
'AgglomerativeClustering' : AgglomerativeClustering,
65-
'FeatureAgglomeration' : FeatureAgglomeration,
66-
'Birch' : Birch,
67-
'SpectralClustering' : SpectralClustering,
68-
'HDBSCAN' : HDBSCAN
69-
}
70-
71-
# if reduce is a string, find the corresponding model
72-
if isinstance(cluster, six.string_types):
73-
model = models[cluster]
74-
if cluster != 'HDBSCAN':
75-
model_params = {
76-
'n_clusters' : n_clusters
77-
}
78-
else:
79-
model_params = {}
80-
# if its a dict, use custom params
81-
elif type(cluster) is dict:
82-
if isinstance(cluster['model'], six.string_types):
83-
model = models[cluster['model']]
84-
model_params = cluster['params']
85-
86-
# initialize model
87-
model = model(**model_params)
88-
89-
# fit the model
90-
model.fit(np.vstack(x))
91-
92-
# return the labels
93-
return list(model.labels_)
67+
elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \
68+
(isinstance(cluster, dict) and cluster['model']=='HDBSCAN'):
69+
if not _has_hdbscan:
70+
raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11')
71+
72+
if ndims != None:
73+
warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.')
74+
75+
if format_data:
76+
x = formatter(x, ppca=True)
77+
78+
# if reduce is a string, find the corresponding model
79+
if isinstance(cluster, six.string_types):
80+
model = models[cluster]
81+
if cluster != 'HDBSCAN':
82+
model_params = {
83+
'n_clusters' : n_clusters
84+
}
85+
else:
86+
model_params = {}
87+
# if its a dict, use custom params
88+
elif type(cluster) is dict:
89+
if isinstance(cluster['model'], six.string_types):
90+
model = models[cluster['model']]
91+
model_params = cluster['params']
92+
93+
# initialize model
94+
model = model(**model_params)
95+
96+
# fit the model
97+
model.fit(np.vstack(x))
98+
99+
# return the labels
100+
return list(model.labels_)

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ matplotlib>=1.5.1
66
scipy>=1.0.0
77
numpy>=1.10.4
88
umap-learn>=0.1.5
9-
hdbscan>=0.8.11
109
future
1110
requests
1211
deepdish

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
setup(
2727
name='hypertools',
28-
version='0.5.0',
28+
version='0.5.1',
2929
description=DESCRIPTION,
3030
long_description=LONG_DESCRIPTION,
3131
author='Contextual Dynamics Lab',
@@ -42,7 +42,6 @@
4242
'matplotlib>=1.5.1',
4343
'scipy>=1.0.0',
4444
'numpy>=1.10.4',
45-
'hdbscan>=0.8.11',
4645
'umap-learn>=0.1.5',
4746
'future',
4847
'requests',

tests/test_cluster.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# -*- coding: utf-8 -*-
22

33
import numpy as np
4+
import pytest
45
from hypertools.tools.cluster import cluster
56
from hypertools.plot.plot import plot
67

78
cluster1 = np.random.multivariate_normal(np.zeros(3), np.eye(3), size=100)
89
cluster2 = np.random.multivariate_normal(np.zeros(3)+100, np.eye(3), size=100)
9-
data = np.vstack([cluster1,cluster2])
10-
labels = cluster(data,n_clusters=2)
10+
data = np.vstack([cluster1, cluster2])
11+
labels = cluster(data, n_clusters=2)
1112

1213

1314
def test_cluster_n_clusters():
@@ -19,12 +20,15 @@ def test_cluster_returns_list():
1920

2021

2122
def test_cluster_hdbscan():
22-
# Given well separated clusters this should "just work"
23-
hdbscan_labels = cluster(data, cluster='HDBSCAN')
24-
assert len(set(hdbscan_labels)) == 2
25-
26-
27-
def text_cluster_geo():
28-
geo = plot(data, show=False)
29-
hdbscan_labels = cluster(geo, cluster='HDBSCAN')
30-
assert len(set(hdbscan_labels)) == 2
23+
try:
24+
from hdbscan import HDBSCAN
25+
_has_hdbscan = True
26+
except:
27+
_has_hdbscan = False
28+
29+
if _has_hdbscan:
30+
hdbscan_labels = cluster(data, cluster='HDBSCAN')
31+
assert len(set(hdbscan_labels)) == 2
32+
else:
33+
with pytest.raises(ImportError):
34+
hdbscan_labels = cluster(data, cluster='HDBSCAN')

tests/test_plot.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,6 @@ def test_plot_cluster_n_clusters():
9999
geo = plot.plot(weights, n_clusters=3, show=False)
100100
assert isinstance(geo, DataGeometry)
101101

102-
103-
def test_plot_cluster_HDBSCAN():
104-
# should return 10d data since ndims=10
105-
geo = plot.plot(weights, cluster='HDBSCAN', show=False)
106-
assert isinstance(geo, DataGeometry)
107-
108-
109102
def test_plot_nd():
110103
geo = plot.plot(data, show=False)
111104
assert all([i.shape[1]==d.shape[1] for i, d in zip(geo.data, data)])

0 commit comments

Comments
 (0)