diff --git a/examples/animate.py b/examples/animate.py index 3174c8b3..13f36500 100644 --- a/examples/animate.py +++ b/examples/animate.py @@ -18,4 +18,4 @@ geo = hyp.load('weights_avg') # plot -geo.plot(animate=True) +geo.plot(animate=True, legend=['first', 'second']) diff --git a/hypertools/tools/cluster.py b/hypertools/tools/cluster.py index e8bd0d13..8ae44309 100644 --- a/hypertools/tools/cluster.py +++ b/hypertools/tools/cluster.py @@ -1,12 +1,28 @@ #!/usr/bin/env python import warnings -from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering -import numpy as np import six -from hdbscan import HDBSCAN +import numpy as np +from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering from .._shared.helpers import * from .format_data import format_data as formatter +# dictionary of models +models = { + 'KMeans': KMeans, + 'MiniBatchKMeans': MiniBatchKMeans, + 'AgglomerativeClustering': AgglomerativeClustering, + 'FeatureAgglomeration': FeatureAgglomeration, + 'Birch': Birch, + 'SpectralClustering': SpectralClustering, +} + +try: + from hdbscan import HDBSCAN + _has_hdbscan = True + models.update({'HDBSCAN': HDBSCAN}) +except ImportError: + _has_hdbscan = False + @memoize def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True): @@ -46,48 +62,39 @@ def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True): """ - # if cluster is None, just return data - if cluster is None: + if cluster == None: return x - else: - - if ndims is not None: - warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.') - - if format_data: - x = formatter(x, ppca=True) - - # dictionary of models - models = { - 'KMeans' : KMeans, - 'MiniBatchKMeans' : MiniBatchKMeans, - 'AgglomerativeClustering' : AgglomerativeClustering, - 'FeatureAgglomeration' : FeatureAgglomeration, - 'Birch' : Birch, - 'SpectralClustering' : SpectralClustering, - 'HDBSCAN' : HDBSCAN - } - - # if reduce is a string, find the corresponding model - if isinstance(cluster, six.string_types): - model = models[cluster] - if cluster != 'HDBSCAN': - model_params = { - 'n_clusters' : n_clusters - } - else: - model_params = {} - # if its a dict, use custom params - elif type(cluster) is dict: - if isinstance(cluster['model'], six.string_types): - model = models[cluster['model']] - model_params = cluster['params'] - - # initialize model - model = model(**model_params) - - # fit the model - model.fit(np.vstack(x)) - - # return the labels - return list(model.labels_) + elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \ + (isinstance(cluster, dict) and cluster['model']=='HDBSCAN'): + if not _has_hdbscan: + raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11') + + if ndims != None: + warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.') + + if format_data: + x = formatter(x, ppca=True) + + # if reduce is a string, find the corresponding model + if isinstance(cluster, six.string_types): + model = models[cluster] + if cluster != 'HDBSCAN': + model_params = { + 'n_clusters' : n_clusters + } + else: + model_params = {} + # if its a dict, use custom params + elif type(cluster) is dict: + if isinstance(cluster['model'], six.string_types): + model = models[cluster['model']] + model_params = cluster['params'] + + # initialize model + model = model(**model_params) + + # fit the model + model.fit(np.vstack(x)) + + # return the labels + return list(model.labels_) diff --git a/requirements.txt b/requirements.txt index 843f3c5e..df41590c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ matplotlib>=1.5.1 scipy>=1.0.0 numpy>=1.10.4 umap-learn>=0.1.5 -hdbscan>=0.8.11 future requests deepdish diff --git a/setup.py b/setup.py index 0cea66c4..8580cf73 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup( name='hypertools', - version='0.5.0', + version='0.5.1', description=DESCRIPTION, long_description=LONG_DESCRIPTION, author='Contextual Dynamics Lab', @@ -42,7 +42,6 @@ 'matplotlib>=1.5.1', 'scipy>=1.0.0', 'numpy>=1.10.4', - 'hdbscan>=0.8.11', 'umap-learn>=0.1.5', 'future', 'requests', diff --git a/tests/test_cluster.py b/tests/test_cluster.py index e9b541a0..aa153763 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- import numpy as np +import pytest from hypertools.tools.cluster import cluster from hypertools.plot.plot import plot cluster1 = np.random.multivariate_normal(np.zeros(3), np.eye(3), size=100) cluster2 = np.random.multivariate_normal(np.zeros(3)+100, np.eye(3), size=100) -data = np.vstack([cluster1,cluster2]) -labels = cluster(data,n_clusters=2) +data = np.vstack([cluster1, cluster2]) +labels = cluster(data, n_clusters=2) def test_cluster_n_clusters(): @@ -19,12 +20,15 @@ def test_cluster_returns_list(): def test_cluster_hdbscan(): - # Given well separated clusters this should "just work" - hdbscan_labels = cluster(data, cluster='HDBSCAN') - assert len(set(hdbscan_labels)) == 2 - - -def text_cluster_geo(): - geo = plot(data, show=False) - hdbscan_labels = cluster(geo, cluster='HDBSCAN') - assert len(set(hdbscan_labels)) == 2 + try: + from hdbscan import HDBSCAN + _has_hdbscan = True + except: + _has_hdbscan = False + + if _has_hdbscan: + hdbscan_labels = cluster(data, cluster='HDBSCAN') + assert len(set(hdbscan_labels)) == 2 + else: + with pytest.raises(ImportError): + hdbscan_labels = cluster(data, cluster='HDBSCAN') diff --git a/tests/test_plot.py b/tests/test_plot.py index d9b520ef..cf5f00d7 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -99,13 +99,6 @@ def test_plot_cluster_n_clusters(): geo = plot.plot(weights, n_clusters=3, show=False) assert isinstance(geo, DataGeometry) - -def test_plot_cluster_HDBSCAN(): - # should return 10d data since ndims=10 - geo = plot.plot(weights, cluster='HDBSCAN', show=False) - assert isinstance(geo, DataGeometry) - - def test_plot_nd(): geo = plot.plot(data, show=False) assert all([i.shape[1]==d.shape[1] for i, d in zip(geo.data, data)])