Skip to content

Commit

Permalink
refactor: replace custom warningCondition to warningf and make rename…
Browse files Browse the repository at this point in the history
… check_ to assert_ for checkmate style behaviour
  • Loading branch information
m-muecke committed Apr 24, 2024
1 parent c0ca811 commit 7bde3de
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 7 deletions.
2 changes: 1 addition & 1 deletion R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")
assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
m = invoke(e1071::cmeans, x = task$data(), .args = pv, .opts = allow_partial_matching)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")
assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
m = invoke(kernlab::kkmeans, x = as.matrix(task$data()), .args = pv)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",
warningf("`nstart` parameter is only relevant when `centers` is integer.")
}

check_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")
assert_centers_param(self$param_set$values$centers, task, test_data_frame, "centers")

pv = self$param_set$get_values(tags = "train")
m = invoke(stats::kmeans, x = task$data(), .args = pv)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
),
private = list(
.train = function(task) {
check_centers_param(self$param_set$values$CENTROIDS, task, test_matrix, "CENTROIDS")
assert_centers_param(self$param_set$values$CENTROIDS, task, test_matrix, "CENTROIDS")
if (test_matrix(self$param_set$values$CENTROIDS) &&
nrow(self$param_set$values$CENTROIDS) != self$param_set$values$clusters) {
stopf("`CENTROIDS` must have same number of rows as `clusters`")
Expand Down
5 changes: 2 additions & 3 deletions R/helper.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
warn_prediction_useless = function(id) {
msg = sprintf("Learner '%s' doesn't predict on new data and predictions may not make sense on new data", id)
warning(warningCondition(msg, class = "predictionUselessWarning"))
warningf("Learner '%s' doesn't predict on new data and predictions may not make sense on new data.", id)
}

allow_partial_matching = list(
Expand All @@ -9,7 +8,7 @@ allow_partial_matching = list(
warnPartialMatchDollar = FALSE
)

check_centers_param = function(centers, task, test_class, name) {
assert_centers_param = function(centers, task, test_class, name) {
if (test_class(centers) && ncol(centers) != task$ncol) {
stopf("`%s` must have same number of columns as data.", name)
}
Expand Down

0 comments on commit 7bde3de

Please sign in to comment.