Skip to content

Commit c2c0305

Browse files
committed
Adds reproducible script for the article
1 parent 82d2073 commit c2c0305

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

R/imbalance.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ NULL
2424

2525
#' @useDynLib imbalance, .registration = TRUE
2626
#' @importFrom Rcpp sourceCpp
27+
#' @import smotefamily
2728
NULL
2829

2930

@@ -48,7 +49,6 @@ NULL
4849
#'
4950
#' @return A balanced \code{data.frame} with same structure as \code{dataset},
5051
#' containing both original instances and new ones
51-
#' @import smotefamily
5252
#' @export
5353
#'
5454
#' @examples

data-raw/banana.R

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
library("imbalance")
2+
library("ggplot2")
3+
library("caret")
4+
set.seed(12345)
5+
6+
# Load original banana dataset
7+
data(banana_orig)
8+
# Load imbalanced banana dataset
9+
data(banana)
10+
# Limits for the plot axis
11+
myxlim <- c(-4, 4)
12+
myylim <- c(-3, 4)
13+
# Methods to apply and plot
14+
methods <- c("original", "imbalanced", "SMOTE", "MWMOTE", "RWO", "PDFOS")
15+
16+
for(i in seq_along(methods)){
17+
method <- methods[i]
18+
19+
if(method == "original"){
20+
dataset <- banana_orig
21+
}else{
22+
if(method == "imbalanced"){
23+
dataset <- banana
24+
}else{
25+
dataset <- oversample(banana, ratio = 0.7, method = method, filtering = T)
26+
}
27+
28+
model <- knn3Train(dataset[, -3], banana_orig[, -3], dataset$Class, k = 3, l = 0, prob = TRUE, use.all = TRUE)
29+
model[model == "negative"] <- -1
30+
model[model == "positive"] <- 1
31+
model <- as.numeric(model)
32+
modelauc <- auc(banana_orig[, 3], model)
33+
print(paste(method, "AUC:", modelauc))
34+
}
35+
36+
qplot(At1, At2, col = Class, data = dataset,
37+
xlab = "at1", ylab = "at2", xlim = myxlim, ylim = myylim) +
38+
scale_colour_manual(values = c("#E69F00", "#000000")) +
39+
geom_point(size = 2, alpha = 0.3) +
40+
theme(text = element_text(size = 30), legend.position="none")
41+
42+
ggplot(dataset, aes_string("At1", "At2", col = "Class")) +
43+
geom_point(alpha = 0.3) +
44+
scale_color_manual(values = c("#E69F00", "#000000")) +
45+
geom_point(size = 3, alpha = 0.3) +
46+
theme(text = element_text(size = 30), legend.position="none")
47+
48+
# theme(axis.text.x = element_text(size = 15),
49+
# axis.text.y = element_text(size = 15),
50+
# axis.title.x = element_text(size = 20),
51+
# axis.title.y = element_text(size = 20))
52+
ggsave(paste("banana-", method, ".png", sep = ""), device = "png", width = 14, height = 7.85)
53+
}
54+
55+
56+
# 3-repeated 3-fold cross validation
57+
# control <- trainControl(method = "repeatedcv", number = 3, repeats = 3,
58+
# summaryFunction = twoClassSummary, classProbs = TRUE)
59+
#
60+
# model <- train(Class~., data = dataset, method="kknn",
61+
# trControl=control,
62+
# metric = "ROC",
63+
# tuneGrid = data.frame(kmax = 5, distance = 2, kernel = "rectangular"))
64+
# model$results
65+
66+

0 commit comments

Comments
 (0)