Skip to content

Commit

Permalink
Merge branch 'main' into ASReview2-rf
Browse files Browse the repository at this point in the history
  • Loading branch information
timovdk committed Feb 5, 2025
2 parents 2f982b2 + cbe49a6 commit 75b137e
Show file tree
Hide file tree
Showing 8 changed files with 1,135 additions and 99 deletions.
26 changes: 9 additions & 17 deletions asreview2-optuna/classifiers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import optuna
from asreview.models.classifiers import (
LogisticClassifier,
NaiveBayesClassifier,
SVMClassifier,
Logistic,
NaiveBayes,
RandomForest,
SVM,
)

from sklearn.ensemble import RandomForestClassifier
Expand All @@ -23,16 +24,7 @@ def logistic_params(trial: optuna.trial.FrozenTrial):
def svm_params(trial: optuna.trial.FrozenTrial):
# Use logarithmic normal distribution for C (C effect is non-linear)
C = trial.suggest_float("svm__C", 1e-3, 100, log=True)

# Use categorical for kernel
kernel = trial.suggest_categorical("svm__kernel", ["linear", "rbf"])

# Only set gamma to a value if we use rbf kernel
gamma = "scale"
if kernel == "rbf":
# Use logarithmic normal distribution for gamma (gamma effect is non-linear)
gamma = trial.suggest_float("svm__gamma", 1e-4, 10, log=True)
return {"C": C, "kernel": kernel, "gamma": gamma}
return {"C": C}


def random_forest_params(trial: optuna.trial.FrozenTrial):
Expand Down Expand Up @@ -72,8 +64,8 @@ def __init__(self, n_estimators=100, max_features=10, **kwargs):


classifiers = {
"nb": NaiveBayesClassifier,
"log": LogisticClassifier,
"svm": SVMClassifier,
"rf": RFClassifier,
"nb": NaiveBayes,
"log": Logistic,
"svm": SVM,
"rf": RandomForest,
}
Loading

0 comments on commit 75b137e

Please sign in to comment.