Skip to content

Commit

Permalink
refactor: add back newline between train and predict methods
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Feb 5, 2024
1 parent b480dcb commit 05b4573
Show file tree
Hide file tree
Showing 20 changed files with 21 additions and 0 deletions.
1 change: 1 addition & 0 deletions R/LearnerClustAffinityPropagation.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ LearnerClustAP = R6Class("LearnerClustAP",

return(m)
},

.predict = function(task) {
sim_func = self$param_set$values$s
exemplar_data = attributes(self$model)$exemplar_data
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustAgnes.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustCMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ LearnerClustCMeans = R6Class("LearnerClustCMeans",

return(m)
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
prob = unclass(cl_predict(self$model, newdata = task$data(), type = "memberships"))
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustCobweb.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ LearnerClustCobweb = R6Class("LearnerClustCobweb",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ LearnerClustDBSCAN = R6Class("LearnerClustDBSCAN",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), self$model$data)
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustDBSCANfpc.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ LearnerClustDBSCANfpc = R6Class("LearnerClustDBSCANfpc",

return(m)
},

.predict = function(task) {
partition = as.integer(predict(self$model, data = self$model$data, newdata = task$data()))
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustDiana.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ LearnerClustDiana = R6Class("LearnerClustDiana",

return(m)
},

.predict = function(task) {
if (test_true(self$param_set$values$k > task$nrow)) {
stopf("`k` needs to be between 1 and %s", task$nrow)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustEM.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ LearnerClustEM = R6Class("LearnerClustEM",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustFanny.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ LearnerClustFanny = R6Class("LearnerClustFanny",

return(m)
},

.predict = function(task) {
warn_prediction_useless(self$id)

Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustFarthestFirst.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ LearnerClustFarthestFirst = R6Class("LearnerClustFF",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ LearnerClustFeatureless = R6Class("LearnerClustFeatureless",
"clust.featureless_model"
)
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
n = task$nrow
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustHclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ LearnerClustHclust = R6Class("LearnerClustHclust",

return(m)
},

.predict = function(task) {
if (self$param_set$values$k > task$nrow) {
stopf("`k` needs to be between 1 and %i", task$nrow)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",
}
return(m)
},

.predict = function(task) {
# all of predict is taken from mlr2

Expand Down
2 changes: 2 additions & 0 deletions R/LearnerClustKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",
)
}
),

private = list(
.train = function(task) {
if ("nstart" %in% names(self$param_set$values)) {
Expand All @@ -65,6 +66,7 @@ LearnerClustKMeans = R6Class("LearnerClustKMeans",

return(m)
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustMclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ LearnerClustMclust = R6Class("LearnerClustMclust",

return(m)
},

.predict = function(task) {
predictions = predict(self$model, newdata = task$data())
partition = as.integer(predictions$classification)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustMeanShift.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ LearnerClustMeanShift = R6Class("LearnerClustMeanShift",

return(m)
},

.predict = function(task) {
warn_prediction_useless(self$id)
partition = as.integer(self$model$cluster.label)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",

return(m)
},

.predict = function(task) {
if (self$predict_type == "partition") {
partition = unclass(ClusterR::predict_MBatchKMeans(
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustPAM.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ LearnerClustPAM = R6Class("LearnerClustPAM",

return(m)
},

.predict = function(task) {
partition = unclass(cl_predict(self$model, newdata = task$data(), type = "class_ids"))
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustSimpleKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ LearnerClustSimpleKMeans = R6Class("LearnerClustSimpleKMeans",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
Expand Down
1 change: 1 addition & 0 deletions R/LearnerClustXMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ LearnerClustXMeans = R6Class("LearnerClustXMeans",

return(m)
},

.predict = function(task) {
partition = predict(self$model, newdata = task$data(), type = "class") + 1L
PredictionClust$new(task = task, partition = partition)
Expand Down

0 comments on commit 05b4573

Please sign in to comment.