|
| 1 | +library("imbalance") |
| 2 | +library("ggplot2") |
| 3 | +library("caret") |
| 4 | +set.seed(12345) |
| 5 | + |
| 6 | +# Load original banana dataset |
| 7 | +data(banana_orig) |
| 8 | +# Load imbalanced banana dataset |
| 9 | +data(banana) |
| 10 | +# Limits for the plot axis |
| 11 | +myxlim <- c(-4, 4) |
| 12 | +myylim <- c(-3, 4) |
| 13 | +# Methods to apply and plot |
| 14 | +methods <- c("original", "imbalanced", "SMOTE", "MWMOTE", "RWO", "PDFOS") |
| 15 | + |
| 16 | +for(i in seq_along(methods)){ |
| 17 | + method <- methods[i] |
| 18 | + |
| 19 | + if(method == "original"){ |
| 20 | + dataset <- banana_orig |
| 21 | + }else{ |
| 22 | + if(method == "imbalanced"){ |
| 23 | + dataset <- banana |
| 24 | + }else{ |
| 25 | + dataset <- oversample(banana, ratio = 0.7, method = method, filtering = T) |
| 26 | + } |
| 27 | + |
| 28 | + model <- knn3Train(dataset[, -3], banana_orig[, -3], dataset$Class, k = 3, l = 0, prob = TRUE, use.all = TRUE) |
| 29 | + model[model == "negative"] <- -1 |
| 30 | + model[model == "positive"] <- 1 |
| 31 | + model <- as.numeric(model) |
| 32 | + modelauc <- auc(banana_orig[, 3], model) |
| 33 | + print(paste(method, "AUC:", modelauc)) |
| 34 | + } |
| 35 | + |
| 36 | + qplot(At1, At2, col = Class, data = dataset, |
| 37 | + xlab = "at1", ylab = "at2", xlim = myxlim, ylim = myylim) + |
| 38 | + scale_colour_manual(values = c("#E69F00", "#000000")) + |
| 39 | + geom_point(size = 2, alpha = 0.3) + |
| 40 | + theme(text = element_text(size = 30), legend.position="none") |
| 41 | + |
| 42 | + ggplot(dataset, aes_string("At1", "At2", col = "Class")) + |
| 43 | + geom_point(alpha = 0.3) + |
| 44 | + scale_color_manual(values = c("#E69F00", "#000000")) + |
| 45 | + geom_point(size = 3, alpha = 0.3) + |
| 46 | + theme(text = element_text(size = 30), legend.position="none") |
| 47 | + |
| 48 | + # theme(axis.text.x = element_text(size = 15), |
| 49 | + # axis.text.y = element_text(size = 15), |
| 50 | + # axis.title.x = element_text(size = 20), |
| 51 | + # axis.title.y = element_text(size = 20)) |
| 52 | + ggsave(paste("banana-", method, ".png", sep = ""), device = "png", width = 14, height = 7.85) |
| 53 | +} |
| 54 | + |
| 55 | + |
| 56 | +# 3-repeated 3-fold cross validation |
| 57 | +# control <- trainControl(method = "repeatedcv", number = 3, repeats = 3, |
| 58 | +# summaryFunction = twoClassSummary, classProbs = TRUE) |
| 59 | +# |
| 60 | +# model <- train(Class~., data = dataset, method="kknn", |
| 61 | +# trControl=control, |
| 62 | +# metric = "ROC", |
| 63 | +# tuneGrid = data.frame(kmax = 5, distance = 2, kernel = "rectangular")) |
| 64 | +# model$results |
| 65 | + |
| 66 | + |
0 commit comments