Skip to content

Commit a62372f

Browse files
authored
Merge branch 'develop' into issue-3
2 parents a8ea6db + 195231d commit a62372f

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "pCRscore"
7-
version = "0.0.6+issue4.issue3"
7+
version = "0.0.7"
88
authors = [
99
{name="Youness Azimzade"}
1010
]

src/pCRscore/misc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
def _binary_encode(data, column, out_values=[-1, 1], reverse=False):
2+
unique_values = data[column].unique()
3+
if reverse:
4+
# Flip unique_values order
5+
unique_values = unique_values[::-1]
6+
if len(unique_values) != 2:
7+
raise ValueError(f"{column} must contain exactly two unique values.")
8+
print(
9+
"Recoding '", unique_values[0], "' as ", -1,
10+
" and '", unique_values[1], "' as ", 1, sep=''
11+
)
12+
value_map = {
13+
unique_values[0]: out_values[0], unique_values[1]: out_values[1]
14+
}
15+
data[column] = data[column].map(value_map)
16+
return data

src/pCRscore/svm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
from sklearn.model_selection import \
77
GridSearchCV, train_test_split, KFold, cross_val_score
88
from sklearn.svm import SVC
9+
from .misc import _binary_encode
910

1011

1112
def preprocess(data, split_var='Cohort'):
1213
# Mapping the values in the 'Response' column to binary values 0 and 1
1314
resp = {'pCR': 1, 'RD': 0}
1415
data.Response = [resp[item] for item in data.Response]
1516

16-
# Mapping the values in the 'ER' column to binary values 0 and 1
17-
er = {'Positive': 1, 'Negative': 0}
18-
data.ER = [er[item] for item in data.ER]
17+
# Mapping the values in the 'ER' column to binary values
18+
data = _binary_encode(data, 'ER', out_values=[-1, 1])
1919

2020
# Creating dummy variables for the categorical column 'PAM50'
2121
categorical_cols = ['PAM50']

tests/test_svm.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest import mock
44
import pytest
55
import numpy as np
6+
from pandas.testing import assert_frame_equal
67

78

89
@pytest.fixture
@@ -92,3 +93,28 @@ def test_shapley():
9293
assert isinstance(shapl, np.ndarray)
9394
assert shapl.shape == (30, 44)
9495
svm.shap_plot(shapl, X)
96+
97+
98+
def test__binary_encode():
99+
# Test with likely data
100+
data = pd.DataFrame({'PAM50': ['Lum', 'Bas', 'Lum', 'Bas', 'Lum', 'Bas']})
101+
data_encoded = svm._binary_encode(data, 'PAM50')
102+
data_ref = pd.DataFrame({'PAM50': [-1, 1, -1, 1, -1, 1]})
103+
assert_frame_equal(data_encoded, data_ref)
104+
105+
# Check that first value is always -1
106+
data = pd.DataFrame({'X': ['A', 'Z']})
107+
data_encoded = svm._binary_encode(data, 'X')
108+
data_ref = pd.DataFrame({'X': [-1, 1]})
109+
assert_frame_equal(data_encoded, data_ref)
110+
111+
data = pd.DataFrame({'X': ['Z', 'A']})
112+
data_encoded = svm._binary_encode(data, 'X')
113+
data_ref = pd.DataFrame({'X': [-1, 1]})
114+
assert_frame_equal(data_encoded, data_ref)
115+
116+
# Reversing works
117+
data = pd.DataFrame({'X': ['A', 'Z']})
118+
data_encoded = svm._binary_encode(data, 'X', reverse=True)
119+
data_ref = pd.DataFrame({'X': [1, -1]})
120+
assert_frame_equal(data_encoded, data_ref)

0 commit comments

Comments
 (0)