Skip to content

Commit

Permalink
refactor: use invoke for predict calls and replace do.call for invoke…
Browse files Browse the repository at this point in the history
… where relevant
  • Loading branch information
m-muecke committed Apr 22, 2024
1 parent 67248b1 commit a922215
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
prob = unclass(cl_predict(self$model, newdata = task$data(), type = "memberships"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
prob = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "memberships"))
colnames(prob) = seq_len(ncol(prob))

PredictionClust$new(task = task, partition = partition, prob = prob)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustCobweb.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::Cobweb, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -53,7 +53,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ LearnerClustEM = R6Class("LearnerClustEM",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::make_Weka_clusterer("weka/clusterers/EM"), x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -64,7 +64,7 @@ LearnerClustEM = R6Class("LearnerClustEM",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustFarthestFirst.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::FarthestFirst, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -54,7 +54,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustMclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ LearnerClustMclust = R6Class("LearnerClustMclust",
},

.predict = function(task) {
predictions = predict(self$model, newdata = task$data())
predictions = invoke(predict, self$model, newdata = task$data())
partition = as.integer(predictions$classification)
prob = predictions$z
PredictionClust$new(task = task, partition = partition, prob = prob)
Expand Down
6 changes: 3 additions & 3 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
pv = self$param_set$get_values(tags = "train")
m = invoke(ClusterR::MiniBatchKmeans, data = task$data(), .args = pv)
if (self$save_assignments) {
self$assignments = unclass(ClusterR::predict_MBatchKMeans(
self$assignments = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = m$centroids,
fuzzy = FALSE
Expand All @@ -82,15 +82,15 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",

.predict = function(task) {
if (self$predict_type == "partition") {
partition = unclass(ClusterR::predict_MBatchKMeans(
partition = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = self$model$centroids,
fuzzy = FALSE
))
partition = as.integer(partition)
pred = PredictionClust$new(task = task, partition = partition)
} else if (self$predict_type == "prob") {
partition = unclass(ClusterR::predict_MBatchKMeans(
partition = unclass(invoke(ClusterR::predict_MBatchKMeans,
data = task$data(),
CENTROIDS = self$model$centroids,
fuzzy = TRUE
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustPAM.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ LearnerClustPAM = R6Class("LearnerClustPAM",
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
partition = unclass(invoke(cl_predict, self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustSimpleKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::SimpleKMeans, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -68,7 +68,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustXMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = chartr("_", "-", names(pv))
ctrl = do.call(RWeka::Weka_control, pv)
ctrl = invoke(RWeka::Weka_control, .args = pv)
m = invoke(RWeka::XMeans, x = task$data(), control = ctrl)
if (self$save_assignments) {
self$assignments = unname(m$class_ids + 1L)
Expand All @@ -67,7 +67,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
partition = invoke(predict, self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
}
)
Expand Down

0 comments on commit a922215

Please sign in to comment.