Skip to content

Commit f1fc032

Browse files
committed
Added a new feature that allows another classifier to be used for generating loreplots.
1 parent 0b40c19 commit f1fc032

12 files changed

+503
-29
lines changed

CITATION.cff

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ authors:
1111
given-names: "Jeroen"
1212
orcid: "https://orcid.org/0000-0002-1337-041X"
1313
title: "lorepy: Logistic Regression Plots for Python"
14-
version: 0.1.1
15-
doi: 10.5281/zenodo.8324902
14+
version: 0.2.0
15+
doi: 10.5281/zenodo.8321785
1616
date-released: 2023-09-07
1717
url: "https://github.com/raeslab/lorepy"

LICENSE

Lines changed: 437 additions & 21 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
[![Run Pytest](https://github.com/raeslab/lorepy/actions/workflows/autopytest.yml/badge.svg)](https://github.com/raeslab/lorepy/actions/workflows/autopytest.yml) [![Coverage](https://raw.githubusercontent.com/raeslab/lorepy/main/docs/coverage-badge.svg)](https://raw.githubusercontent.com/raeslab/lorepy/main/docs/coverage-badge.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![DOI](https://zenodo.org/badge/686018963.svg)](https://zenodo.org/badge/latestdoi/686018963) [![PyPI version](https://badge.fury.io/py/lorepy.svg)](https://badge.fury.io/py/lorepy)
1+
[![Run Pytest](https://github.com/raeslab/lorepy/actions/workflows/autopytest.yml/badge.svg)](https://github.com/raeslab/lorepy/actions/workflows/autopytest.yml) [![Coverage](https://raw.githubusercontent.com/raeslab/lorepy/main/docs/coverage-badge.svg)](https://raw.githubusercontent.com/raeslab/lorepy/main/docs/coverage-badge.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![DOI](https://zenodo.org/badge/686018963.svg)](https://zenodo.org/badge/latestdoi/686018963) [![PyPI version](https://badge.fury.io/py/lorepy.svg)](https://badge.fury.io/py/lorepy) [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
22

33
# lorepy: Logistic Regression Plots for Python
44

@@ -109,13 +109,38 @@ plt.show()
109109

110110
![LoRePlot in subplots](https://raw.githubusercontent.com/raeslab/lorepy/main/docs/img/loreplot_subplot.png)
111111

112+
By default lorepy uses a multi-class logistic regression model, however this can be replaced with any classifier
113+
from scikit-learn that implements ```predict_proba``` and ```fit```. Below you can see the code and output with a
114+
Support Vector Classifier (SVC) and Random Forest Classifier (RF).
115+
116+
```python
117+
from sklearn.svm import SVC
118+
from sklearn.ensemble import RandomForestClassifier
119+
120+
fig, ax = plt.subplots(1, 2, sharex=False, sharey=True)
121+
122+
svc = SVC(probability=True)
123+
rf = RandomForestClassifier(n_estimators=10, max_depth=2)
124+
125+
loreplot(data=iris_df, x="sepal width (cm)", y="species", clf=svc, ax=ax[0])
126+
loreplot(data=iris_df, x="sepal width (cm)", y="species", clf=rf, ax=ax[1])
127+
128+
ax[0].set_title("SVC")
129+
ax[1].set_title("RF")
130+
131+
plt.savefig("./docs/img/loreplot_other_clf.png", dpi=150)
132+
plt.show()
133+
```
134+
135+
[Lorepy with different types of classifiers](./docs/img/loreplot_other_clf.png)
136+
112137
## Contributing
113138

114-
Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**.
139+
Any contributions you make are **greatly appreciated**.
115140

116141
* Found a bug or have some suggestions? Open an [issue](https://github.com/raeslab/lorepy/issues).
117142
* Pull requests are welcome! Though open an [issue](https://github.com/raeslab/lorepy/issues) first to discuss which features/changes you wish to implement.
118143

119144
## Contact
120145

121-
lorepy was developed by [Sebastian Proost](https://sebastian.proost.science/) at the [RaesLab](https://raeslab.sites.vib.be/en) and was based on R code written by [Sara Vieira-Silva](https://saravsilva.github.io/).
146+
lorepy was developed by [Sebastian Proost](https://sebastian.proost.science/) at the [RaesLab](https://raeslab.sites.vib.be/en) and was based on R code written by [Sara Vieira-Silva](https://saravsilva.github.io/). As of version 0.2.0 lorepy is available under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.

docs/img/loreplot.png

788 Bytes
Loading

docs/img/loreplot_custom_color.png

1.9 KB
Loading

docs/img/loreplot_custom_markers.png

1.16 KB
Loading

docs/img/loreplot_other_clf.png

82.5 KB
Loading

docs/img/loreplot_subplot.png

4.7 KB
Loading

example.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,21 @@
5252

5353
plt.savefig("./docs/img/loreplot_subplot.png", dpi=150)
5454
plt.show()
55+
56+
# Basic Lore Plot with default style but different classifier
57+
from sklearn.svm import SVC
58+
from sklearn.ensemble import RandomForestClassifier
59+
60+
fig, ax = plt.subplots(1, 2, sharex=False, sharey=True)
61+
62+
svc = SVC(probability=True)
63+
rf = RandomForestClassifier(n_estimators=10, max_depth=2)
64+
65+
loreplot(data=iris_df, x="sepal width (cm)", y="species", clf=svc, ax=ax[0])
66+
loreplot(data=iris_df, x="sepal width (cm)", y="species", clf=rf, ax=ax[1])
67+
68+
ax[0].set_title("SVC")
69+
ax[1].set_title("RF")
70+
71+
plt.savefig("./docs/img/loreplot_other_clf.png", dpi=150)
72+
plt.show()

lorepy/lorepy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def loreplot(
4242
x_range: Optional[Tuple[float, float]] = None,
4343
scatter_kws: dict = dict({}),
4444
ax=None,
45+
clf=None,
4546
**kwargs
4647
):
4748
"""
@@ -54,6 +55,7 @@ def loreplot(
5455
:param x_range: Either None (range will be selected automatically) or a tuple with min and max value for the x-axis
5556
:param scatter_kws: Dictionary with keyword arguments to pass to the scatter function
5657
:param ax: subplot to draw on, in case lorepy is used in a subplot
58+
:param clf: provide a different scikit-learn classifier for the function. Should implement the predict_proba() and fit()
5759
:param kwargs: Additional arguments to pass to pandas' plot.area function
5860
"""
5961
if ax is None:
@@ -62,7 +64,7 @@ def loreplot(
6264
X_reg = np.array(tmp_df[x]).reshape(-1, 1)
6365
y_reg = np.array(tmp_df[y])
6466

65-
lg = LogisticRegression(multi_class="multinomial")
67+
lg = LogisticRegression(multi_class="multinomial") if clf is None else clf
6668
lg.fit(X_reg, y_reg)
6769

6870
if "linestyle" not in kwargs.keys():

0 commit comments

Comments
 (0)