diff --git a/R/LearnerClustAffinityPropagation.R b/R/LearnerClustAffinityPropagation.R index 4c577573..a9a5019c 100644 --- a/R/LearnerClustAffinityPropagation.R +++ b/R/LearnerClustAffinityPropagation.R @@ -63,6 +63,7 @@ LearnerClustAP = R6Class("LearnerClustAP", return(m) }, + .predict = function(task) { sim_func = self$param_set$values$s exemplar_data = attributes(self$model)$exemplar_data diff --git a/R/LearnerClustAgnes.R b/R/LearnerClustAgnes.R index d30e0fd4..6b2fa93b 100644 --- a/R/LearnerClustAgnes.R +++ b/R/LearnerClustAgnes.R @@ -70,6 +70,7 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes", return(m) }, + .predict = function(task) { if (self$param_set$values$k > task$nrow) { stopf("`k` needs to be between 1 and %i", task$nrow) diff --git a/R/LearnerClustCMeans.R b/R/LearnerClustCMeans.R index 4294b482..ed688fb8 100644 --- a/R/LearnerClustCMeans.R +++ b/R/LearnerClustCMeans.R @@ -70,6 +70,7 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans", return(m) }, + .predict = function(task) { partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids")) prob = unclass(cl_predict(self$model, newdata = task$data(), type = "memberships")) diff --git a/R/LearnerClustCobweb.R b/R/LearnerClustCobweb.R index ade38ca1..9181abc4 100644 --- a/R/LearnerClustCobweb.R +++ b/R/LearnerClustCobweb.R @@ -49,6 +49,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), type = "class") + 1L PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustDBSCAN.R b/R/LearnerClustDBSCAN.R index 7330b7a9..4e592d72 100644 --- a/R/LearnerClustDBSCAN.R +++ b/R/LearnerClustDBSCAN.R @@ -62,6 +62,7 @@ LearnerClustDBSCAN = R6Class("LearnerClustDBSCAN", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), self$model$data) PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustDBSCANfpc.R b/R/LearnerClustDBSCANfpc.R index 91276740..a62827e8 100644 --- a/R/LearnerClustDBSCANfpc.R +++ b/R/LearnerClustDBSCANfpc.R @@ -70,6 +70,7 @@ LearnerClustDBSCANfpc = R6Class("LearnerClustDBSCANfpc", return(m) }, + .predict = function(task) { partition = as.integer(predict(self$model, data = self$model$data, newdata = task$data())) PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustDiana.R b/R/LearnerClustDiana.R index d9abdfd5..66d5e225 100644 --- a/R/LearnerClustDiana.R +++ b/R/LearnerClustDiana.R @@ -51,6 +51,7 @@ LearnerClustDiana = R6Class("LearnerClustDiana", return(m) }, + .predict = function(task) { if (test_true(self$param_set$values$k > task$nrow)) { stopf("`k` needs to be between 1 and %s", task$nrow) diff --git a/R/LearnerClustEM.R b/R/LearnerClustEM.R index 7d00bb9e..f6b23d0a 100644 --- a/R/LearnerClustEM.R +++ b/R/LearnerClustEM.R @@ -60,6 +60,7 @@ LearnerClustEM = R6Class("LearnerClustEM", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), type = "class") + 1L PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustFanny.R b/R/LearnerClustFanny.R index cc1756c6..1d382450 100644 --- a/R/LearnerClustFanny.R +++ b/R/LearnerClustFanny.R @@ -57,6 +57,7 @@ LearnerClustFanny = R6Class("LearnerClustFanny", return(m) }, + .predict = function(task) { warn_prediction_useless(self$id) diff --git a/R/LearnerClustFarthestFirst.R b/R/LearnerClustFarthestFirst.R index 8cf3da58..de8677cd 100644 --- a/R/LearnerClustFarthestFirst.R +++ b/R/LearnerClustFarthestFirst.R @@ -59,6 +59,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), type = "class") + 1L PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustFeatureless.R b/R/LearnerClustFeatureless.R index d477c8f9..b6774569 100644 --- a/R/LearnerClustFeatureless.R +++ b/R/LearnerClustFeatureless.R @@ -55,6 +55,7 @@ LearnerClustFeatureless = R6Class("LearnerClustFeatureless", "clust.featureless_model" ) }, + .predict = function(task) { pv = self$param_set$get_values(tags = "predict") n = task$nrow diff --git a/R/LearnerClustHclust.R b/R/LearnerClustHclust.R index de31a14c..4c2361d7 100644 --- a/R/LearnerClustHclust.R +++ b/R/LearnerClustHclust.R @@ -69,6 +69,7 @@ LearnerClustHclust = R6Class("LearnerClustHclust", return(m) }, + .predict = function(task) { if (self$param_set$values$k > task$nrow) { stopf("`k` needs to be between 1 and %i", task$nrow) diff --git a/R/LearnerClustKKMeans.R b/R/LearnerClustKKMeans.R index 5caaf8e7..a5cc9ca5 100644 --- a/R/LearnerClustKKMeans.R +++ b/R/LearnerClustKKMeans.R @@ -72,6 +72,7 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans", } return(m) }, + .predict = function(task) { # all of predict is taken from mlr2 diff --git a/R/LearnerClustKMeans.R b/R/LearnerClustKMeans.R index e472cfca..3745f3bb 100644 --- a/R/LearnerClustKMeans.R +++ b/R/LearnerClustKMeans.R @@ -47,6 +47,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans", ) } ), + private = list( .train = function(task) { if ("nstart" %in% names(self$param_set$values)) { @@ -65,6 +66,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans", return(m) }, + .predict = function(task) { partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids")) PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustMclust.R b/R/LearnerClustMclust.R index cb8c54e7..4ba26380 100644 --- a/R/LearnerClustMclust.R +++ b/R/LearnerClustMclust.R @@ -53,6 +53,7 @@ LearnerClustMclust = R6Class("LearnerClustMclust", return(m) }, + .predict = function(task) { predictions = predict(self$model, newdata = task$data()) partition = as.integer(predictions$classification) diff --git a/R/LearnerClustMeanShift.R b/R/LearnerClustMeanShift.R index 56425f9b..f8e1e454 100644 --- a/R/LearnerClustMeanShift.R +++ b/R/LearnerClustMeanShift.R @@ -62,6 +62,7 @@ LearnerClustMeanShift = R6Class("LearnerClustMeanShift", return(m) }, + .predict = function(task) { warn_prediction_useless(self$id) partition = as.integer(self$model$cluster.label) diff --git a/R/LearnerClustMiniBatchKMeans.R b/R/LearnerClustMiniBatchKMeans.R index f9d8a953..a7afff39 100644 --- a/R/LearnerClustMiniBatchKMeans.R +++ b/R/LearnerClustMiniBatchKMeans.R @@ -78,6 +78,7 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans", return(m) }, + .predict = function(task) { if (self$predict_type == "partition") { partition = unclass(ClusterR::predict_MBatchKMeans( diff --git a/R/LearnerClustPAM.R b/R/LearnerClustPAM.R index 3e6568bc..2fef1281 100644 --- a/R/LearnerClustPAM.R +++ b/R/LearnerClustPAM.R @@ -73,6 +73,7 @@ LearnerClustPAM = R6Class("LearnerClustPAM", return(m) }, + .predict = function(task) { partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids")) PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustSimpleKMeans.R b/R/LearnerClustSimpleKMeans.R index c9b007fd..a6d2426a 100644 --- a/R/LearnerClustSimpleKMeans.R +++ b/R/LearnerClustSimpleKMeans.R @@ -73,6 +73,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), type = "class") + 1L PredictionClust$new(task = task, partition = partition) diff --git a/R/LearnerClustXMeans.R b/R/LearnerClustXMeans.R index aa921804..53a23408 100644 --- a/R/LearnerClustXMeans.R +++ b/R/LearnerClustXMeans.R @@ -63,6 +63,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans", return(m) }, + .predict = function(task) { partition = predict(self$model, newdata = task$data(), type = "class") + 1L PredictionClust$new(task = task, partition = partition)