Skip to content

Commit

Permalink
Merge pull request #213 from andrewheusser/remove-hdbscan-dep
Browse files Browse the repository at this point in the history
Remove hdbscan dependency
  • Loading branch information
jeremymanning authored Aug 2, 2018
2 parents 671ac89 + 6a461eb commit 1bd82af
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 69 deletions.
2 changes: 1 addition & 1 deletion examples/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
geo = hyp.load('weights_avg')

# plot
geo.plot(animate=True)
geo.plot(animate=True, legend=['first', 'second'])
101 changes: 54 additions & 47 deletions hypertools/tools/cluster.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
#!/usr/bin/env python
import warnings
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
import numpy as np
import six
from hdbscan import HDBSCAN
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
from .._shared.helpers import *
from .format_data import format_data as formatter

# dictionary of models
models = {
'KMeans': KMeans,
'MiniBatchKMeans': MiniBatchKMeans,
'AgglomerativeClustering': AgglomerativeClustering,
'FeatureAgglomeration': FeatureAgglomeration,
'Birch': Birch,
'SpectralClustering': SpectralClustering,
}

try:
from hdbscan import HDBSCAN
_has_hdbscan = True
models.update({'HDBSCAN': HDBSCAN})
except ImportError:
_has_hdbscan = False


@memoize
def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
Expand Down Expand Up @@ -46,48 +62,39 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
"""

# if cluster is None, just return data
if cluster is None:
if cluster == None:
return x
else:

if ndims is not None:
warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.')

if format_data:
x = formatter(x, ppca=True)

# dictionary of models
models = {
'KMeans' : KMeans,
'MiniBatchKMeans' : MiniBatchKMeans,
'AgglomerativeClustering' : AgglomerativeClustering,
'FeatureAgglomeration' : FeatureAgglomeration,
'Birch' : Birch,
'SpectralClustering' : SpectralClustering,
'HDBSCAN' : HDBSCAN
}

# if reduce is a string, find the corresponding model
if isinstance(cluster, six.string_types):
model = models[cluster]
if cluster != 'HDBSCAN':
model_params = {
'n_clusters' : n_clusters
}
else:
model_params = {}
# if its a dict, use custom params
elif type(cluster) is dict:
if isinstance(cluster['model'], six.string_types):
model = models[cluster['model']]
model_params = cluster['params']

# initialize model
model = model(**model_params)

# fit the model
model.fit(np.vstack(x))

# return the labels
return list(model.labels_)
elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \
(isinstance(cluster, dict) and cluster['model']=='HDBSCAN'):
if not _has_hdbscan:
raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11')

if ndims != None:
warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.')

if format_data:
x = formatter(x, ppca=True)

# if reduce is a string, find the corresponding model
if isinstance(cluster, six.string_types):
model = models[cluster]
if cluster != 'HDBSCAN':
model_params = {
'n_clusters' : n_clusters
}
else:
model_params = {}
# if its a dict, use custom params
elif type(cluster) is dict:
if isinstance(cluster['model'], six.string_types):
model = models[cluster['model']]
model_params = cluster['params']

# initialize model
model = model(**model_params)

# fit the model
model.fit(np.vstack(x))

# return the labels
return list(model.labels_)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ matplotlib>=1.5.1
scipy>=1.0.0
numpy>=1.10.4
umap-learn>=0.1.5
hdbscan>=0.8.11
future
requests
deepdish
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

setup(
name='hypertools',
version='0.5.0',
version='0.5.1',
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
author='Contextual Dynamics Lab',
Expand All @@ -42,7 +42,6 @@
'matplotlib>=1.5.1',
'scipy>=1.0.0',
'numpy>=1.10.4',
'hdbscan>=0.8.11',
'umap-learn>=0.1.5',
'future',
'requests',
Expand Down
26 changes: 15 additions & 11 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-

import numpy as np
import pytest
from hypertools.tools.cluster import cluster
from hypertools.plot.plot import plot

cluster1 = np.random.multivariate_normal(np.zeros(3), np.eye(3), size=100)
cluster2 = np.random.multivariate_normal(np.zeros(3)+100, np.eye(3), size=100)
data = np.vstack([cluster1,cluster2])
labels = cluster(data,n_clusters=2)
data = np.vstack([cluster1, cluster2])
labels = cluster(data, n_clusters=2)


def test_cluster_n_clusters():
Expand All @@ -19,12 +20,15 @@ def test_cluster_returns_list():


def test_cluster_hdbscan():
# Given well separated clusters this should "just work"
hdbscan_labels = cluster(data, cluster='HDBSCAN')
assert len(set(hdbscan_labels)) == 2


def text_cluster_geo():
geo = plot(data, show=False)
hdbscan_labels = cluster(geo, cluster='HDBSCAN')
assert len(set(hdbscan_labels)) == 2
try:
from hdbscan import HDBSCAN
_has_hdbscan = True
except:
_has_hdbscan = False

if _has_hdbscan:
hdbscan_labels = cluster(data, cluster='HDBSCAN')
assert len(set(hdbscan_labels)) == 2
else:
with pytest.raises(ImportError):
hdbscan_labels = cluster(data, cluster='HDBSCAN')
7 changes: 0 additions & 7 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,6 @@ def test_plot_cluster_n_clusters():
geo = plot.plot(weights, n_clusters=3, show=False)
assert isinstance(geo, DataGeometry)


def test_plot_cluster_HDBSCAN():
# should return 10d data since ndims=10
geo = plot.plot(weights, cluster='HDBSCAN', show=False)
assert isinstance(geo, DataGeometry)


def test_plot_nd():
geo = plot.plot(data, show=False)
assert all([i.shape[1]==d.shape[1] for i, d in zip(geo.data, data)])
Expand Down

0 comments on commit 1bd82af

Please sign in to comment.