Skip to content

Commit ad98c0c

Browse files
authored
Bug fixes in pre-processing methods
1 parent cb8c6b6 commit ad98c0c

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

src/aequitas/flow/datasets/folktables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
BOOL_COLUMNS = {
6666
"ACSIncome": ["SEX"],
67-
"ACSEmployment": ["SEX", "DIS", "NATIVTY", "DEAR", "DEYE", "DREM"],
67+
"ACSEmployment": ["SEX", "DIS", "NATIVITY", "DEAR", "DEYE", "DREM"],
6868
"ACSMobility": [
6969
"SEX",
7070
"DIS",

src/aequitas/flow/methods/preprocessing/data_repairer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
self.repair_level = repair_level
5555
self.columns = columns
5656
self.definition = definition
57+
self.used_in_inference = True
5758

5859
def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> None:
5960
"""
@@ -72,7 +73,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> N
7273
super().fit(X, y, s)
7374

7475
if self.columns is None:
75-
self.columns = X.columns.tolist()
76+
self.columns = [
77+
column
78+
for column in X.columns
79+
if (X[column].dtype != "category" and X[column].dtype != "bool")
80+
]
7681
if s is None:
7782
raise ValueError("s must be passed.")
7883
self._quantile_points = np.linspace(0, 1, self.definition)
@@ -141,7 +146,7 @@ def transform(
141146
Transformed features, labels, and sensitive attribute.
142147
"""
143148
super().transform(X, y, s)
144-
149+
145150
if s is None:
146151
raise ValueError("s must be passed.")
147152

src/aequitas/flow/methods/preprocessing/feature_importance_suppression.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
feature_importance_threshold: Optional[float] = 0.1,
1717
n_estimators: Optional[int] = 10,
1818
seed: int = 0,
19+
n_jobs: int = 1,
1920
):
2021
"""Iterively removes the most important features with respect to the sensitive
2122
attribute.
@@ -32,6 +33,8 @@ def __init__(
3233
The number of trees in the random forest. Defaults to 10.
3334
seed : int, optional
3435
The seed for the random forest. Defaults to 0.
36+
n_jobs : int, optional
37+
The number of jobs to run in parallel. Defaults to 1.
3538
"""
3639
self.logger = create_logger(
3740
"methods.preprocessing.FeatureImportanceSuppression"
@@ -45,6 +48,7 @@ def __init__(
4548
self.feature_importance_threshold = feature_importance_threshold
4649
self.n_estimators = n_estimators
4750
self.seed = seed
51+
self.n_jobs = n_jobs
4852

4953
def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
5054
"""Iteratively removes the most important features to predict the sensitive
@@ -64,7 +68,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
6468
self.logger.info("Identifying features to remove.")
6569

6670
rf = RandomForestClassifier(
67-
n_estimators=self.n_estimators, random_state=self.seed
71+
n_estimators=self.n_estimators, random_state=self.seed, n_jobs=self.n_jobs
6872
)
6973

7074
features = pd.concat([X, y], axis=1)

src/aequitas/flow/methods/preprocessing/massaging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424

2525
self.classifier = instantiate_object(classifier, **classifier_args)
2626
self.logger.info(f"Created base estimator {self.classifier}")
27+
self.used_in_inference = False
2728

2829
def _rank(
2930
self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]

0 commit comments

Comments
 (0)